fix: re-implement cancellation context

This commit is contained in:
ohaiibuzzle 2024-08-10 14:15:22 +07:00
parent 714daeab99
commit c93ddd67e0
2 changed files with 9 additions and 4 deletions

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"regexp" "regexp"
"strconv" "strconv"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -86,13 +87,15 @@ func systemLookup(domain string) (string, error) {
} }
func dohLookup(domain string) (string, error) { func dohLookup(domain string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
log.Debug("[DoH] ", domain, " resolving with dns over https") log.Debug("[DoH] ", domain, " resolving with dns over https")
dnsUpstream := util.GetConfig().DnsAddr dnsUpstream := util.GetConfig().DnsAddr
client := GetDoHClient(*dnsUpstream) client := GetDoHClient(*dnsUpstream)
// try up to 3 times // try up to 3 times
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
resp, err := client.Resolve(domain, []uint16{dns.TypeA, dns.TypeAAAA}) resp, err := client.Resolve(ctx, domain, []uint16{dns.TypeA, dns.TypeAAAA})
if err == nil { if err == nil {
if len(resp) == 0 { // yes this happens if len(resp) == 0 { // yes this happens
return "", errors.New("no record found(doh)") return "", errors.New("no record found(doh)")

View File

@ -2,6 +2,7 @@ package dns
import ( import (
"bytes" "bytes"
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net" "net"
@ -52,7 +53,7 @@ func GetDoHClient(upstream string) *DoHClient {
return client return client
} }
func (d *DoHClient) doGetRequest(msg *dns.Msg) (*dns.Msg, error) { func (d *DoHClient) doGetRequest(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
pack, err := msg.Pack() pack, err := msg.Pack()
if err != nil { if err != nil {
return nil, err return nil, err
@ -64,6 +65,7 @@ func (d *DoHClient) doGetRequest(msg *dns.Msg) (*dns.Msg, error) {
return nil, err return nil, err
} }
req = req.WithContext(ctx)
req.Header.Set("Accept", "application/dns-message") req.Header.Set("Accept", "application/dns-message")
resp, err := d.c.Do(req) resp, err := d.c.Do(req)
@ -88,14 +90,14 @@ func (d *DoHClient) doGetRequest(msg *dns.Msg) (*dns.Msg, error) {
return ret_msg, nil return ret_msg, nil
} }
func (d *DoHClient) Resolve(domain string, dnsTypes []uint16) ([]string, error) { func (d *DoHClient) Resolve(ctx context.Context, domain string, dnsTypes []uint16) ([]string, error) {
var ret []string var ret []string
for _, dnsType := range dnsTypes { for _, dnsType := range dnsTypes {
msg := new(dns.Msg) msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(domain), dnsType) msg.SetQuestion(dns.Fqdn(domain), dnsType)
resp, err := d.doGetRequest(msg) resp, err := d.doGetRequest(ctx, msg)
if err != nil { if err != nil {
return nil, err return nil, err
} }