From 464afe937df88ab51088ad491a440d29a86fb202 Mon Sep 17 00:00:00 2001 From: xvzc Date: Tue, 13 Aug 2024 07:53:37 +0900 Subject: [PATCH] chore: refactor doh --- dns/dns.go | 8 ++++---- dns/doh.go | 19 +++++-------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/dns/dns.go b/dns/dns.go index dc19a53..758b38d 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -41,7 +41,7 @@ func (d *DnsResolver) Lookup(domain string, useSystemDns bool) (string, error) { if d.enableDoh { log.Debug("[DNS] ", domain, " resolving with dns over https") - return dohLookup(domain) + return dohLookup(d.host, domain) } log.Debug("[DNS] ", domain, " resolving with custom dns") @@ -86,16 +86,16 @@ func systemLookup(domain string) (string, error) { return "", errors.New("no record found(system)") } -func dohLookup(domain string) (string, error) { +func dohLookup(host string, domain string) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - client := GetDOHClient(*util.GetConfig().DnsAddr) + client := getDOHClient(host) msg := new(dns.Msg) msg.SetQuestion(dns.Fqdn(domain), dns.TypeA) - response, err := client.Exchange(ctx, domain, msg) + response, err := client.dohExchange(ctx, msg) if err != nil { return "", errors.New("could not resolve the domain(doh)") } diff --git a/dns/doh.go b/dns/doh.go index dfab6d7..6fb3fe7 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -8,7 +8,6 @@ import ( "fmt" "net" "net/http" - "strings" "sync" "time" @@ -23,17 +22,9 @@ type DOHClient struct { var client *DOHClient var clientOnce sync.Once -func GetDOHClient(upstream string) *DOHClient { +func getDOHClient(host string) *DOHClient { clientOnce.Do(func() { if client == nil { - if !strings.HasPrefix(upstream, "https://") { - upstream = "https://" + upstream - } - - if !strings.HasSuffix(upstream, "/dns-query") { - upstream = upstream + "/dns-query" - } - c := &http.Client{ Timeout: 5 * time.Second, Transport: &http.Transport{ @@ -48,7 +39,7 @@ func GetDOHClient(upstream string) *DOHClient { } client = &DOHClient{ - upstream: upstream, + upstream: "https://" + host + "/dns-query", client: c, } } @@ -57,7 +48,7 @@ func GetDOHClient(upstream string) *DOHClient { return client } -func (d *DOHClient) query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { +func (d *DOHClient) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { return nil, err @@ -97,8 +88,8 @@ func (d *DOHClient) query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { return resultMsg, nil } -func (d *DOHClient) Exchange(ctx context.Context, domain string, msg *dns.Msg) (*dns.Msg, error) { - res, err := d.query(ctx, msg) +func (d *DOHClient) dohExchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + res, err := d.dohQuery(ctx, msg) if err != nil { return nil, err }