From 84ad7288eac69dc81175f3dce4ec34a1c411dbd1 Mon Sep 17 00:00:00 2001 From: xvzc Date: Tue, 13 Aug 2024 13:30:59 +0900 Subject: [PATCH] fix: re-add handling leading https and trailing /dns-query --- dns/dns.go | 2 +- dns/doh.go | 48 ++++++++++++++++++++++++++---------------------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/dns/dns.go b/dns/dns.go index 758b38d..eee15c6 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -106,5 +106,5 @@ func dohLookup(host string, domain string) (string, error) { } } - return "", errors.New("no record found(system)") + return "", errors.New("no record found(doh)") } diff --git a/dns/doh.go b/dns/doh.go index 6fb3fe7..d597642 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/http" + "regexp" "sync" "time" @@ -16,36 +17,39 @@ import ( type DOHClient struct { upstream string - client *http.Client + httpClient *http.Client } -var client *DOHClient +var dohClient *DOHClient var clientOnce sync.Once func getDOHClient(host string) *DOHClient { - clientOnce.Do(func() { - if client == nil { - c := &http.Client{ - Timeout: 5 * time.Second, - Transport: &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 3 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - TLSHandshakeTimeout: 5 * time.Second, - MaxIdleConnsPerHost: 100, - MaxIdleConns: 100, - }, - } + if dohClient != nil { + return dohClient + } - client = &DOHClient{ - upstream: "https://" + host + "/dns-query", - client: c, - } + clientOnce.Do(func() { + h := &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 3 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 5 * time.Second, + MaxIdleConnsPerHost: 100, + MaxIdleConns: 100, + }, + } + + host = regexp.MustCompile(`^https:\/\/|\/dns-query$`).ReplaceAllString(host, "") + dohClient = &DOHClient{ + upstream: "https://" + host + "/dns-query", + httpClient: h, } }) - return client + return dohClient } func (d *DOHClient) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { @@ -63,7 +67,7 @@ func (d *DOHClient) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, error req = req.WithContext(ctx) req.Header.Set("Accept", "application/dns-message") - resp, err := d.client.Do(req) + resp, err := d.httpClient.Do(req) if err != nil { return nil, err }