feat: web frontend; middleware; serde (WIP?)

This commit is contained in:
2025-11-30 09:41:37 +08:00
parent be35040e26
commit 531ac029af
45 changed files with 6806 additions and 82 deletions

View File

@@ -0,0 +1,47 @@
use std::sync::Arc;
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use crate::{auth, state::AppState};
pub async fn auth_middleware(
State(state): State<Arc<AppState>>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
// Skip authentication for login endpoint
if request.uri().path().ends_with("/login") {
return Ok(next.run(request).await);
}
// Skip authentication for static files (non-API routes)
if !request.uri().path().starts_with("/api/") {
return Ok(next.run(request).await);
}
// Get Authorization header
let auth_header = request
.headers()
.get("Authorization")
.and_then(|h| h.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
// Check if it's a Bearer token
if !auth_header.starts_with("Bearer ") {
return Err(StatusCode::UNAUTHORIZED);
}
let token = &auth_header[7..]; // Skip "Bearer "
// Validate token
auth::validate_token(state.jwt_secret.as_bytes(), token)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
// Token is valid, proceed
Ok(next.run(request).await)
}

366
src/admin/handlers.rs Normal file
View File

@@ -0,0 +1,366 @@
use std::sync::Arc;
use axum::{
Json,
extract::{Path, State},
http::StatusCode,
};
use tracing::error;
use crate::{AppState, auth, crypto, db};
use super::models::*;
// Authentication handlers
pub async fn login(
State(state): State<Arc<AppState>>,
Json(req): Json<LoginRequest>,
) -> Result<Json<LoginResponse>, (StatusCode, Json<ApiError>)> {
// Get stored password hash from config
let password_hash = match db::repositories::config::get(&state.db, "admin_password_hash").await
{
Ok(Some(hash)) => hash,
Ok(None) => {
error!("Admin password hash not found in config");
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Authentication not configured")),
));
}
Err(e) => {
error!("Failed to get password hash: {}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Authentication error")),
));
}
};
// Verify password
match auth::verify_password(&req.password, &password_hash) {
Ok(true) => {
// Generate JWT token
match auth::generate_token(state.jwt_secret.as_bytes(), "admin") {
Ok(token) => Ok(Json(LoginResponse { token })),
Err(e) => {
error!("Failed to generate token: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to generate token")),
))
}
}
}
Ok(false) => Err((
StatusCode::UNAUTHORIZED,
Json(ApiError::new("Invalid password")),
)),
Err(e) => {
error!("Password verification error: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Authentication error")),
))
}
}
}
pub async fn change_password(
State(state): State<Arc<AppState>>,
Json(req): Json<super::models::ChangePasswordRequest>,
) -> Result<StatusCode, (StatusCode, Json<ApiError>)> {
// Get current password hash from config
let current_hash = match db::repositories::config::get(&state.db, "admin_password_hash").await {
Ok(Some(hash)) => hash,
Ok(None) => {
error!("Admin password hash not found in config");
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Authentication not configured")),
));
}
Err(e) => {
error!("Failed to get password hash: {}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to get current password")),
));
}
};
// Verify old password
match auth::verify_password(&req.old_password, &current_hash) {
Ok(true) => {
// Hash new password
let new_hash = match auth::hash_password(&req.new_password) {
Ok(hash) => hash,
Err(e) => {
error!("Failed to hash new password: {}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to hash new password")),
));
}
};
// Update password hash in config
if let Err(e) = db::repositories::config::set(
&state.db,
"admin_password_hash",
&new_hash,
Some("Hashed admin password"),
)
.await
{
error!("Failed to update password hash: {}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to update password")),
));
}
Ok(StatusCode::OK)
}
Ok(false) => Err((
StatusCode::UNAUTHORIZED,
Json(ApiError::new("Invalid old password")),
)),
Err(e) => {
error!("Password verification error: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Password verification error")),
))
}
}
}
// Rules handlers
pub async fn list_rules(
State(state): State<Arc<AppState>>,
) -> Result<Json<Vec<RuleResponse>>, (StatusCode, Json<ApiError>)> {
match db::repositories::rules::list_all(&state.db).await {
Ok(rules) => Ok(Json(rules.into_iter().map(Into::into).collect())),
Err(e) => {
error!("Failed to list rules: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to fetch rules")),
))
}
}
}
pub async fn create_rule(
State(state): State<Arc<AppState>>,
Json(req): Json<CreateRuleRequest>,
) -> Result<Json<RuleResponse>, (StatusCode, Json<ApiError>)> {
// Validate action
if !matches!(req.action.as_str(), "passthrough" | "modify" | "replace") {
return Err((
StatusCode::BAD_REQUEST,
Json(ApiError::new(
"Invalid action. Must be 'passthrough', 'modify', or 'replace'",
)),
));
}
// If action is replace, validate and encrypt custom_response
let custom_response = if req.action == "replace" {
match req.custom_response {
Some(resp) => Some(resp),
None => {
return Err((
StatusCode::BAD_REQUEST,
Json(ApiError::new(
"custom_response is required when action is 'replace'",
)),
));
}
}
} else {
None
};
match db::repositories::rules::create(
&state.db,
&req.method_name,
&req.action,
custom_response.as_deref(),
)
.await
{
Ok(id) => match db::repositories::rules::find_by_id(&state.db, id).await {
Ok(Some(rule)) => Ok(Json(rule.into())),
_ => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to fetch created rule")),
)),
},
Err(e) => {
error!("Failed to create rule: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new(format!("Failed to create rule: {}", e))),
))
}
}
}
pub async fn update_rule(
State(state): State<Arc<AppState>>,
Path(id): Path<i64>,
Json(req): Json<UpdateRuleRequest>,
) -> Result<Json<RuleResponse>, (StatusCode, Json<ApiError>)> {
// Validate action if provided
if let Some(ref action) = req.action
&& !matches!(action.as_str(), "passthrough" | "modify" | "replace")
{
return Err((
StatusCode::BAD_REQUEST,
Json(ApiError::new("Invalid action")),
));
}
// Encrypt custom_response if provided
let custom_response = if let Some(resp) = req.custom_response {
match crypto::encrypt(&resp, &state.key, &state.iv) {
Ok(encrypted) => Some(encrypted),
Err(e) => {
error!("Failed to encrypt custom response: {}", e);
return Err((
StatusCode::BAD_REQUEST,
Json(ApiError::new("Failed to encrypt custom response")),
));
}
}
} else {
None
};
match db::repositories::rules::update(
&state.db,
id,
req.method_name.as_deref(),
req.action.as_deref(),
custom_response.as_deref(),
req.is_enabled,
)
.await
{
Ok(_) => match db::repositories::rules::find_by_id(&state.db, id).await {
Ok(Some(rule)) => Ok(Json(rule.into())),
_ => Err((StatusCode::NOT_FOUND, Json(ApiError::new("Rule not found")))),
},
Err(e) => {
error!("Failed to update rule: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to update rule")),
))
}
}
}
pub async fn delete_rule(
State(state): State<Arc<AppState>>,
Path(id): Path<i64>,
) -> Result<StatusCode, (StatusCode, Json<ApiError>)> {
match db::repositories::rules::delete(&state.db, id).await {
Ok(_) => Ok(StatusCode::NO_CONTENT),
Err(e) => {
error!("Failed to delete rule: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to delete rule")),
))
}
}
}
// Commands handlers
pub async fn list_commands(
State(state): State<Arc<AppState>>,
) -> Result<Json<Vec<CommandResponse>>, (StatusCode, Json<ApiError>)> {
match db::repositories::commands::list_all(&state.db).await {
Ok(commands) => Ok(Json(commands.into_iter().map(Into::into).collect())),
Err(e) => {
error!("Failed to list commands: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to fetch commands")),
))
}
}
}
pub async fn verify_command(
State(state): State<Arc<AppState>>,
Path(id): Path<i64>,
Json(req): Json<UpdateCommandRequest>,
) -> Result<Json<CommandResponse>, (StatusCode, Json<ApiError>)> {
match db::repositories::commands::update_status(
&state.db,
id,
&req.status,
req.notes.as_deref(),
)
.await
{
Ok(_) => match db::repositories::commands::find_by_id(&state.db, id).await {
Ok(Some(cmd)) => {
// Also update in-memory queue if status is verified
if req.status == "verified"
&& let Ok(cmd_value) = serde_json::from_str(&cmd.command_json)
{
state.commands.verified.write().await.push(cmd_value);
}
Ok(Json(cmd.into()))
}
_ => Err((
StatusCode::NOT_FOUND,
Json(ApiError::new("Command not found")),
)),
},
Err(e) => {
error!("Failed to update command: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to update command")),
))
}
}
}
// Config handlers
pub async fn get_config(
State(state): State<Arc<AppState>>,
) -> Result<Json<std::collections::HashMap<String, String>>, (StatusCode, Json<ApiError>)> {
match db::repositories::config::get_all(&state.db).await {
Ok(config) => Ok(Json(config)),
Err(e) => {
error!("Failed to get config: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new("Failed to fetch configuration")),
))
}
}
}
pub async fn update_config(
State(state): State<Arc<AppState>>,
Json(config): Json<std::collections::HashMap<String, String>>,
) -> Result<StatusCode, (StatusCode, Json<ApiError>)> {
for (key, value) in config {
if let Err(e) = db::repositories::config::set(&state.db, &key, &value, None).await {
error!("Failed to update config {}: {}", key, e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ApiError::new(format!("Failed to update config: {}", e))),
));
}
}
Ok(StatusCode::OK)
}

