diff --git a/proxy/http.go b/proxy/http.go index 35d5c1d..183b0d6 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -39,7 +39,7 @@ func (pxy *Proxy) handleHttp(lConn *net.TCPConn, pkt *packet.HttpPacket, ip stri log.Debug("[HTTP] New connection to the server ", pkt.Domain(), " ", rConn.LocalAddr()) - go Serve(rConn, lConn, "[HTTP]", lConn.RemoteAddr().String(), pkt.Domain(), pxy.timeout) + go Serve(rConn, lConn, "[HTTP]", lConn.RemoteAddr().String(), pkt.Domain(), pxy.timeout, pxy.bufferSize) _, err = rConn.Write(pkt.Raw()) if err != nil { @@ -49,5 +49,5 @@ func (pxy *Proxy) handleHttp(lConn *net.TCPConn, pkt *packet.HttpPacket, ip stri log.Debug("[HTTP] Sent a request to ", pkt.Domain()) - Serve(lConn, rConn, "[HTTP]", lConn.RemoteAddr().String(), pkt.Domain(), pxy.timeout) + Serve(lConn, rConn, "[HTTP]", lConn.RemoteAddr().String(), pkt.Domain(), pxy.timeout, pxy.bufferSize) } diff --git a/proxy/https.go b/proxy/https.go index 0cef83a..634f671 100644 --- a/proxy/https.go +++ b/proxy/https.go @@ -45,7 +45,8 @@ func (pxy *Proxy) handleHttps(lConn *net.TCPConn, exploit bool, initPkt *packet. log.Debug("[HTTPS] Sent 200 Connection Estabalished to ", lConn.RemoteAddr()) // Read client hello - clientHello, err := ReadBytes(lConn) + tmpBuffer := make([]byte, pxy.bufferSize) + clientHello, err := ReadBytes(lConn, tmpBuffer) if err != nil { log.Debug("[HTTPS] Error reading client hello from the client", err) return @@ -60,7 +61,7 @@ func (pxy *Proxy) handleHttps(lConn *net.TCPConn, exploit bool, initPkt *packet. // lConn.SetLinger(3) // rConn.SetLinger(3) - go Serve(rConn, lConn, "[HTTPS]", rConn.RemoteAddr().String(), initPkt.Domain(), pxy.timeout) + go Serve(rConn, lConn, "[HTTPS]", rConn.RemoteAddr().String(), initPkt.Domain(), pxy.timeout, pxy.bufferSize) if exploit { log.Debug("[HTTPS] Writing chunked client hello to ", initPkt.Domain()) @@ -77,7 +78,7 @@ func (pxy *Proxy) handleHttps(lConn *net.TCPConn, exploit bool, initPkt *packet. } } - Serve(lConn, rConn, "[HTTPS]", lConn.RemoteAddr().String(), initPkt.Domain(), pxy.timeout) + Serve(lConn, rConn, "[HTTPS]", lConn.RemoteAddr().String(), initPkt.Domain(), pxy.timeout, pxy.bufferSize) } func splitInChunks(bytes []byte, size int) [][]byte { diff --git a/proxy/io.go b/proxy/io.go index a8a1fff..f166bb3 100644 --- a/proxy/io.go +++ b/proxy/io.go @@ -9,8 +9,6 @@ import ( log "github.com/sirupsen/logrus" ) -const BUF_SIZE = 1024 - func WriteChunks(conn *net.TCPConn, c [][]byte) (n int, err error) { total := 0 for i := 0; i < len(c); i++ { @@ -25,45 +23,42 @@ func WriteChunks(conn *net.TCPConn, c [][]byte) (n int, err error) { return total, nil } -func ReadBytes(conn *net.TCPConn) ([]byte, error) { - ret := make([]byte, 0) - buf := make([]byte, BUF_SIZE) - - for { - n, err := conn.Read(buf) - if err != nil { - switch err.(type) { - case *net.OpError: - return nil, errors.New("timed out") - default: - return nil, err - } - } - ret = append(ret, buf[:n]...) - - if n < BUF_SIZE { - break - } - } - - if len(ret) == 0 { - return nil, io.EOF - } - - return ret, nil +func ReadBytes(conn *net.TCPConn, dest []byte) ([]byte, error) { + n, err := readBytesInternal(conn, dest) + return dest[:n], err } -func Serve(from *net.TCPConn, to *net.TCPConn, proto string, fd string, td string, timeout int) { - proto += " " - +func readBytesInternal(in io.Reader, dest []byte) (int, error) { + totalRead := 0 for { - if timeout > 0 { - from.SetReadDeadline( - time.Now().Add(time.Millisecond * time.Duration(timeout)), - ) - } + numRead, readErr := in.Read(dest[totalRead:]) + totalRead += numRead + if readErr != nil { + switch readErr.(type) { + case *net.OpError: + return totalRead, errors.New("timed out") + default: + return totalRead, readErr + } + } + if totalRead == 0 { + return 0, io.EOF + } + return totalRead, nil + } +} - buf, err := ReadBytes(from) +func Serve(from *net.TCPConn, to *net.TCPConn, proto string, fd string, td string, timeout int, bufferSize int) { + proto += " " + buf := make([]byte, bufferSize) + for { + if timeout > 0 { + from.SetReadDeadline( + time.Now().Add(time.Millisecond * time.Duration(timeout)), + ) + } + + bytesRead, err := ReadBytes(from, buf) if err != nil { if err == io.EOF { log.Debug(proto, "Finished ", fd) @@ -73,7 +68,7 @@ func Serve(from *net.TCPConn, to *net.TCPConn, proto string, fd string, td strin return } - if _, err := to.Write(buf); err != nil { + if _, err := to.Write(bytesRead); err != nil { log.Debug(proto, "Error Writing to ", td) return } diff --git a/proxy/proxy.go b/proxy/proxy.go index 24b16e6..d4301ae 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -20,6 +20,7 @@ type Proxy struct { resolver *dns.DnsResolver windowSize int allowedPattern []*regexp.Regexp + bufferSize int } func New(config *util.Config) *Proxy { @@ -30,6 +31,7 @@ func New(config *util.Config) *Proxy { windowSize: *config.WindowSize, allowedPattern: config.AllowedPattern, resolver: dns.NewResolver(config), + bufferSize: *config.BufferSize, } } @@ -46,7 +48,7 @@ func (pxy *Proxy) Start() { log.Println("[PROXY] Created a listener on port", pxy.port) if len(pxy.allowedPattern) > 0 { - log.Println("[PROXY] Number of white-listed pattern:", len(pxy.allowedPattern)) + log.Println("[PROXY] Number of white-listed pattern:", len(pxy.allowedPattern)) } for { @@ -57,7 +59,8 @@ func (pxy *Proxy) Start() { } go func() { - b, err := ReadBytes(conn.(*net.TCPConn)) + tmpBuf := make([]byte, pxy.bufferSize) + b, err := ReadBytes(conn.(*net.TCPConn), tmpBuf) if err != nil { return } @@ -77,8 +80,8 @@ func (pxy *Proxy) Start() { return } - matched := pxy.patternMatches([]byte(pkt.Domain())) - useSystemDns := !matched + matched := pxy.patternMatches([]byte(pkt.Domain())) + useSystemDns := !matched ip, err := pxy.resolver.Lookup(pkt.Domain(), useSystemDns) if err != nil { @@ -113,11 +116,11 @@ func (pxy *Proxy) patternMatches(bytes []byte) bool { for _, pattern := range pxy.allowedPattern { if pattern.Match(bytes) { - return true - } + return true + } } - return false + return false } func isLoopedRequest(ip net.IP) bool { diff --git a/util/config.go b/util/config.go index aca391c..2b5a9d1 100644 --- a/util/config.go +++ b/util/config.go @@ -22,6 +22,7 @@ type Config struct { AllowedPattern []*regexp.Regexp WindowSize *int Version *bool + BufferSize *int } type StringArray []string @@ -58,7 +59,7 @@ when not given, the client hello packet will be sent in two parts: fragmentation for the first data packet and the rest `) config.Version = flag.Bool("v", false, "print spoof-dpi's version; this may contain some other relevant information") - + config.BufferSize = flag.Int("buffer-size", 1024, "buffer size, in number of bytes, is the maximum amount of data that can be read at once from a remote resource") var allowedPattern StringArray flag.Var( &allowedPattern,