Content-Type auto detection for client

This commit is contained in:
Star Brilliant
2018-03-21 16:58:04 +08:00
parent 5723558934
commit 0e36d3b31b
4 changed files with 144 additions and 59 deletions

View File

@@ -30,6 +30,7 @@ import (
"net"
"net/http"
"net/http/cookiejar"
"strings"
"time"
"../json-dns"
@@ -46,6 +47,15 @@ type Client struct {
httpClient *http.Client
}
type DNSRequest struct {
response *http.Response
reply *dns.Msg
udpSize uint16
ednsClientAddress net.IP
ednsClientNetmask uint8
err error
}
func NewClient(conf *config) (c *Client, err error) {
c = &Client{
conf: conf,
@@ -144,20 +154,50 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
return
}
requestType := ""
if len(c.conf.UpstreamIETF) == 0 {
c.handlerFuncGoogle(w, r, isTCP)
return
}
if len(c.conf.UpstreamGoogle) == 0 {
c.handlerFuncIETF(w, r, isTCP)
return
}
numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF)
random := rand.Intn(numServers)
if random < len(c.conf.UpstreamGoogle) {
c.handlerFuncGoogle(w, r, isTCP)
requestType = "application/x-www-form-urlencoded"
} else if len(c.conf.UpstreamGoogle) == 0 {
requestType = "application/dns-udpwireformat"
} else {
c.handlerFuncIETF(w, r, isTCP)
numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF)
random := rand.Intn(numServers)
if random < len(c.conf.UpstreamGoogle) {
requestType = "application/x-www-form-urlencoded"
} else {
requestType = "application/dns-udpwireformat"
}
}
var req *DNSRequest
if requestType == "application/x-www-form-urlencoded" {
req = c.generateRequestGoogle(w, r, isTCP)
} else if requestType == "application/dns-udpwireformat" {
req = c.generateRequestIETF(w, r, isTCP)
} else {
panic("Unknown request Content-Type")
}
contentType := ""
candidateType := strings.SplitN(req.response.Header.Get("Content-Type"), ";", 2)[0]
if candidateType == "application/json" {
contentType = "application/json"
} else if candidateType == "application/dns-udpwireformat" {
contentType = "application/dns-udpwireformat"
} else {
if requestType == "application/x-www-form-urlencoded" {
contentType = "application/json"
} else if requestType == "application/dns-udpwireformat" {
contentType = "application/dns-udpwireformat"
}
}
if contentType == "application/json" {
c.parseResponseGoogle(w, r, isTCP, req)
} else if contentType == "application/dns-udpwireformat" {
c.parseResponseIETF(w, r, isTCP, req)
} else {
panic("Unknown response Content-Type")
}
}

View File

@@ -39,14 +39,16 @@ import (
"github.com/miekg/dns"
)
func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
reply := jsonDNS.PrepareReply(r)
if len(r.Question) != 1 {
log.Println("Number of questions is not 1")
reply.Rcode = dns.RcodeFormatError
w.WriteMsg(reply)
return
return &DNSRequest{
err: &dns.Error{},
}
}
question := &r.Question[0]
// knot-resolver scrambles capitalization, I think it is unfriendly to cache
@@ -85,9 +87,11 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool)
log.Println(err)
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
return
return &DNSRequest{
err: err,
}
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept", "application/json, application/dns-udpwireformat")
req.Header.Set("User-Agent", "DNS-over-HTTPS/1.1 (+https://github.com/m13253/dns-over-https)")
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -95,23 +99,36 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool)
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
c.httpTransport.CloseIdleConnections()
return
return &DNSRequest{
err: err,
}
}
if resp.StatusCode != 200 {
log.Printf("HTTP error: %s\n", resp.Status)
reply.Rcode = dns.RcodeServerFailure
contentType := resp.Header.Get("Content-Type")
return &DNSRequest{
response: resp,
reply: reply,
udpSize: udpSize,
ednsClientAddress: ednsClientAddress,
ednsClientNetmask: ednsClientNetmask,
}
}
func (c *Client) parseResponseGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) {
if req.response.StatusCode != 200 {
log.Printf("HTTP error: %s\n", req.response.Status)
req.reply.Rcode = dns.RcodeServerFailure
contentType := req.response.Header.Get("Content-Type")
if contentType != "application/json" && !strings.HasPrefix(contentType, "application/json;") {
w.WriteMsg(reply)
w.WriteMsg(req.reply)
return
}
}
body, err := ioutil.ReadAll(resp.Body)
body, err := ioutil.ReadAll(req.response.Body)
if err != nil {
log.Println(err)
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(req.reply)
return
}
@@ -119,8 +136,8 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool)
err = json.Unmarshal(body, &respJSON)
if err != nil {
log.Println(err)
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(req.reply)
return
}
@@ -128,22 +145,22 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool)
log.Printf("DNS error: %s\n", respJSON.Comment)
}
fullReply := jsonDNS.Unmarshal(reply, &respJSON, udpSize, ednsClientNetmask)
fullReply := jsonDNS.Unmarshal(req.reply, &respJSON, req.udpSize, req.ednsClientNetmask)
buf, err := fullReply.Pack()
if err != nil {
log.Println(err)
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(req.reply)
return
}
if !isTCP && len(buf) > int(udpSize) {
if !isTCP && len(buf) > int(req.udpSize) {
fullReply.Truncated = true
buf, err = fullReply.Pack()
if err != nil {
log.Println(err)
return
}
buf = buf[:udpSize]
buf = buf[:req.udpSize]
}
w.Write(buf)
}

