mirror of
https://github.com/m13253/dns-over-https.git
synced 2026-03-30 12:05:38 +00:00
feat: support dual-stack for interface binding
This commit is contained in:
@@ -92,24 +92,24 @@ func NewClient(conf *config.Config) (c *Client, err error) {
|
||||
}
|
||||
|
||||
if c.conf.Other.Interface != "" {
|
||||
// Setup UDP Dialer
|
||||
udpLocalAddr, err := c.bindToInterface("udp")
|
||||
localV4, localV6, err := c.getInterfaceIPs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to bind passthrough UDP to interface %s: %v", c.conf.Other.Interface, err)
|
||||
return nil, fmt.Errorf("failed to get interface IPs for %s: %v", c.conf.Other.Interface, err)
|
||||
}
|
||||
c.udpClient.Dialer = &net.Dialer{
|
||||
Timeout: time.Duration(conf.Other.Timeout) * time.Second,
|
||||
LocalAddr: udpLocalAddr,
|
||||
var localAddr net.IP
|
||||
if localV4 != nil {
|
||||
localAddr = localV4
|
||||
} else {
|
||||
localAddr = localV6
|
||||
}
|
||||
|
||||
// Setup TCP Dialer
|
||||
tcpLocalAddr, err := c.bindToInterface("tcp")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to bind passthrough TCP to interface %s: %v", c.conf.Other.Interface, err)
|
||||
c.udpClient.Dialer = &net.Dialer{
|
||||
Timeout: time.Duration(conf.Other.Timeout) * time.Second,
|
||||
LocalAddr: &net.UDPAddr{IP: localAddr},
|
||||
}
|
||||
c.tcpClient.Dialer = &net.Dialer{
|
||||
Timeout: time.Duration(conf.Other.Timeout) * time.Second,
|
||||
LocalAddr: tcpLocalAddr,
|
||||
LocalAddr: &net.TCPAddr{IP: localAddr},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,11 +144,35 @@ func NewClient(conf *config.Config) (c *Client, err error) {
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
if c.conf.Other.Interface != "" {
|
||||
localAddr, err := c.bindToInterface(network)
|
||||
localV4, localV6, err := c.getInterfaceIPs()
|
||||
if err != nil {
|
||||
log.Printf("Bootstrap dial warning: %v", err)
|
||||
} else {
|
||||
d.LocalAddr = localAddr
|
||||
numServers := len(c.bootstrap)
|
||||
bootstrap := c.bootstrap[rand.Intn(numServers)]
|
||||
host, _, _ := net.SplitHostPort(bootstrap)
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
if ip.To4() != nil {
|
||||
if localV4 != nil {
|
||||
if strings.HasPrefix(network, "udp") {
|
||||
d.LocalAddr = &net.UDPAddr{IP: localV4}
|
||||
} else {
|
||||
d.LocalAddr = &net.TCPAddr{IP: localV4}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if localV6 != nil {
|
||||
if strings.HasPrefix(network, "udp") {
|
||||
d.LocalAddr = &net.UDPAddr{IP: localV6}
|
||||
} else {
|
||||
d.LocalAddr = &net.TCPAddr{IP: localV6}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
conn, err := d.DialContext(ctx, network, bootstrap)
|
||||
return conn, err
|
||||
}
|
||||
}
|
||||
numServers := len(c.bootstrap)
|
||||
@@ -266,22 +290,72 @@ func (c *Client) newHTTPClient() error {
|
||||
if c.httpTransport != nil {
|
||||
c.httpTransport.CloseIdleConnections()
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
|
||||
localV4, localV6, err := c.getInterfaceIPs()
|
||||
if err != nil {
|
||||
log.Printf("Interface binding error: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
baseDialer := &net.Dialer{
|
||||
Timeout: time.Duration(c.conf.Other.Timeout) * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
// DualStack: true,
|
||||
Resolver: c.bootstrapResolver,
|
||||
}
|
||||
if c.conf.Other.Interface != "" {
|
||||
localAddr, err := c.bindToInterface("tcp")
|
||||
if err != nil {
|
||||
log.Printf("Failed to resolve interface %s: %v", c.conf.Other.Interface, err)
|
||||
return err
|
||||
}
|
||||
dialer.LocalAddr = localAddr
|
||||
Resolver: c.bootstrapResolver,
|
||||
}
|
||||
|
||||
c.httpTransport = &http.Transport{
|
||||
DialContext: dialer.DialContext,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if c.conf.Other.Interface == "" {
|
||||
return baseDialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
if network == "tcp4" && localV4 != nil {
|
||||
d := *baseDialer
|
||||
d.LocalAddr = &net.TCPAddr{IP: localV4}
|
||||
return d.DialContext(ctx, network, addr)
|
||||
}
|
||||
if network == "tcp6" && localV6 != nil {
|
||||
d := *baseDialer
|
||||
d.LocalAddr = &net.TCPAddr{IP: localV6}
|
||||
return d.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
// Manual Dual-Stack: Resolve host and try compatible families sequentially
|
||||
host, port, _ := net.SplitHostPort(addr)
|
||||
ips, err := c.bootstrapResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, ip := range ips {
|
||||
d := *baseDialer
|
||||
targetAddr := net.JoinHostPort(ip.String(), port)
|
||||
|
||||
if ip.IP.To4() != nil {
|
||||
if localV4 == nil {
|
||||
continue
|
||||
}
|
||||
d.LocalAddr = &net.TCPAddr{IP: localV4}
|
||||
} else {
|
||||
if localV6 == nil {
|
||||
continue
|
||||
}
|
||||
d.LocalAddr = &net.TCPAddr{IP: localV6}
|
||||
}
|
||||
|
||||
conn, err := d.DialContext(ctx, "tcp", targetAddr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, fmt.Errorf("connection to %s failed: no matching local/remote IP families on interface %s", addr, c.conf.Other.Interface)
|
||||
},
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
@@ -290,15 +364,18 @@ func (c *Client) newHTTPClient() error {
|
||||
TLSHandshakeTimeout: time.Duration(c.conf.Other.Timeout) * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.conf.Other.TLSInsecureSkipVerify},
|
||||
}
|
||||
|
||||
if c.conf.Other.NoIPv6 {
|
||||
originalDial := c.httpTransport.DialContext
|
||||
c.httpTransport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if strings.HasPrefix(network, "tcp") {
|
||||
network = "tcp4"
|
||||
}
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
return originalDial(ctx, network, address)
|
||||
}
|
||||
}
|
||||
err := http2.ConfigureTransport(c.httpTransport)
|
||||
|
||||
err = http2.ConfigureTransport(c.httpTransport)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -525,49 +602,37 @@ func (c *Client) findClientIP(w dns.ResponseWriter, r *dns.Msg) (ednsClientAddre
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) bindToInterface(network string) (net.Addr, error) {
|
||||
// getInterfaceIPs returns the first valid IPv4 and IPv6 addresses found on the interface
|
||||
func (c *Client) getInterfaceIPs() (v4, v6 net.IP, err error) {
|
||||
if c.conf.Other.Interface == "" {
|
||||
return nil, nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
ifi, err := net.InterfaceByName(c.conf.Other.Interface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
addrs, err := ifi.Addrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Determine if we need IPv4 or IPv6 based on the network string (e.g., "tcp4", "udp6")
|
||||
wantIPv6 := strings.Contains(network, "6")
|
||||
wantIPv4 := strings.Contains(network, "4") || !wantIPv6 // Default to 4 if not specified, or if generic "tcp"/"udp"
|
||||
|
||||
for _, addr := range addrs {
|
||||
ip, _, err := net.ParseCIDR(addr.String())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if we want IPv4 but got IPv6
|
||||
if ip.To4() == nil && wantIPv4 && !wantIPv6 {
|
||||
continue
|
||||
}
|
||||
// Skip if we want IPv6 but got IPv4
|
||||
if ip.To4() != nil && wantIPv6 {
|
||||
continue
|
||||
}
|
||||
// Skip IPv6 if disabled in config
|
||||
if ip.To4() == nil && c.conf.Other.NoIPv6 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Return the appropriate address type
|
||||
if strings.HasPrefix(network, "tcp") {
|
||||
return &net.TCPAddr{IP: ip}, nil
|
||||
}
|
||||
if strings.HasPrefix(network, "udp") {
|
||||
return &net.UDPAddr{IP: ip}, nil
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
if v4 == nil {
|
||||
v4 = ip4
|
||||
}
|
||||
} else {
|
||||
if v6 == nil && !c.conf.Other.NoIPv6 {
|
||||
v6 = ip
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no suitable address found on interface %s for network %s", c.conf.Other.Interface, network)
|
||||
if v4 == nil && v6 == nil {
|
||||
return nil, nil, fmt.Errorf("no valid IP addresses found on interface %s", c.conf.Other.Interface)
|
||||
}
|
||||
return v4, v6, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user