From 19ec6980ba769b460cc8e2e4c47b3e27a3c283bf Mon Sep 17 00:00:00 2001 From: Ledorub Date: Sat, 31 Aug 2024 17:50:02 +0300 Subject: [PATCH] Change types of flags that can't be negative to uint*. (#231) --- util/args.go | 75 ++++++++++++++++++++++++++++++++++++++++++++------ util/config.go | 12 ++++---- 2 files changed, 73 insertions(+), 14 deletions(-) diff --git a/util/args.go b/util/args.go index c7f35c7..c592b98 100644 --- a/util/args.go +++ b/util/args.go @@ -1,22 +1,25 @@ package util import ( + "errors" "flag" "fmt" + "strconv" + "unsafe" ) type Args struct { Addr string - Port int + Port uint16 DnsAddr string - DnsPort int + DnsPort uint16 EnableDoh bool Debug bool Banner bool SystemProxy bool - Timeout int + Timeout uint16 AllowedPattern StringArray - WindowSize int + WindowSize uint16 Version bool } @@ -35,15 +38,15 @@ func ParseArgs() *Args { args := new(Args) flag.StringVar(&args.Addr, "addr", "127.0.0.1", "listen address") - flag.IntVar(&args.Port, "port", 8080, "port") + uintNVar(&args.Port, "port", 8080, "port") flag.StringVar(&args.DnsAddr, "dns-addr", "8.8.8.8", "dns address") - flag.IntVar(&args.DnsPort, "dns-port", 53, "port number for dns") + uintNVar(&args.DnsPort, "dns-port", 53, "port number for dns") flag.BoolVar(&args.EnableDoh, "enable-doh", false, "enable 'dns-over-https'") flag.BoolVar(&args.Debug, "debug", false, "enable debug output") flag.BoolVar(&args.Banner, "banner", true, "enable banner") flag.BoolVar(&args.SystemProxy, "system-proxy", true, "enable system-wide proxy") - flag.IntVar(&args.Timeout, "timeout", 0, "timeout in milliseconds; no timeout when not given") - flag.IntVar(&args.WindowSize, "window-size", 0, `chunk size, in number of bytes, for fragmented client hello, + uintNVar(&args.Timeout, "timeout", 0, "timeout in milliseconds; no timeout when not given") + uintNVar(&args.WindowSize, "window-size", 0, `chunk size, in number of bytes, for fragmented client hello, try lower values if the default value doesn't bypass the DPI; when not given, the client hello packet will be sent in two parts: fragmentation for the first data packet and the rest @@ -59,3 +62,59 @@ fragmentation for the first data packet and the rest return args } + +var ( + errParse = errors.New("parse error") + errRange = errors.New("value out of range") +) + +type unsigned interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} + +func uintNVar[T unsigned](p *T, name string, value T, usage string) { + flag.CommandLine.Var(newUintNValue(value, p), name, usage) +} + +type uintNValue[T unsigned] struct { + val *T +} + +func newUintNValue[T unsigned](val T, p *T) *uintNValue[T] { + *p = val + return &uintNValue[T]{val: p} +} + +func (u *uintNValue[T]) Set(s string) error { + size := int(unsafe.Sizeof(*u.val) * 8) + v, err := strconv.ParseUint(s, 0, size) + if err != nil { + err = numError(err) + } + *u.val = T(v) + return err +} + +func (u *uintNValue[T]) Get() any { + if u.val == nil { + return T(0) + } + return *u.val +} + +func (u *uintNValue[T]) String() string { + if u.val == nil { + return "0" + } + return strconv.FormatUint(uint64(*u.val), 10) +} + +func numError(err error) error { + if errors.Is(err, strconv.ErrSyntax) { + return errParse + } + if errors.Is(err, strconv.ErrRange) { + return errRange + } + return err +} diff --git a/util/config.go b/util/config.go index aa7e8b3..c6ec2a8 100644 --- a/util/config.go +++ b/util/config.go @@ -15,7 +15,7 @@ type Config struct { DnsPort int EnableDoh bool Debug bool - Banner bool + Banner bool SystemProxy bool Timeout int WindowSize int @@ -33,16 +33,16 @@ func GetConfig() *Config { func (c *Config) Load(args *Args) { c.Addr = args.Addr - c.Port = args.Port + c.Port = int(args.Port) c.DnsAddr = args.DnsAddr - c.DnsPort = args.DnsPort + c.DnsPort = int(args.DnsPort) c.Debug = args.Debug c.EnableDoh = args.EnableDoh c.Banner = args.Banner c.SystemProxy = args.SystemProxy - c.Timeout = args.Timeout + c.Timeout = int(args.Timeout) c.AllowedPatterns = parseAllowedPattern(args.AllowedPattern) - c.WindowSize = args.WindowSize + c.WindowSize = int(args.WindowSize) } func parseAllowedPattern(patterns StringArray) []*regexp.Regexp { @@ -67,7 +67,7 @@ func PrintColoredBanner() { {Level: 0, Text: "DEBUG : " + fmt.Sprint(config.Debug)}, }).Render() - pterm.DefaultBasicText.Println("Press 'CTRL + c' to quit") + pterm.DefaultBasicText.Println("Press 'CTRL + c' to quit") } func PrintSimpleInfo() {