Merge pull request #136 from xvzc/refactor

chore: refactor doh
This commit is contained in:
xvzc 2024-08-13 21:33:09 +09:00 committed by GitHub
commit ea22bd2451
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 81 deletions

View File

@ -41,7 +41,7 @@ func (d *DnsResolver) Lookup(domain string, useSystemDns bool) (string, error) {
if d.enableDoh { if d.enableDoh {
log.Debug("[DNS] ", domain, " resolving with dns over https") log.Debug("[DNS] ", domain, " resolving with dns over https")
return dohLookup(domain) return dohLookup(d.host, domain)
} }
log.Debug("[DNS] ", domain, " resolving with custom dns") log.Debug("[DNS] ", domain, " resolving with custom dns")
@ -59,7 +59,7 @@ func customLookup(host string, port string, domain string) (string, error) {
response, _, err := c.Exchange(msg, dnsServer) response, _, err := c.Exchange(msg, dnsServer)
if err != nil { if err != nil {
return "", errors.New("couldn not resolve the domain(custom)") return "", errors.New("could not resolve the domain(custom)")
} }
for _, answer := range response.Answer { for _, answer := range response.Answer {
@ -76,7 +76,7 @@ func systemLookup(domain string) (string, error) {
systemResolver := net.Resolver{PreferGo: true} systemResolver := net.Resolver{PreferGo: true}
ips, err := systemResolver.LookupIPAddr(context.Background(), domain) ips, err := systemResolver.LookupIPAddr(context.Background(), domain)
if err != nil { if err != nil {
return "", errors.New("couldn not resolve the domain(system)") return "", errors.New("could not resolve the domain(system)")
} }
for _, ip := range ips { for _, ip := range ips {
@ -86,21 +86,25 @@ func systemLookup(domain string) (string, error) {
return "", errors.New("no record found(system)") return "", errors.New("no record found(system)")
} }
func dohLookup(domain string) (string, error) { func dohLookup(host string, domain string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer cancel()
log.Debug("[DoH] ", domain, " resolving with dns over https")
dnsUpstream := util.GetConfig().DnsAddr client := getDOHClient(host)
client := GetDoHClient(*dnsUpstream)
resp, err := client.Resolve(ctx, domain, []uint16{dns.TypeA, dns.TypeAAAA})
if err == nil {
if len(resp) == 0 { // yes this happens
return "", errors.New("no record found(doh)")
}
return resp[0], nil msg := new(dns.Msg)
} msg.SetQuestion(dns.Fqdn(domain), dns.TypeA)
response, err := client.dohExchange(ctx, msg)
if err != nil {
return "", errors.New("could not resolve the domain(doh)") return "", errors.New("could not resolve the domain(doh)")
}
for _, answer := range response.Answer {
if record, ok := answer.(*dns.A); ok {
return record.A.String(), nil
}
}
return "", errors.New("no record found(doh)")
} }

View File

@ -4,37 +4,32 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"strings" "regexp"
"sync" "sync"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus"
) )
type DoHClient struct { type DOHClient struct {
upstream string upstream string
c *http.Client httpClient *http.Client
} }
var client *DoHClient var dohClient *DOHClient
var clientOnce sync.Once var clientOnce sync.Once
func GetDoHClient(upstream string) *DoHClient { func getDOHClient(host string) *DOHClient {
if dohClient != nil {
return dohClient
}
clientOnce.Do(func() { clientOnce.Do(func() {
if client == nil { h := &http.Client{
if !strings.HasPrefix(upstream, "https://") {
upstream = "https://" + upstream
}
if !strings.HasSuffix(upstream, "/dns-query") {
upstream = upstream + "/dns-query"
}
c := &http.Client{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
@ -47,17 +42,17 @@ func GetDoHClient(upstream string) *DoHClient {
}, },
} }
client = &DoHClient{ host = regexp.MustCompile(`^https:\/\/|\/dns-query$`).ReplaceAllString(host, "")
upstream: upstream, dohClient = &DOHClient{
c: c, upstream: "https://" + host + "/dns-query",
} httpClient: h,
} }
}) })
return client return dohClient
} }
func (d *DoHClient) doGetRequest(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { func (d *DOHClient) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
pack, err := msg.Pack() pack, err := msg.Pack()
if err != nil { if err != nil {
return nil, err return nil, err
@ -72,53 +67,40 @@ func (d *DoHClient) doGetRequest(ctx context.Context, msg *dns.Msg) (*dns.Msg, e
req = req.WithContext(ctx) req = req.WithContext(ctx)
req.Header.Set("Accept", "application/dns-message") req.Header.Set("Accept", "application/dns-message")
resp, err := d.c.Do(req) resp, err := d.httpClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
log.Debug("[DoH] Error while resolving ", url, " : ", resp.Status) return nil, errors.New("doh status error")
} }
buf := bytes.Buffer{} buf := bytes.Buffer{}
buf.ReadFrom(resp.Body) _, err = buf.ReadFrom(resp.Body)
ret_msg := new(dns.Msg)
err = ret_msg.Unpack(buf.Bytes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ret_msg, nil resultMsg := new(dns.Msg)
} err = resultMsg.Unpack(buf.Bytes())
func (d *DoHClient) Resolve(ctx context.Context, domain string, dnsTypes []uint16) ([]string, error) {
var ret []string
for _, dnsType := range dnsTypes {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(domain), dnsType)
resp, err := d.doGetRequest(ctx, msg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Rcode != dns.RcodeSuccess { return resultMsg, nil
continue }
}
func (d *DOHClient) dohExchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
for _, answer := range resp.Answer { res, err := d.dohQuery(ctx, msg)
if t, ok := answer.(*dns.A); ok { if err != nil {
ret = append(ret, t.A.String()) return nil, err
} }
if t, ok := answer.(*dns.AAAA); ok {
ret = append(ret, t.AAAA.String()) if res.Rcode != dns.RcodeSuccess {
} return nil, errors.New("doh rcode wasn't successful")
} }
}
return res, nil
return ret, nil
} }