5
src/admin/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
pub mod auth_middleware;
pub mod handlers;
pub mod models;
pub mod routes;
pub mod static_files;

101
src/admin/models.rs Normal file
View File

@@ -0,0 +1,101 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
// Authentication models
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub token: String,
}
#[derive(Debug, Deserialize)]
pub struct ChangePasswordRequest {
pub old_password: String,
pub new_password: String,
}
// Request models
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub method_name: String,
pub action: String,
pub custom_response: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest {
pub method_name: Option<String>,
pub action: Option<String>,
pub custom_response: Option<String>,
pub is_enabled: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateCommandRequest {
pub status: String,
pub notes: Option<String>,
}
// Response models
#[derive(Debug, Serialize)]
pub struct RuleResponse {
pub id: i64,
pub method_name: String,
pub action: String,
pub custom_response: Option<String>,
pub is_enabled: bool,
pub created_at: String,
pub updated_at: String,
}
impl From<crate::db::models::InterceptionRule> for RuleResponse {
fn from(rule: crate::db::models::InterceptionRule) -> Self {
Self {
id: rule.id,
method_name: rule.method_name,
action: rule.action,
custom_response: rule.custom_response,
is_enabled: rule.is_enabled,
created_at: rule.created_at,
updated_at: rule.updated_at,
}
}
}
#[derive(Debug, Serialize)]
pub struct CommandResponse {
pub id: i64,
pub command: Value,
pub status: String,
pub received_at: String,
pub processed_at: Option<String>,
pub notes: Option<String>,
}
impl From<crate::db::models::Command> for CommandResponse {
fn from(cmd: crate::db::models::Command) -> Self {
Self {
id: cmd.id,
command: serde_json::from_str(&cmd.command_json).unwrap_or(Value::Null),
status: cmd.status,
received_at: cmd.received_at,
processed_at: cmd.processed_at,
notes: cmd.notes,
}
}
}
#[derive(Debug, Serialize)]
pub struct ApiError {
pub error: String,
}
impl ApiError {
pub fn new(msg: impl Into<String>) -> Self {
Self { error: msg.into() }
}
}

38
src/admin/routes.rs Normal file
View File

@@ -0,0 +1,38 @@
use std::sync::Arc;
use axum::{
Router,
routing::{delete, get, post, put},
};
use crate::AppState;
use super::{auth_middleware, handlers, static_files};
pub fn admin_routes(state: Arc<AppState>) -> Router<Arc<AppState>> {
// Public routes (no authentication required)
let public_routes = Router::new().route("/api/login", post(handlers::login));
// Protected API routes (require authentication)
let protected_routes = Router::new()
.route("/api/password", put(handlers::change_password))
.route("/api/rules", get(handlers::list_rules))
.route("/api/rules", post(handlers::create_rule))
.route("/api/rules/{:id}", put(handlers::update_rule))
.route("/api/rules/{:id}", delete(handlers::delete_rule))
.route("/api/commands", get(handlers::list_commands))
.route("/api/commands/{:id}", post(handlers::verify_command))
.route("/api/config", get(handlers::get_config))
.route("/api/config", put(handlers::update_config))
.layer(axum::middleware::from_fn_with_state(
state,
auth_middleware::auth_middleware,
));
// Combine routes
Router::new()
.merge(public_routes)
.merge(protected_routes)
// Static files (frontend)
.fallback(static_files::serve_static)
}

49
src/admin/static_files.rs Normal file
View File

@@ -0,0 +1,49 @@
use axum::{
body::Body,
http::{StatusCode, Uri, header},
response::{IntoResponse, Response},
};
use rust_embed::RustEmbed;
use tracing::info;
#[derive(RustEmbed)]
#[folder = "frontend/dist"]
pub struct Assets;
pub async fn serve_static(uri: Uri) -> impl IntoResponse {
let mut path = uri.path().trim_start_matches("/admin/").to_string();
// Default to index.html for root or directories
if path.is_empty() || path.ends_with('/') {
path = "index.html".to_string();
}
info!("{path}");
match Assets::get(path.trim_start_matches('/')) {
Some(content) => {
let mime = mime_guess::from_path(&path).first_or_octet_stream();
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, mime.as_ref())
.body(Body::from(content.data))
.unwrap()
}
None => {
// For SPA routing, serve index.html for non-asset paths
if !path.contains('.')
&& let Some(index) = Assets::get("index.html")
{
return Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/html")
.body(Body::from(index.data))
.unwrap();
}
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("404 Not Found"))
.unwrap()
}
}
}

