Compare commits

...

5 Commits

Author SHA1 Message Date
Star Brilliant
972d404ebc Add Last-Modified header 2018-03-23 15:32:51 +08:00
Star Brilliant
1be17bff4d Fix a problem when a single HTTP error crashes the program 2018-03-23 15:28:42 +08:00
Star Brilliant
ab2bf57995 Comment out the Googl experimental server 2018-03-21 17:17:14 +08:00
Star Brilliant
06b700cb7e Fix server Content-Type problem 2018-03-21 17:07:40 +08:00
Star Brilliant
0e36d3b31b Content-Type auto detection for client 2018-03-21 16:58:42 +08:00
6 changed files with 154 additions and 61 deletions

View File

@@ -30,6 +30,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"strings"
"time" "time"
"../json-dns" "../json-dns"
@@ -46,6 +47,15 @@ type Client struct {
httpClient *http.Client httpClient *http.Client
} }
type DNSRequest struct {
response *http.Response
reply *dns.Msg
udpSize uint16
ednsClientAddress net.IP
ednsClientNetmask uint8
err error
}
func NewClient(conf *config) (c *Client, err error) { func NewClient(conf *config) (c *Client, err error) {
c = &Client{ c = &Client{
conf: conf, conf: conf,
@@ -144,20 +154,54 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
return return
} }
requestType := ""
if len(c.conf.UpstreamIETF) == 0 { if len(c.conf.UpstreamIETF) == 0 {
c.handlerFuncGoogle(w, r, isTCP) requestType = "application/x-www-form-urlencoded"
return } else if len(c.conf.UpstreamGoogle) == 0 {
} requestType = "application/dns-udpwireformat"
if len(c.conf.UpstreamGoogle) == 0 {
c.handlerFuncIETF(w, r, isTCP)
return
}
numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF)
random := rand.Intn(numServers)
if random < len(c.conf.UpstreamGoogle) {
c.handlerFuncGoogle(w, r, isTCP)
} else { } else {
c.handlerFuncIETF(w, r, isTCP) numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF)
random := rand.Intn(numServers)
if random < len(c.conf.UpstreamGoogle) {
requestType = "application/x-www-form-urlencoded"
} else {
requestType = "application/dns-udpwireformat"
}
}
var req *DNSRequest
if requestType == "application/x-www-form-urlencoded" {
req = c.generateRequestGoogle(w, r, isTCP)
} else if requestType == "application/dns-udpwireformat" {
req = c.generateRequestIETF(w, r, isTCP)
} else {
panic("Unknown request Content-Type")
}
if req.err != nil {
return
}
contentType := ""
candidateType := strings.SplitN(req.response.Header.Get("Content-Type"), ";", 2)[0]
if candidateType == "application/json" {
contentType = "application/json"
} else if candidateType == "application/dns-udpwireformat" {
contentType = "application/dns-udpwireformat"
} else {
if requestType == "application/x-www-form-urlencoded" {
contentType = "application/json"
} else if requestType == "application/dns-udpwireformat" {
contentType = "application/dns-udpwireformat"
}
}
if contentType == "application/json" {
c.parseResponseGoogle(w, r, isTCP, req)
} else if contentType == "application/dns-udpwireformat" {
c.parseResponseIETF(w, r, isTCP, req)
} else {
panic("Unknown response Content-Type")
} }
} }

View File

@@ -7,7 +7,7 @@ upstream_google = [
"https://dns.google.com/resolve", "https://dns.google.com/resolve",
] ]
upstream_ietf = [ upstream_ietf = [
"https://dns.google.com/experimental", #"https://dns.google.com/experimental",
] ]
# Bootstrap DNS server to resolve the address of the upstream resolver # Bootstrap DNS server to resolve the address of the upstream resolver

View File

