mirror of
https://github.com/xvzc/SpoofDPI.git
synced 2024-12-22 06:15:51 +00:00
79d255719e
* fix: http proxy * chore: make a function for setting connection timeout * chore: rename parameters for consistency
161 lines
3.6 KiB
Go
161 lines
3.6 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"os"
|
|
"regexp"
|
|
"strconv"
|
|
|
|
"github.com/xvzc/SpoofDPI/dns"
|
|
"github.com/xvzc/SpoofDPI/packet"
|
|
"github.com/xvzc/SpoofDPI/proxy/handler"
|
|
"github.com/xvzc/SpoofDPI/util"
|
|
"github.com/xvzc/SpoofDPI/util/log"
|
|
)
|
|
|
|
const scopeProxy = "PROXY"
|
|
|
|
type Proxy struct {
|
|
addr string
|
|
port int
|
|
timeout int
|
|
resolver *dns.Dns
|
|
windowSize int
|
|
enableDoh bool
|
|
allowedPattern []*regexp.Regexp
|
|
}
|
|
|
|
type Handler interface {
|
|
Serve(ctx context.Context, lConn *net.TCPConn, pkt *packet.HttpRequest, ip string)
|
|
}
|
|
|
|
func New(config *util.Config) *Proxy {
|
|
return &Proxy{
|
|
addr: config.Addr,
|
|
port: config.Port,
|
|
timeout: config.Timeout,
|
|
windowSize: config.WindowSize,
|
|
enableDoh: config.EnableDoh,
|
|
allowedPattern: config.AllowedPatterns,
|
|
resolver: dns.NewDns(config),
|
|
}
|
|
}
|
|
|
|
func (pxy *Proxy) Start(ctx context.Context) {
|
|
ctx = util.GetCtxWithScope(ctx, scopeProxy)
|
|
logger := log.GetCtxLogger(ctx)
|
|
|
|
l, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP(pxy.addr), Port: pxy.port})
|
|
if err != nil {
|
|
logger.Fatal().Msgf("error creating listener: %s", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
if pxy.timeout > 0 {
|
|
logger.Info().Msgf("connection timeout is set to %d ms", pxy.timeout)
|
|
}
|
|
|
|
logger.Info().Msgf("created a listener on port %d", pxy.port)
|
|
if len(pxy.allowedPattern) > 0 {
|
|
logger.Info().Msgf("number of white-listed pattern: %d", len(pxy.allowedPattern))
|
|
}
|
|
|
|
for {
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
logger.Fatal().Msgf("error accepting connection: %s", err)
|
|
continue
|
|
}
|
|
|
|
go func() {
|
|
ctx := util.GetCtxWithTraceId(ctx)
|
|
logger := log.GetCtxLogger(ctx)
|
|
|
|
pkt, err := packet.ReadHttpRequest(conn)
|
|
if err != nil {
|
|
logger.Debug().Msgf("error while parsing request: %s", err)
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
pkt.Tidy()
|
|
|
|
logger.Debug().Msgf("request from %s\n\n%s", conn.RemoteAddr(), string(pkt.Raw()))
|
|
|
|
if !pkt.IsValidMethod() {
|
|
logger.Debug().Msgf("unsupported method: %s", pkt.Method())
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
matched := pxy.patternMatches([]byte(pkt.Domain()))
|
|
useSystemDns := !matched
|
|
|
|
ip, err := pxy.resolver.ResolveHost(ctx, pkt.Domain(), pxy.enableDoh, useSystemDns)
|
|
if err != nil {
|
|
logger.Debug().Msgf("error while dns lookup: %s %s", pkt.Domain(), err)
|
|
conn.Write([]byte(pkt.Version() + " 502 Bad Gateway\r\n\r\n"))
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Avoid recursively querying self
|
|
if pkt.Port() == strconv.Itoa(pxy.port) && isLoopedRequest(ctx, net.ParseIP(ip)) {
|
|
logger.Error().Msg("looped request has been detected. aborting.")
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
var h Handler
|
|
if pkt.IsConnectMethod() {
|
|
h = handler.NewHttpsHandler(pxy.timeout, pxy.windowSize, pxy.allowedPattern, matched)
|
|
} else {
|
|
h = handler.NewHttpHandler(pxy.timeout)
|
|
}
|
|
|
|
h.Serve(ctx, conn.(*net.TCPConn), pkt, ip)
|
|
}()
|
|
}
|
|
}
|
|
|
|
func (pxy *Proxy) patternMatches(bytes []byte) bool {
|
|
if pxy.allowedPattern == nil {
|
|
return true
|
|
}
|
|
|
|
for _, pattern := range pxy.allowedPattern {
|
|
if pattern.Match(bytes) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func isLoopedRequest(ctx context.Context, ip net.IP) bool {
|
|
if ip.IsLoopback() {
|
|
return true
|
|
}
|
|
|
|
logger := log.GetCtxLogger(ctx)
|
|
|
|
// Get list of available addresses
|
|
// See `ip -4 addr show`
|
|
addr, err := net.InterfaceAddrs() // needs AF_NETLINK on linux
|
|
if err != nil {
|
|
logger.Error().Msgf("error while getting addresses of our network interfaces: %s", err)
|
|
return false
|
|
}
|
|
|
|
for _, addr := range addr {
|
|
if ipnet, ok := addr.(*net.IPNet); ok {
|
|
if ipnet.IP.Equal(ip) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|