mirror of
https://github.com/m13253/dns-over-https.git
synced 2026-05-25 20:44:17 +00:00
Add backend weight round robin select (#34)
* Add upstream selector, there are two selector now:
- random selector
- weight random selector
random selector will choose upstream at random; weight random selector will choose upstream at random with weight
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Rewrite config and config file example, prepare for weight round robbin selector
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Replace bad implement of weight random selector with weight round robbin selector, the algorithm is nginx weight round robbin like
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Use new config module
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Disable deprecated DualStack set
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Fix typo
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Optimize upstreamSelector judge
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Fix typo
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Add config timeout unit tips
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Set wrr http client timeout to replace http request timeout
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Add weight value range
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Add a line ending for .gitignore
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Optimize config file style
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Modify Weight type to int32
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Add upstreamError
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Rewrite Selector interface and wrr implement
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Use http module predefined constant to judge req.response.StatusCode
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Use Selector.ReportUpstreamError to report upstream error for evaluation loop in real time
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Make client selector field private
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Replace config file url to URL
Add miss space for 'weight= 50'
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Rewrite Selector.ReportUpstreamError to Selector.ReportUpstreamStatus, report upstream ok in real time
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Fix checkIETFResponse: if upstream OK, won't increase weight
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
* Fix typo
Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
This commit is contained in:
committed by
Star Brilliant
parent
8f2004d1de
commit
fec1e84d5e
+122
-62
@@ -31,11 +31,14 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/m13253/dns-over-https/doh-client/config"
|
||||
"github.com/m13253/dns-over-https/doh-client/selector"
|
||||
"github.com/m13253/dns-over-https/json-dns"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/http2"
|
||||
@@ -43,7 +46,7 @@ import (
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
conf *config
|
||||
conf *config.Config
|
||||
bootstrap []string
|
||||
passthrough []string
|
||||
udpClient *dns.Client
|
||||
@@ -56,6 +59,7 @@ type Client struct {
|
||||
httpTransport *http.Transport
|
||||
httpClient *http.Client
|
||||
httpClientLastCreate time.Time
|
||||
selector selector.Selector
|
||||
}
|
||||
|
||||
type DNSRequest struct {
|
||||
@@ -68,7 +72,7 @@ type DNSRequest struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func NewClient(conf *config) (c *Client, err error) {
|
||||
func NewClient(conf *config.Config) (c *Client, err error) {
|
||||
c = &Client{
|
||||
conf: conf,
|
||||
}
|
||||
@@ -78,11 +82,11 @@ func NewClient(conf *config) (c *Client, err error) {
|
||||
c.udpClient = &dns.Client{
|
||||
Net: "udp",
|
||||
UDPSize: dns.DefaultMsgSize,
|
||||
Timeout: time.Duration(conf.Timeout) * time.Second,
|
||||
Timeout: time.Duration(conf.Other.Timeout) * time.Second,
|
||||
}
|
||||
c.tcpClient = &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: time.Duration(conf.Timeout) * time.Second,
|
||||
Timeout: time.Duration(conf.Other.Timeout) * time.Second,
|
||||
}
|
||||
for _, addr := range conf.Listen {
|
||||
c.udpServers = append(c.udpServers, &dns.Server{
|
||||
@@ -98,9 +102,9 @@ func NewClient(conf *config) (c *Client, err error) {
|
||||
})
|
||||
}
|
||||
c.bootstrapResolver = net.DefaultResolver
|
||||
if len(conf.Bootstrap) != 0 {
|
||||
c.bootstrap = make([]string, len(conf.Bootstrap))
|
||||
for i, bootstrap := range conf.Bootstrap {
|
||||
if len(conf.Other.Bootstrap) != 0 {
|
||||
c.bootstrap = make([]string, len(conf.Other.Bootstrap))
|
||||
for i, bootstrap := range conf.Other.Bootstrap {
|
||||
bootstrapAddr, err := net.ResolveUDPAddr("udp", bootstrap)
|
||||
if err != nil {
|
||||
bootstrapAddr, err = net.ResolveUDPAddr("udp", "["+bootstrap+"]:53")
|
||||
@@ -120,9 +124,9 @@ func NewClient(conf *config) (c *Client, err error) {
|
||||
return conn, err
|
||||
},
|
||||
}
|
||||
if len(conf.Passthrough) != 0 {
|
||||
c.passthrough = make([]string, len(conf.Passthrough))
|
||||
for i, passthrough := range conf.Passthrough {
|
||||
if len(conf.Other.Passthrough) != 0 {
|
||||
c.passthrough = make([]string, len(conf.Other.Passthrough))
|
||||
for i, passthrough := range conf.Other.Passthrough {
|
||||
if punycode, err := idna.ToASCII(passthrough); err != nil {
|
||||
passthrough = punycode
|
||||
}
|
||||
@@ -133,7 +137,7 @@ func NewClient(conf *config) (c *Client, err error) {
|
||||
// Most CDNs require Cookie support to prevent DDoS attack.
|
||||
// Disabling Cookie does not effectively prevent tracking,
|
||||
// so I will leave it on to make anti-DDoS services happy.
|
||||
if !c.conf.NoCookies {
|
||||
if !c.conf.Other.NoCookies {
|
||||
c.cookieJar, err = cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -147,23 +151,59 @@ func NewClient(conf *config) (c *Client, err error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch c.conf.Upstream.UpstreamSelector {
|
||||
default:
|
||||
// if selector is invalid or random, use random selector, or should we stop program and let user knows he is wrong?
|
||||
s := selector.NewRandomSelector()
|
||||
for _, u := range c.conf.Upstream.UpstreamGoogle {
|
||||
if err := s.Add(u.URL, selector.Google); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, u := range c.conf.Upstream.UpstreamIETF {
|
||||
if err := s.Add(u.URL, selector.IETF); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
c.selector = s
|
||||
|
||||
case config.WeightedRoundRobin:
|
||||
s := selector.NewWeightRoundRobinSelector(time.Duration(c.conf.Other.Timeout) * time.Second)
|
||||
for _, u := range c.conf.Upstream.UpstreamGoogle {
|
||||
if err := s.Add(u.URL, selector.Google, u.Weight); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, u := range c.conf.Upstream.UpstreamIETF {
|
||||
if err := s.Add(u.URL, selector.IETF, u.Weight); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
c.selector = s
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) newHTTPClient() error {
|
||||
c.httpClientMux.Lock()
|
||||
defer c.httpClientMux.Unlock()
|
||||
if !c.httpClientLastCreate.IsZero() && time.Since(c.httpClientLastCreate) < time.Duration(c.conf.Timeout)*time.Second {
|
||||
if !c.httpClientLastCreate.IsZero() && time.Since(c.httpClientLastCreate) < time.Duration(c.conf.Other.Timeout)*time.Second {
|
||||
return nil
|
||||
}
|
||||
if c.httpTransport != nil {
|
||||
c.httpTransport.CloseIdleConnections()
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Duration(c.conf.Timeout) * time.Second,
|
||||
Timeout: time.Duration(c.conf.Other.Timeout) * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
Resolver: c.bootstrapResolver,
|
||||
// DualStack: true,
|
||||
Resolver: c.bootstrapResolver,
|
||||
}
|
||||
c.httpTransport = &http.Transport{
|
||||
DialContext: dialer.DialContext,
|
||||
@@ -172,9 +212,9 @@ func (c *Client) newHTTPClient() error {
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSHandshakeTimeout: time.Duration(c.conf.Timeout) * time.Second,
|
||||
TLSHandshakeTimeout: time.Duration(c.conf.Other.Timeout) * time.Second,
|
||||
}
|
||||
if c.conf.NoIPv6 {
|
||||
if c.conf.Other.NoIPv6 {
|
||||
c.httpTransport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if strings.HasPrefix(network, "tcp") {
|
||||
network = "tcp4"
|
||||
@@ -213,11 +253,15 @@ func (c *Client) Start() error {
|
||||
}
|
||||
}
|
||||
close(results)
|
||||
|
||||
// start evaluation poll
|
||||
c.selector.StartEvaluate()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Timeout)*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Other.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if r.Response {
|
||||
@@ -246,7 +290,7 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
|
||||
} else {
|
||||
questionType = strconv.FormatUint(uint64(question.Qtype), 10)
|
||||
}
|
||||
if c.conf.Verbose {
|
||||
if c.conf.Other.Verbose {
|
||||
fmt.Printf("%s - - [%s] \"%s %s %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionClass, questionType)
|
||||
}
|
||||
|
||||
@@ -284,64 +328,80 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
|
||||
return
|
||||
}
|
||||
|
||||
requestType := ""
|
||||
if len(c.conf.UpstreamIETF) == 0 {
|
||||
requestType = "application/dns-json"
|
||||
} else if len(c.conf.UpstreamGoogle) == 0 {
|
||||
requestType = "application/dns-message"
|
||||
} else {
|
||||
numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF)
|
||||
random := rand.Intn(numServers)
|
||||
if random < len(c.conf.UpstreamGoogle) {
|
||||
requestType = "application/dns-json"
|
||||
} else {
|
||||
requestType = "application/dns-message"
|
||||
}
|
||||
upstream := c.selector.Get()
|
||||
requestType := upstream.RequestType
|
||||
|
||||
if c.conf.Other.Verbose {
|
||||
log.Println("choose upstream:", upstream)
|
||||
}
|
||||
|
||||
var req *DNSRequest
|
||||
if requestType == "application/dns-json" {
|
||||
req = c.generateRequestGoogle(ctx, w, r, isTCP)
|
||||
} else if requestType == "application/dns-message" {
|
||||
req = c.generateRequestIETF(ctx, w, r, isTCP)
|
||||
} else {
|
||||
switch requestType {
|
||||
case "application/dns-json":
|
||||
req = c.generateRequestGoogle(ctx, w, r, isTCP, upstream)
|
||||
|
||||
case "application/dns-message":
|
||||
req = c.generateRequestIETF(ctx, w, r, isTCP, upstream)
|
||||
|
||||
default:
|
||||
panic("Unknown request Content-Type")
|
||||
}
|
||||
|
||||
if req.response != nil {
|
||||
defer req.response.Body.Close()
|
||||
for _, header := range c.conf.DebugHTTPHeaders {
|
||||
if value := req.response.Header.Get(header); value != "" {
|
||||
log.Printf("%s: %s\n", header, value)
|
||||
if req.err != nil {
|
||||
if urlErr, ok := req.err.(*url.Error); ok {
|
||||
// should we only check timeout?
|
||||
if urlErr.Timeout() {
|
||||
c.selector.ReportUpstreamStatus(upstream, selector.Timeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.err != nil {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
contentType := ""
|
||||
candidateType := strings.SplitN(req.response.Header.Get("Content-Type"), ";", 2)[0]
|
||||
if candidateType == "application/json" {
|
||||
contentType = "application/json"
|
||||
} else if candidateType == "application/dns-message" {
|
||||
contentType = "application/dns-message"
|
||||
} else if candidateType == "application/dns-udpwireformat" {
|
||||
contentType = "application/dns-message"
|
||||
} else {
|
||||
if requestType == "application/dns-json" {
|
||||
contentType = "application/json"
|
||||
} else if requestType == "application/dns-message" {
|
||||
contentType = "application/dns-message"
|
||||
// if req.err == nil, req.response != nil
|
||||
defer req.response.Body.Close()
|
||||
|
||||
for _, header := range c.conf.Other.DebugHTTPHeaders {
|
||||
if value := req.response.Header.Get(header); value != "" {
|
||||
log.Printf("%s: %s\n", header, value)
|
||||
}
|
||||
}
|
||||
|
||||
if contentType == "application/json" {
|
||||
candidateType := strings.SplitN(req.response.Header.Get("Content-Type"), ";", 2)[0]
|
||||
|
||||
switch candidateType {
|
||||
case "application/json":
|
||||
c.parseResponseGoogle(ctx, w, r, isTCP, req)
|
||||
} else if contentType == "application/dns-message" {
|
||||
|
||||
case "application/dns-message", "application/dns-udpwireformat":
|
||||
c.parseResponseIETF(ctx, w, r, isTCP, req)
|
||||
} else {
|
||||
panic("Unknown response Content-Type")
|
||||
|
||||
default:
|
||||
switch requestType {
|
||||
case "application/dns-json":
|
||||
c.parseResponseGoogle(ctx, w, r, isTCP, req)
|
||||
|
||||
case "application/dns-message":
|
||||
c.parseResponseIETF(ctx, w, r, isTCP, req)
|
||||
|
||||
default:
|
||||
panic("Unknown response Content-Type")
|
||||
}
|
||||
}
|
||||
|
||||
// https://developers.cloudflare.com/1.1.1.1/dns-over-https/request-structure/ says
|
||||
// returns code will be 200 / 400 / 413 / 415 / 504, some server will return 503, so
|
||||
// I think if status code is 5xx, upstream must has some problems
|
||||
/*if req.response.StatusCode/100 == 5 {
|
||||
c.selector.ReportUpstreamStatus(upstream, selector.Medium)
|
||||
}*/
|
||||
|
||||
switch req.response.StatusCode / 100 {
|
||||
case 5:
|
||||
c.selector.ReportUpstreamStatus(upstream, selector.Error)
|
||||
|
||||
case 2:
|
||||
c.selector.ReportUpstreamStatus(upstream, selector.OK)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,7 +420,7 @@ var (
|
||||
|
||||
func (c *Client) findClientIP(w dns.ResponseWriter, r *dns.Msg) (ednsClientAddress net.IP, ednsClientNetmask uint8) {
|
||||
ednsClientNetmask = 255
|
||||
if c.conf.NoECS {
|
||||
if c.conf.Other.NoECS {
|
||||
return net.IPv4(0, 0, 0, 0), 0
|
||||
}
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
|
||||
Reference in New Issue
Block a user