Files
David Zhao ad76898f03 support auth checks with mock server (#4629)
* support auth checks with mock server

* simplify unit tests
2026-06-28 23:19:05 -07:00

255 lines
7.4 KiB
Go

// Copyright 2026 LiveKit, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Command test-server is a programmable mock of the LiveKit server HTTP API,
// used by the server SDKs to test client behavior. See cmd/test-server/README.md.
package main
import (
"errors"
"fmt"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/utils/protojson"
)
// X-Lk-Mock-* request headers control the mock's behavior; see the README.
const (
headerFailRegions = "X-Lk-Mock-Fail-Regions"
headerFailMode = "X-Lk-Mock-Fail-Mode"
headerFailStatus = "X-Lk-Mock-Fail-Status"
headerFailTwirpCode = "X-Lk-Mock-Fail-Twirp-Code"
headerDelayMs = "X-Lk-Mock-Delay-Ms"
headerRegionsStatus = "X-Lk-Mock-Regions-Status"
headerResponse = "X-Lk-Mock-Response"
// headerSkipAuth disables permission enforcement for a request.
headerSkipAuth = "X-Lk-Mock-Skip-Auth"
// headerRegion is set on responses to the index of the region that served it.
headerRegion = "X-Lk-Mock-Region"
)
const defaultDelayMs = 30_000
func main() {
portsFlag := flagValue("--ports", "LK_TEST_SERVER_PORTS", "9999,10000,10001,10002")
advertiseHost := flagValue("--advertise-host", "LK_TEST_SERVER_ADVERTISE_HOST", "http://127.0.0.1")
bindAddr := flagValue("--bind", "LK_TEST_SERVER_BIND", "0.0.0.0")
twirpPrefix := flagValue("--twirp-prefix", "LK_TEST_SERVER_TWIRP_PREFIX", "/twirp")
// API secret used to verify request tokens for permission enforcement.
// Defaults to the `livekit-server --dev` secret.
apiSecret := flagValue("--api-secret", "LK_TEST_SERVER_API_SECRET", "secret")
ports, err := parsePorts(portsFlag)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid --ports: %v\n", err)
os.Exit(1)
}
advertiseHost = strings.TrimRight(advertiseHost, "/")
regions := &livekit.RegionSettings{}
for i, p := range ports {
regions.Regions = append(regions.Regions, &livekit.RegionInfo{
Region: fmt.Sprintf("region-%d", i),
Url: fmt.Sprintf("%s:%d", advertiseHost, p),
Distance: int64(i),
})
}
errCh := make(chan error, len(ports))
for i, p := range ports {
srv := &http.Server{
Addr: fmt.Sprintf("%s:%d", bindAddr, p),
Handler: &mockHandler{regionIndex: i, regions: regions, twirpPrefix: twirpPrefix, apiSecret: apiSecret},
}
go func() { errCh <- srv.ListenAndServe() }()
fmt.Printf("test-server: region-%d listening on %s:%d (advertised as %s:%d)\n", i, bindAddr, p, advertiseHost, p)
}
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case err := <-errCh:
fmt.Fprintf(os.Stderr, "listener failed: %v\n", err)
os.Exit(1)
case <-sigCh:
fmt.Println("test-server: shutting down")
}
}
type mockHandler struct {
regionIndex int
regions *livekit.RegionSettings
twirpPrefix string
apiSecret string
}
func (h *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/settings/regions":
h.handleRegions(w, r)
case strings.HasPrefix(r.URL.Path, h.twirpPrefix+"/"):
h.handleTwirp(w, r)
case r.URL.Path == "/" || r.URL.Path == "/_test/health":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
default:
http.NotFound(w, r)
}
}
func (h *mockHandler) handleRegions(w http.ResponseWriter, r *http.Request) {
if status := parseStatus(r.Header.Get(headerRegionsStatus), 0); status != 0 && status != http.StatusOK {
w.WriteHeader(status)
return
}
body, err := protojson.Marshal(h.regions)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "max-age=0")
w.Header().Set(headerRegion, strconv.Itoa(h.regionIndex))
_, _ = w.Write(body)
}
func (h *mockHandler) handleTwirp(w http.ResponseWriter, r *http.Request) {
h.serveAPI(w, r)
}
func (h *mockHandler) shouldFail(r *http.Request) bool {
for _, part := range strings.Split(r.Header.Get(headerFailRegions), ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if idx, err := strconv.Atoi(part); err == nil && idx == h.regionIndex {
return true
}
}
return false
}
func (h *mockHandler) fail(w http.ResponseWriter, r *http.Request) {
switch strings.ToLower(r.Header.Get(headerFailMode)) {
case "drop":
if hj, ok := w.(http.Hijacker); ok {
if conn, _, err := hj.Hijack(); err == nil {
_ = conn.Close()
return
}
}
w.WriteHeader(http.StatusServiceUnavailable)
case "delay":
delay := defaultDelayMs
if ms, err := strconv.Atoi(r.Header.Get(headerDelayMs)); err == nil && ms >= 0 {
delay = ms
}
time.Sleep(time.Duration(delay) * time.Millisecond)
writeTwirpError(w, r, parseStatus(r.Header.Get(headerFailStatus), http.StatusServiceUnavailable))
default:
writeTwirpError(w, r, parseStatus(r.Header.Get(headerFailStatus), http.StatusServiceUnavailable))
}
}
func writeTwirpError(w http.ResponseWriter, r *http.Request, status int) {
code := r.Header.Get(headerFailTwirpCode)
if code == "" {
code = twirpCodeForStatus(status)
}
writeTwirpErrorCode(w, status, code, fmt.Sprintf("mock failure (status %d)", status))
}
// writeTwirpErrorCode writes a Twirp JSON error with an explicit code and message.
func writeTwirpErrorCode(w http.ResponseWriter, status int, code, msg string) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set(headerRegion, "")
w.WriteHeader(status)
_, _ = fmt.Fprintf(w, `{"code":%q,"msg":%q}`, code, msg)
}
func twirpCodeForStatus(status int) string {
switch {
case status == http.StatusBadRequest:
return "invalid_argument"
case status == http.StatusUnauthorized:
return "unauthenticated"
case status == http.StatusForbidden:
return "permission_denied"
case status == http.StatusNotFound:
return "not_found"
case status == http.StatusTooManyRequests:
return "resource_exhausted"
case status >= 500:
return "unavailable"
default:
return "internal"
}
}
func parseStatus(s string, def int) int {
if s == "" {
return def
}
if v, err := strconv.Atoi(strings.TrimSpace(s)); err == nil && v >= 100 && v <= 599 {
return v
}
return def
}
func parsePorts(s string) ([]int, error) {
var ports []int
for _, part := range strings.Split(s, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
v, err := strconv.Atoi(part)
if err != nil {
return nil, fmt.Errorf("%q is not a port number", part)
}
ports = append(ports, v)
}
if len(ports) == 0 {
return nil, errors.New("at least one port is required")
}
return ports, nil
}
// flagValue resolves a setting from a --flag, then an environment variable, then a default.
func flagValue(flag, env, def string) string {
prefix := flag + "="
for i, arg := range os.Args[1:] {
if arg == flag {
if i+2 <= len(os.Args[1:]) {
return os.Args[1:][i+1]
}
}
if strings.HasPrefix(arg, prefix) {
return strings.TrimPrefix(arg, prefix)
}
}
if v := os.Getenv(env); v != "" {
return v
}
return def
}