View File

@@ -30,6 +30,7 @@ import (
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"strconv"
"strings"
@@ -39,14 +40,16 @@ import (
"github.com/miekg/dns"
)
func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
reply := jsonDNS.PrepareReply(r)
if len(r.Question) != 1 {
log.Println("Number of questions is not 1")
reply.Rcode = dns.RcodeFormatError
w.WriteMsg(reply)
return
return &DNSRequest{
err: &dns.Error{},
}
}
question := &r.Question[0]
@@ -63,8 +66,6 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
fmt.Printf("%s - - [%s] \"%s IN %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionType)
}
requestID := r.Id
r.Id = 0
question.Name = questionName
opt := r.IsEdns0()
udpSize := uint16(512)
@@ -85,9 +86,10 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
break
}
}
ednsClientAddress, ednsClientNetmask := net.IP(nil), uint8(255)
if edns0Subnet == nil {
ednsClientFamily := uint16(0)
ednsClientAddress, ednsClientNetmask := c.findClientIP(w, r)
ednsClientAddress, ednsClientNetmask = c.findClientIP(w, r)
if ednsClientAddress != nil {
if ipv4 := ednsClientAddress.To4(); ipv4 != nil {
ednsClientFamily = 1
@@ -105,15 +107,22 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
edns0Subnet.Address = ednsClientAddress
opt.Option = append(opt.Option, edns0Subnet)
}
} else {
ednsClientAddress, ednsClientNetmask = edns0Subnet.Address, edns0Subnet.SourceNetmask
}
requestID := r.Id
r.Id = 0
requestBinary, err := r.Pack()
if err != nil {
log.Println(err)
reply.Rcode = dns.RcodeFormatError
w.WriteMsg(reply)
return
return &DNSRequest{
err: err,
}
}
r.Id = requestID
requestBase64 := base64.RawURLEncoding.EncodeToString(requestBinary)
numServers := len(c.conf.UpstreamIETF)
@@ -127,7 +136,9 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
log.Println(err)
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
return
return &DNSRequest{
err: err,
}
}
} else {
req, err = http.NewRequest("POST", upstream, bytes.NewReader(requestBinary))
@@ -135,11 +146,13 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
log.Println(err)
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
return
return &DNSRequest{
err: err,
}
}
req.Header.Set("Content-Type", "application/dns-udpwireformat")
}
req.Header.Set("Accept", "application/dns-udpwireformat")
req.Header.Set("Accept", "application/dns-udpwireformat, application/json")
req.Header.Set("User-Agent", "DNS-over-HTTPS/1.1 (+https://github.com/m13253/dns-over-https)")
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -147,26 +160,39 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
c.httpTransport.CloseIdleConnections()
return
return &DNSRequest{
err: err,
}
}
if resp.StatusCode != 200 {
log.Printf("HTTP error: %s\n", resp.Status)
reply.Rcode = dns.RcodeServerFailure
contentType := resp.Header.Get("Content-Type")
return &DNSRequest{
response: resp,
reply: reply,
udpSize: udpSize,
ednsClientAddress: ednsClientAddress,
ednsClientNetmask: ednsClientNetmask,
}
}
func (c *Client) parseResponseIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) {
if req.response.StatusCode != 200 {
log.Printf("HTTP error: %s\n", req.response.Status)
req.reply.Rcode = dns.RcodeServerFailure
contentType := req.response.Header.Get("Content-Type")
if contentType != "application/dns-udpwireformat" && !strings.HasPrefix(contentType, "application/dns-udpwireformat;") {
w.WriteMsg(reply)
w.WriteMsg(req.reply)
return
}
}
body, err := ioutil.ReadAll(resp.Body)
body, err := ioutil.ReadAll(req.response.Body)
if err != nil {
log.Println(err)
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(req.reply)
return
}
headerNow := resp.Header.Get("Date")
headerNow := req.response.Header.Get("Date")
now := time.Now().UTC()
if headerNow != "" {
if nowDate, err := time.Parse(http.TimeFormat, headerNow); err == nil {
@@ -175,7 +201,7 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
log.Println(err)
}
}
headerLastModified := resp.Header.Get("Last-Modified")
headerLastModified := req.response.Header.Get("Last-Modified")
lastModified := now
if headerLastModified != "" {
if lastModifiedDate, err := time.Parse(http.TimeFormat, headerLastModified); err == nil {
@@ -193,12 +219,12 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
err = fullReply.Unpack(body)
if err != nil {
log.Println(err)
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(req.reply)
return
}
fullReply.Id = requestID
fullReply.Id = r.Id
for _, rr := range fullReply.Answer {
_ = fixRecordTTL(rr, timeDelta)
}
@@ -215,18 +241,18 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
buf, err := fullReply.Pack()
if err != nil {
log.Println(err)
reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply)
req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(req.reply)
return
}
if !isTCP && len(buf) > int(udpSize) {
if !isTCP && len(buf) > int(req.udpSize) {
fullReply.Truncated = true
buf, err = fullReply.Pack()
if err != nil {
log.Println(err)
return
}
buf = buf[:udpSize]
buf = buf[:req.udpSize]
}
w.Write(buf)
}

View File

@@ -146,6 +146,8 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) {
s.generateResponseGoogle(w, r, req)
} else if contentType == "application/dns-udpwireformat" {
s.generateResponseIETF(w, r, req)
} else {
panic("Unknown response Content-Type")
}
}