feat: middleware (WIP)

This commit is contained in:
2025-11-28 18:05:19 +08:00
parent e46690cb21
commit be35040e26
7 changed files with 334 additions and 122 deletions

View File

@@ -1,12 +1,13 @@
use std::fmt;
use aes::Aes128;
use base64::{Engine as _, engine::general_purpose::STANDARD};
use base64::{Engine, engine::general_purpose::STANDARD};
use cbc::{
Decryptor,
cipher::{BlockDecryptMut, KeyIvInit, block_padding::Pkcs7},
Decryptor, Encryptor,
cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit, block_padding::Pkcs7},
};
type Aes128CbcEnc = Encryptor<Aes128>;
type Aes128CbcDec = Decryptor<Aes128>;
#[derive(Debug)]
@@ -14,6 +15,7 @@ pub enum CryptoError {
Base64(base64::DecodeError),
Aes(cbc::cipher::InvalidLength),
Unpad(cbc::cipher::block_padding::UnpadError),
Pad(aes::cipher::inout::PadError),
}
impl fmt::Display for CryptoError {
@@ -22,6 +24,7 @@ impl fmt::Display for CryptoError {
CryptoError::Base64(e) => write!(f, "Base64 decode failed: {}", e),
CryptoError::Aes(e) => write!(f, "AES decryption failed: {}", e),
CryptoError::Unpad(e) => write!(f, "PKCS7 unpadding failed: {}", e),
CryptoError::Pad(e) => write!(f, "PKCS7 padding failed: {}", e),
}
}
}
@@ -52,3 +55,19 @@ pub fn decrypt(ciphertext_b64: &str, key: &str, iv: &str) -> Result<String, Cryp
// 4. Convert plaintext to a UTF-8 string
Ok(String::from_utf8_lossy(plaintext_slice).to_string())
}
pub fn encrypt(plaintext: &str, key: &str, iv: &str) -> Result<String, CryptoError> {
// 1. Initialize AES-128 in CBC mode
let key_bytes = key.as_bytes();
let iv_bytes = iv.as_bytes();
let encryptor = Aes128CbcEnc::new_from_slices(key_bytes, iv_bytes).map_err(CryptoError::Aes)?;
// 2. Encrypt the plaintext with PKCS7 padding
let mut buffer = plaintext.as_bytes().to_vec();
let ciphertext = encryptor
.encrypt_padded_mut::<Pkcs7>(&mut buffer, plaintext.len())
.map_err(CryptoError::Pad)?;
// 3. Base64 encode the ciphertext
Ok(STANDARD.encode(ciphertext))
}

View File

@@ -1 +1,19 @@
use hashbrown::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Serialize, Deserialize)]
pub struct Request {
pub method: String,
pub params: Option<serde_json::Map<String, Value>>,
pub id: Value,
pub jsonrpc: Option<String>
}
#[derive(Serialize, Deserialize)]
pub struct Response {
pub result: Option<Value>,
pub params: Option<HashMap<String, Value>>,
pub id: Value,
pub jsonrpc: Option<String>
}

View File

@@ -1,13 +1,17 @@
use std::{net::SocketAddr, sync::Arc};
use anyhow::Context;
use axum::{Router, routing::any};
use axum::handler::Handler;
use dotenvy::dotenv;
use serde_json::Value;
use tokio::sync::RwLock;
use tower_http::compression::CompressionLayer;
use tracing::{error, info, level_filters::LevelFilter};
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
mod crypto;
mod jsonrpc;
mod middleware;
mod proxy;
#[derive(Clone)]
@@ -16,6 +20,8 @@ pub struct AppState {
pub target_url: reqwest::Url,
pub key: String,
pub iv: String,
pub command_queue: Arc<RwLock<Vec<Value>>>,
pub saved_tactics: Arc<RwLock<Option<Value>>>,
}
const DEFAULT_TARGET_URL: &str = "https://cloud.linspirer.com:883";
@@ -65,11 +71,16 @@ async fn main() -> anyhow::Result<()> {
target_url,
key,
iv,
command_queue: Arc::new(RwLock::new(Vec::new())),
saved_tactics: Arc::new(RwLock::new(None)),
});
// Build our application with a single route
let app = Router::new()
.route("/{*path}", any(proxy::proxy_handler))
// Build our application
let app = proxy::proxy_handler
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::log_middleware,
))
.layer(CompressionLayer::new().gzip(true))
.with_state(state);

