From a09dfbbbc14ac55b3e42df8441bee49a4ecfc898 Mon Sep 17 00:00:00 2001 From: gdm85 Date: Wed, 16 Oct 2019 12:28:31 +0200 Subject: [PATCH] Add support for type prefix for upstream addresses Add support for DNS-over-TLS upstream addresses Remove tcp_only configuration option --- doh-server/config.go | 31 ++++++++++++++++++-- doh-server/doh-server.conf | 16 ++++++----- doh-server/server.go | 59 ++++++++++++++++++++++++++------------ 3 files changed, 78 insertions(+), 28 deletions(-) diff --git a/doh-server/config.go b/doh-server/config.go index f937c14..7507a43 100644 --- a/doh-server/config.go +++ b/doh-server/config.go @@ -25,6 +25,7 @@ package main import ( "fmt" + "regexp" "github.com/BurntSushi/toml" ) @@ -38,7 +39,6 @@ type config struct { Upstream []string `toml:"upstream"` Timeout uint `toml:"timeout"` Tries uint `toml:"tries"` - TCPOnly bool `toml:"tcp_only"` Verbose bool `toml:"verbose"` DebugHTTPHeaders []string `toml:"debug_http_headers"` LogGuessedIP bool `toml:"log_guessed_client_ip"` @@ -62,7 +62,7 @@ func loadConfig(path string) (*config, error) { conf.Path = "/dns-query" } if len(conf.Upstream) == 0 { - conf.Upstream = []string{"8.8.8.8:53", "8.8.4.4:53"} + conf.Upstream = []string{"udp:8.8.8.8:53", "udp:8.8.4.4:53"} } if conf.Timeout == 0 { conf.Timeout = 10 @@ -75,9 +75,36 @@ func loadConfig(path string) (*config, error) { return nil, &configError{"You must specify both -cert and -key to enable TLS"} } + // validate all upstreams + for _, us := range conf.Upstream { + address, t := addressAndType(us) + if address == "" { + return nil, &configError{"One of the upstreams has not a (udp|tcp|tcp-tls) prefix e.g. udp:1.1.1.1:53"} + } + + switch t { + case "tcp", "udp", "tcp-tls": + // OK + default: + return nil, &configError{"Invalid upstream prefix specified, choose one of: udp tcp tcp-tls"} + } + } + return conf, nil } +var rxUpstreamWithTypePrefix = regexp.MustCompile("^[a-z-]+(:)") + +func addressAndType(us string) (string, string) { + p := rxUpstreamWithTypePrefix.FindStringSubmatchIndex(us) + fmt.Println(p) + if len(p) != 4 { + return "", "" + } + + return us[p[2]+1:], us[:p[2]] +} + type configError struct { err string } diff --git a/doh-server/doh-server.conf b/doh-server/doh-server.conf index 0fed8d9..72e06ad 100644 --- a/doh-server/doh-server.conf +++ b/doh-server/doh-server.conf @@ -27,11 +27,16 @@ path = "/dns-query" # Upstream DNS resolver # If multiple servers are specified, a random one will be chosen each time. +# You can use "udp", "tcp" or "tcp-tls" for the type prefix. +# For "udp", UDP will first be used, and switch to TCP when the server asks to +# or the response is too large. +# For "tcp", only TCP will be used. +# For "tcp-tls", DNS-over-TLS (RFC 7858) will be used to secure the upstream connection. upstream = [ - "1.1.1.1:53", - "1.0.0.1:53", - "8.8.8.8:53", - "8.8.4.4:53", + "udp:1.1.1.1:53", + "udp:1.0.0.1:53", + "udp:8.8.8.8:53", + "udp:8.8.4.4:53", ] # Upstream timeout @@ -40,9 +45,6 @@ timeout = 10 # Number of tries if upstream DNS fails tries = 3 -# Only use TCP for DNS query -tcp_only = false - # Enable logging verbose = false diff --git a/doh-server/server.go b/doh-server/server.go index b8d61d5..24113eb 100644 --- a/doh-server/server.go +++ b/doh-server/server.go @@ -35,15 +35,16 @@ import ( "time" "github.com/gorilla/handlers" - "github.com/m13253/dns-over-https/json-dns" + jsonDNS "github.com/m13253/dns-over-https/json-dns" "github.com/miekg/dns" ) type Server struct { - conf *config - udpClient *dns.Client - tcpClient *dns.Client - servemux *http.ServeMux + conf *config + udpClient *dns.Client + tcpClient *dns.Client + tcpClientTLS *dns.Client + servemux *http.ServeMux } type DNSRequest struct { @@ -69,6 +70,10 @@ func NewServer(conf *config) (*Server, error) { Net: "tcp", Timeout: timeout, }, + tcpClientTLS: &dns.Client{ + Net: "tcp-tls", + Timeout: timeout, + }, servemux: http.NewServeMux(), } if conf.LocalAddr != "" { @@ -88,6 +93,10 @@ func NewServer(conf *config) (*Server, error) { Timeout: timeout, LocalAddr: tcpLocalAddr, } + s.tcpClientTLS.Dialer = &net.Dialer{ + Timeout: timeout, + LocalAddr: tcpLocalAddr, + } } s.servemux.HandleFunc(conf.Path, s.handlerFunc) return s, nil @@ -279,23 +288,35 @@ func (s *Server) doDNSQuery(ctx context.Context, req *DNSRequest) (resp *DNSRequ for i := uint(0); i < s.conf.Tries; i++ { req.currentUpstream = s.conf.Upstream[rand.Intn(numServers)] - // Use TCP if always configured to or if the Query type dictates it (AXFR) - if s.conf.TCPOnly || (s.indexQuestionType(req.request, dns.TypeAXFR) > -1) { - req.response, _, err = s.tcpClient.Exchange(req.request, req.currentUpstream) - } else { - req.response, _, err = s.udpClient.Exchange(req.request, req.currentUpstream) - if err == nil && req.response != nil && req.response.Truncated { - log.Println(err) - req.response, _, err = s.tcpClient.Exchange(req.request, req.currentUpstream) - } + upstream, t := addressAndType(req.currentUpstream) - // Retry with TCP if this was an IXFR request and we only received an SOA - if (s.indexQuestionType(req.request, dns.TypeIXFR) > -1) && - (len(req.response.Answer) == 1) && - (req.response.Answer[0].Header().Rrtype == dns.TypeSOA) { - req.response, _, err = s.tcpClient.Exchange(req.request, req.currentUpstream) + switch t { + default: + log.Printf("invalid DNS type %q in upstream %q", t, upstream) + return nil, &configError{"invalid DNS type"} + // Use DNS-over-TLS (DoT) if configured to do so + case "tcp-tls": + req.response, _, err = s.tcpClientTLS.Exchange(req.request, upstream) + case "tcp", "udp": + // Use TCP if always configured to or if the Query type dictates it (AXFR) + if t == "tcp" || (s.indexQuestionType(req.request, dns.TypeAXFR) > -1) { + req.response, _, err = s.tcpClient.Exchange(req.request, upstream) + } else { + req.response, _, err = s.udpClient.Exchange(req.request, upstream) + if err == nil && req.response != nil && req.response.Truncated { + log.Println(err) + req.response, _, err = s.tcpClient.Exchange(req.request, upstream) + } + + // Retry with TCP if this was an IXFR request and we only received an SOA + if (s.indexQuestionType(req.request, dns.TypeIXFR) > -1) && + (len(req.response.Answer) == 1) && + (req.response.Answer[0].Header().Rrtype == dns.TypeSOA) { + req.response, _, err = s.tcpClient.Exchange(req.request, upstream) + } } } + if err == nil { return req, nil }