362 lines
12 KiB
Rust
362 lines
12 KiB
Rust
use std::sync::Arc;
|
|
|
|
use axum::{
|
|
extract::State,
|
|
http::{Request, StatusCode},
|
|
middleware::Next,
|
|
response::{IntoResponse, Response},
|
|
};
|
|
use http_body_util::BodyExt;
|
|
use serde_json::Value;
|
|
use tracing::{debug, warn};
|
|
|
|
use crate::{
|
|
AppContext,
|
|
db::{
|
|
self,
|
|
models::{InterceptionAction, InterceptionRule},
|
|
},
|
|
};
|
|
|
|
struct Processed {
|
|
original: String,
|
|
final_: Option<(String, InterceptionAction)>,
|
|
encrypted: String,
|
|
}
|
|
|
|
/// Main middleware to intercept, decrypt, modify, and log requests and responses.
|
|
pub async fn middleware(
|
|
State(ctx): State<Arc<AppContext>>,
|
|
req: Request<axum::body::Body>,
|
|
next: Next,
|
|
) -> impl IntoResponse {
|
|
let (parts, body) = req.into_parts();
|
|
let body_bytes;
|
|
let body = match body.collect().await {
|
|
Ok(body) => {
|
|
body_bytes = body.to_bytes();
|
|
str::from_utf8(&body_bytes).unwrap_or_else(|e| {
|
|
warn!("Received request with invalid UTF-8: {e}. Replacing with an empty string");
|
|
""
|
|
})
|
|
}
|
|
Err(e) => {
|
|
warn!("Failed to read request body: {}", e);
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
"Failed to read request body".to_string(),
|
|
)
|
|
.into_response();
|
|
}
|
|
};
|
|
|
|
// Process request: decrypt, deserialize, modify, re-encrypt
|
|
let (processed_req, method) = process_request(body, &ctx).await;
|
|
|
|
// Pass modified request to the next handler
|
|
let req = Request::from_parts(parts, axum::body::Body::from(processed_req.encrypted));
|
|
let res = next.run(req).await;
|
|
|
|
// Early exit if next handler returned error
|
|
if !res.status().is_success() {
|
|
return res;
|
|
}
|
|
|
|
// Process response: decrypt, deserialize, modify, re-encrypt
|
|
let (resp_parts, body) = res.into_parts();
|
|
let body_bytes;
|
|
let body = match body.collect().await {
|
|
Ok(b) => {
|
|
body_bytes = b.to_bytes();
|
|
str::from_utf8(&body_bytes)
|
|
.unwrap_or_else(|e| {
|
|
warn!(
|
|
"Received response with invalid UTF-8: {e}. Replacing with an empty string"
|
|
);
|
|
""
|
|
})
|
|
.to_string()
|
|
}
|
|
Err(e) => {
|
|
warn!("Failed to read response body: {}", e);
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
"Failed to read response body".to_string(),
|
|
)
|
|
.into_response();
|
|
}
|
|
};
|
|
|
|
let processed_resp = process_response(body, &method, &ctx).await;
|
|
|
|
let (req_body, req_action) = processed_req
|
|
.final_
|
|
.map_or((None, None), |(a, b)| (Some(a), Some(b)));
|
|
let (resp_body, resp_action) = processed_resp
|
|
.final_
|
|
.map_or((None, None), |(a, b)| (Some(a), Some(b)));
|
|
// Write log to database
|
|
if let Err(e) = db::repositories::logs::create(
|
|
&ctx.db,
|
|
method,
|
|
processed_req.original,
|
|
processed_resp.original,
|
|
req_body,
|
|
resp_body,
|
|
req_action,
|
|
resp_action,
|
|
)
|
|
.await
|
|
{
|
|
warn!("Failed to log request: {}", e);
|
|
}
|
|
|
|
// Build and return the final response
|
|
let mut response_builder = Response::builder().status(resp_parts.status);
|
|
if !resp_parts.headers.is_empty() {
|
|
*response_builder.headers_mut().unwrap() = resp_parts.headers;
|
|
}
|
|
response_builder
|
|
.body(axum::body::Body::from(processed_resp.encrypted))
|
|
.unwrap()
|
|
}
|
|
|
|
/// Processes the incoming request body.
|
|
/// Returns the re-encrypted body for the next handler, the decrypted JSON for logging, and the method name.
|
|
async fn process_request(body: &str, ctx: &Arc<AppContext>) -> (Processed, String) {
|
|
let mut plain_request: Value = serde_json::from_str(body).unwrap_or_else(|err| {
|
|
warn!(
|
|
"Failed to deserialize request body: {}. Using fallback string value.",
|
|
err
|
|
);
|
|
Value::String("Could not deserialize request".into())
|
|
});
|
|
decrypt_params(&mut plain_request, ctx);
|
|
let original = serde_json::to_string(&plain_request).expect("deserialization succeeded");
|
|
|
|
let method = plain_request
|
|
.get("method")
|
|
.and_then(Value::as_str)
|
|
.unwrap_or_else(|| {
|
|
warn!("No JSON-RPC method found in request body, fallback to an empty string");
|
|
""
|
|
})
|
|
.to_string();
|
|
|
|
let action = match modify_request(&mut plain_request, &method, ctx).await {
|
|
Ok(action) => action,
|
|
Err(e) => {
|
|
warn!("Failed to modify request: {}", e);
|
|
None
|
|
}
|
|
};
|
|
let final_ = action.map(|action| {
|
|
(
|
|
serde_json::to_string(&plain_request).expect("deserialization succeeded"),
|
|
action,
|
|
)
|
|
});
|
|
|
|
let mut encrypted_request = plain_request.clone();
|
|
encrypt_params(&mut encrypted_request, ctx);
|
|
|
|
(
|
|
Processed {
|
|
original,
|
|
final_,
|
|
encrypted: serde_json::to_string(&encrypted_request)
|
|
.expect("deserialization succeeded"),
|
|
},
|
|
method,
|
|
)
|
|
}
|
|
|
|
/// Processes the outgoing response body.
|
|
/// Returns the final encrypted body for the client and the decrypted JSON for logging.
|
|
async fn process_response(body: String, method: &str, ctx: &Arc<AppContext>) -> Processed {
|
|
let decrypted_body = ctx.cryptor.decrypt(body.clone()).unwrap_or_else(|err| {
|
|
warn!(
|
|
"Failed to decrypt response body: {}. Assuming it's not encrypted.",
|
|
err
|
|
);
|
|
body
|
|
});
|
|
|
|
let mut response_value: Value = serde_json::from_str(&decrypted_body).unwrap_or_else(|err| {
|
|
warn!(
|
|
"Failed to deserialize response body: {}. Using string value.",
|
|
err
|
|
);
|
|
Value::String(decrypted_body.clone())
|
|
});
|
|
|
|
let action = match modify_response(&mut response_value, method, ctx).await {
|
|
Ok(action) => action,
|
|
Err(e) => {
|
|
warn!("Failed to modify response: {}", e);
|
|
None
|
|
}
|
|
};
|
|
|
|
let modified_body_str =
|
|
serde_json::to_string(&response_value).expect("serialization succeeded");
|
|
let encrypted = ctx.cryptor.encrypt(modified_body_str.clone());
|
|
let final_ = action.map(|action| (modified_body_str.clone(), action));
|
|
|
|
Processed {
|
|
original: decrypted_body,
|
|
final_,
|
|
encrypted,
|
|
}
|
|
}
|
|
|
|
/// Placeholder for request modification logic.
|
|
async fn modify_request(
|
|
_request_json: &mut Value,
|
|
_method: &str,
|
|
_ctx: &Arc<AppContext>,
|
|
) -> anyhow::Result<Option<InterceptionAction>> {
|
|
// TODO: Implement request modification logic based on rules or other criteria.
|
|
Ok(None)
|
|
}
|
|
|
|
/// Applies modification rules to the response.
|
|
async fn modify_response(
|
|
response_json: &mut Value,
|
|
method: &str,
|
|
ctx: &Arc<AppContext>,
|
|
) -> anyhow::Result<Option<InterceptionAction>> {
|
|
// Check for generic method interception (e.g., replace response from DB)
|
|
if let Some((intercepted, action)) = intercept_response(method, ctx).await? {
|
|
debug!("Intercepting response for method: {}", method);
|
|
*response_json = serde_json::from_str(&intercepted).unwrap_or_else(|e| {
|
|
warn!(
|
|
"Failed to parse intercepted response as JSON: {}. Using as string.",
|
|
e
|
|
);
|
|
Value::String(intercepted)
|
|
});
|
|
return Ok(Some(action));
|
|
}
|
|
|
|
// Special handling for getcommand
|
|
// TODO: Return interception rule
|
|
if method == "com.linspirer.device.getcommand"
|
|
&& let Err(e) = handle_getcommand_response(response_json, ctx).await
|
|
{
|
|
warn!(
|
|
"Failed to handle getcommand response: {}. Responding with empty command list.",
|
|
e
|
|
);
|
|
if let Some(obj) = response_json.as_object_mut() {
|
|
obj.insert("result".to_string(), Value::Array(vec![]));
|
|
}
|
|
}
|
|
|
|
Ok(None)
|
|
}
|
|
|
|
/// Handles the 'getcommand' response by injecting verified commands.
|
|
async fn handle_getcommand_response(
|
|
response_json: &mut Value,
|
|
ctx: &Arc<AppContext>,
|
|
) -> anyhow::Result<()> {
|
|
if let Some(result) = response_json.get_mut("result")
|
|
&& let Some(commands) = result.as_array_mut()
|
|
&& !commands.is_empty()
|
|
{
|
|
// Persist commands to database
|
|
for cmd in commands.iter() {
|
|
let cmd_json = serde_json::to_string(cmd)?;
|
|
if let Err(e) =
|
|
crate::db::repositories::commands::insert(&ctx.db, &cmd_json, "unverified").await
|
|
{
|
|
warn!("Failed to persist command to database: {}", e);
|
|
}
|
|
debug!("Added command to the queue: {:?}", cmd);
|
|
}
|
|
}
|
|
|
|
if let Some(obj) = response_json.as_object_mut() {
|
|
// Get verified commands from database
|
|
let verified_cmds =
|
|
match crate::db::repositories::commands::list_by_status(&ctx.db, "verified").await {
|
|
Ok(cmds) => cmds,
|
|
Err(e) => {
|
|
warn!("Failed to fetch verified commands from database: {}", e);
|
|
Vec::new()
|
|
}
|
|
};
|
|
|
|
// Convert to JSON values
|
|
let verified_values: Vec<Value> = verified_cmds
|
|
.iter()
|
|
.filter_map(|c| serde_json::from_str(&c.command_json).ok())
|
|
.collect();
|
|
|
|
obj.insert("result".to_string(), Value::Array(verified_values));
|
|
|
|
// Clear verified commands from database after sending
|
|
if let Err(e) = crate::db::repositories::commands::clear_verified(&ctx.db).await {
|
|
warn!("Failed to clear verified commands from database: {}", e);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Checks for and applies response interception rules.
|
|
async fn intercept_response(
|
|
method: &str,
|
|
ctx: &Arc<AppContext>,
|
|
) -> anyhow::Result<Option<(String, InterceptionAction)>> {
|
|
// Check if there's an interception rule for this method
|
|
let rule = crate::db::repositories::rules::find_by_method(&ctx.db, method).await?;
|
|
|
|
match rule {
|
|
Some(InterceptionRule {
|
|
action: InterceptionAction::Replace,
|
|
custom_response,
|
|
..
|
|
}) => Ok(custom_response.map(|resp| (resp, InterceptionAction::Replace))),
|
|
Some(InterceptionRule {
|
|
action: InterceptionAction::Modify,
|
|
..
|
|
}) => {
|
|
// TODO: Apply modifications
|
|
Ok(None)
|
|
}
|
|
_ => Ok(None), // Passthrough
|
|
}
|
|
}
|
|
|
|
fn decrypt_params(request: &mut Value, ctx: &Arc<AppContext>) {
|
|
if let Some(params) = request.get_mut("params") {
|
|
*params = match params.take() {
|
|
Value::String(params) => {
|
|
let params = ctx.cryptor.decrypt(params).unwrap_or_else(|err| {
|
|
warn!("Failed to decrypt request params: {err}. Using fallback string value.",);
|
|
"\"Failed to decrypt params\"".to_string()
|
|
});
|
|
serde_json::from_str(¶ms).unwrap_or_else(|_| {
|
|
Value::String("Failed to deserialize decrypted params".into())
|
|
})
|
|
}
|
|
other => {
|
|
warn!("'params' is not an encrypted string");
|
|
other
|
|
}
|
|
};
|
|
};
|
|
}
|
|
|
|
fn encrypt_params(request: &mut Value, ctx: &Arc<AppContext>) {
|
|
if let Some(params) = request.get_mut("params") {
|
|
let plaintext = serde_json::to_string(params).unwrap_or_else(|err| {
|
|
warn!("Failed to serialize params: {err}. Using fallback string value.");
|
|
"".to_string()
|
|
});
|
|
*params = Value::String(ctx.cryptor.encrypt(plaintext));
|
|
};
|
|
}
|