chore: refactor dns/resolver

This commit is contained in:
xvzc 2024-08-18 18:51:20 +09:00
parent 85e1c0aa33
commit 442ae6840b
4 changed files with 26 additions and 40 deletions

View File

@ -41,11 +41,7 @@ func NewDOHClient(host string) *DOHResolver {
} }
func (r *DOHResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) { func (r *DOHResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) {
sendMsg := func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { resultCh := lookupAllTypes(ctx, host, qTypes, r.exchange)
return r.dohExchange(ctx, msg)
}
resultCh := lookup(ctx, host, qTypes, sendMsg)
addrs, err := processResults(ctx, resultCh) addrs, err := processResults(ctx, resultCh)
return addrs, err return addrs, err
} }
@ -54,7 +50,7 @@ func (r *DOHResolver) String() string {
return fmt.Sprintf("doh resolver(%s)", r.upstream) return fmt.Sprintf("doh resolver(%s)", r.upstream)
} }
func (r *DOHResolver) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { func (r *DOHResolver) exchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
pack, err := msg.Pack() pack, err := msg.Pack()
if err != nil { if err != nil {
return nil, err return nil, err
@ -91,18 +87,9 @@ func (r *DOHResolver) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, err
return nil, err return nil, err
} }
return resultMsg, nil if resultMsg.Rcode != dns.RcodeSuccess {
}
func (r *DOHResolver) dohExchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
res, err := r.dohQuery(ctx, msg)
if err != nil {
return nil, err
}
if res.Rcode != dns.RcodeSuccess {
return nil, errors.New("doh rcode wasn't successful") return nil, errors.New("doh rcode wasn't successful")
} }
return res, nil return resultMsg, nil
} }

View File

@ -26,12 +26,7 @@ func NewGeneralClient(server string) *GeneralResolver {
} }
func (r *GeneralResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) { func (r *GeneralResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) {
sendMsg := func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { resultCh := lookupAllTypes(ctx, host, qTypes, r.exchange)
resp, _, err := r.client.Exchange(msg, r.server)
return resp, err
}
resultCh := lookup(ctx, host, qTypes, sendMsg)
addrs, err := processResults(ctx, resultCh) addrs, err := processResults(ctx, resultCh)
return addrs, err return addrs, err
} }
@ -39,3 +34,8 @@ func (r *GeneralResolver) Resolve(ctx context.Context, host string, qTypes []uin
func (c *GeneralResolver) String() string { func (c *GeneralResolver) String() string {
return fmt.Sprintf("general resolver(%s)", c.server) return fmt.Sprintf("general resolver(%s)", c.server)
} }
func (r *GeneralResolver) exchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
resp, _, err := r.client.Exchange(msg, r.server)
return resp, err
}

View File

@ -12,13 +12,13 @@ import (
"github.com/xvzc/SpoofDPI/dns/addrselect" "github.com/xvzc/SpoofDPI/dns/addrselect"
) )
type exchangeFunc = func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
type Resolver interface { type Resolver interface {
Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error)
String() string String() string
} }
type exchangeFunc = func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
func recordTypeIDToName(id uint16) string { func recordTypeIDToName(id uint16) string {
switch id { switch id {
case 1: case 1:
@ -47,32 +47,31 @@ func sortAddrs(addrs []net.IPAddr) {
addrselect.SortByRFC6724(addrs) addrselect.SortByRFC6724(addrs)
} }
func lookup(ctx context.Context, host string, queryTypes []uint16, sendMsg exchangeFunc) <-chan *DNSResult { func lookupAllTypes(ctx context.Context, host string, qTypes []uint16, exchange exchangeFunc) <-chan *DNSResult {
var wg sync.WaitGroup var wg sync.WaitGroup
resCh := make(chan *DNSResult) resCh := make(chan *DNSResult)
lookup := func(qType uint16) { for _, qType := range qTypes {
wg.Add(1)
go func(qType uint16) {
defer wg.Done() defer wg.Done()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case resCh <- query(ctx, host, qType, sendMsg): case resCh <- lookupType(ctx, host, qType, exchange):
} }
} }(qType)
for _, queryType := range queryTypes {
wg.Add(1)
go lookup(queryType)
} }
go func() { go func() {
wg.Wait() wg.Wait()
close(resCh) close(resCh)
}() }()
return resCh return resCh
} }
func query(ctx context.Context, host string, queryType uint16, exchange exchangeFunc) *DNSResult { func lookupType(ctx context.Context, host string, queryType uint16, exchange exchangeFunc) *DNSResult {
msg := newMsg(host, queryType) msg := newMsg(host, queryType)
resp, err := exchange(ctx, msg) resp, err := exchange(ctx, msg)
if err != nil { if err != nil {
@ -103,7 +102,7 @@ func processResults(ctx context.Context, resCh <-chan *DNSResult) ([]net.IPAddr,
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, errors.New("cancelled") return nil, errors.New("canceled")
default: default:
if len(addrs) == 0 { if len(addrs) == 0 {
return addrs, errors.Join(errs...) return addrs, errors.Join(errs...)

View File

@ -19,7 +19,7 @@ func (r *SystemResolver) String() string {
return "system resolver" return "system resolver"
} }
func (r *SystemResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) { func (r *SystemResolver) Resolve(ctx context.Context, host string, _ []uint16) ([]net.IPAddr, error) {
addrs, err := r.LookupIPAddr(ctx, host) addrs, err := r.LookupIPAddr(ctx, host)
if err != nil { if err != nil {
return []net.IPAddr{}, err return []net.IPAddr{}, err