diff --git a/dns/resolver/doh.go b/dns/resolver/doh.go index 266e4b9..1bce0de 100644 --- a/dns/resolver/doh.go +++ b/dns/resolver/doh.go @@ -41,11 +41,7 @@ func NewDOHClient(host string) *DOHResolver { } func (r *DOHResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) { - sendMsg := func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - return r.dohExchange(ctx, msg) - } - - resultCh := lookup(ctx, host, qTypes, sendMsg) + resultCh := lookupAllTypes(ctx, host, qTypes, r.exchange) addrs, err := processResults(ctx, resultCh) return addrs, err } @@ -54,7 +50,7 @@ func (r *DOHResolver) String() string { return fmt.Sprintf("doh resolver(%s)", r.upstream) } -func (r *DOHResolver) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { +func (r *DOHResolver) exchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { return nil, err @@ -91,18 +87,9 @@ func (r *DOHResolver) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, err return nil, err } - return resultMsg, nil -} - -func (r *DOHResolver) dohExchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - res, err := r.dohQuery(ctx, msg) - if err != nil { - return nil, err - } - - if res.Rcode != dns.RcodeSuccess { + if resultMsg.Rcode != dns.RcodeSuccess { return nil, errors.New("doh rcode wasn't successful") } - return res, nil + return resultMsg, nil } diff --git a/dns/resolver/general.go b/dns/resolver/general.go index 44441c4..2f83ffd 100644 --- a/dns/resolver/general.go +++ b/dns/resolver/general.go @@ -26,12 +26,7 @@ func NewGeneralClient(server string) *GeneralResolver { } func (r *GeneralResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) { - sendMsg := func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - resp, _, err := r.client.Exchange(msg, r.server) - return resp, err - } - - resultCh := lookup(ctx, host, qTypes, sendMsg) + resultCh := lookupAllTypes(ctx, host, qTypes, r.exchange) addrs, err := processResults(ctx, resultCh) return addrs, err } @@ -39,3 +34,8 @@ func (r *GeneralResolver) Resolve(ctx context.Context, host string, qTypes []uin func (c *GeneralResolver) String() string { return fmt.Sprintf("general resolver(%s)", c.server) } + +func (r *GeneralResolver) exchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + resp, _, err := r.client.Exchange(msg, r.server) + return resp, err +} diff --git a/dns/resolver/resolver.go b/dns/resolver/resolver.go index f2033ac..031f73e 100644 --- a/dns/resolver/resolver.go +++ b/dns/resolver/resolver.go @@ -12,13 +12,13 @@ import ( "github.com/xvzc/SpoofDPI/dns/addrselect" ) +type exchangeFunc = func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) + type Resolver interface { Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) String() string } -type exchangeFunc = func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) - func recordTypeIDToName(id uint16) string { switch id { case 1: @@ -47,32 +47,31 @@ func sortAddrs(addrs []net.IPAddr) { addrselect.SortByRFC6724(addrs) } -func lookup(ctx context.Context, host string, queryTypes []uint16, sendMsg exchangeFunc) <-chan *DNSResult { +func lookupAllTypes(ctx context.Context, host string, qTypes []uint16, exchange exchangeFunc) <-chan *DNSResult { var wg sync.WaitGroup resCh := make(chan *DNSResult) - lookup := func(qType uint16) { - defer wg.Done() - select { - case <-ctx.Done(): - return - case resCh <- query(ctx, host, qType, sendMsg): - } - } - - for _, queryType := range queryTypes { + for _, qType := range qTypes { wg.Add(1) - go lookup(queryType) + go func(qType uint16) { + defer wg.Done() + select { + case <-ctx.Done(): + return + case resCh <- lookupType(ctx, host, qType, exchange): + } + }(qType) } go func() { wg.Wait() close(resCh) }() + return resCh } -func query(ctx context.Context, host string, queryType uint16, exchange exchangeFunc) *DNSResult { +func lookupType(ctx context.Context, host string, queryType uint16, exchange exchangeFunc) *DNSResult { msg := newMsg(host, queryType) resp, err := exchange(ctx, msg) if err != nil { @@ -103,7 +102,7 @@ func processResults(ctx context.Context, resCh <-chan *DNSResult) ([]net.IPAddr, } select { case <-ctx.Done(): - return nil, errors.New("cancelled") + return nil, errors.New("canceled") default: if len(addrs) == 0 { return addrs, errors.Join(errs...) diff --git a/dns/resolver/system.go b/dns/resolver/system.go index 5be0b9f..928be7f 100644 --- a/dns/resolver/system.go +++ b/dns/resolver/system.go @@ -19,7 +19,7 @@ func (r *SystemResolver) String() string { return "system resolver" } -func (r *SystemResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) { +func (r *SystemResolver) Resolve(ctx context.Context, host string, _ []uint16) ([]net.IPAddr, error) { addrs, err := r.LookupIPAddr(ctx, host) if err != nil { return []net.IPAddr{}, err