mirror of
https://github.com/m13253/dns-over-https.git
synced 2026-03-29 20:40:35 +00:00
Use context for more functions
This commit is contained in:
@@ -217,6 +217,9 @@ func (c *Client) Start() error {
|
||||
}
|
||||
|
||||
func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if r.Response == true {
|
||||
log.Println("Received a response packet")
|
||||
return
|
||||
@@ -298,9 +301,9 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
|
||||
|
||||
var req *DNSRequest
|
||||
if requestType == "application/dns-json" {
|
||||
req = c.generateRequestGoogle(w, r, isTCP)
|
||||
req = c.generateRequestGoogle(ctx, w, r, isTCP)
|
||||
} else if requestType == "application/dns-message" {
|
||||
req = c.generateRequestIETF(w, r, isTCP)
|
||||
req = c.generateRequestIETF(ctx, w, r, isTCP)
|
||||
} else {
|
||||
panic("Unknown request Content-Type")
|
||||
}
|
||||
@@ -334,9 +337,9 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
|
||||
}
|
||||
|
||||
if contentType == "application/json" {
|
||||
c.parseResponseGoogle(w, r, isTCP, req)
|
||||
c.parseResponseGoogle(ctx, w, r, isTCP, req)
|
||||
} else if contentType == "application/dns-message" {
|
||||
c.parseResponseIETF(w, r, isTCP, req)
|
||||
c.parseResponseIETF(ctx, w, r, isTCP, req)
|
||||
} else {
|
||||
panic("Unknown response Content-Type")
|
||||
}
|
||||
|
||||
@@ -34,13 +34,12 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/m13253/dns-over-https/json-dns"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
|
||||
func (c *Client) generateRequestGoogle(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
|
||||
question := &r.Question[0]
|
||||
questionName := question.Name
|
||||
questionClass := question.Qclass
|
||||
@@ -89,12 +88,10 @@ func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP b
|
||||
}
|
||||
req.Header.Set("Accept", "application/json, application/dns-message, application/dns-udpwireformat")
|
||||
req.Header.Set("User-Agent", USER_AGENT)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Timeout)*time.Second)
|
||||
req = req.WithContext(ctx)
|
||||
c.httpClientMux.RLock()
|
||||
resp, err := c.httpClient.Do(req)
|
||||
c.httpClientMux.RUnlock()
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
reply := jsonDNS.PrepareReply(r)
|
||||
@@ -115,7 +112,7 @@ func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP b
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) parseResponseGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) {
|
||||
func (c *Client) parseResponseGoogle(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) {
|
||||
if req.response.StatusCode != 200 {
|
||||
log.Printf("HTTP error from upstream %s: %s\n", req.currentUpstream, req.response.Status)
|
||||
req.reply.Rcode = dns.RcodeServerFailure
|
||||
|
||||
@@ -40,7 +40,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
|
||||
func (c *Client) generateRequestIETF(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
|
||||
opt := r.IsEdns0()
|
||||
udpSize := uint16(512)
|
||||
if opt == nil {
|
||||
@@ -127,12 +127,10 @@ func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP boo
|
||||
}
|
||||
req.Header.Set("Accept", "application/dns-message, application/dns-udpwireformat, application/json")
|
||||
req.Header.Set("User-Agent", USER_AGENT)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Timeout)*time.Second)
|
||||
req = req.WithContext(ctx)
|
||||
c.httpClientMux.RLock()
|
||||
resp, err := c.httpClient.Do(req)
|
||||
c.httpClientMux.RUnlock()
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
reply := jsonDNS.PrepareReply(r)
|
||||
@@ -153,7 +151,7 @@ func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP boo
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) parseResponseIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) {
|
||||
func (c *Client) parseResponseIETF(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) {
|
||||
if req.response.StatusCode != 200 {
|
||||
log.Printf("HTTP error from upstream %s: %s\n", req.currentUpstream, req.response.Status)
|
||||
req.reply.Rcode = dns.RcodeServerFailure
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -38,7 +39,7 @@ import (
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
func (s *Server) parseRequestGoogle(w http.ResponseWriter, r *http.Request) *DNSRequest {
|
||||
func (s *Server) parseRequestGoogle(ctx context.Context, w http.ResponseWriter, r *http.Request) *DNSRequest {
|
||||
name := r.FormValue("name")
|
||||
if name == "" {
|
||||
return &DNSRequest{
|
||||
@@ -168,7 +169,7 @@ func (s *Server) parseRequestGoogle(w http.ResponseWriter, r *http.Request) *DNS
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) generateResponseGoogle(w http.ResponseWriter, r *http.Request, req *DNSRequest) {
|
||||
func (s *Server) generateResponseGoogle(ctx context.Context, w http.ResponseWriter, r *http.Request, req *DNSRequest) {
|
||||
respJSON := jsonDNS.Marshal(req.response)
|
||||
respStr, err := json.Marshal(respJSON)
|
||||
if err != nil {
|
||||
|
||||
@@ -25,6 +25,7 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
@@ -38,7 +39,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (s *Server) parseRequestIETF(w http.ResponseWriter, r *http.Request) *DNSRequest {
|
||||
func (s *Server) parseRequestIETF(ctx context.Context, w http.ResponseWriter, r *http.Request) *DNSRequest {
|
||||
requestBase64 := r.FormValue("dns")
|
||||
requestBinary, err := base64.RawURLEncoding.DecodeString(requestBase64)
|
||||
if err != nil {
|
||||
@@ -145,7 +146,7 @@ func (s *Server) parseRequestIETF(w http.ResponseWriter, r *http.Request) *DNSRe
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) generateResponseIETF(w http.ResponseWriter, r *http.Request, req *DNSRequest) {
|
||||
func (s *Server) generateResponseIETF(ctx context.Context, w http.ResponseWriter, r *http.Request, req *DNSRequest) {
|
||||
respJSON := jsonDNS.Marshal(req.response)
|
||||
req.response.Id = req.transactionID
|
||||
respBytes, err := req.response.Pack()
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
@@ -105,6 +106,8 @@ func (s *Server) Start() error {
|
||||
}
|
||||
|
||||
func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
w.Header().Set("Server", USER_AGENT)
|
||||
w.Header().Set("X-Powered-By", USER_AGENT)
|
||||
|
||||
@@ -158,11 +161,11 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var req *DNSRequest
|
||||
if contentType == "application/dns-json" {
|
||||
req = s.parseRequestGoogle(w, r)
|
||||
req = s.parseRequestGoogle(ctx, w, r)
|
||||
} else if contentType == "application/dns-message" {
|
||||
req = s.parseRequestIETF(w, r)
|
||||
req = s.parseRequestIETF(ctx, w, r)
|
||||
} else if contentType == "application/dns-udpwireformat" {
|
||||
req = s.parseRequestIETF(w, r)
|
||||
req = s.parseRequestIETF(ctx, w, r)
|
||||
} else {
|
||||
jsonDNS.FormatError(w, fmt.Sprintf("Invalid argument value: \"ct\" = %q", contentType), 415)
|
||||
return
|
||||
@@ -178,16 +181,16 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) {
|
||||
req = s.patchRootRD(req)
|
||||
|
||||
var err error
|
||||
req, err = s.doDNSQuery(req)
|
||||
req, err = s.doDNSQuery(ctx, req)
|
||||
if err != nil {
|
||||
jsonDNS.FormatError(w, fmt.Sprintf("DNS query failure (%s)", err.Error()), 503)
|
||||
return
|
||||
}
|
||||
|
||||
if responseType == "application/json" {
|
||||
s.generateResponseGoogle(w, r, req)
|
||||
s.generateResponseGoogle(ctx, w, r, req)
|
||||
} else if responseType == "application/dns-message" {
|
||||
s.generateResponseIETF(w, r, req)
|
||||
s.generateResponseIETF(ctx, w, r, req)
|
||||
} else {
|
||||
panic("Unknown response Content-Type")
|
||||
}
|
||||
@@ -232,7 +235,8 @@ func (s *Server) patchRootRD(req *DNSRequest) *DNSRequest {
|
||||
return req
|
||||
}
|
||||
|
||||
func (s *Server) doDNSQuery(req *DNSRequest) (resp *DNSRequest, err error) {
|
||||
func (s *Server) doDNSQuery(ctx context.Context, req *DNSRequest) (resp *DNSRequest, err error) {
|
||||
// TODO(m13253): Make ctx work. Waiting for a patch for ExchangeContext from miekg/dns.
|
||||
numServers := len(s.conf.Upstream)
|
||||
for i := uint(0); i < s.conf.Tries; i++ {
|
||||
req.currentUpstream = s.conf.Upstream[rand.Intn(numServers)]
|
||||
|
||||
Reference in New Issue
Block a user