chore: refactor doh

This commit is contained in:
xvzc 2024-08-13 07:37:20 +09:00
parent dbe7f32a1f
commit 76b16da2e8
2 changed files with 45 additions and 54 deletions

View File

@ -59,7 +59,7 @@ func customLookup(host string, port string, domain string) (string, error) {
response, _, err := c.Exchange(msg, dnsServer)
if err != nil {
return "", errors.New("couldn not resolve the domain(custom)")
return "", errors.New("could not resolve the domain(custom)")
}
for _, answer := range response.Answer {
@ -76,7 +76,7 @@ func systemLookup(domain string) (string, error) {
systemResolver := net.Resolver{PreferGo: true}
ips, err := systemResolver.LookupIPAddr(context.Background(), domain)
if err != nil {
return "", errors.New("couldn not resolve the domain(system)")
return "", errors.New("could not resolve the domain(system)")
}
for _, ip := range ips {
@ -89,18 +89,22 @@ func systemLookup(domain string) (string, error) {
func dohLookup(domain string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
log.Debug("[DoH] ", domain, " resolving with dns over https")
dnsUpstream := util.GetConfig().DnsAddr
client := GetDoHClient(*dnsUpstream)
resp, err := client.Resolve(ctx, domain, []uint16{dns.TypeA, dns.TypeAAAA})
if err == nil {
if len(resp) == 0 { // yes this happens
return "", errors.New("no record found(doh)")
}
client := GetDOHClient(*util.GetConfig().DnsAddr)
return resp[0], nil
}
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(domain), dns.TypeA)
response, err := client.Exchange(ctx, domain, msg)
if err != nil {
return "", errors.New("could not resolve the domain(doh)")
}
for _, answer := range response.Answer {
if record, ok := answer.(*dns.A); ok {
return record.A.String(), nil
}
}
return "", errors.New("no record found(system)")
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
@ -12,18 +13,17 @@ import (
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
type DoHClient struct {
type DOHClient struct {
upstream string
c *http.Client
client *http.Client
}
var client *DoHClient
var client *DOHClient
var clientOnce sync.Once
func GetDoHClient(upstream string) *DoHClient {
func GetDOHClient(upstream string) *DOHClient {
clientOnce.Do(func() {
if client == nil {
if !strings.HasPrefix(upstream, "https://") {
@ -47,9 +47,9 @@ func GetDoHClient(upstream string) *DoHClient {
},
}
client = &DoHClient{
client = &DOHClient{
upstream: upstream,
c: c,
client: c,
}
}
})
@ -57,7 +57,7 @@ func GetDoHClient(upstream string) *DoHClient {
return client
}
func (d *DoHClient) doGetRequest(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
func (d *DOHClient) query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
pack, err := msg.Pack()
if err != nil {
return nil, err
@ -72,53 +72,40 @@ func (d *DoHClient) doGetRequest(ctx context.Context, msg *dns.Msg) (*dns.Msg, e
req = req.WithContext(ctx)
req.Header.Set("Accept", "application/dns-message")
resp, err := d.c.Do(req)
resp, err := d.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Debug("[DoH] Error while resolving ", url, " : ", resp.Status)
return nil, errors.New("doh status error")
}
buf := bytes.Buffer{}
buf.ReadFrom(resp.Body)
ret_msg := new(dns.Msg)
err = ret_msg.Unpack(buf.Bytes())
_, err = buf.ReadFrom(resp.Body)
if err != nil {
return nil, err
}
return ret_msg, nil
}
func (d *DoHClient) Resolve(ctx context.Context, domain string, dnsTypes []uint16) ([]string, error) {
var ret []string
for _, dnsType := range dnsTypes {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(domain), dnsType)
resp, err := d.doGetRequest(ctx, msg)
resultMsg := new(dns.Msg)
err = resultMsg.Unpack(buf.Bytes())
if err != nil {
return nil, err
}
if resp.Rcode != dns.RcodeSuccess {
continue
}
for _, answer := range resp.Answer {
if t, ok := answer.(*dns.A); ok {
ret = append(ret, t.A.String())
}
if t, ok := answer.(*dns.AAAA); ok {
ret = append(ret, t.AAAA.String())
}
}
}
return ret, nil
return resultMsg, nil
}
func (d *DOHClient) Exchange(ctx context.Context, domain string, msg *dns.Msg) (*dns.Msg, error) {
res, err := d.query(ctx, msg)
if err != nil {
return nil, err
}
if res.Rcode != dns.RcodeSuccess {
return nil, errors.New("doh rcode wasn't successful")
}
return res, nil
}