Files
mylinspirer/cmd/proxy/main.go

178 lines
4.1 KiB
Go

package main
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"git.imxyy.top/imxyy1soope1/MyLinspirer/internal/utils/crypto"
"git.imxyy.top/imxyy1soope1/MyLinspirer/internal/utils/format"
)
const (
targetURL = "https://cloud.linspirer.com:883"
)
var (
proxy *httputil.ReverseProxy
host string
port string
key string
iv string
once sync.Once
)
func init() {
once.Do(func() {
key = os.Getenv("LINSPIRER_KEY")
iv = os.Getenv("LINSPIRER_IV")
if key == "" || iv == "" {
log.Fatalf("LINSPIRER_KEY or LINSPIRER_IV is not set")
}
})
}
func main() {
flag.StringVar(&host, "a", "", "listening host")
flag.StringVar(&port, "p", "8080", "listening port")
flag.Parse()
portNum, err := strconv.Atoi(port)
if err != nil {
log.Fatalf("invalid port: %v", err)
}
if portNum < 1 || portNum > 65535 {
log.Fatalf("port out of range (1-65535): %d", portNum)
}
addr := host + ":" + port
if host == "" {
addr = ":" + port
}
target, _ := url.Parse(targetURL)
proxy = &httputil.ReverseProxy{
Rewrite: func(req *httputil.ProxyRequest) {
startTime := time.Now()
recordRequest(req, startTime)
req.SetURL(target)
},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
ModifyResponse: logResponse,
}
log.Printf("Proxy started on %s => %s", addr, targetURL)
log.Fatal(http.ListenAndServe(addr, http.HandlerFunc(proxy.ServeHTTP)))
}
func recordRequest(req *httputil.ProxyRequest, startTime time.Time) {
if req.Out.Body == nil {
return
}
body, _ := io.ReadAll(req.Out.Body)
req.Out.Body = io.NopCloser(bytes.NewBuffer(body))
var requestData map[string]any
if err := json.Unmarshal(body, &requestData); err == nil {
if paramsEnc, ok := requestData["params"].(string); ok {
if decrypted, err := crypto.Decrypt(paramsEnc, key, iv); err == nil {
var unmarshaledParams map[string]any
err = json.Unmarshal([]byte(decrypted), &unmarshaledParams)
if err == nil {
requestData["params"] = unmarshaledParams
} else {
requestData["params"] = decrypted
}
} else {
requestData["params"] = fmt.Sprintf("decrypt failed: %v", err)
}
}
} else {
requestData = map[string]any{"request": fmt.Sprintf("JSON parse error: %v", err)}
}
ctx := context.WithValue(req.Out.Context(), "startTime", startTime)
ctx = context.WithValue(ctx, "decryptedRequest", requestData)
req.Out = req.Out.WithContext(ctx)
}
func extractContextValue[T any](ctx context.Context, name string) (val T, ok bool) {
valAny := ctx.Value(name)
if valAny == nil {
log.Printf("[ERROR] %s not found in context", name)
return
}
val, ok = valAny.(T)
if !ok {
log.Printf("[ERROR] invalid %s type: %T", name, valAny)
return
}
return
}
func logResponse(resp *http.Response) error {
startTime, timeOk := extractContextValue[time.Time](resp.Request.Context(), "startTime")
decryptedRequest, reqOk := extractContextValue[map[string]any](resp.Request.Context(), "decryptedRequest")
if !timeOk || !reqOk {
return nil
}
body, _ := io.ReadAll(resp.Body)
resp.Body = io.NopCloser(bytes.NewBuffer(body))
var response []byte
if resp.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(bytes.NewReader(body))
if err == nil {
defer gzReader.Close()
if decompressed, err := io.ReadAll(gzReader); err == nil {
response = decompressed
} else {
fmt.Appendf(response, "gzip decompress failed: %v", err)
}
} else {
fmt.Appendf(response, "gzip init failed: %v", err)
}
} else {
response = body
}
respPlaintext := "N/A"
if decrypted, err := crypto.Decrypt(string(response), key, iv); err == nil {
respPlaintext = format.FormatJSON(decrypted)
} else {
respPlaintext = fmt.Sprintf("decrypt failed: %v", err)
}
requestJSON, _ := json.MarshalIndent(decryptedRequest, "", " ")
log.Printf("[%s] %s\nRequest:\n%s\nResponse:\n%s\n%s\n",
startTime.Format("2006/01/02 15:04:05"),
decryptedRequest["method"].(string),
format.FormatJSON(string(requestJSON)),
respPlaintext,
strings.Repeat("-", 80),
)
return nil
}