diff --git a/dns/dns.go b/dns/dns.go index 53b2c2a..eee15c6 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -41,7 +41,7 @@ func (d *DnsResolver) Lookup(domain string, useSystemDns bool) (string, error) { if d.enableDoh { log.Debug("[DNS] ", domain, " resolving with dns over https") - return dohLookup(domain) + return dohLookup(d.host, domain) } log.Debug("[DNS] ", domain, " resolving with custom dns") @@ -59,7 +59,7 @@ func customLookup(host string, port string, domain string) (string, error) { response, _, err := c.Exchange(msg, dnsServer) if err != nil { - return "", errors.New("couldn not resolve the domain(custom)") + return "", errors.New("could not resolve the domain(custom)") } for _, answer := range response.Answer { @@ -76,7 +76,7 @@ func systemLookup(domain string) (string, error) { systemResolver := net.Resolver{PreferGo: true} ips, err := systemResolver.LookupIPAddr(context.Background(), domain) if err != nil { - return "", errors.New("couldn not resolve the domain(system)") + return "", errors.New("could not resolve the domain(system)") } for _, ip := range ips { @@ -86,21 +86,25 @@ func systemLookup(domain string) (string, error) { return "", errors.New("no record found(system)") } -func dohLookup(domain string) (string, error) { +func dohLookup(host string, domain string) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - log.Debug("[DoH] ", domain, " resolving with dns over https") - dnsUpstream := util.GetConfig().DnsAddr - client := GetDoHClient(*dnsUpstream) - resp, err := client.Resolve(ctx, domain, []uint16{dns.TypeA, dns.TypeAAAA}) - if err == nil { - if len(resp) == 0 { // yes this happens - return "", errors.New("no record found(doh)") - } + client := getDOHClient(host) - return resp[0], nil + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(domain), dns.TypeA) + + response, err := client.dohExchange(ctx, msg) + if err != nil { + return "", errors.New("could not resolve the domain(doh)") } - return "", errors.New("could not resolve the domain(doh)") + for _, answer := range response.Answer { + if record, ok := answer.(*dns.A); ok { + return record.A.String(), nil + } + } + + return "", errors.New("no record found(doh)") } diff --git a/dns/doh.go b/dns/doh.go index 43e52bc..d597642 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -4,60 +4,55 @@ import ( "bytes" "context" "encoding/base64" + "errors" "fmt" "net" "net/http" - "strings" + "regexp" "sync" "time" "github.com/miekg/dns" - log "github.com/sirupsen/logrus" ) -type DoHClient struct { +type DOHClient struct { upstream string - c *http.Client + httpClient *http.Client } -var client *DoHClient +var dohClient *DOHClient var clientOnce sync.Once -func GetDoHClient(upstream string) *DoHClient { +func getDOHClient(host string) *DOHClient { + if dohClient != nil { + return dohClient + } + clientOnce.Do(func() { - if client == nil { - if !strings.HasPrefix(upstream, "https://") { - upstream = "https://" + upstream - } + h := &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 3 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 5 * time.Second, + MaxIdleConnsPerHost: 100, + MaxIdleConns: 100, + }, + } - if !strings.HasSuffix(upstream, "/dns-query") { - upstream = upstream + "/dns-query" - } - - c := &http.Client{ - Timeout: 5 * time.Second, - Transport: &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 3 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - TLSHandshakeTimeout: 5 * time.Second, - MaxIdleConnsPerHost: 100, - MaxIdleConns: 100, - }, - } - - client = &DoHClient{ - upstream: upstream, - c: c, - } + host = regexp.MustCompile(`^https:\/\/|\/dns-query$`).ReplaceAllString(host, "") + dohClient = &DOHClient{ + upstream: "https://" + host + "/dns-query", + httpClient: h, } }) - return client + return dohClient } -func (d *DoHClient) doGetRequest(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { +func (d *DOHClient) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { return nil, err @@ -72,53 +67,40 @@ func (d *DoHClient) doGetRequest(ctx context.Context, msg *dns.Msg) (*dns.Msg, e req = req.WithContext(ctx) req.Header.Set("Accept", "application/dns-message") - resp, err := d.c.Do(req) + resp, err := d.httpClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - log.Debug("[DoH] Error while resolving ", url, " : ", resp.Status) + return nil, errors.New("doh status error") } buf := bytes.Buffer{} - buf.ReadFrom(resp.Body) - - ret_msg := new(dns.Msg) - err = ret_msg.Unpack(buf.Bytes()) + _, err = buf.ReadFrom(resp.Body) if err != nil { return nil, err } - return ret_msg, nil -} - -func (d *DoHClient) Resolve(ctx context.Context, domain string, dnsTypes []uint16) ([]string, error) { - var ret []string - - for _, dnsType := range dnsTypes { - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(domain), dnsType) - - resp, err := d.doGetRequest(ctx, msg) - if err != nil { - return nil, err - } - - if resp.Rcode != dns.RcodeSuccess { - continue - } - - for _, answer := range resp.Answer { - if t, ok := answer.(*dns.A); ok { - ret = append(ret, t.A.String()) - } - if t, ok := answer.(*dns.AAAA); ok { - ret = append(ret, t.AAAA.String()) - } - } + resultMsg := new(dns.Msg) + err = resultMsg.Unpack(buf.Bytes()) + if err != nil { + return nil, err } - return ret, nil + return resultMsg, nil +} + +func (d *DOHClient) dohExchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + res, err := d.dohQuery(ctx, msg) + if err != nil { + return nil, err + } + + if res.Rcode != dns.RcodeSuccess { + return nil, errors.New("doh rcode wasn't successful") + } + + return res, nil }