diff --git a/doh-client/client.go b/doh-client/client.go index 5a80455..88b7168 100644 --- a/doh-client/client.go +++ b/doh-client/client.go @@ -217,6 +217,9 @@ func (c *Client) Start() error { } func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Timeout)*time.Second) + defer cancel() + if r.Response == true { log.Println("Received a response packet") return @@ -298,9 +301,9 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { var req *DNSRequest if requestType == "application/dns-json" { - req = c.generateRequestGoogle(w, r, isTCP) + req = c.generateRequestGoogle(ctx, w, r, isTCP) } else if requestType == "application/dns-message" { - req = c.generateRequestIETF(w, r, isTCP) + req = c.generateRequestIETF(ctx, w, r, isTCP) } else { panic("Unknown request Content-Type") } @@ -334,9 +337,9 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { } if contentType == "application/json" { - c.parseResponseGoogle(w, r, isTCP, req) + c.parseResponseGoogle(ctx, w, r, isTCP, req) } else if contentType == "application/dns-message" { - c.parseResponseIETF(w, r, isTCP, req) + c.parseResponseIETF(ctx, w, r, isTCP, req) } else { panic("Unknown response Content-Type") } diff --git a/doh-client/google.go b/doh-client/google.go index 5661a9f..b02feb0 100644 --- a/doh-client/google.go +++ b/doh-client/google.go @@ -34,13 +34,12 @@ import ( "net/url" "strconv" "strings" - "time" "github.com/m13253/dns-over-https/json-dns" "github.com/miekg/dns" ) -func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest { +func (c *Client) generateRequestGoogle(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest { question := &r.Question[0] questionName := question.Name questionClass := question.Qclass @@ -89,12 +88,10 @@ func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP b } req.Header.Set("Accept", "application/json, application/dns-message, application/dns-udpwireformat") req.Header.Set("User-Agent", USER_AGENT) - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Timeout)*time.Second) req = req.WithContext(ctx) c.httpClientMux.RLock() resp, err := c.httpClient.Do(req) c.httpClientMux.RUnlock() - cancel() if err != nil { log.Println(err) reply := jsonDNS.PrepareReply(r) @@ -115,7 +112,7 @@ func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP b } } -func (c *Client) parseResponseGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) { +func (c *Client) parseResponseGoogle(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) { if req.response.StatusCode != 200 { log.Printf("HTTP error from upstream %s: %s\n", req.currentUpstream, req.response.Status) req.reply.Rcode = dns.RcodeServerFailure diff --git a/doh-client/ietf.go b/doh-client/ietf.go index d02d46c..b8e2816 100644 --- a/doh-client/ietf.go +++ b/doh-client/ietf.go @@ -40,7 +40,7 @@ import ( "github.com/miekg/dns" ) -func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest { +func (c *Client) generateRequestIETF(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest { opt := r.IsEdns0() udpSize := uint16(512) if opt == nil { @@ -127,12 +127,10 @@ func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP boo } req.Header.Set("Accept", "application/dns-message, application/dns-udpwireformat, application/json") req.Header.Set("User-Agent", USER_AGENT) - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Timeout)*time.Second) req = req.WithContext(ctx) c.httpClientMux.RLock() resp, err := c.httpClient.Do(req) c.httpClientMux.RUnlock() - cancel() if err != nil { log.Println(err) reply := jsonDNS.PrepareReply(r) @@ -153,7 +151,7 @@ func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP boo } } -func (c *Client) parseResponseIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) { +func (c *Client) parseResponseIETF(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) { if req.response.StatusCode != 200 { log.Printf("HTTP error from upstream %s: %s\n", req.currentUpstream, req.response.Status) req.reply.Rcode = dns.RcodeServerFailure diff --git a/doh-server/google.go b/doh-server/google.go index 67312e9..8504784 100644 --- a/doh-server/google.go +++ b/doh-server/google.go @@ -24,6 +24,7 @@ package main import ( + "context" "encoding/json" "fmt" "log" @@ -38,7 +39,7 @@ import ( "golang.org/x/net/idna" ) -func (s *Server) parseRequestGoogle(w http.ResponseWriter, r *http.Request) *DNSRequest { +func (s *Server) parseRequestGoogle(ctx context.Context, w http.ResponseWriter, r *http.Request) *DNSRequest { name := r.FormValue("name") if name == "" { return &DNSRequest{ @@ -168,7 +169,7 @@ func (s *Server) parseRequestGoogle(w http.ResponseWriter, r *http.Request) *DNS } } -func (s *Server) generateResponseGoogle(w http.ResponseWriter, r *http.Request, req *DNSRequest) { +func (s *Server) generateResponseGoogle(ctx context.Context, w http.ResponseWriter, r *http.Request, req *DNSRequest) { respJSON := jsonDNS.Marshal(req.response) respStr, err := json.Marshal(respJSON) if err != nil { diff --git a/doh-server/ietf.go b/doh-server/ietf.go index 4e85b41..f21b818 100644 --- a/doh-server/ietf.go +++ b/doh-server/ietf.go @@ -25,6 +25,7 @@ package main import ( "bytes" + "context" "encoding/base64" "fmt" "io/ioutil" @@ -38,7 +39,7 @@ import ( "github.com/miekg/dns" ) -func (s *Server) parseRequestIETF(w http.ResponseWriter, r *http.Request) *DNSRequest { +func (s *Server) parseRequestIETF(ctx context.Context, w http.ResponseWriter, r *http.Request) *DNSRequest { requestBase64 := r.FormValue("dns") requestBinary, err := base64.RawURLEncoding.DecodeString(requestBase64) if err != nil { @@ -145,7 +146,7 @@ func (s *Server) parseRequestIETF(w http.ResponseWriter, r *http.Request) *DNSRe } } -func (s *Server) generateResponseIETF(w http.ResponseWriter, r *http.Request, req *DNSRequest) { +func (s *Server) generateResponseIETF(ctx context.Context, w http.ResponseWriter, r *http.Request, req *DNSRequest) { respJSON := jsonDNS.Marshal(req.response) req.response.Id = req.transactionID respBytes, err := req.response.Pack() diff --git a/doh-server/server.go b/doh-server/server.go index ba6a475..2e59cfd 100644 --- a/doh-server/server.go +++ b/doh-server/server.go @@ -24,6 +24,7 @@ package main import ( + "context" "fmt" "log" "math/rand" @@ -105,6 +106,8 @@ func (s *Server) Start() error { } func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + w.Header().Set("Server", USER_AGENT) w.Header().Set("X-Powered-By", USER_AGENT) @@ -158,11 +161,11 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { var req *DNSRequest if contentType == "application/dns-json" { - req = s.parseRequestGoogle(w, r) + req = s.parseRequestGoogle(ctx, w, r) } else if contentType == "application/dns-message" { - req = s.parseRequestIETF(w, r) + req = s.parseRequestIETF(ctx, w, r) } else if contentType == "application/dns-udpwireformat" { - req = s.parseRequestIETF(w, r) + req = s.parseRequestIETF(ctx, w, r) } else { jsonDNS.FormatError(w, fmt.Sprintf("Invalid argument value: \"ct\" = %q", contentType), 415) return @@ -178,16 +181,16 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { req = s.patchRootRD(req) var err error - req, err = s.doDNSQuery(req) + req, err = s.doDNSQuery(ctx, req) if err != nil { jsonDNS.FormatError(w, fmt.Sprintf("DNS query failure (%s)", err.Error()), 503) return } if responseType == "application/json" { - s.generateResponseGoogle(w, r, req) + s.generateResponseGoogle(ctx, w, r, req) } else if responseType == "application/dns-message" { - s.generateResponseIETF(w, r, req) + s.generateResponseIETF(ctx, w, r, req) } else { panic("Unknown response Content-Type") } @@ -232,7 +235,8 @@ func (s *Server) patchRootRD(req *DNSRequest) *DNSRequest { return req } -func (s *Server) doDNSQuery(req *DNSRequest) (resp *DNSRequest, err error) { +func (s *Server) doDNSQuery(ctx context.Context, req *DNSRequest) (resp *DNSRequest, err error) { + // TODO(m13253): Make ctx work. Waiting for a patch for ExchangeContext from miekg/dns. numServers := len(s.conf.Upstream) for i := uint(0); i < s.conf.Tries; i++ { req.currentUpstream = s.conf.Upstream[rand.Intn(numServers)]