use std::sync::Arc; use axum::{ extract::State, http::{Request, StatusCode}, middleware::Next, response::{IntoResponse, Response}, }; use http_body_util::BodyExt; use serde_json::Value; use tracing::{debug, warn}; use crate::{ AppContext, db::{ self, models::{InterceptionAction, InterceptionRule}, }, }; struct Processed { original: String, final_: Option<(String, InterceptionAction)>, encrypted: String, } /// Main middleware to intercept, decrypt, modify, and log requests and responses. pub async fn middleware( State(ctx): State>, req: Request, next: Next, ) -> impl IntoResponse { let (parts, body) = req.into_parts(); let body_bytes; let body = match body.collect().await { Ok(body) => { body_bytes = body.to_bytes(); str::from_utf8(&body_bytes).unwrap_or_else(|e| { warn!("Received request with invalid UTF-8: {e}. Replacing with an empty string"); "" }) } Err(e) => { warn!("Failed to read request body: {}", e); return ( StatusCode::INTERNAL_SERVER_ERROR, "Failed to read request body".to_string(), ) .into_response(); } }; // Process request: decrypt, deserialize, modify, re-encrypt let (processed_req, method) = process_request(body, &ctx).await; // Pass modified request to the next handler let req = Request::from_parts(parts, axum::body::Body::from(processed_req.encrypted)); let res = next.run(req).await; // Early exit if next handler returned error if !res.status().is_success() { return res; } // Process response: decrypt, deserialize, modify, re-encrypt let (resp_parts, body) = res.into_parts(); let body_bytes; let body = match body.collect().await { Ok(b) => { body_bytes = b.to_bytes(); str::from_utf8(&body_bytes) .unwrap_or_else(|e| { warn!( "Received response with invalid UTF-8: {e}. Replacing with an empty string" ); "" }) .to_string() } Err(e) => { warn!("Failed to read response body: {}", e); return ( StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response body".to_string(), ) .into_response(); } }; let processed_resp = process_response(body, &method, &ctx).await; let (req_body, req_action) = processed_req .final_ .map_or((None, None), |(a, b)| (Some(a), Some(b))); let (resp_body, resp_action) = processed_resp .final_ .map_or((None, None), |(a, b)| (Some(a), Some(b))); // Write log to database if let Err(e) = db::repositories::logs::create( &ctx.db, method, processed_req.original, processed_resp.original, req_body, resp_body, req_action, resp_action, ) .await { warn!("Failed to log request: {}", e); } // Build and return the final response let mut response_builder = Response::builder().status(resp_parts.status); if !resp_parts.headers.is_empty() { *response_builder.headers_mut().unwrap() = resp_parts.headers; } response_builder .body(axum::body::Body::from(processed_resp.encrypted)) .unwrap() } /// Processes the incoming request body. /// Returns the re-encrypted body for the next handler, the decrypted JSON for logging, and the method name. async fn process_request(body: &str, ctx: &Arc) -> (Processed, String) { let mut plain_request: Value = serde_json::from_str(body).unwrap_or_else(|err| { warn!( "Failed to deserialize request body: {}. Using fallback string value.", err ); Value::String("Could not deserialize request".into()) }); decrypt_params(&mut plain_request, ctx); let original = serde_json::to_string(&plain_request).expect("deserialization succeeded"); let method = plain_request .get("method") .and_then(Value::as_str) .unwrap_or_else(|| { warn!("No JSON-RPC method found in request body, fallback to an empty string"); "" }) .to_string(); let action = match modify_request(&mut plain_request, &method, ctx).await { Ok(action) => action, Err(e) => { warn!("Failed to modify request: {}", e); None } }; let final_ = action.map(|action| { ( serde_json::to_string(&plain_request).expect("deserialization succeeded"), action, ) }); let mut encrypted_request = plain_request.clone(); encrypt_params(&mut encrypted_request, ctx); ( Processed { original, final_, encrypted: serde_json::to_string(&encrypted_request) .expect("deserialization succeeded"), }, method, ) } /// Processes the outgoing response body. /// Returns the final encrypted body for the client and the decrypted JSON for logging. async fn process_response(body: String, method: &str, ctx: &Arc) -> Processed { let decrypted_body = ctx.cryptor.decrypt(body.clone()).unwrap_or_else(|err| { warn!( "Failed to decrypt response body: {}. Assuming it's not encrypted.", err ); body }); let mut response_value: Value = serde_json::from_str(&decrypted_body).unwrap_or_else(|err| { warn!( "Failed to deserialize response body: {}. Using string value.", err ); Value::String(decrypted_body.clone()) }); let action = match modify_response(&mut response_value, method, ctx).await { Ok(action) => action, Err(e) => { warn!("Failed to modify response: {}", e); None } }; let modified_body_str = serde_json::to_string(&response_value).expect("serialization succeeded"); let encrypted = ctx.cryptor.encrypt(modified_body_str.clone()); let final_ = action.map(|action| (modified_body_str.clone(), action)); Processed { original: decrypted_body, final_, encrypted, } } /// Placeholder for request modification logic. async fn modify_request( _request_json: &mut Value, _method: &str, _ctx: &Arc, ) -> anyhow::Result> { // TODO: Implement request modification logic based on rules or other criteria. Ok(None) } /// Applies modification rules to the response. async fn modify_response( response_json: &mut Value, method: &str, ctx: &Arc, ) -> anyhow::Result> { // Check for generic method interception (e.g., replace response from DB) if let Some((intercepted, action)) = intercept_response(method, ctx).await? { debug!("Intercepting response for method: {}", method); *response_json = serde_json::from_str(&intercepted).unwrap_or_else(|e| { warn!( "Failed to parse intercepted response as JSON: {}. Using as string.", e ); Value::String(intercepted) }); return Ok(Some(action)); } // Special handling for getcommand // TODO: Return interception rule if method == "com.linspirer.device.getcommand" && let Err(e) = handle_getcommand_response(response_json, ctx).await { warn!( "Failed to handle getcommand response: {}. Responding with empty command list.", e ); if let Some(obj) = response_json.as_object_mut() { obj.insert("result".to_string(), Value::Array(vec![])); } } Ok(None) } /// Handles the 'getcommand' response by injecting verified commands. async fn handle_getcommand_response( response_json: &mut Value, ctx: &Arc, ) -> anyhow::Result<()> { if let Some(result) = response_json.get_mut("result") && let Some(commands) = result.as_array_mut() && !commands.is_empty() { // Persist commands to database for cmd in commands.iter() { let cmd_json = serde_json::to_string(cmd)?; if let Err(e) = crate::db::repositories::commands::insert(&ctx.db, &cmd_json, "unverified").await { warn!("Failed to persist command to database: {}", e); } debug!("Added command to the queue: {:?}", cmd); } } if let Some(obj) = response_json.as_object_mut() { // Get verified commands from database let verified_cmds = match crate::db::repositories::commands::list_by_status(&ctx.db, "verified").await { Ok(cmds) => cmds, Err(e) => { warn!("Failed to fetch verified commands from database: {}", e); Vec::new() } }; // Convert to JSON values let verified_values: Vec = verified_cmds .iter() .filter_map(|c| serde_json::from_str(&c.command_json).ok()) .collect(); obj.insert("result".to_string(), Value::Array(verified_values)); // Clear verified commands from database after sending if let Err(e) = crate::db::repositories::commands::clear_verified(&ctx.db).await { warn!("Failed to clear verified commands from database: {}", e); } } Ok(()) } /// Checks for and applies response interception rules. async fn intercept_response( method: &str, ctx: &Arc, ) -> anyhow::Result> { // Check if there's an interception rule for this method let rule = crate::db::repositories::rules::find_by_method(&ctx.db, method).await?; match rule { Some(InterceptionRule { action: InterceptionAction::Replace, custom_response, .. }) => Ok(custom_response.map(|resp| (resp, InterceptionAction::Replace))), Some(InterceptionRule { action: InterceptionAction::Modify, .. }) => { // TODO: Apply modifications Ok(None) } _ => Ok(None), // Passthrough } } fn decrypt_params(request: &mut Value, ctx: &Arc) { if let Some(params) = request.get_mut("params") { *params = match params.take() { Value::String(params) => { let params = ctx.cryptor.decrypt(params).unwrap_or_else(|err| { warn!("Failed to decrypt request params: {err}. Using fallback string value.",); "\"Failed to decrypt params\"".to_string() }); serde_json::from_str(¶ms).unwrap_or_else(|_| { Value::String("Failed to deserialize decrypted params".into()) }) } other => { warn!("'params' is not an encrypted string"); other } }; }; } fn encrypt_params(request: &mut Value, ctx: &Arc) { if let Some(params) = request.get_mut("params") { let plaintext = serde_json::to_string(params).unwrap_or_else(|err| { warn!("Failed to serialize params: {err}. Using fallback string value."); "".to_string() }); *params = Value::String(ctx.cryptor.encrypt(plaintext)); }; }