feat: add IPv6 support (#161)

* Add support for IPv6 lookups.

* Refactor DNSResolver.

* Make listener support IPv6.
This commit is contained in:
Ledorub 2024-08-18 07:34:09 +03:00 committed by GitHub
parent ab4d6819c7
commit 15163ca5fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 638 additions and 91 deletions

27
dns/addrselect/LICENSE Normal file
View File

@ -0,0 +1,27 @@
Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,377 @@
package addrselect
import (
"net"
"net/netip"
"sort"
)
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Minimal RFC 6724 address selection.
func SortByRFC6724(addrs []net.IPAddr) {
if len(addrs) < 2 {
return
}
sortByRFC6724withSrcs(addrs, srcAddrs(addrs))
}
func sortByRFC6724withSrcs(addrs []net.IPAddr, srcs []netip.Addr) {
if len(addrs) != len(srcs) {
panic("internal error")
}
addrAttr := make([]ipAttr, len(addrs))
srcAttr := make([]ipAttr, len(srcs))
for i, v := range addrs {
addrAttrIP, _ := netip.AddrFromSlice(v.IP)
addrAttr[i] = ipAttrOf(addrAttrIP)
srcAttr[i] = ipAttrOf(srcs[i])
}
sort.Stable(&byRFC6724{
addrs: addrs,
addrAttr: addrAttr,
srcs: srcs,
srcAttr: srcAttr,
})
}
// srcAddrs tries to UDP-connect to each address to see if it has a
// route. (This doesn't send any packets). The destination port
// number is irrelevant.
func srcAddrs(addrs []net.IPAddr) []netip.Addr {
srcs := make([]netip.Addr, len(addrs))
dst := net.UDPAddr{Port: 9}
for i := range addrs {
dst.IP = addrs[i].IP
dst.Zone = addrs[i].Zone
c, err := net.DialUDP("udp", nil, &dst)
if err == nil {
if src, ok := c.LocalAddr().(*net.UDPAddr); ok {
srcs[i], _ = netip.AddrFromSlice(src.IP)
}
c.Close()
}
}
return srcs
}
type ipAttr struct {
Scope scope
Precedence uint8
Label uint8
}
func ipAttrOf(ip netip.Addr) ipAttr {
if !ip.IsValid() {
return ipAttr{}
}
match := rfc6724policyTable.Classify(ip)
return ipAttr{
Scope: classifyScope(ip),
Precedence: match.Precedence,
Label: match.Label,
}
}
type byRFC6724 struct {
addrs []net.IPAddr // addrs to sort
addrAttr []ipAttr
srcs []netip.Addr // or not valid addr if unreachable
srcAttr []ipAttr
}
func (s *byRFC6724) Len() int { return len(s.addrs) }
func (s *byRFC6724) Swap(i, j int) {
s.addrs[i], s.addrs[j] = s.addrs[j], s.addrs[i]
s.srcs[i], s.srcs[j] = s.srcs[j], s.srcs[i]
s.addrAttr[i], s.addrAttr[j] = s.addrAttr[j], s.addrAttr[i]
s.srcAttr[i], s.srcAttr[j] = s.srcAttr[j], s.srcAttr[i]
}
// Less reports whether i is a better destination address for this
// host than j.
//
// The algorithm and variable names comes from RFC 6724 section 6.
func (s *byRFC6724) Less(i, j int) bool {
DA := s.addrs[i].IP
DB := s.addrs[j].IP
SourceDA := s.srcs[i]
SourceDB := s.srcs[j]
attrDA := &s.addrAttr[i]
attrDB := &s.addrAttr[j]
attrSourceDA := &s.srcAttr[i]
attrSourceDB := &s.srcAttr[j]
const preferDA = true
const preferDB = false
// Rule 1: Avoid unusable destinations.
// If DB is known to be unreachable or if Source(DB) is undefined, then
// prefer DA. Similarly, if DA is known to be unreachable or if
// Source(DA) is undefined, then prefer DB.
if !SourceDA.IsValid() && !SourceDB.IsValid() {
return false // "equal"
}
if !SourceDB.IsValid() {
return preferDA
}
if !SourceDA.IsValid() {
return preferDB
}
// Rule 2: Prefer matching scope.
// If Scope(DA) = Scope(Source(DA)) and Scope(DB) <> Scope(Source(DB)),
// then prefer DA. Similarly, if Scope(DA) <> Scope(Source(DA)) and
// Scope(DB) = Scope(Source(DB)), then prefer DB.
if attrDA.Scope == attrSourceDA.Scope && attrDB.Scope != attrSourceDB.Scope {
return preferDA
}
if attrDA.Scope != attrSourceDA.Scope && attrDB.Scope == attrSourceDB.Scope {
return preferDB
}
// Rule 3: Avoid deprecated addresses.
// If Source(DA) is deprecated and Source(DB) is not, then prefer DB.
// Similarly, if Source(DA) is not deprecated and Source(DB) is
// deprecated, then prefer DA.
// TODO(bradfitz): implement? low priority for now.
// Rule 4: Prefer home addresses.
// If Source(DA) is simultaneously a home address and care-of address
// and Source(DB) is not, then prefer DA. Similarly, if Source(DB) is
// simultaneously a home address and care-of address and Source(DA) is
// not, then prefer DB.
// TODO(bradfitz): implement? low priority for now.
// Rule 5: Prefer matching label.
// If Label(Source(DA)) = Label(DA) and Label(Source(DB)) <> Label(DB),
// then prefer DA. Similarly, if Label(Source(DA)) <> Label(DA) and
// Label(Source(DB)) = Label(DB), then prefer DB.
if attrSourceDA.Label == attrDA.Label &&
attrSourceDB.Label != attrDB.Label {
return preferDA
}
if attrSourceDA.Label != attrDA.Label &&
attrSourceDB.Label == attrDB.Label {
return preferDB
}
// Rule 6: Prefer higher precedence.
// If Precedence(DA) > Precedence(DB), then prefer DA. Similarly, if
// Precedence(DA) < Precedence(DB), then prefer DB.
if attrDA.Precedence > attrDB.Precedence {
return preferDA
}
if attrDA.Precedence < attrDB.Precedence {
return preferDB
}
// Rule 7: Prefer native transport.
// If DA is reached via an encapsulating transition mechanism (e.g.,
// IPv6 in IPv4) and DB is not, then prefer DB. Similarly, if DB is
// reached via encapsulation and DA is not, then prefer DA.
// TODO(bradfitz): implement? low priority for now.
// Rule 8: Prefer smaller scope.
// If Scope(DA) < Scope(DB), then prefer DA. Similarly, if Scope(DA) >
// Scope(DB), then prefer DB.
if attrDA.Scope < attrDB.Scope {
return preferDA
}
if attrDA.Scope > attrDB.Scope {
return preferDB
}
// Rule 9: Use the longest matching prefix.
// When DA and DB belong to the same address family (both are IPv6 or
// both are IPv4 [but see below]): If CommonPrefixLen(Source(DA), DA) >
// CommonPrefixLen(Source(DB), DB), then prefer DA. Similarly, if
// CommonPrefixLen(Source(DA), DA) < CommonPrefixLen(Source(DB), DB),
// then prefer DB.
//
// However, applying this rule to IPv4 addresses causes
// problems (see issues 13283 and 18518), so limit to IPv6.
if DA.To4() == nil && DB.To4() == nil {
commonA := commonPrefixLen(SourceDA, DA)
commonB := commonPrefixLen(SourceDB, DB)
if commonA > commonB {
return preferDA
}
if commonA < commonB {
return preferDB
}
}
// Rule 10: Otherwise, leave the order unchanged.
// If DA preceded DB in the original list, prefer DA.
// Otherwise, prefer DB.
return false // "equal"
}
type policyTableEntry struct {
Prefix netip.Prefix
Precedence uint8
Label uint8
}
type policyTable []policyTableEntry
// RFC 6724 section 2.1.
// Items are sorted by the size of their Prefix.Mask.Size,
var rfc6724policyTable = policyTable{
{
// "::1/128"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}), 128),
Precedence: 50,
Label: 0,
},
{
// "::ffff:0:0/96"
// IPv4-compatible, etc.
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}), 96),
Precedence: 35,
Label: 4,
},
{
// "::/96"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 96),
Precedence: 1,
Label: 3,
},
{
// "2001::/32"
// Teredo
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0x20, 0x01}), 32),
Precedence: 5,
Label: 5,
},
{
// "2002::/16"
// 6to4
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0x20, 0x02}), 16),
Precedence: 30,
Label: 2,
},
{
// "3ffe::/16"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0x3f, 0xfe}), 16),
Precedence: 1,
Label: 12,
},
{
// "fec0::/10"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0xfe, 0xc0}), 10),
Precedence: 1,
Label: 11,
},
{
// "fc00::/7"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0xfc}), 7),
Precedence: 3,
Label: 13,
},
{
// "::/0"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
Precedence: 40,
Label: 1,
},
}
// Classify returns the policyTableEntry of the entry with the longest
// matching prefix that contains ip.
// The table t must be sorted from largest mask size to smallest.
func (t policyTable) Classify(ip netip.Addr) policyTableEntry {
// Prefix.Contains() will not match an IPv6 prefix for an IPv4 address.
if ip.Is4() {
ip = netip.AddrFrom16(ip.As16())
}
for _, ent := range t {
if ent.Prefix.Contains(ip) {
return ent
}
}
return policyTableEntry{}
}
// RFC 6724 section 3.1.
type scope uint8
const (
scopeInterfaceLocal scope = 0x1
scopeLinkLocal scope = 0x2
scopeAdminLocal scope = 0x4
scopeSiteLocal scope = 0x5
scopeOrgLocal scope = 0x8
scopeGlobal scope = 0xe
)
func classifyScope(ip netip.Addr) scope {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() {
return scopeLinkLocal
}
ipv6 := ip.Is6() && !ip.Is4In6()
ipv6AsBytes := ip.As16()
if ipv6 && ip.IsMulticast() {
return scope(ipv6AsBytes[1] & 0xf)
}
// Site-local addresses are defined in RFC 3513 section 2.5.6
// (and deprecated in RFC 3879).
if ipv6 && ipv6AsBytes[0] == 0xfe && ipv6AsBytes[1]&0xc0 == 0xc0 {
return scopeSiteLocal
}
return scopeGlobal
}
// commonPrefixLen reports the length of the longest prefix (looking
// at the most significant, or leftmost, bits) that the
// two addresses have in common, up to the length of a's prefix (i.e.,
// the portion of the address not including the interface ID).
//
// If a or b is an IPv4 address as an IPv6 address, the IPv4 addresses
// are compared (with max common prefix length of 32).
// If a and b are different IP versions, 0 is returned.
//
// See https://tools.ietf.org/html/rfc6724#section-2.2
func commonPrefixLen(a netip.Addr, b net.IP) (cpl int) {
if b4 := b.To4(); b4 != nil {
b = b4
}
aAsSlice := a.AsSlice()
if len(aAsSlice) != len(b) {
return 0
}
// If IPv6, only up to the prefix (first 64 bits)
if len(aAsSlice) > 8 {
aAsSlice = aAsSlice[:8]
b = b[:8]
}
for len(aAsSlice) > 0 {
if aAsSlice[0] == b[0] {
cpl += 8
aAsSlice = aAsSlice[1:]
b = b[1:]
continue
}
bits := 8
ab, bb := aAsSlice[0], b[0]
for {
ab >>= 1
bb >>= 1
bits--
if ab == bb {
cpl += bits
return
}
}
}
return
}

