diff --git a/dns/dns.go b/dns/dns.go index 53b2c2a..dc19a53 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -59,7 +59,7 @@ func customLookup(host string, port string, domain string) (string, error) { response, _, err := c.Exchange(msg, dnsServer) if err != nil { - return "", errors.New("couldn not resolve the domain(custom)") + return "", errors.New("could not resolve the domain(custom)") } for _, answer := range response.Answer { @@ -76,7 +76,7 @@ func systemLookup(domain string) (string, error) { systemResolver := net.Resolver{PreferGo: true} ips, err := systemResolver.LookupIPAddr(context.Background(), domain) if err != nil { - return "", errors.New("couldn not resolve the domain(system)") + return "", errors.New("could not resolve the domain(system)") } for _, ip := range ips { @@ -89,18 +89,22 @@ func systemLookup(domain string) (string, error) { func dohLookup(domain string) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - log.Debug("[DoH] ", domain, " resolving with dns over https") - dnsUpstream := util.GetConfig().DnsAddr - client := GetDoHClient(*dnsUpstream) - resp, err := client.Resolve(ctx, domain, []uint16{dns.TypeA, dns.TypeAAAA}) - if err == nil { - if len(resp) == 0 { // yes this happens - return "", errors.New("no record found(doh)") - } + client := GetDOHClient(*util.GetConfig().DnsAddr) - return resp[0], nil + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(domain), dns.TypeA) + + response, err := client.Exchange(ctx, domain, msg) + if err != nil { + return "", errors.New("could not resolve the domain(doh)") } - return "", errors.New("could not resolve the domain(doh)") + for _, answer := range response.Answer { + if record, ok := answer.(*dns.A); ok { + return record.A.String(), nil + } + } + + return "", errors.New("no record found(system)") } diff --git a/dns/doh.go b/dns/doh.go index 43e52bc..dfab6d7 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/base64" + "errors" "fmt" "net" "net/http" @@ -12,18 +13,17 @@ import ( "time" "github.com/miekg/dns" - log "github.com/sirupsen/logrus" ) -type DoHClient struct { +type DOHClient struct { upstream string - c *http.Client + client *http.Client } -var client *DoHClient +var client *DOHClient var clientOnce sync.Once -func GetDoHClient(upstream string) *DoHClient { +func GetDOHClient(upstream string) *DOHClient { clientOnce.Do(func() { if client == nil { if !strings.HasPrefix(upstream, "https://") { @@ -47,9 +47,9 @@ func GetDoHClient(upstream string) *DoHClient { }, } - client = &DoHClient{ + client = &DOHClient{ upstream: upstream, - c: c, + client: c, } } }) @@ -57,7 +57,7 @@ func GetDoHClient(upstream string) *DoHClient { return client } -func (d *DoHClient) doGetRequest(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { +func (d *DOHClient) query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { return nil, err @@ -72,53 +72,40 @@ func (d *DoHClient) doGetRequest(ctx context.Context, msg *dns.Msg) (*dns.Msg, e req = req.WithContext(ctx) req.Header.Set("Accept", "application/dns-message") - resp, err := d.c.Do(req) + resp, err := d.client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - log.Debug("[DoH] Error while resolving ", url, " : ", resp.Status) + return nil, errors.New("doh status error") } buf := bytes.Buffer{} - buf.ReadFrom(resp.Body) - - ret_msg := new(dns.Msg) - err = ret_msg.Unpack(buf.Bytes()) + _, err = buf.ReadFrom(resp.Body) if err != nil { return nil, err } - return ret_msg, nil -} - -func (d *DoHClient) Resolve(ctx context.Context, domain string, dnsTypes []uint16) ([]string, error) { - var ret []string - - for _, dnsType := range dnsTypes { - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(domain), dnsType) - - resp, err := d.doGetRequest(ctx, msg) - if err != nil { - return nil, err - } - - if resp.Rcode != dns.RcodeSuccess { - continue - } - - for _, answer := range resp.Answer { - if t, ok := answer.(*dns.A); ok { - ret = append(ret, t.A.String()) - } - if t, ok := answer.(*dns.AAAA); ok { - ret = append(ret, t.AAAA.String()) - } - } + resultMsg := new(dns.Msg) + err = resultMsg.Unpack(buf.Bytes()) + if err != nil { + return nil, err } - return ret, nil + return resultMsg, nil +} + +func (d *DOHClient) Exchange(ctx context.Context, domain string, msg *dns.Msg) (*dns.Msg, error) { + res, err := d.query(ctx, msg) + if err != nil { + return nil, err + } + + if res.Rcode != dns.RcodeSuccess { + return nil, errors.New("doh rcode wasn't successful") + } + + return res, nil }