diff --git a/dns/dns.go b/dns/dns.go index 97d8485..32b7989 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -90,10 +90,17 @@ func dohLookup(domain string) (string, error) { dnsUpstream := util.GetConfig().DnsAddr client := GetDoHClient(*dnsUpstream) - resp, err := client.Resolve(domain, dns.TypeA) - if err != nil { - return "", errors.New("couldn not resolve the domain(doh)") + // try up to 3 times + for i := 0; i < 3; i++ { + resp, err := client.Resolve(domain, []uint16{dns.TypeA, dns.TypeAAAA}) + if err == nil { + if len(resp) == 0 { // yes this happens + return "", errors.New("no record found(doh)") + } + + return resp[0], nil + } } - return resp[0], nil + return "", errors.New("could not resolve the domain(doh)") } diff --git a/dns/doh.go b/dns/doh.go index 02ee557..561cb00 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -88,20 +88,31 @@ func (d *DoHClient) doGetRequest(msg *dns.Msg) (*dns.Msg, error) { return ret_msg, nil } -func (d *DoHClient) Resolve(domain string, dnsType uint16) ([]string, error) { - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(domain), dnsType) - - resp, err := d.doGetRequest(msg) - if err != nil { - return nil, err - } - +func (d *DoHClient) Resolve(domain string, dnsTypes []uint16) ([]string, error) { var ret []string - for _, ans := range resp.Answer { - if a, ok := ans.(*dns.A); ok { - ret = append(ret, a.A.String()) + + for _, dnsType := range dnsTypes { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(domain), dnsType) + + resp, err := d.doGetRequest(msg) + if err != nil { + return nil, err + } + + if resp.Rcode != dns.RcodeSuccess { + continue + } + + for _, answer := range resp.Answer { + if t, ok := answer.(*dns.A); ok { + ret = append(ret, t.A.String()) + } + if t, ok := answer.(*dns.AAAA); ok { + ret = append(ret, t.AAAA.String()) + } } } + return ret, nil }