@@ -39,14 +39,16 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
reply := jsonDNS.PrepareReply(r) reply := jsonDNS.PrepareReply(r)
if len(r.Question) != 1 { if len(r.Question) != 1 {
log.Println("Number of questions is not 1") log.Println("Number of questions is not 1")
reply.Rcode = dns.RcodeFormatError reply.Rcode = dns.RcodeFormatError
w.WriteMsg(reply) w.WriteMsg(reply)
return return &DNSRequest{
err: &dns.Error{},
}
} }
question := &r.Question[0] question := &r.Question[0]
// knot-resolver scrambles capitalization, I think it is unfriendly to cache // knot-resolver scrambles capitalization, I think it is unfriendly to cache
@@ -85,9 +87,11 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool)
log.Println(err) log.Println(err)
reply.Rcode = dns.RcodeServerFailure reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(reply)
return return &DNSRequest{
err: err,
}
} }
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json, application/dns-udpwireformat")
req.Header.Set("User-Agent", "DNS-over-HTTPS/1.1 (+https://github.com/m13253/dns-over-https)") req.Header.Set("User-Agent", "DNS-over-HTTPS/1.1 (+https://github.com/m13253/dns-over-https)")
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
@@ -95,23 +99,36 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool)
reply.Rcode = dns.RcodeServerFailure reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(reply)
c.httpTransport.CloseIdleConnections() c.httpTransport.CloseIdleConnections()
return return &DNSRequest{
err: err,
}
} }
if resp.StatusCode != 200 {
log.Printf("HTTP error: %s\n", resp.Status) return &DNSRequest{
reply.Rcode = dns.RcodeServerFailure response: resp,
contentType := resp.Header.Get("Content-Type") reply: reply,
udpSize: udpSize,
ednsClientAddress: ednsClientAddress,
ednsClientNetmask: ednsClientNetmask,
}
}
func (c *Client) parseResponseGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) {
if req.response.StatusCode != 200 {
log.Printf("HTTP error: %s\n", req.response.Status)
req.reply.Rcode = dns.RcodeServerFailure
contentType := req.response.Header.Get("Content-Type")
if contentType != "application/json" && !strings.HasPrefix(contentType, "application/json;") { if contentType != "application/json" && !strings.HasPrefix(contentType, "application/json;") {
w.WriteMsg(reply) w.WriteMsg(req.reply)
return return
} }
} }
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(req.response.Body)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
reply.Rcode = dns.RcodeServerFailure req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(req.reply)
return return
} }
@@ -119,8 +136,8 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool)
err = json.Unmarshal(body, &respJSON) err = json.Unmarshal(body, &respJSON)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
reply.Rcode = dns.RcodeServerFailure req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(req.reply)
return return
} }
@@ -128,22 +145,22 @@ func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool)
log.Printf("DNS error: %s\n", respJSON.Comment) log.Printf("DNS error: %s\n", respJSON.Comment)
} }
fullReply := jsonDNS.Unmarshal(reply, &respJSON, udpSize, ednsClientNetmask) fullReply := jsonDNS.Unmarshal(req.reply, &respJSON, req.udpSize, req.ednsClientNetmask)
buf, err := fullReply.Pack() buf, err := fullReply.Pack()
if err != nil { if err != nil {
log.Println(err) log.Println(err)
reply.Rcode = dns.RcodeServerFailure req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(req.reply)
return return
} }
if !isTCP && len(buf) > int(udpSize) { if !isTCP && len(buf) > int(req.udpSize) {
fullReply.Truncated = true fullReply.Truncated = true
buf, err = fullReply.Pack() buf, err = fullReply.Pack()
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
buf = buf[:udpSize] buf = buf[:req.udpSize]
} }
w.Write(buf) w.Write(buf)
} }

View File