178
src/middleware.rs Normal file
View File

@@ -0,0 +1,178 @@
use std::str;
use std::sync::Arc;
use axum::{
extract::{OriginalUri, State},
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use http_body_util::BodyExt;
use serde_json::Value;
use tracing::{info, warn};
use crate::{AppState, crypto};
enum ResponseBody {
Original(String),
Modified(Value),
}
pub async fn log_middleware(
State(state): State<Arc<AppState>>,
OriginalUri(uri): OriginalUri,
req: Request<axum::body::Body>,
next: Next,
) -> impl IntoResponse {
let (parts, body) = req.into_parts();
let body_bytes = match body.collect().await {
Ok(body) => body.to_bytes(),
Err(e) => {
warn!("Failed to read request body: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to read request body".to_string(),
)
.into_response();
}
};
let path = uri.path();
let (decrypted_request_log, method) = match str::from_utf8(&body_bytes)
.map_err(anyhow::Error::from)
.and_then(|body| process_and_log_request(body, &state.key, &state.iv))
{
Ok(request_data) => {
let method = request_data
.get("method")
.and_then(Value::as_str)
.map(str::to_string);
(request_data, method)
}
Err(e) => {
warn!("Failed to process request for logging: {}", e);
let val = Value::String("Could not decrypt request".to_string());
(val, None)
}
};
let req = Request::from_parts(parts, axum::body::Body::from(body_bytes));
let res = next.run(req).await;
let (resp_parts, body_bytes) = {
let (parts, body) = res.into_parts();
let bytes = match body.collect().await {
Ok(b) => b.to_bytes(),
Err(e) => {
warn!("Failed to read response body: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to read response body".to_string(),
)
.into_response();
}
};
(parts, bytes)
};
let resp_body_text = String::from_utf8(body_bytes.clone().to_vec()).unwrap_or_default();
let response_body_to_log = if Some("com.linspirer.device.getcommand") == method.as_deref() {
match handle_getcommand_response(&resp_body_text, &state).await {
Ok(new_body) => ResponseBody::Modified(new_body),
Err(e) => {
warn!(
"Failed to handle getcommand response: {}. Responding with empty command list.",
e
);
let mut empty_response =
serde_json::from_str::<Value>(&resp_body_text).unwrap_or(Value::Null);
if let Some(obj) = empty_response.as_object_mut() {
obj.insert("result".to_string(), Value::Array(vec![]));
}
ResponseBody::Modified(empty_response)
}
}
} else {
ResponseBody::Original(resp_body_text)
};
let (decrypted_response_for_log, final_response_body) = match response_body_to_log {
ResponseBody::Original(body_text) => {
let decrypted = decrypt_and_format(&body_text, &state.key, &state.iv)
.unwrap_or_else(|_| "Could not decrypt or format response".to_string());
(decrypted, body_text)
}
ResponseBody::Modified(body_value) => {
let pretty_printed = serde_json::to_string_pretty(&body_value).unwrap_or_default();
let encrypted = crypto::encrypt(&pretty_printed, &state.key, &state.iv)
.unwrap_or_else(|_| "Failed to encrypt modified response".to_string());
(pretty_printed, encrypted)
}
};
info!(
"{}\nRequest:\n{}\nResponse:\n{}\n{}",
path,
serde_json::to_string_pretty(&decrypted_request_log).unwrap_or_default(),
decrypted_response_for_log,
"-".repeat(80),
);
let mut response_builder = Response::builder().status(resp_parts.status);
if !resp_parts.headers.is_empty() {
*response_builder.headers_mut().unwrap() = resp_parts.headers;
}
response_builder
.body(axum::body::Body::from(final_response_body))
.unwrap()
}
fn process_and_log_request(body: &str, key: &str, iv: &str) -> anyhow::Result<Value> {
let mut request_data: Value = serde_json::from_str(body)?;
if let Some(params_value) = request_data.get_mut("params")
&& let Some(params_str) = params_value.as_str()
{
let params_str_owned = params_str.to_string();
match crypto::decrypt(&params_str_owned, key, iv) {
Ok(decrypted_str) => {
let decrypted_params: Value =
serde_json::from_str(&decrypted_str).unwrap_or(Value::String(decrypted_str));
*params_value = decrypted_params;
}
Err(e) => {
*params_value = Value::String(format!("decrypt failed: {}", e));
}
}
}
Ok(request_data)
}
async fn handle_getcommand_response(body_text: &str, state: &Arc<AppState>) -> anyhow::Result<Value> {
let decrypted = crypto::decrypt(body_text, &state.key, &state.iv)?;
let mut response_json: Value = serde_json::from_str(&decrypted)?;
if let Some(result) = response_json.get("result")
&& let Some(commands) = result.as_array()
&& !commands.is_empty()
{
let mut queue = state.command_queue.write().await;
for cmd in commands {
queue.push(cmd.clone());
info!("Added command to the queue: {:?}", cmd);
}
}
if let Some(obj) = response_json.as_object_mut() {
obj.insert("result".to_string(), Value::Array(vec![]));
}
Ok(response_json)
}
fn decrypt_and_format(body_text: &str, key: &str, iv: &str) -> anyhow::Result<String> {
let decrypted = crypto::decrypt(body_text, key, iv)?;
let formatted: Value = serde_json::from_str(&decrypted)?;
Ok(serde_json::to_string_pretty(&formatted)?)
}

View File

@@ -1,23 +1,21 @@
use std::sync::Arc;
use axum::{
extract::{Path, State},
extract::{OriginalUri, State},
http::{Request, StatusCode},
response::{IntoResponse, Response},
};
use http_body_util::BodyExt;
use serde_json::Value;
use tracing::{error, info, warn};
use tracing::error;
use crate::AppState;
mod crypto;
pub async fn proxy_handler(
State(state): State<Arc<AppState>>,
Path(path): Path<String>,
OriginalUri(uri): OriginalUri,
req: Request<axum::body::Body>,
) -> impl IntoResponse {
let path = uri.path();
let (parts, body) = req.into_parts();
let body = match body.collect().await {
Ok(body) => body.to_bytes(),
@@ -31,22 +29,8 @@ pub async fn proxy_handler(
}
};
let decrypted_request_log = match str::from_utf8(&body)
.map(|body| process_and_log_request(body, &state.key, &state.iv))
{
Ok(Ok(log)) => log,
Ok(Err(e)) => {
warn!("Failed to process request for logging: {}", e);
Value::String("Could not decrypt request".to_string())
}
Err(e) => {
warn!("Failed to decode request for logging: {}", e);
Value::String("Could not decrypt request".to_string())
}
};
let mut target_url = state.target_url.clone();
target_url.set_path(&path);
target_url.set_path(path);
target_url.set_query(parts.uri.query());
let mut forwarded_req = state.client.request(parts.method, target_url);
if !parts.headers.is_empty() {
@@ -77,21 +61,6 @@ pub async fn proxy_handler(
}
};
let decrypted_response_log = match decrypt_response(&resp_body, &state.key, &state.iv).await {
Ok(log) => log,
Err(e) => {
warn!("Failed to process response for logging: {}", e);
"Could not decrypt response".to_string()
}
};
info!(
"\nRequest:\n{}\nResponse:\n{}\n{}",
serde_json::to_string_pretty(&decrypted_request_log).unwrap_or_default(),
decrypted_response_log,
"-".repeat(80),
);
let mut response_builder = Response::builder().status(resp_parts.status);
if !resp_parts.headers.is_empty() {
*response_builder.headers_mut().unwrap() = resp_parts.headers;
@@ -101,33 +70,6 @@ pub async fn proxy_handler(
.unwrap()
}
fn process_and_log_request(body: &str, key: &str, iv: &str) -> anyhow::Result<Value> {
let mut request_data: Value = serde_json::from_str(body)?;
if let Some(params_value) = request_data.get_mut("params")
&& let Some(params_str) = params_value.as_str()
{
let params_str_owned = params_str.to_string();
match crypto::decrypt(&params_str_owned, key, iv) {
Ok(decrypted_str) => {
let decrypted_params: Value =
serde_json::from_str(&decrypted_str).unwrap_or(Value::String(decrypted_str));
*params_value = decrypted_params;
}
Err(e) => {
*params_value = Value::String(format!("decrypt failed: {}", e));
}
}
}
Ok(request_data)
}
async fn decrypt_response(body_text: &str, key: &str, iv: &str) -> anyhow::Result<String> {
let decrypted = crypto::decrypt(body_text, key, iv)?;
let formatted: Value = serde_json::from_str(&decrypted)?;
Ok(serde_json::to_string_pretty(&formatted)?)
}
async fn clone_response(
resp: reqwest::Response,
) -> anyhow::Result<(axum::http::response::Parts, String)> {