From 0e53ef4488fb1c627c33c13fbed8908dde6dd27e Mon Sep 17 00:00:00 2001 From: imxyy_soope_ Date: Thu, 4 Dec 2025 17:35:26 +0800 Subject: [PATCH] refactor: modularize middleware & crypto --- Cargo.lock | 1 + Cargo.toml | 1 + frontend/src/components/ChangePassword.tsx | 2 +- src/admin/auth_middleware.rs | 4 +- src/admin/handlers.rs | 34 +- src/admin/models.rs | 18 -- src/admin/routes.rs | 6 +- src/context.rs | 67 ++++ src/crypto.rs | 108 +++---- src/jsonrpc.rs | 2 + src/jsonrpc/tactics.rs | 2 + src/main.rs | 26 +- src/middleware.rs | 349 ++++++++++++--------- src/proxy.rs | 4 +- src/state.rs | 33 -- 15 files changed, 347 insertions(+), 310 deletions(-) create mode 100644 src/context.rs delete mode 100644 src/state.rs diff --git a/Cargo.lock b/Cargo.lock index 199f654..63f52f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1298,6 +1298,7 @@ dependencies = [ "axum", "base64", "bcrypt", + "bytes", "cbc", "chrono", "concat-idents", diff --git a/Cargo.toml b/Cargo.toml index aefb437..9b68b35 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,3 +30,4 @@ rust-embed = "8.0" mime_guess = "2.0" jsonwebtoken = "9" bcrypt = "0.17" +bytes = "1.11.0" diff --git a/frontend/src/components/ChangePassword.tsx b/frontend/src/components/ChangePassword.tsx index 9170606..445997f 100644 --- a/frontend/src/components/ChangePassword.tsx +++ b/frontend/src/components/ChangePassword.tsx @@ -2,7 +2,7 @@ import { Component, createSignal } from 'solid-js'; import { authApi } from '../api/client'; import { authStore } from '../api/auth'; import { Button } from './ui/Button'; -import { Card, CardContent, CardFooter, CardHeader } from './ui/Card'; +import { CardContent, CardFooter, CardHeader } from './ui/Card'; import { Input } from './ui/Input'; import { Modal, ModalContent } from './ui/Modal'; diff --git a/src/admin/auth_middleware.rs b/src/admin/auth_middleware.rs index 25188eb..cd6e292 100644 --- a/src/admin/auth_middleware.rs +++ b/src/admin/auth_middleware.rs @@ -7,10 +7,10 @@ use axum::{ response::Response, }; -use crate::{auth, state::AppState}; +use crate::{auth, context::AppContext}; pub async fn auth_middleware( - State(state): State>, + State(state): State>, request: Request, next: Next, ) -> Result { diff --git a/src/admin/handlers.rs b/src/admin/handlers.rs index 7904c7b..6d02bd7 100644 --- a/src/admin/handlers.rs +++ b/src/admin/handlers.rs @@ -8,7 +8,7 @@ use axum::{ use tracing::error; use crate::{ - AppState, auth, + AppContext, auth, db::{self, models::RequestLog}, }; @@ -16,7 +16,7 @@ use super::models::*; // Authentication handlers pub async fn login( - State(state): State>, + State(state): State>, Json(req): Json, ) -> Result, (StatusCode, Json)> { // Get stored password hash from config @@ -69,7 +69,7 @@ pub async fn login( } pub async fn change_password( - State(state): State>, + State(state): State>, Json(req): Json, ) -> Result)> { // Get current password hash from config @@ -140,7 +140,7 @@ pub async fn change_password( // Rules handlers pub async fn list_rules( - State(state): State>, + State(state): State>, ) -> Result>, (StatusCode, Json)> { match db::repositories::rules::list_all(&state.db).await { Ok(rules) => Ok(Json(rules.into_iter().map(Into::into).collect())), @@ -155,7 +155,7 @@ pub async fn list_rules( } pub async fn create_rule( - State(state): State>, + State(state): State>, Json(req): Json, ) -> Result, (StatusCode, Json)> { // Validate action @@ -211,7 +211,7 @@ pub async fn create_rule( } pub async fn update_rule( - State(state): State>, + State(state): State>, Path(id): Path, Json(req): Json, ) -> Result, (StatusCode, Json)> { @@ -250,7 +250,7 @@ pub async fn update_rule( } pub async fn delete_rule( - State(state): State>, + State(state): State>, Path(id): Path, ) -> Result)> { match db::repositories::rules::delete(&state.db, id).await { @@ -267,7 +267,7 @@ pub async fn delete_rule( // Commands handlers pub async fn list_commands( - State(state): State>, + State(state): State>, ) -> Result>, (StatusCode, Json)> { match db::repositories::commands::list_all(&state.db).await { Ok(commands) => Ok(Json(commands.into_iter().map(Into::into).collect())), @@ -282,7 +282,7 @@ pub async fn list_commands( } pub async fn verify_command( - State(state): State>, + State(state): State>, Path(id): Path, Json(req): Json, ) -> Result, (StatusCode, Json)> { @@ -295,15 +295,7 @@ pub async fn verify_command( .await { Ok(_) => match db::repositories::commands::find_by_id(&state.db, id).await { - Ok(Some(cmd)) => { - // Also update in-memory queue if status is verified - if req.status == "verified" - && let Ok(cmd_value) = serde_json::from_str(&cmd.command_json) - { - state.commands.verified.write().await.push(cmd_value); - } - Ok(Json(cmd.into())) - } + Ok(Some(cmd)) => Ok(Json(cmd.into())), _ => Err(( StatusCode::NOT_FOUND, Json(ApiError::new("Command not found")), @@ -321,7 +313,7 @@ pub async fn verify_command( // Config handlers pub async fn get_config( - State(state): State>, + State(state): State>, ) -> Result>, (StatusCode, Json)> { match db::repositories::config::get_all(&state.db).await { Ok(config) => Ok(Json(config)), @@ -336,7 +328,7 @@ pub async fn get_config( } pub async fn update_config( - State(state): State>, + State(state): State>, Json(config): Json>, ) -> Result)> { for (key, value) in config { @@ -354,7 +346,7 @@ pub async fn update_config( // Log handlers pub async fn list_logs( - State(state): State>, + State(state): State>, ) -> Result>, (StatusCode, Json)> { match db::repositories::logs::list_all(&state.db).await { Ok(logs) => Ok(Json(logs)), diff --git a/src/admin/models.rs b/src/admin/models.rs index b904767..f66dfde 100644 --- a/src/admin/models.rs +++ b/src/admin/models.rs @@ -99,21 +99,3 @@ impl ApiError { Self { error: msg.into() } } } - -#[derive(Debug, Serialize)] -pub struct RequestDetails { - pub headers: Vec
, - pub body: Body, -} - -#[derive(Debug, Serialize)] -pub struct Body { - pub value: Value, - pub modified: bool, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Header { - pub value: String, - pub modified: bool, -} diff --git a/src/admin/routes.rs b/src/admin/routes.rs index 270bc9d..11fc9ba 100644 --- a/src/admin/routes.rs +++ b/src/admin/routes.rs @@ -5,11 +5,11 @@ use axum::{ routing::{delete, get, post, put}, }; -use crate::AppState; +use crate::AppContext; use super::{auth_middleware, handlers, static_files}; -pub fn admin_routes(state: Arc) -> Router> { +pub fn admin_routes(ctx: Arc) -> Router> { // Public routes (no authentication required) let public_routes = Router::new().route("/api/login", post(handlers::login)); @@ -26,7 +26,7 @@ pub fn admin_routes(state: Arc) -> Router> { .route("/api/config", put(handlers::update_config)) .route("/api/logs", get(handlers::list_logs)) .layer(axum::middleware::from_fn_with_state( - state, + ctx, auth_middleware::auth_middleware, )); diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..b0a20cf --- /dev/null +++ b/src/context.rs @@ -0,0 +1,67 @@ +use sqlx::SqlitePool; + +use crate::crypto::Cryptor; + +pub struct AppContext { + pub client: reqwest::Client, + pub target_url: reqwest::Url, + pub cryptor: Cryptor, + pub jwt_secret: String, + pub db: SqlitePool, +} + +impl AppContext { + pub fn builder<'a>( + target_url: reqwest::Url, + key: &'a [u8], + iv: &'a [u8], + jwt_secret: String, + db: SqlitePool, + ) -> AppContextBuilder<'a> { + AppContextBuilder { + client: None, + target_url, + key, + iv, + jwt_secret, + db, + } + } +} + +pub struct AppContextBuilder<'a> { + client: Option, + target_url: reqwest::Url, + key: &'a [u8], + iv: &'a [u8], + jwt_secret: String, + db: SqlitePool, +} + +impl AppContextBuilder<'_> { + pub fn with_client(self, client: reqwest::Client) -> Self { + Self { + client: Some(client), + ..self + } + } + + pub fn build(self) -> anyhow::Result { + let AppContextBuilder { + client, + target_url, + key, + iv, + jwt_secret, + db, + } = self; + + Ok(AppContext { + client: client.unwrap_or_default(), + cryptor: Cryptor::new(key, iv)?, + target_url, + jwt_secret, + db, + }) + } +} diff --git a/src/crypto.rs b/src/crypto.rs index 33321e7..babf8e3 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -7,71 +7,59 @@ use cbc::{ cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit, block_padding::Pkcs7}, }; -type Aes128CbcEnc = Encryptor; -type Aes128CbcDec = Decryptor; - -#[derive(Debug)] -pub enum CryptoError { - Base64(base64::DecodeError), - Aes(cbc::cipher::InvalidLength), - Unpad(cbc::cipher::block_padding::UnpadError), - Pad(aes::cipher::inout::PadError), +pub struct Cryptor { + encryptor: Encryptor, + decryptor: Decryptor, } -impl fmt::Display for CryptoError { +impl Cryptor { + pub fn new(key: &[u8], iv: &[u8]) -> Result { + Ok(Self { + encryptor: Encryptor::new_from_slices(key, iv)?, + decryptor: Decryptor::new_from_slices(key, iv)?, + }) + } + + pub fn decrypt(&self, ciphertext: String) -> Result { + let mut ciphertext = STANDARD.decode(ciphertext).map_err(DecryptError::Base64)?; + + let plaintext = self + .decryptor + .clone() + .decrypt_padded_mut::(ciphertext.as_mut_slice()) + .map_err(DecryptError::Unpad)?; + + Ok(String::from_utf8_lossy(plaintext).to_string()) + } + + pub fn encrypt(&self, plaintext: String) -> String { + // Allocate buffer with extra space for padding (AES block size is 16 bytes) + let plaintext_bytes = plaintext.as_bytes(); + let mut buffer = vec![0u8; 16 * (plaintext_bytes.len() / 16 + 1)]; + buffer[..plaintext_bytes.len()].copy_from_slice(plaintext_bytes); + let ciphertext = self + .encryptor + .clone() + .encrypt_padded_mut::(&mut buffer, plaintext_bytes.len()) + .expect("enough space for encrypting is allocated"); + + STANDARD.encode(ciphertext) + } +} + +#[derive(Debug)] +pub enum DecryptError { + Base64(base64::DecodeError), + Unpad(cbc::cipher::block_padding::UnpadError), +} + +impl fmt::Display for DecryptError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - CryptoError::Base64(e) => write!(f, "Base64 decode failed: {}", e), - CryptoError::Aes(e) => write!(f, "AES decryption failed: {}", e), - CryptoError::Unpad(e) => write!(f, "PKCS7 unpadding failed: {}", e), - CryptoError::Pad(e) => write!(f, "PKCS7 padding failed: {}", e), + DecryptError::Base64(e) => write!(f, "Base64 decode failed: {}", e), + DecryptError::Unpad(e) => write!(f, "PKCS7 unpadding failed: {}", e), } } } -impl std::error::Error for CryptoError {} - -pub fn decrypt(ciphertext_b64: &str, key: &str, iv: &str) -> Result { - // 1. Base64 decode the ciphertext - let ciphertext = STANDARD - .decode(ciphertext_b64) - .map_err(CryptoError::Base64)?; - - // 2. Initialize AES-128 in CBC mode - let key_bytes = key.as_bytes(); - let iv_bytes = iv.as_bytes(); - let decryptor = Aes128CbcDec::new_from_slices(key_bytes, iv_bytes).map_err(CryptoError::Aes)?; - - // 3. Decrypt the ciphertext, handling padding - let decrypted_len = ciphertext.len(); - let mut plaintext = vec![0u8; decrypted_len]; - let copy_len = ciphertext.len(); - plaintext[..copy_len].copy_from_slice(&ciphertext); - - let plaintext_slice = decryptor - .decrypt_padded_mut::(&mut plaintext[..decrypted_len]) - .map_err(CryptoError::Unpad)?; - - // 4. Convert plaintext to a UTF-8 string - Ok(String::from_utf8_lossy(plaintext_slice).to_string()) -} - -pub fn encrypt(plaintext: &str, key: &str, iv: &str) -> Result { - // 1. Initialize AES-128 in CBC mode - let key_bytes = key.as_bytes(); - let iv_bytes = iv.as_bytes(); - let encryptor = Aes128CbcEnc::new_from_slices(key_bytes, iv_bytes).map_err(CryptoError::Aes)?; - - // 2. Encrypt the plaintext with PKCS7 padding - // Allocate buffer with extra space for padding (AES block size is 16 bytes) - let plaintext_bytes = plaintext.as_bytes(); - let mut buffer = vec![0u8; plaintext_bytes.len() + 16]; - buffer[..plaintext_bytes.len()].copy_from_slice(plaintext_bytes); - - let ciphertext = encryptor - .encrypt_padded_mut::(&mut buffer, plaintext_bytes.len()) - .map_err(CryptoError::Pad)?; - - // 3. Base64 encode the ciphertext - Ok(STANDARD.encode(ciphertext)) -} +impl std::error::Error for DecryptError {} diff --git a/src/jsonrpc.rs b/src/jsonrpc.rs index 5d7c3fe..7f5610a 100644 --- a/src/jsonrpc.rs +++ b/src/jsonrpc.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/src/jsonrpc/tactics.rs b/src/jsonrpc/tactics.rs index c52279d..0f6df9d 100644 --- a/src/jsonrpc/tactics.rs +++ b/src/jsonrpc/tactics.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use serde_with::{BoolFromInt, serde_as}; diff --git a/src/main.rs b/src/main.rs index 3f49543..b2b57c9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,14 +10,14 @@ use tracing_subscriber::{EnvFilter, fmt, prelude::*}; mod admin; mod auth; +mod context; mod crypto; mod db; mod jsonrpc; mod middleware; mod proxy; -mod state; -use state::AppState; +use context::AppContext; const DEFAULT_TARGET_URL: &str = "https://cloud.linspirer.com:883"; const DEFAULT_HOST: &str = "0.0.0.0"; @@ -69,31 +69,27 @@ async fn main() -> anyhow::Result<()> { .danger_accept_invalid_certs(true) .build()?; - // Create shared state - let state = Arc::new(AppState { - client, - target_url, - key, - iv, - jwt_secret, - db, - commands: Default::default(), - }); + // Create shared context + let ctx = Arc::new( + AppContext::builder(target_url, key.as_bytes(), iv.as_bytes(), jwt_secret, db) + .with_client(client) + .build()?, + ); let proxy_middleware = - axum::middleware::from_fn_with_state(state.clone(), middleware::middleware); + axum::middleware::from_fn_with_state(ctx.clone(), middleware::middleware); // Build our application let app = Router::new() // Admin routes - .nest("/admin", admin::routes::admin_routes(state.clone())) + .nest("/admin", admin::routes::admin_routes(ctx.clone())) // Proxy Linspirer APIs .route( "/public-interface.php", any(proxy::proxy_handler.layer(proxy_middleware)), ) .layer(CompressionLayer::new().gzip(true)) - .with_state(state); + .with_state(ctx); // Run the server info!("Proxy started on {} => {}", addr_str, target_url_str); diff --git a/src/middleware.rs b/src/middleware.rs index 19514e1..1af17c6 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,4 +1,3 @@ -use std::str; use std::sync::Arc; use axum::{ @@ -7,19 +6,16 @@ use axum::{ middleware::Next, response::{IntoResponse, Response}, }; +use bytes::Bytes; use http_body_util::BodyExt; use serde_json::Value; use tracing::{debug, info, warn}; -use crate::{AppState, crypto, db}; - -enum ResponseBody { - Original(String), - Modified(Value), -} +use crate::{AppContext, db}; +/// Main middleware to intercept, decrypt, modify, and log requests and responses. pub async fn middleware( - State(state): State>, + State(ctx): State>, req: Request, next: Next, ) -> impl IntoResponse { @@ -36,91 +32,31 @@ pub async fn middleware( } }; - let (decrypted_request, method) = match str::from_utf8(&body_bytes) - .map_err(anyhow::Error::from) - .and_then(|body| process_and_log_request(body, &state.key, &state.iv)) - { - Ok(request_data) => { - let method = request_data - .get("method") - .and_then(Value::as_str) - .map(str::to_string); - (request_data, method) - } - Err(e) => { - warn!("Failed to process request for logging: {}", e); - let val = Value::String("Could not decrypt request".to_string()); - (val, None) - } - }; + // Process request: decrypt, deserialize, modify, re-encrypt + let (processed_req_body, decrypted_request, method) = process_request(body_bytes, &ctx).await; - let req = Request::from_parts(parts, axum::body::Body::from(body_bytes)); + // Pass modified request to the next handler + let req = Request::from_parts(parts, axum::body::Body::from(processed_req_body)); let res = next.run(req).await; - let (resp_parts, body_bytes) = { - let (parts, body) = res.into_parts(); - let bytes = match body.collect().await { - Ok(b) => b.to_bytes(), - Err(e) => { - warn!("Failed to read response body: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to read response body".to_string(), - ) - .into_response(); - } - }; - (parts, bytes) - }; - let resp_body_text = String::from_utf8(body_bytes.clone().to_vec()).unwrap_or_default(); - - // Check for generic method interception first - let response_body = if let Some(method_str) = &method { - if let Ok(Some(intercepted)) = intercept_response(method_str, &resp_body_text, &state).await - { - info!("Intercepting response for method: {}", method_str); - ResponseBody::Original(intercepted) - } else if Some("com.linspirer.device.getcommand") == method.as_deref() { - // Special handling for getcommand - match handle_getcommand_response(&resp_body_text, &state).await { - Ok(new_body) => ResponseBody::Modified(new_body), - Err(e) => { - warn!( - "Failed to handle getcommand response: {}. Responding with empty command list.", - e - ); - let mut empty_response = - serde_json::from_str::(&resp_body_text).unwrap_or(Value::Null); - if let Some(obj) = empty_response.as_object_mut() { - obj.insert("result".to_string(), Value::Array(vec![])); - } - ResponseBody::Modified(empty_response) - } - } - } else { - ResponseBody::Original(resp_body_text.clone()) - } - } else { - ResponseBody::Original(resp_body_text.clone()) - }; - - let (decrypted_response, final_response_body) = match response_body { - ResponseBody::Original(body_text) => { - let decrypted = - decrypt_and_format(&body_text, &state.key, &state.iv).unwrap_or_else(|_| { - Value::String("Could not decrypt or format response".to_string()) - }); - (decrypted, body_text) - } - ResponseBody::Modified(response_body_value) => { - let pretty_printed = - serde_json::to_string_pretty(&response_body_value).unwrap_or_default(); - let encrypted = crypto::encrypt(&pretty_printed, &state.key, &state.iv) - .unwrap_or_else(|_| "Failed to encrypt modified response".to_string()); - (response_body_value, encrypted) + // Process response: decrypt, deserialize, modify, re-encrypt + let (resp_parts, body) = res.into_parts(); + let body_bytes = match body.collect().await { + Ok(b) => b.to_bytes(), + Err(e) => { + warn!("Failed to read response body: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to read response body".to_string(), + ) + .into_response(); } }; + let (final_response_body, decrypted_response) = + process_response(body_bytes, &method, &ctx).await; + + // Log the decrypted request and response debug!( "\nRequest:\n{}\nResponse:\n{}\n{}", serde_json::to_string_pretty(&decrypted_request).unwrap_or_default(), @@ -128,22 +64,21 @@ pub async fn middleware( "-".repeat(80), ); - if let Some(method) = method { - // TODO: interception action - if let Err(e) = db::repositories::logs::create( - &state.db, - method, - decrypted_request, - decrypted_response, - "".into(), - "".into(), - ) - .await - { - warn!("Failed to log request: {}", e); - } + // Write log to database + if let Err(e) = db::repositories::logs::create( + &ctx.db, + method, + decrypted_request, + decrypted_response, + "".into(), + "".into(), + ) + .await + { + warn!("Failed to log request: {}", e); } + // Build and return the final response let mut response_builder = Response::builder().status(resp_parts.status); if !resp_parts.headers.is_empty() { *response_builder.headers_mut().unwrap() = resp_parts.headers; @@ -153,34 +88,128 @@ pub async fn middleware( .unwrap() } -fn process_and_log_request(body: &str, key: &str, iv: &str) -> anyhow::Result { - let mut request_data: Value = serde_json::from_str(body)?; +/// Processes the incoming request body. +/// Returns the re-encrypted body for the next handler, the decrypted JSON for logging, and the method name. +async fn process_request(body_bytes: Bytes, ctx: &Arc) -> (String, Value, String) { + let body_str = String::from_utf8(body_bytes.to_vec()).unwrap_or_default(); - if let Some(params_value) = request_data.get_mut("params") - && let Some(params_str) = params_value.as_str() - { - let params_str_owned = params_str.to_string(); - match crypto::decrypt(¶ms_str_owned, key, iv) { - Ok(decrypted_str) => { - let decrypted_params: Value = - serde_json::from_str(&decrypted_str).unwrap_or(Value::String(decrypted_str)); - *params_value = decrypted_params; - } - Err(e) => { - *params_value = Value::String(format!("decrypt failed: {}", e)); - } - } + let mut plain_request: Value = serde_json::from_str(&body_str).unwrap_or_else(|err| { + warn!( + "Failed to deserialize request body: {}. Using fallback string value.", + err + ); + Value::String("Could not deserialize request".into()) + }); + decrypt_params(&mut plain_request, ctx); + + let method = plain_request + .get("method") + .and_then(Value::as_str) + .unwrap_or_default() + .to_string(); + + if let Err(e) = modify_request(&mut plain_request, &method, ctx).await { + warn!("Failed to modify request: {}", e); } - Ok(request_data) + + let mut crypted_request = plain_request.clone(); + encrypt_params(&mut crypted_request, ctx); + + ( + serde_json::to_string(&crypted_request).expect("deserialization succeeded"), + plain_request, + method, + ) } -async fn handle_getcommand_response( - body_text: &str, - state: &Arc, -) -> anyhow::Result { - let decrypted = crypto::decrypt(body_text, &state.key, &state.iv)?; - let mut response_json: Value = serde_json::from_str(&decrypted)?; +/// Processes the outgoing response body. +/// Returns the final encrypted body for the client and the decrypted JSON for logging. +async fn process_response( + body_bytes: Bytes, + method: &str, + ctx: &Arc, +) -> (String, Value) { + let body_str = String::from_utf8(body_bytes.to_vec()).unwrap_or_default(); + let decrypted_body = ctx.cryptor.decrypt(body_str.clone()).unwrap_or_else(|err| { + warn!( + "Failed to decrypt response body: {}. Assuming it's not encrypted.", + err + ); + body_str + }); + + let mut response_value: Value = serde_json::from_str(&decrypted_body).unwrap_or_else(|err| { + warn!( + "Failed to deserialize response body: {}. Using string value.", + err + ); + Value::String(decrypted_body.clone()) + }); + + let decrypted = response_value.clone(); + + if let Err(e) = modify_response(&mut response_value, method, ctx, &decrypted_body).await { + warn!("Failed to modify response: {}", e); + } + + let modified_body_str = serde_json::to_string(&response_value).unwrap_or_default(); + let encrypted = ctx.cryptor.encrypt(modified_body_str); + + (encrypted, decrypted) +} + +/// Placeholder for request modification logic. +async fn modify_request( + _request_json: &mut Value, + _method: &str, + _ctx: &Arc, +) -> anyhow::Result<()> { + // TODO: Implement request modification logic based on rules or other criteria. + Ok(()) +} + +/// Applies modification rules to the response. +async fn modify_response( + response_json: &mut Value, + method: &str, + ctx: &Arc, + original_decrypted: &str, +) -> anyhow::Result<()> { + // Check for generic method interception (e.g., replace response from DB) + if let Some(intercepted) = intercept_response(method, original_decrypted, ctx).await? { + info!("Intercepting response for method: {}", method); + *response_json = serde_json::from_str(&intercepted).unwrap_or_else(|e| { + warn!( + "Failed to parse intercepted response as JSON: {}. Using as string.", + e + ); + Value::String(intercepted) + }); + return Ok(()); + } + + // Special handling for getcommand + if method == "com.linspirer.device.getcommand" + && let Err(e) = handle_getcommand_response(response_json, ctx).await + { + warn!( + "Failed to handle getcommand response: {}. Responding with empty command list.", + e + ); + if let Some(obj) = response_json.as_object_mut() { + obj.insert("result".to_string(), Value::Array(vec![])); + } + } + + Ok(()) +} + +/// Handles the 'getcommand' response by injecting verified commands. +async fn handle_getcommand_response( + response_json: &mut Value, + ctx: &Arc, +) -> anyhow::Result<()> { if let Some(result) = response_json.get_mut("result") && let Some(commands) = result.as_array_mut() && !commands.is_empty() @@ -189,22 +218,18 @@ async fn handle_getcommand_response( for cmd in commands.iter() { let cmd_json = serde_json::to_string(cmd)?; if let Err(e) = - crate::db::repositories::commands::insert(&state.db, &cmd_json, "unverified").await + crate::db::repositories::commands::insert(&ctx.db, &cmd_json, "unverified").await { warn!("Failed to persist command to database: {}", e); } info!("Added command to the queue: {:?}", cmd); } - - // Also add to in-memory queue for backwards compatibility - let mut queue = state.commands.unverified.write().await; - queue.extend(commands.drain(..)); } if let Some(obj) = response_json.as_object_mut() { // Get verified commands from database let verified_cmds = - match crate::db::repositories::commands::list_by_status(&state.db, "verified").await { + match crate::db::repositories::commands::list_by_status(&ctx.db, "verified").await { Ok(cmds) => cmds, Err(e) => { warn!("Failed to fetch verified commands from database: {}", e); @@ -218,48 +243,62 @@ async fn handle_getcommand_response( .filter_map(|c| serde_json::from_str(&c.command_json).ok()) .collect(); - // Also include in-memory verified commands - let mem_verified = std::mem::take(&mut *state.commands.verified.write().await); - let mut all_verified = verified_values; - all_verified.extend(mem_verified); - - obj.insert("result".to_string(), Value::Array(all_verified.clone())); + obj.insert("result".to_string(), Value::Array(verified_values)); // Clear verified commands from database after sending - if let Err(e) = crate::db::repositories::commands::clear_verified(&state.db).await { + if let Err(e) = crate::db::repositories::commands::clear_verified(&ctx.db).await { warn!("Failed to clear verified commands from database: {}", e); } } - Ok(response_json) -} - -fn decrypt_and_format(body_text: &str, key: &str, iv: &str) -> anyhow::Result { - let decrypted = crypto::decrypt(body_text, key, iv)?; - Ok(serde_json::from_str(&decrypted)?) + Ok(()) } +/// Checks for and applies response interception rules. async fn intercept_response( method: &str, - _original_response: &str, - state: &Arc, + _orignal_response: &str, + ctx: &Arc, ) -> anyhow::Result> { // Check if there's an interception rule for this method - let rule = crate::db::repositories::rules::find_by_method(&state.db, method).await?; + let rule = crate::db::repositories::rules::find_by_method(&ctx.db, method).await?; match rule { - Some(r) if r.action == "replace" => { - if let Some(custom_response) = &r.custom_response { - Ok(crypto::encrypt(custom_response, &state.key, &state.iv).map(Some)?) - } else { - Ok(None) - } - } + Some(r) if r.action == "replace" => Ok(r.custom_response.as_ref().cloned()), Some(r) if r.action == "modify" => { - // Future: Apply transformations - // For now, just pass through + // TODO: Apply modifications Ok(None) } _ => Ok(None), // Passthrough } } + +fn decrypt_params(request: &mut Value, ctx: &Arc) { + if let Some(params) = request.get_mut("params") { + *params = match params.take() { + Value::String(params) => { + let params = ctx.cryptor.decrypt(params).unwrap_or_else(|err| { + warn!("Failed to decrypt request params: {err}. Using fallback string value.",); + "\"Failed to decrypt params\"".to_string() + }); + serde_json::from_str(¶ms).unwrap_or_else(|_| { + Value::String("Failed to deserialize decrypted params".into()) + }) + } + other => { + warn!("'params' is not an encrypted string"); + other + } + }; + }; +} + +fn encrypt_params(request: &mut Value, ctx: &Arc) { + if let Some(params) = request.get_mut("params") { + let plaintext = serde_json::to_string(params).unwrap_or_else(|err| { + warn!("Failed to serialize params: {err}. Using fallback string value."); + "".to_string() + }); + *params = Value::String(ctx.cryptor.encrypt(plaintext)); + }; +} diff --git a/src/proxy.rs b/src/proxy.rs index 048b8e7..ad9221b 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -8,10 +8,10 @@ use axum::{ use http_body_util::BodyExt; use tracing::error; -use crate::AppState; +use crate::AppContext; pub async fn proxy_handler( - State(state): State>, + State(state): State>, OriginalUri(uri): OriginalUri, req: Request, ) -> impl IntoResponse { diff --git a/src/state.rs b/src/state.rs deleted file mode 100644 index 5da2344..0000000 --- a/src/state.rs +++ /dev/null @@ -1,33 +0,0 @@ -use serde_json::Value; -use sqlx::SqlitePool; -use tokio::sync::RwLock; - -pub struct AppState { - pub client: reqwest::Client, - pub target_url: reqwest::Url, - pub key: String, - pub iv: String, - pub jwt_secret: String, - pub db: SqlitePool, - pub commands: Commands, -} - -pub struct Commands { - pub unverified: RwLock>, - pub verified: RwLock>, -} - -impl Commands { - pub fn new() -> Self { - Self { - unverified: RwLock::default(), - verified: RwLock::default(), - } - } -} - -impl Default for Commands { - fn default() -> Self { - Self::new() - } -}