mirror of
https://github.com/m13253/dns-over-https.git
synced 2026-03-29 20:40:35 +00:00
Add passthrough feature, tests are welcome
This commit is contained in:
@@ -25,11 +25,13 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -37,11 +39,15 @@ import (
|
||||
"github.com/m13253/dns-over-https/json-dns"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
conf *config
|
||||
bootstrap []string
|
||||
passthrough []string
|
||||
udpClient *dns.Client
|
||||
tcpClient *dns.Client
|
||||
udpServers []*dns.Server
|
||||
tcpServers []*dns.Server
|
||||
bootstrapResolver *net.Resolver
|
||||
@@ -69,6 +75,15 @@ func NewClient(conf *config) (c *Client, err error) {
|
||||
|
||||
udpHandler := dns.HandlerFunc(c.udpHandlerFunc)
|
||||
tcpHandler := dns.HandlerFunc(c.tcpHandlerFunc)
|
||||
c.udpClient = &dns.Client{
|
||||
Net: "udp",
|
||||
UDPSize: dns.DefaultMsgSize,
|
||||
Timeout: time.Duration(conf.Timeout) * time.Second,
|
||||
}
|
||||
c.tcpClient = &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: time.Duration(conf.Timeout) * time.Second,
|
||||
}
|
||||
for _, addr := range conf.Listen {
|
||||
c.udpServers = append(c.udpServers, &dns.Server{
|
||||
Addr: addr,
|
||||
@@ -105,6 +120,15 @@ func NewClient(conf *config) (c *Client, err error) {
|
||||
return conn, err
|
||||
},
|
||||
}
|
||||
if len(conf.Passthrough) != 0 {
|
||||
c.passthrough = make([]string, len(conf.Passthrough))
|
||||
for i, passthrough := range conf.Passthrough {
|
||||
if punycode, err := idna.ToASCII(passthrough); err != nil {
|
||||
passthrough = punycode
|
||||
}
|
||||
c.passthrough[i] = "." + strings.ToLower(strings.Trim(passthrough, ".")) + "."
|
||||
}
|
||||
}
|
||||
}
|
||||
// Most CDNs require Cookie support to prevent DDoS attack.
|
||||
// Disabling Cookie does not effectively prevent tracking,
|
||||
@@ -115,7 +139,7 @@ func NewClient(conf *config) (c *Client, err error) {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
c.cookieJar = nil;
|
||||
c.cookieJar = nil
|
||||
}
|
||||
|
||||
c.httpClientMux = new(sync.RWMutex)
|
||||
@@ -199,6 +223,65 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(r.Question) != 1 {
|
||||
log.Println("Number of questions is not 1")
|
||||
reply := jsonDNS.PrepareReply(r)
|
||||
reply.Rcode = dns.RcodeFormatError
|
||||
w.WriteMsg(reply)
|
||||
return
|
||||
}
|
||||
question := &r.Question[0]
|
||||
questionName := question.Name
|
||||
questionClass := ""
|
||||
if qclass, ok := dns.ClassToString[question.Qclass]; ok {
|
||||
questionClass = qclass
|
||||
} else {
|
||||
questionClass = strconv.FormatUint(uint64(question.Qclass), 10)
|
||||
}
|
||||
questionType := ""
|
||||
if qtype, ok := dns.TypeToString[question.Qtype]; ok {
|
||||
questionType = qtype
|
||||
} else {
|
||||
questionType = strconv.FormatUint(uint64(question.Qtype), 10)
|
||||
}
|
||||
if c.conf.Verbose {
|
||||
fmt.Printf("%s - - [%s] \"%s %s %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionClass, questionType)
|
||||
}
|
||||
|
||||
shouldPassthrough := false
|
||||
passthroughQuestionName := questionName
|
||||
if punycode, err := idna.ToASCII(passthroughQuestionName); err != nil {
|
||||
passthroughQuestionName = punycode
|
||||
}
|
||||
passthroughQuestionName = "." + strings.ToLower(strings.Trim(passthroughQuestionName, ".")) + "."
|
||||
for _, passthrough := range c.passthrough {
|
||||
if strings.HasSuffix(passthroughQuestionName, passthrough) {
|
||||
shouldPassthrough = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if shouldPassthrough {
|
||||
numServers := len(c.bootstrap)
|
||||
upstream := c.bootstrap[rand.Intn(numServers)]
|
||||
log.Printf("Request %s %s %s is passed through %s.\n", questionName, questionClass, questionType, upstream)
|
||||
var reply *dns.Msg
|
||||
var err error
|
||||
if !isTCP {
|
||||
reply, _, err = c.udpClient.Exchange(r, upstream)
|
||||
} else {
|
||||
reply, _, err = c.tcpClient.Exchange(r, upstream)
|
||||
}
|
||||
if err == nil || err == dns.ErrTruncated {
|
||||
w.WriteMsg(reply)
|
||||
return
|
||||
}
|
||||
log.Println(err)
|
||||
reply = jsonDNS.PrepareReply(r)
|
||||
reply.Rcode = dns.RcodeServerFailure
|
||||
w.WriteMsg(reply)
|
||||
return
|
||||
}
|
||||
|
||||
requestType := ""
|
||||
if len(c.conf.UpstreamIETF) == 0 {
|
||||
requestType = "application/dns-json"
|
||||
|
||||
@@ -34,6 +34,7 @@ type config struct {
|
||||
UpstreamGoogle []string `toml:"upstream_google"`
|
||||
UpstreamIETF []string `toml:"upstream_ietf"`
|
||||
Bootstrap []string `toml:"bootstrap"`
|
||||
Passthrough []string `toml:"passthrough"`
|
||||
Timeout uint `toml:"timeout"`
|
||||
NoCookies bool `toml:"no_cookies"`
|
||||
NoECS bool `toml:"no_ecs"`
|
||||
|
||||
@@ -58,6 +58,23 @@ bootstrap = [
|
||||
|
||||
]
|
||||
|
||||
# The domain names here are directly passed to bootstrap servers listed above,
|
||||
# allowing captive portal detection and systems without RTC to work.
|
||||
# Only effective if at least one bootstrap server is configured.
|
||||
passthrough = [
|
||||
"captive.apple.com",
|
||||
"connectivitycheck.gstatic.com",
|
||||
"msftconnecttest.com",
|
||||
"nmcheck.gnome.org",
|
||||
|
||||
"pool.ntp.org",
|
||||
"time.apple.com",
|
||||
"time.asia.apple.com",
|
||||
"time.euro.apple.com",
|
||||
"time.nist.gov",
|
||||
"time.windows.com",
|
||||
]
|
||||
|
||||
# Timeout for upstream request
|
||||
timeout = 30
|
||||
|
||||
|
||||
@@ -33,34 +33,28 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/m13253/dns-over-https/json-dns"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
|
||||
reply := jsonDNS.PrepareReply(r)
|
||||
|
||||
if len(r.Question) != 1 {
|
||||
log.Println("Number of questions is not 1")
|
||||
question := &r.Question[0]
|
||||
questionName := question.Name
|
||||
questionClass := question.Qclass
|
||||
if questionClass != dns.ClassINET {
|
||||
reply := jsonDNS.PrepareReply(r)
|
||||
reply.Rcode = dns.RcodeFormatError
|
||||
w.WriteMsg(reply)
|
||||
return &DNSRequest{
|
||||
err: &dns.Error{},
|
||||
}
|
||||
}
|
||||
question := &r.Question[0]
|
||||
questionName := question.Name
|
||||
questionType := ""
|
||||
if qtype, ok := dns.TypeToString[question.Qtype]; ok {
|
||||
questionType = qtype
|
||||
} else {
|
||||
questionType = strconv.Itoa(int(question.Qtype))
|
||||
}
|
||||
|
||||
if c.conf.Verbose {
|
||||
fmt.Printf("%s - - [%s] \"%s IN %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionType)
|
||||
questionType = strconv.FormatUint(uint64(question.Qtype), 10)
|
||||
}
|
||||
|
||||
numServers := len(c.conf.UpstreamGoogle)
|
||||
@@ -84,6 +78,7 @@ func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP b
|
||||
req, err := http.NewRequest("GET", requestURL, nil)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
reply := jsonDNS.PrepareReply(r)
|
||||
reply.Rcode = dns.RcodeServerFailure
|
||||
w.WriteMsg(reply)
|
||||
return &DNSRequest{
|
||||
@@ -97,6 +92,7 @@ func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP b
|
||||
c.httpClientMux.RUnlock()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
reply := jsonDNS.PrepareReply(r)
|
||||
reply.Rcode = dns.RcodeServerFailure
|
||||
w.WriteMsg(reply)
|
||||
err1 := c.newHTTPClient()
|
||||
@@ -110,7 +106,7 @@ func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP b
|
||||
|
||||
return &DNSRequest{
|
||||
response: resp,
|
||||
reply: reply,
|
||||
reply: jsonDNS.PrepareReply(r),
|
||||
udpSize: udpSize,
|
||||
ednsClientAddress: ednsClientAddress,
|
||||
ednsClientNetmask: ednsClientNetmask,
|
||||
|
||||
@@ -32,7 +32,6 @@ import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -41,31 +40,6 @@ import (
|
||||
)
|
||||
|
||||
func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
|
||||
reply := jsonDNS.PrepareReply(r)
|
||||
|
||||
if len(r.Question) != 1 {
|
||||
log.Println("Number of questions is not 1")
|
||||
reply.Rcode = dns.RcodeFormatError
|
||||
w.WriteMsg(reply)
|
||||
return &DNSRequest{
|
||||
err: &dns.Error{},
|
||||
}
|
||||
}
|
||||
|
||||
question := &r.Question[0]
|
||||
questionName := question.Name
|
||||
questionType := ""
|
||||
if qtype, ok := dns.TypeToString[question.Qtype]; ok {
|
||||
questionType = qtype
|
||||
} else {
|
||||
questionType = strconv.Itoa(int(question.Qtype))
|
||||
}
|
||||
|
||||
if c.conf.Verbose {
|
||||
fmt.Printf("%s - - [%s] \"%s IN %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionType)
|
||||
}
|
||||
|
||||
question.Name = questionName
|
||||
opt := r.IsEdns0()
|
||||
udpSize := uint16(512)
|
||||
if opt == nil {
|
||||
@@ -115,6 +89,7 @@ func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP boo
|
||||
requestBinary, err := r.Pack()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
reply := jsonDNS.PrepareReply(r)
|
||||
reply.Rcode = dns.RcodeFormatError
|
||||
w.WriteMsg(reply)
|
||||
return &DNSRequest{
|
||||
@@ -156,6 +131,7 @@ func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP boo
|
||||
c.httpClientMux.RUnlock()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
reply := jsonDNS.PrepareReply(r)
|
||||
reply.Rcode = dns.RcodeServerFailure
|
||||
w.WriteMsg(reply)
|
||||
err1 := c.newHTTPClient()
|
||||
@@ -169,7 +145,7 @@ func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP boo
|
||||
|
||||
return &DNSRequest{
|
||||
response: resp,
|
||||
reply: reply,
|
||||
reply: jsonDNS.PrepareReply(r),
|
||||
udpSize: udpSize,
|
||||
ednsClientAddress: ednsClientAddress,
|
||||
ednsClientNetmask: ednsClientNetmask,
|
||||
@@ -220,7 +196,7 @@ func (c *Client) parseResponseIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool,
|
||||
|
||||
fullReply := new(dns.Msg)
|
||||
err = fullReply.Unpack(body)
|
||||
if err != nil {
|
||||
if err != nil && err != dns.ErrTruncated {
|
||||
log.Println(err)
|
||||
req.reply.Rcode = dns.RcodeServerFailure
|
||||
w.WriteMsg(req.reply)
|
||||
|
||||
@@ -184,9 +184,9 @@ func (s *Server) generateResponseGoogle(w http.ResponseWriter, r *http.Request,
|
||||
w.Header().Set("Vary", "Accept")
|
||||
if respJSON.HaveTTL {
|
||||
if req.isTailored {
|
||||
w.Header().Set("Cache-Control", "private, max-age="+strconv.Itoa(int(respJSON.LeastTTL)))
|
||||
w.Header().Set("Cache-Control", "private, max-age="+strconv.FormatUint(uint64(respJSON.LeastTTL), 10))
|
||||
} else {
|
||||
w.Header().Set("Cache-Control", "public, max-age="+strconv.Itoa(int(respJSON.LeastTTL)))
|
||||
w.Header().Set("Cache-Control", "public, max-age="+strconv.FormatUint(uint64(respJSON.LeastTTL), 10))
|
||||
}
|
||||
w.Header().Set("Expires", respJSON.EarliestExpires.Format(http.TimeFormat))
|
||||
}
|
||||
|
||||
@@ -85,13 +85,13 @@ func (s *Server) parseRequestIETF(w http.ResponseWriter, r *http.Request) *DNSRe
|
||||
if qclass, ok := dns.ClassToString[question.Qclass]; ok {
|
||||
questionClass = qclass
|
||||
} else {
|
||||
questionClass = strconv.Itoa(int(question.Qclass))
|
||||
questionClass = strconv.FormatUint(uint64(question.Qclass), 10)
|
||||
}
|
||||
questionType := ""
|
||||
if qtype, ok := dns.TypeToString[question.Qtype]; ok {
|
||||
questionType = qtype
|
||||
} else {
|
||||
questionType = strconv.Itoa(int(question.Qtype))
|
||||
questionType = strconv.FormatUint(uint64(question.Qtype), 10)
|
||||
}
|
||||
fmt.Printf("%s - - [%s] \"%s %s %s\"\n", r.RemoteAddr, time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionClass, questionType)
|
||||
}
|
||||
@@ -165,9 +165,9 @@ func (s *Server) generateResponseIETF(w http.ResponseWriter, r *http.Request, re
|
||||
|
||||
if respJSON.HaveTTL {
|
||||
if req.isTailored {
|
||||
w.Header().Set("Cache-Control", "private, max-age="+strconv.Itoa(int(respJSON.LeastTTL)))
|
||||
w.Header().Set("Cache-Control", "private, max-age="+strconv.FormatUint(uint64(respJSON.LeastTTL), 10))
|
||||
} else {
|
||||
w.Header().Set("Cache-Control", "public, max-age="+strconv.Itoa(int(respJSON.LeastTTL)))
|
||||
w.Header().Set("Cache-Control", "public, max-age="+strconv.FormatUint(uint64(respJSON.LeastTTL), 10))
|
||||
}
|
||||
w.Header().Set("Expires", respJSON.EarliestExpires.Format(http.TimeFormat))
|
||||
}
|
||||
|
||||
@@ -245,7 +245,7 @@ func (s *Server) doDNSQuery(req *DNSRequest) (resp *DNSRequest, err error) {
|
||||
} else {
|
||||
req.response, _, err = s.tcpClient.Exchange(req.request, req.currentUpstream)
|
||||
}
|
||||
if err == nil {
|
||||
if err == nil || err == dns.ErrTruncated {
|
||||
return req, nil
|
||||
}
|
||||
log.Printf("DNS error from upstream %s: %s\n", req.currentUpstream, err.Error())
|
||||
|
||||
@@ -90,7 +90,7 @@ func Marshal(msg *dns.Msg) *Response {
|
||||
} else if ipv4 := clientAddress.To4(); ipv4 != nil {
|
||||
clientAddress = ipv4
|
||||
}
|
||||
resp.EdnsClientSubnet = clientAddress.String() + "/" + strconv.Itoa(int(edns0.SourceScope))
|
||||
resp.EdnsClientSubnet = clientAddress.String() + "/" + strconv.FormatUint(uint64(edns0.SourceScope), 10)
|
||||
}
|
||||
}
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user