diff --git a/doh-server/server.go b/doh-server/server.go index 0827c16..f9d76db 100644 --- a/doh-server/server.go +++ b/doh-server/server.go @@ -84,13 +84,13 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { name := r.FormValue("name") if name == "" { - http.Error(w, jsonDNS.FormatError("Invalid argument value: \"name\""), 400) + jsonDNS.FormatError(w, "Invalid argument value: \"name\"", 400) return } if punycode, err := idna.ToASCII(name); err == nil { name = punycode } else { - http.Error(w, jsonDNS.FormatError(fmt.Sprintf("Invalid argument value: \"name\" = %q (%s)", name, err.Error())), 400) + jsonDNS.FormatError(w, fmt.Sprintf("Invalid argument value: \"name\" = %q (%s)", name, err.Error()), 400) return } @@ -102,7 +102,7 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { } else if v, ok := dns.StringToType[strings.ToUpper(rrTypeStr)]; ok { rrType = v } else { - http.Error(w, jsonDNS.FormatError(fmt.Sprintf("Invalid argument value: \"type\" = %q", rrTypeStr)), 400) + jsonDNS.FormatError(w, fmt.Sprintf("Invalid argument value: \"type\" = %q", rrTypeStr), 400) return } @@ -112,7 +112,7 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { cd = true } else if cdStr == "0" || cdStr == "false" || cdStr == "" { } else { - http.Error(w, jsonDNS.FormatError(fmt.Sprintf("Invalid argument value: \"cd\" = %q", cdStr)), 400) + jsonDNS.FormatError(w, fmt.Sprintf("Invalid argument value: \"cd\" = %q", cdStr), 400) return } @@ -125,7 +125,7 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { if slash < 0 { ednsClientAddress = net.ParseIP(ednsClientSubnet) if ednsClientAddress == nil { - http.Error(w, jsonDNS.FormatError(fmt.Sprintf("Invalid argument value: \"edns_client_subnet\" = %q", ednsClientSubnet)), 400) + jsonDNS.FormatError(w, fmt.Sprintf("Invalid argument value: \"edns_client_subnet\" = %q", ednsClientSubnet), 400) return } if ipv4 := ednsClientAddress.To4(); ipv4 != nil { @@ -139,7 +139,7 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { } else { ednsClientAddress = net.ParseIP(ednsClientSubnet[:slash]) if ednsClientAddress == nil { - http.Error(w, jsonDNS.FormatError(fmt.Sprintf("Invalid argument value: \"edns_client_subnet\" = %q", ednsClientSubnet)), 400) + jsonDNS.FormatError(w, fmt.Sprintf("Invalid argument value: \"edns_client_subnet\" = %q", ednsClientSubnet), 400) return } if ipv4 := ednsClientAddress.To4(); ipv4 != nil { @@ -150,7 +150,7 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { } netmask, err := strconv.ParseUint(ednsClientSubnet[slash + 1:], 10, 8) if err != nil { - http.Error(w, jsonDNS.FormatError(fmt.Sprintf("Invalid argument value: \"edns_client_subnet\" = %q (%s)", ednsClientSubnet, err.Error())), 400) + jsonDNS.FormatError(w, fmt.Sprintf("Invalid argument value: \"edns_client_subnet\" = %q (%s)", ednsClientSubnet, err.Error()), 400) return } ednsClientNetmask = uint8(netmask) @@ -190,13 +190,13 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) { resp, err := s.doDNSQuery(msg) if err != nil { - http.Error(w, jsonDNS.FormatError(fmt.Sprintf("DNS query failure (%s)", err.Error())), 503) + jsonDNS.FormatError(w, fmt.Sprintf("DNS query failure (%s)", err.Error()), 503) return } respJson := jsonDNS.Marshal(resp) respStr, err := json.Marshal(respJson) if err != nil { - http.Error(w, jsonDNS.FormatError(fmt.Sprintf("DNS packet parse failure (%s)", err.Error())), 500) + jsonDNS.FormatError(w, fmt.Sprintf("DNS packet parse failure (%s)", err.Error()), 500) return } diff --git a/json-dns/error.go b/json-dns/error.go index 6f0a79f..d08ff01 100644 --- a/json-dns/error.go +++ b/json-dns/error.go @@ -21,6 +21,7 @@ package jsonDNS import ( "encoding/json" "log" + "net/http" "github.com/miekg/dns" ) @@ -29,7 +30,7 @@ type dnsError struct { Comment string `json:"Comment,omitempty"` } -func FormatError(comment string) string { +func FormatError(w http.ResponseWriter, comment string, errcode int) { errJson := dnsError { Status: dns.RcodeServerFailure, Comment: comment, @@ -38,5 +39,6 @@ func FormatError(comment string) string { if err != nil { log.Fatalln(err) } - return string(errStr) + w.WriteHeader(errcode) + w.Write(errStr) }