package dns import ( "bytes" "context" "encoding/base64" "errors" "fmt" "net" "net/http" "regexp" "sync" "time" "github.com/miekg/dns" ) type DOHClient struct { upstream string httpClient *http.Client } var dohClient *DOHClient var clientOnce sync.Once func getDOHClient(host string) *DOHClient { if dohClient != nil { return dohClient } clientOnce.Do(func() { h := &http.Client{ Timeout: 5 * time.Second, Transport: &http.Transport{ DialContext: (&net.Dialer{ Timeout: 3 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, TLSHandshakeTimeout: 5 * time.Second, MaxIdleConnsPerHost: 100, MaxIdleConns: 100, }, } host = regexp.MustCompile(`^https:\/\/|\/dns-query$`).ReplaceAllString(host, "") dohClient = &DOHClient{ upstream: "https://" + host + "/dns-query", httpClient: h, } }) return dohClient } func (d *DOHClient) dohQuery(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { return nil, err } url := fmt.Sprintf("%s?dns=%s", d.upstream, base64.RawStdEncoding.EncodeToString(pack)) req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } req = req.WithContext(ctx) req.Header.Set("Accept", "application/dns-message") resp, err := d.httpClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, errors.New("doh status error") } buf := bytes.Buffer{} _, err = buf.ReadFrom(resp.Body) if err != nil { return nil, err } resultMsg := new(dns.Msg) err = resultMsg.Unpack(buf.Bytes()) if err != nil { return nil, err } return resultMsg, nil } func (d *DOHClient) dohExchange(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { res, err := d.dohQuery(ctx, msg) if err != nil { return nil, err } if res.Rcode != dns.RcodeSuccess { return nil, errors.New("doh rcode wasn't successful") } return res, nil }