56
src/auth.rs Normal file
View File

@@ -0,0 +1,56 @@
use anyhow::{Result, anyhow};
use bcrypt::{DEFAULT_COST, hash, verify};
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
const TOKEN_EXPIRATION_HOURS: u64 = 24;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub exp: usize,
}
/// Hash a password using bcrypt
pub fn hash_password(password: &str) -> Result<String> {
hash(password, DEFAULT_COST).map_err(|e| anyhow!("Failed to hash password: {}", e))
}
/// Verify a password against a hash
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
verify(password, hash).map_err(|e| anyhow!("Failed to verify password: {}", e))
}
/// Generate a JWT token
pub fn generate_token(jwt_secret: &[u8], username: &str) -> Result<String> {
let expiration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| anyhow!("Time error: {}", e))?
.as_secs()
+ (TOKEN_EXPIRATION_HOURS * 3600);
let claims = Claims {
sub: username.to_owned(),
exp: expiration as usize,
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(jwt_secret),
)
.map_err(|e| anyhow!("Failed to generate token: {}", e))
}
/// Validate a JWT token and return the claims
pub fn validate_token(jwt_secret: &[u8], token: &str) -> Result<Claims> {
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(jwt_secret),
&Validation::default(),
)
.map_err(|e| anyhow!("Invalid token: {}", e))?;
Ok(token_data.claims)
}

View File

@@ -63,9 +63,13 @@ pub fn encrypt(plaintext: &str, key: &str, iv: &str) -> Result<String, CryptoErr
let encryptor = Aes128CbcEnc::new_from_slices(key_bytes, iv_bytes).map_err(CryptoError::Aes)?;
// 2. Encrypt the plaintext with PKCS7 padding
let mut buffer = plaintext.as_bytes().to_vec();
// Allocate buffer with extra space for padding (AES block size is 16 bytes)
let plaintext_bytes = plaintext.as_bytes();
let mut buffer = vec![0u8; plaintext_bytes.len() + 16];
buffer[..plaintext_bytes.len()].copy_from_slice(plaintext_bytes);
let ciphertext = encryptor
.encrypt_padded_mut::<Pkcs7>(&mut buffer, plaintext.len())
.encrypt_padded_mut::<Pkcs7>(&mut buffer, plaintext_bytes.len())
.map_err(CryptoError::Pad)?;
// 3. Base64 encode the ciphertext

26
src/db/mod.rs Normal file
View File

@@ -0,0 +1,26 @@
pub mod models;
pub mod repositories;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use std::str::FromStr;
use tracing::info;
pub async fn init_db(database_url: &str) -> anyhow::Result<SqlitePool> {
info!("Initializing database at: {}", database_url);
// Parse connection options
let options = SqliteConnectOptions::from_str(database_url)?.create_if_missing(true);
// Create connection pool
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
// Run migrations
sqlx::migrate!("./migrations").run(&pool).await?;
info!("Database initialized successfully");
Ok(pool)
}

51
src/db/models.rs Normal file
View File