@@ -30,6 +30,7 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"math/rand" "math/rand"
"net"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@@ -39,14 +40,16 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
reply := jsonDNS.PrepareReply(r) reply := jsonDNS.PrepareReply(r)
if len(r.Question) != 1 { if len(r.Question) != 1 {
log.Println("Number of questions is not 1") log.Println("Number of questions is not 1")
reply.Rcode = dns.RcodeFormatError reply.Rcode = dns.RcodeFormatError
w.WriteMsg(reply) w.WriteMsg(reply)
return return &DNSRequest{
err: &dns.Error{},
}
} }
question := &r.Question[0] question := &r.Question[0]
@@ -63,8 +66,6 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
fmt.Printf("%s - - [%s] \"%s IN %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionType) fmt.Printf("%s - - [%s] \"%s IN %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionType)
} }
requestID := r.Id
r.Id = 0
question.Name = questionName question.Name = questionName
opt := r.IsEdns0() opt := r.IsEdns0()
udpSize := uint16(512) udpSize := uint16(512)
@@ -85,9 +86,10 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
break break
} }
} }
ednsClientAddress, ednsClientNetmask := net.IP(nil), uint8(255)
if edns0Subnet == nil { if edns0Subnet == nil {
ednsClientFamily := uint16(0) ednsClientFamily := uint16(0)
ednsClientAddress, ednsClientNetmask := c.findClientIP(w, r) ednsClientAddress, ednsClientNetmask = c.findClientIP(w, r)
if ednsClientAddress != nil { if ednsClientAddress != nil {
if ipv4 := ednsClientAddress.To4(); ipv4 != nil { if ipv4 := ednsClientAddress.To4(); ipv4 != nil {
ednsClientFamily = 1 ednsClientFamily = 1
@@ -105,15 +107,22 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
edns0Subnet.Address = ednsClientAddress edns0Subnet.Address = ednsClientAddress
opt.Option = append(opt.Option, edns0Subnet) opt.Option = append(opt.Option, edns0Subnet)
} }
} else {
ednsClientAddress, ednsClientNetmask = edns0Subnet.Address, edns0Subnet.SourceNetmask
} }
requestID := r.Id
r.Id = 0
requestBinary, err := r.Pack() requestBinary, err := r.Pack()
if err != nil { if err != nil {
log.Println(err) log.Println(err)
reply.Rcode = dns.RcodeFormatError reply.Rcode = dns.RcodeFormatError
w.WriteMsg(reply) w.WriteMsg(reply)
return return &DNSRequest{
err: err,
}
} }
r.Id = requestID
requestBase64 := base64.RawURLEncoding.EncodeToString(requestBinary) requestBase64 := base64.RawURLEncoding.EncodeToString(requestBinary)
numServers := len(c.conf.UpstreamIETF) numServers := len(c.conf.UpstreamIETF)
@@ -127,7 +136,9 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
log.Println(err) log.Println(err)
reply.Rcode = dns.RcodeServerFailure reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(reply)
return return &DNSRequest{
err: err,
}
} }
} else { } else {
req, err = http.NewRequest("POST", upstream, bytes.NewReader(requestBinary)) req, err = http.NewRequest("POST", upstream, bytes.NewReader(requestBinary))
@@ -135,11 +146,13 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
log.Println(err) log.Println(err)
reply.Rcode = dns.RcodeServerFailure reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(reply)
return return &DNSRequest{
err: err,
}
} }
req.Header.Set("Content-Type", "application/dns-udpwireformat") req.Header.Set("Content-Type", "application/dns-udpwireformat")
} }
req.Header.Set("Accept", "application/dns-udpwireformat") req.Header.Set("Accept", "application/dns-udpwireformat, application/json")
req.Header.Set("User-Agent", "DNS-over-HTTPS/1.1 (+https://github.com/m13253/dns-over-https)") req.Header.Set("User-Agent", "DNS-over-HTTPS/1.1 (+https://github.com/m13253/dns-over-https)")
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
@@ -147,26 +160,39 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
reply.Rcode = dns.RcodeServerFailure reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(reply)
c.httpTransport.CloseIdleConnections() c.httpTransport.CloseIdleConnections()
return return &DNSRequest{
err: err,
}
} }
if resp.StatusCode != 200 {
log.Printf("HTTP error: %s\n", resp.Status) return &DNSRequest{
reply.Rcode = dns.RcodeServerFailure response: resp,
contentType := resp.Header.Get("Content-Type") reply: reply,
udpSize: udpSize,
ednsClientAddress: ednsClientAddress,
ednsClientNetmask: ednsClientNetmask,
}
}
func (c *Client) parseResponseIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) {
if req.response.StatusCode != 200 {
log.Printf("HTTP error: %s\n", req.response.Status)
req.reply.Rcode = dns.RcodeServerFailure
contentType := req.response.Header.Get("Content-Type")
if contentType != "application/dns-udpwireformat" && !strings.HasPrefix(contentType, "application/dns-udpwireformat;") { if contentType != "application/dns-udpwireformat" && !strings.HasPrefix(contentType, "application/dns-udpwireformat;") {
w.WriteMsg(reply) w.WriteMsg(req.reply)
return return
} }
} }
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(req.response.Body)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
reply.Rcode = dns.RcodeServerFailure req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(req.reply)
return return
} }
headerNow := resp.Header.Get("Date") headerNow := req.response.Header.Get("Date")
now := time.Now().UTC() now := time.Now().UTC()
if headerNow != "" { if headerNow != "" {
if nowDate, err := time.Parse(http.TimeFormat, headerNow); err == nil { if nowDate, err := time.Parse(http.TimeFormat, headerNow); err == nil {
@@ -175,7 +201,7 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
log.Println(err) log.Println(err)
} }
} }
headerLastModified := resp.Header.Get("Last-Modified") headerLastModified := req.response.Header.Get("Last-Modified")
lastModified := now lastModified := now
if headerLastModified != "" { if headerLastModified != "" {
if lastModifiedDate, err := time.Parse(http.TimeFormat, headerLastModified); err == nil { if lastModifiedDate, err := time.Parse(http.TimeFormat, headerLastModified); err == nil {
@@ -193,12 +219,12 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
err = fullReply.Unpack(body) err = fullReply.Unpack(body)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
reply.Rcode = dns.RcodeServerFailure req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(req.reply)
return return
} }
fullReply.Id = requestID fullReply.Id = r.Id
for _, rr := range fullReply.Answer { for _, rr := range fullReply.Answer {
_ = fixRecordTTL(rr, timeDelta) _ = fixRecordTTL(rr, timeDelta)
} }
@@ -215,18 +241,18 @@ func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
buf, err := fullReply.Pack() buf, err := fullReply.Pack()
if err != nil { if err != nil {
log.Println(err) log.Println(err)
reply.Rcode = dns.RcodeServerFailure req.reply.Rcode = dns.RcodeServerFailure
w.WriteMsg(reply) w.WriteMsg(req.reply)
return return
} }
if !isTCP && len(buf) > int(udpSize) { if !isTCP && len(buf) > int(req.udpSize) {
fullReply.Truncated = true fullReply.Truncated = true
buf, err = fullReply.Pack() buf, err = fullReply.Pack()
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
buf = buf[:udpSize] buf = buf[:req.udpSize]
} }
w.Write(buf) w.Write(buf)
} }

View File

@@ -31,6 +31,7 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time"
"../json-dns" "../json-dns"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -178,6 +179,9 @@ func (s *Server) generateResponseGoogle(w http.ResponseWriter, r *http.Request,
} }
w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.Header().Set("Content-Type", "application/json; charset=UTF-8")
now := time.Now().UTC().Format(http.TimeFormat)
w.Header().Set("Date", now)
w.Header().Set("Last-Modified", now)
if respJSON.HaveTTL { if respJSON.HaveTTL {
if req.isTailored { if req.isTailored {
w.Header().Set("Cache-Control", "private, max-age="+strconv.Itoa(int(respJSON.LeastTTL))) w.Header().Set("Cache-Control", "private, max-age="+strconv.Itoa(int(respJSON.LeastTTL)))

View File

@@ -144,8 +144,10 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) {
if responseType == "application/json" { if responseType == "application/json" {
s.generateResponseGoogle(w, r, req) s.generateResponseGoogle(w, r, req)
} else if contentType == "application/dns-udpwireformat" { } else if responseType == "application/dns-udpwireformat" {
s.generateResponseIETF(w, r, req) s.generateResponseIETF(w, r, req)
} else {
panic("Unknown response Content-Type")
} }
} }