From 0e36d3b31b9285a60157820c778a8c2abfd92c0b Mon Sep 17 00:00:00 2001 From: Star Brilliant Date: Wed, 21 Mar 2018 16:58:04 +0800 Subject: [PATCH] Content-Type auto detection for client --- doh-client/client.go | 64 ++++++++++++++++++++++++++++------- doh-client/google.go | 57 ++++++++++++++++++++----------- doh-client/ietf.go | 80 +++++++++++++++++++++++++++++--------------- doh-server/server.go | 2 ++ 4 files changed, 144 insertions(+), 59 deletions(-) diff --git a/doh-client/client.go b/doh-client/client.go index 0a7ca93..1624247 100644 --- a/doh-client/client.go +++ b/doh-client/client.go @@ -30,6 +30,7 @@ import ( "net" "net/http" "net/http/cookiejar" + "strings" "time" "../json-dns" @@ -46,6 +47,15 @@ type Client struct { httpClient *http.Client } +type DNSRequest struct { + response *http.Response + reply *dns.Msg + udpSize uint16 + ednsClientAddress net.IP + ednsClientNetmask uint8 + err error +} + func NewClient(conf *config) (c *Client, err error) { c = &Client{ conf: conf, @@ -144,20 +154,50 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { return } + requestType := "" if len(c.conf.UpstreamIETF) == 0 { - c.handlerFuncGoogle(w, r, isTCP) - return - } - if len(c.conf.UpstreamGoogle) == 0 { - c.handlerFuncIETF(w, r, isTCP) - return - } - numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF) - random := rand.Intn(numServers) - if random < len(c.conf.UpstreamGoogle) { - c.handlerFuncGoogle(w, r, isTCP) + requestType = "application/x-www-form-urlencoded" + } else if len(c.conf.UpstreamGoogle) == 0 { + requestType = "application/dns-udpwireformat" } else { - c.handlerFuncIETF(w, r, isTCP) + numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF) + random := rand.Intn(numServers) + if random < len(c.conf.UpstreamGoogle) { + requestType = "application/x-www-form-urlencoded" + } else { + requestType = "application/dns-udpwireformat" + } + } + + var req *DNSRequest + if requestType == "application/x-www-form-urlencoded" { + req = c.generateRequestGoogle(w, r, isTCP) + } else if requestType == "application/dns-udpwireformat" { + req = c.generateRequestIETF(w, r, isTCP) + } else { + panic("Unknown request Content-Type") + } + + contentType := "" + candidateType := strings.SplitN(req.response.Header.Get("Content-Type"), ";", 2)[0] + if candidateType == "application/json" { + contentType = "application/json" + } else if candidateType == "application/dns-udpwireformat" { + contentType = "application/dns-udpwireformat" + } else { + if requestType == "application/x-www-form-urlencoded" { + contentType = "application/json" + } else if requestType == "application/dns-udpwireformat" { + contentType = "application/dns-udpwireformat" + } + } + + if contentType == "application/json" { + c.parseResponseGoogle(w, r, isTCP, req) + } else if contentType == "application/dns-udpwireformat" { + c.parseResponseIETF(w, r, isTCP, req) + } else { + panic("Unknown response Content-Type") } } diff --git a/doh-client/google.go b/doh-client/google.go index e1c2241..2e1cfb1 100644 --- a/doh-client/google.go +++ b/doh-client/google.go @@ -39,14 +39,16 @@ import ( "github.com/miekg/dns" ) -func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { +func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest { reply := jsonDNS.PrepareReply(r) if len(r.Question) != 1 { log.Println("Number of questions is not 1") reply.Rcode = dns.RcodeFormatError w.WriteMsg(reply) - return + return &DNSRequest{ + err: &dns.Error{}, + } } question := &r.Question[0] // knot-resolver scrambles capitalization, I think it is unfriendly to cache @@ -85,9 +87,11 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) log.Println(err) reply.Rcode = dns.RcodeServerFailure w.WriteMsg(reply) - return + return &DNSRequest{ + err: err, + } } - req.Header.Set("Accept", "application/json") + req.Header.Set("Accept", "application/json, application/dns-udpwireformat") req.Header.Set("User-Agent", "DNS-over-HTTPS/1.1 (+https://github.com/m13253/dns-over-https)") resp, err := c.httpClient.Do(req) if err != nil { @@ -95,23 +99,36 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) reply.Rcode = dns.RcodeServerFailure w.WriteMsg(reply) c.httpTransport.CloseIdleConnections() - return + return &DNSRequest{ + err: err, + } } - if resp.StatusCode != 200 { - log.Printf("HTTP error: %s\n", resp.Status) - reply.Rcode = dns.RcodeServerFailure - contentType := resp.Header.Get("Content-Type") + + return &DNSRequest{ + response: resp, + reply: reply, + udpSize: udpSize, + ednsClientAddress: ednsClientAddress, + ednsClientNetmask: ednsClientNetmask, + } +} + +func (c *Client) parseResponseGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) { + if req.response.StatusCode != 200 { + log.Printf("HTTP error: %s\n", req.response.Status) + req.reply.Rcode = dns.RcodeServerFailure + contentType := req.response.Header.Get("Content-Type") if contentType != "application/json" && !strings.HasPrefix(contentType, "application/json;") { - w.WriteMsg(reply) + w.WriteMsg(req.reply) return } } - body, err := ioutil.ReadAll(resp.Body) + body, err := ioutil.ReadAll(req.response.Body) if err != nil { log.Println(err) - reply.Rcode = dns.RcodeServerFailure - w.WriteMsg(reply) + req.reply.Rcode = dns.RcodeServerFailure + w.WriteMsg(req.reply) return } @@ -119,8 +136,8 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) err = json.Unmarshal(body, &respJSON) if err != nil { log.Println(err) - reply.Rcode = dns.RcodeServerFailure - w.WriteMsg(reply) + req.reply.Rcode = dns.RcodeServerFailure + w.WriteMsg(req.reply) return } @@ -128,22 +145,22 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) log.Printf("DNS error: %s\n", respJSON.Comment) } - fullReply := jsonDNS.Unmarshal(reply, &respJSON, udpSize, ednsClientNetmask) + fullReply := jsonDNS.Unmarshal(req.reply, &respJSON, req.udpSize, req.ednsClientNetmask) buf, err := fullReply.Pack() if err != nil { log.Println(err) - reply.Rcode = dns.RcodeServerFailure - w.WriteMsg(reply) + req.reply.Rcode = dns.RcodeServerFailure + w.WriteMsg(req.reply) return } - if !isTCP && len(buf) > int(udpSize) { + if !isTCP && len(buf) > int(req.udpSize) { fullReply.Truncated = true buf, err = fullReply.Pack() if err != nil { log.Println(err) return } - buf = buf[:udpSize] + buf = buf[:req.udpSize] } w.Write(buf) } diff --git a/doh-client/ietf.go b/doh-client/ietf.go index a802926..a7a63b2 100644 --- a/doh-client/ietf.go +++ b/doh-client/ietf.go @@ -30,6 +30,7 @@ import ( "io/ioutil" "log" "math/rand" + "net" "net/http" "strconv" "strings" @@ -39,14 +40,16 @@ import ( "github.com/miekg/dns" ) -func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { +func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest { reply := jsonDNS.PrepareReply(r) if len(r.Question) != 1 { log.Println("Number of questions is not 1") reply.Rcode = dns.RcodeFormatError w.WriteMsg(reply) - return + return &DNSRequest{ + err: &dns.Error{}, + } } question := &r.Question[0] @@ -63,8 +66,6 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { fmt.Printf("%s - - [%s] \"%s IN %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionType) } - requestID := r.Id - r.Id = 0 question.Name = questionName opt := r.IsEdns0() udpSize := uint16(512) @@ -85,9 +86,10 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { break } } + ednsClientAddress, ednsClientNetmask := net.IP(nil), uint8(255) if edns0Subnet == nil { ednsClientFamily := uint16(0) - ednsClientAddress, ednsClientNetmask := c.findClientIP(w, r) + ednsClientAddress, ednsClientNetmask = c.findClientIP(w, r) if ednsClientAddress != nil { if ipv4 := ednsClientAddress.To4(); ipv4 != nil { ednsClientFamily = 1 @@ -105,15 +107,22 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { edns0Subnet.Address = ednsClientAddress opt.Option = append(opt.Option, edns0Subnet) } + } else { + ednsClientAddress, ednsClientNetmask = edns0Subnet.Address, edns0Subnet.SourceNetmask } + requestID := r.Id + r.Id = 0 requestBinary, err := r.Pack() if err != nil { log.Println(err) reply.Rcode = dns.RcodeFormatError w.WriteMsg(reply) - return + return &DNSRequest{ + err: err, + } } + r.Id = requestID requestBase64 := base64.RawURLEncoding.EncodeToString(requestBinary) numServers := len(c.conf.UpstreamIETF) @@ -127,7 +136,9 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { log.Println(err) reply.Rcode = dns.RcodeServerFailure w.WriteMsg(reply) - return + return &DNSRequest{ + err: err, + } } } else { req, err = http.NewRequest("POST", upstream, bytes.NewReader(requestBinary)) @@ -135,11 +146,13 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { log.Println(err) reply.Rcode = dns.RcodeServerFailure w.WriteMsg(reply) - return + return &DNSRequest{ + err: err, + } } req.Header.Set("Content-Type", "application/dns-udpwireformat") } - req.Header.Set("Accept", "application/dns-udpwireformat") + req.Header.Set("Accept", "application/dns-udpwireformat, application/json") req.Header.Set("User-Agent", "DNS-over-HTTPS/1.1 (+https://github.com/m13253/dns-over-https)") resp, err := c.httpClient.Do(req) if err != nil { @@ -147,26 +160,39 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { reply.Rcode = dns.RcodeServerFailure w.WriteMsg(reply) c.httpTransport.CloseIdleConnections() - return + return &DNSRequest{ + err: err, + } } - if resp.StatusCode != 200 { - log.Printf("HTTP error: %s\n", resp.Status) - reply.Rcode = dns.RcodeServerFailure - contentType := resp.Header.Get("Content-Type") + + return &DNSRequest{ + response: resp, + reply: reply, + udpSize: udpSize, + ednsClientAddress: ednsClientAddress, + ednsClientNetmask: ednsClientNetmask, + } +} + +func (c *Client) parseResponseIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) { + if req.response.StatusCode != 200 { + log.Printf("HTTP error: %s\n", req.response.Status) + req.reply.Rcode = dns.RcodeServerFailure + contentType := req.response.Header.Get("Content-Type") if contentType != "application/dns-udpwireformat" && !strings.HasPrefix(contentType, "application/dns-udpwireformat;") { - w.WriteMsg(reply) + w.WriteMsg(req.reply) return } } - body, err := ioutil.ReadAll(resp.Body) + body, err := ioutil.ReadAll(req.response.Body) if err != nil { log.Println(err) - reply.Rcode = dns.RcodeServerFailure - w.WriteMsg(reply) + req.reply.Rcode = dns.RcodeServerFailure + w.WriteMsg(req.reply) return } - headerNow := resp.Header.Get("Date") + headerNow := req.response.Header.Get("Date") now := time.Now().UTC() if headerNow != "" { if nowDate, err := time.Parse(http.TimeFormat, headerNow); err == nil { @@ -175,7 +201,7 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { log.Println(err) } } - headerLastModified := resp.Header.Get("Last-Modified") + headerLastModified := req.response.Header.Get("Last-Modified") lastModified := now if headerLastModified != "" { if lastModifiedDate, err := time.Parse(http.TimeFormat, headerLastModified); err == nil { @@ -193,12 +219,12 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { err = fullReply.Unpack(body) if err != nil { log.Println(err) - reply.Rcode = dns.RcodeServerFailure - w.WriteMsg(reply) + req.reply.Rcode = dns.RcodeServerFailure + w.WriteMsg(req.reply) return } - fullReply.Id = requestID + fullReply.Id = r.Id for _, rr := range fullReply.Answer { _ = fixRecordTTL(rr, timeDelta) } @@ -215,18 +241,18 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { buf, err := fullReply.Pack() if err != nil { log.Println(err) - reply.Rcode = dns.RcodeServerFailure - w.WriteMsg(reply) + req.reply.Rcode = dns.RcodeServerFailure + w.WriteMsg(req.reply) return } - if !isTCP && len(buf) > int(udpSize) { + if !isTCP && len(buf) > int(req.udpSize) { fullReply.Truncated = true buf, err = fullReply.Pack() if err != nil { log.Println(err) return } - buf = buf[:udpSize] + buf = buf[:req.udpSize] } w.Write(buf) } diff --git a/doh-server/server.go b/doh-server/server.go index a047c2e..ccbb79e 100644 --- a/doh-server/server.go +++ b/doh-server/server.go @@ -146,6 +146,8 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { s.generateResponseGoogle(w, r, req) } else if contentType == "application/dns-udpwireformat" { s.generateResponseIETF(w, r, req) + } else { + panic("Unknown response Content-Type") } }