diff --git a/net/conn.go b/net/conn.go index 2b5989c..17b3ee3 100644 --- a/net/conn.go +++ b/net/conn.go @@ -13,11 +13,15 @@ import ( const BUF_SIZE = 1024 type Conn struct { - conn net.Conn + conn *net.TCPConn } func (c *Conn) CloseWrite() { - c.conn.(*net.TCPConn).CloseWrite() + c.conn.CloseWrite() +} + +func (c *Conn) CloseRead() { + c.conn.CloseRead() } func (c *Conn) Close() { @@ -51,7 +55,7 @@ func (c *Conn) SetDeadLine(t time.Time) (error) { } func (c *Conn) SetKeepAlive(b bool) (error) { - c.conn.(*net.TCPConn).SetKeepAlive(b) + c.conn.SetKeepAlive(b) return nil } @@ -95,6 +99,8 @@ func (conn *Conn) ReadBytes() ([]byte, error) { func (lConn *Conn) HandleHttp(p *packet.HttpPacket) { defer func() { + lConn.CloseRead() + lConn.CloseWrite() lConn.Close() log.Debug("[HTTP] Closing client Connection.. ", lConn.RemoteAddr()) }() @@ -111,19 +117,21 @@ func (lConn *Conn) HandleHttp(p *packet.HttpPacket) { log.Debug("[DOH] Found ", ip, " with ", p.Domain()) // Create connection to server - var port = ":80" + var port = "80" if p.Port() != "" { - port = ":" + p.Port() + port = p.Port() } - rConn, err := Dial("tcp", ip + port) + rConn, err := DialTCP("tcp", ip, port) if err != nil { log.Debug("[HTTP] ", err) return } defer func() { - defer rConn.Close() + rConn.CloseRead() + rConn.CloseWrite() + rConn.Close() log.Debug("[HTTP] Closing server Connection.. ", p.Domain(), " ", rConn.LocalAddr()) }() @@ -144,6 +152,8 @@ func (lConn *Conn) HandleHttp(p *packet.HttpPacket) { func (lConn *Conn) HandleHttps(p *packet.HttpPacket) { defer func() { + lConn.CloseRead() + lConn.CloseWrite() lConn.Close() log.Debug("[HTTPS] Closing client Connection.. ", lConn.RemoteAddr()) }() @@ -158,19 +168,21 @@ func (lConn *Conn) HandleHttps(p *packet.HttpPacket) { log.Debug("[DOH] Found ", ip, " with ", p.Domain()) // Create a connection to the requested server - var port = ":443" + var port = "443" if p.Port() != "" { - port = ":" + p.Port() + port = p.Port() } - rConn, err := Dial("tcp", ip + port) + rConn, err := DialTCP("tcp4", ip, port) if err != nil { log.Debug("[HTTPS] ", err) return } defer func() { - defer rConn.Close() + rConn.CloseRead() + rConn.CloseWrite() + rConn.Close() log.Debug("[HTTPS] Closing server Connection.. ", p.Domain(), " ", rConn.LocalAddr()) }() diff --git a/net/dial.go b/net/dial.go index 7c49ab5..ed2d88a 100644 --- a/net/dial.go +++ b/net/dial.go @@ -2,10 +2,11 @@ package net import ( "net" + "strconv" ) -func Listen(network, address string) (Listener, error) { - l, err := net.Listen(network, address) +func ListenTCP(network string, addr *TCPAddr) (Listener, error) { + l, err := net.ListenTCP(network, addr.Addr) if err != nil { return Listener{}, err } @@ -13,8 +14,15 @@ func Listen(network, address string) (Listener, error) { return Listener{listener: l}, nil } -func Dial(network, address string) (*Conn, error) { - conn, err := net.Dial(network, address) +func DialTCP(network string, ip string, port string) (*Conn, error) { + p, _ := strconv.Atoi(port) + + addr := &net.TCPAddr{ + IP: net.ParseIP(ip), + Port: p, + } + + conn, err := net.DialTCP(network, nil, addr) if err != nil { return &Conn{}, err } diff --git a/net/listener.go b/net/listener.go index 7d99dc6..544b3dd 100644 --- a/net/listener.go +++ b/net/listener.go @@ -5,14 +5,14 @@ import ( ) type Listener struct { - listener net.Listener + listener *net.TCPListener } -func (l *Listener) Accept() (Conn, error) { - conn, err := l.listener.Accept() +func (l *Listener) Accept() (*Conn, error) { + conn, err := l.listener.AcceptTCP() if err != nil { - return Conn{}, err + return &Conn{}, err } - return Conn{conn: conn}, nil + return &Conn{conn: conn}, nil } diff --git a/net/tcp.go b/net/tcp.go new file mode 100644 index 0000000..cc90877 --- /dev/null +++ b/net/tcp.go @@ -0,0 +1,21 @@ +package net + +import ( + "net" +) + +type TCPAddr struct { + Addr *net.TCPAddr +} + + +func TcpAddr(ip string, port int) (*TCPAddr) { + addr := &net.TCPAddr { + IP: net.ParseIP(ip), + Port: port, + } + + return &TCPAddr{ + Addr: addr, + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 187f21c..bcf9d29 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -9,27 +9,27 @@ import ( ) type Proxy struct { - port string addr string + port int } -func New(addr string, port string) *Proxy { +func New(addr string, port int) *Proxy { return &Proxy{ addr: addr, port: port, } } -func (p *Proxy) TcpAddr() string { - return p.addr + ":" + p.port +func (p *Proxy) TcpAddr() *net.TCPAddr { + return net.TcpAddr(p.addr, p.port) } -func (p *Proxy) Port() string { +func (p *Proxy) Port() int { return p.port } func (p *Proxy) Start() { - l, err := net.Listen("tcp", p.TcpAddr()) + l, err := net.ListenTCP("tcp4", p.TcpAddr()) if err != nil { log.Fatal("Error creating listener: ", err) os.Exit(1) diff --git a/util/os.go b/util/os.go index bc21b6b..2e15db5 100644 --- a/util/os.go +++ b/util/os.go @@ -1,12 +1,13 @@ package util import ( + "fmt" "os/exec" "runtime" "strings" ) -func SetOsProxy(port string) error { +func SetOsProxy(port int) error { if runtime.GOOS != "darwin" { return nil } @@ -17,12 +18,12 @@ func SetOsProxy(port string) error { return err } - _, err = exec.Command("sh", "-c", "networksetup -setwebproxy "+ "'" +strings.TrimSpace(string(network)) + "'" + " 127.0.0.1 "+port).Output() + _, err = exec.Command("sh", "-c", "networksetup -setwebproxy "+ "'" +strings.TrimSpace(string(network)) + "'" + " 127.0.0.1 "+ fmt.Sprint(port)).Output() if err != nil { return err } - _, err = exec.Command("sh", "-c", "networksetup -setsecurewebproxy " + "'" + strings.TrimSpace(string(network))+"'" + " 127.0.0.1 "+port).Output() + _, err = exec.Command("sh", "-c", "networksetup -setsecurewebproxy " + "'" + strings.TrimSpace(string(network))+"'" + " 127.0.0.1 "+ fmt.Sprint(port)).Output() if err != nil { return err } diff --git a/util/util.go b/util/util.go index 3c63abb..16364dd 100644 --- a/util/util.go +++ b/util/util.go @@ -7,9 +7,9 @@ import ( "github.com/pterm/pterm" ) -func ParseArgs() (string,string, string, bool) { +func ParseArgs() (string, int, string, bool) { addr := flag.String("addr", "127.0.0.1", "Listen addr") - port := flag.String("port", "8080", "port") + port := flag.Int("port", 8080, "port") dns := flag.String("dns", "8.8.8.8", "DNS server") debug := flag.Bool("debug", false, "true | false") @@ -18,14 +18,14 @@ func ParseArgs() (string,string, string, bool) { return *addr, *port, *dns, *debug } -func PrintWelcome(addr, port string, dns string, debug bool) { +func PrintWelcome(addr string, port int, dns string, debug bool) { cyan := pterm.NewLettersFromStringWithStyle("Spoof", pterm.NewStyle(pterm.FgCyan)) purple := pterm.NewLettersFromStringWithStyle("DPI", pterm.NewStyle(pterm.FgLightMagenta)) pterm.DefaultBigText.WithLetters(cyan, purple).Render() pterm.DefaultBulletList.WithItems([]pterm.BulletListItem{ {Level: 0, Text: "ADDR : " + addr}, - {Level: 0, Text: "PORT : " + port}, + {Level: 0, Text: "PORT : " + fmt.Sprint(port)}, {Level: 0, Text: "DNS : " + dns}, {Level: 0, Text: "DEBUG : " + fmt.Sprint(debug)}, }).Render()