@@ -0,0 +1,51 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct InterceptionRule {
pub id: i64,
pub method_name: String,
pub action: String,
pub custom_response: Option<String>,
pub is_enabled: bool,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Command {
pub id: i64,
pub command_json: String,
pub status: String,
pub received_at: String,
pub processed_at: Option<String>,
pub notes: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Config {
pub key: String,
pub value: String,
pub description: Option<String>,
pub updated_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct ResponseTemplate {
pub id: i64,
pub name: String,
pub method_name: String,
pub response_json: String,
pub description: Option<String>,
pub created_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct RequestLog {
pub id: i64,
pub method: Option<String>,
pub path: String,
pub request_body: String,
pub response_body: String,
pub status_code: i32,
pub timestamp: String,
}

View File

@@ -0,0 +1,78 @@
use crate::db::models::Command;
use sqlx::SqlitePool;
pub async fn list_all(pool: &SqlitePool) -> anyhow::Result<Vec<Command>> {
let commands = sqlx::query_as::<_, Command>("SELECT * FROM commands ORDER BY received_at DESC")
.fetch_all(pool)
.await?;
Ok(commands)
}
pub async fn list_by_status(pool: &SqlitePool, status: &str) -> anyhow::Result<Vec<Command>> {
let commands = sqlx::query_as::<_, Command>(
"SELECT * FROM commands WHERE status = ? ORDER BY received_at DESC",
)
.bind(status)
.fetch_all(pool)
.await?;
Ok(commands)
}
pub async fn find_by_id(pool: &SqlitePool, id: i64) -> anyhow::Result<Option<Command>> {
let command = sqlx::query_as::<_, Command>("SELECT * FROM commands WHERE id = ?")
.bind(id)
.fetch_optional(pool)
.await?;
Ok(command)
}
pub async fn insert(pool: &SqlitePool, command_json: &str, status: &str) -> anyhow::Result<i64> {
let result = sqlx::query("INSERT INTO commands (command_json, status) VALUES (?, ?)")
.bind(command_json)
.bind(status)
.execute(pool)
.await?;
Ok(result.last_insert_rowid())
}
pub async fn update_status(
pool: &SqlitePool,
id: i64,
status: &str,
notes: Option<&str>,
) -> anyhow::Result<()> {
sqlx::query(
"UPDATE commands SET status = ?, processed_at = CURRENT_TIMESTAMP, notes = ? WHERE id = ?",
)
.bind(status)
.bind(notes)
.bind(id)
.execute(pool)
.await?;
Ok(())
}
pub async fn delete_old(pool: &SqlitePool, days: i64) -> anyhow::Result<u64> {
let result = sqlx::query(
"DELETE FROM commands WHERE status IN ('verified', 'rejected')
AND processed_at < datetime('now', '-' || ? || ' days')",
)
.bind(days)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
pub async fn clear_verified(pool: &SqlitePool) -> anyhow::Result<()> {
sqlx::query("DELETE FROM commands WHERE status = 'verified'")
.execute(pool)
.await?;
Ok(())
}

View File

@@ -0,0 +1,51 @@
use crate::db::models::Config;
use sqlx::SqlitePool;
use std::collections::HashMap;
pub async fn get_all(pool: &SqlitePool) -> anyhow::Result<HashMap<String, String>> {
let configs = sqlx::query_as::<_, Config>("SELECT * FROM config")
.fetch_all(pool)
.await?;
let map: HashMap<String, String> = configs.into_iter().map(|c| (c.key, c.value)).collect();
Ok(map)
}
pub async fn get(pool: &SqlitePool, key: &str) -> anyhow::Result<Option<String>> {
let config = sqlx::query_as::<_, Config>("SELECT * FROM config WHERE key = ?")
.bind(key)
.fetch_optional(pool)
.await?;
Ok(config.map(|c| c.value))
}
pub async fn set(
pool: &SqlitePool,
key: &str,
value: &str,
description: Option<&str>,
) -> anyhow::Result<()> {
sqlx::query(
"INSERT INTO config (key, value, description) VALUES (?, ?, ?)
ON CONFLICT(key) DO UPDATE SET value = ?, updated_at = CURRENT_TIMESTAMP",
)
.bind(key)
.bind(value)
.bind(description)
.bind(value)
.execute(pool)
.await?;
Ok(())
}
pub async fn delete(pool: &SqlitePool, key: &str) -> anyhow::Result<()> {
sqlx::query("DELETE FROM config WHERE key = ?")
.bind(key)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -0,0 +1,3 @@
pub mod commands;
pub mod config;
pub mod rules;

View File

@@ -0,0 +1,105 @@
use crate::db::models::InterceptionRule;
use sqlx::SqlitePool;
pub async fn list_all(pool: &SqlitePool) -> anyhow::Result<Vec<InterceptionRule>> {
let rules = sqlx::query_as::<_, InterceptionRule>(
"SELECT * FROM interception_rules ORDER BY created_at DESC",
)
.fetch_all(pool)
.await?;
Ok(rules)
}
pub async fn find_by_id(pool: &SqlitePool, id: i64) -> anyhow::Result<Option<InterceptionRule>> {
let rule =
sqlx::query_as::<_, InterceptionRule>("SELECT * FROM interception_rules WHERE id = ?")
.bind(id)
.fetch_optional(pool)
.await?;
Ok(rule)
}
pub async fn find_by_method(
pool: &SqlitePool,
method: &str,
) -> anyhow::Result<Option<InterceptionRule>> {
let rule = sqlx::query_as::<_, InterceptionRule>(
"SELECT * FROM interception_rules WHERE method_name = ? AND is_enabled = 1",
)
.bind(method)
.fetch_optional(pool)
.await?;
Ok(rule)
}
pub async fn create(
pool: &SqlitePool,
method_name: &str,
action: &str,
custom_response: Option<&str>,
) -> anyhow::Result<i64> {
let result = sqlx::query(
"INSERT INTO interception_rules (method_name, action, custom_response, is_enabled)
VALUES (?, ?, ?, 1)",
)
.bind(method_name)
.bind(action)
.bind(custom_response)
.execute(pool)
.await?;
Ok(result.last_insert_rowid())
}
pub async fn update(
pool: &SqlitePool,
id: i64,
method_name: Option<&str>,
action: Option<&str>,
custom_response: Option<&str>,
is_enabled: Option<bool>,
) -> anyhow::Result<()> {
let mut query = String::from("UPDATE interception_rules SET updated_at = CURRENT_TIMESTAMP");
let mut params: Vec<String> = Vec::new();
if let Some(m) = method_name {
query.push_str(", method_name = ?");
params.push(m.to_string());
}
if let Some(a) = action {
query.push_str(", action = ?");
params.push(a.to_string());
}
if custom_response.is_some() {
query.push_str(", custom_response = ?");
params.push(custom_response.unwrap_or("").to_string());
}
if let Some(e) = is_enabled {
query.push_str(", is_enabled = ?");
params.push(if e { "1" } else { "0" }.to_string());
}
query.push_str(" WHERE id = ?");
let mut q = sqlx::query(&query);
for param in params {
q = q.bind(param);
}
q = q.bind(id);
q.execute(pool).await?;
Ok(())
}
pub async fn delete(pool: &SqlitePool, id: i64) -> anyhow::Result<()> {
sqlx::query("DELETE FROM interception_rules WHERE id = ?")
.bind(id)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -1,19 +1,438 @@
use hashbrown::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Serialize, Deserialize)]
pub struct Request {
pub method: String,
pub params: Option<serde_json::Map<String, Value>>,
pub id: Value,
pub jsonrpc: Option<String>
mod tactics;
use tactics::Tactics;
macro_rules! data_wrapper {
($name:ident, $ty:literal) => {
::concat_idents::concat_idents! {
struct_name = $name, Wrapper,
{
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename = $ty)]
pub struct struct_name {
pub data: $name,
}
}
}
};
}
#[derive(Serialize, Deserialize)]
macro_rules! declare_request_enum {
(
$(#[$enum_meta:meta])*
$vis:vis enum $enum_name:ident {
$(
$key:literal => $variant:ident($ty:ty)
),*
$(,)?
_ => $generic_variant:ident($generic_ty:ident)
}
) => {
$(#[$enum_meta])*
#[serde(tag = "method", content = "params")]
$vis enum $enum_name {
$(
#[serde(rename = $key)]
$variant($ty),
)*
#[serde(untagged)]
$generic_variant($generic_ty),
}
impl<'de> serde::Deserialize<'de> for $enum_name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct __RawRequest {
method: String,
params: serde_json::Map<String, Value>,
}
let raw = __RawRequest::deserialize(deserializer)?;
match raw.method.as_str() {
$(
$key => {
let params = serde_json::from_value(Value::Object(raw.params))
.map_err(serde::de::Error::custom)?;
Ok($enum_name::$variant(params))
}
)*
_ => {
Ok($enum_name::$generic_variant($generic_ty {
method: raw.method,
params: raw.params,
}))
}
}
}
}
};
}
declare_request_enum! {
#[derive(Serialize, Debug)]
pub enum RequestContent {
"com.linspirer.tactics.gettactics" => GetTactics(GetTacticsParams),
_ => Generic(GenericRequestContent)
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Request {
#[serde(rename = "!version")]
pub version: i32,
pub client_version: String,
pub id: i32,
pub jsonrpc: String,
#[serde(flatten)]
pub content: RequestContent,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct BaseParams {
email: String,
model: String,
swdid: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct GetTacticsParams {
#[serde(flatten)]
base: BaseParams,
launcher_version: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct GenericRequestContent {
pub method: String,
#[serde(default)]
pub params: serde_json::Map<String, Value>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Response {
pub result: Option<Value>,
pub params: Option<HashMap<String, Value>>,
pub id: Value,
pub jsonrpc: Option<String>
}
pub code: i32,
#[serde(flatten)]
pub data: ResponseData,
}
data_wrapper!(Tactics, "object");
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum ResponseData {
Tactics(Box<TacticsWrapper>),
Generic(GenericResponseContent),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct GenericResponseContent {
pub r#type: String,
#[serde(default)]
pub data: Option<Value>,
}
#[cfg(test)]
mod test {
use indoc::indoc;
use serde_json::Map;
use super::*;
#[test]
fn deserialize_gettactics_response() {
const RESP: &str = indoc!(
r#"{
"code": 0,
"data": {
"app_status": true,
"app_tactics": {
"applist": [
{
"canuninstall": false,
"created_at": "2025-09-04 08:04:08",
"devicetype": "HITV102C",
"exception_white_url": 1,
"grant_to": 788955,
"grant_type": 5,
"groupid": 1,
"hide_icon_status": 0,
"id": 156727,
"is_trust": true,
"isforce": true,
"isnew": false,
"name": "青鹿作业5",
"packagename": "com.qljy.smarthomework.student",
"sha1": "FD:73:65:ED:4F:B9:32:DE:4D:8D:26:BF:B7:61:E4:CF:82:47:24:AB",
"sort_weight": 0,
"status": 1,
"target_sdk_version": 30,
"updated_at": "2025-11-19 17:23:59",
"versioncode": 1019,
"versionname": "5.5.6"
}
]
},
"device_setting": {
"alarm_clock_status": 0,
"allow_change_password_status": 0,
"calendar_status": 0,
"camera_status": 0,
"data_flow_status": 1,
"disable_reinstall_system_status": 0,
"enable_client_admin_status": 1,
"enable_gesture_pwd_status": 0,
"enable_gps_status": 1,
"enable_screenshots_status": 1,
"enable_system_upgrade_status": 1,
"enable_wifi_advanced_status": 0,
"gallery_status": 0,
"hide_accelerate_status": 0,
"hide_cleanup_status": 0,
"keep_alive_package": null,
"launch_app": {
"launch_mode": 1,
"launch_package": "cn.com.ava.ebook5"
},
"logout_status": 1,
"only_install_store_app_status": 1,
"otg_set": {
"pv_list": [],
"status": 0
},
"protected_eyes_status": {
"distance_status": 0,
"sensitive_status": 0,
"sitting_position_status": 0
},
"remind_duration": {
"duration": 0,
"remind_status": 0
},
"rotate_setting_status": 1,
"school_class_display_status": 0,
"sdcard_and_otg": 0,
"show_privacy_statement_status": 1,
"simcard": 0
},
"device_status": true,
"device_tactics": {
"deviceManage": {
"command_bluetooth": true,
"command_camera": true,
"command_connect_usb": false,
"command_data_flow": true,
"command_force_open_wifi": false,
"command_gps": true,
"command_otg": false,
"command_phone_msg": false,
"command_recording": true,
"command_sd_card": false,
"command_wifi_advanced": false,
"command_wifi_switch": true
}
},
"enable_amap_status": 1,
"free_control": 0,
"id": 116425,
"illegal_status": false,
"illegal_tactics": {
"already_root": {
"eliminate_data": false,
"enable": false,
"lock_workspace": false,
"notify_admin": false
},
"change_simcard": {
"eliminate_data": false,
"enable": false,
"lock_workspace": false,
"notify_admin": false
},
"prohibited_app": {
"eliminate_data": false,
"enable": false,
"lock_workspace": false,
"notify_admin": false
},
"usb_to_pc": {
"eliminate_data": false,
"enable": false,
"lock_workspace": false,
"notify_admin": false
}
},
"interest_applist": [],
"name": "whoami",
"release_control": 0,
"updated_at": "2025-09-04 08:05:18",
"usergroup": 1219237,
"wifi_status": false,
"wifi_tactics": [],
"wifi_tactics_2": [],
"workspace_status": false,
"workspace_tactics": {
"worktime": {}
}
},
"type": "object"
}"#
);
let resp: Response = serde_json::from_str(RESP).unwrap();
println!("{resp:#?}");
}
#[test]
fn serialize_gettactics_response() {
let tactics = Tactics {
app_status: true,
app_tactics: tactics::AppTactics {
applist: Vec::new(),
},
device_setting: tactics::DeviceSetting {
alarm_clock_status: true,
allow_change_password_status: true,
calendar_status: true,
camera_status: true,
data_flow_status: true,
disable_reinstall_system_status: true,
enable_client_admin_status: true,
enable_gesture_pwd_status: true,
enable_gps_status: true,
enable_screenshots_status: true,
enable_system_upgrade_status: true,
enable_wifi_advanced_status: true,
gallery_status: true,
hide_accelerate_status: false,
hide_cleanup_status: false,
keep_alive_package: Value::Null,
launch_app: tactics::LaunchApp {
launch_mode: 1,
launch_package: "cn.com.ava.ebook5".to_string(),
},
logout_status: true,
only_install_store_app_status: true,
otg_set: tactics::OtgSet {
pv_list: vec![],
status: false,
},
protected_eyes_status: tactics::ProtectedEyesStatus {
distance_status: false,
sensitive_status: false,
sitting_position_status: false,
},
remind_duration: tactics::RemindDuration {
duration: 0,
remind_status: false,
},
rotate_setting_status: true,
school_class_display_status: false,
sdcard_and_otg: false,
show_privacy_statement_status: true,
simcard: false,
},
device_status: true,
device_tactics: tactics::DeviceTactics {
device_manage: tactics::DeviceManage {
command_bluetooth: true,
command_data_flow: true,
command_gps: true,
command_otg: true,
command_camera: true,
command_sd_card: true,
command_phone_msg: true,
command_recording: true,
command_connect_usb: true,
command_wifi_switch: true,
command_wifi_advanced: true,
command_force_open_wifi: true,
},
},
enable_amap_status: true,
free_control: false,
id: 116425,
illegal_status: false,
illegal_tactics: tactics::IllegalTactics::default(),
interest_applist: vec![],
name: "whoami".to_string(),
release_control: false,
updated_at: "idunno".to_string(),
user_group: 1219237,
wifi_status: false,
wifi_tactics: vec![],
wifi_tactics_2: vec![],
workspace_status: false,
workspace_tactics: tactics::WorkspaceTactics {
worktime: Map::new(),
},
};
println!("{}", serde_json::to_string_pretty(&tactics).unwrap())
}
#[test]
fn deserialize_gettactics_request() {
const REQ: &str = indoc!(
r#"{
"!version": 6,
"client_version": "sxqinglu_product_5.04.105.1",
"id": 1,
"jsonrpc": "2.0",
"method": "com.linspirer.tactics.gettactics",
"params": {
"email": "idunno",
"launcher_version": "sxqinglu_product_5.04.105.1",
"model": "HITV102C",
"swdid": "idunno"
}
}"#
);
let req: Request = serde_json::from_str(REQ).unwrap();
println!("{:#?}", req);
}
#[test]
fn serialize_gettactics_request() {
let req = Request {
version: 6,
client_version: "sxqinglu_product_5.04.105.1".to_string(),
id: 1,
jsonrpc: "2.0".to_string(),
content: RequestContent::GetTactics(GetTacticsParams {
base: BaseParams {
email: "idunno".to_string(),
model: "HITV102C".to_string(),
swdid: "idunno".to_string(),
},
launcher_version: "sxqinglu_product_5.04.105.1".to_string(),
}),
};
// base params goes first
const REQ: &str = indoc!(
r#"{
"!version": 6,
"client_version": "sxqinglu_product_5.04.105.1",
"id": 1,
"jsonrpc": "2.0",
"method": "com.linspirer.tactics.gettactics",
"params": {
"email": "idunno",
"model": "HITV102C",
"swdid": "idunno",
"launcher_version": "sxqinglu_product_5.04.105.1"
}
}"#
);
let req = serde_json::to_string_pretty(&req).unwrap();
println!("{req}");
assert_eq!(req, REQ);
}
}

206
src/jsonrpc/tactics.rs Normal file
View File

@@ -0,0 +1,206 @@
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use serde_with::{BoolFromInt, serde_as};
#[serde_as]
#[derive(Deserialize, Serialize, Debug)]
pub struct Tactics {
pub app_status: bool,
pub app_tactics: AppTactics,
pub device_setting: DeviceSetting,
pub device_status: bool,
pub device_tactics: DeviceTactics,
#[serde_as(as = "BoolFromInt")]
pub enable_amap_status: bool,
#[serde_as(as = "BoolFromInt")]
pub free_control: bool,
pub id: u32,
pub illegal_status: bool,
pub illegal_tactics: IllegalTactics,
pub interest_applist: Vec<Value>,
pub name: String,
#[serde_as(as = "BoolFromInt")]
pub release_control: bool,
// TODO: std::time::Instant
pub updated_at: String,
#[serde(rename = "usergroup")]
pub user_group: u32,
pub wifi_status: bool,
pub wifi_tactics: Vec<Value>,
pub wifi_tactics_2: Vec<Value>,
pub workspace_status: bool,
pub workspace_tactics: WorkspaceTactics,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct AppTactics {
pub applist: Vec<App>,
}
#[serde_as]
#[derive(Deserialize, Serialize, Debug)]
pub struct App {
#[serde(rename = "canuninstall")]
can_uninstall: bool,
// TODO: std::time::Instant
created_at: String,
#[serde(rename = "devicetype")]
device_type: String,
#[serde_as(as = "BoolFromInt")]
exception_white_url: bool,
grant_to: u32,
grant_type: u32,
#[serde(rename = "groupid")]
group_id: u32,
#[serde_as(as = "BoolFromInt")]
hide_icon_status: bool,
id: u32,
is_trust: bool,
#[serde(rename = "isforce")]
is_force: bool,
#[serde(rename = "isnew")]
is_new: bool,
name: String,
#[serde(rename = "packagename")]
package_name: String,
sha1: String,
sort_weight: i32,
status: i32,
target_sdk_version: i32,
// TODO: std::time::Instant
updated_at: String,
#[serde(rename = "versioncode")]
version_code: i32,
#[serde(rename = "versionname")]
version_name: String,
}
#[serde_as]
#[derive(Deserialize, Serialize, Debug)]
pub struct DeviceSetting {
#[serde_as(as = "BoolFromInt")]
pub alarm_clock_status: bool,
#[serde_as(as = "BoolFromInt")]
pub allow_change_password_status: bool,
#[serde_as(as = "BoolFromInt")]
pub calendar_status: bool,
#[serde_as(as = "BoolFromInt")]
pub camera_status: bool,
#[serde_as(as = "BoolFromInt")]
pub data_flow_status: bool,
#[serde_as(as = "BoolFromInt")]
pub disable_reinstall_system_status: bool,
#[serde_as(as = "BoolFromInt")]
pub enable_client_admin_status: bool,
#[serde_as(as = "BoolFromInt")]
pub enable_gesture_pwd_status: bool,
#[serde_as(as = "BoolFromInt")]
pub enable_gps_status: bool,
#[serde_as(as = "BoolFromInt")]
pub enable_screenshots_status: bool,
#[serde_as(as = "BoolFromInt")]
pub enable_system_upgrade_status: bool,
#[serde_as(as = "BoolFromInt")]
pub enable_wifi_advanced_status: bool,
#[serde_as(as = "BoolFromInt")]
pub gallery_status: bool,
#[serde_as(as = "BoolFromInt")]
pub hide_accelerate_status: bool,
#[serde_as(as = "BoolFromInt")]
pub hide_cleanup_status: bool,
pub keep_alive_package: Value,
pub launch_app: LaunchApp,
#[serde_as(as = "BoolFromInt")]
pub logout_status: bool,
#[serde_as(as = "BoolFromInt")]
pub only_install_store_app_status: bool,
pub otg_set: OtgSet,
pub protected_eyes_status: ProtectedEyesStatus,
pub remind_duration: RemindDuration,
#[serde_as(as = "BoolFromInt")]
pub rotate_setting_status: bool,
#[serde_as(as = "BoolFromInt")]
pub school_class_display_status: bool,
#[serde_as(as = "BoolFromInt")]
pub sdcard_and_otg: bool,
#[serde_as(as = "BoolFromInt")]
pub show_privacy_statement_status: bool,
#[serde_as(as = "BoolFromInt")]
pub simcard: bool,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct LaunchApp {
pub launch_mode: u32,
pub launch_package: String,
}
#[serde_as]
#[derive(Deserialize, Serialize, Debug)]
pub struct OtgSet {
pub pv_list: Vec<Value>,
#[serde_as(as = "BoolFromInt")]
pub status: bool,
}
#[serde_as]
#[derive(Deserialize, Serialize, Debug)]
pub struct ProtectedEyesStatus {
#[serde_as(as = "BoolFromInt")]
pub distance_status: bool,
#[serde_as(as = "BoolFromInt")]
pub sensitive_status: bool,
#[serde_as(as = "BoolFromInt")]
pub sitting_position_status: bool,
}
#[serde_as]
#[derive(Deserialize, Serialize, Debug)]
pub struct RemindDuration {
pub duration: u32,
#[serde_as(as = "BoolFromInt")]
pub remind_status: bool,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct DeviceTactics {
#[serde(rename = "deviceManage")]
pub device_manage: DeviceManage,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct DeviceManage {
pub command_bluetooth: bool,
pub command_camera: bool,
pub command_connect_usb: bool,
pub command_data_flow: bool,
pub command_force_open_wifi: bool,
pub command_gps: bool,
pub command_otg: bool,
pub command_phone_msg: bool,
pub command_recording: bool,
pub command_sd_card: bool,
pub command_wifi_advanced: bool,
pub command_wifi_switch: bool,
}
#[derive(Deserialize, Serialize, Debug, Default)]
pub struct IllegalTactics {
pub already_root: IllegalTactic,
pub change_simcard: IllegalTactic,
pub prohibited_app: IllegalTactic,
pub usb_to_pc: IllegalTactic,
}
#[derive(Deserialize, Serialize, Debug, Default)]
pub struct IllegalTactic {
pub eliminate_data: bool,
pub enable: bool,
pub lock_workspace: bool,
pub notify_admin: bool,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct WorkspaceTactics {
pub worktime: Map<String, Value>,
}

View File

@@ -1,32 +1,28 @@
use std::{net::SocketAddr, sync::Arc};
use anyhow::Context;
use axum::Router;
use axum::handler::Handler;
use dotenvy::dotenv;
use serde_json::Value;
use tokio::sync::RwLock;
use tower_http::compression::CompressionLayer;
use tracing::{error, info, level_filters::LevelFilter};
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
mod admin;
mod auth;
mod crypto;
mod db;
mod jsonrpc;
mod middleware;
mod proxy;
mod state;
#[derive(Clone)]
pub struct AppState {
pub client: reqwest::Client,
pub target_url: reqwest::Url,
pub key: String,
pub iv: String,
pub command_queue: Arc<RwLock<Vec<Value>>>,
pub saved_tactics: Arc<RwLock<Option<Value>>>,
}
use state::AppState;
const DEFAULT_TARGET_URL: &str = "https://cloud.linspirer.com:883";
const DEFAULT_HOST: &str = "0.0.0.0";
const DEFAULT_PORT: &str = "8080";
const DEFAULT_DB_PATH: &str = "sqlite://./data/linspirer.db";
#[tokio::main]
async fn main() -> anyhow::Result<()> {
@@ -50,6 +46,9 @@ async fn main() -> anyhow::Result<()> {
let target_url_str = target_url_str.as_deref().unwrap_or(DEFAULT_TARGET_URL);
let target_url = reqwest::Url::parse(target_url_str)?;
let db_path_str = std::env::var("LINSPIRER_DB_PATH");
let db_path = db_path_str.as_deref().unwrap_or(DEFAULT_DB_PATH);
let host_str = std::env::var("LINSPIRER_HOST");
let host = host_str.as_deref().unwrap_or(DEFAULT_HOST);
let port_str = std::env::var("LINSPIRER_PORT");
@@ -58,6 +57,11 @@ async fn main() -> anyhow::Result<()> {
let addr: SocketAddr = addr_str
.parse()
.context(format!("Invalid address format: {}", addr_str))?;
let jwt_secret = std::env::var("LINSPIRER_JWT_SECRET")
.map_err(|_| anyhow::anyhow!("LINSPIRER_JWT_SECRET not set"))?;
// Initialize database
let db = db::init_db(db_path).await?;
// Create a reqwest client that ignores SSL certificate verification
let client = reqwest::Client::builder()
@@ -71,16 +75,20 @@ async fn main() -> anyhow::Result<()> {
target_url,
key,
iv,
command_queue: Arc::new(RwLock::new(Vec::new())),
saved_tactics: Arc::new(RwLock::new(None)),
jwt_secret,
db,
commands: Default::default(),
});
let log_middleware =
axum::middleware::from_fn_with_state(state.clone(), middleware::log_middleware);
// Build our application
let app = proxy::proxy_handler
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::log_middleware,
))
let app = Router::new()
// Admin routes
.nest("/admin", admin::routes::admin_routes(state.clone()))
// Proxy all other routes (fallback)
.fallback(proxy::proxy_handler.layer(log_middleware))
.layer(CompressionLayer::new().gzip(true))
.with_state(state);

View File

@@ -9,7 +9,7 @@ use axum::{
};
use http_body_util::BodyExt;
use serde_json::Value;
use tracing::{info, warn};
use tracing::{debug, info, warn};
use crate::{AppState, crypto};
@@ -77,24 +77,35 @@ pub async fn log_middleware(
};
let resp_body_text = String::from_utf8(body_bytes.clone().to_vec()).unwrap_or_default();
let response_body_to_log = if Some("com.linspirer.device.getcommand") == method.as_deref() {
match handle_getcommand_response(&resp_body_text, &state).await {
Ok(new_body) => ResponseBody::Modified(new_body),
Err(e) => {
warn!(
"Failed to handle getcommand response: {}. Responding with empty command list.",
e
);
let mut empty_response =
serde_json::from_str::<Value>(&resp_body_text).unwrap_or(Value::Null);
if let Some(obj) = empty_response.as_object_mut() {
obj.insert("result".to_string(), Value::Array(vec![]));
// Check for generic method interception first
let response_body_to_log = if let Some(method_str) = &method {
if let Ok(Some(intercepted)) =
maybe_intercept_response(method_str, &resp_body_text, &state).await
{
info!("Intercepting response for method: {}", method_str);
ResponseBody::Original(intercepted)
} else if Some("com.linspirer.device.getcommand") == method.as_deref() {
// Special handling for getcommand
match handle_getcommand_response(&resp_body_text, &state).await {
Ok(new_body) => ResponseBody::Modified(new_body),
Err(e) => {
warn!(
"Failed to handle getcommand response: {}. Responding with empty command list.",
e
);
let mut empty_response =
serde_json::from_str::<Value>(&resp_body_text).unwrap_or(Value::Null);
if let Some(obj) = empty_response.as_object_mut() {
obj.insert("result".to_string(), Value::Array(vec![]));
}
ResponseBody::Modified(empty_response)
}
ResponseBody::Modified(empty_response)
}
} else {
ResponseBody::Original(resp_body_text.clone())
}
} else {
ResponseBody::Original(resp_body_text)
ResponseBody::Original(resp_body_text.clone())
};
let (decrypted_response_for_log, final_response_body) = match response_body_to_log {
@@ -111,7 +122,7 @@ pub async fn log_middleware(
}
};
info!(
debug!(
"{}\nRequest:\n{}\nResponse:\n{}\n{}",
path,
serde_json::to_string_pretty(&decrypted_request_log).unwrap_or_default(),
@@ -149,23 +160,61 @@ fn process_and_log_request(body: &str, key: &str, iv: &str) -> anyhow::Result<Va
Ok(request_data)
}
async fn handle_getcommand_response(body_text: &str, state: &Arc<AppState>) -> anyhow::Result<Value> {
async fn handle_getcommand_response(
body_text: &str,
state: &Arc<AppState>,
) -> anyhow::Result<Value> {
let decrypted = crypto::decrypt(body_text, &state.key, &state.iv)?;
let mut response_json: Value = serde_json::from_str(&decrypted)?;
if let Some(result) = response_json.get("result")
&& let Some(commands) = result.as_array()
if let Some(result) = response_json.get_mut("result")
&& let Some(commands) = result.as_array_mut()
&& !commands.is_empty()
{
let mut queue = state.command_queue.write().await;
for cmd in commands {
queue.push(cmd.clone());
// Persist commands to database
for cmd in commands.iter() {
let cmd_json = serde_json::to_string(cmd)?;
if let Err(e) =
crate::db::repositories::commands::insert(&state.db, &cmd_json, "unverified").await
{
warn!("Failed to persist command to database: {}", e);
}
info!("Added command to the queue: {:?}", cmd);
}
// Also add to in-memory queue for backwards compatibility
let mut queue = state.commands.unverified.write().await;
queue.extend(commands.drain(..));
}
if let Some(obj) = response_json.as_object_mut() {
obj.insert("result".to_string(), Value::Array(vec![]));
// Get verified commands from database
let verified_cmds =
match crate::db::repositories::commands::list_by_status(&state.db, "verified").await {
Ok(cmds) => cmds,
Err(e) => {
warn!("Failed to fetch verified commands from database: {}", e);
Vec::new()
}
};
// Convert to JSON values
let verified_values: Vec<Value> = verified_cmds
.iter()
.filter_map(|c| serde_json::from_str(&c.command_json).ok())
.collect();
// Also include in-memory verified commands
let mem_verified = std::mem::take(&mut *state.commands.verified.write().await);
let mut all_verified = verified_values;
all_verified.extend(mem_verified);
obj.insert("result".to_string(), Value::Array(all_verified.clone()));
// Clear verified commands from database after sending
if let Err(e) = crate::db::repositories::commands::clear_verified(&state.db).await {
warn!("Failed to clear verified commands from database: {}", e);
}
}
Ok(response_json)
@@ -175,4 +224,29 @@ fn decrypt_and_format(body_text: &str, key: &str, iv: &str) -> anyhow::Result<St
let decrypted = crypto::decrypt(body_text, key, iv)?;
let formatted: Value = serde_json::from_str(&decrypted)?;
Ok(serde_json::to_string_pretty(&formatted)?)
}
}
async fn maybe_intercept_response(
method: &str,
_original_response: &str,
state: &Arc<AppState>,
) -> anyhow::Result<Option<String>> {
// Check if there's an interception rule for this method
let rule = crate::db::repositories::rules::find_by_method(&state.db, method).await?;
match rule {
Some(r) if r.action == "replace" => {
if let Some(custom_response) = &r.custom_response {
Ok(crypto::encrypt(custom_response, &state.key, &state.iv).map(Some)?)
} else {
Ok(None)
}
}
Some(r) if r.action == "modify" => {
// Future: Apply transformations
// For now, just pass through
Ok(None)
}
_ => Ok(None), // Passthrough
}
}

33
src/state.rs Normal file
View File

@@ -0,0 +1,33 @@
use serde_json::Value;
use sqlx::SqlitePool;
use tokio::sync::RwLock;
pub struct AppState {
pub client: reqwest::Client,
pub target_url: reqwest::Url,
pub key: String,
pub iv: String,
pub jwt_secret: String,
pub db: SqlitePool,
pub commands: Commands,
}
pub struct Commands {
pub unverified: RwLock<Vec<Value>>,
pub verified: RwLock<Vec<Value>>,
}
impl Commands {
pub fn new() -> Self {
Self {
unverified: RwLock::default(),
verified: RwLock::default(),
}
}
}
impl Default for Commands {
fn default() -> Self {
Self::new()
}
}