chore: refactor dns/resolver

This commit is contained in:
xvzc 2024-08-18 18:51:20 +09:00
parent 85e1c0aa33
commit 442ae6840b
4 changed files with 26 additions and 40 deletions

View File

@ -41,11 +41,7 @@ func NewDOHClient(host string) *DOHResolver {
}
func (r *DOHResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) {
sendMsg := func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
return r.dohExchange(ctx, msg)
}
resultCh := lookup(ctx, host, qTypes, sendMsg)
resultCh := lookupAllTypes(ctx, host, qTypes, r.exchange)
addrs, err := processResults(ctx, resultCh)
return addrs, err
}
@ -54,7 +50,7 @@ func (r *DOHResolver) String() string {
return fmt.Sprintf("doh resolver(%s)", r.upstream)
}
func (r *DOHResolver) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
func (r *DOHResolver) exchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
pack, err := msg.Pack()
if err != nil {
return nil, err
@ -91,18 +87,9 @@ func (r *DOHResolver) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, err
return nil, err
}
return resultMsg, nil
}
func (r *DOHResolver) dohExchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
res, err := r.dohQuery(ctx, msg)
if err != nil {
return nil, err
}
if res.Rcode != dns.RcodeSuccess {
if resultMsg.Rcode != dns.RcodeSuccess {
return nil, errors.New("doh rcode wasn't successful")
}
return res, nil
return resultMsg, nil
}

View File

@ -26,12 +26,7 @@ func NewGeneralClient(server string) *GeneralResolver {
}
func (r *GeneralResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) {
sendMsg := func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
resp, _, err := r.client.Exchange(msg, r.server)
return resp, err
}
resultCh := lookup(ctx, host, qTypes, sendMsg)
resultCh := lookupAllTypes(ctx, host, qTypes, r.exchange)
addrs, err := processResults(ctx, resultCh)
return addrs, err
}
@ -39,3 +34,8 @@ func (r *GeneralResolver) Resolve(ctx context.Context, host string, qTypes []uin
func (c *GeneralResolver) String() string {
return fmt.Sprintf("general resolver(%s)", c.server)
}
func (r *GeneralResolver) exchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
resp, _, err := r.client.Exchange(msg, r.server)
return resp, err
}

View File

@ -12,13 +12,13 @@ import (
"github.com/xvzc/SpoofDPI/dns/addrselect"
)
type exchangeFunc = func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
type Resolver interface {
Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error)
String() string
}
type exchangeFunc = func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
func recordTypeIDToName(id uint16) string {
switch id {
case 1:
@ -47,32 +47,31 @@ func sortAddrs(addrs []net.IPAddr) {
addrselect.SortByRFC6724(addrs)
}
func lookup(ctx context.Context, host string, queryTypes []uint16, sendMsg exchangeFunc) <-chan *DNSResult {
func lookupAllTypes(ctx context.Context, host string, qTypes []uint16, exchange exchangeFunc) <-chan *DNSResult {
var wg sync.WaitGroup
resCh := make(chan *DNSResult)
lookup := func(qType uint16) {
for _, qType := range qTypes {
wg.Add(1)
go func(qType uint16) {
defer wg.Done()
select {
case <-ctx.Done():
return
case resCh <- query(ctx, host, qType, sendMsg):
case resCh <- lookupType(ctx, host, qType, exchange):
}
}
for _, queryType := range queryTypes {
wg.Add(1)
go lookup(queryType)
}(qType)
}
go func() {
wg.Wait()
close(resCh)
}()
return resCh
}
func query(ctx context.Context, host string, queryType uint16, exchange exchangeFunc) *DNSResult {
func lookupType(ctx context.Context, host string, queryType uint16, exchange exchangeFunc) *DNSResult {
msg := newMsg(host, queryType)
resp, err := exchange(ctx, msg)
if err != nil {
@ -103,7 +102,7 @@ func processResults(ctx context.Context, resCh <-chan *DNSResult) ([]net.IPAddr,
}
select {
case <-ctx.Done():
return nil, errors.New("cancelled")
return nil, errors.New("canceled")
default:
if len(addrs) == 0 {
return addrs, errors.Join(errs...)

View File

@ -19,7 +19,7 @@ func (r *SystemResolver) String() string {
return "system resolver"
}
func (r *SystemResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) {
func (r *SystemResolver) Resolve(ctx context.Context, host string, _ []uint16) ([]net.IPAddr, error) {
addrs, err := r.LookupIPAddr(ctx, host)
if err != nil {
return []net.IPAddr{}, err