rust implementation

This commit is contained in:
2025-11-23 12:12:20 +08:00
parent f40ace4058
commit 226e8cbf4f
19 changed files with 2358 additions and 364 deletions

7
.env.example Normal file
View File

@@ -0,0 +1,7 @@
# Please replace these with your actual key and IV
LINSPIRER_KEY="0123456789abcdef"
LINSPIRER_IV="0123456789abcdef"
# Optional: Set the listening host and port
# LINSPIRER_HOST="0.0.0.0"
# LINSPIRER_PORT="8080"

1
.envrc Normal file
View File

@@ -0,0 +1 @@
use flake

4
.gitignore vendored
View File

@@ -1,2 +1,6 @@
/.direnv
/.env
/target
/*.log

1937
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

22
Cargo.toml Normal file
View File

@@ -0,0 +1,22 @@
[package]
name = "mylinspirer"
version = "0.1.0"
edition = "2024"
[dependencies]
axum = "0.8"
tokio = { version = "1", features = ["full"] }
hyper = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
base64 = "0.22"
aes = "0.8"
cbc = { version = "0.1", features = ["alloc"] }
dotenvy = "0.15"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
anyhow = "1"
thiserror = "2"
http-body-util = "0.1.1"
chrono = { version = "0.4", features = ["clock"] }
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }

View File

@@ -1,177 +0,0 @@
package main
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"git.imxyy.top/imxyy1soope1/MyLinspirer/internal/utils/crypto"
"git.imxyy.top/imxyy1soope1/MyLinspirer/internal/utils/format"
)
const (
targetURL = "https://cloud.linspirer.com:883"
)
var (
proxy *httputil.ReverseProxy
host string
port string
key string
iv string
once sync.Once
)
func init() {
once.Do(func() {
key = os.Getenv("LINSPIRER_KEY")
iv = os.Getenv("LINSPIRER_IV")
if key == "" || iv == "" {
log.Fatalf("LINSPIRER_KEY or LINSPIRER_IV is not set")
}
})
}
func main() {
flag.StringVar(&host, "a", "", "listening host")
flag.StringVar(&port, "p", "8080", "listening port")
flag.Parse()
portNum, err := strconv.Atoi(port)
if err != nil {
log.Fatalf("invalid port: %v", err)
}
if portNum < 1 || portNum > 65535 {
log.Fatalf("port out of range (1-65535): %d", portNum)
}
addr := host + ":" + port
if host == "" {
addr = ":" + port
}
target, _ := url.Parse(targetURL)
proxy = &httputil.ReverseProxy{
Rewrite: func(req *httputil.ProxyRequest) {
startTime := time.Now()
recordRequest(req, startTime)
req.SetURL(target)
},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
ModifyResponse: logResponse,
}
log.Printf("Proxy started on %s => %s", addr, targetURL)
log.Fatal(http.ListenAndServe(addr, http.HandlerFunc(proxy.ServeHTTP)))
}
func recordRequest(req *httputil.ProxyRequest, startTime time.Time) {
if req.Out.Body == nil {
return
}
body, _ := io.ReadAll(req.Out.Body)
req.Out.Body = io.NopCloser(bytes.NewBuffer(body))
var requestData map[string]any
if err := json.Unmarshal(body, &requestData); err == nil {
if paramsEnc, ok := requestData["params"].(string); ok {
if decrypted, err := crypto.Decrypt(paramsEnc, key, iv); err == nil {
var unmarshaledParams map[string]any
err = json.Unmarshal([]byte(decrypted), &unmarshaledParams)
if err == nil {
requestData["params"] = unmarshaledParams
} else {
requestData["params"] = decrypted
}
} else {
requestData["params"] = fmt.Sprintf("decrypt failed: %v", err)
}
}
} else {
requestData = map[string]any{"request": fmt.Sprintf("JSON parse error: %v", err)}
}
ctx := context.WithValue(req.Out.Context(), "startTime", startTime)
ctx = context.WithValue(ctx, "decryptedRequest", requestData)
req.Out = req.Out.WithContext(ctx)
}
func extractContextValue[T any](ctx context.Context, name string) (val T, ok bool) {
valAny := ctx.Value(name)
if valAny == nil {
log.Printf("[ERROR] %s not found in context", name)
return
}
val, ok = valAny.(T)
if !ok {
log.Printf("[ERROR] invalid %s type: %T", name, valAny)
return
}
return
}
func logResponse(resp *http.Response) error {
startTime, timeOk := extractContextValue[time.Time](resp.Request.Context(), "startTime")
decryptedRequest, reqOk := extractContextValue[map[string]any](resp.Request.Context(), "decryptedRequest")
if !timeOk || !reqOk {
return nil
}
body, _ := io.ReadAll(resp.Body)
resp.Body = io.NopCloser(bytes.NewBuffer(body))
var response []byte
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(bytes.NewReader(body))
if err == nil {
defer gzReader.Close()
if decompressed, err := io.ReadAll(gzReader); err == nil {
response = decompressed
} else {
fmt.Appendf(response, "gzip decompress failed: %v", err)
}
} else {
fmt.Appendf(response, "gzip init failed: %v", err)
}
} else {
response = body
}
respPlaintext := "N/A"
if decrypted, err := crypto.Decrypt(string(response), key, iv); err == nil {
respPlaintext = format.FormatJSON(decrypted)
} else {
respPlaintext = fmt.Sprintf("decrypt failed: %v", err)
}
requestJSON, _ := json.MarshalIndent(decryptedRequest, "", " ")
log.Printf("[%s] %s\nRequest:\n%s\nResponse:\n%s\n%s\n",
startTime.Format("2006/01/02 15:04:05"),
decryptedRequest["method"].(string),
format.FormatJSON(string(requestJSON)),
respPlaintext,
strings.Repeat("-", 80),
)
return nil
}

66
flake.lock generated Normal file
View File

@@ -0,0 +1,66 @@
{
"nodes": {
"fenix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"rust-analyzer-src": "rust-analyzer-src"
},
"locked": {
"lastModified": 1759301100,
"narHash": "sha256-hmiTEoVAqLnn80UkreCNunnRKPucKvcg5T4/CELEtbw=",
"owner": "nix-community",
"repo": "fenix",
"rev": "0956bc5d1df2ea800010172c6bc4470d9a22cb81",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "fenix",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1759036355,
"narHash": "sha256-0m27AKv6ka+q270dw48KflE0LwQYrO7Fm4/2//KCVWg=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "e9f00bd893984bc8ce46c895c3bf7cac95331127",
"type": "github"
},
"original": {
"owner": "nixos",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"fenix": "fenix",
"nixpkgs": "nixpkgs"
}
},
"rust-analyzer-src": {
"flake": false,
"locked": {
"lastModified": 1759245522,
"narHash": "sha256-H4Hx/EuMJ9qi1WzPV4UG2bbZiDCdREtrtDvYcHr0kmk=",
"owner": "rust-lang",
"repo": "rust-analyzer",
"rev": "a6bc4a4bbe6a65b71cbf76a0cf528c47a8d9f97f",
"type": "github"
},
"original": {
"owner": "rust-lang",
"ref": "nightly",
"repo": "rust-analyzer",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

36
flake.nix Normal file
View File

@@ -0,0 +1,36 @@
{
inputs = {
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
fenix.url = "github:nix-community/fenix";
fenix.inputs.nixpkgs.follows = "nixpkgs";
};
outputs = { nixpkgs, fenix, ... }:
let
forAllSystems = nixpkgs.lib.genAttrs nixpkgs.lib.systems.flakeExposed;
in
{
devShells = forAllSystems (system:
let pkgs = import nixpkgs { inherit system; config.allowUnfree = true; }; in
{
default =
pkgs.mkShell rec {
nativeBuildInputs = with pkgs; [
gcc
openssl
pkg-config
(fenix.packages.${system}.stable.withComponents [
"cargo"
"clippy"
"rust-src"
"rustc"
"rustfmt"
"rust-analyzer"
])
];
buildInputs = [];
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath (nativeBuildInputs ++ buildInputs);
};
}
);
};
}

12
go.mod
View File

@@ -1,12 +0,0 @@
module git.imxyy.top/imxyy1soope1/MyLinspirer
go 1.24.1
require (
github.com/bytedance/sonic v1.13.2 // indirect
github.com/bytedance/sonic/loader v0.2.4 // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect
)

29
go.sum
View File

@@ -1,29 +0,0 @@
github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ=
github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY=
github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=

View File

@@ -1,52 +0,0 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"errors"
"fmt"
)
func Decrypt(ciphertextB64, key, iv string) (string, error) {
ciphertext, err := base64.StdEncoding.DecodeString(ciphertextB64)
if err != nil {
return "", fmt.Errorf("base64 decode failed: %v", err)
}
block, err := aes.NewCipher([]byte(key))
if err != nil {
return "", fmt.Errorf("AES init failed: %v", err)
}
if len(iv) != aes.BlockSize {
return "", fmt.Errorf("IV must be %d bytes", aes.BlockSize)
}
mode := cipher.NewCBCDecrypter(block, []byte(iv))
plaintext := make([]byte, len(ciphertext))
mode.CryptBlocks(plaintext, ciphertext)
unpadded, err := pkcs7Unpad(plaintext)
if err != nil {
return "", fmt.Errorf("PKCS7 unpadding failed: %v", err)
}
return string(unpadded), nil
}
func pkcs7Unpad(data []byte) ([]byte, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
padSize := int(data[len(data)-1])
if padSize > len(data) || padSize > aes.BlockSize {
return nil, fmt.Errorf("invalid padding size: %d", padSize)
}
for i := range padSize {
if data[len(data)-1-i] != byte(padSize) {
return nil, errors.New("invalid padding content")
}
}
return data[:len(data)-padSize], nil
}

View File

@@ -1,41 +0,0 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"fmt"
)
func Encrypt(plaintext, key, iv string) (string, error) {
padded := pkcs7Pad([]byte(plaintext))
block, err := aes.NewCipher([]byte(key))
if err != nil {
return "", fmt.Errorf("AES init failed: %v", err)
}
if len(iv) != aes.BlockSize {
return "", fmt.Errorf("IV must be %d bytes", aes.BlockSize)
}
mode := cipher.NewCBCEncrypter(block, []byte(iv))
ciphertext := make([]byte, len(padded))
mode.CryptBlocks(ciphertext, padded)
ciphertextB64 := base64.StdEncoding.EncodeToString(ciphertext)
return ciphertextB64, nil
}
func pkcs7Pad(buf []byte) []byte {
bufLen := len(buf)
padLen := aes.BlockSize - bufLen%aes.BlockSize
padded := make([]byte, bufLen+padLen)
copy(padded, buf)
for i := range padLen {
padded[bufLen+i] = byte(padLen)
}
return padded
}

View File

@@ -1,31 +0,0 @@
package format
import (
"bytes"
"encoding/json"
)
func FormatJSON(input string) string {
var data any
if err := json.Unmarshal([]byte(input), &data); err != nil {
var out bytes.Buffer
if err := json.Indent(&out, []byte(input), "", " "); err == nil {
return out.String()
}
return input
}
var out bytes.Buffer
encoder := json.NewEncoder(&out)
encoder.SetEscapeHTML(false)
encoder.SetIndent("", " ")
if err := encoder.Encode(data); err != nil {
return input
}
formatted := out.String()
if len(formatted) > 0 && formatted[len(formatted)-1] == '\n' {
formatted = formatted[:len(formatted)-1]
}
return formatted
}

View File

@@ -1,5 +0,0 @@
package types
type Params interface {
param()
}

View File

@@ -1,10 +0,0 @@
package types
type Request struct {
Version int `json:"!version"`
ClientVersion string `json:"client_version"`
Id int `json:"id"`
JsonRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params Params `json:"params"`
}

View File

@@ -1,7 +0,0 @@
package types
type Response struct {
Code int32 `json:"code"`
Type string `json:"type"`
Data string `json:"data"`
}

54
src/crypto.rs Normal file
View File

@@ -0,0 +1,54 @@
use std::fmt;
use aes::Aes128;
use base64::{engine::general_purpose::STANDARD, Engine as _};
use cbc::{
cipher::{block_padding::Pkcs7, BlockDecryptMut, KeyIvInit},
Decryptor,
};
type Aes128CbcDec = Decryptor<Aes128>;
#[derive(Debug)]
pub enum CryptoError {
Base64(base64::DecodeError),
Aes(cbc::cipher::InvalidLength),
Unpad(cbc::cipher::block_padding::UnpadError),
}
impl fmt::Display for CryptoError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CryptoError::Base64(e) => write!(f, "Base64 decode failed: {}", e),
CryptoError::Aes(e) => write!(f, "AES decryption failed: {}", e),
CryptoError::Unpad(e) => write!(f, "PKCS7 unpadding failed: {}", e),
}
}
}
impl std::error::Error for CryptoError {}
pub fn decrypt(ciphertext_b64: &str, key: &str, iv: &str) -> Result<String, CryptoError> {
// 1. Base64 decode the ciphertext
let ciphertext = STANDARD
.decode(ciphertext_b64)
.map_err(CryptoError::Base64)?;
// 2. Initialize AES-128 in CBC mode
let key_bytes = key.as_bytes();
let iv_bytes = iv.as_bytes();
let decryptor = Aes128CbcDec::new_from_slices(key_bytes, iv_bytes).map_err(CryptoError::Aes)?;
// 3. Decrypt the ciphertext, handling padding
let decrypted_len = ciphertext.len();
let mut plaintext = vec![0u8; decrypted_len];
let copy_len = ciphertext.len();
plaintext[..copy_len].copy_from_slice(&ciphertext);
let plaintext_slice = decryptor
.decrypt_padded_mut::<Pkcs7>(&mut plaintext[..decrypted_len])
.map_err(CryptoError::Unpad)?;
// 4. Convert plaintext to a UTF-8 string
Ok(String::from_utf8_lossy(plaintext_slice).to_string())
}

81
src/main.rs Normal file
View File

@@ -0,0 +1,81 @@
use std::{net::SocketAddr, sync::Arc};
use anyhow::Context;
use axum::{routing::any, Router};
use dotenvy::dotenv;
use tracing::{error, info, level_filters::LevelFilter};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
mod crypto;
mod proxy;
#[derive(Clone)]
pub struct AppState {
pub client: reqwest::Client,
pub target_url: reqwest::Url,
pub key: String,
pub iv: String,
}
const DEFAULT_TARGET_URL: &str = "https://cloud.linspirer.com:883";
const DEFAULT_HOST: &str = "0.0.0.0";
const DEFAULT_PORT: &str = "8080";
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Load environment variables from .env file
dotenv().ok();
// Set up logging
tracing_subscriber::registry()
.with(fmt::layer())
.with(
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy(),
)
.init();
// Load configuration from environment
let key = std::env::var("LINSPIRER_KEY").context("LINSPIRER_KEY must be set")?;
let iv = std::env::var("LINSPIRER_IV").context("LINSPIRER_IV must be set")?;
let target_url_str = std::env::var("LINSPIRER_TARGET_URL");
let target_url_str = target_url_str.as_deref().unwrap_or(DEFAULT_TARGET_URL);
let target_url = reqwest::Url::parse(target_url_str)?;
let host_str = std::env::var("LINSPIRER_HOST");
let host = host_str.as_deref().unwrap_or(DEFAULT_HOST);
let port_str = std::env::var("LINSPIRER_PORT");
let port = port_str.as_deref().unwrap_or(DEFAULT_PORT);
let addr_str = format!("{}:{}", host, port);
let addr: SocketAddr = addr_str
.parse()
.context(format!("Invalid address format: {}", addr_str))?;
// Create a reqwest client that ignores SSL certificate verification
let client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.build()?;
// Create shared state
let state = Arc::new(AppState {
client,
target_url,
key,
iv,
});
// Build our application with a single route
let app = Router::new()
.route("/{*path}", any(proxy::proxy_handler))
.with_state(state);
// Run the server
info!("Proxy started on {} => {}", addr_str, target_url_str);
let listener = tokio::net::TcpListener::bind(&addr).await?;
if let Err(e) = axum::serve(listener, app.into_make_service()).await {
error!("Server error: {}", e);
}
Ok(())
}

150
src/proxy.rs Normal file
View File

@@ -0,0 +1,150 @@
use std::sync::Arc;
use axum::{
body::Bytes,
extract::{Path, State},
http::{Request, StatusCode},
response::{IntoResponse, Response},
};
use http_body_util::BodyExt;
use serde_json::Value;
use tracing::{error, info, warn};
use crate::{crypto, AppState};
pub async fn proxy_handler(
State(state): State<Arc<AppState>>,
Path(path): Path<String>,
req: Request<axum::body::Body>,
) -> impl IntoResponse {
let (parts, body) = req.into_parts();
let body_bytes = match body.collect().await {
Ok(body) => body.to_bytes(),
Err(e) => {
error!("Failed to read request body: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to read request body".to_string(),
)
.into_response();
}
};
let decrypted_request_log = match process_and_log_request(&body_bytes, &state.key, &state.iv) {
Ok(log) => log,
Err(e) => {
warn!("Failed to process request for logging: {}", e);
Value::String("Could not decrypt request".to_string())
}
};
let mut target_url = state.target_url.clone();
target_url.set_path(&path);
target_url.set_query(parts.uri.query());
let mut forwarded_req = state.client.request(parts.method, target_url);
if !parts.headers.is_empty() {
forwarded_req = forwarded_req.headers(parts.headers);
}
forwarded_req = forwarded_req.body(body_bytes);
let resp = match forwarded_req.send().await {
Ok(resp) => resp,
Err(e) => {
error!("Failed to forward request: {}", e);
return (
StatusCode::BAD_GATEWAY,
format!("Failed to forward request: {}", e),
)
.into_response();
}
};
let (resp_parts, resp_body_bytes) = match clone_response(resp).await {
Ok(tuple) => tuple,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to clone response".to_string(),
)
.into_response();
}
};
let decrypted_response_log =
match process_and_log_response(&resp_body_bytes, &state.key, &state.iv).await {
Ok(log) => log,
Err(e) => {
warn!("Failed to process response for logging: {}", e);
"Could not decrypt response".to_string()
}
};
let method = decrypted_request_log
.get("method")
.and_then(Value::as_str)
.unwrap_or("UNKNOWN");
info!(
"[{}] {}\nRequest:\n{}\nResponse:\n{}\n{}",
chrono::Local::now().format("%Y/%m/%d %H:%M:%S"),
method,
serde_json::to_string_pretty(&decrypted_request_log).unwrap_or_default(),
decrypted_response_log,
"-".repeat(80),
);
let mut response_builder = Response::builder().status(resp_parts.status);
if !resp_parts.headers.is_empty() {
*response_builder.headers_mut().unwrap() = resp_parts.headers;
}
response_builder
.body(axum::body::Body::from(resp_body_bytes))
.unwrap()
}
fn process_and_log_request(body_bytes: &Bytes, key: &str, iv: &str) -> anyhow::Result<Value> {
let mut request_data: Value = serde_json::from_slice(body_bytes)?;
if let Some(params_value) = request_data.get_mut("params")
&& let Some(params_str) = params_value.as_str() {
let params_str_owned = params_str.to_string();
match crypto::decrypt(&params_str_owned, key, iv) {
Ok(decrypted_str) => {
let decrypted_params: Value = serde_json::from_str(&decrypted_str)
.unwrap_or(Value::String(decrypted_str));
*params_value = decrypted_params;
}
Err(e) => {
*params_value = Value::String(format!("decrypt failed: {}", e));
}
}
}
Ok(request_data)
}
async fn process_and_log_response(
body_bytes: &Bytes,
key: &str,
iv: &str,
) -> anyhow::Result<String> {
let decrypted = crypto::decrypt(std::str::from_utf8(body_bytes)?, key, iv)?;
let formatted: Value = serde_json::from_str(&decrypted)?;
Ok(serde_json::to_string_pretty(&formatted)?)
}
async fn clone_response(
resp: reqwest::Response,
) -> anyhow::Result<(axum::http::response::Parts, Bytes)> {
let mut parts_builder = axum::http::response::Builder::new()
.status(resp.status())
.version(resp.version());
if !resp.headers().is_empty() {
*parts_builder.headers_mut().unwrap() = resp.headers().clone();
}
let parts = parts_builder.body(())?.into_parts().0;
let body_bytes = resp.bytes().await?;
Ok((parts, body_bytes))
}