From fec1e84d5e53d8f2b6d7e5f0dd59a7edee3fc150 Mon Sep 17 00:00:00 2001 From: Sherlock Holo Date: Sat, 9 Mar 2019 18:12:45 +0800 Subject: [PATCH] Add backend weight round robin select (#34) * Add upstream selector, there are two selector now: - random selector - weight random selector random selector will choose upstream at random; weight random selector will choose upstream at random with weight Signed-off-by: Sherlock Holo * Rewrite config and config file example, prepare for weight round robbin selector Signed-off-by: Sherlock Holo * Replace bad implement of weight random selector with weight round robbin selector, the algorithm is nginx weight round robbin like Signed-off-by: Sherlock Holo * Use new config module Signed-off-by: Sherlock Holo * Disable deprecated DualStack set Signed-off-by: Sherlock Holo * Fix typo Signed-off-by: Sherlock Holo * Optimize upstreamSelector judge Signed-off-by: Sherlock Holo * Fix typo Signed-off-by: Sherlock Holo * Add config timeout unit tips Signed-off-by: Sherlock Holo * Set wrr http client timeout to replace http request timeout Signed-off-by: Sherlock Holo * Add weight value range Signed-off-by: Sherlock Holo * Add a line ending for .gitignore Signed-off-by: Sherlock Holo * Optimize config file style Signed-off-by: Sherlock Holo * Modify Weight type to int32 Signed-off-by: Sherlock Holo * Add upstreamError Signed-off-by: Sherlock Holo * Rewrite Selector interface and wrr implement Signed-off-by: Sherlock Holo * Use http module predefined constant to judge req.response.StatusCode Signed-off-by: Sherlock Holo * Use Selector.ReportUpstreamError to report upstream error for evaluation loop in real time Signed-off-by: Sherlock Holo * Make client selector field private Signed-off-by: Sherlock Holo * Replace config file url to URL Add miss space for 'weight= 50' Signed-off-by: Sherlock Holo * Rewrite Selector.ReportUpstreamError to Selector.ReportUpstreamStatus, report upstream ok in real time Signed-off-by: Sherlock Holo * Fix checkIETFResponse: if upstream OK, won't increase weight Signed-off-by: Sherlock Holo * Fix typo Signed-off-by: Sherlock Holo --- .gitignore | 2 + doh-client/client.go | 184 +++++++++++------ doh-client/{ => config}/config.go | 45 +++- doh-client/doh-client.conf | 64 +++--- doh-client/google.go | 23 ++- doh-client/ietf.go | 23 ++- doh-client/main.go | 6 +- doh-client/selector/randomSelector.go | 50 +++++ doh-client/selector/selector.go | 12 ++ doh-client/selector/upstream.go | 28 +++ doh-client/selector/upstreamStatus.go | 14 ++ .../selector/weightRoundRobinSelector.go | 192 ++++++++++++++++++ go.mod | 14 ++ go.sum | 17 ++ 14 files changed, 552 insertions(+), 122 deletions(-) rename doh-client/{ => config}/config.go (66%) create mode 100644 doh-client/selector/randomSelector.go create mode 100644 doh-client/selector/selector.go create mode 100644 doh-client/selector/upstream.go create mode 100644 doh-client/selector/upstreamStatus.go create mode 100644 doh-client/selector/weightRoundRobinSelector.go create mode 100644 go.mod create mode 100644 go.sum diff --git a/.gitignore b/.gitignore index a1338d6..f46854f 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 .glide/ + +.idea/ diff --git a/doh-client/client.go b/doh-client/client.go index 3fbd228..03c9b68 100644 --- a/doh-client/client.go +++ b/doh-client/client.go @@ -31,11 +31,14 @@ import ( "net" "net/http" "net/http/cookiejar" + "net/url" "strconv" "strings" "sync" "time" + "github.com/m13253/dns-over-https/doh-client/config" + "github.com/m13253/dns-over-https/doh-client/selector" "github.com/m13253/dns-over-https/json-dns" "github.com/miekg/dns" "golang.org/x/net/http2" @@ -43,7 +46,7 @@ import ( ) type Client struct { - conf *config + conf *config.Config bootstrap []string passthrough []string udpClient *dns.Client @@ -56,6 +59,7 @@ type Client struct { httpTransport *http.Transport httpClient *http.Client httpClientLastCreate time.Time + selector selector.Selector } type DNSRequest struct { @@ -68,7 +72,7 @@ type DNSRequest struct { err error } -func NewClient(conf *config) (c *Client, err error) { +func NewClient(conf *config.Config) (c *Client, err error) { c = &Client{ conf: conf, } @@ -78,11 +82,11 @@ func NewClient(conf *config) (c *Client, err error) { c.udpClient = &dns.Client{ Net: "udp", UDPSize: dns.DefaultMsgSize, - Timeout: time.Duration(conf.Timeout) * time.Second, + Timeout: time.Duration(conf.Other.Timeout) * time.Second, } c.tcpClient = &dns.Client{ Net: "tcp", - Timeout: time.Duration(conf.Timeout) * time.Second, + Timeout: time.Duration(conf.Other.Timeout) * time.Second, } for _, addr := range conf.Listen { c.udpServers = append(c.udpServers, &dns.Server{ @@ -98,9 +102,9 @@ func NewClient(conf *config) (c *Client, err error) { }) } c.bootstrapResolver = net.DefaultResolver - if len(conf.Bootstrap) != 0 { - c.bootstrap = make([]string, len(conf.Bootstrap)) - for i, bootstrap := range conf.Bootstrap { + if len(conf.Other.Bootstrap) != 0 { + c.bootstrap = make([]string, len(conf.Other.Bootstrap)) + for i, bootstrap := range conf.Other.Bootstrap { bootstrapAddr, err := net.ResolveUDPAddr("udp", bootstrap) if err != nil { bootstrapAddr, err = net.ResolveUDPAddr("udp", "["+bootstrap+"]:53") @@ -120,9 +124,9 @@ func NewClient(conf *config) (c *Client, err error) { return conn, err }, } - if len(conf.Passthrough) != 0 { - c.passthrough = make([]string, len(conf.Passthrough)) - for i, passthrough := range conf.Passthrough { + if len(conf.Other.Passthrough) != 0 { + c.passthrough = make([]string, len(conf.Other.Passthrough)) + for i, passthrough := range conf.Other.Passthrough { if punycode, err := idna.ToASCII(passthrough); err != nil { passthrough = punycode } @@ -133,7 +137,7 @@ func NewClient(conf *config) (c *Client, err error) { // Most CDNs require Cookie support to prevent DDoS attack. // Disabling Cookie does not effectively prevent tracking, // so I will leave it on to make anti-DDoS services happy. - if !c.conf.NoCookies { + if !c.conf.Other.NoCookies { c.cookieJar, err = cookiejar.New(nil) if err != nil { return nil, err @@ -147,23 +151,59 @@ func NewClient(conf *config) (c *Client, err error) { if err != nil { return nil, err } + + switch c.conf.Upstream.UpstreamSelector { + default: + // if selector is invalid or random, use random selector, or should we stop program and let user knows he is wrong? + s := selector.NewRandomSelector() + for _, u := range c.conf.Upstream.UpstreamGoogle { + if err := s.Add(u.URL, selector.Google); err != nil { + return nil, err + } + } + + for _, u := range c.conf.Upstream.UpstreamIETF { + if err := s.Add(u.URL, selector.IETF); err != nil { + return nil, err + } + } + + c.selector = s + + case config.WeightedRoundRobin: + s := selector.NewWeightRoundRobinSelector(time.Duration(c.conf.Other.Timeout) * time.Second) + for _, u := range c.conf.Upstream.UpstreamGoogle { + if err := s.Add(u.URL, selector.Google, u.Weight); err != nil { + return nil, err + } + } + + for _, u := range c.conf.Upstream.UpstreamIETF { + if err := s.Add(u.URL, selector.IETF, u.Weight); err != nil { + return nil, err + } + } + + c.selector = s + } + return c, nil } func (c *Client) newHTTPClient() error { c.httpClientMux.Lock() defer c.httpClientMux.Unlock() - if !c.httpClientLastCreate.IsZero() && time.Since(c.httpClientLastCreate) < time.Duration(c.conf.Timeout)*time.Second { + if !c.httpClientLastCreate.IsZero() && time.Since(c.httpClientLastCreate) < time.Duration(c.conf.Other.Timeout)*time.Second { return nil } if c.httpTransport != nil { c.httpTransport.CloseIdleConnections() } dialer := &net.Dialer{ - Timeout: time.Duration(c.conf.Timeout) * time.Second, + Timeout: time.Duration(c.conf.Other.Timeout) * time.Second, KeepAlive: 30 * time.Second, - DualStack: true, - Resolver: c.bootstrapResolver, + // DualStack: true, + Resolver: c.bootstrapResolver, } c.httpTransport = &http.Transport{ DialContext: dialer.DialContext, @@ -172,9 +212,9 @@ func (c *Client) newHTTPClient() error { MaxIdleConns: 100, MaxIdleConnsPerHost: 10, Proxy: http.ProxyFromEnvironment, - TLSHandshakeTimeout: time.Duration(c.conf.Timeout) * time.Second, + TLSHandshakeTimeout: time.Duration(c.conf.Other.Timeout) * time.Second, } - if c.conf.NoIPv6 { + if c.conf.Other.NoIPv6 { c.httpTransport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { if strings.HasPrefix(network, "tcp") { network = "tcp4" @@ -213,11 +253,15 @@ func (c *Client) Start() error { } } close(results) + + // start evaluation poll + c.selector.StartEvaluate() + return nil } func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Timeout)*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Other.Timeout)*time.Second) defer cancel() if r.Response { @@ -246,7 +290,7 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { } else { questionType = strconv.FormatUint(uint64(question.Qtype), 10) } - if c.conf.Verbose { + if c.conf.Other.Verbose { fmt.Printf("%s - - [%s] \"%s %s %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionClass, questionType) } @@ -284,64 +328,80 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) { return } - requestType := "" - if len(c.conf.UpstreamIETF) == 0 { - requestType = "application/dns-json" - } else if len(c.conf.UpstreamGoogle) == 0 { - requestType = "application/dns-message" - } else { - numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF) - random := rand.Intn(numServers) - if random < len(c.conf.UpstreamGoogle) { - requestType = "application/dns-json" - } else { - requestType = "application/dns-message" - } + upstream := c.selector.Get() + requestType := upstream.RequestType + + if c.conf.Other.Verbose { + log.Println("choose upstream:", upstream) } var req *DNSRequest - if requestType == "application/dns-json" { - req = c.generateRequestGoogle(ctx, w, r, isTCP) - } else if requestType == "application/dns-message" { - req = c.generateRequestIETF(ctx, w, r, isTCP) - } else { + switch requestType { + case "application/dns-json": + req = c.generateRequestGoogle(ctx, w, r, isTCP, upstream) + + case "application/dns-message": + req = c.generateRequestIETF(ctx, w, r, isTCP, upstream) + + default: panic("Unknown request Content-Type") } - if req.response != nil { - defer req.response.Body.Close() - for _, header := range c.conf.DebugHTTPHeaders { - if value := req.response.Header.Get(header); value != "" { - log.Printf("%s: %s\n", header, value) + if req.err != nil { + if urlErr, ok := req.err.(*url.Error); ok { + // should we only check timeout? + if urlErr.Timeout() { + c.selector.ReportUpstreamStatus(upstream, selector.Timeout) } } - } - if req.err != nil { + return } - contentType := "" - candidateType := strings.SplitN(req.response.Header.Get("Content-Type"), ";", 2)[0] - if candidateType == "application/json" { - contentType = "application/json" - } else if candidateType == "application/dns-message" { - contentType = "application/dns-message" - } else if candidateType == "application/dns-udpwireformat" { - contentType = "application/dns-message" - } else { - if requestType == "application/dns-json" { - contentType = "application/json" - } else if requestType == "application/dns-message" { - contentType = "application/dns-message" + // if req.err == nil, req.response != nil + defer req.response.Body.Close() + + for _, header := range c.conf.Other.DebugHTTPHeaders { + if value := req.response.Header.Get(header); value != "" { + log.Printf("%s: %s\n", header, value) } } - if contentType == "application/json" { + candidateType := strings.SplitN(req.response.Header.Get("Content-Type"), ";", 2)[0] + + switch candidateType { + case "application/json": c.parseResponseGoogle(ctx, w, r, isTCP, req) - } else if contentType == "application/dns-message" { + + case "application/dns-message", "application/dns-udpwireformat": c.parseResponseIETF(ctx, w, r, isTCP, req) - } else { - panic("Unknown response Content-Type") + + default: + switch requestType { + case "application/dns-json": + c.parseResponseGoogle(ctx, w, r, isTCP, req) + + case "application/dns-message": + c.parseResponseIETF(ctx, w, r, isTCP, req) + + default: + panic("Unknown response Content-Type") + } + } + + // https://developers.cloudflare.com/1.1.1.1/dns-over-https/request-structure/ says + // returns code will be 200 / 400 / 413 / 415 / 504, some server will return 503, so + // I think if status code is 5xx, upstream must has some problems + /*if req.response.StatusCode/100 == 5 { + c.selector.ReportUpstreamStatus(upstream, selector.Medium) + }*/ + + switch req.response.StatusCode / 100 { + case 5: + c.selector.ReportUpstreamStatus(upstream, selector.Error) + + case 2: + c.selector.ReportUpstreamStatus(upstream, selector.OK) } } @@ -360,7 +420,7 @@ var ( func (c *Client) findClientIP(w dns.ResponseWriter, r *dns.Msg) (ednsClientAddress net.IP, ednsClientNetmask uint8) { ednsClientNetmask = 255 - if c.conf.NoECS { + if c.conf.Other.NoECS { return net.IPv4(0, 0, 0, 0), 0 } if opt := r.IsEdns0(); opt != nil { diff --git a/doh-client/config.go b/doh-client/config/config.go similarity index 66% rename from doh-client/config.go rename to doh-client/config/config.go index e045ddb..3a9a215 100644 --- a/doh-client/config.go +++ b/doh-client/config/config.go @@ -21,7 +21,7 @@ DEALINGS IN THE SOFTWARE. */ -package main +package config import ( "fmt" @@ -29,10 +29,23 @@ import ( "github.com/BurntSushi/toml" ) -type config struct { - Listen []string `toml:"listen"` - UpstreamGoogle []string `toml:"upstream_google"` - UpstreamIETF []string `toml:"upstream_ietf"` +const ( + Random = "random" + WeightedRoundRobin = "weighted_round_robin" +) + +type upstreamDetail struct { + URL string `toml:"url"` + Weight int32 `toml:"weight"` +} + +type upstream struct { + UpstreamGoogle []upstreamDetail `toml:"upstream_google"` + UpstreamIETF []upstreamDetail `toml:"upstream_ietf"` + UpstreamSelector string `toml:"upstream_selector"` // usable: random or weighted_random +} + +type others struct { Bootstrap []string `toml:"bootstrap"` Passthrough []string `toml:"passthrough"` Timeout uint `toml:"timeout"` @@ -43,8 +56,14 @@ type config struct { DebugHTTPHeaders []string `toml:"debug_http_headers"` } -func loadConfig(path string) (*config, error) { - conf := &config{} +type Config struct { + Listen []string `toml:"listen"` + Upstream upstream `toml:"upstream"` + Other others `toml:"others"` +} + +func LoadConfig(path string) (*Config, error) { + conf := &Config{} metaData, err := toml.DecodeFile(path, conf) if err != nil { return nil, err @@ -56,11 +75,15 @@ func loadConfig(path string) (*config, error) { if len(conf.Listen) == 0 { conf.Listen = []string{"127.0.0.1:53", "[::1]:53"} } - if len(conf.UpstreamGoogle) == 0 && len(conf.UpstreamIETF) == 0 { - conf.UpstreamGoogle = []string{"https://dns.google.com/resolve"} + if len(conf.Upstream.UpstreamGoogle) == 0 && len(conf.Upstream.UpstreamIETF) == 0 { + conf.Upstream.UpstreamGoogle = []upstreamDetail{{URL: "https://dns.google.com/resolve", Weight: 50}} } - if conf.Timeout == 0 { - conf.Timeout = 10 + if conf.Other.Timeout == 0 { + conf.Other.Timeout = 10 + } + + if conf.Upstream.UpstreamSelector == "" { + conf.Upstream.UpstreamSelector = Random } return conf, nil diff --git a/doh-client/doh-client.conf b/doh-client/doh-client.conf index bea2df7..c1f13af 100644 --- a/doh-client/doh-client.conf +++ b/doh-client/doh-client.conf @@ -7,40 +7,52 @@ listen = [ ] # HTTP path for upstream resolver -# If multiple servers are specified, a random one will be chosen each time. -upstream_google = [ - # Google's productive resolver, good ECS, bad DNSSEC - "https://dns.google.com/resolve", +# CloudFlare's resolver for Tor, available only with Tor +# Remember to disable ECS below when using Tor! +# Blog: https://blog.cloudflare.com/welcome-hidden-resolver/ +#"https://dns4torpnlfs2ifuz2s2yf3fc7rdmsbhm6rw75euj35pac6ap25zgqad.onion/dns-query", - # CloudFlare's resolver, bad ECS, good DNSSEC - #"https://cloudflare-dns.com/dns-query", - #"https://1.1.1.1/dns-query", - #"https://1.0.0.1/dns-query", - # CloudFlare's resolver for Tor, available only with Tor - # Remember to disable ECS below when using Tor! - # Blog: https://blog.cloudflare.com/welcome-hidden-resolver/ - #"https://dns4torpnlfs2ifuz2s2yf3fc7rdmsbhm6rw75euj35pac6ap25zgqad.onion/dns-query", +[upstream] -] -upstream_ietf = [ +# available selector: random or weighted_round_robin +upstream_selector = "random" - # Google's experimental resolver, good ECS, good DNSSEC - #"https://dns.google.com/experimental", +# weight should in (0, 100], if upstream_selector is random, weight will be ignored - # CloudFlare's resolver, bad ECS, good DNSSEC - #"https://cloudflare-dns.com/dns-query", - #"https://1.1.1.1/dns-query", - #"https://1.0.0.1/dns-query", +# Google's productive resolver, good ECS, bad DNSSEC +[[upstream.upstream_google]] + url = "https://dns.google.com/resolve" + weight = 50 - # CloudFlare's resolver for Tor, available only with Tor - # Remember to disable ECS below when using Tor! - # Blog: https://blog.cloudflare.com/welcome-hidden-resolver/ - #"https://dns4torpnlfs2ifuz2s2yf3fc7rdmsbhm6rw75euj35pac6ap25zgqad.onion/dns-query", +# CloudFlare's resolver, bad ECS, good DNSSEC +[[upstream.upstream_google]] + url = "https://cloudflare-dns.com/dns-query" + weight = 50 -] +# CloudFlare's resolver, bad ECS, good DNSSEC +[[upstream.upstream_google]] + url = "https://1.1.1.1/dns-query" + weight = 50 +# Google's experimental resolver, good ECS, good DNSSEC +[[upstream.upstream_ietf]] + url = "https://dns.google.com/experimental" + weight = 50 + +# CloudFlare's resolver, bad ECS, good DNSSEC +[[upstream.upstream_ietf]] + url = "https://cloudflare-dns.com/dns-query" + weight = 50 + +# CloudFlare's resolver, bad ECS, good DNSSEC +[[upstream.upstream_ietf]] + url = "https://1.1.1.1/dns-query" + weight = 50 + + +[others] # Bootstrap DNS server to resolve the address of the upstream resolver # If multiple servers are specified, a random one will be chosen each time. # If empty, use the system DNS settings. @@ -76,7 +88,7 @@ passthrough = [ "time.windows.com", ] -# Timeout for upstream request +# Timeout for upstream request in seconds timeout = 30 # Disable HTTP Cookies diff --git a/doh-client/google.go b/doh-client/google.go index 506f519..23e0d9f 100644 --- a/doh-client/google.go +++ b/doh-client/google.go @@ -29,17 +29,17 @@ import ( "fmt" "io/ioutil" "log" - "math/rand" "net/http" "net/url" "strconv" "strings" + "github.com/m13253/dns-over-https/doh-client/selector" "github.com/m13253/dns-over-https/json-dns" "github.com/miekg/dns" ) -func (c *Client) generateRequestGoogle(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest { +func (c *Client) generateRequestGoogle(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool, upstream *selector.Upstream) *DNSRequest { question := &r.Question[0] questionName := question.Name questionClass := question.Qclass @@ -58,9 +58,7 @@ func (c *Client) generateRequestGoogle(ctx context.Context, w dns.ResponseWriter questionType = strconv.FormatUint(uint64(question.Qtype), 10) } - numServers := len(c.conf.UpstreamGoogle) - upstream := c.conf.UpstreamGoogle[rand.Intn(numServers)] - requestURL := fmt.Sprintf("%s?ct=application/dns-json&name=%s&type=%s", upstream, url.QueryEscape(questionName), url.QueryEscape(questionType)) + requestURL := fmt.Sprintf("%s?ct=application/dns-json&name=%s&type=%s", upstream.Url, url.QueryEscape(questionName), url.QueryEscape(questionType)) if r.CheckingDisabled { requestURL += "&cd=1" @@ -76,7 +74,7 @@ func (c *Client) generateRequestGoogle(ctx context.Context, w dns.ResponseWriter requestURL += fmt.Sprintf("&edns_client_subnet=%s/%d", ednsClientAddress.String(), ednsClientNetmask) } - req, err := http.NewRequest("GET", requestURL, nil) + req, err := http.NewRequest(http.MethodGet, requestURL, nil) if err != nil { log.Println(err) reply := jsonDNS.PrepareReply(r) @@ -86,19 +84,24 @@ func (c *Client) generateRequestGoogle(ctx context.Context, w dns.ResponseWriter err: err, } } + req.Header.Set("Accept", "application/json, application/dns-message, application/dns-udpwireformat") req.Header.Set("User-Agent", USER_AGENT) req = req.WithContext(ctx) + c.httpClientMux.RLock() resp, err := c.httpClient.Do(req) c.httpClientMux.RUnlock() - if err == context.DeadlineExceeded { + + // if http Client.Do returns non-nil error, it always *url.Error + /*if err == context.DeadlineExceeded { // Do not respond, silently fail to prevent caching of SERVFAIL log.Println(err) return &DNSRequest{ err: err, } - } + }*/ + if err != nil { log.Println(err) reply := jsonDNS.PrepareReply(r) @@ -115,12 +118,12 @@ func (c *Client) generateRequestGoogle(ctx context.Context, w dns.ResponseWriter udpSize: udpSize, ednsClientAddress: ednsClientAddress, ednsClientNetmask: ednsClientNetmask, - currentUpstream: upstream, + currentUpstream: upstream.Url, } } func (c *Client) parseResponseGoogle(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) { - if req.response.StatusCode != 200 { + if req.response.StatusCode != http.StatusOK { log.Printf("HTTP error from upstream %s: %s\n", req.currentUpstream, req.response.Status) req.reply.Rcode = dns.RcodeServerFailure contentType := req.response.Header.Get("Content-Type") diff --git a/doh-client/ietf.go b/doh-client/ietf.go index df18d82..f68f08f 100644 --- a/doh-client/ietf.go +++ b/doh-client/ietf.go @@ -30,17 +30,17 @@ import ( "fmt" "io/ioutil" "log" - "math/rand" "net" "net/http" "strings" "time" + "github.com/m13253/dns-over-https/doh-client/selector" "github.com/m13253/dns-over-https/json-dns" "github.com/miekg/dns" ) -func (c *Client) generateRequestIETF(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest { +func (c *Client) generateRequestIETF(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool, upstream *selector.Upstream) *DNSRequest { opt := r.IsEdns0() udpSize := uint16(512) if opt == nil { @@ -100,13 +100,11 @@ func (c *Client) generateRequestIETF(ctx context.Context, w dns.ResponseWriter, r.Id = requestID requestBase64 := base64.RawURLEncoding.EncodeToString(requestBinary) - numServers := len(c.conf.UpstreamIETF) - upstream := c.conf.UpstreamIETF[rand.Intn(numServers)] - requestURL := fmt.Sprintf("%s?ct=application/dns-message&dns=%s", upstream, requestBase64) + requestURL := fmt.Sprintf("%s?ct=application/dns-message&dns=%s", upstream.Url, requestBase64) var req *http.Request if len(requestURL) < 2048 { - req, err = http.NewRequest("GET", requestURL, nil) + req, err = http.NewRequest(http.MethodGet, requestURL, nil) if err != nil { log.Println(err) reply := jsonDNS.PrepareReply(r) @@ -117,7 +115,7 @@ func (c *Client) generateRequestIETF(ctx context.Context, w dns.ResponseWriter, } } } else { - req, err = http.NewRequest("POST", upstream, bytes.NewReader(requestBinary)) + req, err = http.NewRequest(http.MethodPost, upstream.Url, bytes.NewReader(requestBinary)) if err != nil { log.Println(err) reply := jsonDNS.PrepareReply(r) @@ -135,13 +133,16 @@ func (c *Client) generateRequestIETF(ctx context.Context, w dns.ResponseWriter, c.httpClientMux.RLock() resp, err := c.httpClient.Do(req) c.httpClientMux.RUnlock() - if err == context.DeadlineExceeded { + + // if http Client.Do returns non-nil error, it always *url.Error + /*if err == context.DeadlineExceeded { // Do not respond, silently fail to prevent caching of SERVFAIL log.Println(err) return &DNSRequest{ err: err, } - } + }*/ + if err != nil { log.Println(err) reply := jsonDNS.PrepareReply(r) @@ -158,12 +159,12 @@ func (c *Client) generateRequestIETF(ctx context.Context, w dns.ResponseWriter, udpSize: udpSize, ednsClientAddress: ednsClientAddress, ednsClientNetmask: ednsClientNetmask, - currentUpstream: upstream, + currentUpstream: upstream.Url, } } func (c *Client) parseResponseIETF(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) { - if req.response.StatusCode != 200 { + if req.response.StatusCode != http.StatusOK { log.Printf("HTTP error from upstream %s: %s\n", req.currentUpstream, req.response.Status) req.reply.Rcode = dns.RcodeServerFailure contentType := req.response.Header.Get("Content-Type") diff --git a/doh-client/main.go b/doh-client/main.go index f6a5295..cb53677 100644 --- a/doh-client/main.go +++ b/doh-client/main.go @@ -32,6 +32,8 @@ import ( "os" "runtime" "strconv" + + "github.com/m13253/dns-over-https/doh-client/config" ) func checkPIDFile(pidFile string) (bool, error) { @@ -101,13 +103,13 @@ func main() { } } - conf, err := loadConfig(*confPath) + conf, err := config.LoadConfig(*confPath) if err != nil { log.Fatalln(err) } if *verbose { - conf.Verbose = true + conf.Other.Verbose = true } client, err := NewClient(conf) diff --git a/doh-client/selector/randomSelector.go b/doh-client/selector/randomSelector.go new file mode 100644 index 0000000..36eca50 --- /dev/null +++ b/doh-client/selector/randomSelector.go @@ -0,0 +1,50 @@ +package selector + +import ( + "errors" + "math/rand" + "time" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +type RandomSelector struct { + upstreams []*Upstream +} + +func NewRandomSelector() *RandomSelector { + return new(RandomSelector) +} + +func (rs *RandomSelector) Add(url string, upstreamType UpstreamType) (err error) { + switch upstreamType { + case Google: + rs.upstreams = append(rs.upstreams, &Upstream{ + Type: Google, + Url: url, + RequestType: "application/dns-json", + }) + + case IETF: + rs.upstreams = append(rs.upstreams, &Upstream{ + Type: IETF, + Url: url, + RequestType: "application/dns-message", + }) + + default: + return errors.New("unknown upstream type") + } + + return nil +} + +func (rs *RandomSelector) Get() *Upstream { + return rs.upstreams[rand.Intn(len(rs.upstreams)-1)] +} + +func (rs *RandomSelector) StartEvaluate() {} + +func (rs *RandomSelector) ReportUpstreamStatus(upstream *Upstream, upstreamStatus upstreamStatus) {} diff --git a/doh-client/selector/selector.go b/doh-client/selector/selector.go new file mode 100644 index 0000000..d32185e --- /dev/null +++ b/doh-client/selector/selector.go @@ -0,0 +1,12 @@ +package selector + +type Selector interface { + // Get returns a upstream + Get() *Upstream + + // StartEvaluate start upstream evaluation loop + StartEvaluate() + + // ReportUpstreamStatus report upstream status + ReportUpstreamStatus(upstream *Upstream, upstreamStatus upstreamStatus) +} diff --git a/doh-client/selector/upstream.go b/doh-client/selector/upstream.go new file mode 100644 index 0000000..e7c8ea0 --- /dev/null +++ b/doh-client/selector/upstream.go @@ -0,0 +1,28 @@ +package selector + +import "fmt" + +type UpstreamType int + +const ( + Google UpstreamType = iota + IETF +) + +var typeMap = map[UpstreamType]string{ + Google: "Google", + IETF: "IETF", +} + +type Upstream struct { + Type UpstreamType + Url string + RequestType string + weight int32 + effectiveWeight int32 + currentWeight int32 +} + +func (u Upstream) String() string { + return fmt.Sprintf("upstream type: %s, upstream url: %s", typeMap[u.Type], u.Url) +} diff --git a/doh-client/selector/upstreamStatus.go b/doh-client/selector/upstreamStatus.go new file mode 100644 index 0000000..c3faf45 --- /dev/null +++ b/doh-client/selector/upstreamStatus.go @@ -0,0 +1,14 @@ +package selector + +type upstreamStatus int + +const ( + // when query upstream timeout, usually upstream is unavailable for a long time + Timeout upstreamStatus = iota + + // when query upstream return 5xx response, upstream still alive, maybe just a lof of query for him + Error + + // when query upstream ok, means upstream is available + OK +) diff --git a/doh-client/selector/weightRoundRobinSelector.go b/doh-client/selector/weightRoundRobinSelector.go new file mode 100644 index 0000000..a0e87c1 --- /dev/null +++ b/doh-client/selector/weightRoundRobinSelector.go @@ -0,0 +1,192 @@ +package selector + +import ( + "encoding/json" + "errors" + "net/http" + "sync/atomic" + "time" +) + +type WeightRoundRobinSelector struct { + upstreams []*Upstream // upstreamsInfo + client http.Client // http client to check the upstream +} + +func NewWeightRoundRobinSelector(timeout time.Duration) *WeightRoundRobinSelector { + return &WeightRoundRobinSelector{ + client: http.Client{Timeout: timeout}, + } +} + +func (ws *WeightRoundRobinSelector) Add(url string, upstreamType UpstreamType, weight int32) (err error) { + switch upstreamType { + case Google: + ws.upstreams = append(ws.upstreams, &Upstream{ + Type: Google, + Url: url, + RequestType: "application/dns-json", + weight: weight, + effectiveWeight: weight, + }) + + case IETF: + ws.upstreams = append(ws.upstreams, &Upstream{ + Type: IETF, + Url: url, + RequestType: "application/dns-message", + weight: weight, + effectiveWeight: weight, + }) + + default: + return errors.New("unknown upstream type") + } + + return nil +} + +// COW, avoid concurrent read write upstreams +func (ws *WeightRoundRobinSelector) StartEvaluate() { + go func() { + for { + for i := range ws.upstreams { + upstreamUrl := ws.upstreams[i].Url + var acceptType string + + switch ws.upstreams[i].Type { + case Google: + upstreamUrl += "?name=www.example.com&type=A" + acceptType = "application/dns-json" + + case IETF: + // www.example.com + upstreamUrl += "?dns=q80BAAABAAAAAAAAA3d3dwdleGFtcGxlA2NvbQAAAQAB" + acceptType = "application/dns-message" + } + + req, err := http.NewRequest(http.MethodGet, upstreamUrl, nil) + if err != nil { + /*log.Println("upstream:", upstreamUrl, "type:", typeMap[upstream.Type], "check failed:", err) + continue*/ + + // should I only log it? But if there is an error, I think when query the server will return error too + panic("upstream: " + upstreamUrl + " type: " + typeMap[ws.upstreams[i].Type] + " check failed: " + err.Error()) + } + + req.Header.Set("accept", acceptType) + + resp, err := ws.client.Do(req) + if err != nil { + // should I check error in detail? + if atomic.AddInt32(&ws.upstreams[i].effectiveWeight, -10) < 0 { + atomic.StoreInt32(&ws.upstreams[i].effectiveWeight, 0) + } + continue + } + + switch ws.upstreams[i].Type { + case Google: + checkGoogleResponse(resp, ws.upstreams[i]) + + case IETF: + checkIETFResponse(resp, ws.upstreams[i]) + } + } + + time.Sleep(30 * time.Second) + } + }() +} + +// nginx wrr like +func (ws *WeightRoundRobinSelector) Get() *Upstream { + var ( + total int32 + bestUpstreamIndex = -1 + ) + + for i := range ws.upstreams { + effectiveWeight := atomic.LoadInt32(&ws.upstreams[i].effectiveWeight) + ws.upstreams[i].currentWeight += effectiveWeight + total += effectiveWeight + + if bestUpstreamIndex == -1 || ws.upstreams[i].currentWeight > ws.upstreams[bestUpstreamIndex].currentWeight { + bestUpstreamIndex = i + } + } + + ws.upstreams[bestUpstreamIndex].currentWeight -= total + + return ws.upstreams[bestUpstreamIndex] +} + +func (ws *WeightRoundRobinSelector) ReportUpstreamStatus(upstream *Upstream, upstreamStatus upstreamStatus) { + switch upstreamStatus { + case Timeout: + if atomic.AddInt32(&upstream.effectiveWeight, -10) < 0 { + atomic.StoreInt32(&upstream.effectiveWeight, 0) + } + + case Error: + if atomic.AddInt32(&upstream.effectiveWeight, -5) < 0 { + atomic.StoreInt32(&upstream.effectiveWeight, 0) + } + + case OK: + if atomic.AddInt32(&upstream.effectiveWeight, 2) > upstream.weight { + atomic.StoreInt32(&upstream.effectiveWeight, upstream.weight) + } + } +} + +func checkGoogleResponse(resp *http.Response, upstream *Upstream) { + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + // server error + if atomic.AddInt32(&upstream.effectiveWeight, -5) < 0 { + atomic.StoreInt32(&upstream.effectiveWeight, 0) + } + return + } + + m := make(map[string]interface{}) + if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { + // should I check error in detail? + if atomic.AddInt32(&upstream.effectiveWeight, -1) < 0 { + atomic.StoreInt32(&upstream.effectiveWeight, 0) + } + return + } + + if status, ok := m["status"]; ok { + if statusNum, ok := status.(int); ok && statusNum == 0 { + if atomic.AddInt32(&upstream.effectiveWeight, 5) > upstream.weight { + atomic.StoreInt32(&upstream.effectiveWeight, upstream.weight) + } + return + } + } + + // should I check error in detail? + if atomic.AddInt32(&upstream.effectiveWeight, -1) < 0 { + atomic.StoreInt32(&upstream.effectiveWeight, 0) + } +} + +func checkIETFResponse(resp *http.Response, upstream *Upstream) { + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + // server error + if atomic.AddInt32(&upstream.effectiveWeight, -5) < 0 { + atomic.StoreInt32(&upstream.effectiveWeight, 0) + } + return + } + + if atomic.AddInt32(&upstream.effectiveWeight, 5) > upstream.weight { + atomic.StoreInt32(&upstream.effectiveWeight, upstream.weight) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f395e0d --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/m13253/dns-over-https + +go 1.12 + +require ( + github.com/BurntSushi/toml v0.3.1 + github.com/gorilla/handlers v1.4.0 + github.com/miekg/dns v1.1.4 + golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25 // indirect + golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95 + golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6 // indirect + golang.org/x/sys v0.0.0-20190305064518-30e92a19ae4a // indirect + golang.org/x/text v0.3.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5300a42 --- /dev/null +++ b/go.sum @@ -0,0 +1,17 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/gorilla/handlers v1.4.0 h1:XulKRWSQK5uChr4pEgSE4Tc/OcmnU9GJuSwdog/tZsA= +github.com/gorilla/handlers v1.4.0/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= +github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0= +github.com/miekg/dns v1.1.4/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25 h1:jsG6UpNLt9iAsb0S2AGW28DveNzzgmbXR+ENoPjUeIU= +golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95 h1:fY7Dsw114eJN4boqzVSbpVHO6rTdhq6/GnXeu+PKnzU= +golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6 h1:bjcUS9ztw9kFmmIxJInhon/0Is3p+EHBKNgquIzo1OI= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190305064518-30e92a19ae4a h1:wsSB0WNK6x5F2PxWYOQpGTzp/IH7X8V603VJwSXZUWc= +golang.org/x/sys v0.0.0-20190305064518-30e92a19ae4a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=