mirror of
https://github.com/xvzc/SpoofDPI.git
synced 2024-12-22 06:15:51 +00:00
chore: refactor dns/resolver
This commit is contained in:
parent
85e1c0aa33
commit
442ae6840b
@ -41,11 +41,7 @@ func NewDOHClient(host string) *DOHResolver {
|
||||
}
|
||||
|
||||
func (r *DOHResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) {
|
||||
sendMsg := func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
return r.dohExchange(ctx, msg)
|
||||
}
|
||||
|
||||
resultCh := lookup(ctx, host, qTypes, sendMsg)
|
||||
resultCh := lookupAllTypes(ctx, host, qTypes, r.exchange)
|
||||
addrs, err := processResults(ctx, resultCh)
|
||||
return addrs, err
|
||||
}
|
||||
@ -54,7 +50,7 @@ func (r *DOHResolver) String() string {
|
||||
return fmt.Sprintf("doh resolver(%s)", r.upstream)
|
||||
}
|
||||
|
||||
func (r *DOHResolver) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
func (r *DOHResolver) exchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
pack, err := msg.Pack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -91,18 +87,9 @@ func (r *DOHResolver) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resultMsg, nil
|
||||
}
|
||||
|
||||
func (r *DOHResolver) dohExchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
res, err := r.dohQuery(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.Rcode != dns.RcodeSuccess {
|
||||
if resultMsg.Rcode != dns.RcodeSuccess {
|
||||
return nil, errors.New("doh rcode wasn't successful")
|
||||
}
|
||||
|
||||
return res, nil
|
||||
return resultMsg, nil
|
||||
}
|
||||
|
@ -26,12 +26,7 @@ func NewGeneralClient(server string) *GeneralResolver {
|
||||
}
|
||||
|
||||
func (r *GeneralResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) {
|
||||
sendMsg := func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
resp, _, err := r.client.Exchange(msg, r.server)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
resultCh := lookup(ctx, host, qTypes, sendMsg)
|
||||
resultCh := lookupAllTypes(ctx, host, qTypes, r.exchange)
|
||||
addrs, err := processResults(ctx, resultCh)
|
||||
return addrs, err
|
||||
}
|
||||
@ -39,3 +34,8 @@ func (r *GeneralResolver) Resolve(ctx context.Context, host string, qTypes []uin
|
||||
func (c *GeneralResolver) String() string {
|
||||
return fmt.Sprintf("general resolver(%s)", c.server)
|
||||
}
|
||||
|
||||
func (r *GeneralResolver) exchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
resp, _, err := r.client.Exchange(msg, r.server)
|
||||
return resp, err
|
||||
}
|
||||
|
@ -12,13 +12,13 @@ import (
|
||||
"github.com/xvzc/SpoofDPI/dns/addrselect"
|
||||
)
|
||||
|
||||
type exchangeFunc = func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
|
||||
|
||||
type Resolver interface {
|
||||
Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error)
|
||||
String() string
|
||||
}
|
||||
|
||||
type exchangeFunc = func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
|
||||
|
||||
func recordTypeIDToName(id uint16) string {
|
||||
switch id {
|
||||
case 1:
|
||||
@ -47,32 +47,31 @@ func sortAddrs(addrs []net.IPAddr) {
|
||||
addrselect.SortByRFC6724(addrs)
|
||||
}
|
||||
|
||||
func lookup(ctx context.Context, host string, queryTypes []uint16, sendMsg exchangeFunc) <-chan *DNSResult {
|
||||
func lookupAllTypes(ctx context.Context, host string, qTypes []uint16, exchange exchangeFunc) <-chan *DNSResult {
|
||||
var wg sync.WaitGroup
|
||||
resCh := make(chan *DNSResult)
|
||||
|
||||
lookup := func(qType uint16) {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case resCh <- query(ctx, host, qType, sendMsg):
|
||||
}
|
||||
}
|
||||
|
||||
for _, queryType := range queryTypes {
|
||||
for _, qType := range qTypes {
|
||||
wg.Add(1)
|
||||
go lookup(queryType)
|
||||
go func(qType uint16) {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case resCh <- lookupType(ctx, host, qType, exchange):
|
||||
}
|
||||
}(qType)
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resCh)
|
||||
}()
|
||||
|
||||
return resCh
|
||||
}
|
||||
|
||||
func query(ctx context.Context, host string, queryType uint16, exchange exchangeFunc) *DNSResult {
|
||||
func lookupType(ctx context.Context, host string, queryType uint16, exchange exchangeFunc) *DNSResult {
|
||||
msg := newMsg(host, queryType)
|
||||
resp, err := exchange(ctx, msg)
|
||||
if err != nil {
|
||||
@ -103,7 +102,7 @@ func processResults(ctx context.Context, resCh <-chan *DNSResult) ([]net.IPAddr,
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, errors.New("cancelled")
|
||||
return nil, errors.New("canceled")
|
||||
default:
|
||||
if len(addrs) == 0 {
|
||||
return addrs, errors.Join(errs...)
|
||||
|
@ -19,7 +19,7 @@ func (r *SystemResolver) String() string {
|
||||
return "system resolver"
|
||||
}
|
||||
|
||||
func (r *SystemResolver) Resolve(ctx context.Context, host string, qTypes []uint16) ([]net.IPAddr, error) {
|
||||
func (r *SystemResolver) Resolve(ctx context.Context, host string, _ []uint16) ([]net.IPAddr, error) {
|
||||
addrs, err := r.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return []net.IPAddr{}, err
|
||||
|
Loading…
Reference in New Issue
Block a user