refactor: modularize middleware & crypto
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
use std::str;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
@@ -7,19 +6,16 @@ use axum::{
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use http_body_util::BodyExt;
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::{AppState, crypto, db};
|
||||
|
||||
enum ResponseBody {
|
||||
Original(String),
|
||||
Modified(Value),
|
||||
}
|
||||
use crate::{AppContext, db};
|
||||
|
||||
/// Main middleware to intercept, decrypt, modify, and log requests and responses.
|
||||
pub async fn middleware(
|
||||
State(state): State<Arc<AppState>>,
|
||||
State(ctx): State<Arc<AppContext>>,
|
||||
req: Request<axum::body::Body>,
|
||||
next: Next,
|
||||
) -> impl IntoResponse {
|
||||
@@ -36,91 +32,31 @@ pub async fn middleware(
|
||||
}
|
||||
};
|
||||
|
||||
let (decrypted_request, method) = match str::from_utf8(&body_bytes)
|
||||
.map_err(anyhow::Error::from)
|
||||
.and_then(|body| process_and_log_request(body, &state.key, &state.iv))
|
||||
{
|
||||
Ok(request_data) => {
|
||||
let method = request_data
|
||||
.get("method")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
(request_data, method)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to process request for logging: {}", e);
|
||||
let val = Value::String("Could not decrypt request".to_string());
|
||||
(val, None)
|
||||
}
|
||||
};
|
||||
// Process request: decrypt, deserialize, modify, re-encrypt
|
||||
let (processed_req_body, decrypted_request, method) = process_request(body_bytes, &ctx).await;
|
||||
|
||||
let req = Request::from_parts(parts, axum::body::Body::from(body_bytes));
|
||||
// Pass modified request to the next handler
|
||||
let req = Request::from_parts(parts, axum::body::Body::from(processed_req_body));
|
||||
let res = next.run(req).await;
|
||||
|
||||
let (resp_parts, body_bytes) = {
|
||||
let (parts, body) = res.into_parts();
|
||||
let bytes = match body.collect().await {
|
||||
Ok(b) => b.to_bytes(),
|
||||
Err(e) => {
|
||||
warn!("Failed to read response body: {}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to read response body".to_string(),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
(parts, bytes)
|
||||
};
|
||||
let resp_body_text = String::from_utf8(body_bytes.clone().to_vec()).unwrap_or_default();
|
||||
|
||||
// Check for generic method interception first
|
||||
let response_body = if let Some(method_str) = &method {
|
||||
if let Ok(Some(intercepted)) = intercept_response(method_str, &resp_body_text, &state).await
|
||||
{
|
||||
info!("Intercepting response for method: {}", method_str);
|
||||
ResponseBody::Original(intercepted)
|
||||
} else if Some("com.linspirer.device.getcommand") == method.as_deref() {
|
||||
// Special handling for getcommand
|
||||
match handle_getcommand_response(&resp_body_text, &state).await {
|
||||
Ok(new_body) => ResponseBody::Modified(new_body),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to handle getcommand response: {}. Responding with empty command list.",
|
||||
e
|
||||
);
|
||||
let mut empty_response =
|
||||
serde_json::from_str::<Value>(&resp_body_text).unwrap_or(Value::Null);
|
||||
if let Some(obj) = empty_response.as_object_mut() {
|
||||
obj.insert("result".to_string(), Value::Array(vec![]));
|
||||
}
|
||||
ResponseBody::Modified(empty_response)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ResponseBody::Original(resp_body_text.clone())
|
||||
}
|
||||
} else {
|
||||
ResponseBody::Original(resp_body_text.clone())
|
||||
};
|
||||
|
||||
let (decrypted_response, final_response_body) = match response_body {
|
||||
ResponseBody::Original(body_text) => {
|
||||
let decrypted =
|
||||
decrypt_and_format(&body_text, &state.key, &state.iv).unwrap_or_else(|_| {
|
||||
Value::String("Could not decrypt or format response".to_string())
|
||||
});
|
||||
(decrypted, body_text)
|
||||
}
|
||||
ResponseBody::Modified(response_body_value) => {
|
||||
let pretty_printed =
|
||||
serde_json::to_string_pretty(&response_body_value).unwrap_or_default();
|
||||
let encrypted = crypto::encrypt(&pretty_printed, &state.key, &state.iv)
|
||||
.unwrap_or_else(|_| "Failed to encrypt modified response".to_string());
|
||||
(response_body_value, encrypted)
|
||||
// Process response: decrypt, deserialize, modify, re-encrypt
|
||||
let (resp_parts, body) = res.into_parts();
|
||||
let body_bytes = match body.collect().await {
|
||||
Ok(b) => b.to_bytes(),
|
||||
Err(e) => {
|
||||
warn!("Failed to read response body: {}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to read response body".to_string(),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let (final_response_body, decrypted_response) =
|
||||
process_response(body_bytes, &method, &ctx).await;
|
||||
|
||||
// Log the decrypted request and response
|
||||
debug!(
|
||||
"\nRequest:\n{}\nResponse:\n{}\n{}",
|
||||
serde_json::to_string_pretty(&decrypted_request).unwrap_or_default(),
|
||||
@@ -128,22 +64,21 @@ pub async fn middleware(
|
||||
"-".repeat(80),
|
||||
);
|
||||
|
||||
if let Some(method) = method {
|
||||
// TODO: interception action
|
||||
if let Err(e) = db::repositories::logs::create(
|
||||
&state.db,
|
||||
method,
|
||||
decrypted_request,
|
||||
decrypted_response,
|
||||
"".into(),
|
||||
"".into(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to log request: {}", e);
|
||||
}
|
||||
// Write log to database
|
||||
if let Err(e) = db::repositories::logs::create(
|
||||
&ctx.db,
|
||||
method,
|
||||
decrypted_request,
|
||||
decrypted_response,
|
||||
"".into(),
|
||||
"".into(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to log request: {}", e);
|
||||
}
|
||||
|
||||
// Build and return the final response
|
||||
let mut response_builder = Response::builder().status(resp_parts.status);
|
||||
if !resp_parts.headers.is_empty() {
|
||||
*response_builder.headers_mut().unwrap() = resp_parts.headers;
|
||||
@@ -153,34 +88,128 @@ pub async fn middleware(
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn process_and_log_request(body: &str, key: &str, iv: &str) -> anyhow::Result<Value> {
|
||||
let mut request_data: Value = serde_json::from_str(body)?;
|
||||
/// Processes the incoming request body.
|
||||
/// Returns the re-encrypted body for the next handler, the decrypted JSON for logging, and the method name.
|
||||
async fn process_request(body_bytes: Bytes, ctx: &Arc<AppContext>) -> (String, Value, String) {
|
||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap_or_default();
|
||||
|
||||
if let Some(params_value) = request_data.get_mut("params")
|
||||
&& let Some(params_str) = params_value.as_str()
|
||||
{
|
||||
let params_str_owned = params_str.to_string();
|
||||
match crypto::decrypt(¶ms_str_owned, key, iv) {
|
||||
Ok(decrypted_str) => {
|
||||
let decrypted_params: Value =
|
||||
serde_json::from_str(&decrypted_str).unwrap_or(Value::String(decrypted_str));
|
||||
*params_value = decrypted_params;
|
||||
}
|
||||
Err(e) => {
|
||||
*params_value = Value::String(format!("decrypt failed: {}", e));
|
||||
}
|
||||
}
|
||||
let mut plain_request: Value = serde_json::from_str(&body_str).unwrap_or_else(|err| {
|
||||
warn!(
|
||||
"Failed to deserialize request body: {}. Using fallback string value.",
|
||||
err
|
||||
);
|
||||
Value::String("Could not deserialize request".into())
|
||||
});
|
||||
decrypt_params(&mut plain_request, ctx);
|
||||
|
||||
let method = plain_request
|
||||
.get("method")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
if let Err(e) = modify_request(&mut plain_request, &method, ctx).await {
|
||||
warn!("Failed to modify request: {}", e);
|
||||
}
|
||||
Ok(request_data)
|
||||
|
||||
let mut crypted_request = plain_request.clone();
|
||||
encrypt_params(&mut crypted_request, ctx);
|
||||
|
||||
(
|
||||
serde_json::to_string(&crypted_request).expect("deserialization succeeded"),
|
||||
plain_request,
|
||||
method,
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_getcommand_response(
|
||||
body_text: &str,
|
||||
state: &Arc<AppState>,
|
||||
) -> anyhow::Result<Value> {
|
||||
let decrypted = crypto::decrypt(body_text, &state.key, &state.iv)?;
|
||||
let mut response_json: Value = serde_json::from_str(&decrypted)?;
|
||||
/// Processes the outgoing response body.
|
||||
/// Returns the final encrypted body for the client and the decrypted JSON for logging.
|
||||
async fn process_response(
|
||||
body_bytes: Bytes,
|
||||
method: &str,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> (String, Value) {
|
||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap_or_default();
|
||||
|
||||
let decrypted_body = ctx.cryptor.decrypt(body_str.clone()).unwrap_or_else(|err| {
|
||||
warn!(
|
||||
"Failed to decrypt response body: {}. Assuming it's not encrypted.",
|
||||
err
|
||||
);
|
||||
body_str
|
||||
});
|
||||
|
||||
let mut response_value: Value = serde_json::from_str(&decrypted_body).unwrap_or_else(|err| {
|
||||
warn!(
|
||||
"Failed to deserialize response body: {}. Using string value.",
|
||||
err
|
||||
);
|
||||
Value::String(decrypted_body.clone())
|
||||
});
|
||||
|
||||
let decrypted = response_value.clone();
|
||||
|
||||
if let Err(e) = modify_response(&mut response_value, method, ctx, &decrypted_body).await {
|
||||
warn!("Failed to modify response: {}", e);
|
||||
}
|
||||
|
||||
let modified_body_str = serde_json::to_string(&response_value).unwrap_or_default();
|
||||
let encrypted = ctx.cryptor.encrypt(modified_body_str);
|
||||
|
||||
(encrypted, decrypted)
|
||||
}
|
||||
|
||||
/// Placeholder for request modification logic.
|
||||
async fn modify_request(
|
||||
_request_json: &mut Value,
|
||||
_method: &str,
|
||||
_ctx: &Arc<AppContext>,
|
||||
) -> anyhow::Result<()> {
|
||||
// TODO: Implement request modification logic based on rules or other criteria.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Applies modification rules to the response.
|
||||
async fn modify_response(
|
||||
response_json: &mut Value,
|
||||
method: &str,
|
||||
ctx: &Arc<AppContext>,
|
||||
original_decrypted: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
// Check for generic method interception (e.g., replace response from DB)
|
||||
if let Some(intercepted) = intercept_response(method, original_decrypted, ctx).await? {
|
||||
info!("Intercepting response for method: {}", method);
|
||||
*response_json = serde_json::from_str(&intercepted).unwrap_or_else(|e| {
|
||||
warn!(
|
||||
"Failed to parse intercepted response as JSON: {}. Using as string.",
|
||||
e
|
||||
);
|
||||
Value::String(intercepted)
|
||||
});
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Special handling for getcommand
|
||||
if method == "com.linspirer.device.getcommand"
|
||||
&& let Err(e) = handle_getcommand_response(response_json, ctx).await
|
||||
{
|
||||
warn!(
|
||||
"Failed to handle getcommand response: {}. Responding with empty command list.",
|
||||
e
|
||||
);
|
||||
if let Some(obj) = response_json.as_object_mut() {
|
||||
obj.insert("result".to_string(), Value::Array(vec![]));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handles the 'getcommand' response by injecting verified commands.
|
||||
async fn handle_getcommand_response(
|
||||
response_json: &mut Value,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> anyhow::Result<()> {
|
||||
if let Some(result) = response_json.get_mut("result")
|
||||
&& let Some(commands) = result.as_array_mut()
|
||||
&& !commands.is_empty()
|
||||
@@ -189,22 +218,18 @@ async fn handle_getcommand_response(
|
||||
for cmd in commands.iter() {
|
||||
let cmd_json = serde_json::to_string(cmd)?;
|
||||
if let Err(e) =
|
||||
crate::db::repositories::commands::insert(&state.db, &cmd_json, "unverified").await
|
||||
crate::db::repositories::commands::insert(&ctx.db, &cmd_json, "unverified").await
|
||||
{
|
||||
warn!("Failed to persist command to database: {}", e);
|
||||
}
|
||||
info!("Added command to the queue: {:?}", cmd);
|
||||
}
|
||||
|
||||
// Also add to in-memory queue for backwards compatibility
|
||||
let mut queue = state.commands.unverified.write().await;
|
||||
queue.extend(commands.drain(..));
|
||||
}
|
||||
|
||||
if let Some(obj) = response_json.as_object_mut() {
|
||||
// Get verified commands from database
|
||||
let verified_cmds =
|
||||
match crate::db::repositories::commands::list_by_status(&state.db, "verified").await {
|
||||
match crate::db::repositories::commands::list_by_status(&ctx.db, "verified").await {
|
||||
Ok(cmds) => cmds,
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch verified commands from database: {}", e);
|
||||
@@ -218,48 +243,62 @@ async fn handle_getcommand_response(
|
||||
.filter_map(|c| serde_json::from_str(&c.command_json).ok())
|
||||
.collect();
|
||||
|
||||
// Also include in-memory verified commands
|
||||
let mem_verified = std::mem::take(&mut *state.commands.verified.write().await);
|
||||
let mut all_verified = verified_values;
|
||||
all_verified.extend(mem_verified);
|
||||
|
||||
obj.insert("result".to_string(), Value::Array(all_verified.clone()));
|
||||
obj.insert("result".to_string(), Value::Array(verified_values));
|
||||
|
||||
// Clear verified commands from database after sending
|
||||
if let Err(e) = crate::db::repositories::commands::clear_verified(&state.db).await {
|
||||
if let Err(e) = crate::db::repositories::commands::clear_verified(&ctx.db).await {
|
||||
warn!("Failed to clear verified commands from database: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response_json)
|
||||
}
|
||||
|
||||
fn decrypt_and_format(body_text: &str, key: &str, iv: &str) -> anyhow::Result<Value> {
|
||||
let decrypted = crypto::decrypt(body_text, key, iv)?;
|
||||
Ok(serde_json::from_str(&decrypted)?)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Checks for and applies response interception rules.
|
||||
async fn intercept_response(
|
||||
method: &str,
|
||||
_original_response: &str,
|
||||
state: &Arc<AppState>,
|
||||
_orignal_response: &str,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> anyhow::Result<Option<String>> {
|
||||
// Check if there's an interception rule for this method
|
||||
let rule = crate::db::repositories::rules::find_by_method(&state.db, method).await?;
|
||||
let rule = crate::db::repositories::rules::find_by_method(&ctx.db, method).await?;
|
||||
|
||||
match rule {
|
||||
Some(r) if r.action == "replace" => {
|
||||
if let Some(custom_response) = &r.custom_response {
|
||||
Ok(crypto::encrypt(custom_response, &state.key, &state.iv).map(Some)?)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
Some(r) if r.action == "replace" => Ok(r.custom_response.as_ref().cloned()),
|
||||
Some(r) if r.action == "modify" => {
|
||||
// Future: Apply transformations
|
||||
// For now, just pass through
|
||||
// TODO: Apply modifications
|
||||
Ok(None)
|
||||
}
|
||||
_ => Ok(None), // Passthrough
|
||||
}
|
||||
}
|
||||
|
||||
fn decrypt_params(request: &mut Value, ctx: &Arc<AppContext>) {
|
||||
if let Some(params) = request.get_mut("params") {
|
||||
*params = match params.take() {
|
||||
Value::String(params) => {
|
||||
let params = ctx.cryptor.decrypt(params).unwrap_or_else(|err| {
|
||||
warn!("Failed to decrypt request params: {err}. Using fallback string value.",);
|
||||
"\"Failed to decrypt params\"".to_string()
|
||||
});
|
||||
serde_json::from_str(¶ms).unwrap_or_else(|_| {
|
||||
Value::String("Failed to deserialize decrypted params".into())
|
||||
})
|
||||
}
|
||||
other => {
|
||||
warn!("'params' is not an encrypted string");
|
||||
other
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
fn encrypt_params(request: &mut Value, ctx: &Arc<AppContext>) {
|
||||
if let Some(params) = request.get_mut("params") {
|
||||
let plaintext = serde_json::to_string(params).unwrap_or_else(|err| {
|
||||
warn!("Failed to serialize params: {err}. Using fallback string value.");
|
||||
"".to_string()
|
||||
});
|
||||
*params = Value::String(ctx.cryptor.encrypt(plaintext));
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user