Files
mylinspirer/src/middleware.rs

305 lines
9.9 KiB
Rust

use std::sync::Arc;
use axum::{
extract::State,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use bytes::Bytes;
use http_body_util::BodyExt;
use serde_json::Value;
use tracing::{debug, warn};
use crate::{AppContext, db};
/// Main middleware to intercept, decrypt, modify, and log requests and responses.
pub async fn middleware(
State(ctx): State<Arc<AppContext>>,
req: Request<axum::body::Body>,
next: Next,
) -> impl IntoResponse {
let (parts, body) = req.into_parts();
let body_bytes = match body.collect().await {
Ok(body) => body.to_bytes(),
Err(e) => {
warn!("Failed to read request body: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to read request body".to_string(),
)
.into_response();
}
};
// Process request: decrypt, deserialize, modify, re-encrypt
let (processed_req_body, decrypted_request, method) = process_request(body_bytes, &ctx).await;
// Pass modified request to the next handler
let req = Request::from_parts(parts, axum::body::Body::from(processed_req_body));
let res = next.run(req).await;
// Process response: decrypt, deserialize, modify, re-encrypt
let (resp_parts, body) = res.into_parts();
let body_bytes = match body.collect().await {
Ok(b) => b.to_bytes(),
Err(e) => {
warn!("Failed to read response body: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to read response body".to_string(),
)
.into_response();
}
};
let (final_response_body, decrypted_response) =
process_response(body_bytes, &method, &ctx).await;
// Log the decrypted request and response
debug!(
"\nRequest:\n{}\nResponse:\n{}\n{}",
serde_json::to_string_pretty(&decrypted_request).unwrap_or_default(),
serde_json::to_string_pretty(&decrypted_response).unwrap_or_default(),
"-".repeat(80),
);
// Write log to database
if let Err(e) = db::repositories::logs::create(
&ctx.db,
method,
decrypted_request,
decrypted_response,
"".into(),
"".into(),
)
.await
{
warn!("Failed to log request: {}", e);
}
// Build and return the final response
let mut response_builder = Response::builder().status(resp_parts.status);
if !resp_parts.headers.is_empty() {
*response_builder.headers_mut().unwrap() = resp_parts.headers;
}
response_builder
.body(axum::body::Body::from(final_response_body))
.unwrap()
}
/// Processes the incoming request body.
/// Returns the re-encrypted body for the next handler, the decrypted JSON for logging, and the method name.
async fn process_request(body_bytes: Bytes, ctx: &Arc<AppContext>) -> (String, Value, String) {
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap_or_default();
let mut plain_request: Value = serde_json::from_str(&body_str).unwrap_or_else(|err| {
warn!(
"Failed to deserialize request body: {}. Using fallback string value.",
err
);
Value::String("Could not deserialize request".into())
});
decrypt_params(&mut plain_request, ctx);
let method = plain_request
.get("method")
.and_then(Value::as_str)
.unwrap_or_default()
.to_string();
if let Err(e) = modify_request(&mut plain_request, &method, ctx).await {
warn!("Failed to modify request: {}", e);
}
let mut crypted_request = plain_request.clone();
encrypt_params(&mut crypted_request, ctx);
(
serde_json::to_string(&crypted_request).expect("deserialization succeeded"),
plain_request,
method,
)
}
/// Processes the outgoing response body.
/// Returns the final encrypted body for the client and the decrypted JSON for logging.
async fn process_response(
body_bytes: Bytes,
method: &str,
ctx: &Arc<AppContext>,
) -> (String, Value) {
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap_or_default();
let decrypted_body = ctx.cryptor.decrypt(body_str.clone()).unwrap_or_else(|err| {
warn!(
"Failed to decrypt response body: {}. Assuming it's not encrypted.",
err
);
body_str
});
let mut response_value: Value = serde_json::from_str(&decrypted_body).unwrap_or_else(|err| {
warn!(
"Failed to deserialize response body: {}. Using string value.",
err
);
Value::String(decrypted_body.clone())
});
let decrypted = response_value.clone();
if let Err(e) = modify_response(&mut response_value, method, ctx, &decrypted_body).await {
warn!("Failed to modify response: {}", e);
}
let modified_body_str = serde_json::to_string(&response_value).unwrap_or_default();
let encrypted = ctx.cryptor.encrypt(modified_body_str);
(encrypted, decrypted)
}
/// Placeholder for request modification logic.
async fn modify_request(
_request_json: &mut Value,
_method: &str,
_ctx: &Arc<AppContext>,
) -> anyhow::Result<()> {
// TODO: Implement request modification logic based on rules or other criteria.
Ok(())
}
/// Applies modification rules to the response.
async fn modify_response(
response_json: &mut Value,
method: &str,
ctx: &Arc<AppContext>,
original_decrypted: &str,
) -> anyhow::Result<()> {
// Check for generic method interception (e.g., replace response from DB)
if let Some(intercepted) = intercept_response(method, original_decrypted, ctx).await? {
debug!("Intercepting response for method: {}", method);
*response_json = serde_json::from_str(&intercepted).unwrap_or_else(|e| {
warn!(
"Failed to parse intercepted response as JSON: {}. Using as string.",
e
);
Value::String(intercepted)
});
return Ok(());
}
// Special handling for getcommand
if method == "com.linspirer.device.getcommand"
&& let Err(e) = handle_getcommand_response(response_json, ctx).await
{
warn!(
"Failed to handle getcommand response: {}. Responding with empty command list.",
e
);
if let Some(obj) = response_json.as_object_mut() {
obj.insert("result".to_string(), Value::Array(vec![]));
}
}
Ok(())
}
/// Handles the 'getcommand' response by injecting verified commands.
async fn handle_getcommand_response(
response_json: &mut Value,
ctx: &Arc<AppContext>,
) -> anyhow::Result<()> {
if let Some(result) = response_json.get_mut("result")
&& let Some(commands) = result.as_array_mut()
&& !commands.is_empty()
{
// Persist commands to database
for cmd in commands.iter() {
let cmd_json = serde_json::to_string(cmd)?;
if let Err(e) =
crate::db::repositories::commands::insert(&ctx.db, &cmd_json, "unverified").await
{
warn!("Failed to persist command to database: {}", e);
}
debug!("Added command to the queue: {:?}", cmd);
}
}
if let Some(obj) = response_json.as_object_mut() {
// Get verified commands from database
let verified_cmds =
match crate::db::repositories::commands::list_by_status(&ctx.db, "verified").await {
Ok(cmds) => cmds,
Err(e) => {
warn!("Failed to fetch verified commands from database: {}", e);
Vec::new()
}
};
// Convert to JSON values
let verified_values: Vec<Value> = verified_cmds
.iter()
.filter_map(|c| serde_json::from_str(&c.command_json).ok())
.collect();
obj.insert("result".to_string(), Value::Array(verified_values));
// Clear verified commands from database after sending
if let Err(e) = crate::db::repositories::commands::clear_verified(&ctx.db).await {
warn!("Failed to clear verified commands from database: {}", e);
}
}
Ok(())
}
/// Checks for and applies response interception rules.
async fn intercept_response(
method: &str,
_orignal_response: &str,
ctx: &Arc<AppContext>,
) -> anyhow::Result<Option<String>> {
// Check if there's an interception rule for this method
let rule = crate::db::repositories::rules::find_by_method(&ctx.db, method).await?;
match rule {
Some(r) if r.action == "replace" => Ok(r.custom_response.as_ref().cloned()),
Some(r) if r.action == "modify" => {
// TODO: Apply modifications
Ok(None)
}
_ => Ok(None), // Passthrough
}
}
fn decrypt_params(request: &mut Value, ctx: &Arc<AppContext>) {
if let Some(params) = request.get_mut("params") {
*params = match params.take() {
Value::String(params) => {
let params = ctx.cryptor.decrypt(params).unwrap_or_else(|err| {
warn!("Failed to decrypt request params: {err}. Using fallback string value.",);
"\"Failed to decrypt params\"".to_string()
});
serde_json::from_str(&params).unwrap_or_else(|_| {
Value::String("Failed to deserialize decrypted params".into())
})
}
other => {
warn!("'params' is not an encrypted string");
other
}
};
};
}
fn encrypt_params(request: &mut Value, ctx: &Arc<AppContext>) {
if let Some(params) = request.get_mut("params") {
let plaintext = serde_json::to_string(params).unwrap_or_else(|err| {
warn!("Failed to serialize params: {err}. Using fallback string value.");
"".to_string()
});
*params = Value::String(ctx.cryptor.encrypt(plaintext));
};
}