Change types of flags that can't be negative to uint*. (#231)

This commit is contained in:
Ledorub 2024-08-31 17:50:02 +03:00 committed by GitHub
parent 7d6bc4c696
commit 19ec6980ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 14 deletions

View File

@ -1,22 +1,25 @@
package util package util
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"strconv"
"unsafe"
) )
type Args struct { type Args struct {
Addr string Addr string
Port int Port uint16
DnsAddr string DnsAddr string
DnsPort int DnsPort uint16
EnableDoh bool EnableDoh bool
Debug bool Debug bool
Banner bool Banner bool
SystemProxy bool SystemProxy bool
Timeout int Timeout uint16
AllowedPattern StringArray AllowedPattern StringArray
WindowSize int WindowSize uint16
Version bool Version bool
} }
@ -35,15 +38,15 @@ func ParseArgs() *Args {
args := new(Args) args := new(Args)
flag.StringVar(&args.Addr, "addr", "127.0.0.1", "listen address") flag.StringVar(&args.Addr, "addr", "127.0.0.1", "listen address")
flag.IntVar(&args.Port, "port", 8080, "port") uintNVar(&args.Port, "port", 8080, "port")
flag.StringVar(&args.DnsAddr, "dns-addr", "8.8.8.8", "dns address") flag.StringVar(&args.DnsAddr, "dns-addr", "8.8.8.8", "dns address")
flag.IntVar(&args.DnsPort, "dns-port", 53, "port number for dns") uintNVar(&args.DnsPort, "dns-port", 53, "port number for dns")
flag.BoolVar(&args.EnableDoh, "enable-doh", false, "enable 'dns-over-https'") flag.BoolVar(&args.EnableDoh, "enable-doh", false, "enable 'dns-over-https'")
flag.BoolVar(&args.Debug, "debug", false, "enable debug output") flag.BoolVar(&args.Debug, "debug", false, "enable debug output")
flag.BoolVar(&args.Banner, "banner", true, "enable banner") flag.BoolVar(&args.Banner, "banner", true, "enable banner")
flag.BoolVar(&args.SystemProxy, "system-proxy", true, "enable system-wide proxy") flag.BoolVar(&args.SystemProxy, "system-proxy", true, "enable system-wide proxy")
flag.IntVar(&args.Timeout, "timeout", 0, "timeout in milliseconds; no timeout when not given") uintNVar(&args.Timeout, "timeout", 0, "timeout in milliseconds; no timeout when not given")
flag.IntVar(&args.WindowSize, "window-size", 0, `chunk size, in number of bytes, for fragmented client hello, uintNVar(&args.WindowSize, "window-size", 0, `chunk size, in number of bytes, for fragmented client hello,
try lower values if the default value doesn't bypass the DPI; try lower values if the default value doesn't bypass the DPI;
when not given, the client hello packet will be sent in two parts: when not given, the client hello packet will be sent in two parts:
fragmentation for the first data packet and the rest fragmentation for the first data packet and the rest
@ -59,3 +62,59 @@ fragmentation for the first data packet and the rest
return args return args
} }
var (
errParse = errors.New("parse error")
errRange = errors.New("value out of range")
)
type unsigned interface {
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
}
func uintNVar[T unsigned](p *T, name string, value T, usage string) {
flag.CommandLine.Var(newUintNValue(value, p), name, usage)
}
type uintNValue[T unsigned] struct {
val *T
}
func newUintNValue[T unsigned](val T, p *T) *uintNValue[T] {
*p = val
return &uintNValue[T]{val: p}
}
func (u *uintNValue[T]) Set(s string) error {
size := int(unsafe.Sizeof(*u.val) * 8)
v, err := strconv.ParseUint(s, 0, size)
if err != nil {
err = numError(err)
}
*u.val = T(v)
return err
}
func (u *uintNValue[T]) Get() any {
if u.val == nil {
return T(0)
}
return *u.val
}
func (u *uintNValue[T]) String() string {
if u.val == nil {
return "0"
}
return strconv.FormatUint(uint64(*u.val), 10)
}
func numError(err error) error {
if errors.Is(err, strconv.ErrSyntax) {
return errParse
}
if errors.Is(err, strconv.ErrRange) {
return errRange
}
return err
}

View File

@ -15,7 +15,7 @@ type Config struct {
DnsPort int DnsPort int
EnableDoh bool EnableDoh bool
Debug bool Debug bool
Banner bool Banner bool
SystemProxy bool SystemProxy bool
Timeout int Timeout int
WindowSize int WindowSize int
@ -33,16 +33,16 @@ func GetConfig() *Config {
func (c *Config) Load(args *Args) { func (c *Config) Load(args *Args) {
c.Addr = args.Addr c.Addr = args.Addr
c.Port = args.Port c.Port = int(args.Port)
c.DnsAddr = args.DnsAddr c.DnsAddr = args.DnsAddr
c.DnsPort = args.DnsPort c.DnsPort = int(args.DnsPort)
c.Debug = args.Debug c.Debug = args.Debug
c.EnableDoh = args.EnableDoh c.EnableDoh = args.EnableDoh
c.Banner = args.Banner c.Banner = args.Banner
c.SystemProxy = args.SystemProxy c.SystemProxy = args.SystemProxy
c.Timeout = args.Timeout c.Timeout = int(args.Timeout)
c.AllowedPatterns = parseAllowedPattern(args.AllowedPattern) c.AllowedPatterns = parseAllowedPattern(args.AllowedPattern)
c.WindowSize = args.WindowSize c.WindowSize = int(args.WindowSize)
} }
func parseAllowedPattern(patterns StringArray) []*regexp.Regexp { func parseAllowedPattern(patterns StringArray) []*regexp.Regexp {
@ -67,7 +67,7 @@ func PrintColoredBanner() {
{Level: 0, Text: "DEBUG : " + fmt.Sprint(config.Debug)}, {Level: 0, Text: "DEBUG : " + fmt.Sprint(config.Debug)},
}).Render() }).Render()
pterm.DefaultBasicText.Println("Press 'CTRL + c' to quit") pterm.DefaultBasicText.Println("Press 'CTRL + c' to quit")
} }
func PrintSimpleInfo() { func PrintSimpleInfo() {