View File

@ -3,108 +3,256 @@ package dns
import (
"context"
"errors"
"net"
"regexp"
"strconv"
"time"
"fmt"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/xvzc/SpoofDPI/dns/addrselect"
"github.com/xvzc/SpoofDPI/util"
"net"
"net/netip"
"strconv"
"sync"
"time"
)
type DnsResolver struct {
host string
port string
enableDoh bool
type client interface {
Resolve(ctx context.Context, host string) ([]net.IPAddr, error)
String() string
}
func NewResolver(config *util.Config) *DnsResolver {
return &DnsResolver{
host: *config.DnsAddr,
port: strconv.Itoa(*config.DnsPort),
enableDoh: *config.EnableDoh,
type Resolver struct {
host string
port string
enableDoh bool
systemClient client
customClient client
}
func NewResolver(config *util.Config) *Resolver {
addr := *config.DnsAddr
port := strconv.Itoa(*config.DnsPort)
server := net.JoinHostPort(addr, port)
var systemClient client
if config.AllowedPatterns != nil {
systemClient = NewSystemClient()
}
var customClient client
if *config.EnableDoh {
customClient = NewDoHClient(addr)
} else {
customClient = NewCustomClient(server)
}
return &Resolver{
host: *config.DnsAddr,
port: port,
enableDoh: *config.EnableDoh,
systemClient: systemClient,
customClient: customClient,
}
}
func (d *DnsResolver) Lookup(domain string, useSystemDns bool) (string, error) {
ipRegex := "^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
if r, _ := regexp.MatchString(ipRegex, domain); r {
func (d *Resolver) Lookup(domain string, useSystemDns bool) (string, error) {
if _, err := parseAddr(domain); err == nil {
return domain, nil
}
if useSystemDns {
log.Debug("[DNS] ", domain, " resolving with system dns")
return systemLookup(domain)
}
if d.enableDoh {
log.Debug("[DNS] ", domain, " resolving with dns over https")
return dohLookup(d.host, domain)
}
log.Debug("[DNS] ", domain, " resolving with custom dns")
return customLookup(d.host, d.port, domain)
}
func customLookup(host string, port string, domain string) (string, error) {
dnsServer := host + ":" + port
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(domain), dns.TypeA)
c := new(dns.Client)
response, _, err := c.Exchange(msg, dnsServer)
if err != nil {
return "", errors.New("could not resolve the domain(custom)")
}
for _, answer := range response.Answer {
if record, ok := answer.(*dns.A); ok {
return record.A.String(), nil
}
}
return "", errors.New("no record found(custom)")
}
func systemLookup(domain string) (string, error) {
systemResolver := net.Resolver{PreferGo: true}
ips, err := systemResolver.LookupIPAddr(context.Background(), domain)
if err != nil {
return "", errors.New("could not resolve the domain(system)")
}
for _, ip := range ips {
return ip.String(), nil
}
return "", errors.New("no record found(system)")
}
func dohLookup(host string, domain string) (string, error) {
clt := d.getClient(useSystemDns)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
client := getDOHClient(host)
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(domain), dns.TypeA)
response, err := client.dohExchange(ctx, msg)
log.Debugf("[DNS] resolving %s using %s", domain, clt)
t := time.Now()
addrs, err := clt.Resolve(ctx, domain)
if err != nil {
return "", errors.New("could not resolve the domain(doh)")
return "", fmt.Errorf("%s: %w", clt, err)
}
lookupTime := time.Since(t).Milliseconds()
for _, answer := range response.Answer {
if record, ok := answer.(*dns.A); ok {
return record.A.String(), nil
addr := addrs[0].String()
log.Debugf("[DNS] resolved %s to %s in %d ms", domain, addr, lookupTime)
return addr, nil
}
func (d *Resolver) getClient(useSystemDns bool) client {
if useSystemDns {
return d.systemClient
} else {
return d.customClient
}
}
type SystemClient struct {
client *net.Resolver
}
func NewSystemClient() *SystemClient {
return &SystemClient{
client: &net.Resolver{PreferGo: true},
}
}
func (c *SystemClient) String() string {
return "SystemClient"
}
func (c *SystemClient) Resolve(ctx context.Context, host string) ([]net.IPAddr, error) {
addrs, err := c.client.LookupIPAddr(ctx, host)
if err != nil {
return []net.IPAddr{}, err
}
return addrs, nil
}
type sendMsgFunc = func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
type customDNSResult struct {
msg *dns.Msg
err error
}
type CustomClient struct {
server string
sendMsgFn sendMsgFunc
}
func (c *CustomClient) Resolve(ctx context.Context, host string) ([]net.IPAddr, error) {
queryTypes := []uint16{dns.TypeAAAA, dns.TypeA}
resultCh := c.makeLookups(ctx, host, queryTypes)
addrs, err := c.processResults(ctx, resultCh)
return addrs, err
}
func (c *CustomClient) makeLookups(ctx context.Context, host string, queryTypes []uint16) <-chan *customDNSResult {
var wg sync.WaitGroup
resCh := make(chan *customDNSResult)
lookup := func(qType uint16) {
defer wg.Done()
select {
case <-ctx.Done():
return
case resCh <- c.makeLookup(ctx, host, qType):
}
}
return "", errors.New("no record found(doh)")
for _, queryType := range queryTypes {
wg.Add(1)
go lookup(queryType)
}
go func() {
wg.Wait()
close(resCh)
}()
return resCh
}
func (c *CustomClient) makeLookup(ctx context.Context, host string, queryType uint16) *customDNSResult {
msg := c.newMsg(host, queryType)
resp, err := c.sendMsg(ctx, msg)
if err != nil {
queryName := recordTypeIDToName(queryType)
err = fmt.Errorf("resolving %s, query type %s: %w", host, queryName, err)
return &customDNSResult{err: err}
}
return &customDNSResult{msg: resp}
}
func (c *CustomClient) newMsg(host string, qType uint16) *dns.Msg {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(host), qType)
return msg
}
func (c *CustomClient) sendMsg(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
resp, err := c.sendMsgFn(ctx, msg)
return resp, err
}
func (c *CustomClient) processResults(ctx context.Context, resCh <-chan *customDNSResult) ([]net.IPAddr, error) {
var errs []error
var addrs []net.IPAddr
for result := range resCh {
if result.err != nil {
errs = append(errs, result.err)
continue
}
resultAddrs := parseAddrsFromMsg(result.msg)
addrs = append(addrs, resultAddrs...)
}
select {
case <-ctx.Done():
return nil, errors.New("cancelled")
default:
if len(addrs) == 0 {
return addrs, errors.Join(errs...)
}
}
sortAddrs(addrs)
return addrs, nil
}
func (c *CustomClient) String() string {
return fmt.Sprintf("CustomClient for %s", c.server)
}
func NewCustomClient(server string) *CustomClient {
clt := &dns.Client{}
return &CustomClient{
server: server,
sendMsgFn: func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
resp, _, err := clt.Exchange(msg, server)
return resp, err
},
}
}
func NewDoHClient(host string) *CustomClient {
server := net.JoinHostPort(host, "443")
clt := getDOHClient(server)
return &CustomClient{
server: server,
sendMsgFn: func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
return clt.dohExchange(ctx, msg)
},
}
}
func recordTypeIDToName(id uint16) string {
switch id {
case 1:
return "A"
case 28:
return "AAAA"
}
return strconv.FormatUint(uint64(id), 10)
}
func parseAddrsFromMsg(msg *dns.Msg) []net.IPAddr {
var addrs []net.IPAddr
for _, record := range msg.Answer {
switch ipRecord := record.(type) {
case *dns.A:
addrs = append(addrs, net.IPAddr{IP: ipRecord.A})
case *dns.AAAA:
addrs = append(addrs, net.IPAddr{IP: ipRecord.AAAA})
}
}
return addrs
}
func parseAddr(addr string) (net.IP, error) {
parsed, err := netip.ParseAddr(addr)
if err != nil {
return net.IP{}, fmt.Errorf("parsing %s as an IP address: %w", addr, err)
}
return parsed.AsSlice(), nil
}
func sortAddrs(addrs []net.IPAddr) {
addrselect.SortByRFC6724(addrs)
}

View File

@ -17,7 +17,7 @@ type Proxy struct {
addr string
port int
timeout int
resolver *dns.DnsResolver
resolver *dns.Resolver
windowSize int
allowedPattern []*regexp.Regexp
}
@ -34,7 +34,7 @@ func New(config *util.Config) *Proxy {
}
func (pxy *Proxy) Start() {
l, err := net.ListenTCP("tcp4", &net.TCPAddr{IP: net.ParseIP(pxy.addr), Port: pxy.port})
l, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP(pxy.addr), Port: pxy.port})
if err != nil {
log.Fatal("[PROXY] error creating listener: ", err)
os.Exit(1)
@ -114,11 +114,6 @@ func (pxy *Proxy) patternMatches(bytes []byte) bool {
}
func isLoopedRequest(ip net.IP) bool {
// we don't handle IPv6 at all it seems
if ip.To4() == nil {
return false
}
if ip.IsLoopback() {
return true
}
@ -133,7 +128,7 @@ func isLoopedRequest(ip net.IP) bool {
for _, addr := range addr {
if ipnet, ok := addr.(*net.IPNet); ok {
if ipnet.IP.To4() != nil && ipnet.IP.To4().Equal(ip) {
if ipnet.IP.Equal(ip) {
return true
}
}