SpoofDPI/dns/resolver/resolver.go

120 lines
2.4 KiB
Go
Raw Normal View History

2024-08-18 06:55:47 +00:00
package client
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"sync"
"github.com/miekg/dns"
"github.com/xvzc/SpoofDPI/dns/addrselect"
)
2024-08-18 09:51:20 +00:00
type exchangeFunc = func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
2024-08-18 07:11:04 +00:00
type Resolver interface {
2024-08-18 06:55:47 +00:00
Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error)
String() string
}
2024-08-19 01:47:03 +00:00
type DNSResult struct {
msg *dns.Msg
err error
}
2024-08-18 06:55:47 +00:00
func recordTypeIDToName(id uint16) string {
switch id {
case 1:
return "A"
case 28:
return "AAAA"
}
return strconv.FormatUint(uint64(id), 10)
}
func parseAddrsFromMsg(msg *dns.Msg) []net.IPAddr {
var addrs []net.IPAddr
for _, record := range msg.Answer {
switch ipRecord := record.(type) {
case *dns.A:
addrs = append(addrs, net.IPAddr{IP: ipRecord.A})
case *dns.AAAA:
addrs = append(addrs, net.IPAddr{IP: ipRecord.AAAA})
}
}
return addrs
}
func sortAddrs(addrs []net.IPAddr) {
addrselect.SortByRFC6724(addrs)
}
2024-08-18 09:51:20 +00:00
func lookupAllTypes(ctx context.Context, host string, qTypes []uint16, exchange exchangeFunc) <-chan *DNSResult {
2024-08-18 06:55:47 +00:00
var wg sync.WaitGroup
resCh := make(chan *DNSResult)
2024-08-18 09:51:20 +00:00
for _, qType := range qTypes {
2024-08-18 06:55:47 +00:00
wg.Add(1)
2024-08-18 09:51:20 +00:00
go func(qType uint16) {
defer wg.Done()
select {
case <-ctx.Done():
return
case resCh <- lookupType(ctx, host, qType, exchange):
}
}(qType)
2024-08-18 06:55:47 +00:00
}
go func() {
wg.Wait()
close(resCh)
}()
2024-08-18 09:51:20 +00:00
2024-08-18 06:55:47 +00:00
return resCh
}
2024-08-18 09:51:20 +00:00
func lookupType(ctx context.Context, host string, queryType uint16, exchange exchangeFunc) *DNSResult {
2024-08-18 06:55:47 +00:00
msg := newMsg(host, queryType)
resp, err := exchange(ctx, msg)
if err != nil {
queryName := recordTypeIDToName(queryType)
err = fmt.Errorf("resolving %s, query type %s: %w", host, queryName, err)
return &DNSResult{err: err}
}
return &DNSResult{msg: resp}
}
func newMsg(host string, qType uint16) *dns.Msg {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(host), qType)
return msg
}
func processResults(ctx context.Context, resCh <-chan *DNSResult) ([]net.IPAddr, error) {
var errs []error
var addrs []net.IPAddr
for result := range resCh {
if result.err != nil {
errs = append(errs, result.err)
continue
}
resultAddrs := parseAddrsFromMsg(result.msg)
addrs = append(addrs, resultAddrs...)
}
select {
case <-ctx.Done():
2024-08-18 09:51:20 +00:00
return nil, errors.New("canceled")
2024-08-18 06:55:47 +00:00
default:
if len(addrs) == 0 {
return addrs, errors.Join(errs...)
}
}
sortAddrs(addrs)
return addrs, nil
}