From afa0d563d06449c042ee10a12040ca2ca737f35b Mon Sep 17 00:00:00 2001 From: Star Brilliant Date: Wed, 7 Nov 2018 17:10:39 +0800 Subject: [PATCH] Add passthrough feature, tests are welcome --- doh-client/client.go | 85 +++++++++++++++++++++++++++++++++++++- doh-client/config.go | 1 + doh-client/doh-client.conf | 17 ++++++++ doh-client/google.go | 22 ++++------ doh-client/ietf.go | 32 ++------------ doh-server/google.go | 4 +- doh-server/ietf.go | 8 ++-- doh-server/server.go | 2 +- json-dns/marshal.go | 2 +- 9 files changed, 123 insertions(+), 50 deletions(-) diff --git a/doh-client/client.go b/doh-client/client.go index d296c00..e6c1768 100644 --- a/doh-client/client.go +++ b/doh-client/client.go @@ -25,11 +25,13 @@ package main import ( "context" + "fmt" "log" "math/rand" "net" "net/http" "net/http/cookiejar" + "strconv" "strings" "sync" "time" @@ -37,11 +39,15 @@ import ( "github.com/m13253/dns-over-https/json-dns" "github.com/miekg/dns" "golang.org/x/net/http2" + "golang.org/x/net/idna" ) type Client struct { conf *config bootstrap []string + passthrough []string + udpClient *dns.Client + tcpClient *dns.Client udpServers []*dns.Server tcpServers []*dns.Server bootstrapResolver *net.Resolver @@ -69,6 +75,15 @@ func NewClient(conf *config) (c *Client, err error) { udpHandler := dns.HandlerFunc(c.udpHandlerFunc) tcpHandler := dns.HandlerFunc(c.tcpHandlerFunc) + c.udpClient = &dns.Client{ + Net: "udp", + UDPSize: dns.DefaultMsgSize, + Timeout: time.Duration(conf.Timeout) * time.Second, + } + c.tcpClient = &dns.Client{ + Net: "tcp", + Timeout: time.Duration(conf.Timeout) * time.Second, + } for _, addr := range conf.Listen { c.udpServers = append(c.udpServers, &dns.Server{ Addr: addr, @@ -105,6 +120,15 @@ func NewClient(conf *config) (c *Client, err error) { return conn, err }, } + if len(conf.Passthrough) != 0 { + c.passthrough = make([]string, len(conf.Passthrough)) + for i, passthrough := range conf.Passthrough { + if punycode, err := idna.ToASCII(passthrough); err != nil { + passthrough = punycode + } + c.passthrough[i] = "." + strings.ToLower(strings.Trim(passthrough, ".")) + "." + } + } } // Most CDNs require Cookie support to prevent DDoS attack. // Disabling Cookie does not effectively prevent tracking, @@ -115,7 +139,7 @@ func NewClient(conf *config) (c *Client, err error) { return nil, err } } else { - c.cookieJar = nil; + c.cookieJar = nil } c.httpClientMux = new(sync.RWMutex) @@ -199,6 +223,65 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { return } + if len(r.Question) != 1 { + log.Println("Number of questions is not 1") + reply := jsonDNS.PrepareReply(r) + reply.Rcode = dns.RcodeFormatError + w.WriteMsg(reply) + return + } + question := &r.Question[0] + questionName := question.Name + questionClass := "" + if qclass, ok := dns.ClassToString[question.Qclass]; ok { + questionClass = qclass + } else { + questionClass = strconv.FormatUint(uint64(question.Qclass), 10) + } + questionType := "" + if qtype, ok := dns.TypeToString[question.Qtype]; ok { + questionType = qtype + } else { + questionType = strconv.FormatUint(uint64(question.Qtype), 10) + } + if c.conf.Verbose { + fmt.Printf("%s - - [%s] \"%s %s %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionClass, questionType) + } + + shouldPassthrough := false + passthroughQuestionName := questionName + if punycode, err := idna.ToASCII(passthroughQuestionName); err != nil { + passthroughQuestionName = punycode + } + passthroughQuestionName = "." + strings.ToLower(strings.Trim(passthroughQuestionName, ".")) + "." + for _, passthrough := range c.passthrough { + if strings.HasSuffix(passthroughQuestionName, passthrough) { + shouldPassthrough = true + break + } + } + if shouldPassthrough { + numServers := len(c.bootstrap) + upstream := c.bootstrap[rand.Intn(numServers)] + log.Printf("Request %s %s %s is passed through %s.\n", questionName, questionClass, questionType, upstream) + var reply *dns.Msg + var err error + if !isTCP { + reply, _, err = c.udpClient.Exchange(r, upstream) + } else { + reply, _, err = c.tcpClient.Exchange(r, upstream) + } + if err == nil || err == dns.ErrTruncated { + w.WriteMsg(reply) + return + } + log.Println(err) + reply = jsonDNS.PrepareReply(r) + reply.Rcode = dns.RcodeServerFailure + w.WriteMsg(reply) + return + } + requestType := "" if len(c.conf.UpstreamIETF) == 0 { requestType = "application/dns-json" diff --git a/doh-client/config.go b/doh-client/config.go index 458bddb..e045ddb 100644 --- a/doh-client/config.go +++ b/doh-client/config.go @@ -34,6 +34,7 @@ type config struct { UpstreamGoogle []string `toml:"upstream_google"` UpstreamIETF []string `toml:"upstream_ietf"` Bootstrap []string `toml:"bootstrap"` + Passthrough []string `toml:"passthrough"` Timeout uint `toml:"timeout"` NoCookies bool `toml:"no_cookies"` NoECS bool `toml:"no_ecs"` diff --git a/doh-client/doh-client.conf b/doh-client/doh-client.conf index 3b5de14..5807f6a 100644 --- a/doh-client/doh-client.conf +++ b/doh-client/doh-client.conf @@ -58,6 +58,23 @@ bootstrap = [ ] +# The domain names here are directly passed to bootstrap servers listed above, +# allowing captive portal detection and systems without RTC to work. +# Only effective if at least one bootstrap server is configured. +passthrough = [ + "captive.apple.com", + "connectivitycheck.gstatic.com", + "msftconnecttest.com", + "nmcheck.gnome.org", + + "pool.ntp.org", + "time.apple.com", + "time.asia.apple.com", + "time.euro.apple.com", + "time.nist.gov", + "time.windows.com", +] + # Timeout for upstream request timeout = 30 diff --git a/doh-client/google.go b/doh-client/google.go index c5fdcf8..f2bb3ad 100644 --- a/doh-client/google.go +++ b/doh-client/google.go @@ -33,34 +33,28 @@ import ( "net/url" "strconv" "strings" - "time" "github.com/m13253/dns-over-https/json-dns" "github.com/miekg/dns" ) func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest { - reply := jsonDNS.PrepareReply(r) - - if len(r.Question) != 1 { - log.Println("Number of questions is not 1") + question := &r.Question[0] + questionName := question.Name + questionClass := question.Qclass + if questionClass != dns.ClassINET { + reply := jsonDNS.PrepareReply(r) reply.Rcode = dns.RcodeFormatError w.WriteMsg(reply) return &DNSRequest{ err: &dns.Error{}, } } - question := &r.Question[0] - questionName := question.Name questionType := "" if qtype, ok := dns.TypeToString[question.Qtype]; ok { questionType = qtype } else { - questionType = strconv.Itoa(int(question.Qtype)) - } - - if c.conf.Verbose { - fmt.Printf("%s - - [%s] \"%s IN %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionType) + questionType = strconv.FormatUint(uint64(question.Qtype), 10) } numServers := len(c.conf.UpstreamGoogle) @@ -84,6 +78,7 @@ func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP b req, err := http.NewRequest("GET", requestURL, nil) if err != nil { log.Println(err) + reply := jsonDNS.PrepareReply(r) reply.Rcode = dns.RcodeServerFailure w.WriteMsg(reply) return &DNSRequest{ @@ -97,6 +92,7 @@ func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP b c.httpClientMux.RUnlock() if err != nil { log.Println(err) + reply := jsonDNS.PrepareReply(r) reply.Rcode = dns.RcodeServerFailure w.WriteMsg(reply) err1 := c.newHTTPClient() @@ -110,7 +106,7 @@ func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP b return &DNSRequest{ response: resp, - reply: reply, + reply: jsonDNS.PrepareReply(r), udpSize: udpSize, ednsClientAddress: ednsClientAddress, ednsClientNetmask: ednsClientNetmask, diff --git a/doh-client/ietf.go b/doh-client/ietf.go index 830026e..e993b33 100644 --- a/doh-client/ietf.go +++ b/doh-client/ietf.go @@ -32,7 +32,6 @@ import ( "math/rand" "net" "net/http" - "strconv" "strings" "time" @@ -41,31 +40,6 @@ import ( ) func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest { - reply := jsonDNS.PrepareReply(r) - - if len(r.Question) != 1 { - log.Println("Number of questions is not 1") - reply.Rcode = dns.RcodeFormatError - w.WriteMsg(reply) - return &DNSRequest{ - err: &dns.Error{}, - } - } - - question := &r.Question[0] - questionName := question.Name - questionType := "" - if qtype, ok := dns.TypeToString[question.Qtype]; ok { - questionType = qtype - } else { - questionType = strconv.Itoa(int(question.Qtype)) - } - - if c.conf.Verbose { - fmt.Printf("%s - - [%s] \"%s IN %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionType) - } - - question.Name = questionName opt := r.IsEdns0() udpSize := uint16(512) if opt == nil { @@ -115,6 +89,7 @@ func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP boo requestBinary, err := r.Pack() if err != nil { log.Println(err) + reply := jsonDNS.PrepareReply(r) reply.Rcode = dns.RcodeFormatError w.WriteMsg(reply) return &DNSRequest{ @@ -156,6 +131,7 @@ func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP boo c.httpClientMux.RUnlock() if err != nil { log.Println(err) + reply := jsonDNS.PrepareReply(r) reply.Rcode = dns.RcodeServerFailure w.WriteMsg(reply) err1 := c.newHTTPClient() @@ -169,7 +145,7 @@ func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP boo return &DNSRequest{ response: resp, - reply: reply, + reply: jsonDNS.PrepareReply(r), udpSize: udpSize, ednsClientAddress: ednsClientAddress, ednsClientNetmask: ednsClientNetmask, @@ -220,7 +196,7 @@ func (c *Client) parseResponseIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool, fullReply := new(dns.Msg) err = fullReply.Unpack(body) - if err != nil { + if err != nil && err != dns.ErrTruncated { log.Println(err) req.reply.Rcode = dns.RcodeServerFailure w.WriteMsg(req.reply) diff --git a/doh-server/google.go b/doh-server/google.go index 96bd7ba..67312e9 100644 --- a/doh-server/google.go +++ b/doh-server/google.go @@ -184,9 +184,9 @@ func (s *Server) generateResponseGoogle(w http.ResponseWriter, r *http.Request, w.Header().Set("Vary", "Accept") if respJSON.HaveTTL { if req.isTailored { - w.Header().Set("Cache-Control", "private, max-age="+strconv.Itoa(int(respJSON.LeastTTL))) + w.Header().Set("Cache-Control", "private, max-age="+strconv.FormatUint(uint64(respJSON.LeastTTL), 10)) } else { - w.Header().Set("Cache-Control", "public, max-age="+strconv.Itoa(int(respJSON.LeastTTL))) + w.Header().Set("Cache-Control", "public, max-age="+strconv.FormatUint(uint64(respJSON.LeastTTL), 10)) } w.Header().Set("Expires", respJSON.EarliestExpires.Format(http.TimeFormat)) } diff --git a/doh-server/ietf.go b/doh-server/ietf.go index 875c5df..4e85b41 100644 --- a/doh-server/ietf.go +++ b/doh-server/ietf.go @@ -85,13 +85,13 @@ func (s *Server) parseRequestIETF(w http.ResponseWriter, r *http.Request) *DNSRe if qclass, ok := dns.ClassToString[question.Qclass]; ok { questionClass = qclass } else { - questionClass = strconv.Itoa(int(question.Qclass)) + questionClass = strconv.FormatUint(uint64(question.Qclass), 10) } questionType := "" if qtype, ok := dns.TypeToString[question.Qtype]; ok { questionType = qtype } else { - questionType = strconv.Itoa(int(question.Qtype)) + questionType = strconv.FormatUint(uint64(question.Qtype), 10) } fmt.Printf("%s - - [%s] \"%s %s %s\"\n", r.RemoteAddr, time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionClass, questionType) } @@ -165,9 +165,9 @@ func (s *Server) generateResponseIETF(w http.ResponseWriter, r *http.Request, re if respJSON.HaveTTL { if req.isTailored { - w.Header().Set("Cache-Control", "private, max-age="+strconv.Itoa(int(respJSON.LeastTTL))) + w.Header().Set("Cache-Control", "private, max-age="+strconv.FormatUint(uint64(respJSON.LeastTTL), 10)) } else { - w.Header().Set("Cache-Control", "public, max-age="+strconv.Itoa(int(respJSON.LeastTTL))) + w.Header().Set("Cache-Control", "public, max-age="+strconv.FormatUint(uint64(respJSON.LeastTTL), 10)) } w.Header().Set("Expires", respJSON.EarliestExpires.Format(http.TimeFormat)) } diff --git a/doh-server/server.go b/doh-server/server.go index bee558b..ba6a475 100644 --- a/doh-server/server.go +++ b/doh-server/server.go @@ -245,7 +245,7 @@ func (s *Server) doDNSQuery(req *DNSRequest) (resp *DNSRequest, err error) { } else { req.response, _, err = s.tcpClient.Exchange(req.request, req.currentUpstream) } - if err == nil { + if err == nil || err == dns.ErrTruncated { return req, nil } log.Printf("DNS error from upstream %s: %s\n", req.currentUpstream, err.Error()) diff --git a/json-dns/marshal.go b/json-dns/marshal.go index 568c40d..1f04dfe 100644 --- a/json-dns/marshal.go +++ b/json-dns/marshal.go @@ -90,7 +90,7 @@ func Marshal(msg *dns.Msg) *Response { } else if ipv4 := clientAddress.To4(); ipv4 != nil { clientAddress = ipv4 } - resp.EdnsClientSubnet = clientAddress.String() + "/" + strconv.Itoa(int(edns0.SourceScope)) + resp.EdnsClientSubnet = clientAddress.String() + "/" + strconv.FormatUint(uint64(edns0.SourceScope), 10) } } continue