diff --git a/.github/banner_dark.png b/.github/banner_dark.png index d19c687ce..20c986b97 100644 Binary files a/.github/banner_dark.png and b/.github/banner_dark.png differ diff --git a/.github/banner_light.png b/.github/banner_light.png index 788b1f94b..7cd686bb5 100644 Binary files a/.github/banner_light.png and b/.github/banner_light.png differ diff --git a/.github/workflows/buildtest.yaml b/.github/workflows/buildtest.yaml index 78aa75721..0b906e6f0 100644 --- a/.github/workflows/buildtest.yaml +++ b/.github/workflows/buildtest.yaml @@ -33,7 +33,7 @@ jobs: - run: redis-cli ping - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: "1.20" @@ -71,7 +71,7 @@ jobs: # Upload the original go test log as an artifact for later review. - name: Upload test log - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 if: always() with: name: test-log diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 58574e162..967b1f734 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -39,7 +39,7 @@ jobs: type=semver,pattern=v{{major}}.{{minor}} - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: '>=1.18' diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index cd3b523fc..2f37c7a87 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -33,7 +33,7 @@ jobs: run: git fetch --force --tags - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: '>=1.18' diff --git a/CHANGELOG b/CHANGELOG index 1947140a6..bb789bceb 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -2,6 +2,32 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.5.2] - 2023-12-21 + +Support for LiveKit SIP Bridge + +### Added +- Add SIP Support (#2240 #2241 #2244 #2250 #2263 #2291 #2293) +- Introduce `LOST` connection quality. (#2265 #2276) +- Expose detailed connection info with ICEConnectionDetails (#2287) +- Add Version to TrackInfo. (#2324 #2325) + +### Fixed +- Guard against bad quality in trackInfo (#2271) +- Group SDES items for one SSRC in the same chunk. (#2280) +- Avoid dropping data packets on local router (#2270) +- Fix signal response delivery after session start failure (#2294) +- Populate disconnect updates with participant identity (#2310) +- Fix mid info lost when migrating multi-codec simulcast track (#2315) +- Store identity in participant update cache. (#2320) +- Fix panic occurs when starting livekit-server with key-file option (#2312) (#2313) + +### Changed +- INFO logging reduction (#2243 #2273 #2275 #2281 #2283 #2285 #2322) +- Clean up restart a bit. (#2247) +- Use a worker to report signal/data stats. (#2260) +- Consolidate TrackInfo. (#2331) + ## [1.5.1] - 2023-11-09 Support for the Agent framework. diff --git a/README.md b/README.md index c05377095..6040880cf 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,12 @@ - - - - The LiveKit icon, the name of the repository and some sample code in the background. - - + + + + + The LiveKit icon, the name of the repository and some sample code in the background. + + + # LiveKit: Real-time video, audio and data for developers @@ -297,9 +299,10 @@ LiveKit server is licensed under Apache License v2.0.
- - - + + + +
LiveKit Ecosystem
Client SDKsComponents · JavaScript · iOS/macOS · Android · Flutter · React Native · Rust · Python · Unity (web) · Unity (beta)
Server SDKsNode.js · Golang · Ruby · Java/Kotlin · PHP (community) · Python (community)
ServicesLivekit server · Egress · Ingress
Real-time SDKsReact Components · JavaScript · iOS/macOS · Android · Flutter · React Native · Rust · Python · Unity (web) · Unity (beta)
Server APIsNode.js · Golang · Ruby · Java/Kotlin · Python · Rust · PHP (community)
Agents FrameworksPython · Playground
ServicesLivekit server · Egress · Ingress · SIP
ResourcesDocs · Example apps · Cloud · Self-hosting · CLI
diff --git a/bootstrap.sh b/bootstrap.sh index 3109ddc38..7511dba3b 100755 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # Copyright 2023 LiveKit, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/cmd/server/main.go b/cmd/server/main.go index c53f01640..b9bfea418 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -292,9 +292,12 @@ func startServer(c *cli.Context) error { signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) go func() { - sig := <-sigChan - logger.Infow("exit requested, shutting down", "signal", sig) - server.Stop(false) + for i := 0; i < 2; i++ { + sig := <-sigChan + force := i > 0 + logger.Infow("exit requested, shutting down", "signal", sig, "force", force) + go server.Stop(force) + } }() return server.Start() diff --git a/config-sample.yaml b/config-sample.yaml index 4a70f2ae8..0d001bd10 100644 --- a/config-sample.yaml +++ b/config-sample.yaml @@ -183,8 +183,6 @@ keys: # since v1.4.0, a more reliable, psrpc based signal relay is available # this gives us the ability to reliably proxy messages between a signal server and RTC node # signal_relay: -# # enabled by default as of v1.5.0, legacy signal proxy will be removed in v1.6 -# enabled: true # # amount of time a message delivery is tried before giving up # retry_timeout: 30s # # minimum amount of time to wait for RTC node to ack, @@ -199,8 +197,6 @@ keys: # PSRPC # since v1.5.1, a more reliable, psrpc based internal rpc # psrpc: -# # enable the psrpc internal api client for roomservice calls -# enabled: true # # maximum number of rpc attempts # max_attempts: 3 # # initial time to wait for calls to complete diff --git a/go.mod b/go.mod index 558fd67d9..1d6f8666e 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/d5/tengo/v2 v2.16.1 github.com/dustin/go-humanize v1.0.1 github.com/elliotchance/orderedmap/v2 v2.2.0 - github.com/florianl/go-tc v0.4.2 + github.com/florianl/go-tc v0.4.3 github.com/frostbyte73/core v0.0.9 github.com/gammazero/deque v0.2.1 github.com/gammazero/workerpool v1.1.3 @@ -17,38 +17,38 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 - github.com/livekit/mediatransportutil v0.0.0-20231128042044-05525c8278cb - github.com/livekit/protocol v1.9.3-0.20231129173544-1c3f5fe919b0 - github.com/livekit/psrpc v0.5.2 + github.com/livekit/mediatransportutil v0.0.0-20231213075826-cccbf2b93d3f + github.com/livekit/protocol v1.9.8-0.20240130003842-54f76fc6865e + github.com/livekit/psrpc v0.5.3-0.20240129223932-473b29cda289 github.com/mackerelio/go-osstat v0.2.4 github.com/magefile/mage v1.15.0 - github.com/maxbrunsfeld/counterfeiter/v6 v6.7.0 + github.com/maxbrunsfeld/counterfeiter/v6 v6.8.1 github.com/mitchellh/go-homedir v1.1.0 github.com/olekukonko/tablewriter v0.0.5 - github.com/pion/dtls/v2 v2.2.8 - github.com/pion/ice/v2 v2.3.11 + github.com/pion/dtls/v2 v2.2.9 + github.com/pion/ice/v2 v2.3.12 github.com/pion/interceptor v0.1.25 - github.com/pion/rtcp v1.2.12 + github.com/pion/rtcp v1.2.13 github.com/pion/rtp v1.8.3 github.com/pion/sctp v1.8.9 github.com/pion/sdp/v3 v3.0.6 github.com/pion/transport/v2 v2.2.4 github.com/pion/turn/v2 v2.1.4 - github.com/pion/webrtc/v3 v3.2.23 + github.com/pion/webrtc/v3 v3.2.24 github.com/pkg/errors v0.9.1 - github.com/prometheus/client_golang v1.17.0 - github.com/redis/go-redis/v9 v9.3.0 + github.com/prometheus/client_golang v1.18.0 + github.com/redis/go-redis/v9 v9.4.0 github.com/rs/cors v1.10.1 github.com/stretchr/testify v1.8.4 github.com/thoas/go-funk v0.9.3 github.com/twitchtv/twirp v8.1.3+incompatible - github.com/ua-parser/uap-go v0.0.0-20230823213814-f77b3e91e9dc - github.com/urfave/cli/v2 v2.25.7 + github.com/ua-parser/uap-go v0.0.0-20240113215029-33f8e6d47f38 + github.com/urfave/cli/v2 v2.27.1 github.com/urfave/negroni/v3 v3.0.0 go.uber.org/atomic v1.11.0 - golang.org/x/exp v0.0.0-20231127185646-65229373498e - golang.org/x/sync v0.5.0 - google.golang.org/protobuf v1.31.0 + golang.org/x/exp v0.0.0-20240119083558-1b970713d09a + golang.org/x/sync v0.6.0 + google.golang.org/protobuf v1.32.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -61,47 +61,47 @@ require ( github.com/eapache/channels v1.1.0 // indirect github.com/eapache/queue v1.1.0 // indirect github.com/go-jose/go-jose/v3 v3.0.1 // indirect - github.com/go-logr/logr v1.3.0 // indirect + github.com/go-logr/logr v1.4.1 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/go-cmp v0.5.9 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/subcommands v1.2.0 // indirect - github.com/google/uuid v1.3.1 // indirect + github.com/google/uuid v1.5.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.5 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/josharian/native v1.1.0 // indirect - github.com/klauspost/compress v1.17.3 // indirect + github.com/klauspost/compress v1.17.5 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/lithammer/shortuuid/v4 v4.0.0 // indirect github.com/mattn/go-runewidth v0.0.9 // indirect - github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect + github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect github.com/mdlayher/netlink v1.7.1 // indirect github.com/mdlayher/socket v0.4.0 // indirect - github.com/nats-io/nats.go v1.31.0 // indirect - github.com/nats-io/nkeys v0.4.6 // indirect + github.com/nats-io/nats.go v1.32.0 // indirect + github.com/nats-io/nkeys v0.4.7 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/pion/datachannel v1.5.5 // indirect github.com/pion/logging v0.2.2 // indirect - github.com/pion/mdns v0.0.8 // indirect + github.com/pion/mdns v0.0.9 // indirect github.com/pion/randutil v0.1.0 // indirect github.com/pion/srtp/v2 v2.0.18 // indirect github.com/pion/stun v0.6.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect - github.com/prometheus/common v0.44.0 // indirect - github.com/prometheus/procfs v0.11.1 // indirect + github.com/prometheus/client_model v0.5.0 // indirect + github.com/prometheus/common v0.45.0 // indirect + github.com/prometheus/procfs v0.12.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.26.0 // indirect - golang.org/x/crypto v0.16.0 // indirect + golang.org/x/crypto v0.18.0 // indirect golang.org/x/mod v0.14.0 // indirect - golang.org/x/net v0.19.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/net v0.20.0 // indirect + golang.org/x/sys v0.16.0 // indirect golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.16.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20231127180814-3a041ad873d4 // indirect - google.golang.org/grpc v1.59.0 // indirect + golang.org/x/tools v0.17.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240125205218-1f4bbc51befe // indirect + google.golang.org/grpc v1.61.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index a1b81d76c..231b870b9 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,8 @@ github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/elliotchance/orderedmap/v2 v2.2.0 h1:7/2iwO98kYT4XkOjA9mBEIwvi4KpGB4cyHeOFOnj4Vk= github.com/elliotchance/orderedmap/v2 v2.2.0/go.mod h1:85lZyVbpGaGvHvnKa7Qhx7zncAdBIBq6u56Hb1PRU5Q= -github.com/florianl/go-tc v0.4.2 h1:jan5zcOWCLhA9SRBHZhQ0SSAq7cmDUagiRPngAi5AOQ= -github.com/florianl/go-tc v0.4.2/go.mod h1:2W1jSMFryiYlpQigr4ZpSSpE9XNze+bW7cTsCXWbMwo= +github.com/florianl/go-tc v0.4.3 h1:xpobG2gFNvEqbclU07zjddALSjqTQTWJkxg5/kRYDpw= +github.com/florianl/go-tc v0.4.3/go.mod h1:uvp6pIlOw7Z8hhfnT5M4+V1hHVgZWRZwwMS8Z0JsRxc= github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= github.com/frankban/quicktest v1.14.0/go.mod h1:NeW+ay9A/U67EYXNFA1nPE8e/tnQv/09mUdL/ijj8og= github.com/frostbyte73/core v0.0.9 h1:AmE9GjgGpPsWk9ZkmY3HsYUs2hf2tZt+/W6r49URBQI= @@ -42,8 +42,8 @@ github.com/gammazero/workerpool v1.1.3 h1:WixN4xzukFoN0XSeXF6puqEqFTl2mECI9S6W44 github.com/gammazero/workerpool v1.1.3/go.mod h1:wPjyBLDbyKnUn2XwwyD3EEwo9dHutia9/fwNmSHWACc= github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= -github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= -github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= @@ -66,14 +66,15 @@ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/subcommands v1.0.1/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= +github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.5.0 h1:I7ELFeVBr3yfPIcc8+MWvrjk+3VjbcSzoXm3JVa+jD8= github.com/google/wire v0.5.0/go.mod h1:ngWDr9Qvq3yZA10YrxfyGELY/AFWGVpy9c1LTRi1EoU= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= @@ -107,8 +108,8 @@ github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786 h1:N527AHMa79 github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786/go.mod h1:v4hqbTdfQngbVSZJVWUhGE/lbTFf9jb+ygmNUDQMuOs= github.com/jxskiss/base62 v1.1.0 h1:A5zbF8v8WXx2xixnAKD2w+abC+sIzYJX+nxmhA6HWFw= github.com/jxskiss/base62 v1.1.0/go.mod h1:HhWAlUXvxKThfOlZbcuFzsqwtF5TcqS9ru3y5GfjWAc= -github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= -github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.5 h1:d4vBd+7CHydUqpFBgUEKkSdtSugf9YFmSkvUYPquI5E= +github.com/klauspost/compress v1.17.5/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -123,22 +124,22 @@ github.com/lithammer/shortuuid/v4 v4.0.0 h1:QRbbVkfgNippHOS8PXDkti4NaWeyYfcBTHtw github.com/lithammer/shortuuid/v4 v4.0.0/go.mod h1:Zs8puNcrvf2rV9rTH51ZLLcj7ZXqQI3lv67aw4KiB1Y= github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkDaKb5iXdynYrzB84ErPPO4LbRASk58= github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= -github.com/livekit/mediatransportutil v0.0.0-20231128042044-05525c8278cb h1:KiGg4k+kYQD9NjKixaSDMMeYOO2//XBM4IROTI1Itjo= -github.com/livekit/mediatransportutil v0.0.0-20231128042044-05525c8278cb/go.mod h1:GBzn9xL+mivI1pW+tyExcKgbc0VOc29I9yJsNcAVaAc= -github.com/livekit/protocol v1.9.3-0.20231129173544-1c3f5fe919b0 h1:AhJlQejQ+Ma9Q+EPqCNt2S7h6ETJXDiO7qsQdTq9VvM= -github.com/livekit/protocol v1.9.3-0.20231129173544-1c3f5fe919b0/go.mod h1:8f342d5nvfNp9YAEfJokSR+zbNFpaivgU0h6vwaYhes= -github.com/livekit/psrpc v0.5.2 h1:+MvG8Otm/J6MTg2MP/uuMbrkxOWsrj2hDhu/I1VIU1U= -github.com/livekit/psrpc v0.5.2/go.mod h1:cQjxg1oCxYHhxxv6KJH1gSvdtCHQoRZCHgPdm5N8v2g= +github.com/livekit/mediatransportutil v0.0.0-20231213075826-cccbf2b93d3f h1:XHrwGwLNGQB3ZqolH1YdMH/22hgXKr4vm+2M7JKMMGg= +github.com/livekit/mediatransportutil v0.0.0-20231213075826-cccbf2b93d3f/go.mod h1:GBzn9xL+mivI1pW+tyExcKgbc0VOc29I9yJsNcAVaAc= +github.com/livekit/protocol v1.9.8-0.20240130003842-54f76fc6865e h1:aQGmivssjKlvca8Mmmm4XzHQrFV+vwrGIhAxQE8tnZk= +github.com/livekit/protocol v1.9.8-0.20240130003842-54f76fc6865e/go.mod h1:lSJlMeTJfQBEv8/D2p3zdCo+i+jTmTtn24ysL4ePK28= +github.com/livekit/psrpc v0.5.3-0.20240129223932-473b29cda289 h1:oTgNH7v9TXsBgoltKk5mnWjv4qqcPF2iV+WtEVQ6ROM= +github.com/livekit/psrpc v0.5.3-0.20240129223932-473b29cda289/go.mod h1:cQjxg1oCxYHhxxv6KJH1gSvdtCHQoRZCHgPdm5N8v2g= github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs= github.com/mackerelio/go-osstat v0.2.4/go.mod h1:Zy+qzGdZs3A9cuIqmgbJvwbmLQH9dJvtio5ZjJTbdlQ= github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= -github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= -github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= -github.com/maxbrunsfeld/counterfeiter/v6 v6.7.0 h1:z0CfPybq3CxaJvrrpf7Gme1psZTqHhJxf83q6apkSpI= -github.com/maxbrunsfeld/counterfeiter/v6 v6.7.0/go.mod h1:RVP6/F85JyxTrbJxWIdKU2vlSvK48iCMnMXRkSz7xtg= +github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= +github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= +github.com/maxbrunsfeld/counterfeiter/v6 v6.8.1 h1:NicmruxkeqHjDv03SfSxqmaLuisddudfP3h5wdXFbhM= +github.com/maxbrunsfeld/counterfeiter/v6 v6.8.1/go.mod h1:eyp4DdUJAKkr9tvxR3jWhw2mDK7CWABMG5r9uyaKC7I= github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43/go.mod h1:+t7E0lkKfbBsebllff1xdTmyJt8lH37niI6kwFk9OTo= github.com/mdlayher/genetlink v1.0.0/go.mod h1:0rJ0h4itni50A86M2kHcgS85ttZazNt7a8H2a2cw0Gc= github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA= @@ -160,10 +161,10 @@ github.com/mdlayher/socket v0.4.0 h1:280wsy40IC9M9q1uPGcLBwXpcTQDtoGwVt+BNoITxIw github.com/mdlayher/socket v0.4.0/go.mod h1:xxFqz5GRCUN3UEOm9CZqEJsAbe1C8OwSK46NlmWuVoc= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/nats-io/nats.go v1.31.0 h1:/WFBHEc/dOKBF6qf1TZhrdEfTmOZ5JzdJ+Y3m6Y/p7E= -github.com/nats-io/nats.go v1.31.0/go.mod h1:di3Bm5MLsoB4Bx61CBTsxuarI36WbhAwOm8QrW39+i8= -github.com/nats-io/nkeys v0.4.6 h1:IzVe95ru2CT6ta874rt9saQRkWfe2nFj1NtvYSLqMzY= -github.com/nats-io/nkeys v0.4.6/go.mod h1:4DxZNzenSVd1cYQoAa8948QY3QDjrHfcfVADymtkpts= +github.com/nats-io/nats.go v1.32.0 h1:Bx9BZS+aXYlxW08k8Gd3yR2s73pV5XSoAQUyp1Kwvp0= +github.com/nats-io/nats.go v1.32.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= +github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= +github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= @@ -177,25 +178,28 @@ github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042 github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= -github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= +github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8= github.com/pion/datachannel v1.5.5 h1:10ef4kwdjije+M9d7Xm9im2Y3O6A6ccQb0zcqZcJew8= github.com/pion/datachannel v1.5.5/go.mod h1:iMz+lECmfdCMqFRhXhcA/219B0SQlbpoR2V118yimL0= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= -github.com/pion/dtls/v2 v2.2.8 h1:BUroldfiIbV9jSnC6cKOMnyiORRWrWWpV11JUyEu5OA= -github.com/pion/dtls/v2 v2.2.8/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= -github.com/pion/ice/v2 v2.3.11 h1:rZjVmUwyT55cmN8ySMpL7rsS8KYsJERsrxJLLxpKhdw= +github.com/pion/dtls/v2 v2.2.9 h1:K+D/aVf9/REahQvqk6G5JavdrD8W1PWDKC11UlwN7ts= +github.com/pion/dtls/v2 v2.2.9/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/ice/v2 v2.3.11/go.mod h1:hPcLC3kxMa+JGRzMHqQzjoSj3xtE9F+eoncmXLlCL4E= +github.com/pion/ice/v2 v2.3.12 h1:NWKW2b3+oSZS3klbQMIEWQ0i52Kuo0KBg505a5kQv4s= +github.com/pion/ice/v2 v2.3.12/go.mod h1:hPcLC3kxMa+JGRzMHqQzjoSj3xtE9F+eoncmXLlCL4E= github.com/pion/interceptor v0.1.25 h1:pwY9r7P6ToQ3+IF0bajN0xmk/fNw/suTgaTdlwTDmhc= github.com/pion/interceptor v0.1.25/go.mod h1:wkbPYAak5zKsfpVDYMtEfWEy8D4zL+rpxCxPImLOg3Y= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/mdns v0.0.8 h1:HhicWIg7OX5PVilyBO6plhMetInbzkVJAhbdJiAeVaI= github.com/pion/mdns v0.0.8/go.mod h1:hYE72WX8WDveIhg7fmXgMKivD3Puklk0Ymzog0lSyaI= +github.com/pion/mdns v0.0.9 h1:7Ue5KZsqq8EuqStnpPWV33vYYEH0+skdDN5L7EiEsI4= +github.com/pion/mdns v0.0.9/go.mod h1:2JA5exfxwzXiCihmxpTKgFUpiQws2MnipoPK09vecIc= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/rtcp v1.2.10/go.mod h1:ztfEwXZNLGyF1oQDttz/ZKIBaeeg/oWbRYqzBM9TL1I= -github.com/pion/rtcp v1.2.12 h1:bKWiX93XKgDZENEXCijvHRU/wRifm6JV5DGcH6twtSM= github.com/pion/rtcp v1.2.12/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4= +github.com/pion/rtcp v1.2.13 h1:+EQijuisKwm/8VBs8nWllr0bIndR7Lf7cZG200mpbNo= +github.com/pion/rtcp v1.2.13/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4= github.com/pion/rtp v1.8.2/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= github.com/pion/rtp v1.8.3 h1:VEHxqzSVQxCkKDSHro5/4IUUG1ea+MFdqR2R3xSpNU8= github.com/pion/rtp v1.8.3/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= @@ -221,22 +225,22 @@ github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9 github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= github.com/pion/turn/v2 v2.1.4 h1:2xn8rduI5W6sCZQkEnIUDAkrBQNl2eYIBCHMZ3QMmP8= github.com/pion/turn/v2 v2.1.4/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= -github.com/pion/webrtc/v3 v3.2.23 h1:GbqEuxBbVLFhXk0GwxKAoaIJYiEa9TyoZPEZC+2HZxM= -github.com/pion/webrtc/v3 v3.2.23/go.mod h1:1CaT2fcZzZ6VZA+O1i9yK2DU4EOcXVvSbWG9pr5jefs= +github.com/pion/webrtc/v3 v3.2.24 h1:MiFL5DMo2bDaaIFWr0DDpwiV/L4EGbLZb+xoRvfEo1Y= +github.com/pion/webrtc/v3 v3.2.24/go.mod h1:1CaT2fcZzZ6VZA+O1i9yK2DU4EOcXVvSbWG9pr5jefs= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q= -github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY= -github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM= -github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= -github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY= -github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= -github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI= -github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY= -github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0= -github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= +github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= +github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= +github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= +github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= +github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= +github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= +github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/redis/go-redis/v9 v9.4.0 h1:Yzoz33UZw9I/mFhx4MNrB6Fk+XHO1VukNcCa1+lwyKk= +github.com/redis/go-redis/v9 v9.4.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo= @@ -262,10 +266,10 @@ github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw= github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/twitchtv/twirp v8.1.3+incompatible h1:+F4TdErPgSUbMZMwp13Q/KgDVuI7HJXP61mNV3/7iuU= github.com/twitchtv/twirp v8.1.3+incompatible/go.mod h1:RRJoFSAmTEh2weEqWtpPE3vFK5YBhA6bqp2l1kfCC5A= -github.com/ua-parser/uap-go v0.0.0-20230823213814-f77b3e91e9dc h1:iT5lwxf894PiMq7cnMMQg/7VOD1pxmu//gQuHWAFy4s= -github.com/ua-parser/uap-go v0.0.0-20230823213814-f77b3e91e9dc/go.mod h1:BUbeWZiieNxAuuADTBNb3/aeje6on3DhU3rpWsQSB1E= -github.com/urfave/cli/v2 v2.25.7 h1:VAzn5oq403l5pHjc4OhD54+XGO9cdKVL/7lDjF+iKUs= -github.com/urfave/cli/v2 v2.25.7/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= +github.com/ua-parser/uap-go v0.0.0-20240113215029-33f8e6d47f38 h1:F04Na0QJP9GJrwmK3vQDuDrCuGllrrfngW8CIeF1aag= +github.com/ua-parser/uap-go v0.0.0-20240113215029-33f8e6d47f38/go.mod h1:BUbeWZiieNxAuuADTBNb3/aeje6on3DhU3rpWsQSB1E= +github.com/urfave/cli/v2 v2.27.1 h1:8xSQ6szndafKVRmfyeUMxkNUJQMjL1F2zmsZ+qHpfho= +github.com/urfave/cli/v2 v2.27.1/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= github.com/urfave/negroni/v3 v3.0.0 h1:Vo8CeZfu1lFR9gW8GnAb6dOGCJyijfil9j/jKKc/JhU= github.com/urfave/negroni/v3 v3.0.0/go.mod h1:jWvnX03kcSjDBl/ShB0iHvx5uOs7mAzZXW+JvJ5XYAs= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= @@ -291,10 +295,11 @@ golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= -golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/exp v0.0.0-20231127185646-65229373498e h1:Gvh4YaCaXNs6dKTlfgismwWZKyjVZXwOPfIyUaqU3No= -golang.org/x/exp v0.0.0-20231127185646-65229373498e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= +golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -327,17 +332,17 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= -golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -378,8 +383,9 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -389,6 +395,7 @@ golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -399,6 +406,7 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -407,16 +415,16 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.16.0 h1:GO788SKMRunPIBCXiQyo2AaexLstOrVhuAL5YwsckQM= -golang.org/x/tools v0.16.0/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= +golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= +golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231127180814-3a041ad873d4 h1:DC7wcm+i+P1rN3Ff07vL+OndGg5OhNddHyTA+ocPqYE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231127180814-3a041ad873d4/go.mod h1:eJVxU6o+4G1PSczBr85xmyvSNYAKvAYgkub40YGomFM= -google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= -google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240125205218-1f4bbc51befe h1:bQnxqljG/wqi4NTXu2+DJ3n7APcEA882QZ1JvhQAq9o= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240125205218-1f4bbc51befe/go.mod h1:PAREbraiVEVGVdTZsVWjSbbTtSyGbAgIIvni8a8CD5s= +google.golang.org/grpc v1.61.0 h1:TOvOcuXn30kRao+gfcvsebNEa5iZIiLkisYEkf7R7o0= +google.golang.org/grpc v1.61.0/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -425,8 +433,8 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= +google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/clientconfiguration/staticconfiguration.go b/pkg/clientconfiguration/staticconfiguration.go index e0309c26c..4fb67bae4 100644 --- a/pkg/clientconfiguration/staticconfiguration.go +++ b/pkg/clientconfiguration/staticconfiguration.go @@ -17,6 +17,7 @@ package clientconfiguration import ( "google.golang.org/protobuf/proto" + "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) @@ -40,7 +41,9 @@ func (s *StaticClientConfigurationManager) GetConfiguration(clientInfo *livekit. for _, c := range s.confs { matched, err := c.Match.Match(clientInfo) if err != nil { - logger.Errorw("matchrule failed", err, "clientInfo", logger.Proto(clientInfo)) + logger.Errorw("matchrule failed", err, + "clientInfo", logger.Proto(utils.ClientInfoWithoutAddress(clientInfo)), + ) continue } if !matched { diff --git a/pkg/config/config.go b/pkg/config/config.go index dd0e84012..36e76d473 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -268,7 +268,6 @@ type NodeSelectorConfig struct { } type SignalRelayConfig struct { - Enabled bool `yaml:"enabled,omitempty"` RetryTimeout time.Duration `yaml:"retry_timeout,omitempty"` MinRetryInterval time.Duration `yaml:"min_retry_interval,omitempty"` MaxRetryInterval time.Duration `yaml:"max_retry_interval,omitempty"` @@ -487,7 +486,6 @@ var DefaultConfig = Config{ CPULoadLimit: 0.9, }, SignalRelay: SignalRelayConfig{ - Enabled: true, RetryTimeout: 7500 * time.Millisecond, MinRetryInterval: 500 * time.Millisecond, MaxRetryInterval: 4 * time.Second, @@ -652,6 +650,7 @@ func (conf *Config) ValidateKeys() error { _ = f.Close() }() decoder := yaml.NewDecoder(f) + conf.Keys = map[string]string{} if err = decoder.Decode(conf.Keys); err != nil { return err } diff --git a/pkg/routing/errors.go b/pkg/routing/errors.go index 4b6af0686..28a0dd1ac 100644 --- a/pkg/routing/errors.go +++ b/pkg/routing/errors.go @@ -14,7 +14,9 @@ package routing -import "errors" +import ( + "errors" +) var ( ErrNotFound = errors.New("could not find object") @@ -26,4 +28,9 @@ var ( ErrInvalidRouterMessage = errors.New("invalid router message") ErrChannelClosed = errors.New("channel closed") ErrChannelFull = errors.New("channel is full") + + // errors when starting signal connection + ErrRequestChannelClosed = errors.New("request channel closed") + ErrCouldNotMigrateParticipant = errors.New("could not migrate participant") + ErrClientInfoNotSet = errors.New("client info not set") ) diff --git a/pkg/routing/interfaces.go b/pkg/routing/interfaces.go index 7a450bd16..d13d734a0 100644 --- a/pkg/routing/interfaces.go +++ b/pkg/routing/interfaces.go @@ -21,10 +21,10 @@ import ( "github.com/redis/go-redis/v9" "google.golang.org/protobuf/proto" - "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" ) //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate @@ -63,21 +63,6 @@ type ParticipantInit struct { SubscriberAllowPause *bool } -type NewParticipantCallback func( - ctx context.Context, - roomName livekit.RoomName, - pi ParticipantInit, - requestSource MessageSource, - responseSink MessageSink, -) error - -type RTCMessageCallback func( - ctx context.Context, - roomName livekit.RoomName, - identity livekit.ParticipantIdentity, - msg *livekit.RTCNodeMessage, -) - // Router allows multiple nodes to coordinate the participant session // //counterfeiter:generate . Router @@ -99,28 +84,26 @@ type Router interface { Start() error Drain() Stop() +} - // OnNewParticipantRTC is called to start a new participant's RTC connection - OnNewParticipantRTC(callback NewParticipantCallback) - - // OnRTCMessage is called to execute actions on the RTC node - OnRTCMessage(callback RTCMessageCallback) +type StartParticipantSignalResults struct { + ConnectionID livekit.ConnectionID + RequestSink MessageSink + ResponseSource MessageSource + NodeID livekit.NodeID + NodeSelectionReason string } type MessageRouter interface { // StartParticipantSignal participant signal connection is ready to start - StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) - - // Write a message to a participant or room - WriteParticipantRTC(ctx context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity, msg *livekit.RTCNodeMessage) error - WriteRoomRTC(ctx context.Context, roomName livekit.RoomName, msg *livekit.RTCNodeMessage) error + StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (res StartParticipantSignalResults, err error) } -func CreateRouter(config *config.Config, rc redis.UniversalClient, node LocalNode, signalClient SignalClient) Router { +func CreateRouter(rc redis.UniversalClient, node LocalNode, signalClient SignalClient, kps rpc.KeepalivePubSub) Router { lr := NewLocalRouter(node, signalClient) if rc != nil { - return NewRedisRouter(config, lr, rc) + return NewRedisRouter(lr, rc, kps) } // local routing and store diff --git a/pkg/routing/localrouter.go b/pkg/routing/localrouter.go index b0fcbbccb..85279ead0 100644 --- a/pkg/routing/localrouter.go +++ b/pkg/routing/localrouter.go @@ -26,8 +26,7 @@ import ( "github.com/livekit/protocol/logger" ) -// aggregated channel for all participants -const localRTCChannelSize = 10000 +var _ Router = (*LocalRouter)(nil) // a router of messages on the same node, basic implementation for local testing type LocalRouter struct { @@ -39,11 +38,6 @@ type LocalRouter struct { requestChannels map[string]*MessageChannel responseChannels map[string]*MessageChannel isStarted atomic.Bool - - rtcMessageChan *MessageChannel - - onNewParticipant NewParticipantCallback - onRTCMessage RTCMessageCallback } func NewLocalRouter(currentNode LocalNode, signalClient SignalClient) *LocalRouter { @@ -52,7 +46,6 @@ func NewLocalRouter(currentNode LocalNode, signalClient SignalClient) *LocalRout signalClient: signalClient, requestChannels: make(map[string]*MessageChannel), responseChannels: make(map[string]*MessageChannel), - rtcMessageChan: NewMessageChannel(livekit.ConnectionID("local"), localRTCChannelSize), } } @@ -68,7 +61,6 @@ func (r *LocalRouter) SetNodeForRoom(_ context.Context, _ livekit.RoomName, _ li } func (r *LocalRouter) ClearRoomState(_ context.Context, _ livekit.RoomName) error { - // do nothing return nil } @@ -97,82 +89,45 @@ func (r *LocalRouter) ListNodes() ([]*livekit.Node, error) { }, nil } -func (r *LocalRouter) StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) { +func (r *LocalRouter) StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (res StartParticipantSignalResults, err error) { return r.StartParticipantSignalWithNodeID(ctx, roomName, pi, livekit.NodeID(r.currentNode.Id)) } -func (r *LocalRouter) StartParticipantSignalWithNodeID(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, nodeID livekit.NodeID) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) { - connectionID, reqSink, resSource, err = r.signalClient.StartParticipantSignal(ctx, roomName, pi, nodeID) +func (r *LocalRouter) StartParticipantSignalWithNodeID(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, nodeID livekit.NodeID) (res StartParticipantSignalResults, err error) { + connectionID, reqSink, resSource, err := r.signalClient.StartParticipantSignal(ctx, roomName, pi, nodeID) if err != nil { logger.Errorw("could not handle new participant", err, "room", roomName, "participant", pi.Identity, "connID", connectionID, ) + } else { + return StartParticipantSignalResults{ + ConnectionID: connectionID, + RequestSink: reqSink, + ResponseSource: resSource, + NodeID: nodeID, + }, nil } return } -func (r *LocalRouter) WriteParticipantRTC(_ context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity, msg *livekit.RTCNodeMessage) error { - r.lock.Lock() - if r.rtcMessageChan.IsClosed() { - // create a new one - r.rtcMessageChan = NewMessageChannel(livekit.ConnectionID("local"), localRTCChannelSize) - } - r.lock.Unlock() - msg.ParticipantKey = string(ParticipantKeyLegacy(roomName, identity)) - msg.ParticipantKeyB62 = string(ParticipantKey(roomName, identity)) - return r.writeRTCMessage(r.rtcMessageChan, msg) -} - -func (r *LocalRouter) WriteRoomRTC(ctx context.Context, roomName livekit.RoomName, msg *livekit.RTCNodeMessage) error { - msg.ParticipantKey = string(ParticipantKeyLegacy(roomName, "")) - msg.ParticipantKeyB62 = string(ParticipantKey(roomName, "")) - return r.WriteNodeRTC(ctx, r.currentNode.Id, msg) -} - -func (r *LocalRouter) WriteNodeRTC(_ context.Context, _ string, msg *livekit.RTCNodeMessage) error { - r.lock.Lock() - if r.rtcMessageChan.IsClosed() { - // create a new one - r.rtcMessageChan = NewMessageChannel(livekit.ConnectionID("local"), localRTCChannelSize) - } - r.lock.Unlock() - return r.writeRTCMessage(r.rtcMessageChan, msg) -} - -func (r *LocalRouter) writeRTCMessage(sink MessageSink, msg *livekit.RTCNodeMessage) error { - defer sink.Close() - msg.SenderTime = time.Now().Unix() - return sink.WriteMessage(msg) -} - -func (r *LocalRouter) OnNewParticipantRTC(callback NewParticipantCallback) { - r.onNewParticipant = callback -} - -func (r *LocalRouter) OnRTCMessage(callback RTCMessageCallback) { - r.onRTCMessage = callback -} - func (r *LocalRouter) Start() error { if r.isStarted.Swap(true) { return nil } go r.statsWorker() // go r.memStatsWorker() - // on local routers, Start doesn't do anything, websocket connections initiate the connections - go r.rtcMessageWorker() return nil } func (r *LocalRouter) Drain() { + r.lock.Lock() + defer r.lock.Unlock() r.currentNode.State = livekit.NodeState_SHUTTING_DOWN } -func (r *LocalRouter) Stop() { - r.rtcMessageChan.Close() -} +func (r *LocalRouter) Stop() {} func (r *LocalRouter) GetRegion() string { return r.currentNode.Region @@ -208,73 +163,3 @@ func (r *LocalRouter) statsWorker() { } } */ -func (r *LocalRouter) rtcMessageWorker() { - // is a new channel available? if so swap to that one - if !r.isStarted.Load() { - return - } - - // start a new worker after this finished - defer func() { - go r.rtcMessageWorker() - }() - - r.lock.RLock() - isClosed := r.rtcMessageChan.IsClosed() - r.lock.RUnlock() - if isClosed { - // sleep and retry - time.Sleep(time.Second) - } - - r.lock.RLock() - msgChan := r.rtcMessageChan.ReadChan() - r.lock.RUnlock() - // consume messages from - for msg := range msgChan { - if rtcMsg, ok := msg.(*livekit.RTCNodeMessage); ok { - var room livekit.RoomName - var identity livekit.ParticipantIdentity - var err error - if rtcMsg.ParticipantKeyB62 != "" { - room, identity, err = parseParticipantKey(livekit.ParticipantKey(rtcMsg.ParticipantKeyB62)) - } - if err != nil { - room, identity, err = parseParticipantKeyLegacy(livekit.ParticipantKey(rtcMsg.ParticipantKey)) - } - if err != nil { - logger.Errorw("could not process RTC message", err) - continue - } - if r.onRTCMessage != nil { - r.onRTCMessage(context.Background(), room, identity, rtcMsg) - } - } - } -} - -func (r *LocalRouter) getMessageChannel(target map[string]*MessageChannel, key string) *MessageChannel { - r.lock.RLock() - defer r.lock.RUnlock() - return target[key] -} - -func (r *LocalRouter) getOrCreateMessageChannel(target map[string]*MessageChannel, key string) *MessageChannel { - r.lock.Lock() - defer r.lock.Unlock() - mc := target[key] - - if mc != nil { - return mc - } - - mc = NewMessageChannel(livekit.ConnectionID(key), DefaultMessageChannelSize) - mc.OnClose(func() { - r.lock.Lock() - delete(target, key) - r.lock.Unlock() - }) - target[key] = mc - - return mc -} diff --git a/pkg/routing/messagechannel_test.go b/pkg/routing/messagechannel_test.go index 5d78c2104..bac68ef1e 100644 --- a/pkg/routing/messagechannel_test.go +++ b/pkg/routing/messagechannel_test.go @@ -39,12 +39,12 @@ func TestMessageChannel_WriteMessageClosed(t *testing.T) { go func() { defer wg.Done() for i := 0; i < 100; i++ { - _ = m.WriteMessage(&livekit.RTCNodeMessage{}) + _ = m.WriteMessage(&livekit.SignalRequest{}) } }() - _ = m.WriteMessage(&livekit.RTCNodeMessage{}) + _ = m.WriteMessage(&livekit.SignalRequest{}) m.Close() - _ = m.WriteMessage(&livekit.RTCNodeMessage{}) + _ = m.WriteMessage(&livekit.SignalRequest{}) wg.Wait() } diff --git a/pkg/routing/redis.go b/pkg/routing/redis.go deleted file mode 100644 index 5d81ce088..000000000 --- a/pkg/routing/redis.go +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2023 LiveKit, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package routing - -import ( - "context" - - "github.com/redis/go-redis/v9" - "go.uber.org/atomic" - "google.golang.org/protobuf/proto" - - "github.com/livekit/protocol/livekit" -) - -const ( - // hash of node_id => Node proto - NodesKey = "nodes" - - // hash of room_name => node_id - NodeRoomKey = "room_node_map" -) - -var redisCtx = context.Background() - -// location of the participant's RTC connection, hash -func participantRTCKey(participantKey livekit.ParticipantKey) string { - return "participant_rtc:" + string(participantKey) -} - -// location of the participant's Signal connection, hash -func participantSignalKey(connectionID livekit.ConnectionID) string { - return "participant_signal:" + string(connectionID) -} - -func rtcNodeChannel(nodeID livekit.NodeID) string { - return "rtc_channel:" + string(nodeID) -} - -func signalNodeChannel(nodeID livekit.NodeID) string { - return "signal_channel:" + string(nodeID) -} - -func publishRTCMessage(rc redis.UniversalClient, nodeID livekit.NodeID, participantKey livekit.ParticipantKey, participantKeyB62 livekit.ParticipantKey, msg proto.Message) error { - rm := &livekit.RTCNodeMessage{ - ParticipantKey: string(participantKey), - ParticipantKeyB62: string(participantKeyB62), - } - switch o := msg.(type) { - case *livekit.StartSession: - rm.Message = &livekit.RTCNodeMessage_StartSession{ - StartSession: o, - } - case *livekit.SignalRequest: - rm.Message = &livekit.RTCNodeMessage_Request{ - Request: o, - } - case *livekit.RTCNodeMessage: - rm = o - rm.ParticipantKey = string(participantKey) - rm.ParticipantKeyB62 = string(participantKeyB62) - default: - return ErrInvalidRouterMessage - } - data, err := proto.Marshal(rm) - if err != nil { - return err - } - - // logger.Debugw("publishing to rtc", "rtcChannel", rtcNodeChannel(nodeID), - // "message", rm.Message) - return rc.Publish(redisCtx, rtcNodeChannel(nodeID), data).Err() -} - -func publishSignalMessage(rc redis.UniversalClient, nodeID livekit.NodeID, connectionID livekit.ConnectionID, msg proto.Message) error { - rm := &livekit.SignalNodeMessage{ - ConnectionId: string(connectionID), - } - switch o := msg.(type) { - case *livekit.SignalResponse: - rm.Message = &livekit.SignalNodeMessage_Response{ - Response: o, - } - case *livekit.EndSession: - rm.Message = &livekit.SignalNodeMessage_EndSession{ - EndSession: o, - } - default: - return ErrInvalidRouterMessage - } - data, err := proto.Marshal(rm) - if err != nil { - return err - } - - // logger.Debugw("publishing to signal", "signalChannel", signalNodeChannel(nodeID), - // "message", rm.Message) - return rc.Publish(redisCtx, signalNodeChannel(nodeID), data).Err() -} - -type RTCNodeSink struct { - rc redis.UniversalClient - nodeID livekit.NodeID - connectionID livekit.ConnectionID - participantKey livekit.ParticipantKey - participantKeyB62 livekit.ParticipantKey - isClosed atomic.Bool - onClose func() -} - -func NewRTCNodeSink( - rc redis.UniversalClient, - nodeID livekit.NodeID, - connectionID livekit.ConnectionID, - participantKey livekit.ParticipantKey, - participantKeyB62 livekit.ParticipantKey, -) *RTCNodeSink { - return &RTCNodeSink{ - rc: rc, - nodeID: nodeID, - connectionID: connectionID, - participantKey: participantKey, - participantKeyB62: participantKeyB62, - } -} - -func (s *RTCNodeSink) WriteMessage(msg proto.Message) error { - if s.isClosed.Load() { - return ErrChannelClosed - } - return publishRTCMessage(s.rc, s.nodeID, s.participantKey, s.participantKeyB62, msg) -} - -func (s *RTCNodeSink) Close() { - if s.isClosed.Swap(true) { - return - } - if s.onClose != nil { - s.onClose() - } -} - -func (s *RTCNodeSink) IsClosed() bool { - return s.isClosed.Load() -} - -func (s *RTCNodeSink) OnClose(f func()) { - s.onClose = f -} - -func (s *RTCNodeSink) ConnectionID() livekit.ConnectionID { - return s.connectionID -} - -// ---------------------------------------------------------------------- - -type SignalNodeSink struct { - rc redis.UniversalClient - nodeID livekit.NodeID - connectionID livekit.ConnectionID - isClosed atomic.Bool - onClose func() -} - -func NewSignalNodeSink(rc redis.UniversalClient, nodeID livekit.NodeID, connectionID livekit.ConnectionID) *SignalNodeSink { - return &SignalNodeSink{ - rc: rc, - nodeID: nodeID, - connectionID: connectionID, - } -} - -func (s *SignalNodeSink) WriteMessage(msg proto.Message) error { - if s.isClosed.Load() { - return ErrChannelClosed - } - return publishSignalMessage(s.rc, s.nodeID, s.connectionID, msg) -} - -func (s *SignalNodeSink) Close() { - if s.isClosed.Swap(true) { - return - } - _ = publishSignalMessage(s.rc, s.nodeID, s.connectionID, &livekit.EndSession{}) - if s.onClose != nil { - s.onClose() - } -} - -func (s *SignalNodeSink) IsClosed() bool { - return s.isClosed.Load() -} - -func (s *SignalNodeSink) OnClose(f func()) { - s.onClose = f -} - -func (s *SignalNodeSink) ConnectionID() livekit.ConnectionID { - return s.connectionID -} diff --git a/pkg/routing/redisrouter.go b/pkg/routing/redisrouter.go index a3e7e8365..579671814 100644 --- a/pkg/routing/redisrouter.go +++ b/pkg/routing/redisrouter.go @@ -28,9 +28,8 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" - "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/rpc" - "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing/selector" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" ) @@ -40,31 +39,38 @@ const ( participantMappingTTL = 24 * time.Hour statsUpdateInterval = 2 * time.Second statsMaxDelaySeconds = 30 + + // hash of node_id => Node proto + NodesKey = "nodes" + + // hash of room_name => node_id + NodeRoomKey = "room_node_map" ) +var _ Router = (*RedisRouter)(nil) + // RedisRouter uses Redis pub/sub to route signaling messages across different nodes // It relies on the RTC node to be the primary driver of the participant connection. // Because type RedisRouter struct { *LocalRouter - rc redis.UniversalClient - usePSRPCSignal bool - ctx context.Context - isStarted atomic.Bool - nodeMu sync.RWMutex + rc redis.UniversalClient + kps rpc.KeepalivePubSub + ctx context.Context + isStarted atomic.Bool + nodeMu sync.RWMutex // previous stats for computing averages prevStats *livekit.NodeStats - pubsub *redis.PubSub cancel func() } -func NewRedisRouter(config *config.Config, lr *LocalRouter, rc redis.UniversalClient) *RedisRouter { +func NewRedisRouter(lr *LocalRouter, rc redis.UniversalClient, kps rpc.KeepalivePubSub) *RedisRouter { rr := &RedisRouter{ - LocalRouter: lr, - rc: rc, - usePSRPCSignal: config.SignalRelay.Enabled, + LocalRouter: lr, + rc: rc, + kps: kps, } rr.ctx, rr.cancel = context.WithCancel(context.Background()) return rr @@ -156,159 +162,14 @@ func (r *RedisRouter) ListNodes() ([]*livekit.Node, error) { } // StartParticipantSignal signal connection sets up paths to the RTC node, and starts to route messages to that message queue -func (r *RedisRouter) StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) { +func (r *RedisRouter) StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (res StartParticipantSignalResults, err error) { // find the node where the room is hosted at rtcNode, err := r.GetNodeForRoom(ctx, roomName) if err != nil { return } - if r.usePSRPCSignal { - connectionID, reqSink, resSource, err = r.StartParticipantSignalWithNodeID(ctx, roomName, pi, livekit.NodeID(rtcNode.Id)) - if err != nil { - return - } - - // map signal & rtc nodes - err = r.setParticipantSignalNode(connectionID, r.currentNode.Id) - return - } - - connectionID = livekit.ConnectionID(utils.NewGuid("CO_")) - pKey := ParticipantKeyLegacy(roomName, pi.Identity) - pKeyB62 := ParticipantKey(roomName, pi.Identity) - - // map signal & rtc nodes - if err = r.setParticipantSignalNode(connectionID, r.currentNode.Id); err != nil { - return - } - - // index by connectionID, since there may be multiple connections for the participant - // set up response channel before sending StartSession and be ready to receive responses. - resChan := r.getOrCreateMessageChannel(r.responseChannels, string(connectionID)) - - sink := NewRTCNodeSink(r.rc, livekit.NodeID(rtcNode.Id), connectionID, pKey, pKeyB62) - - // serialize claims - ss, err := pi.ToStartSession(roomName, connectionID) - if err != nil { - return - } - - // sends a message to start session - err = sink.WriteMessage(ss) - if err != nil { - return - } - - return connectionID, sink, resChan, nil -} - -func (r *RedisRouter) WriteParticipantRTC(_ context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity, msg *livekit.RTCNodeMessage) error { - pkey := ParticipantKeyLegacy(roomName, identity) - pkeyB62 := ParticipantKey(roomName, identity) - rtcNode, err := r.getParticipantRTCNode(pkey, pkeyB62) - if err != nil { - return err - } - - rtcSink := NewRTCNodeSink(r.rc, livekit.NodeID(rtcNode), "ephemeral", pkey, pkeyB62) - msg.ParticipantKey = string(ParticipantKeyLegacy(roomName, identity)) - msg.ParticipantKeyB62 = string(ParticipantKey(roomName, identity)) - return r.writeRTCMessage(rtcSink, msg) -} - -func (r *RedisRouter) WriteRoomRTC(ctx context.Context, roomName livekit.RoomName, msg *livekit.RTCNodeMessage) error { - node, err := r.GetNodeForRoom(ctx, roomName) - if err != nil { - return err - } - msg.ParticipantKey = string(ParticipantKeyLegacy(roomName, "")) - msg.ParticipantKeyB62 = string(ParticipantKey(roomName, "")) - return r.WriteNodeRTC(ctx, node.Id, msg) -} - -func (r *RedisRouter) WriteNodeRTC(_ context.Context, rtcNodeID string, msg *livekit.RTCNodeMessage) error { - rtcSink := NewRTCNodeSink(r.rc, livekit.NodeID(rtcNodeID), "ephemeral", livekit.ParticipantKey(msg.ParticipantKey), livekit.ParticipantKey(msg.ParticipantKeyB62)) - return r.writeRTCMessage(rtcSink, msg) -} - -func (r *RedisRouter) startParticipantRTC(ss *livekit.StartSession, participantKey livekit.ParticipantKey, participantKeyB62 livekit.ParticipantKey) error { - prometheus.IncrementParticipantRtcInit(1) - // find the node where the room is hosted at - rtcNode, err := r.GetNodeForRoom(r.ctx, livekit.RoomName(ss.RoomName)) - if err != nil { - return err - } - - if rtcNode.Id != r.currentNode.Id { - err = ErrIncorrectRTCNode - logger.Errorw("called participant on incorrect node", err, - "rtcNode", rtcNode, - ) - return err - } - - if err := r.SetParticipantRTCNode(participantKey, participantKeyB62, rtcNode.Id); err != nil { - return err - } - - // find signal node to send responses back - signalNode, err := r.getParticipantSignalNode(livekit.ConnectionID(ss.ConnectionId)) - if err != nil { - return err - } - - // treat it as a new participant connecting - if r.onNewParticipant == nil { - return ErrHandlerNotDefined - } - - // we do not want to re-use the same response sink - // the previous rtc worker thread is still consuming off of it. - // we'll want to sever the connection and switch to the new one - r.lock.RLock() - var requestChan *MessageChannel - var ok bool - var pkey livekit.ParticipantKey - if participantKeyB62 != "" { - requestChan, ok = r.requestChannels[string(participantKeyB62)] - pkey = participantKeyB62 - } else { - requestChan, ok = r.requestChannels[string(participantKey)] - pkey = participantKey - } - r.lock.RUnlock() - if ok { - requestChan.Close() - } - - pi, err := ParticipantInitFromStartSession(ss, r.currentNode.Region) - if err != nil { - return err - } - - reqChan := r.getOrCreateMessageChannel(r.requestChannels, string(pkey)) - resSink := NewSignalNodeSink(r.rc, livekit.NodeID(signalNode), livekit.ConnectionID(ss.ConnectionId)) - go func() { - err := r.onNewParticipant( - r.ctx, - livekit.RoomName(ss.RoomName), - *pi, - reqChan, - resSink, - ) - if err != nil { - logger.Errorw("could not handle new participant", err, - "room", ss.RoomName, - "participant", ss.Identity, - ) - // cleanup request channels - reqChan.Close() - resSink.Close() - } - }() - return nil + return r.StartParticipantSignalWithNodeID(ctx, roomName, pi, livekit.NodeID(rtcNode.Id)) } func (r *RedisRouter) Start() error { @@ -316,17 +177,12 @@ func (r *RedisRouter) Start() error { return nil } - workerStarted := make(chan struct{}) + workerStarted := make(chan error) go r.statsWorker() - go r.redisWorker(workerStarted) + go r.keepaliveWorker(workerStarted) // wait until worker is running - select { - case <-workerStarted: - return nil - case <-time.After(3 * time.Second): - return errors.New("Unable to start redis router") - } + return <-workerStarted } func (r *RedisRouter) Drain() { @@ -343,63 +199,10 @@ func (r *RedisRouter) Stop() { return } logger.Debugw("stopping RedisRouter") - _ = r.pubsub.Close() _ = r.UnregisterNode() r.cancel() } -func (r *RedisRouter) SetParticipantRTCNode(participantKey livekit.ParticipantKey, participantKeyB62 livekit.ParticipantKey, nodeID string) error { - var err error - if participantKey != "" { - err1 := r.rc.Set(r.ctx, participantRTCKey(participantKey), nodeID, participantMappingTTL).Err() - if err1 != nil { - err = errors.Wrap(err, "could not set rtc node") - } - } - if participantKeyB62 != "" { - err2 := r.rc.Set(r.ctx, participantRTCKey(participantKeyB62), nodeID, participantMappingTTL).Err() - if err2 != nil { - err = errors.Wrap(err, "could not set rtc node") - } - } - return err -} - -func (r *RedisRouter) setParticipantSignalNode(connectionID livekit.ConnectionID, nodeID string) error { - if err := r.rc.Set(r.ctx, participantSignalKey(connectionID), nodeID, participantMappingTTL).Err(); err != nil { - return errors.Wrap(err, "could not set signal node") - } - return nil -} - -func (r *RedisRouter) getParticipantRTCNode(participantKey livekit.ParticipantKey, participantKeyB62 livekit.ParticipantKey) (string, error) { - var val string - var err error - if participantKeyB62 != "" { - val, err = r.rc.Get(r.ctx, participantRTCKey(participantKeyB62)).Result() - if err == redis.Nil { - val, err = r.rc.Get(r.ctx, participantRTCKey(participantKey)).Result() - if err == redis.Nil { - err = ErrNodeNotFound - } - } - } else { - val, err = r.rc.Get(r.ctx, participantRTCKey(participantKey)).Result() - if err == redis.Nil { - err = ErrNodeNotFound - } - } - return val, err -} - -func (r *RedisRouter) getParticipantSignalNode(connectionID livekit.ConnectionID) (nodeID string, err error) { - val, err := r.rc.Get(r.ctx, participantSignalKey(connectionID)).Result() - if err == redis.Nil { - err = ErrNodeNotFound - } - return val, err -} - // update node stats and cleanup func (r *RedisRouter) statsWorker() { goroutineDumped := false @@ -407,9 +210,8 @@ func (r *RedisRouter) statsWorker() { // update periodically select { case <-time.After(statsUpdateInterval): - _ = r.WriteNodeRTC(context.Background(), r.currentNode.Id, &livekit.RTCNodeMessage{ - Message: &livekit.RTCNodeMessage_KeepAlive{}, - }) + r.kps.PublishPing(r.ctx, livekit.NodeID(r.currentNode.Id), &rpc.KeepalivePing{Timestamp: time.Now().Unix()}) + r.nodeMu.RLock() stats := r.currentNode.Stats r.nodeMu.RUnlock() @@ -433,112 +235,17 @@ func (r *RedisRouter) statsWorker() { } } -// worker that consumes redis messages intended for this node -func (r *RedisRouter) redisWorker(startedChan chan struct{}) { - defer func() { - logger.Debugw("finishing redisWorker", "nodeID", r.currentNode.Id) - }() - logger.Debugw("starting redisWorker", "nodeID", r.currentNode.Id) - - sigChannel := signalNodeChannel(livekit.NodeID(r.currentNode.Id)) - rtcChannel := rtcNodeChannel(livekit.NodeID(r.currentNode.Id)) - r.pubsub = r.rc.Subscribe(r.ctx, sigChannel, rtcChannel) - +func (r *RedisRouter) keepaliveWorker(startedChan chan error) { + pings, err := r.kps.SubscribePing(r.ctx, livekit.NodeID(r.currentNode.Id)) + if err != nil { + startedChan <- err + return + } close(startedChan) - for msg := range r.pubsub.Channel() { - if msg == nil { - return - } - if msg.Channel == sigChannel { - sm := livekit.SignalNodeMessage{} - if err := proto.Unmarshal([]byte(msg.Payload), &sm); err != nil { - logger.Errorw("could not unmarshal signal message on sigchan", err) - prometheus.MessageCounter.WithLabelValues("signal", "failure").Add(1) - continue - } - if err := r.handleSignalMessage(&sm); err != nil { - logger.Errorw("error processing signal message", err) - prometheus.MessageCounter.WithLabelValues("signal", "failure").Add(1) - continue - } - prometheus.MessageCounter.WithLabelValues("signal", "success").Add(1) - } else if msg.Channel == rtcChannel { - rm := livekit.RTCNodeMessage{} - if err := proto.Unmarshal([]byte(msg.Payload), &rm); err != nil { - logger.Errorw("could not unmarshal RTC message on rtcchan", err) - prometheus.MessageCounter.WithLabelValues("rtc", "failure").Add(1) - continue - } - if err := r.handleRTCMessage(&rm); err != nil { - logger.Errorw("error processing RTC message", err) - prometheus.MessageCounter.WithLabelValues("rtc", "failure").Add(1) - continue - } - prometheus.MessageCounter.WithLabelValues("rtc", "success").Add(1) - } - } -} - -func (r *RedisRouter) handleSignalMessage(sm *livekit.SignalNodeMessage) error { - connectionID := sm.ConnectionId - - r.lock.RLock() - resSink := r.responseChannels[connectionID] - r.lock.RUnlock() - - // if a client closed the channel, then sent more messages after that, - if resSink == nil { - return nil - } - - switch rmb := sm.Message.(type) { - case *livekit.SignalNodeMessage_Response: - // logger.Debugw("forwarding signal message", - // "connID", connectionID, - // "type", fmt.Sprintf("%T", rmb.Response.Message)) - if err := resSink.WriteMessage(rmb.Response); err != nil { - return err - } - - case *livekit.SignalNodeMessage_EndSession: - // logger.Debugw("received EndSession, closing signal connection", - // "connID", connectionID) - resSink.Close() - } - return nil -} - -func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error { - pKey := livekit.ParticipantKey(rm.ParticipantKey) - pKeyB62 := livekit.ParticipantKey(rm.ParticipantKeyB62) - - switch rmb := rm.Message.(type) { - case *livekit.RTCNodeMessage_StartSession: - // RTC session should start on this node - if err := r.startParticipantRTC(rmb.StartSession, pKey, pKeyB62); err != nil { - return errors.Wrap(err, "could not start participant") - } - - case *livekit.RTCNodeMessage_Request: - r.lock.RLock() - var requestChan *MessageChannel - if pKeyB62 != "" { - requestChan = r.requestChannels[string(pKeyB62)] - } else { - requestChan = r.requestChannels[string(pKey)] - } - r.lock.RUnlock() - if requestChan == nil { - return ErrChannelClosed - } - if err := requestChan.WriteMessage(rmb.Request); err != nil { - return err - } - - case *livekit.RTCNodeMessage_KeepAlive: - if time.Since(time.Unix(rm.SenderTime, 0)) > statsUpdateInterval { - logger.Infow("keep alive too old, skipping", "senderTime", rm.SenderTime) + for ping := range pings.Channel() { + if time.Since(time.Unix(ping.Timestamp, 0)) > statsUpdateInterval { + logger.Infow("keep alive too old, skipping", "timestamp", ping.Timestamp) break } @@ -550,7 +257,7 @@ func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error { if err != nil { logger.Errorw("could not update node stats", err) r.nodeMu.Unlock() - return err + continue } r.currentNode.Stats = updated if computedAvg { @@ -562,24 +269,5 @@ func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error { if err := r.RegisterNode(); err != nil { logger.Errorw("could not update node", err) } - - default: - // route it to handler - if r.onRTCMessage != nil { - var roomName livekit.RoomName - var identity livekit.ParticipantIdentity - var err error - if pKeyB62 != "" { - roomName, identity, err = parseParticipantKey(pKeyB62) - } - if err != nil || pKeyB62 == "" { - roomName, identity, err = parseParticipantKeyLegacy(pKey) - } - if err != nil { - return err - } - r.onRTCMessage(r.ctx, roomName, identity, rm) - } } - return nil } diff --git a/pkg/routing/routingfakes/fake_router.go b/pkg/routing/routingfakes/fake_router.go index 190996099..88fa9bf58 100644 --- a/pkg/routing/routingfakes/fake_router.go +++ b/pkg/routing/routingfakes/fake_router.go @@ -62,16 +62,6 @@ type FakeRouter struct { result1 []*livekit.Node result2 error } - OnNewParticipantRTCStub func(routing.NewParticipantCallback) - onNewParticipantRTCMutex sync.RWMutex - onNewParticipantRTCArgsForCall []struct { - arg1 routing.NewParticipantCallback - } - OnRTCMessageStub func(routing.RTCMessageCallback) - onRTCMessageMutex sync.RWMutex - onRTCMessageArgsForCall []struct { - arg1 routing.RTCMessageCallback - } RegisterNodeStub func() error registerNodeMutex sync.RWMutex registerNodeArgsForCall []struct { @@ -115,7 +105,7 @@ type FakeRouter struct { startReturnsOnCall map[int]struct { result1 error } - StartParticipantSignalStub func(context.Context, livekit.RoomName, routing.ParticipantInit) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error) + StartParticipantSignalStub func(context.Context, livekit.RoomName, routing.ParticipantInit) (routing.StartParticipantSignalResults, error) startParticipantSignalMutex sync.RWMutex startParticipantSignalArgsForCall []struct { arg1 context.Context @@ -123,16 +113,12 @@ type FakeRouter struct { arg3 routing.ParticipantInit } startParticipantSignalReturns struct { - result1 livekit.ConnectionID - result2 routing.MessageSink - result3 routing.MessageSource - result4 error + result1 routing.StartParticipantSignalResults + result2 error } startParticipantSignalReturnsOnCall map[int]struct { - result1 livekit.ConnectionID - result2 routing.MessageSink - result3 routing.MessageSource - result4 error + result1 routing.StartParticipantSignalResults + result2 error } StopStub func() stopMutex sync.RWMutex @@ -148,33 +134,6 @@ type FakeRouter struct { unregisterNodeReturnsOnCall map[int]struct { result1 error } - WriteParticipantRTCStub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity, *livekit.RTCNodeMessage) error - writeParticipantRTCMutex sync.RWMutex - writeParticipantRTCArgsForCall []struct { - arg1 context.Context - arg2 livekit.RoomName - arg3 livekit.ParticipantIdentity - arg4 *livekit.RTCNodeMessage - } - writeParticipantRTCReturns struct { - result1 error - } - writeParticipantRTCReturnsOnCall map[int]struct { - result1 error - } - WriteRoomRTCStub func(context.Context, livekit.RoomName, *livekit.RTCNodeMessage) error - writeRoomRTCMutex sync.RWMutex - writeRoomRTCArgsForCall []struct { - arg1 context.Context - arg2 livekit.RoomName - arg3 *livekit.RTCNodeMessage - } - writeRoomRTCReturns struct { - result1 error - } - writeRoomRTCReturnsOnCall map[int]struct { - result1 error - } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } @@ -439,70 +398,6 @@ func (fake *FakeRouter) ListNodesReturnsOnCall(i int, result1 []*livekit.Node, r }{result1, result2} } -func (fake *FakeRouter) OnNewParticipantRTC(arg1 routing.NewParticipantCallback) { - fake.onNewParticipantRTCMutex.Lock() - fake.onNewParticipantRTCArgsForCall = append(fake.onNewParticipantRTCArgsForCall, struct { - arg1 routing.NewParticipantCallback - }{arg1}) - stub := fake.OnNewParticipantRTCStub - fake.recordInvocation("OnNewParticipantRTC", []interface{}{arg1}) - fake.onNewParticipantRTCMutex.Unlock() - if stub != nil { - fake.OnNewParticipantRTCStub(arg1) - } -} - -func (fake *FakeRouter) OnNewParticipantRTCCallCount() int { - fake.onNewParticipantRTCMutex.RLock() - defer fake.onNewParticipantRTCMutex.RUnlock() - return len(fake.onNewParticipantRTCArgsForCall) -} - -func (fake *FakeRouter) OnNewParticipantRTCCalls(stub func(routing.NewParticipantCallback)) { - fake.onNewParticipantRTCMutex.Lock() - defer fake.onNewParticipantRTCMutex.Unlock() - fake.OnNewParticipantRTCStub = stub -} - -func (fake *FakeRouter) OnNewParticipantRTCArgsForCall(i int) routing.NewParticipantCallback { - fake.onNewParticipantRTCMutex.RLock() - defer fake.onNewParticipantRTCMutex.RUnlock() - argsForCall := fake.onNewParticipantRTCArgsForCall[i] - return argsForCall.arg1 -} - -func (fake *FakeRouter) OnRTCMessage(arg1 routing.RTCMessageCallback) { - fake.onRTCMessageMutex.Lock() - fake.onRTCMessageArgsForCall = append(fake.onRTCMessageArgsForCall, struct { - arg1 routing.RTCMessageCallback - }{arg1}) - stub := fake.OnRTCMessageStub - fake.recordInvocation("OnRTCMessage", []interface{}{arg1}) - fake.onRTCMessageMutex.Unlock() - if stub != nil { - fake.OnRTCMessageStub(arg1) - } -} - -func (fake *FakeRouter) OnRTCMessageCallCount() int { - fake.onRTCMessageMutex.RLock() - defer fake.onRTCMessageMutex.RUnlock() - return len(fake.onRTCMessageArgsForCall) -} - -func (fake *FakeRouter) OnRTCMessageCalls(stub func(routing.RTCMessageCallback)) { - fake.onRTCMessageMutex.Lock() - defer fake.onRTCMessageMutex.Unlock() - fake.OnRTCMessageStub = stub -} - -func (fake *FakeRouter) OnRTCMessageArgsForCall(i int) routing.RTCMessageCallback { - fake.onRTCMessageMutex.RLock() - defer fake.onRTCMessageMutex.RUnlock() - argsForCall := fake.onRTCMessageArgsForCall[i] - return argsForCall.arg1 -} - func (fake *FakeRouter) RegisterNode() error { fake.registerNodeMutex.Lock() ret, specificReturn := fake.registerNodeReturnsOnCall[len(fake.registerNodeArgsForCall)] @@ -725,7 +620,7 @@ func (fake *FakeRouter) StartReturnsOnCall(i int, result1 error) { }{result1} } -func (fake *FakeRouter) StartParticipantSignal(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error) { +func (fake *FakeRouter) StartParticipantSignal(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit) (routing.StartParticipantSignalResults, error) { fake.startParticipantSignalMutex.Lock() ret, specificReturn := fake.startParticipantSignalReturnsOnCall[len(fake.startParticipantSignalArgsForCall)] fake.startParticipantSignalArgsForCall = append(fake.startParticipantSignalArgsForCall, struct { @@ -741,9 +636,9 @@ func (fake *FakeRouter) StartParticipantSignal(arg1 context.Context, arg2 liveki return stub(arg1, arg2, arg3) } if specificReturn { - return ret.result1, ret.result2, ret.result3, ret.result4 + return ret.result1, ret.result2 } - return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3, fakeReturns.result4 + return fakeReturns.result1, fakeReturns.result2 } func (fake *FakeRouter) StartParticipantSignalCallCount() int { @@ -752,7 +647,7 @@ func (fake *FakeRouter) StartParticipantSignalCallCount() int { return len(fake.startParticipantSignalArgsForCall) } -func (fake *FakeRouter) StartParticipantSignalCalls(stub func(context.Context, livekit.RoomName, routing.ParticipantInit) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error)) { +func (fake *FakeRouter) StartParticipantSignalCalls(stub func(context.Context, livekit.RoomName, routing.ParticipantInit) (routing.StartParticipantSignalResults, error)) { fake.startParticipantSignalMutex.Lock() defer fake.startParticipantSignalMutex.Unlock() fake.StartParticipantSignalStub = stub @@ -765,36 +660,30 @@ func (fake *FakeRouter) StartParticipantSignalArgsForCall(i int) (context.Contex return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 } -func (fake *FakeRouter) StartParticipantSignalReturns(result1 livekit.ConnectionID, result2 routing.MessageSink, result3 routing.MessageSource, result4 error) { +func (fake *FakeRouter) StartParticipantSignalReturns(result1 routing.StartParticipantSignalResults, result2 error) { fake.startParticipantSignalMutex.Lock() defer fake.startParticipantSignalMutex.Unlock() fake.StartParticipantSignalStub = nil fake.startParticipantSignalReturns = struct { - result1 livekit.ConnectionID - result2 routing.MessageSink - result3 routing.MessageSource - result4 error - }{result1, result2, result3, result4} + result1 routing.StartParticipantSignalResults + result2 error + }{result1, result2} } -func (fake *FakeRouter) StartParticipantSignalReturnsOnCall(i int, result1 livekit.ConnectionID, result2 routing.MessageSink, result3 routing.MessageSource, result4 error) { +func (fake *FakeRouter) StartParticipantSignalReturnsOnCall(i int, result1 routing.StartParticipantSignalResults, result2 error) { fake.startParticipantSignalMutex.Lock() defer fake.startParticipantSignalMutex.Unlock() fake.StartParticipantSignalStub = nil if fake.startParticipantSignalReturnsOnCall == nil { fake.startParticipantSignalReturnsOnCall = make(map[int]struct { - result1 livekit.ConnectionID - result2 routing.MessageSink - result3 routing.MessageSource - result4 error + result1 routing.StartParticipantSignalResults + result2 error }) } fake.startParticipantSignalReturnsOnCall[i] = struct { - result1 livekit.ConnectionID - result2 routing.MessageSink - result3 routing.MessageSource - result4 error - }{result1, result2, result3, result4} + result1 routing.StartParticipantSignalResults + result2 error + }{result1, result2} } func (fake *FakeRouter) Stop() { @@ -874,133 +763,6 @@ func (fake *FakeRouter) UnregisterNodeReturnsOnCall(i int, result1 error) { }{result1} } -func (fake *FakeRouter) WriteParticipantRTC(arg1 context.Context, arg2 livekit.RoomName, arg3 livekit.ParticipantIdentity, arg4 *livekit.RTCNodeMessage) error { - fake.writeParticipantRTCMutex.Lock() - ret, specificReturn := fake.writeParticipantRTCReturnsOnCall[len(fake.writeParticipantRTCArgsForCall)] - fake.writeParticipantRTCArgsForCall = append(fake.writeParticipantRTCArgsForCall, struct { - arg1 context.Context - arg2 livekit.RoomName - arg3 livekit.ParticipantIdentity - arg4 *livekit.RTCNodeMessage - }{arg1, arg2, arg3, arg4}) - stub := fake.WriteParticipantRTCStub - fakeReturns := fake.writeParticipantRTCReturns - fake.recordInvocation("WriteParticipantRTC", []interface{}{arg1, arg2, arg3, arg4}) - fake.writeParticipantRTCMutex.Unlock() - if stub != nil { - return stub(arg1, arg2, arg3, arg4) - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeRouter) WriteParticipantRTCCallCount() int { - fake.writeParticipantRTCMutex.RLock() - defer fake.writeParticipantRTCMutex.RUnlock() - return len(fake.writeParticipantRTCArgsForCall) -} - -func (fake *FakeRouter) WriteParticipantRTCCalls(stub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity, *livekit.RTCNodeMessage) error) { - fake.writeParticipantRTCMutex.Lock() - defer fake.writeParticipantRTCMutex.Unlock() - fake.WriteParticipantRTCStub = stub -} - -func (fake *FakeRouter) WriteParticipantRTCArgsForCall(i int) (context.Context, livekit.RoomName, livekit.ParticipantIdentity, *livekit.RTCNodeMessage) { - fake.writeParticipantRTCMutex.RLock() - defer fake.writeParticipantRTCMutex.RUnlock() - argsForCall := fake.writeParticipantRTCArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 -} - -func (fake *FakeRouter) WriteParticipantRTCReturns(result1 error) { - fake.writeParticipantRTCMutex.Lock() - defer fake.writeParticipantRTCMutex.Unlock() - fake.WriteParticipantRTCStub = nil - fake.writeParticipantRTCReturns = struct { - result1 error - }{result1} -} - -func (fake *FakeRouter) WriteParticipantRTCReturnsOnCall(i int, result1 error) { - fake.writeParticipantRTCMutex.Lock() - defer fake.writeParticipantRTCMutex.Unlock() - fake.WriteParticipantRTCStub = nil - if fake.writeParticipantRTCReturnsOnCall == nil { - fake.writeParticipantRTCReturnsOnCall = make(map[int]struct { - result1 error - }) - } - fake.writeParticipantRTCReturnsOnCall[i] = struct { - result1 error - }{result1} -} - -func (fake *FakeRouter) WriteRoomRTC(arg1 context.Context, arg2 livekit.RoomName, arg3 *livekit.RTCNodeMessage) error { - fake.writeRoomRTCMutex.Lock() - ret, specificReturn := fake.writeRoomRTCReturnsOnCall[len(fake.writeRoomRTCArgsForCall)] - fake.writeRoomRTCArgsForCall = append(fake.writeRoomRTCArgsForCall, struct { - arg1 context.Context - arg2 livekit.RoomName - arg3 *livekit.RTCNodeMessage - }{arg1, arg2, arg3}) - stub := fake.WriteRoomRTCStub - fakeReturns := fake.writeRoomRTCReturns - fake.recordInvocation("WriteRoomRTC", []interface{}{arg1, arg2, arg3}) - fake.writeRoomRTCMutex.Unlock() - if stub != nil { - return stub(arg1, arg2, arg3) - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeRouter) WriteRoomRTCCallCount() int { - fake.writeRoomRTCMutex.RLock() - defer fake.writeRoomRTCMutex.RUnlock() - return len(fake.writeRoomRTCArgsForCall) -} - -func (fake *FakeRouter) WriteRoomRTCCalls(stub func(context.Context, livekit.RoomName, *livekit.RTCNodeMessage) error) { - fake.writeRoomRTCMutex.Lock() - defer fake.writeRoomRTCMutex.Unlock() - fake.WriteRoomRTCStub = stub -} - -func (fake *FakeRouter) WriteRoomRTCArgsForCall(i int) (context.Context, livekit.RoomName, *livekit.RTCNodeMessage) { - fake.writeRoomRTCMutex.RLock() - defer fake.writeRoomRTCMutex.RUnlock() - argsForCall := fake.writeRoomRTCArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 -} - -func (fake *FakeRouter) WriteRoomRTCReturns(result1 error) { - fake.writeRoomRTCMutex.Lock() - defer fake.writeRoomRTCMutex.Unlock() - fake.WriteRoomRTCStub = nil - fake.writeRoomRTCReturns = struct { - result1 error - }{result1} -} - -func (fake *FakeRouter) WriteRoomRTCReturnsOnCall(i int, result1 error) { - fake.writeRoomRTCMutex.Lock() - defer fake.writeRoomRTCMutex.Unlock() - fake.WriteRoomRTCStub = nil - if fake.writeRoomRTCReturnsOnCall == nil { - fake.writeRoomRTCReturnsOnCall = make(map[int]struct { - result1 error - }) - } - fake.writeRoomRTCReturnsOnCall[i] = struct { - result1 error - }{result1} -} - func (fake *FakeRouter) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() @@ -1014,10 +776,6 @@ func (fake *FakeRouter) Invocations() map[string][][]interface{} { defer fake.getRegionMutex.RUnlock() fake.listNodesMutex.RLock() defer fake.listNodesMutex.RUnlock() - fake.onNewParticipantRTCMutex.RLock() - defer fake.onNewParticipantRTCMutex.RUnlock() - fake.onRTCMessageMutex.RLock() - defer fake.onRTCMessageMutex.RUnlock() fake.registerNodeMutex.RLock() defer fake.registerNodeMutex.RUnlock() fake.removeDeadNodesMutex.RLock() @@ -1032,10 +790,6 @@ func (fake *FakeRouter) Invocations() map[string][][]interface{} { defer fake.stopMutex.RUnlock() fake.unregisterNodeMutex.RLock() defer fake.unregisterNodeMutex.RUnlock() - fake.writeParticipantRTCMutex.RLock() - defer fake.writeParticipantRTCMutex.RUnlock() - fake.writeRoomRTCMutex.RLock() - defer fake.writeRoomRTCMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} for key, value := range fake.invocations { copiedInvocations[key] = value diff --git a/pkg/routing/utils.go b/pkg/routing/utils.go deleted file mode 100644 index 2e11fdbe2..000000000 --- a/pkg/routing/utils.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2023 LiveKit, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package routing - -import ( - "fmt" - "strings" - - "github.com/jxskiss/base62" - - "github.com/livekit/protocol/livekit" -) - -func ParticipantKeyLegacy(roomName livekit.RoomName, identity livekit.ParticipantIdentity) livekit.ParticipantKey { - return livekit.ParticipantKey(string(roomName) + "|" + string(identity)) -} - -func parseParticipantKeyLegacy(pkey livekit.ParticipantKey) (roomName livekit.RoomName, identity livekit.ParticipantIdentity, err error) { - parts := strings.Split(string(pkey), "|") - if len(parts) == 2 { - roomName = livekit.RoomName(parts[0]) - identity = livekit.ParticipantIdentity(parts[1]) - return - } - - err = fmt.Errorf("invalid participant key: %s", pkey) - return -} - -func ParticipantKey(roomName livekit.RoomName, identity livekit.ParticipantIdentity) livekit.ParticipantKey { - return livekit.ParticipantKey(encode(string(roomName), string(identity))) -} - -func parseParticipantKey(pkey livekit.ParticipantKey) (roomName livekit.RoomName, identity livekit.ParticipantIdentity, err error) { - parts, err := decode(string(pkey)) - if err != nil { - return - } - if len(parts) == 2 { - roomName = livekit.RoomName(parts[0]) - identity = livekit.ParticipantIdentity(parts[1]) - return - } - - err = fmt.Errorf("invalid participant key: %s", pkey) - return -} - -func encode(str ...string) string { - encoded := make([]string, 0, len(str)) - for _, s := range str { - encoded = append(encoded, base62.EncodeToString([]byte(s))) - } - return strings.Join(encoded, "|") -} - -func decode(encoded string) ([]string, error) { - split := strings.Split(encoded, "|") - decoded := make([]string, 0, len(split)) - for _, s := range split { - part, err := base62.DecodeString(s) - if err != nil { - return nil, err - } - decoded = append(decoded, string(part)) - } - return decoded, nil -} diff --git a/pkg/routing/utils_test.go b/pkg/routing/utils_test.go deleted file mode 100644 index 8ae1e9b4c..000000000 --- a/pkg/routing/utils_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2023 LiveKit, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package routing - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/livekit/protocol/livekit" -) - -func TestUtils_ParticipantKey(t *testing.T) { - // encode/decode empty - encoded := ParticipantKey("", "") - roomName, identity, err := parseParticipantKey(encoded) - require.NoError(t, err) - require.Equal(t, livekit.RoomName(""), roomName) - require.Equal(t, livekit.ParticipantIdentity(""), identity) - - // decode invalid - _, _, err = parseParticipantKey("abcd") - require.Error(t, err) - - // encode/decode without delimiter - encoded = ParticipantKey("room1", "identity1") - roomName, identity, err = parseParticipantKey(encoded) - require.NoError(t, err) - require.Equal(t, livekit.RoomName("room1"), roomName) - require.Equal(t, livekit.ParticipantIdentity("identity1"), identity) - - // encode/decode with delimiter in roomName - encoded = ParticipantKey("room1|alter_room1", "identity1") - roomName, identity, err = parseParticipantKey(encoded) - require.NoError(t, err) - require.Equal(t, livekit.RoomName("room1|alter_room1"), roomName) - require.Equal(t, livekit.ParticipantIdentity("identity1"), identity) - - // encode/decode with delimiter in identity - encoded = ParticipantKey("room1", "identity1|alter-identity1") - roomName, identity, err = parseParticipantKey(encoded) - require.NoError(t, err) - require.Equal(t, livekit.RoomName("room1"), roomName) - require.Equal(t, livekit.ParticipantIdentity("identity1|alter-identity1"), identity) - - // encode/decode with delimiter in both and multiple delimiters in both - encoded = ParticipantKey("room1|alter_room1|again_room1", "identity1|alter-identity1|again-identity1") - roomName, identity, err = parseParticipantKey(encoded) - require.NoError(t, err) - require.Equal(t, livekit.RoomName("room1|alter_room1|again_room1"), roomName) - require.Equal(t, livekit.ParticipantIdentity("identity1|alter-identity1|again-identity1"), identity) -} diff --git a/pkg/rtc/dynacastmanager.go b/pkg/rtc/dynacastmanager.go index 8943646fe..1163cb09a 100644 --- a/pkg/rtc/dynacastmanager.go +++ b/pkg/rtc/dynacastmanager.go @@ -59,7 +59,7 @@ func NewDynacastManager(params DynacastManagerParams) *DynacastManager { maxSubscribedQuality: make(map[string]livekit.VideoQuality), committedMaxSubscribedQuality: make(map[string]livekit.VideoQuality), maxSubscribedQualityDebounce: debounce.New(params.DynacastPauseDelay), - qualityNotifyOpQueue: utils.NewOpsQueue(params.Logger, "quality-notify", 100), + qualityNotifyOpQueue: utils.NewOpsQueue("quality-notify", 0, true), } d.qualityNotifyOpQueue.Start() return d diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 584c953c3..ba714228d 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -22,7 +22,6 @@ import ( "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "go.uber.org/atomic" - "google.golang.org/protobuf/proto" "github.com/livekit/mediatransportutil/pkg/twcc" "github.com/livekit/protocol/livekit" @@ -52,7 +51,6 @@ type MediaTrack struct { } type MediaTrackParams struct { - TrackInfo *livekit.TrackInfo SignalCid string SdpCid string ParticipantID livekit.ParticipantID @@ -71,13 +69,12 @@ type MediaTrackParams struct { SimTracks map[uint32]SimulcastTrackInfo } -func NewMediaTrack(params MediaTrackParams) *MediaTrack { +func NewMediaTrack(params MediaTrackParams, ti *livekit.TrackInfo) *MediaTrack { t := &MediaTrack{ params: params, } t.MediaTrackReceiver = NewMediaTrackReceiver(MediaTrackReceiverParams{ - TrackInfo: params.TrackInfo, MediaTrack: t, IsRelayed: false, ParticipantID: params.ParticipantID, @@ -88,19 +85,9 @@ func NewMediaTrack(params MediaTrackParams) *MediaTrack { AudioConfig: params.AudioConfig, Telemetry: params.Telemetry, Logger: params.Logger, - }) - t.MediaTrackReceiver.OnVideoLayerUpdate(func(layers []*livekit.VideoLayer) { - t.params.Telemetry.TrackPublishedUpdate(context.Background(), t.PublisherID(), - &livekit.TrackInfo{ - Sid: string(t.ID()), - Type: livekit.TrackType_VIDEO, - Muted: t.IsMuted(), - Simulcast: t.IsSimulcast(), - Layers: layers, - }) - }) + }, ti) - if params.TrackInfo.Type == livekit.TrackType_AUDIO { + if ti.Type == livekit.TrackType_AUDIO { t.MediaLossProxy = NewMediaLossProxy(MediaLossProxyParams{ Logger: params.Logger, }) @@ -113,7 +100,7 @@ func NewMediaTrack(params MediaTrackParams) *MediaTrack { t.MediaTrackReceiver.OnMediaLossFeedback(t.MediaLossProxy.HandleMaxLossFeedback) } - if params.TrackInfo.Type == livekit.TrackType_VIDEO { + if ti.Type == livekit.TrackType_VIDEO { t.dynacastManager = NewDynacastManager(DynacastManagerParams{ DynacastPauseDelay: params.VideoConfig.DynacastPauseDelay, Logger: params.Logger, @@ -126,7 +113,7 @@ func NewMediaTrack(params MediaTrackParams) *MediaTrack { t.dynacastManager.NotifySubscriberMaxQuality( subscriberID, codec.MimeType, - buffer.SpatialLayerToVideoQuality(layer, t.params.TrackInfo), + buffer.SpatialLayerToVideoQuality(layer, t.MediaTrackReceiver.TrackInfo()), ) }, ) @@ -138,6 +125,7 @@ func NewMediaTrack(params MediaTrackParams) *MediaTrack { func (t *MediaTrack) OnSubscribedMaxQualityChange( f func( trackID livekit.TrackID, + trackInfo *livekit.TrackInfo, subscribedQualities []*livekit.SubscribedCodec, maxSubscribedQualities []types.SubscribedCodecQuality, ) error, @@ -148,13 +136,13 @@ func (t *MediaTrack) OnSubscribedMaxQualityChange( handler := func(subscribedQualities []*livekit.SubscribedCodec, maxSubscribedQualities []types.SubscribedCodecQuality) { if f != nil && !t.IsMuted() { - _ = f(t.ID(), subscribedQualities, maxSubscribedQualities) + _ = f(t.ID(), t.ToProto(), subscribedQualities, maxSubscribedQualities) } for _, q := range maxSubscribedQualities { receiver := t.Receiver(q.CodecMime) if receiver != nil { - receiver.SetMaxExpectedSpatialLayer(buffer.VideoQualityToSpatialLayer(q.Quality, t.params.TrackInfo)) + receiver.SetMaxExpectedSpatialLayer(buffer.VideoQualityToSpatialLayer(q.Quality, t.MediaTrackReceiver.TrackInfo())) } } } @@ -177,8 +165,8 @@ func (t *MediaTrack) HasSdpCid(cid string) bool { return true } - info := t.params.TrackInfo - for _, c := range info.Codecs { + ti := t.MediaTrackReceiver.TrackInfoClone() + for _, c := range ti.Codecs { if c.Cid == cid { return true } @@ -187,24 +175,11 @@ func (t *MediaTrack) HasSdpCid(cid string) bool { } func (t *MediaTrack) ToProto() *livekit.TrackInfo { - info := t.MediaTrackReceiver.TrackInfo(true) - info.Muted = t.IsMuted() - info.Simulcast = t.IsSimulcast() - return info + return t.MediaTrackReceiver.TrackInfoClone() } -func (t *MediaTrack) SetPendingCodecSid(codecs []*livekit.SimulcastCodec) { - ti := proto.Clone(t.params.TrackInfo).(*livekit.TrackInfo) - for _, c := range codecs { - for _, origin := range ti.Codecs { - if strings.Contains(origin.MimeType, c.Codec) { - origin.Cid = c.Cid - break - } - } - } - t.params.TrackInfo = ti - t.MediaTrackReceiver.UpdateTrackInfo(ti) +func (t *MediaTrack) UpdateCodecCid(codecs []*livekit.SimulcastCodec) { + t.MediaTrackReceiver.UpdateCodecCid(codecs) } // AddReceiver adds a new RTP receiver to the track, returns true when receiver represents a new codec @@ -233,24 +208,40 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra } }) + ti := t.MediaTrackReceiver.TrackInfoClone() t.lock.Lock() mime := strings.ToLower(track.Codec().MimeType) - layer := buffer.RidToSpatialLayer(track.RID(), t.trackInfo) - t.params.Logger.Debugw("AddReceiver", "mime", track.Codec().MimeType) + layer := buffer.RidToSpatialLayer(track.RID(), ti) + t.params.Logger.Debugw( + "AddReceiver", + "mime", track.Codec().MimeType, + "rid", track.RID(), + "layer", layer, + "ssrc", track.SSRC(), + ) wr := t.MediaTrackReceiver.Receiver(mime) if wr == nil { priority := -1 - for idx, c := range t.params.TrackInfo.Codecs { + for idx, c := range ti.Codecs { if strings.EqualFold(mime, c.MimeType) { priority = idx break } } - if len(t.params.TrackInfo.Codecs) == 0 { - priority = 0 + if priority < 0 { + switch len(ti.Codecs) { + case 0: + // audio track + priority = 0 + case 1: + // older clients or non simulcast-codec, mime type only set later + if ti.Codecs[0].MimeType == "" { + priority = 0 + } + } } if priority < 0 { - t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mime, "track", logger.Proto(t.params.TrackInfo)) + t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mime, "track", logger.Proto(ti)) t.lock.Unlock() return false } @@ -258,7 +249,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra newWR := sfu.NewWebRTCReceiver( receiver, track, - t.params.TrackInfo, + ti, LoggerWithCodecMime(t.params.Logger, mime), twcc, t.params.VideoConfig.StreamTracker, @@ -277,16 +268,20 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra } } }) - newWR.OnStatsUpdate(func(_ *sfu.WebRTCReceiver, stat *livekit.AnalyticsStat) { - // LK-TODO: this needs to be receiver/mime aware - key := telemetry.StatsKeyForTrack(livekit.StreamType_UPSTREAM, t.PublisherID(), t.ID(), t.params.TrackInfo.Source, t.params.TrackInfo.Type) - t.params.Telemetry.TrackStats(key, stat) - }) + // SIMULCAST-CODEC-TODO: these need to be receiver/mime aware, setting it up only for primary now + if priority == 0 { + newWR.OnStatsUpdate(func(_ *sfu.WebRTCReceiver, stat *livekit.AnalyticsStat) { + key := telemetry.StatsKeyForTrack(livekit.StreamType_UPSTREAM, t.PublisherID(), t.ID(), ti.Source, ti.Type) + t.params.Telemetry.TrackStats(key, stat) + }) + + newWR.OnMaxLayerChange(t.onMaxLayerChange) + } if t.PrimaryReceiver() == nil { // primary codec published, set potential codecs - potentialCodecs := make([]webrtc.RTPCodecParameters, 0, len(t.params.TrackInfo.Codecs)) + potentialCodecs := make([]webrtc.RTPCodecParameters, 0, len(ti.Codecs)) parameters := receiver.GetParameters() - for _, c := range t.params.TrackInfo.Codecs { + for _, c := range ti.Codecs { for _, nc := range parameters.Codecs { if strings.EqualFold(nc.MimeType, c.MimeType) { potentialCodecs = append(potentialCodecs, nc) @@ -301,8 +296,6 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra } } - newWR.OnMaxLayerChange(t.onMaxLayerChange) - t.buffer = buff t.MediaTrackReceiver.SetupReceiver(newWR, priority, mid) @@ -367,17 +360,7 @@ func (t *MediaTrack) HasPendingCodec() bool { } func (t *MediaTrack) onMaxLayerChange(maxLayer int32) { - ti := &livekit.TrackInfo{ - Sid: t.trackInfo.Sid, - Type: t.trackInfo.Type, - } - - if layer, ok := t.MediaTrackReceiver.layerDimensions[livekit.VideoQuality(maxLayer)]; ok { - ti.Layers = []*livekit.VideoLayer{{Quality: livekit.VideoQuality(maxLayer), Width: layer.Width, Height: layer.Height}} - } else if maxLayer == -1 { - ti.Layers = []*livekit.VideoLayer{{Quality: livekit.VideoQuality_OFF}} - } - t.params.Telemetry.TrackPublishedUpdate(context.Background(), t.PublisherID(), ti) + t.MediaTrackReceiver.NotifyMaxLayerChange(maxLayer) } func (t *MediaTrack) Restart() { diff --git a/pkg/rtc/mediatrack_test.go b/pkg/rtc/mediatrack_test.go index 9a0587447..8d96bbae4 100644 --- a/pkg/rtc/mediatrack_test.go +++ b/pkg/rtc/mediatrack_test.go @@ -35,9 +35,7 @@ func TestTrackInfo(t *testing.T) { Muted: true, } - mt := NewMediaTrack(MediaTrackParams{ - TrackInfo: &ti, - }) + mt := NewMediaTrack(MediaTrackParams{}, &ti) outInfo := mt.ToProto() require.Equal(t, ti.Muted, outInfo.Muted) require.Equal(t, ti.Name, outInfo.Name) @@ -51,17 +49,17 @@ func TestTrackInfo(t *testing.T) { require.Equal(t, ti.Simulcast, outInfo.Simulcast) // make it simulcasted - mt.simulcasted.Store(true) + mt.SetSimulcast(true) require.True(t, mt.ToProto().Simulcast) } func TestGetQualityForDimension(t *testing.T) { t.Run("landscape source", func(t *testing.T) { - mt := NewMediaTrack(MediaTrackParams{TrackInfo: &livekit.TrackInfo{ + mt := NewMediaTrack(MediaTrackParams{}, &livekit.TrackInfo{ Type: livekit.TrackType_VIDEO, Width: 1080, Height: 720, - }}) + }) require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(120, 120)) require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(300, 200)) @@ -71,11 +69,11 @@ func TestGetQualityForDimension(t *testing.T) { }) t.Run("portrait source", func(t *testing.T) { - mt := NewMediaTrack(MediaTrackParams{TrackInfo: &livekit.TrackInfo{ + mt := NewMediaTrack(MediaTrackParams{}, &livekit.TrackInfo{ Type: livekit.TrackType_VIDEO, Width: 540, Height: 960, - }}) + }) require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(200, 400)) require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(400, 400)) @@ -84,7 +82,7 @@ func TestGetQualityForDimension(t *testing.T) { }) t.Run("layers provided", func(t *testing.T) { - mt := NewMediaTrack(MediaTrackParams{TrackInfo: &livekit.TrackInfo{ + mt := NewMediaTrack(MediaTrackParams{}, &livekit.TrackInfo{ Type: livekit.TrackType_VIDEO, Width: 1080, Height: 720, @@ -105,7 +103,7 @@ func TestGetQualityForDimension(t *testing.T) { Height: 720, }, }, - }}) + }) require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(120, 120)) require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(300, 300)) @@ -114,7 +112,7 @@ func TestGetQualityForDimension(t *testing.T) { }) t.Run("highest layer with smallest dimensions", func(t *testing.T) { - mt := NewMediaTrack(MediaTrackParams{TrackInfo: &livekit.TrackInfo{ + mt := NewMediaTrack(MediaTrackParams{}, &livekit.TrackInfo{ Type: livekit.TrackType_VIDEO, Width: 1080, Height: 720, @@ -135,7 +133,7 @@ func TestGetQualityForDimension(t *testing.T) { Height: 720, }, }, - }}) + }) require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(120, 120)) require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(300, 300)) @@ -143,7 +141,7 @@ func TestGetQualityForDimension(t *testing.T) { require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(1000, 700)) require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(1200, 800)) - mt = NewMediaTrack(MediaTrackParams{TrackInfo: &livekit.TrackInfo{ + mt = NewMediaTrack(MediaTrackParams{}, &livekit.TrackInfo{ Type: livekit.TrackType_VIDEO, Width: 1080, Height: 720, @@ -164,7 +162,7 @@ func TestGetQualityForDimension(t *testing.T) { Height: 720, }, }, - }}) + }) require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(120, 120)) require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(300, 300)) diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index 20883e19a..d0818ceb1 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -15,6 +15,7 @@ package rtc import ( + "context" "errors" "fmt" "sort" @@ -23,7 +24,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/webrtc/v3" - "go.uber.org/atomic" + "golang.org/x/exp/slices" "google.golang.org/protobuf/proto" "github.com/livekit/protocol/livekit" @@ -73,8 +74,7 @@ func (m mediaTrackReceiverState) String() string { type simulcastReceiver struct { sfu.TrackReceiver - priority int - layerSSRCs [livekit.VideoQuality_HIGH + 1]uint32 + priority int } func (r *simulcastReceiver) Priority() int { @@ -82,7 +82,6 @@ func (r *simulcastReceiver) Priority() int { } type MediaTrackReceiverParams struct { - TrackInfo *livekit.TrackInfo MediaTrack types.MediaTrack IsRelayed bool ParticipantID livekit.ParticipantID @@ -96,32 +95,26 @@ type MediaTrackReceiverParams struct { } type MediaTrackReceiver struct { - params MediaTrackReceiverParams - muted atomic.Bool - simulcasted atomic.Bool + params MediaTrackReceiverParams lock sync.RWMutex receivers []*simulcastReceiver - receiversShadow []*simulcastReceiver trackInfo *livekit.TrackInfo - layerDimensions map[livekit.VideoQuality]*livekit.VideoLayer potentialCodecs []webrtc.RTPCodecParameters state mediaTrackReceiverState onSetupReceiver func(mime string) onMediaLossFeedback func(dt *sfu.DownTrack, report *rtcp.ReceiverReport) - onVideoLayerUpdate func(layers []*livekit.VideoLayer) onClose []func() *MediaTrackSubscriptions } -func NewMediaTrackReceiver(params MediaTrackReceiverParams) *MediaTrackReceiver { +func NewMediaTrackReceiver(params MediaTrackReceiverParams, ti *livekit.TrackInfo) *MediaTrackReceiver { t := &MediaTrackReceiver{ - params: params, - trackInfo: proto.Clone(params.TrackInfo).(*livekit.TrackInfo), - layerDimensions: make(map[livekit.VideoQuality]*livekit.VideoLayer), - state: mediaTrackReceiverStateOpen, + params: params, + trackInfo: proto.Clone(ti).(*livekit.TrackInfo), + state: mediaTrackReceiverStateOpen, } t.MediaTrackSubscriptions = NewMediaTrackSubscriptions(MediaTrackSubscriptionsParams{ @@ -137,22 +130,16 @@ func NewMediaTrackReceiver(params MediaTrackReceiverParams) *MediaTrackReceiver if t.trackInfo.Muted { t.SetMuted(true) } - - if t.trackInfo != nil && t.Kind() == livekit.TrackType_VIDEO { - t.UpdateVideoLayers(t.trackInfo.Layers) - // LK-TODO: maybe use this or simulcast flag in TrackInfo to set simulcasted here - } - return t } func (t *MediaTrackReceiver) Restart() { - t.lock.Lock() - receivers := t.receiversShadow - t.lock.Unlock() + t.lock.RLock() + hq := buffer.VideoQualityToSpatialLayer(livekit.VideoQuality_HIGH, t.trackInfo) + t.lock.RUnlock() - for _, receiver := range receivers { - receiver.SetMaxExpectedSpatialLayer(buffer.VideoQualityToSpatialLayer(livekit.VideoQuality_HIGH, t.params.TrackInfo)) + for _, receiver := range t.loadReceivers() { + receiver.SetMaxExpectedSpatialLayer(hq) } } @@ -170,9 +157,11 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority return } + receivers := slices.Clone(t.receivers) + // codec position maybe taken by DummyReceiver, check and upgrade to WebRTCReceiver var upgradeReceiver bool - for _, r := range t.receivers { + for _, r := range receivers { if strings.EqualFold(r.Codec().MimeType, receiver.Codec().MimeType) { if d, ok := r.TrackReceiver.(*DummyReceiver); ok { d.Upgrade(receiver) @@ -182,39 +171,44 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority } } if !upgradeReceiver { - t.receivers = append(t.receivers, &simulcastReceiver{TrackReceiver: receiver, priority: priority}) + receivers = append(receivers, &simulcastReceiver{TrackReceiver: receiver, priority: priority}) } - sort.Slice(t.receivers, func(i, j int) bool { - return t.receivers[i].Priority() < t.receivers[j].Priority() + sort.Slice(receivers, func(i, j int) bool { + return receivers[i].Priority() < receivers[j].Priority() }) if mid != "" { if priority == 0 { t.trackInfo.MimeType = receiver.Codec().MimeType t.trackInfo.Mid = mid - - // for clients don't have simulcast codecs (old version or single codec), add the primary codec - if len(t.trackInfo.Codecs) == 0 && t.trackInfo.Type == livekit.TrackType_VIDEO { - t.trackInfo.Codecs = append(t.trackInfo.Codecs, &livekit.SimulcastCodecInfo{}) - } } for i, ci := range t.trackInfo.Codecs { if i == priority { - ci.Mid = mid ci.MimeType = receiver.Codec().MimeType + ci.Mid = mid break } } } - t.shadowReceiversLocked() - + t.receivers = receivers onSetupReceiver := t.onSetupReceiver - t.params.Logger.Debugw("setup receiver", "mime", receiver.Codec().MimeType, "priority", priority, "receivers", t.receiversShadow) t.lock.Unlock() + var receiverCodecs []string + for _, r := range receivers { + receiverCodecs = append(receiverCodecs, r.Codec().MimeType) + } + t.params.Logger.Debugw( + "setup receiver", + "mime", receiver.Codec().MimeType, + "priority", priority, + "receivers", receiverCodecs, + "mid", mid, + ) + if onSetupReceiver != nil { onSetupReceiver(receiver.Codec().MimeType) } @@ -231,10 +225,11 @@ func (t *MediaTrackReceiver) SetPotentialCodecs(codecs []webrtc.RTPCodecParamete } } t.lock.Lock() + receivers := slices.Clone(t.receivers) t.potentialCodecs = codecs for i, c := range codecs { var exist bool - for _, r := range t.receivers { + for _, r := range receivers { if strings.EqualFold(c.MimeType, r.Codec().MimeType) { exist = true break @@ -245,53 +240,30 @@ func (t *MediaTrackReceiver) SetPotentialCodecs(codecs []webrtc.RTPCodecParamete if !sfu.IsSvcCodec(c.MimeType) { extHeaders = headersWithoutDD } - t.receivers = append(t.receivers, &simulcastReceiver{ + receivers = append(receivers, &simulcastReceiver{ TrackReceiver: NewDummyReceiver(livekit.TrackID(t.trackInfo.Sid), string(t.PublisherID()), c, extHeaders), priority: i, }) } } - sort.Slice(t.receivers, func(i, j int) bool { - return t.receivers[i].Priority() < t.receivers[j].Priority() + sort.Slice(receivers, func(i, j int) bool { + return receivers[i].Priority() < receivers[j].Priority() }) - t.shadowReceiversLocked() + t.receivers = receivers t.lock.Unlock() } -func (t *MediaTrackReceiver) shadowReceiversLocked() { - t.receiversShadow = make([]*simulcastReceiver, len(t.receivers)) - copy(t.receiversShadow, t.receivers) -} - -func (t *MediaTrackReceiver) SetLayerSsrc(mime string, rid string, ssrc uint32) { - t.lock.Lock() - defer t.lock.Unlock() - - layer := buffer.RidToSpatialLayer(rid, t.params.TrackInfo) - if layer == buffer.InvalidLayerSpatial { - // non-simulcast case will not have `rid` - layer = 0 - } - for _, receiver := range t.receiversShadow { - if strings.EqualFold(receiver.Codec().MimeType, mime) && int(layer) < len(receiver.layerSSRCs) { - receiver.layerSSRCs[layer] = ssrc - return - } - } -} - func (t *MediaTrackReceiver) ClearReceiver(mime string, willBeResumed bool) { - t.params.Logger.Debugw("clearing receiver", "mime", mime) t.lock.Lock() - for idx, receiver := range t.receivers { + receivers := slices.Clone(t.receivers) + for idx, receiver := range receivers { if strings.EqualFold(receiver.Codec().MimeType, mime) { - t.receivers[idx] = t.receivers[len(t.receivers)-1] - t.receivers = t.receivers[:len(t.receivers)-1] + receivers[idx] = receivers[len(receivers)-1] + receivers = receivers[:len(receivers)-1] break } } - - t.shadowReceiversLocked() + t.receivers = receivers t.lock.Unlock() t.removeAllSubscribersForMime(mime, willBeResumed) @@ -300,17 +272,12 @@ func (t *MediaTrackReceiver) ClearReceiver(mime string, willBeResumed bool) { func (t *MediaTrackReceiver) ClearAllReceivers(willBeResumed bool) { t.params.Logger.Debugw("clearing all receivers") t.lock.Lock() - var mimes []string - for _, receiver := range t.receivers { - mimes = append(mimes, receiver.Codec().MimeType) - } - - t.receivers = t.receivers[:0] - t.receiversShadow = nil + receivers := t.receivers + t.receivers = nil t.lock.Unlock() - for _, mime := range mimes { - t.ClearReceiver(mime, willBeResumed) + for _, r := range receivers { + t.removeAllSubscribersForMime(r.Codec().MimeType, willBeResumed) } } @@ -318,10 +285,6 @@ func (t *MediaTrackReceiver) OnMediaLossFeedback(f func(dt *sfu.DownTrack, rr *r t.onMediaLossFeedback = f } -func (t *MediaTrackReceiver) OnVideoLayerUpdate(f func(layers []*livekit.VideoLayer)) { - t.onVideoLayerUpdate = f -} - func (t *MediaTrackReceiver) IsOpen() bool { t.lock.RLock() defer t.lock.RUnlock() @@ -352,7 +315,7 @@ func (t *MediaTrackReceiver) TryClose() bool { return true } - for _, receiver := range t.receiversShadow { + for _, receiver := range t.receivers { if dr, _ := receiver.TrackReceiver.(*DummyReceiver); dr != nil && dr.Receiver() != nil { t.lock.RUnlock() return false @@ -421,11 +384,17 @@ func (t *MediaTrackReceiver) PublisherVersion() uint32 { } func (t *MediaTrackReceiver) IsSimulcast() bool { - return t.simulcasted.Load() + t.lock.RLock() + defer t.lock.RUnlock() + + return t.trackInfo.Simulcast } func (t *MediaTrackReceiver) SetSimulcast(simulcast bool) { - t.simulcasted.Store(simulcast) + t.lock.Lock() + defer t.lock.Unlock() + + t.trackInfo.Simulcast = simulcast } func (t *MediaTrackReceiver) Name() string { @@ -436,16 +405,18 @@ func (t *MediaTrackReceiver) Name() string { } func (t *MediaTrackReceiver) IsMuted() bool { - return t.muted.Load() + t.lock.RLock() + defer t.lock.RUnlock() + + return t.trackInfo.Muted } func (t *MediaTrackReceiver) SetMuted(muted bool) { - t.muted.Store(muted) + t.lock.Lock() + t.trackInfo.Muted = muted + t.lock.Unlock() - t.lock.RLock() - receivers := t.receiversShadow - t.lock.RUnlock() - for _, receiver := range receivers { + for _, receiver := range t.loadReceivers() { receiver.SetUpTrackPaused(muted) } @@ -470,7 +441,7 @@ func (t *MediaTrackReceiver) AddSubscriber(sub types.LocalParticipant) (types.Su return nil, ErrNotOpen } - receivers := t.receiversShadow + receivers := t.receivers potentialCodecs := make([]webrtc.RTPCodecParameters, len(t.potentialCodecs)) copy(potentialCodecs, t.potentialCodecs) t.lock.RUnlock() @@ -552,117 +523,182 @@ func (t *MediaTrackReceiver) RevokeDisallowedSubscribers(allowedSubscriberIdenti return revokedSubscriberIdentities } -func (t *MediaTrackReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { - t.lock.Lock() - t.trackInfo = proto.Clone(ti).(*livekit.TrackInfo) - t.lock.Unlock() +func (t *MediaTrackReceiver) updateTrackInfoOfReceivers() { + t.lock.RLock() + ti := proto.Clone(t.trackInfo).(*livekit.TrackInfo) + t.lock.RUnlock() - if ti != nil && t.Kind() == livekit.TrackType_VIDEO { - t.UpdateVideoLayers(ti.Layers) + for _, r := range t.loadReceivers() { + r.UpdateTrackInfo(ti) } } -func (t *MediaTrackReceiver) TrackInfo(generateLayer bool) *livekit.TrackInfo { - t.lock.RLock() - defer t.lock.RUnlock() - - ti := proto.Clone(t.trackInfo).(*livekit.TrackInfo) - if !generateLayer { - return ti +func (t *MediaTrackReceiver) SetLayerSsrc(mime string, rid string, ssrc uint32) { + t.lock.Lock() + layer := buffer.RidToSpatialLayer(rid, t.trackInfo) + if layer == buffer.InvalidLayerSpatial { + // non-simulcast case will not have `rid` + layer = 0 } - - layers := t.getVideoLayersLocked() - + quality := buffer.SpatialLayerToVideoQuality(layer, t.trackInfo) // set video layer ssrc info - for i, ci := range ti.Codecs { - for _, receiver := range t.receiversShadow { - if receiver.priority == i { - originLayers := ci.Layers - ci.Layers = []*livekit.VideoLayer{} - for layerIdx, layer := range layers { - ci.Layers = append(ci.Layers, proto.Clone(layer).(*livekit.VideoLayer)) + for i, ci := range t.trackInfo.Codecs { + if !strings.EqualFold(ci.MimeType, mime) { + continue + } - // if origin layer has ssrc, don't override it - ssrcFound := false - for _, l := range originLayers { - if l.Quality == ci.Layers[layerIdx].Quality { - if l.Ssrc != 0 { - ci.Layers[layerIdx].Ssrc = l.Ssrc - ssrcFound = true - } - break - } - } - if !ssrcFound && int(layer.Quality) < len(receiver.layerSSRCs) { - ci.Layers[layerIdx].Ssrc = receiver.layerSSRCs[layer.Quality] - } - } - - if i == 0 { - ti.Layers = ci.Layers + // if origin layer has ssrc, don't override it + var matchingLayer *livekit.VideoLayer + ssrcFound := false + for _, l := range ci.Layers { + if l.Quality == quality { + matchingLayer = l + if l.Ssrc != 0 { + ssrcFound = true } break } } + if !ssrcFound && matchingLayer != nil { + matchingLayer.Ssrc = ssrc + } + + // for client don't use simulcast codecs (old client version or single codec) + if i == 0 { + t.trackInfo.Layers = ci.Layers + } + break } + t.lock.Unlock() - // for client don't use simulcast codecs (old client version or single codec) - if len(ti.Codecs) == 0 && len(t.receiversShadow) > 0 { - receiver := t.receiversShadow[0] - originLayers := ti.Layers - ti.Layers = []*livekit.VideoLayer{} - for layerIdx, layer := range layers { - ti.Layers = append(ti.Layers, proto.Clone(layer).(*livekit.VideoLayer)) + t.updateTrackInfoOfReceivers() +} - // if origin layer has ssrc, don't override it - ssrcFound := false - for _, l := range originLayers { - if l.Quality == ti.Layers[layerIdx].Quality { - if l.Ssrc != 0 { - ti.Layers[layerIdx].Ssrc = l.Ssrc - ssrcFound = true - } - break - } - } - if !ssrcFound && int(layer.Quality) < len(receiver.layerSSRCs) { - ti.Layers[layerIdx].Ssrc = receiver.layerSSRCs[layer.Quality] +func (t *MediaTrackReceiver) UpdateCodecCid(codecs []*livekit.SimulcastCodec) { + t.lock.Lock() + for _, c := range codecs { + for _, origin := range t.trackInfo.Codecs { + if strings.Contains(origin.MimeType, c.Codec) { + origin.Cid = c.Cid + break } } } + t.lock.Unlock() - return ti + t.updateTrackInfoOfReceivers() +} + +func (t *MediaTrackReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { + updateMute := false + clonedInfo := proto.Clone(ti).(*livekit.TrackInfo) + + t.lock.Lock() + // patch Mid and SSRC of codecs/layers by keeping original if available + for i, ci := range clonedInfo.Codecs { + for _, originCi := range t.trackInfo.Codecs { + if !strings.EqualFold(ci.MimeType, originCi.MimeType) { + continue + } + + if originCi.Mid != "" { + ci.Mid = originCi.Mid + } + + for _, layer := range ci.Layers { + for _, originLayer := range originCi.Layers { + if layer.Quality == originLayer.Quality { + if originLayer.Ssrc != 0 { + layer.Ssrc = originLayer.Ssrc + } + break + } + } + } + break + } + + // for client don't use simulcast codecs (old client version or single codec) + if i == 0 { + clonedInfo.Layers = ci.Layers + } + } + if t.trackInfo.Muted != clonedInfo.Muted { + updateMute = true + } + t.trackInfo = clonedInfo + t.lock.Unlock() + + if updateMute { + t.SetMuted(clonedInfo.Muted) + } + + t.updateTrackInfoOfReceivers() } func (t *MediaTrackReceiver) UpdateVideoLayers(layers []*livekit.VideoLayer) { t.lock.Lock() - for _, layer := range layers { - t.layerDimensions[layer.Quality] = layer + // set video layer ssrc info + for i, ci := range t.trackInfo.Codecs { + originLayers := ci.Layers + ci.Layers = []*livekit.VideoLayer{} + for layerIdx, layer := range layers { + ci.Layers = append(ci.Layers, proto.Clone(layer).(*livekit.VideoLayer)) + for _, l := range originLayers { + if l.Quality == ci.Layers[layerIdx].Quality { + if l.Ssrc != 0 { + ci.Layers[layerIdx].Ssrc = l.Ssrc + } + break + } + } + } + + // for client don't use simulcast codecs (old client version or single codec) + if i == 0 { + t.trackInfo.Layers = ci.Layers + } } t.lock.Unlock() + t.updateTrackInfoOfReceivers() t.MediaTrackSubscriptions.UpdateVideoLayers() - if t.onVideoLayerUpdate != nil { - t.onVideoLayerUpdate(layers) - } - - // TODO: this might need to trigger a participant update for clients to pick up dimension change } -func (t *MediaTrackReceiver) GetVideoLayers() []*livekit.VideoLayer { +func (t *MediaTrackReceiver) TrackInfo() *livekit.TrackInfo { t.lock.RLock() defer t.lock.RUnlock() - return t.getVideoLayersLocked() + return t.trackInfo } -func (t *MediaTrackReceiver) getVideoLayersLocked() []*livekit.VideoLayer { - layers := make([]*livekit.VideoLayer, 0) - for _, layer := range t.layerDimensions { - layers = append(layers, proto.Clone(layer).(*livekit.VideoLayer)) - } +func (t *MediaTrackReceiver) TrackInfoClone() *livekit.TrackInfo { + t.lock.RLock() + defer t.lock.RUnlock() - return layers + return proto.Clone(t.trackInfo).(*livekit.TrackInfo) +} + +func (t *MediaTrackReceiver) NotifyMaxLayerChange(maxLayer int32) { + t.lock.RLock() + quality := buffer.SpatialLayerToVideoQuality(maxLayer, t.trackInfo) + ti := &livekit.TrackInfo{ + Sid: t.trackInfo.Sid, + Type: t.trackInfo.Type, + Layers: []*livekit.VideoLayer{{Quality: quality}}, + } + if quality != livekit.VideoQuality_OFF { + for _, layer := range t.trackInfo.Layers { + if layer.Quality == quality { + ti.Layers[0].Width = layer.Width + ti.Layers[0].Height = layer.Height + break + } + } + } + t.lock.RUnlock() + + t.params.Telemetry.TrackPublishedUpdate(context.Background(), t.PublisherID(), ti) } // GetQualityForDimension finds the closest quality to use for desired dimensions @@ -690,7 +726,7 @@ func (t *MediaTrackReceiver) GetQualityForDimension(width, height uint32) liveki // default sizes representing qualities low - high layerSizes := []uint32{180, 360, origSize} var providedSizes []uint32 - for _, layer := range t.layerDimensions { + for _, layer := range t.trackInfo.Layers { providedSizes = append(providedSizes, layer.Height) } if len(providedSizes) > 0 { @@ -739,15 +775,12 @@ func (t *MediaTrackReceiver) DebugInfo() map[string]interface{} { info := map[string]interface{}{ "ID": t.ID(), "Kind": t.Kind().String(), - "PubMuted": t.muted.Load(), + "PubMuted": t.IsMuted(), } info["DownTracks"] = t.MediaTrackSubscriptions.DebugInfo() - t.lock.RLock() - receivers := t.receiversShadow - t.lock.RUnlock() - for _, receiver := range receivers { + for _, receiver := range t.loadReceivers() { info[receiver.Codec().MimeType] = receiver.DebugInfo() } @@ -758,20 +791,20 @@ func (t *MediaTrackReceiver) PrimaryReceiver() sfu.TrackReceiver { t.lock.RLock() defer t.lock.RUnlock() - if len(t.receiversShadow) == 0 { + if len(t.receivers) == 0 { return nil } - if dr, ok := t.receiversShadow[0].TrackReceiver.(*DummyReceiver); ok { + if dr, ok := t.receivers[0].TrackReceiver.(*DummyReceiver); ok { return dr.Receiver() } - return t.receiversShadow[0].TrackReceiver + return t.receivers[0].TrackReceiver } func (t *MediaTrackReceiver) Receiver(mime string) sfu.TrackReceiver { t.lock.RLock() defer t.lock.RUnlock() - for _, r := range t.receiversShadow { + for _, r := range t.receivers { if strings.EqualFold(r.Codec().MimeType, mime) { if dr, ok := r.TrackReceiver.(*DummyReceiver); ok { return dr.Receiver() @@ -785,20 +818,21 @@ func (t *MediaTrackReceiver) Receiver(mime string) sfu.TrackReceiver { func (t *MediaTrackReceiver) Receivers() []sfu.TrackReceiver { t.lock.RLock() defer t.lock.RUnlock() - - receivers := make([]sfu.TrackReceiver, 0, len(t.receiversShadow)) - for _, r := range t.receiversShadow { - receivers = append(receivers, r.TrackReceiver) + receivers := make([]sfu.TrackReceiver, len(t.receivers)) + for i, r := range t.receivers { + receivers[i] = r.TrackReceiver } return receivers } -func (t *MediaTrackReceiver) SetRTT(rtt uint32) { - t.lock.Lock() - receivers := t.receiversShadow - t.lock.Unlock() +func (t *MediaTrackReceiver) loadReceivers() []*simulcastReceiver { + t.lock.RLock() + defer t.lock.RUnlock() + return t.receivers +} - for _, r := range receivers { +func (t *MediaTrackReceiver) SetRTT(rtt uint32) { + for _, r := range t.loadReceivers() { if wr, ok := r.TrackReceiver.(*sfu.WebRTCReceiver); ok { wr.SetRTT(rtt) } @@ -829,10 +863,7 @@ func (t *MediaTrackReceiver) IsEncrypted() bool { } func (t *MediaTrackReceiver) GetTrackStats() *livekit.RTPStats { - t.lock.Lock() - receivers := t.receiversShadow - t.lock.Unlock() - + receivers := t.loadReceivers() stats := make([]*livekit.RTPStats, 0, len(receivers)) for _, receiver := range receivers { receiverStats := receiver.GetTrackStats() diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 15db1a8a3..e074c2577 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -196,7 +196,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * }) downTrack.AddReceiverReportListener(func(dt *sfu.DownTrack, report *rtcp.ReceiverReport) { - sub.OnReceiverReport(dt, report) + sub.HandleReceiverReport(dt, report) }) var transceiver *webrtc.RTPTransceiver @@ -207,6 +207,12 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * replacedTrack := false existingTransceiver, dtState = sub.GetCachedDownTrack(trackID) if existingTransceiver != nil { + sub.GetLogger().Debugw( + "trying to use existing transceiver", + "publisher", subTrack.PublisherIdentity(), + "publisherID", subTrack.PublisherID(), + "trackID", trackID, + ) reusingTransceiver.Store(true) rtpSender := existingTransceiver.Sender() if rtpSender != nil { @@ -217,6 +223,12 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * sender = rtpSender transceiver = existingTransceiver replacedTrack = true + sub.GetLogger().Debugw( + "track replaced", + "publisher", subTrack.PublisherIdentity(), + "publisherID", subTrack.PublisherID(), + "trackID", trackID, + ) } } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 6e45fcef9..8c476d5e7 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -36,6 +36,7 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/rtc/supervisor" + "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" @@ -76,13 +77,14 @@ type downTrackState struct { // --------------------------------------------------------------- type participantUpdateInfo struct { + identity livekit.ParticipantIdentity version uint32 state livekit.ParticipantInfo_State updatedAt time.Time } func (p participantUpdateInfo) String() string { - return fmt.Sprintf("version: %d, state: %s, updatedAt: %s", p.version, p.state.String(), p.updatedAt.String()) + return fmt.Sprintf("identity: %s, version: %d, state: %s, updatedAt: %s", p.identity, p.version, p.state.String(), p.updatedAt.String()) } // --------------------------------------------------------------- @@ -96,6 +98,7 @@ type ParticipantParams struct { AudioConfig config.AudioConfig VideoConfig config.VideoConfig ProtocolVersion types.ProtocolVersion + SessionStartTime time.Time Telemetry telemetry.TelemetryService Trailer []byte PLIThrottleConfig config.PLIThrottleConfig @@ -117,6 +120,7 @@ type ParticipantParams struct { AllowUDPUnstableFallback bool TURNSEnabled bool GetParticipantInfo func(pID livekit.ParticipantID) *livekit.ParticipantInfo + GetRegionSettings func(ip string) *livekit.RegionSettings DisableSupervisor bool ReconnectOnPublicationError bool ReconnectOnSubscriptionError bool @@ -136,13 +140,18 @@ type ParticipantImpl struct { params ParticipantParams isClosed atomic.Bool - state atomic.Value // livekit.ParticipantInfo_State - resSinkMu sync.Mutex - resSink routing.MessageSink + closeReason atomic.Value // types.ParticipantCloseReason + + state atomic.Value // livekit.ParticipantInfo_State + + resSinkMu sync.Mutex + resSink routing.MessageSink + grants *auth.ClaimGrants hidden atomic.Bool isPublisher atomic.Bool + sessionStartRecorded atomic.Bool // when first connected connectedAt time.Time // timer that's set when disconnect is detected on primary PC @@ -158,8 +167,8 @@ type ParticipantImpl struct { pendingTracksLock utils.RWMutex pendingTracks map[string]*pendingTrackInfo pendingPublishingTracks map[livekit.TrackID]*pendingTrackInfo - // migrated in muted tracks are not fired need close at participant close - mutedTrackNotFired []*MediaTrack + // migrated in tracks that have not fired need close at participant close + pendingMigratedTracks []*MediaTrack // supported codecs enabledPublishCodecs []*livekit.Codec @@ -188,7 +197,6 @@ type ParticipantImpl struct { lastRTT uint32 lock utils.RWMutex - once sync.Once dirty atomic.Bool version atomic.Uint32 @@ -198,7 +206,7 @@ type ParticipantImpl struct { onTrackPublished func(types.LocalParticipant, types.MediaTrack) onTrackUpdated func(types.LocalParticipant, types.MediaTrack) onTrackUnpublished func(types.LocalParticipant, types.MediaTrack) - onStateChange func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State) + onStateChange func(p types.LocalParticipant, state livekit.ParticipantInfo_State) onMigrateStateChange func(p types.LocalParticipant, migrateState types.MigrateState) onParticipantUpdate func(types.LocalParticipant) onDataPacket func(types.LocalParticipant, *livekit.DataPacket) @@ -249,6 +257,7 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { if !params.DisableSupervisor { p.supervisor = supervisor.NewParticipantSupervisor(supervisor.ParticipantSupervisorParams{Logger: params.Logger}) } + p.closeReason.Store(types.ParticipantCloseReasonNone) p.version.Store(params.InitialVersion) p.timedVersion.Update(params.VersionGenerator.New()) p.migrateState.Store(types.MigrateStateInit) @@ -358,10 +367,6 @@ func (p *ParticipantImpl) GetClientConfiguration() *livekit.ClientConfiguration return p.params.ClientConf } -func (p *ParticipantImpl) GetICEConnectionType() types.ICEConnectionType { - return p.TransportManager.GetICEConnectionType() -} - func (p *ParticipantImpl) GetBufferFactory() *buffer.Factory { return p.params.Config.BufferFactory } @@ -524,18 +529,36 @@ func (p *ParticipantImpl) OnTrackPublished(callback func(types.LocalParticipant, p.lock.Unlock() } +func (p *ParticipantImpl) getOnTrackPublished() func(types.LocalParticipant, types.MediaTrack) { + p.lock.RLock() + defer p.lock.RUnlock() + return p.onTrackPublished +} + func (p *ParticipantImpl) OnTrackUnpublished(callback func(types.LocalParticipant, types.MediaTrack)) { p.lock.Lock() p.onTrackUnpublished = callback p.lock.Unlock() } -func (p *ParticipantImpl) OnStateChange(callback func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State)) { +func (p *ParticipantImpl) getOnTrackUnpublished() func(types.LocalParticipant, types.MediaTrack) { + p.lock.RLock() + defer p.lock.RUnlock() + return p.onTrackUnpublished +} + +func (p *ParticipantImpl) OnStateChange(callback func(p types.LocalParticipant, state livekit.ParticipantInfo_State)) { p.lock.Lock() p.onStateChange = callback p.lock.Unlock() } +func (p *ParticipantImpl) getOnStateChange() func(p types.LocalParticipant, state livekit.ParticipantInfo_State) { + p.lock.RLock() + defer p.lock.RUnlock() + return p.onStateChange +} + func (p *ParticipantImpl) OnMigrateStateChange(callback func(p types.LocalParticipant, state types.MigrateState)) { p.lock.Lock() p.onMigrateStateChange = callback @@ -545,7 +568,6 @@ func (p *ParticipantImpl) OnMigrateStateChange(callback func(p types.LocalPartic func (p *ParticipantImpl) getOnMigrateStateChange() func(p types.LocalParticipant, state types.MigrateState) { p.lock.RLock() defer p.lock.RUnlock() - return p.onMigrateStateChange } @@ -555,6 +577,12 @@ func (p *ParticipantImpl) OnTrackUpdated(callback func(types.LocalParticipant, t p.lock.Unlock() } +func (p *ParticipantImpl) getOnTrackUpdated() func(types.LocalParticipant, types.MediaTrack) { + p.lock.RLock() + defer p.lock.RUnlock() + return p.onTrackUpdated +} + func (p *ParticipantImpl) OnParticipantUpdate(callback func(types.LocalParticipant)) { p.lock.Lock() p.onParticipantUpdate = callback @@ -582,10 +610,8 @@ func (p *ParticipantImpl) OnClaimsChanged(callback func(types.LocalParticipant)) func (p *ParticipantImpl) HandleSignalSourceClose() { p.TransportManager.SetSignalSourceValid(false) - if !p.TransportManager.HasPublisherEverConnected() && !p.TransportManager.HasSubscriberEverConnected() { - reason := types.ParticipantCloseReasonJoinFailed - p.params.Logger.Infow("closing disconnected participant", "reason", reason) - _ = p.Close(false, reason, false) + if !p.HasConnected() { + _ = p.Close(false, types.ParticipantCloseReasonSignalSourceClose, false) } } @@ -634,13 +660,15 @@ func (p *ParticipantImpl) onPublisherAnswer(answer webrtc.SessionDescription) er } if p.MigrateState() == types.MigrateStateSync { - go p.handleMigrateMutedTrack() + go p.handleMigrateTracks() } return nil } -func (p *ParticipantImpl) handleMigrateMutedTrack() { - // muted track won't send rtp packet, so we add mediatrack manually +func (p *ParticipantImpl) handleMigrateTracks() { + // muted track won't send rtp packet, so it is required to add mediatrack manually. + // But, synthesising track publish for unmuted tracks keeps a consistent path. + // In both cases (muted and unmuted), when publisher sends media packets, OnTrack would register and go from there. var addedTracks []*MediaTrack p.pendingTracksLock.Lock() for cid, pti := range p.pendingTracks { @@ -652,17 +680,14 @@ func (p *ParticipantImpl) handleMigrateMutedTrack() { p.pubLogger.Warnw("too many pending migrated tracks", nil, "trackID", pti.trackInfos[0].Sid, "count", len(pti.trackInfos), "cid", cid) } - ti := pti.trackInfos[0] - if ti.Muted { - mt := p.addMigrateMutedTrack(cid, ti) - if mt != nil { - addedTracks = append(addedTracks, mt) - } else { - p.pubLogger.Warnw("could not find migrated muted track", nil, "cid", cid) - } + mt := p.addMigratedTrack(cid, pti.trackInfos[0]) + if mt != nil { + addedTracks = append(addedTracks, mt) + } else { + p.pubLogger.Warnw("could not find migrated track", nil, "cid", cid) } } - p.mutedTrackNotFired = append(p.mutedTrackNotFired, addedTracks...) + p.pendingMigratedTracks = append(p.pendingMigratedTracks, addedTracks...) if len(addedTracks) != 0 { p.dirty.Store(true) @@ -678,12 +703,12 @@ func (p *ParticipantImpl) handleMigrateMutedTrack() { }() } -func (p *ParticipantImpl) removeMutedTrackNotFired(mt *MediaTrack) { +func (p *ParticipantImpl) removePendingMigratedTrack(mt *MediaTrack) { p.pendingTracksLock.Lock() - for i, t := range p.mutedTrackNotFired { + for i, t := range p.pendingMigratedTracks { if t == mt { - p.mutedTrackNotFired[i] = p.mutedTrackNotFired[len(p.mutedTrackNotFired)-1] - p.mutedTrackNotFired = p.mutedTrackNotFired[:len(p.mutedTrackNotFired)-1] + p.pendingMigratedTracks[i] = p.pendingMigratedTracks[len(p.pendingMigratedTracks)-1] + p.pendingMigratedTracks = p.pendingMigratedTracks[:len(p.pendingMigratedTracks)-1] break } } @@ -730,12 +755,6 @@ func (p *ParticipantImpl) SetMigrateInfo( p.TransportManager.SetMigrateInfo(previousOffer, previousAnswer, dataChannels) } -func (p *ParticipantImpl) Start() { - p.once.Do(func() { - p.UpTrackManager.Start() - }) -} - func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseReason, isExpectedToResume bool) error { if p.isClosed.Swap(true) { // already closed @@ -747,20 +766,13 @@ func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseRea "sendLeave", sendLeave, "reason", reason.String(), "isExpectedToResume", isExpectedToResume, - "clientInfo", logger.Proto(p.params.ClientInfo), ) + p.closeReason.Store(reason) p.clearDisconnectTimer() p.clearMigrationTimer() - // send leave message if sendLeave { - _ = p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_Leave{ - Leave: &livekit.LeaveRequest{ - Reason: reason.ToDisconnectReason(), - }, - }, - }) + p.sendLeaveRequest(reason, isExpectedToResume, false, false) } if p.supervisor != nil { @@ -769,11 +781,11 @@ func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseRea p.pendingTracksLock.Lock() p.pendingTracks = make(map[string]*pendingTrackInfo) - closeMutedTrack := p.mutedTrackNotFired - p.mutedTrackNotFired = p.mutedTrackNotFired[:0] + pendingMigratedTracksToClose := p.pendingMigratedTracks + p.pendingMigratedTracks = p.pendingMigratedTracks[:0] p.pendingTracksLock.Unlock() - for _, t := range closeMutedTrack { + for _, t := range pendingMigratedTracksToClose { t.Close(isExpectedToResume) } @@ -806,6 +818,10 @@ func (p *ParticipantImpl) IsClosed() bool { return p.isClosed.Load() } +func (p *ParticipantImpl) CloseReason() types.ParticipantCloseReason { + return p.closeReason.Load().(types.ParticipantCloseReason) +} + // Negotiate subscriber SDP with client, if force is true, will cancel pending // negotiate task and negotiate immediately func (p *ParticipantImpl) Negotiate(force bool) { @@ -836,6 +852,7 @@ func (p *ParticipantImpl) MaybeStartMigration(force bool, onStart func()) bool { onStart() } + p.sendLeaveRequest(types.ParticipantCloseReasonMigrationRequested, true, false, true) p.CloseSignalConnection(types.SignallingCloseReasonMigration) // @@ -858,7 +875,7 @@ func (p *ParticipantImpl) MaybeStartMigration(force bool, onStart func()) bool { if p.IsClosed() || p.IsDisconnected() { return } - p.subLogger.Infow("closing subscriber peer connection to aid migration") + p.subLogger.Debugw("closing subscriber peer connection to aid migration") // // Close all down tracks before closing subscriber peer connection. @@ -996,6 +1013,10 @@ func (p *ParticipantImpl) GetConnectionQuality() *livekit.ConnectionQualityInfo } p.lock.Unlock() + if minQuality == livekit.ConnectionQuality_LOST && !p.ProtocolVersion().SupportsConnectionQualityLost() { + minQuality = livekit.ConnectionQuality_POOR + } + return &livekit.ConnectionQualityInfo{ ParticipantSid: string(p.ID()), Quality: minQuality, @@ -1116,16 +1137,93 @@ func (p *ParticipantImpl) UpdateMediaRTT(rtt uint32) { } } +type AnyTransportHandler struct { + transport.UnimplementedHandler + p *ParticipantImpl +} + +func (h AnyTransportHandler) OnFailed(isShortLived bool) { + h.p.onAnyTransportFailed() +} + +func (h AnyTransportHandler) OnNegotiationFailed() { + h.p.onAnyTransportNegotiationFailed() +} + +func (h AnyTransportHandler) OnICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error { + return h.p.onICECandidate(c, target) +} + +type PublisherTransportHandler struct { + AnyTransportHandler +} + +func (h PublisherTransportHandler) OnAnswer(sd webrtc.SessionDescription) error { + return h.p.onPublisherAnswer(sd) +} + +func (h PublisherTransportHandler) OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + h.p.onMediaTrack(track, rtpReceiver) +} + +func (h PublisherTransportHandler) OnInitialConnected() { + h.p.onPublisherInitialConnected() +} + +func (h PublisherTransportHandler) OnDataPacket(kind livekit.DataPacket_Kind, data []byte) { + h.p.onDataMessage(kind, data) +} + +type SubscriberTransportHandler struct { + AnyTransportHandler +} + +func (h SubscriberTransportHandler) OnOffer(sd webrtc.SessionDescription) error { + return h.p.onSubscriberOffer(sd) +} + +func (h SubscriberTransportHandler) OnStreamStateChange(update *streamallocator.StreamStateUpdate) error { + return h.p.onStreamStateChange(update) +} + +func (h SubscriberTransportHandler) OnInitialConnected() { + h.p.onSubscriberInitialConnected() +} + +type PrimaryTransportHandler struct { + transport.Handler + p *ParticipantImpl +} + +func (h PrimaryTransportHandler) OnInitialConnected() { + h.Handler.OnInitialConnected() + h.p.onPrimaryTransportInitialConnected() +} + +func (h PrimaryTransportHandler) OnFullyEstablished() { + h.p.onPrimaryTransportFullyEstablished() +} + func (p *ParticipantImpl) setupTransportManager() error { + ath := AnyTransportHandler{p: p} + var pth transport.Handler = PublisherTransportHandler{ath} + var sth transport.Handler = SubscriberTransportHandler{ath} + + subscriberAsPrimary := p.ProtocolVersion().SubscriberAsPrimary() && p.CanSubscribe() + if subscriberAsPrimary { + sth = PrimaryTransportHandler{sth, p} + } else { + pth = PrimaryTransportHandler{pth, p} + } + params := TransportManagerParams{ Identity: p.params.Identity, SID: p.params.SID, // primary connection does not change, canSubscribe can change if permission was updated // after the participant has joined - SubscriberAsPrimary: p.ProtocolVersion().SubscriberAsPrimary() && p.CanSubscribe(), + SubscriberAsPrimary: subscriberAsPrimary, Config: p.params.Config, ProtocolVersion: p.params.ProtocolVersion, - Telemetry: p.params.Telemetry, CongestionControlConfig: p.params.CongestionControlConfig, EnabledPublishCodecs: p.enabledPublishCodecs, EnabledSubscribeCodecs: p.enabledSubscribeCodecs, @@ -1139,6 +1237,8 @@ func (p *ParticipantImpl) setupTransportManager() error { AllowPlayoutDelay: p.params.PlayoutDelay.GetEnabled(), DataChannelMaxBufferedAmount: p.params.DataChannelMaxBufferedAmount, Logger: p.params.Logger.WithComponent(sutils.ComponentTransport), + PublisherHandler: pth, + SubscriberHandler: sth, } if p.params.SyncStreams && p.params.PlayoutDelay.GetEnabled() && p.params.ClientInfo.isFirefox() { // we will disable playout delay for Firefox if the user is expecting @@ -1170,27 +1270,6 @@ func (p *ParticipantImpl) setupTransportManager() error { } }) - tm.OnPublisherICECandidate(func(c *webrtc.ICECandidate) error { - return p.onICECandidate(c, livekit.SignalTarget_PUBLISHER) - }) - tm.OnPublisherAnswer(p.onPublisherAnswer) - tm.OnPublisherTrack(p.onMediaTrack) - tm.OnPublisherInitialConnected(p.onPublisherInitialConnected) - - tm.OnSubscriberOffer(p.onSubscriberOffer) - tm.OnSubscriberICECandidate(func(c *webrtc.ICECandidate) error { - return p.onICECandidate(c, livekit.SignalTarget_SUBSCRIBER) - }) - tm.OnSubscriberInitialConnected(p.onSubscriberInitialConnected) - tm.OnSubscriberStreamStateChange(p.onStreamStateChange) - - tm.OnPrimaryTransportInitialConnected(p.onPrimaryTransportInitialConnected) - tm.OnPrimaryTransportFullyEstablished(p.onPrimaryTransportFullyEstablished) - tm.OnAnyTransportFailed(p.onAnyTransportFailed) - tm.OnAnyTransportNegotiationFailed(p.onAnyTransportNegotiationFailed) - - tm.OnDataMessage(p.onDataMessage) - tm.SetSubscriberAllowPause(p.params.SubscriberAllowPause) p.TransportManager = tm return nil @@ -1204,12 +1283,8 @@ func (p *ParticipantImpl) setupUpTrackManager() { }) p.UpTrackManager.OnPublishedTrackUpdated(func(track types.MediaTrack) { - p.lock.RLock() - onTrackUpdated := p.onTrackUpdated - p.lock.RUnlock() - p.dirty.Store(true) - if onTrackUpdated != nil { + if onTrackUpdated := p.getOnTrackUpdated(); onTrackUpdated != nil { onTrackUpdated(p, track) } }) @@ -1240,27 +1315,16 @@ func (p *ParticipantImpl) setupParticipantTrafficLoad() { } func (p *ParticipantImpl) updateState(state livekit.ParticipantInfo_State) { - oldState := p.State() - if state == oldState { + oldState := p.state.Swap(state).(livekit.ParticipantInfo_State) + if oldState == state { return } p.params.Logger.Debugw("updating participant state", "state", state.String()) - p.state.Store(state) p.dirty.Store(true) - p.lock.RLock() - onStateChange := p.onStateChange - p.lock.RUnlock() - if onStateChange != nil { - go func() { - defer func() { - if r := Recover(p.GetLogger()); r != nil { - os.Exit(1) - } - }() - onStateChange(p, oldState) - }() + if onStateChange := p.getOnStateChange(); onStateChange != nil { + go onStateChange(p, state) } } @@ -1343,13 +1407,11 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w "rid", track.RID(), "SSRC", track.SSRC(), "mime", track.Codec().MimeType, + "trackInfo", logger.Proto(publishedTrack.ToProto()), ) if !isNewTrack && !publishedTrack.HasPendingCodec() && p.IsReady() { - p.lock.RLock() - onTrackUpdated := p.onTrackUpdated - p.lock.RUnlock() - if onTrackUpdated != nil { + if onTrackUpdated := p.getOnTrackUpdated(); onTrackUpdated != nil { onTrackUpdated(p, publishedTrack) } } @@ -1421,6 +1483,9 @@ func (p *ParticipantImpl) onPrimaryTransportInitialConnected() { } func (p *ParticipantImpl) onPrimaryTransportFullyEstablished() { + if !p.sessionStartRecorded.Swap(true) { + prometheus.RecordSessionStartTime(int(p.ProtocolVersion()), time.Since(p.params.SessionStartTime)) + } p.updateState(livekit.ParticipantInfo_ACTIVE) } @@ -1444,7 +1509,6 @@ func (p *ParticipantImpl) setupDisconnectTimer() { return } reason := types.ParticipantCloseReasonPeerConnectionDisconnected - p.params.Logger.Infow("closing disconnected participant", "reason", reason) _ = p.Close(true, reason, false) }) p.lock.Unlock() @@ -1452,6 +1516,7 @@ func (p *ParticipantImpl) setupDisconnectTimer() { func (p *ParticipantImpl) onAnyTransportFailed() { // clients support resuming of connections when websocket becomes disconnected + p.sendLeaveRequest(types.ParticipantCloseReasonPeerConnectionDisconnected, true, false, true) p.CloseSignalConnection(types.SignallingCloseReasonTransportFailure) // detect when participant has actually left. @@ -1486,7 +1551,11 @@ func (p *ParticipantImpl) subscriberRTCPWorker() { pkts = append(pkts, sr) sd = append(sd, chunks...) - batchSize = batchSize + 1 + len(chunks) + numItems := 0 + for _, chunk := range chunks { + numItems += len(chunk.Items) + } + batchSize = batchSize + 1 + numItems if batchSize >= sdBatchSize { if len(sd) != 0 { pkts = append(pkts, &rtcp.SourceDescription{Chunks: sd}) @@ -1545,7 +1614,12 @@ func (p *ParticipantImpl) onStreamStateChange(update *streamallocator.StreamStat }) } -func (p *ParticipantImpl) onSubscribedMaxQualityChange(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedCodec, maxSubscribedQualities []types.SubscribedCodecQuality) error { +func (p *ParticipantImpl) onSubscribedMaxQualityChange( + trackID livekit.TrackID, + trackInfo *livekit.TrackInfo, + subscribedQualities []*livekit.SubscribedCodec, + maxSubscribedQualities []types.SubscribedCodecQuality, +) error { if p.params.DisableDynacast { return nil } @@ -1554,6 +1628,28 @@ func (p *ParticipantImpl) onSubscribedMaxQualityChange(trackID livekit.TrackID, return nil } + // send layer info about max subscription changes to telemetry + for _, maxSubscribedQuality := range maxSubscribedQualities { + ti := &livekit.TrackInfo{ + Sid: trackInfo.Sid, + Type: trackInfo.Type, + } + for _, layer := range trackInfo.Layers { + if layer.Quality == maxSubscribedQuality.Quality { + ti.Width = layer.Width + ti.Height = layer.Height + break + } + } + p.params.Telemetry.TrackMaxSubscribedVideoQuality( + context.Background(), + p.ID(), + ti, + maxSubscribedQuality.CodecMime, + maxSubscribedQuality.Quality, + ) + } + // normalize the codec name for _, subscribedQuality := range subscribedQualities { subscribedQuality.Codec = strings.ToLower(strings.TrimLeft(subscribedQuality.Codec, "video/")) @@ -1565,36 +1661,6 @@ func (p *ParticipantImpl) onSubscribedMaxQualityChange(trackID livekit.TrackID, SubscribedCodecs: subscribedQualities, } - // send layer info about max subscription changes to telemetry - track := p.UpTrackManager.GetPublishedTrack(trackID) - var layerInfo map[livekit.VideoQuality]*livekit.VideoLayer - if track != nil { - layers := track.ToProto().Layers - layerInfo = make(map[livekit.VideoQuality]*livekit.VideoLayer, len(layers)) - for _, layer := range layers { - layerInfo[layer.Quality] = layer - } - } - - for _, maxSubscribedQuality := range maxSubscribedQualities { - ti := &livekit.TrackInfo{ - Sid: string(trackID), - Type: livekit.TrackType_VIDEO, - } - if info, ok := layerInfo[maxSubscribedQuality.Quality]; ok { - ti.Width = info.Width - ti.Height = info.Height - } - - p.params.Telemetry.TrackMaxSubscribedVideoQuality( - context.Background(), - p.ID(), - ti, - maxSubscribedQuality.CodecMime, - maxSubscribedQuality.Quality, - ) - } - p.pubLogger.Debugw( "sending max subscribed quality", "trackID", trackID, @@ -1619,7 +1685,7 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l return nil } - track.(*MediaTrack).SetPendingCodecSid(req.SimulcastCodecs) + track.(*MediaTrack).UpdateCodecCid(req.SimulcastCodecs) ti := track.ToProto() return ti } @@ -1642,35 +1708,52 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l ti.Stream = StreamFromTrackSource(ti.Source) } p.setStableTrackID(req.Cid, ti) - seenCodecs := make(map[string]struct{}) - for _, codec := range req.SimulcastCodecs { - mime := codec.Codec - if req.Type == livekit.TrackType_VIDEO { - if !strings.HasPrefix(mime, "video/") { - mime = "video/" + mime - } - if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) { - altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs) - p.pubLogger.Infow("falling back to alternative codec", - "codec", mime, - "altCodec", altCodec, - "trackID", ti.Sid, - ) - // select an alternative MIME type that's generally supported - mime = altCodec - } - } else if req.Type == livekit.TrackType_AUDIO && !strings.HasPrefix(mime, "audio/") { - mime = "audio/" + mime - } - if _, ok := seenCodecs[mime]; ok || mime == "" { - continue + if len(req.SimulcastCodecs) == 0 { + if req.Type == livekit.TrackType_VIDEO { + // clients not supporting simulcast codecs, synthesise a codec + ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ + Cid: req.Cid, + Layers: req.Layers, + }) + } + } else { + seenCodecs := make(map[string]struct{}) + for _, codec := range req.SimulcastCodecs { + mime := codec.Codec + if req.Type == livekit.TrackType_VIDEO { + if !strings.HasPrefix(mime, "video/") { + mime = "video/" + mime + } + if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) { + altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs) + p.pubLogger.Infow("falling back to alternative codec", + "codec", mime, + "altCodec", altCodec, + "trackID", ti.Sid, + ) + // select an alternative MIME type that's generally supported + mime = altCodec + } + } else if req.Type == livekit.TrackType_AUDIO && !strings.HasPrefix(mime, "audio/") { + mime = "audio/" + mime + } + + if _, ok := seenCodecs[mime]; ok || mime == "" { + continue + } + seenCodecs[mime] = struct{}{} + + clonedLayers := make([]*livekit.VideoLayer, 0, len(req.Layers)) + for _, l := range req.Layers { + clonedLayers = append(clonedLayers, proto.Clone(l).(*livekit.VideoLayer)) + } + ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ + MimeType: mime, + Cid: codec.Cid, + Layers: clonedLayers, + }) } - seenCodecs[mime] = struct{}{} - ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ - MimeType: mime, - Cid: codec.Cid, - }) } p.params.Telemetry.TrackPublishRequested(context.Background(), p.ID(), p.Identity(), ti) @@ -1689,7 +1772,7 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l } p.pendingTracks[req.Cid] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}} - p.pubLogger.Infow("pending track added", "trackID", ti.Sid, "track", logger.Proto(ti), "request", logger.Proto(req)) + p.pubLogger.Debugw("pending track added", "trackID", ti.Sid, "track", logger.Proto(ti), "request", logger.Proto(req)) return ti } @@ -1706,6 +1789,10 @@ func (p *ParticipantImpl) GetPendingTrack(trackID livekit.TrackID) *livekit.Trac return nil } +func (p *ParticipantImpl) HasConnected() bool { + return p.TransportManager.HasSubscriberEverConnected() || p.TransportManager.HasPublisherEverConnected() +} + func (p *ParticipantImpl) sendTrackPublished(cid string, ti *livekit.TrackInfo) { p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", logger.Proto(ti)) _ = p.writeMessage(&livekit.SignalResponse{ @@ -1771,6 +1858,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei p.pendingTracksLock.Lock() newTrack := false + mid := p.TransportManager.GetPublisherMid(rtpReceiver) p.pubLogger.Debugw( "media track received", "kind", track.Kind().String(), @@ -1778,8 +1866,8 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei "rid", track.RID(), "SSRC", track.SSRC(), "mime", track.Codec().MimeType, + "mid", mid, ) - mid := p.TransportManager.GetPublisherMid(rtpReceiver) if mid == "" { p.pendingTracksLock.Unlock() p.pubLogger.Warnw("could not get mid for track", nil, "trackID", track.ID()) @@ -1808,7 +1896,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei } } if codecFound != len(ti.Codecs) { - p.params.Logger.Warnw("migrated track codec mismatched", nil, "track", logger.Proto(ti), "webrtcCodec", parameters) + p.pubLogger.Warnw("migrated track codec mismatched", nil, "track", logger.Proto(ti), "webrtcCodec", parameters) p.pendingTracksLock.Unlock() p.IssueFullReconnect(types.ParticipantCloseReasonMigrateCodecMismatch) return nil, false @@ -1816,6 +1904,10 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei } ti.MimeType = track.Codec().MimeType + if utils.NewTimedVersionFromProto(ti.Version).IsZero() { + // only assign version on a fresh publish, i. e. avoid updating version in scenarios like migration + ti.Version = p.params.VersionGenerator.New().ToProto() + } mt = p.addMediaTrack(signalCid, track.ID(), ti) newTrack = true p.dirty.Store(true) @@ -1831,17 +1923,25 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei p.pendingTracksLock.Unlock() if mt.AddReceiver(rtpReceiver, track, p.twcc, mid) { - p.removeMutedTrackNotFired(mt) - if newTrack { - go p.handleTrackPublished(mt) - } + p.removePendingMigratedTrack(mt) + } + + if newTrack { + go func() { + p.pubLogger.Debugw( + "track published", + "trackID", mt.ID(), + "track", logger.Proto(mt.ToProto()), + ) + p.handleTrackPublished(mt) + }() } return mt, newTrack } -func (p *ParticipantImpl) addMigrateMutedTrack(cid string, ti *livekit.TrackInfo) *MediaTrack { - p.pubLogger.Infow("add migrate muted track", "cid", cid, "trackID", ti.Sid, "track", logger.Proto(ti)) +func (p *ParticipantImpl) addMigratedTrack(cid string, ti *livekit.TrackInfo) *MediaTrack { + p.pubLogger.Infow("add migrated track", "cid", cid, "trackID", ti.Sid, "track", logger.Proto(ti)) rtpReceiver := p.TransportManager.GetPublisherRTPReceiver(ti.Mid) if rtpReceiver == nil { p.pubLogger.Errorw("could not find receiver for migrated track", nil, "trackID", ti.Sid) @@ -1883,19 +1983,18 @@ func (p *ParticipantImpl) addMigrateMutedTrack(cid string, ti *livekit.TrackInfo for _, codec := range ti.Codecs { for ssrc, info := range p.params.SimTracks { if info.Mid == codec.Mid { - mt.MediaTrackReceiver.SetLayerSsrc(codec.MimeType, info.Rid, ssrc) + mt.SetLayerSsrc(codec.MimeType, info.Rid, ssrc) } } } mt.SetSimulcast(ti.Simulcast) - mt.SetMuted(true) + mt.SetMuted(ti.Muted) return mt } func (p *ParticipantImpl) addMediaTrack(signalCid string, sdpCid string, ti *livekit.TrackInfo) *MediaTrack { mt := NewMediaTrack(MediaTrackParams{ - TrackInfo: proto.Clone(ti).(*livekit.TrackInfo), SignalCid: signalCid, SdpCid: sdpCid, ParticipantID: p.params.SID, @@ -1911,7 +2010,7 @@ func (p *ParticipantImpl) addMediaTrack(signalCid string, sdpCid string, ti *liv SubscriberConfig: p.params.Config.Subscriber, PLIThrottleConfig: p.params.PLIThrottleConfig, SimTracks: p.params.SimTracks, - }) + }, ti) mt.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) @@ -1965,11 +2064,8 @@ func (p *ParticipantImpl) addMediaTrack(signalCid string, sdpCid string, ti *liv if !p.IsClosed() { // unpublished events aren't necessary when participant is closed - p.pubLogger.Infow("unpublished track", "trackID", ti.Sid, "trackInfo", ti) - p.lock.RLock() - onTrackUnpublished := p.onTrackUnpublished - p.lock.RUnlock() - if onTrackUnpublished != nil { + p.pubLogger.Debugw("track unpublished", "trackID", ti.Sid, "track", logger.Proto(ti)) + if onTrackUnpublished := p.getOnTrackUnpublished(); onTrackUnpublished != nil { onTrackUnpublished(p, mt) } } @@ -1979,10 +2075,7 @@ func (p *ParticipantImpl) addMediaTrack(signalCid string, sdpCid string, ti *liv } func (p *ParticipantImpl) handleTrackPublished(track types.MediaTrack) { - p.lock.RLock() - onTrackPublished := p.onTrackPublished - p.lock.RUnlock() - if onTrackPublished != nil { + if onTrackPublished := p.getOnTrackPublished(); onTrackPublished != nil { onTrackPublished(p, track) } @@ -2246,14 +2339,7 @@ func (p *ParticipantImpl) GetCachedDownTrack(trackID livekit.TrackID) (*webrtc.R } func (p *ParticipantImpl) IssueFullReconnect(reason types.ParticipantCloseReason) { - _ = p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_Leave{ - Leave: &livekit.LeaveRequest{ - CanReconnect: true, - Reason: reason.ToDisconnectReason(), - }, - }, - }) + p.sendLeaveRequest(reason, false, true, false) scr := types.SignallingCloseReasonUnknown switch reason { @@ -2304,7 +2390,7 @@ func (p *ParticipantImpl) onSubscriptionError(trackID livekit.TrackID, fatal boo } func (p *ParticipantImpl) onAnyTransportNegotiationFailed() { - if p.TransportManager.SinceLastSignal() < negotiationFailedTimeout { + if p.TransportManager.SinceLastSignal() < negotiationFailedTimeout/2 { p.params.Logger.Infow("negotiation failed, starting full reconnect") } p.IssueFullReconnect(types.ParticipantCloseReasonNegotiateFailed) diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index afd312bc8..a1b758ff8 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -753,6 +753,7 @@ func newParticipantForTestWithOpts(identity livekit.ParticipantIdentity, opts *p Config: rtcConf, Sink: &routingfakes.FakeMessageSink{}, ProtocolVersion: opts.protocolVersion, + SessionStartTime: time.Now(), PLIThrottleConfig: conf.RTC.PLIThrottle, Grants: grants, PublishEnabledCodecs: enabledCodecs, diff --git a/pkg/rtc/participant_sdp.go b/pkg/rtc/participant_sdp.go index 51e98c478..ed47d2da9 100644 --- a/pkg/rtc/participant_sdp.go +++ b/pkg/rtc/participant_sdp.go @@ -242,7 +242,7 @@ func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescript _, ti, _ = p.getPendingTrack(streamID, livekit.TrackType_AUDIO) p.pendingTracksLock.RUnlock() } else { - ti = track.TrackInfo(false) + ti = track.ToProto() } break } diff --git a/pkg/rtc/participant_signal.go b/pkg/rtc/participant_signal.go index 5214df47e..63e944e41 100644 --- a/pkg/rtc/participant_signal.go +++ b/pkg/rtc/participant_signal.go @@ -44,7 +44,12 @@ func (p *ParticipantImpl) SendJoinResponse(joinResponse *livekit.JoinResponse) e // keep track of participant updates and versions p.updateLock.Lock() for _, op := range joinResponse.OtherParticipants { - p.updateCache.Add(livekit.ParticipantID(op.Sid), participantUpdateInfo{version: op.Version, state: op.State, updatedAt: time.Now()}) + p.updateCache.Add(livekit.ParticipantID(op.Sid), participantUpdateInfo{ + identity: livekit.ParticipantIdentity(op.Identity), + version: op.Version, + state: op.State, + updatedAt: time.Now(), + }) } p.updateLock.Unlock() @@ -104,7 +109,12 @@ func (p *ParticipantImpl) SendParticipantUpdate(participantsToUpdate []*livekit. isValid = false } if isValid { - p.updateCache.Add(pID, participantUpdateInfo{version: pi.Version, state: pi.State, updatedAt: time.Now()}) + p.updateCache.Add(pID, participantUpdateInfo{ + identity: livekit.ParticipantIdentity(pi.Identity), + version: pi.Version, + state: pi.State, + updatedAt: time.Now(), + }) validUpdates = append(validUpdates, pi) } } @@ -212,15 +222,20 @@ func (p *ParticipantImpl) sendDisconnectUpdatesForReconnect() error { break } else if info.state == livekit.ParticipantInfo_DISCONNECTED { disconnectedParticipants = append(disconnectedParticipants, &livekit.ParticipantInfo{ - Sid: string(keys[i]), - Version: info.version, - State: livekit.ParticipantInfo_DISCONNECTED, + Sid: string(keys[i]), + Identity: string(info.identity), + Version: info.version, + State: livekit.ParticipantInfo_DISCONNECTED, }) } } } p.updateLock.Unlock() + if len(disconnectedParticipants) == 0 { + return nil + } + return p.writeMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Update{ Update: &livekit.ParticipantUpdate{ @@ -289,8 +304,52 @@ func (p *ParticipantImpl) writeMessage(msg *livekit.SignalResponse) error { func (p *ParticipantImpl) CloseSignalConnection(reason types.SignallingCloseReason) { sink := p.getResponseSink() if sink != nil { - p.params.Logger.Infow("closing signal connection", "reason", reason, "connID", sink.ConnectionID()) + if reason != types.SignallingCloseReasonParticipantClose { + p.params.Logger.Infow("closing signal connection", "reason", reason, "connID", sink.ConnectionID()) + } sink.Close() p.SetResponseSink(nil) } } + +func (p *ParticipantImpl) sendLeaveRequest( + reason types.ParticipantCloseReason, + isExpectedToResume bool, + isExpectedToReconnect bool, + sendOnlyIfSupportingLeaveRequestWithAction bool, +) error { + var leave *livekit.LeaveRequest + if p.ProtocolVersion().SupportsRegionsInLeaveRequest() { + leave = &livekit.LeaveRequest{ + Reason: reason.ToDisconnectReason(), + } + switch { + case isExpectedToResume: + leave.Action = livekit.LeaveRequest_RESUME + case isExpectedToReconnect: + leave.Action = livekit.LeaveRequest_RECONNECT + default: + leave.Action = livekit.LeaveRequest_DISCONNECT + } + if leave.Action != livekit.LeaveRequest_DISCONNECT && p.params.GetRegionSettings != nil { + // sending region settings even for RESUME just in case client wants to a full reconnect despite server saying RESUME + leave.Regions = p.params.GetRegionSettings(p.params.ClientInfo.Address) + } + } else { + if !sendOnlyIfSupportingLeaveRequestWithAction { + leave = &livekit.LeaveRequest{ + CanReconnect: isExpectedToReconnect, + Reason: reason.ToDisconnectReason(), + } + } + } + if leave != nil { + return p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_Leave{ + Leave: leave, + }, + }) + } + + return nil +} diff --git a/pkg/rtc/participant_traffic_load.go b/pkg/rtc/participant_traffic_load.go index eef44ad5e..87b8b9ef2 100644 --- a/pkg/rtc/participant_traffic_load.go +++ b/pkg/rtc/participant_traffic_load.go @@ -144,7 +144,7 @@ func (p *ParticipantTrafficLoad) updateTrafficLoad() *types.TrafficLoad { trafficTypeStats := make([]*types.TrafficTypeStats, 0, 6) addTypeStats := func(statsList []*types.TrafficStats, trackType livekit.TrackType, streamType livekit.StreamType) { - agg := types.AggregateTrafficStats(statsList) + agg := types.AggregateTrafficStats(statsList...) if agg != nil { trafficTypeStats = append(trafficTypeStats, &types.TrafficTypeStats{ TrackType: trackType, diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 09904506f..4e848a7c5 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -17,9 +17,11 @@ package rtc import ( "context" "errors" + "fmt" "io" "math" "sort" + "strings" "sync" "time" @@ -51,6 +53,8 @@ const ( subscriberUpdateInterval = 3 * time.Second dataForwardLoadBalanceThreshold = 20 + + simulateDisconnectSignalTimeout = 5 * time.Second ) var ( @@ -64,6 +68,17 @@ type broadcastOptions struct { immediate bool } +type participantUpdate struct { + pi *livekit.ParticipantInfo + isSynthesizedDisconnect bool + closeReason types.ParticipantCloseReason +} + +type disconnectSignalOnResumeNoMessages struct { + expiry time.Time + closedCount int +} + type Room struct { lock sync.RWMutex @@ -91,7 +106,7 @@ type Room struct { bufferFactory *buffer.FactoryOfBufferFactory // batch update participant info for non-publishers - batchedUpdates map[livekit.ParticipantIdentity]*livekit.ParticipantInfo + batchedUpdates map[livekit.ParticipantIdentity]*participantUpdate batchedUpdatesMu sync.Mutex // time the first participant joined the room @@ -106,6 +121,10 @@ type Room struct { onParticipantChanged func(p types.LocalParticipant) onRoomUpdated func() onClose func() + + simulationLock sync.Mutex + disconnectSignalOnResumeParticipants map[livekit.ParticipantIdentity]time.Time + disconnectSignalOnResumeNoMessagesParticipants map[livekit.ParticipantIdentity]*disconnectSignalOnResumeNoMessages } type ParticipantOptions struct { @@ -130,21 +149,23 @@ func NewRoom( livekit.RoomName(room.Name), livekit.RoomID(room.Sid), ), - config: config, - audioConfig: audioConfig, - telemetry: telemetry, - egressLauncher: egressLauncher, - agentClient: agentClient, - trackManager: NewRoomTrackManager(), - serverInfo: serverInfo, - participants: make(map[livekit.ParticipantIdentity]types.LocalParticipant), - participantOpts: make(map[livekit.ParticipantIdentity]*ParticipantOptions), - participantRequestSources: make(map[livekit.ParticipantIdentity]routing.MessageSource), - hasPublished: make(map[livekit.ParticipantIdentity]bool), - bufferFactory: buffer.NewFactoryOfBufferFactory(config.Receiver.PacketBufferSize), - batchedUpdates: make(map[livekit.ParticipantIdentity]*livekit.ParticipantInfo), - closed: make(chan struct{}), - trailer: []byte(utils.RandomSecret()), + config: config, + audioConfig: audioConfig, + telemetry: telemetry, + egressLauncher: egressLauncher, + agentClient: agentClient, + trackManager: NewRoomTrackManager(), + serverInfo: serverInfo, + participants: make(map[livekit.ParticipantIdentity]types.LocalParticipant), + participantOpts: make(map[livekit.ParticipantIdentity]*ParticipantOptions), + participantRequestSources: make(map[livekit.ParticipantIdentity]routing.MessageSource), + hasPublished: make(map[livekit.ParticipantIdentity]bool), + bufferFactory: buffer.NewFactoryOfBufferFactory(config.Receiver.PacketBufferSize), + batchedUpdates: make(map[livekit.ParticipantIdentity]*participantUpdate), + closed: make(chan struct{}), + trailer: []byte(utils.RandomSecret()), + disconnectSignalOnResumeParticipants: make(map[livekit.ParticipantIdentity]time.Time), + disconnectSignalOnResumeNoMessagesParticipants: make(map[livekit.ParticipantIdentity]*disconnectSignalOnResumeNoMessages), } r.protoProxy = utils.NewProtoProxy[*livekit.Room](roomUpdateInterval, r.updateProto) @@ -173,6 +194,7 @@ func NewRoom( go r.audioUpdateWorker() go r.connectionQualityWorker() go r.changeUpdateWorker() + go r.simulationCleanupWorker() return r } @@ -305,7 +327,6 @@ func (r *Room) Join(participant types.LocalParticipant, requestSource routing.Me if r.participants[participant.Identity()] != nil { return ErrAlreadyJoined } - if r.protoRoom.MaxParticipants > 0 && !participant.IsRecorder() { numParticipants := uint32(0) for _, p := range r.participants { @@ -322,41 +343,42 @@ func (r *Room) Join(participant types.LocalParticipant, requestSource routing.Me r.joinedAt.Store(time.Now().Unix()) } - // it's important to set this before connection, we don't want to miss out on any published tracks - participant.OnTrackPublished(r.onTrackPublished) - participant.OnStateChange(func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State) { - r.Logger.Infow("participant state changed", - "state", p.State(), - "participant", p.Identity(), - "pID", p.ID(), - "oldState", oldState) + participant.OnStateChange(func(p types.LocalParticipant, state livekit.ParticipantInfo_State) { if r.onParticipantChanged != nil { - r.onParticipantChanged(participant) + r.onParticipantChanged(p) } r.broadcastParticipantState(p, broadcastOptions{skipSource: true}) - state := p.State() if state == livekit.ParticipantInfo_ACTIVE { // subscribe participant to existing published tracks r.subscribeToExistingTracks(p) - // start the workers once connectivity is established - p.Start() - + meta := &livekit.AnalyticsClientMeta{ + ClientConnectTime: uint32(time.Since(p.ConnectedAt()).Milliseconds()), + } + cds := p.GetICEConnectionDetails() + for _, cd := range cds { + if cd.Type != types.ICEConnectionTypeUnknown { + meta.ConnectionType = string(cd.Type) + break + } + } r.telemetry.ParticipantActive(context.Background(), r.ToProto(), p.ToProto(), - &livekit.AnalyticsClientMeta{ - ClientConnectTime: uint32(time.Since(p.ConnectedAt()).Milliseconds()), - ConnectionType: string(p.GetICEConnectionType()), - }, + meta, false, ) + + p.GetLogger().Infow("participant active", connectionDetailsFields(cds)...) } else if state == livekit.ParticipantInfo_DISCONNECTED { // remove participant from room - go r.RemoveParticipant(p.Identity(), p.ID(), types.ParticipantCloseReasonStateDisconnected) + // participant should already be closed and have a close reason, so NONE is fine here + go r.RemoveParticipant(p.Identity(), p.ID(), types.ParticipantCloseReasonNone) } }) + // it's important to set this before connection, we don't want to miss out on any published tracks + participant.OnTrackPublished(r.onTrackPublished) participant.OnTrackUpdated(r.onTrackUpdated) participant.OnTrackUnpublished(r.onTrackUnpublished) participant.OnParticipantUpdate(r.onParticipantUpdate) @@ -394,8 +416,8 @@ func (r *Room) Join(participant types.LocalParticipant, requestSource routing.Me }, }, true) } - }) + r.Logger.Debugw("new participant joined", "pID", participant.ID(), "participant", participant.Identity(), @@ -474,6 +496,26 @@ func (r *Room) ResumeParticipant(p types.LocalParticipant, requestSource routing p.SetSignalSourceValid(true) + // check for simulated signal disconnect on resume before sending any signal response messages + r.simulationLock.Lock() + if state, ok := r.disconnectSignalOnResumeNoMessagesParticipants[p.Identity()]; ok { + // WARNING: this uses knowledge that service layer tries internally + simulated := false + if time.Now().Before(state.expiry) { + state.closedCount++ + p.CloseSignalConnection(types.SignallingCloseReasonDisconnectOnResumeNoMessages) + simulated = true + } + if state.closedCount == 3 { + delete(r.disconnectSignalOnResumeNoMessagesParticipants, p.Identity()) + } + if simulated { + r.simulationLock.Unlock() + return nil + } + } + r.simulationLock.Unlock() + if err := p.HandleReconnectAndSendResponse(reason, &livekit.ReconnectResponse{ IceServers: iceServers, ClientConfiguration: p.GetClientConfiguration(), @@ -489,30 +531,44 @@ func (r *Room) ResumeParticipant(p types.LocalParticipant, requestSource routing _ = p.SendRoomUpdate(r.ToProto()) p.ICERestart(nil) + + // check for simulated signal disconnect on resume + r.simulationLock.Lock() + if timeout, ok := r.disconnectSignalOnResumeParticipants[p.Identity()]; ok { + if time.Now().Before(timeout) { + p.CloseSignalConnection(types.SignallingCloseReasonDisconnectOnResume) + } + delete(r.disconnectSignalOnResumeParticipants, p.Identity()) + } + r.simulationLock.Unlock() + return nil } func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livekit.ParticipantID, reason types.ParticipantCloseReason) { r.lock.Lock() p, ok := r.participants[identity] - if ok { - if pID != "" && p.ID() != pID { - // participant session has been replaced - r.lock.Unlock() - return - } + if !ok { + r.lock.Unlock() + return + } - delete(r.participants, identity) - delete(r.participantOpts, identity) - delete(r.participantRequestSources, identity) - delete(r.hasPublished, identity) - if !p.Hidden() { - r.protoRoom.NumParticipants-- - } + if pID != "" && p.ID() != pID { + // participant session has been replaced + r.lock.Unlock() + return + } + + delete(r.participants, identity) + delete(r.participantOpts, identity) + delete(r.participantRequestSources, identity) + delete(r.hasPublished, identity) + if !p.Hidden() { + r.protoRoom.NumParticipants-- } immediateChange := false - if (p != nil && p.IsRecorder()) || r.protoRoom.ActiveRecording { + if p.IsRecorder() { activeRecording := false for _, op := range r.participants { if op.IsRecorder() { @@ -524,14 +580,16 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek if r.protoRoom.ActiveRecording != activeRecording { r.protoRoom.ActiveRecording = activeRecording immediateChange = true - } } r.lock.Unlock() r.protoProxy.MarkDirty(immediateChange) - if !ok { - return + if !p.HasConnected() { + fields := append(connectionDetailsFields(p.GetICEConnectionDetails()), + "reason", reason.String(), + ) + p.GetLogger().Infow("removing participant without connection", fields...) } // send broadcast only if it's not already closed @@ -551,7 +609,6 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek p.OnSubscribeStatusChanged(nil) // close participant as well - r.Logger.Debugw("closing participant for removal", "pID", p.ID(), "participant", p.Identity()) _ = p.Close(true, reason, false) r.leftAt.Store(time.Now().Unix()) @@ -714,11 +771,11 @@ func (r *Room) CloseIfEmpty() { r.lock.Unlock() if elapsed >= int64(timeout) { - r.Close() + r.Close(types.ParticipantCloseReasonNone) } } -func (r *Room) Close() { +func (r *Room) Close(reason types.ParticipantCloseReason) { r.lock.Lock() select { case <-r.closed: @@ -729,11 +786,14 @@ func (r *Room) Close() { } close(r.closed) r.lock.Unlock() + r.Logger.Infow("closing room") for _, p := range r.GetParticipants() { - _ = p.Close(true, types.ParticipantCloseReasonRoomClose, false) + _ = p.Close(true, reason, false) } + r.protoProxy.Stop() + if r.onClose != nil { r.onClose() } @@ -757,11 +817,11 @@ func (r *Room) SendDataPacket(up *livekit.UserPacket, kind livekit.DataPacket_Ki r.onDataPacket(nil, dp) } -func (r *Room) SetMetadata(metadata string) { +func (r *Room) SetMetadata(metadata string) <-chan struct{} { r.lock.Lock() r.protoRoom.Metadata = metadata r.lock.Unlock() - r.protoProxy.MarkDirty(true) + return r.protoProxy.MarkDirty(true) } func (r *Room) UpdateParticipantMetadata(participant types.LocalParticipant, name string, metadata string) { @@ -839,6 +899,18 @@ func (r *Room) SimulateScenario(participant types.LocalParticipant, simulateScen r.Logger.Infow("simulating subscriber bandwidth end", "participant", participant.Identity()) } participant.SetSubscriberChannelCapacity(scenario.SubscriberBandwidth) + case *livekit.SimulateScenario_DisconnectSignalOnResume: + participant.GetLogger().Infow("simulating disconnect signal on resume") + r.simulationLock.Lock() + r.disconnectSignalOnResumeParticipants[participant.Identity()] = time.Now().Add(simulateDisconnectSignalTimeout) + r.simulationLock.Unlock() + case *livekit.SimulateScenario_DisconnectSignalOnResumeNoMessages: + participant.GetLogger().Infow("simulating disconnect signal on resume before sending any response messages") + r.simulationLock.Lock() + r.disconnectSignalOnResumeNoMessagesParticipants[participant.Identity()] = &disconnectSignalOnResumeNoMessages{ + expiry: time.Now().Add(simulateDisconnectSignalTimeout), + } + r.simulationLock.Unlock() } return nil } @@ -1039,27 +1111,49 @@ func (r *Room) broadcastParticipantState(p types.LocalParticipant, opts broadcas // send update only to hidden participant err := p.SendParticipantUpdate([]*livekit.ParticipantInfo{pi}) if err != nil { - r.Logger.Errorw("could not send update to participant", err, - "participant", p.Identity(), "pID", p.ID()) + p.GetLogger().Errorw("could not send update to participant", err) } } return } - updates := r.pushAndDequeueUpdates(pi, opts.immediate) + updates := r.pushAndDequeueUpdates(pi, p.CloseReason(), opts.immediate) r.sendParticipantUpdates(updates) } -func (r *Room) sendParticipantUpdates(updates []*livekit.ParticipantInfo) { +func (r *Room) sendParticipantUpdates(updates []*participantUpdate) { if len(updates) == 0 { return } + // For filtered updates, skip + // 1. synthesized DISCONNECT - this happens on SID change + // 2. close reasons of DUPLICATE_IDENTITY/STALE - A newer session for that identity exists. + // + // Filtered updates are used with clients that can handle identity based reconnect and hence those + // conditions can be skipped. + var filteredUpdates []*livekit.ParticipantInfo + for _, update := range updates { + if update.isSynthesizedDisconnect || IsCloseNotifySkippable(update.closeReason) { + continue + } + filteredUpdates = append(filteredUpdates, update.pi) + } + + var fullUpdates []*livekit.ParticipantInfo + for _, update := range updates { + fullUpdates = append(fullUpdates, update.pi) + } + for _, op := range r.GetParticipants() { - err := op.SendParticipantUpdate(updates) + var err error + if op.ProtocolVersion().SupportsIdentityBasedReconnection() { + err = op.SendParticipantUpdate(filteredUpdates) + } else { + err = op.SendParticipantUpdate(fullUpdates) + } if err != nil { - r.Logger.Errorw("could not send update to participant", err, - "participant", op.Identity(), "pID", op.ID()) + op.GetLogger().Errorw("could not send update to participant", err) } } } @@ -1105,29 +1199,34 @@ func (r *Room) sendSpeakerChanges(speakers []*livekit.SpeakerInfo) { // * subscriber-only updates will be queued for batch updates // * publisher & immediate updates will be returned without queuing // * when the SID changes, it will return both updates, with the earlier participant set to disconnected -func (r *Room) pushAndDequeueUpdates(pi *livekit.ParticipantInfo, isImmediate bool) []*livekit.ParticipantInfo { +func (r *Room) pushAndDequeueUpdates( + pi *livekit.ParticipantInfo, + closeReason types.ParticipantCloseReason, + isImmediate bool, +) []*participantUpdate { r.batchedUpdatesMu.Lock() defer r.batchedUpdatesMu.Unlock() - var updates []*livekit.ParticipantInfo + var updates []*participantUpdate identity := livekit.ParticipantIdentity(pi.Identity) existing := r.batchedUpdates[identity] shouldSend := isImmediate || pi.IsPublisher if existing != nil { - if pi.Sid == existing.Sid { + if pi.Sid == existing.pi.Sid { // same participant session - if pi.Version < existing.Version { + if pi.Version < existing.pi.Version { // out of order update return nil } } else { // different participant sessions - if existing.JoinedAt < pi.JoinedAt { + if existing.pi.JoinedAt < pi.JoinedAt { // existing is older, synthesize a DISCONNECT for older and // send immediately along with newer session to signal switch shouldSend = true - existing.State = livekit.ParticipantInfo_DISCONNECTED + existing.pi.State = livekit.ParticipantInfo_DISCONNECTED + existing.isSynthesizedDisconnect = true updates = append(updates, existing) } else { // older session update, newer session has already become active, so nothing to do @@ -1148,10 +1247,10 @@ func (r *Room) pushAndDequeueUpdates(pi *livekit.ParticipantInfo, isImmediate bo if shouldSend { // include any queued update, and return delete(r.batchedUpdates, identity) - updates = append(updates, pi) + updates = append(updates, &participantUpdate{pi: pi, closeReason: closeReason}) } else { // enqueue for batch - r.batchedUpdates[identity] = pi + r.batchedUpdates[identity] = &participantUpdate{pi: pi, closeReason: closeReason} } return updates @@ -1192,18 +1291,14 @@ func (r *Room) changeUpdateWorker() { case <-subTicker.C: r.batchedUpdatesMu.Lock() updatesMap := r.batchedUpdates - r.batchedUpdates = make(map[livekit.ParticipantIdentity]*livekit.ParticipantInfo) + r.batchedUpdates = make(map[livekit.ParticipantIdentity]*participantUpdate) r.batchedUpdatesMu.Unlock() if len(updatesMap) == 0 { continue } - updates := make([]*livekit.ParticipantInfo, 0, len(updatesMap)) - for _, pi := range updatesMap { - updates = append(updates, pi) - } - r.sendParticipantUpdates(updates) + r.sendParticipantUpdates(maps.Values(updatesMap)) } } } @@ -1328,6 +1423,31 @@ func (r *Room) connectionQualityWorker() { } } +func (r *Room) simulationCleanupWorker() { + for { + if r.IsClosed() { + return + } + + now := time.Now() + r.simulationLock.Lock() + for identity, timeout := range r.disconnectSignalOnResumeParticipants { + if now.After(timeout) { + delete(r.disconnectSignalOnResumeParticipants, identity) + } + } + + for identity, state := range r.disconnectSignalOnResumeNoMessagesParticipants { + if now.After(state.expiry) { + delete(r.disconnectSignalOnResumeNoMessagesParticipants, identity) + } + } + r.simulationLock.Unlock() + + time.Sleep(10 * time.Second) + } +} + func (r *Room) launchPublisherAgent(p types.Participant) { if p == nil || p.IsRecorder() || p.IsAgent() { return @@ -1360,6 +1480,8 @@ func (r *Room) DebugInfo() map[string]interface{} { return info } +// ------------------------------------------------------------ + func BroadcastDataPacketForRoom(r types.Room, source types.LocalParticipant, dp *livekit.DataPacket, logger logger.Logger) { dest := dp.GetUser().GetDestinationSids() var dpData []byte @@ -1416,3 +1538,43 @@ func BroadcastDataPacketForRoom(r types.Room, source types.LocalParticipant, dp } }) } + +func IsCloseNotifySkippable(closeReason types.ParticipantCloseReason) bool { + return closeReason == types.ParticipantCloseReasonDuplicateIdentity +} + +func connectionDetailsFields(cds []*types.ICEConnectionDetails) []interface{} { + var fields []interface{} + connectionType := types.ICEConnectionTypeUnknown + for _, cd := range cds { + candidates := make([]string, 0, len(cd.Remote)+len(cd.Local)) + for _, c := range cd.Local { + cStr := "[local]" + if c.Selected { + cStr += "[selected]" + } else if c.Filtered { + cStr += "[filtered]" + } + cStr += " " + c.Local.String() + candidates = append(candidates, cStr) + } + for _, c := range cd.Remote { + cStr := "[remote]" + if c.Selected { + cStr += "[selected]" + } else if c.Filtered { + cStr += "[filtered]" + } + cStr += " " + c.Remote.String() + candidates = append(candidates, cStr) + } + if len(candidates) > 0 { + fields = append(fields, fmt.Sprintf("%sCandidates", strings.ToLower(cd.Transport.String())), candidates) + } + if cd.Type != types.ICEConnectionTypeUnknown { + connectionType = cd.Type + } + } + fields = append(fields, "connectionType", connectionType) + return fields +} diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go index 81dfe90c8..0febf2989 100644 --- a/pkg/rtc/room_test.go +++ b/pkg/rtc/room_test.go @@ -110,8 +110,7 @@ func TestRoomJoin(t *testing.T) { stateChangeCB := p.OnStateChangeArgsForCall(0) require.NotNil(t, stateChangeCB) - p.StateReturns(livekit.ParticipantInfo_ACTIVE) - stateChangeCB(p, livekit.ParticipantInfo_JOINED) + stateChangeCB(p, livekit.ParticipantInfo_ACTIVE) // it should become a subscriber when connectivity changes numTracks := 0 @@ -136,7 +135,7 @@ func TestRoomJoin(t *testing.T) { disconnectedParticipant := participants[1].(*typesfakes.FakeLocalParticipant) disconnectedParticipant.StateReturns(livekit.ParticipantInfo_DISCONNECTED) - rm.RemoveParticipant(p.Identity(), p.ID(), types.ParticipantCloseReasonStateDisconnected) + rm.RemoveParticipant(p.Identity(), p.ID(), types.ParticipantCloseReasonClientRequestLeave) time.Sleep(defaultDelay) require.Equal(t, p, changedParticipant) @@ -266,17 +265,18 @@ func TestPushAndDequeueUpdates(t *testing.T) { require.Equal(t, a.Version, b.Version) } testCases := []struct { - name string - pi *livekit.ParticipantInfo - immediate bool - existing *livekit.ParticipantInfo - expected []*livekit.ParticipantInfo - validate func(t *testing.T, rm *Room, updates []*livekit.ParticipantInfo) + name string + pi *livekit.ParticipantInfo + closeReason types.ParticipantCloseReason + immediate bool + existing *participantUpdate + expected []*participantUpdate + validate func(t *testing.T, rm *Room, updates []*participantUpdate) }{ { name: "publisher updates are immediate", pi: publisher1v1, - expected: []*livekit.ParticipantInfo{publisher1v1}, + expected: []*participantUpdate{{pi: publisher1v1}}, }, { name: "subscriber updates are queued", @@ -285,20 +285,20 @@ func TestPushAndDequeueUpdates(t *testing.T) { { name: "last version is enqueued", pi: subscriber1v2, - existing: subscriber1v1, - validate: func(t *testing.T, rm *Room, _ []*livekit.ParticipantInfo) { + existing: &participantUpdate{pi: proto.Clone(subscriber1v1).(*livekit.ParticipantInfo)}, // clone the existing value since it can be modified when setting to disconnected + validate: func(t *testing.T, rm *Room, _ []*participantUpdate) { queued := rm.batchedUpdates[livekit.ParticipantIdentity(identity)] require.NotNil(t, queued) - requirePIEquals(t, subscriber1v2, queued) + requirePIEquals(t, subscriber1v2, queued.pi) }, }, { name: "latest version when immediate", pi: subscriber1v2, - existing: subscriber1v1, + existing: &participantUpdate{pi: proto.Clone(subscriber1v1).(*livekit.ParticipantInfo)}, immediate: true, - expected: []*livekit.ParticipantInfo{subscriber1v2}, - validate: func(t *testing.T, rm *Room, _ []*livekit.ParticipantInfo) { + expected: []*participantUpdate{{pi: subscriber1v2}}, + validate: func(t *testing.T, rm *Room, _ []*participantUpdate) { queued := rm.batchedUpdates[livekit.ParticipantIdentity(identity)] require.Nil(t, queued) }, @@ -306,32 +306,37 @@ func TestPushAndDequeueUpdates(t *testing.T) { { name: "out of order updates are rejected", pi: subscriber1v1, - existing: subscriber1v2, - validate: func(t *testing.T, rm *Room, updates []*livekit.ParticipantInfo) { + existing: &participantUpdate{pi: proto.Clone(subscriber1v2).(*livekit.ParticipantInfo)}, + validate: func(t *testing.T, rm *Room, updates []*participantUpdate) { queued := rm.batchedUpdates[livekit.ParticipantIdentity(identity)] - requirePIEquals(t, subscriber1v2, queued) + requirePIEquals(t, subscriber1v2, queued.pi) }, }, { - name: "sid change is broadcasted immediately", - pi: publisher2, - existing: subscriber1v2, - expected: []*livekit.ParticipantInfo{ + name: "sid change is broadcasted immediately with synthsized disconnect", + pi: publisher2, + closeReason: types.ParticipantCloseReasonServiceRequestRemoveParticipant, // just to test if update contain the close reason + existing: &participantUpdate{pi: proto.Clone(subscriber1v2).(*livekit.ParticipantInfo), closeReason: types.ParticipantCloseReasonStale}, + expected: []*participantUpdate{ { - Identity: identity, - Sid: "1", - Version: 2, - State: livekit.ParticipantInfo_DISCONNECTED, + pi: &livekit.ParticipantInfo{ + Identity: identity, + Sid: "1", + Version: 2, + State: livekit.ParticipantInfo_DISCONNECTED, + }, + isSynthesizedDisconnect: true, + closeReason: types.ParticipantCloseReasonStale, }, - publisher2, + {pi: publisher2, closeReason: types.ParticipantCloseReasonServiceRequestRemoveParticipant}, }, }, { name: "when switching to publisher, queue is cleared", pi: publisher1v2, - existing: subscriber1v1, - expected: []*livekit.ParticipantInfo{publisher1v2}, - validate: func(t *testing.T, rm *Room, updates []*livekit.ParticipantInfo) { + existing: &participantUpdate{pi: proto.Clone(subscriber1v1).(*livekit.ParticipantInfo)}, + expected: []*participantUpdate{{pi: publisher1v2}}, + validate: func(t *testing.T, rm *Room, updates []*participantUpdate) { require.Empty(t, rm.batchedUpdates) }, }, @@ -341,13 +346,14 @@ func TestPushAndDequeueUpdates(t *testing.T) { t.Run(tc.name, func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 1}) if tc.existing != nil { - // clone the existing value since it can be modified when setting to disconnected - rm.batchedUpdates[livekit.ParticipantIdentity(tc.existing.Identity)] = proto.Clone(tc.existing).(*livekit.ParticipantInfo) + rm.batchedUpdates[livekit.ParticipantIdentity(tc.existing.pi.Identity)] = tc.existing } - updates := rm.pushAndDequeueUpdates(tc.pi, tc.immediate) + updates := rm.pushAndDequeueUpdates(tc.pi, tc.closeReason, tc.immediate) require.Equal(t, len(tc.expected), len(updates)) for i, item := range tc.expected { - requirePIEquals(t, item, updates[i]) + requirePIEquals(t, item.pi, updates[i].pi) + require.Equal(t, item.isSynthesizedDisconnect, updates[i].isSynthesizedDisconnect) + require.Equal(t, item.closeReason, updates[i].closeReason) } if tc.validate != nil { @@ -444,7 +450,7 @@ func TestActiveSpeakers(t *testing.T) { audioUpdateDuration := (audioUpdateInterval + 10) * time.Millisecond t.Run("participant should not be getting audio updates (protocol 2)", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 1, protocol: 2}) - defer rm.Close() + defer rm.Close(types.ParticipantCloseReasonNone) p := rm.GetParticipants()[0].(*typesfakes.FakeLocalParticipant) require.Empty(t, rm.GetActiveSpeakers()) @@ -456,7 +462,7 @@ func TestActiveSpeakers(t *testing.T) { t.Run("speakers should be sorted by loudness", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 2}) - defer rm.Close() + defer rm.Close(types.ParticipantCloseReasonNone) participants := rm.GetParticipants() p := participants[0].(*typesfakes.FakeLocalParticipant) p2 := participants[1].(*typesfakes.FakeLocalParticipant) @@ -471,7 +477,7 @@ func TestActiveSpeakers(t *testing.T) { t.Run("participants are getting audio updates (protocol 3+)", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 2, protocol: 3}) - defer rm.Close() + defer rm.Close(types.ParticipantCloseReasonNone) participants := rm.GetParticipants() p := participants[0].(*typesfakes.FakeLocalParticipant) time.Sleep(time.Millisecond) // let the first update cycle run @@ -510,7 +516,7 @@ func TestActiveSpeakers(t *testing.T) { t.Run("audio level is smoothed", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 2, protocol: 3, audioSmoothIntervals: 3}) - defer rm.Close() + defer rm.Close(types.ParticipantCloseReasonNone) participants := rm.GetParticipants() p := participants[0].(*typesfakes.FakeLocalParticipant) op := participants[1].(*typesfakes.FakeLocalParticipant) @@ -567,7 +573,7 @@ func TestDataChannel(t *testing.T) { t.Run("participants should receive data", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 3}) - defer rm.Close() + defer rm.Close(types.ParticipantCloseReasonNone) participants := rm.GetParticipants() p := participants[0].(*typesfakes.FakeLocalParticipant) @@ -597,7 +603,7 @@ func TestDataChannel(t *testing.T) { t.Run("only one participant should receive the data", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 4}) - defer rm.Close() + defer rm.Close(types.ParticipantCloseReasonNone) participants := rm.GetParticipants() p := participants[0].(*typesfakes.FakeLocalParticipant) p1 := participants[1].(*typesfakes.FakeLocalParticipant) @@ -628,7 +634,7 @@ func TestDataChannel(t *testing.T) { t.Run("publishing disallowed", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 2}) - defer rm.Close() + defer rm.Close(types.ParticipantCloseReasonNone) participants := rm.GetParticipants() p := participants[0].(*typesfakes.FakeLocalParticipant) p.CanPublishDataReturns(false) @@ -656,7 +662,7 @@ func TestDataChannel(t *testing.T) { func TestHiddenParticipants(t *testing.T) { t.Run("other participants don't receive hidden updates", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 2, numHidden: 1}) - defer rm.Close() + defer rm.Close(types.ParticipantCloseReasonNone) pNew := NewMockParticipant("new", types.CurrentProtocol, false, false) rm.Join(pNew, nil, nil, iceServersForRoom) @@ -679,17 +685,16 @@ func TestHiddenParticipants(t *testing.T) { stateChangeCB := hidden.OnStateChangeArgsForCall(0) require.NotNil(t, stateChangeCB) - hidden.StateReturns(livekit.ParticipantInfo_ACTIVE) - stateChangeCB(hidden, livekit.ParticipantInfo_JOINED) + stateChangeCB(hidden, livekit.ParticipantInfo_ACTIVE) - require.Equal(t, 2, hidden.SubscribeToTrackCallCount()) + require.Eventually(t, func() bool { return hidden.SubscribeToTrackCallCount() == 2 }, 5*time.Second, 10*time.Millisecond) }) } func TestRoomUpdate(t *testing.T) { t.Run("updates are sent when participant joined", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 1}) - defer rm.Close() + defer rm.Close(types.ParticipantCloseReasonNone) p1 := rm.GetParticipants()[0].(*typesfakes.FakeLocalParticipant) require.Equal(t, 0, p1.SendRoomUpdateCallCount()) @@ -705,7 +710,7 @@ func TestRoomUpdate(t *testing.T) { t.Run("participants should receive metadata update", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 2}) - defer rm.Close() + defer rm.Close(types.ParticipantCloseReasonNone) rm.SetMetadata("test metadata...") diff --git a/pkg/rtc/subscriptionmanager.go b/pkg/rtc/subscriptionmanager.go index 602e786fd..c53791aa5 100644 --- a/pkg/rtc/subscriptionmanager.go +++ b/pkg/rtc/subscriptionmanager.go @@ -18,6 +18,7 @@ package rtc import ( "context" + "errors" "sync" "time" @@ -503,10 +504,22 @@ func (m *SubscriptionManager) subscribe(s *trackSubscription) error { } subTrack, err := track.AddSubscriber(m.params.Participant) - if err != nil && err != errAlreadySubscribed { - // ignore already subscribed error + if err != nil && !errors.Is(err, errAlreadySubscribed) { + // ignore error(s): already subscribed + if !errors.Is(err, ErrTrackNotAttached) && !errors.Is(err, ErrNoReceiver) { + // as track resolution could take some time, not logging errors due to waiting for track resolution + m.params.Logger.Warnw("add subscriber failed", err, "trackID", trackID) + } return err } + if err == errAlreadySubscribed { + m.params.Logger.Debugw( + "already subscribed to track", + "trackID", trackID, + "subscribedAudioCount", m.subscribedAudioCount.Load(), + "subscribedVideoCount", m.subscribedVideoCount.Load(), + ) + } if err == nil && subTrack != nil { // subTrack could be nil if already subscribed subTrack.OnClose(func(willBeResumed bool) { m.handleSubscribedTrackClose(s, willBeResumed) @@ -536,9 +549,14 @@ func (m *SubscriptionManager) subscribe(s *trackSubscription) error { } go m.params.OnTrackSubscribed(subTrack) - } - m.params.Logger.Debugw("subscribed to track", "trackID", trackID, "subscribedAudioCount", m.subscribedAudioCount.Load(), "subscribedVideoCount", m.subscribedVideoCount.Load()) + m.params.Logger.Debugw( + "subscribed to track", + "trackID", trackID, + "subscribedAudioCount", m.subscribedAudioCount.Load(), + "subscribedVideoCount", m.subscribedVideoCount.Load(), + ) + } // add mark the participant as someone we've subscribed to firstSubscribe := false @@ -555,7 +573,7 @@ func (m *SubscriptionManager) subscribe(s *trackSubscription) error { m.lock.Unlock() if changedCB != nil && firstSubscribe { - go changedCB(publisherID, true) + changedCB(publisherID, true) } return nil } @@ -596,7 +614,8 @@ func (m *SubscriptionManager) handleSourceTrackRemoved(trackID livekit.TrackID) // - UpTrack was closed // - publisher revoked permissions for the participant func (m *SubscriptionManager) handleSubscribedTrackClose(s *trackSubscription, willBeResumed bool) { - s.logger.Debugw("subscribed track closed", + s.logger.Debugw( + "subscribed track closed", "willBeResumed", willBeResumed, ) wasBound := s.isBound() @@ -924,7 +943,7 @@ func (s *trackSubscription) handleSourceTrackRemoved() { } // source track removed, we would unsubscribe - s.logger.Infow("unsubscribing from track since source track was removed") + s.logger.Debugw("unsubscribing from track since source track was removed") s.desired = false s.setChangedNotifierLocked(nil) diff --git a/pkg/rtc/testutils.go b/pkg/rtc/testutils.go index 845cc4629..b63c372a8 100644 --- a/pkg/rtc/testutils.go +++ b/pkg/rtc/testutils.go @@ -62,7 +62,7 @@ func NewMockParticipant(identity livekit.ParticipantIdentity, protocol types.Pro f = p.OnTrackUpdatedArgsForCall(p.OnTrackUpdatedCallCount() - 1) } if f != nil { - f(p, nil) + f(p, NewMockTrack(livekit.TrackType_VIDEO, "testcam")) } } @@ -83,5 +83,9 @@ func NewMockTrack(kind livekit.TrackType, name string) *typesfakes.FakeMediaTrac t.IDReturns(livekit.TrackID(utils.NewGuid(utils.TrackPrefix))) t.KindReturns(kind) t.NameReturns(name) + t.ToProtoReturns(&livekit.TrackInfo{ + Type: kind, + Name: name, + }) return t } diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 16e454e78..c0e804d65 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -23,7 +23,6 @@ import ( "github.com/bep/debounce" "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/ice/v2" "github.com/pion/interceptor" "github.com/pion/interceptor/pkg/cc" "github.com/pion/interceptor/pkg/gcc" @@ -35,20 +34,18 @@ import ( "github.com/pkg/errors" "go.uber.org/atomic" + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc/transport" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/pacer" + "github.com/livekit/livekit-server/pkg/sfu/rtpextension" + "github.com/livekit/livekit-server/pkg/sfu/streamallocator" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" sutils "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/logger/pionlogger" lksdp "github.com/livekit/protocol/sdp" - "github.com/livekit/protocol/utils" - - "github.com/livekit/livekit-server/pkg/config" - "github.com/livekit/livekit-server/pkg/rtc/types" - "github.com/livekit/livekit-server/pkg/sfu/pacer" - "github.com/livekit/livekit-server/pkg/sfu/rtpextension" - "github.com/livekit/livekit-server/pkg/sfu/streamallocator" - "github.com/livekit/livekit-server/pkg/telemetry" - "github.com/livekit/livekit-server/pkg/telemetry/prometheus" ) const ( @@ -80,9 +77,6 @@ var ( ErrIceRestartOnClosedPeerConnection = errors.New("ICE restart on closed peer connection") ErrNoTransceiver = errors.New("no transceiver") ErrNoSender = errors.New("no sender") - ErrNoICECandidateHandler = errors.New("no ICE candidate handler") - ErrNoOfferHandler = errors.New("no offer handler") - ErrNoAnswerHandler = errors.New("no answer handler") ErrMidNotFound = errors.New("mid not found") ) @@ -94,7 +88,6 @@ const ( signalICEGatheringComplete signal = iota signalLocalICECandidate signalRemoteICECandidate - signalLogICECandidates signalSendOffer signalRemoteDescriptionReceived signalICERestart @@ -108,8 +101,6 @@ func (s signal) String() string { return "LOCAL_ICE_CANDIDATE" case signalRemoteICECandidate: return "REMOTE_ICE_CANDIDATE" - case signalLogICECandidates: - return "LOG_ICE_CANDIDATES" case signalSendOffer: return "SEND_OFFER" case signalRemoteDescriptionReceived: @@ -134,31 +125,6 @@ func (e event) String() string { // ------------------------------------------------------- -type NegotiationState int - -const ( - NegotiationStateNone NegotiationState = iota - // waiting for remote description - NegotiationStateRemote - // need to Negotiate again - NegotiationStateRetry -) - -func (n NegotiationState) String() string { - switch n { - case NegotiationStateNone: - return "NONE" - case NegotiationStateRemote: - return "WAITING_FOR_REMOTE" - case NegotiationStateRetry: - return "RETRY" - default: - return fmt.Sprintf("%d", int(n)) - } -} - -// ------------------------------------------------------- - type SimulcastTrackInfo struct { Mid string Rid string @@ -181,7 +147,6 @@ type PCTransport struct { reliableDCOpened bool lossyDC *webrtc.DataChannel lossyDCOpened bool - onDataPacket func(kind livekit.DataPacket_Kind, data []byte) iceStartedAt time.Time iceConnectedAt time.Time @@ -192,18 +157,10 @@ type PCTransport struct { resetShortConnOnICERestart atomic.Bool signalingRTT atomic.Uint32 // milliseconds - onFullyEstablished func() - debouncedNegotiate func(func()) debouncePending bool - onICECandidate func(c *webrtc.ICECandidate) error - onOffer func(offer webrtc.SessionDescription) error - onAnswer func(answer webrtc.SessionDescription) error - onInitialConnected func() - onFailed func(isShortLived bool) - onNegotiationStateChanged func(state NegotiationState) - onNegotiationFailed func() + onNegotiationStateChanged func(state transport.NegotiationState) // stream allocator for subscriber PC streamAllocator *streamallocator.StreamAllocator @@ -219,8 +176,7 @@ type PCTransport struct { preferTCP atomic.Bool isClosed atomic.Bool - eventChMu sync.RWMutex - eventCh chan event + eventsQueue *sutils.OpsQueue // the following should be accessed only in event processing go routine cacheLocalCandidates bool @@ -228,29 +184,26 @@ type PCTransport struct { pendingRemoteCandidates []*webrtc.ICECandidateInit restartAfterGathering bool restartAtNextOffer bool - negotiationState NegotiationState + negotiationState transport.NegotiationState negotiateCounter atomic.Int32 signalStateCheckTimer *time.Timer currentOfferIceCredential string // ice user:pwd, for publish side ice restart checking pendingRestartIceOffer *webrtc.SessionDescription - // for cleaner logging - allowedLocalCandidates *utils.DedupedSlice[string] - allowedRemoteCandidates *utils.DedupedSlice[string] - filteredLocalCandidates *utils.DedupedSlice[string] - filteredRemoteCandidates *utils.DedupedSlice[string] + connectionDetails *types.ICEConnectionDetails } type TransportParams struct { + Handler transport.Handler ParticipantID livekit.ParticipantID ParticipantIdentity livekit.ParticipantIdentity ProtocolVersion types.ProtocolVersion Config *WebRTCConfig DirectionConfig DirectionConfig CongestionControlConfig config.CongestionControlConfig - Telemetry telemetry.TelemetryService EnabledCodecs []*livekit.Codec Logger logger.Logger + Transport livekit.SignalTarget SimTracks map[uint32]SimulcastTrackInfo ClientInfo ClientInfo IsOfferer bool @@ -389,20 +342,18 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { t := &PCTransport{ params: params, debouncedNegotiate: debounce.New(negotiationFrequency), - negotiationState: NegotiationStateNone, - eventCh: make(chan event, 50), + negotiationState: transport.NegotiationStateNone, + eventsQueue: sutils.NewOpsQueue("transport", 64, false), previousTrackDescription: make(map[string]*trackDescription), canReuseTransceiver: true, - allowedLocalCandidates: utils.NewDedupedSlice[string](maxICECandidates), - allowedRemoteCandidates: utils.NewDedupedSlice[string](maxICECandidates), - filteredLocalCandidates: utils.NewDedupedSlice[string](maxICECandidates), - filteredRemoteCandidates: utils.NewDedupedSlice[string](maxICECandidates), + connectionDetails: types.NewICEConnectionDetails(params.Transport, params.Logger), } if params.IsSendSide { t.streamAllocator = streamallocator.NewStreamAllocator(streamallocator.StreamAllocatorParams{ Config: params.CongestionControlConfig, Logger: params.Logger.WithComponent(sutils.ComponentCongestionControl), }) + t.streamAllocator.OnStreamStateChange(params.Handler.OnStreamStateChange) t.streamAllocator.Start() t.pacer = pacer.NewPassThrough(params.Logger) } @@ -411,7 +362,7 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { return nil, err } - go t.processEvents() + t.eventsQueue.Start() return t, nil } @@ -433,6 +384,7 @@ func (t *PCTransport) createPeerConnection() error { t.pc.OnConnectionStateChange(t.onPeerConnectionStateChange) t.pc.OnDataChannel(t.onDataChannel) + t.pc.OnTrack(t.params.Handler.OnTrack) t.me = me @@ -466,10 +418,10 @@ func (t *PCTransport) setICEStartedAt(at time.Time) { } else if tcpICETimeout > maxTcpICEConnectTimeout { tcpICETimeout = maxTcpICEConnectTimeout } - t.params.Logger.Debugw("set tcp ice connect timer", "timeout", tcpICETimeout, "signalRTT", signalingRTT) + t.params.Logger.Debugw("set TCP ICE connect timer", "timeout", tcpICETimeout, "signalRTT", signalingRTT) t.tcpICETimer = time.AfterFunc(tcpICETimeout, func() { if t.pc.ICEConnectionState() == webrtc.ICEConnectionStateChecking { - t.params.Logger.Infow("tcp ice connect timeout", "timeout", tcpICETimeout, "signalRTT", signalingRTT) + t.params.Logger.Infow("TCP ICE connect timeout", "timeout", tcpICETimeout, "signalRTT", signalingRTT) t.handleConnectionFailed(true) } }) @@ -497,7 +449,7 @@ func (t *PCTransport) setICEConnectedAt(at time.Time) { if connTimeoutAfterICE > maxConnectTimeoutAfterICE { connTimeoutAfterICE = maxConnectTimeoutAfterICE } - t.params.Logger.Debugw("setting connection timer after ice connected", "timeout", connTimeoutAfterICE, "iceDuration", iceDuration) + t.params.Logger.Debugw("setting connection timer after ICE connected", "timeout", connTimeoutAfterICE, "iceDuration", iceDuration) t.connectAfterICETimer = time.AfterFunc(connTimeoutAfterICE, func() { state := t.pc.ConnectionState() // if pc is still checking or connected but not fully established after timeout, then fire connection fail @@ -546,12 +498,12 @@ func (t *PCTransport) IsShortConnection(at time.Time) (bool, time.Duration) { } func (t *PCTransport) getSelectedPair() (*webrtc.ICECandidatePair, error) { - sctp := t.pc.SCTP() - if sctp == nil { + s := t.pc.SCTP() + if s == nil { return nil, errors.New("no SCTP") } - dtlsTransport := sctp.Transport() + dtlsTransport := s.Transport() if dtlsTransport == nil { return nil, errors.New("no DTLS transport") } @@ -561,13 +513,16 @@ func (t *PCTransport) getSelectedPair() (*webrtc.ICECandidatePair, error) { return nil, errors.New("no ICE transport") } - return iceTransport.GetSelectedCandidatePair() -} + pair, err := iceTransport.GetSelectedCandidatePair() + if err != nil { + return nil, err + } -func (t *PCTransport) logICECandidates() { - t.postEvent(event{ - signal: signalLogICECandidates, - }) + if pair == nil { + return nil, errors.New("no selected pair") + } + + return pair, err } func (t *PCTransport) setConnectedAt(at time.Time) bool { @@ -619,9 +574,7 @@ func (t *PCTransport) handleConnectionFailed(forceShortConn bool) { t.params.Logger.Infow("force short ICE connection") } - if onFailed := t.getOnFailed(); onFailed != nil { - onFailed(isShort) - } + t.params.Handler.OnFailed(isShort) } func (t *PCTransport) onICEConnectionStateChange(state webrtc.ICEConnectionState) { @@ -629,19 +582,22 @@ func (t *PCTransport) onICEConnectionStateChange(state webrtc.ICEConnectionState switch state { case webrtc.ICEConnectionStateConnected: t.setICEConnectedAt(time.Now()) - if pair, err := t.getSelectedPair(); err != nil { - t.params.Logger.Errorw("error getting selected ICE candidate pair", err) - } else { - t.params.Logger.Infow("selected ICE candidate pair", "pair", pair) - } + go func() { + pair, err := t.getSelectedPair() + if err != nil { + t.params.Logger.Warnw("failed to get selected candidate pair", err) + return + } + t.connectionDetails.SetSelectedPair(pair) + }() case webrtc.ICEConnectionStateChecking: t.setICEStartedAt(time.Now()) case webrtc.ICEConnectionStateDisconnected: - fallthrough - case webrtc.ICEConnectionStateFailed: t.params.Logger.Infow("ice connection state change unexpected", "state", state.String()) + case webrtc.ICEConnectionStateFailed: + t.params.Logger.Debugw("ice connection state change unexpected", "state", state.String()) } } @@ -652,9 +608,7 @@ func (t *PCTransport) onPeerConnectionStateChange(state webrtc.PeerConnectionSta t.clearConnTimer() isInitialConnection := t.setConnectedAt(time.Now()) if isInitialConnection { - if onInitialConnected := t.getOnInitialConnected(); onInitialConnected != nil { - onInitialConnected() - } + t.params.Handler.OnInitialConnected() t.maybeNotifyFullyEstablished() } @@ -674,9 +628,7 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { t.reliableDCOpened = true t.lock.Unlock() dc.OnMessage(func(msg webrtc.DataChannelMessage) { - if onDataPacket := t.getOnDataPacket(); onDataPacket != nil { - onDataPacket(livekit.DataPacket_RELIABLE, msg.Data) - } + t.params.Handler.OnDataPacket(livekit.DataPacket_RELIABLE, msg.Data) }) t.maybeNotifyFullyEstablished() @@ -686,9 +638,7 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { t.lossyDCOpened = true t.lock.Unlock() dc.OnMessage(func(msg webrtc.DataChannelMessage) { - if onDataPacket := t.getOnDataPacket(); onDataPacket != nil { - onDataPacket(livekit.DataPacket_LOSSY, msg.Data) - } + t.params.Handler.OnDataPacket(livekit.DataPacket_LOSSY, msg.Data) }) t.maybeNotifyFullyEstablished() @@ -699,9 +649,7 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { func (t *PCTransport) maybeNotifyFullyEstablished() { if t.isFullyEstablished() { - if onFullyEstablished := t.getOnFullyEstablished(); onFullyEstablished != nil { - onFullyEstablished() - } + t.params.Handler.OnFullyEstablished() } } @@ -910,6 +858,10 @@ func (t *PCTransport) HasEverConnected() bool { return !t.firstConnectedAt.IsZero() } +func (t *PCTransport) GetICEConnectionDetails() *types.ICEConnectionDetails { + return t.connectionDetails +} + func (t *PCTransport) WriteRTCP(pkts []rtcp.Packet) error { return t.pc.WriteRTCP(pkts) } @@ -940,14 +892,12 @@ func (t *PCTransport) SendDataPacket(dp *livekit.DataPacket, data []byte) error } func (t *PCTransport) Close() { - t.eventChMu.Lock() if t.isClosed.Swap(true) { - t.eventChMu.Unlock() return } - close(t.eventCh) - t.eventChMu.Unlock() + <-t.eventsQueue.Stop() + t.clearSignalStateCheckTimer() if t.streamAllocator != nil { t.streamAllocator.Stop() @@ -981,128 +931,19 @@ func (t *PCTransport) HandleRemoteDescription(sd webrtc.SessionDescription) { }) } -func (t *PCTransport) OnICECandidate(f func(c *webrtc.ICECandidate) error) { - t.lock.Lock() - t.onICECandidate = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnICECandidate() func(c *webrtc.ICECandidate) error { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onICECandidate -} - -func (t *PCTransport) OnInitialConnected(f func()) { - t.lock.Lock() - t.onInitialConnected = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnInitialConnected() func() { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onInitialConnected -} - -func (t *PCTransport) OnFullyEstablished(f func()) { - t.lock.Lock() - t.onFullyEstablished = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnFullyEstablished() func() { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onFullyEstablished -} - -func (t *PCTransport) OnFailed(f func(isShortLived bool)) { - t.lock.Lock() - t.onFailed = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnFailed() func(isShortLived bool) { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onFailed -} - -func (t *PCTransport) OnTrack(f func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver)) { - t.pc.OnTrack(f) -} - -func (t *PCTransport) OnDataPacket(f func(kind livekit.DataPacket_Kind, data []byte)) { - t.lock.Lock() - t.onDataPacket = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnDataPacket() func(kind livekit.DataPacket_Kind, data []byte) { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onDataPacket -} - -// OnOffer is called when the PeerConnection starts negotiation and prepares an offer -func (t *PCTransport) OnOffer(f func(sd webrtc.SessionDescription) error) { - t.lock.Lock() - t.onOffer = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnOffer() func(sd webrtc.SessionDescription) error { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onOffer -} - -func (t *PCTransport) OnAnswer(f func(sd webrtc.SessionDescription) error) { - t.lock.Lock() - t.onAnswer = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnAnswer() func(sd webrtc.SessionDescription) error { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onAnswer -} - -func (t *PCTransport) OnNegotiationStateChanged(f func(state NegotiationState)) { +func (t *PCTransport) OnNegotiationStateChanged(f func(state transport.NegotiationState)) { t.lock.Lock() t.onNegotiationStateChanged = f t.lock.Unlock() } -func (t *PCTransport) getOnNegotiationStateChanged() func(state NegotiationState) { +func (t *PCTransport) getOnNegotiationStateChanged() func(state transport.NegotiationState) { t.lock.RLock() defer t.lock.RUnlock() return t.onNegotiationStateChanged } -func (t *PCTransport) OnNegotiationFailed(f func()) { - t.lock.Lock() - t.onNegotiationFailed = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnNegotiationFailed() func() { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onNegotiationFailed -} - func (t *PCTransport) Negotiate(force bool) { if t.isClosed.Load() { return @@ -1153,14 +994,6 @@ func (t *PCTransport) ResetShortConnOnICERestart() { t.resetShortConnOnICERestart.Store(true) } -func (t *PCTransport) OnStreamStateChange(f func(update *streamallocator.StreamStateUpdate) error) { - if t.streamAllocator == nil { - return - } - - t.streamAllocator.OnStreamStateChange(f) -} - func (t *PCTransport) AddTrackToStreamAllocator(subTrack types.SubscribedTrack) { if t.streamAllocator == nil { return @@ -1197,44 +1030,6 @@ func (t *PCTransport) SetChannelCapacityOfStreamAllocator(channelCapacity int64) t.streamAllocator.SetChannelCapacity(channelCapacity) } -func (t *PCTransport) GetICEConnectionType() types.ICEConnectionType { - unknown := types.ICEConnectionTypeUnknown - if t.pc == nil { - return unknown - } - p, err := t.getSelectedPair() - if err != nil || p == nil { - return unknown - } - - if p.Remote.Typ == webrtc.ICECandidateTypeRelay { - return types.ICEConnectionTypeTURN - } else if p.Remote.Typ == webrtc.ICECandidateTypePrflx { - // if the remote relay candidate pings us *before* we get a relay candidate, - // Pion would have created a prflx candidate with the same address as the relay candidate. - // to report an accurate connection type, we'll compare to see if existing relay candidates match - t.lock.RLock() - allowedRemoteCandidates := t.allowedRemoteCandidates.Get() - t.lock.RUnlock() - - for _, ci := range allowedRemoteCandidates { - candidateValue := strings.TrimPrefix(ci, "candidate:") - candidate, err := ice.UnmarshalCandidate(candidateValue) - if err == nil && candidate.Type() == ice.CandidateTypeRelay { - if p.Remote.Address == candidate.Address() && - p.Remote.Port == uint16(candidate.Port()) && - p.Remote.Protocol.String() == candidate.NetworkType().NetworkShort() { - return types.ICEConnectionTypeTURN - } - } - } - } - if p.Remote.Protocol == webrtc.ICEProtocolTCP { - return types.ICEConnectionTypeTCP - } - return types.ICEConnectionTypeUDP -} - func (t *PCTransport) preparePC(previousAnswer webrtc.SessionDescription) error { // sticky data channel to first m-lines, if someday we don't send sdp without media streams to // client's subscribe pc after joining, should change this step @@ -1262,7 +1057,7 @@ func (t *PCTransport) preparePC(previousAnswer webrtc.SessionDescription) error // trying to replicate previous setup, read from previous answer and use that role. // se := webrtc.SettingEngine{} - se.SetAnsweringDTLSRole(lksdp.ExtractDTLSRole(parsed)) + _ = se.SetAnsweringDTLSRole(lksdp.ExtractDTLSRole(parsed)) api := webrtc.NewAPI( webrtc.WithSettingEngine(se), webrtc.WithMediaEngine(t.me), @@ -1360,9 +1155,7 @@ func (t *PCTransport) initPCWithPreviousAnswer(previousAnswer webrtc.SessionDesc func (t *PCTransport) SetPreviousSdp(offer, answer *webrtc.SessionDescription) { // when there is no previous answer, cannot migrate, force a full reconnect if answer == nil { - if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { - onNegotiationFailed() - } + t.params.Handler.OnNegotiationFailed() return } @@ -1373,9 +1166,7 @@ func (t *PCTransport) SetPreviousSdp(offer, answer *webrtc.SessionDescription) { t.params.Logger.Errorw("initPCWithPreviousAnswer failed", err) t.lock.Unlock() - if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { - onNegotiationFailed() - } + t.params.Handler.OnNegotiationFailed() return } else if offer != nil { // in migration case, can't reuse transceiver before negotiated except track subscribed at previous node @@ -1417,40 +1208,15 @@ func (t *PCTransport) parseTrackMid(offer webrtc.SessionDescription, senders map } func (t *PCTransport) postEvent(event event) { - t.eventChMu.RLock() - if t.isClosed.Load() { - t.eventChMu.RUnlock() - return - } - - select { - case t.eventCh <- event: - default: - t.params.Logger.Warnw("event queue full", nil, "event", event.String()) - } - t.eventChMu.RUnlock() -} - -func (t *PCTransport) processEvents() { - for event := range t.eventCh { - if t.isClosed.Load() { - // just drain the channel without processing events - continue - } - + t.eventsQueue.Enqueue(func() { err := t.handleEvent(&event) if err != nil { - t.params.Logger.Errorw("error handling event", err, "event", event.String()) - if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { - onNegotiationFailed() + if !t.isClosed.Load() { + t.params.Logger.Errorw("error handling event", err, "event", event.String()) + t.params.Handler.OnNegotiationFailed() } - break } - } - - t.clearSignalStateCheckTimer() - t.params.Logger.Debugw("leaving events processor") - t.handleLogICECandidates(nil) + }) } func (t *PCTransport) handleEvent(e *event) error { @@ -1461,8 +1227,6 @@ func (t *PCTransport) handleEvent(e *event) error { return t.handleLocalICECandidate(e) case signalRemoteICECandidate: return t.handleRemoteICECandidate(e) - case signalLogICECandidates: - return t.handleLogICECandidates(e) case signalSendOffer: return t.handleSendOffer(e) case signalRemoteDescriptionReceived: @@ -1474,7 +1238,7 @@ func (t *PCTransport) handleEvent(e *event) error { return nil } -func (t *PCTransport) handleICEGatheringComplete(e *event) error { +func (t *PCTransport) handleICEGatheringComplete(_ *event) error { if t.params.IsOfferer { return t.handleICEGatheringCompleteOfferer() } else { @@ -1518,59 +1282,45 @@ func (t *PCTransport) localDescriptionSent() error { cachedLocalCandidates := t.cachedLocalCandidates t.cachedLocalCandidates = nil - if onICECandidate := t.getOnICECandidate(); onICECandidate != nil { - for _, c := range cachedLocalCandidates { - if err := onICECandidate(c); err != nil { - return err - } + for _, c := range cachedLocalCandidates { + if err := t.params.Handler.OnICECandidate(c, t.params.Transport); err != nil { + return err } - - return nil } - - return ErrNoICECandidateHandler + return nil } func (t *PCTransport) clearLocalDescriptionSent() { t.cacheLocalCandidates = true t.cachedLocalCandidates = nil - - t.allowedLocalCandidates.Clear() - t.lock.Lock() - t.allowedRemoteCandidates.Clear() - t.lock.Unlock() - t.filteredLocalCandidates.Clear() - t.filteredRemoteCandidates.Clear() + t.connectionDetails.Clear() } func (t *PCTransport) handleLocalICECandidate(e *event) error { c := e.data.(*webrtc.ICECandidate) filtered := false - if t.preferTCP.Load() && c != nil && c.Protocol != webrtc.ICEProtocolTCP { - cstr := c.String() - t.params.Logger.Debugw("filtering out local candidate", "candidate", cstr) - t.filteredLocalCandidates.Add(cstr) - filtered = true + if c != nil { + if t.preferTCP.Load() && c.Protocol != webrtc.ICEProtocolTCP { + t.params.Logger.Debugw("filtering out local candidate", + "candidate", func() interface{} { + return c.String() + }) + filtered = true + } + t.connectionDetails.AddLocalCandidate(c, filtered) } if filtered { return nil } - if c != nil { - t.allowedLocalCandidates.Add(c.String()) - } if t.cacheLocalCandidates { t.cachedLocalCandidates = append(t.cachedLocalCandidates, c) return nil } - if onICECandidate := t.getOnICECandidate(); onICECandidate != nil { - return onICECandidate(c) - } - - return ErrNoICECandidateHandler + return t.params.Handler.OnICECandidate(c, t.params.Transport) } func (t *PCTransport) handleRemoteICECandidate(e *event) error { @@ -1579,18 +1329,14 @@ func (t *PCTransport) handleRemoteICECandidate(e *event) error { filtered := false if t.preferTCP.Load() && !strings.Contains(c.Candidate, "tcp") { t.params.Logger.Debugw("filtering out remote candidate", "candidate", c.Candidate) - t.filteredRemoteCandidates.Add(c.Candidate) filtered = true } + t.connectionDetails.AddRemoteCandidate(*c, filtered) if filtered { return nil } - t.lock.Lock() - t.allowedRemoteCandidates.Add(c.Candidate) - t.lock.Unlock() - if t.pc.RemoteDescription() == nil { t.pendingRemoteCandidates = append(t.pendingRemoteCandidates, c) return nil @@ -1603,23 +1349,7 @@ func (t *PCTransport) handleRemoteICECandidate(e *event) error { return nil } -func (t *PCTransport) handleLogICECandidates(e *event) error { - lc := t.allowedLocalCandidates.Get() - rc := t.allowedRemoteCandidates.Get() - if len(lc) != 0 || len(rc) != 0 { - t.params.Logger.Infow( - "ice candidates", - "lc", lc, - "rc", rc, - "lc (filtered)", t.filteredLocalCandidates.Get(), - "rc (filtered)", t.filteredRemoteCandidates.Get(), - ) - } - - return nil -} - -func (t *PCTransport) setNegotiationState(state NegotiationState) { +func (t *PCTransport) setNegotiationState(state transport.NegotiationState) { t.negotiationState = state if onNegotiationStateChanged := t.getOnNegotiationStateChanged(); onNegotiationStateChanged != nil { onNegotiationStateChanged(t.negotiationState) @@ -1680,9 +1410,9 @@ func (t *PCTransport) setupSignalStateCheckTimer() { t.signalStateCheckTimer = time.AfterFunc(negotiationFailedTimeout, func() { t.clearSignalStateCheckTimer() - failed := t.negotiationState != NegotiationStateNone + failed := t.negotiationState != transport.NegotiationStateNone - if t.negotiateCounter.Load() == negotiateVersion && failed { + if t.negotiateCounter.Load() == negotiateVersion && failed && t.pc.ConnectionState() == webrtc.PeerConnectionStateConnected { t.params.Logger.Infow( "negotiation timed out", "localCurrent", t.pc.CurrentLocalDescription(), @@ -1690,9 +1420,7 @@ func (t *PCTransport) setupSignalStateCheckTimer() { "remoteCurrent", t.pc.CurrentRemoteDescription(), "remotePending", t.pc.PendingRemoteDescription(), ) - if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { - onNegotiationFailed() - } + t.params.Handler.OnNegotiationFailed() } }) } @@ -1704,11 +1432,11 @@ func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { } // when there's an ongoing negotiation, let it finish and not disrupt its state - if t.negotiationState == NegotiationStateRemote { - t.params.Logger.Infow("skipping negotiation, trying again later") - t.setNegotiationState(NegotiationStateRetry) + if t.negotiationState == transport.NegotiationStateRemote { + t.params.Logger.Debugw("skipping negotiation, trying again later") + t.setNegotiationState(transport.NegotiationStateRetry) return nil - } else if t.negotiationState == NegotiationStateRetry { + } else if t.negotiationState == transport.NegotiationStateRetry { // already set to retry, we can safely skip this attempt return nil } @@ -1778,24 +1506,20 @@ func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { } // indicate waiting for remote - t.setNegotiationState(NegotiationStateRemote) + t.setNegotiationState(transport.NegotiationStateRemote) t.setupSignalStateCheckTimer() - if onOffer := t.getOnOffer(); onOffer != nil { - if err := onOffer(offer); err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) - return errors.Wrap(err, "could not send offer") - } - - prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) - return t.localDescriptionSent() + if err := t.params.Handler.OnOffer(offer); err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) + return errors.Wrap(err, "could not send offer") } - return ErrNoOfferHandler + prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) + return t.localDescriptionSent() } -func (t *PCTransport) handleSendOffer(e *event) error { +func (t *PCTransport) handleSendOffer(_ *event) error { return t.createAndSendOffer(nil) } @@ -1899,17 +1623,13 @@ func (t *PCTransport) createAndSendAnswer() error { t.params.Logger.Debugw("local answer (filtered)", "sdp", answer.SDP) } - if onAnswer := t.getOnAnswer(); onAnswer != nil { - if err := onAnswer(answer); err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "write_message").Add(1) - return errors.Wrap(err, "could not send answer") - } - - prometheus.ServiceOperationCounter.WithLabelValues("answer", "success", "").Add(1) - return t.localDescriptionSent() + if err := t.params.Handler.OnAnswer(answer); err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "write_message").Add(1) + return errors.Wrap(err, "could not send answer") } - return ErrNoAnswerHandler + prometheus.ServiceOperationCounter.WithLabelValues("answer", "success", "").Add(1) + return t.localDescriptionSent() } func (t *PCTransport) handleRemoteOfferReceived(sd *webrtc.SessionDescription) error { @@ -1944,6 +1664,8 @@ func (t *PCTransport) handleRemoteOfferReceived(sd *webrtc.SessionDescription) e } func (t *PCTransport) handleRemoteAnswerReceived(sd *webrtc.SessionDescription) error { + t.clearSignalStateCheckTimer() + if err := t.setRemoteDescription(*sd); err != nil { // Pion will call RTPSender.Send method for each new added Downtrack, and return error if the DownTrack.Bind // returns error. In case of Downtrack.Bind returns ErrUnsupportedCodec, the signal state will be stable as negotiation is aleady compelted @@ -1954,16 +1676,14 @@ func (t *PCTransport) handleRemoteAnswerReceived(sd *webrtc.SessionDescription) } } - t.clearSignalStateCheckTimer() - - if t.negotiationState == NegotiationStateRetry { - t.setNegotiationState(NegotiationStateNone) + if t.negotiationState == transport.NegotiationStateRetry { + t.setNegotiationState(transport.NegotiationStateNone) t.params.Logger.Debugw("re-negotiate after receiving answer") return t.createAndSendOffer(nil) } - t.setNegotiationState(NegotiationStateNone) + t.setNegotiationState(transport.NegotiationStateNone) return nil } @@ -1984,7 +1704,7 @@ func (t *PCTransport) doICERestart() error { t.resetShortConn() } - if t.negotiationState == NegotiationStateNone { + if t.negotiationState == transport.NegotiationStateNone { return t.createAndSendOffer(&webrtc.OfferOptions{ICERestart: true}) } @@ -1998,18 +1718,15 @@ func (t *PCTransport) doICERestart() error { return ErrIceRestartWithoutLocalSDP } else { t.params.Logger.Infow("deferring ice restart to next offer") - t.setNegotiationState(NegotiationStateRetry) + t.setNegotiationState(transport.NegotiationStateRetry) t.restartAtNextOffer = true - if onOffer := t.getOnOffer(); onOffer != nil { - err := onOffer(*offer) - if err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) - } else { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) - } - return err + err := t.params.Handler.OnOffer(*offer) + if err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) + } else { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) } - return ErrNoOfferHandler + return err } } else { // recover by re-applying the last answer @@ -2018,13 +1735,13 @@ func (t *PCTransport) doICERestart() error { prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "remote_description").Add(1) return errors.Wrap(err, "set remote description failed") } else { - t.setNegotiationState(NegotiationStateNone) + t.setNegotiationState(transport.NegotiationStateNone) return t.createAndSendOffer(&webrtc.OfferOptions{ICERestart: true}) } } } -func (t *PCTransport) handleICERestart(e *event) error { +func (t *PCTransport) handleICERestart(_ *event) error { return t.doICERestart() } diff --git a/pkg/rtc/transport/handler.go b/pkg/rtc/transport/handler.go new file mode 100644 index 000000000..eab3de349 --- /dev/null +++ b/pkg/rtc/transport/handler.go @@ -0,0 +1,55 @@ +package transport + +import ( + "errors" + + "github.com/pion/webrtc/v3" + + "github.com/livekit/livekit-server/pkg/sfu/streamallocator" + "github.com/livekit/protocol/livekit" +) + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +var ( + ErrNoICECandidateHandler = errors.New("no ICE candidate handler") + ErrNoOfferHandler = errors.New("no offer handler") + ErrNoAnswerHandler = errors.New("no answer handler") +) + +//counterfeiter:generate . Handler +type Handler interface { + OnICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error + OnInitialConnected() + OnFullyEstablished() + OnFailed(isShortLived bool) + OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) + OnDataPacket(kind livekit.DataPacket_Kind, data []byte) + OnOffer(sd webrtc.SessionDescription) error + OnAnswer(sd webrtc.SessionDescription) error + OnNegotiationStateChanged(state NegotiationState) + OnNegotiationFailed() + OnStreamStateChange(update *streamallocator.StreamStateUpdate) error +} + +type UnimplementedHandler struct{} + +func (h UnimplementedHandler) OnICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error { + return ErrNoICECandidateHandler +} +func (h UnimplementedHandler) OnInitialConnected() {} +func (h UnimplementedHandler) OnFullyEstablished() {} +func (h UnimplementedHandler) OnFailed(isShortLived bool) {} +func (h UnimplementedHandler) OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) {} +func (h UnimplementedHandler) OnDataPacket(kind livekit.DataPacket_Kind, data []byte) {} +func (h UnimplementedHandler) OnOffer(sd webrtc.SessionDescription) error { + return ErrNoOfferHandler +} +func (h UnimplementedHandler) OnAnswer(sd webrtc.SessionDescription) error { + return ErrNoAnswerHandler +} +func (h UnimplementedHandler) OnNegotiationStateChanged(state NegotiationState) {} +func (h UnimplementedHandler) OnNegotiationFailed() {} +func (h UnimplementedHandler) OnStreamStateChange(update *streamallocator.StreamStateUpdate) error { + return nil +} diff --git a/pkg/rtc/transport/negotiationstate.go b/pkg/rtc/transport/negotiationstate.go new file mode 100644 index 000000000..8f074f83f --- /dev/null +++ b/pkg/rtc/transport/negotiationstate.go @@ -0,0 +1,26 @@ +package transport + +import "fmt" + +type NegotiationState int + +const ( + NegotiationStateNone NegotiationState = iota + // waiting for remote description + NegotiationStateRemote + // need to Negotiate again + NegotiationStateRetry +) + +func (n NegotiationState) String() string { + switch n { + case NegotiationStateNone: + return "NONE" + case NegotiationStateRemote: + return "WAITING_FOR_REMOTE" + case NegotiationStateRetry: + return "RETRY" + default: + return fmt.Sprintf("%d", int(n)) + } +} diff --git a/pkg/rtc/transport/transportfakes/fake_handler.go b/pkg/rtc/transport/transportfakes/fake_handler.go new file mode 100644 index 000000000..dc7ad3c3e --- /dev/null +++ b/pkg/rtc/transport/transportfakes/fake_handler.go @@ -0,0 +1,593 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package transportfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/transport" + "github.com/livekit/livekit-server/pkg/sfu/streamallocator" + "github.com/livekit/protocol/livekit" + webrtc "github.com/pion/webrtc/v3" +) + +type FakeHandler struct { + OnAnswerStub func(webrtc.SessionDescription) error + onAnswerMutex sync.RWMutex + onAnswerArgsForCall []struct { + arg1 webrtc.SessionDescription + } + onAnswerReturns struct { + result1 error + } + onAnswerReturnsOnCall map[int]struct { + result1 error + } + OnDataPacketStub func(livekit.DataPacket_Kind, []byte) + onDataPacketMutex sync.RWMutex + onDataPacketArgsForCall []struct { + arg1 livekit.DataPacket_Kind + arg2 []byte + } + OnFailedStub func(bool) + onFailedMutex sync.RWMutex + onFailedArgsForCall []struct { + arg1 bool + } + OnFullyEstablishedStub func() + onFullyEstablishedMutex sync.RWMutex + onFullyEstablishedArgsForCall []struct { + } + OnICECandidateStub func(*webrtc.ICECandidate, livekit.SignalTarget) error + onICECandidateMutex sync.RWMutex + onICECandidateArgsForCall []struct { + arg1 *webrtc.ICECandidate + arg2 livekit.SignalTarget + } + onICECandidateReturns struct { + result1 error + } + onICECandidateReturnsOnCall map[int]struct { + result1 error + } + OnInitialConnectedStub func() + onInitialConnectedMutex sync.RWMutex + onInitialConnectedArgsForCall []struct { + } + OnNegotiationFailedStub func() + onNegotiationFailedMutex sync.RWMutex + onNegotiationFailedArgsForCall []struct { + } + OnNegotiationStateChangedStub func(transport.NegotiationState) + onNegotiationStateChangedMutex sync.RWMutex + onNegotiationStateChangedArgsForCall []struct { + arg1 transport.NegotiationState + } + OnOfferStub func(webrtc.SessionDescription) error + onOfferMutex sync.RWMutex + onOfferArgsForCall []struct { + arg1 webrtc.SessionDescription + } + onOfferReturns struct { + result1 error + } + onOfferReturnsOnCall map[int]struct { + result1 error + } + OnStreamStateChangeStub func(*streamallocator.StreamStateUpdate) error + onStreamStateChangeMutex sync.RWMutex + onStreamStateChangeArgsForCall []struct { + arg1 *streamallocator.StreamStateUpdate + } + onStreamStateChangeReturns struct { + result1 error + } + onStreamStateChangeReturnsOnCall map[int]struct { + result1 error + } + OnTrackStub func(*webrtc.TrackRemote, *webrtc.RTPReceiver) + onTrackMutex sync.RWMutex + onTrackArgsForCall []struct { + arg1 *webrtc.TrackRemote + arg2 *webrtc.RTPReceiver + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHandler) OnAnswer(arg1 webrtc.SessionDescription) error { + fake.onAnswerMutex.Lock() + ret, specificReturn := fake.onAnswerReturnsOnCall[len(fake.onAnswerArgsForCall)] + fake.onAnswerArgsForCall = append(fake.onAnswerArgsForCall, struct { + arg1 webrtc.SessionDescription + }{arg1}) + stub := fake.OnAnswerStub + fakeReturns := fake.onAnswerReturns + fake.recordInvocation("OnAnswer", []interface{}{arg1}) + fake.onAnswerMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnAnswerCallCount() int { + fake.onAnswerMutex.RLock() + defer fake.onAnswerMutex.RUnlock() + return len(fake.onAnswerArgsForCall) +} + +func (fake *FakeHandler) OnAnswerCalls(stub func(webrtc.SessionDescription) error) { + fake.onAnswerMutex.Lock() + defer fake.onAnswerMutex.Unlock() + fake.OnAnswerStub = stub +} + +func (fake *FakeHandler) OnAnswerArgsForCall(i int) webrtc.SessionDescription { + fake.onAnswerMutex.RLock() + defer fake.onAnswerMutex.RUnlock() + argsForCall := fake.onAnswerArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnAnswerReturns(result1 error) { + fake.onAnswerMutex.Lock() + defer fake.onAnswerMutex.Unlock() + fake.OnAnswerStub = nil + fake.onAnswerReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnAnswerReturnsOnCall(i int, result1 error) { + fake.onAnswerMutex.Lock() + defer fake.onAnswerMutex.Unlock() + fake.OnAnswerStub = nil + if fake.onAnswerReturnsOnCall == nil { + fake.onAnswerReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onAnswerReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnDataPacket(arg1 livekit.DataPacket_Kind, arg2 []byte) { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.onDataPacketMutex.Lock() + fake.onDataPacketArgsForCall = append(fake.onDataPacketArgsForCall, struct { + arg1 livekit.DataPacket_Kind + arg2 []byte + }{arg1, arg2Copy}) + stub := fake.OnDataPacketStub + fake.recordInvocation("OnDataPacket", []interface{}{arg1, arg2Copy}) + fake.onDataPacketMutex.Unlock() + if stub != nil { + fake.OnDataPacketStub(arg1, arg2) + } +} + +func (fake *FakeHandler) OnDataPacketCallCount() int { + fake.onDataPacketMutex.RLock() + defer fake.onDataPacketMutex.RUnlock() + return len(fake.onDataPacketArgsForCall) +} + +func (fake *FakeHandler) OnDataPacketCalls(stub func(livekit.DataPacket_Kind, []byte)) { + fake.onDataPacketMutex.Lock() + defer fake.onDataPacketMutex.Unlock() + fake.OnDataPacketStub = stub +} + +func (fake *FakeHandler) OnDataPacketArgsForCall(i int) (livekit.DataPacket_Kind, []byte) { + fake.onDataPacketMutex.RLock() + defer fake.onDataPacketMutex.RUnlock() + argsForCall := fake.onDataPacketArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnFailed(arg1 bool) { + fake.onFailedMutex.Lock() + fake.onFailedArgsForCall = append(fake.onFailedArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.OnFailedStub + fake.recordInvocation("OnFailed", []interface{}{arg1}) + fake.onFailedMutex.Unlock() + if stub != nil { + fake.OnFailedStub(arg1) + } +} + +func (fake *FakeHandler) OnFailedCallCount() int { + fake.onFailedMutex.RLock() + defer fake.onFailedMutex.RUnlock() + return len(fake.onFailedArgsForCall) +} + +func (fake *FakeHandler) OnFailedCalls(stub func(bool)) { + fake.onFailedMutex.Lock() + defer fake.onFailedMutex.Unlock() + fake.OnFailedStub = stub +} + +func (fake *FakeHandler) OnFailedArgsForCall(i int) bool { + fake.onFailedMutex.RLock() + defer fake.onFailedMutex.RUnlock() + argsForCall := fake.onFailedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnFullyEstablished() { + fake.onFullyEstablishedMutex.Lock() + fake.onFullyEstablishedArgsForCall = append(fake.onFullyEstablishedArgsForCall, struct { + }{}) + stub := fake.OnFullyEstablishedStub + fake.recordInvocation("OnFullyEstablished", []interface{}{}) + fake.onFullyEstablishedMutex.Unlock() + if stub != nil { + fake.OnFullyEstablishedStub() + } +} + +func (fake *FakeHandler) OnFullyEstablishedCallCount() int { + fake.onFullyEstablishedMutex.RLock() + defer fake.onFullyEstablishedMutex.RUnlock() + return len(fake.onFullyEstablishedArgsForCall) +} + +func (fake *FakeHandler) OnFullyEstablishedCalls(stub func()) { + fake.onFullyEstablishedMutex.Lock() + defer fake.onFullyEstablishedMutex.Unlock() + fake.OnFullyEstablishedStub = stub +} + +func (fake *FakeHandler) OnICECandidate(arg1 *webrtc.ICECandidate, arg2 livekit.SignalTarget) error { + fake.onICECandidateMutex.Lock() + ret, specificReturn := fake.onICECandidateReturnsOnCall[len(fake.onICECandidateArgsForCall)] + fake.onICECandidateArgsForCall = append(fake.onICECandidateArgsForCall, struct { + arg1 *webrtc.ICECandidate + arg2 livekit.SignalTarget + }{arg1, arg2}) + stub := fake.OnICECandidateStub + fakeReturns := fake.onICECandidateReturns + fake.recordInvocation("OnICECandidate", []interface{}{arg1, arg2}) + fake.onICECandidateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnICECandidateCallCount() int { + fake.onICECandidateMutex.RLock() + defer fake.onICECandidateMutex.RUnlock() + return len(fake.onICECandidateArgsForCall) +} + +func (fake *FakeHandler) OnICECandidateCalls(stub func(*webrtc.ICECandidate, livekit.SignalTarget) error) { + fake.onICECandidateMutex.Lock() + defer fake.onICECandidateMutex.Unlock() + fake.OnICECandidateStub = stub +} + +func (fake *FakeHandler) OnICECandidateArgsForCall(i int) (*webrtc.ICECandidate, livekit.SignalTarget) { + fake.onICECandidateMutex.RLock() + defer fake.onICECandidateMutex.RUnlock() + argsForCall := fake.onICECandidateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnICECandidateReturns(result1 error) { + fake.onICECandidateMutex.Lock() + defer fake.onICECandidateMutex.Unlock() + fake.OnICECandidateStub = nil + fake.onICECandidateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnICECandidateReturnsOnCall(i int, result1 error) { + fake.onICECandidateMutex.Lock() + defer fake.onICECandidateMutex.Unlock() + fake.OnICECandidateStub = nil + if fake.onICECandidateReturnsOnCall == nil { + fake.onICECandidateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onICECandidateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnInitialConnected() { + fake.onInitialConnectedMutex.Lock() + fake.onInitialConnectedArgsForCall = append(fake.onInitialConnectedArgsForCall, struct { + }{}) + stub := fake.OnInitialConnectedStub + fake.recordInvocation("OnInitialConnected", []interface{}{}) + fake.onInitialConnectedMutex.Unlock() + if stub != nil { + fake.OnInitialConnectedStub() + } +} + +func (fake *FakeHandler) OnInitialConnectedCallCount() int { + fake.onInitialConnectedMutex.RLock() + defer fake.onInitialConnectedMutex.RUnlock() + return len(fake.onInitialConnectedArgsForCall) +} + +func (fake *FakeHandler) OnInitialConnectedCalls(stub func()) { + fake.onInitialConnectedMutex.Lock() + defer fake.onInitialConnectedMutex.Unlock() + fake.OnInitialConnectedStub = stub +} + +func (fake *FakeHandler) OnNegotiationFailed() { + fake.onNegotiationFailedMutex.Lock() + fake.onNegotiationFailedArgsForCall = append(fake.onNegotiationFailedArgsForCall, struct { + }{}) + stub := fake.OnNegotiationFailedStub + fake.recordInvocation("OnNegotiationFailed", []interface{}{}) + fake.onNegotiationFailedMutex.Unlock() + if stub != nil { + fake.OnNegotiationFailedStub() + } +} + +func (fake *FakeHandler) OnNegotiationFailedCallCount() int { + fake.onNegotiationFailedMutex.RLock() + defer fake.onNegotiationFailedMutex.RUnlock() + return len(fake.onNegotiationFailedArgsForCall) +} + +func (fake *FakeHandler) OnNegotiationFailedCalls(stub func()) { + fake.onNegotiationFailedMutex.Lock() + defer fake.onNegotiationFailedMutex.Unlock() + fake.OnNegotiationFailedStub = stub +} + +func (fake *FakeHandler) OnNegotiationStateChanged(arg1 transport.NegotiationState) { + fake.onNegotiationStateChangedMutex.Lock() + fake.onNegotiationStateChangedArgsForCall = append(fake.onNegotiationStateChangedArgsForCall, struct { + arg1 transport.NegotiationState + }{arg1}) + stub := fake.OnNegotiationStateChangedStub + fake.recordInvocation("OnNegotiationStateChanged", []interface{}{arg1}) + fake.onNegotiationStateChangedMutex.Unlock() + if stub != nil { + fake.OnNegotiationStateChangedStub(arg1) + } +} + +func (fake *FakeHandler) OnNegotiationStateChangedCallCount() int { + fake.onNegotiationStateChangedMutex.RLock() + defer fake.onNegotiationStateChangedMutex.RUnlock() + return len(fake.onNegotiationStateChangedArgsForCall) +} + +func (fake *FakeHandler) OnNegotiationStateChangedCalls(stub func(transport.NegotiationState)) { + fake.onNegotiationStateChangedMutex.Lock() + defer fake.onNegotiationStateChangedMutex.Unlock() + fake.OnNegotiationStateChangedStub = stub +} + +func (fake *FakeHandler) OnNegotiationStateChangedArgsForCall(i int) transport.NegotiationState { + fake.onNegotiationStateChangedMutex.RLock() + defer fake.onNegotiationStateChangedMutex.RUnlock() + argsForCall := fake.onNegotiationStateChangedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnOffer(arg1 webrtc.SessionDescription) error { + fake.onOfferMutex.Lock() + ret, specificReturn := fake.onOfferReturnsOnCall[len(fake.onOfferArgsForCall)] + fake.onOfferArgsForCall = append(fake.onOfferArgsForCall, struct { + arg1 webrtc.SessionDescription + }{arg1}) + stub := fake.OnOfferStub + fakeReturns := fake.onOfferReturns + fake.recordInvocation("OnOffer", []interface{}{arg1}) + fake.onOfferMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnOfferCallCount() int { + fake.onOfferMutex.RLock() + defer fake.onOfferMutex.RUnlock() + return len(fake.onOfferArgsForCall) +} + +func (fake *FakeHandler) OnOfferCalls(stub func(webrtc.SessionDescription) error) { + fake.onOfferMutex.Lock() + defer fake.onOfferMutex.Unlock() + fake.OnOfferStub = stub +} + +func (fake *FakeHandler) OnOfferArgsForCall(i int) webrtc.SessionDescription { + fake.onOfferMutex.RLock() + defer fake.onOfferMutex.RUnlock() + argsForCall := fake.onOfferArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnOfferReturns(result1 error) { + fake.onOfferMutex.Lock() + defer fake.onOfferMutex.Unlock() + fake.OnOfferStub = nil + fake.onOfferReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnOfferReturnsOnCall(i int, result1 error) { + fake.onOfferMutex.Lock() + defer fake.onOfferMutex.Unlock() + fake.OnOfferStub = nil + if fake.onOfferReturnsOnCall == nil { + fake.onOfferReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onOfferReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnStreamStateChange(arg1 *streamallocator.StreamStateUpdate) error { + fake.onStreamStateChangeMutex.Lock() + ret, specificReturn := fake.onStreamStateChangeReturnsOnCall[len(fake.onStreamStateChangeArgsForCall)] + fake.onStreamStateChangeArgsForCall = append(fake.onStreamStateChangeArgsForCall, struct { + arg1 *streamallocator.StreamStateUpdate + }{arg1}) + stub := fake.OnStreamStateChangeStub + fakeReturns := fake.onStreamStateChangeReturns + fake.recordInvocation("OnStreamStateChange", []interface{}{arg1}) + fake.onStreamStateChangeMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnStreamStateChangeCallCount() int { + fake.onStreamStateChangeMutex.RLock() + defer fake.onStreamStateChangeMutex.RUnlock() + return len(fake.onStreamStateChangeArgsForCall) +} + +func (fake *FakeHandler) OnStreamStateChangeCalls(stub func(*streamallocator.StreamStateUpdate) error) { + fake.onStreamStateChangeMutex.Lock() + defer fake.onStreamStateChangeMutex.Unlock() + fake.OnStreamStateChangeStub = stub +} + +func (fake *FakeHandler) OnStreamStateChangeArgsForCall(i int) *streamallocator.StreamStateUpdate { + fake.onStreamStateChangeMutex.RLock() + defer fake.onStreamStateChangeMutex.RUnlock() + argsForCall := fake.onStreamStateChangeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnStreamStateChangeReturns(result1 error) { + fake.onStreamStateChangeMutex.Lock() + defer fake.onStreamStateChangeMutex.Unlock() + fake.OnStreamStateChangeStub = nil + fake.onStreamStateChangeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnStreamStateChangeReturnsOnCall(i int, result1 error) { + fake.onStreamStateChangeMutex.Lock() + defer fake.onStreamStateChangeMutex.Unlock() + fake.OnStreamStateChangeStub = nil + if fake.onStreamStateChangeReturnsOnCall == nil { + fake.onStreamStateChangeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onStreamStateChangeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnTrack(arg1 *webrtc.TrackRemote, arg2 *webrtc.RTPReceiver) { + fake.onTrackMutex.Lock() + fake.onTrackArgsForCall = append(fake.onTrackArgsForCall, struct { + arg1 *webrtc.TrackRemote + arg2 *webrtc.RTPReceiver + }{arg1, arg2}) + stub := fake.OnTrackStub + fake.recordInvocation("OnTrack", []interface{}{arg1, arg2}) + fake.onTrackMutex.Unlock() + if stub != nil { + fake.OnTrackStub(arg1, arg2) + } +} + +func (fake *FakeHandler) OnTrackCallCount() int { + fake.onTrackMutex.RLock() + defer fake.onTrackMutex.RUnlock() + return len(fake.onTrackArgsForCall) +} + +func (fake *FakeHandler) OnTrackCalls(stub func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + fake.onTrackMutex.Lock() + defer fake.onTrackMutex.Unlock() + fake.OnTrackStub = stub +} + +func (fake *FakeHandler) OnTrackArgsForCall(i int) (*webrtc.TrackRemote, *webrtc.RTPReceiver) { + fake.onTrackMutex.RLock() + defer fake.onTrackMutex.RUnlock() + argsForCall := fake.onTrackArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.onAnswerMutex.RLock() + defer fake.onAnswerMutex.RUnlock() + fake.onDataPacketMutex.RLock() + defer fake.onDataPacketMutex.RUnlock() + fake.onFailedMutex.RLock() + defer fake.onFailedMutex.RUnlock() + fake.onFullyEstablishedMutex.RLock() + defer fake.onFullyEstablishedMutex.RUnlock() + fake.onICECandidateMutex.RLock() + defer fake.onICECandidateMutex.RUnlock() + fake.onInitialConnectedMutex.RLock() + defer fake.onInitialConnectedMutex.RUnlock() + fake.onNegotiationFailedMutex.RLock() + defer fake.onNegotiationFailedMutex.RUnlock() + fake.onNegotiationStateChangedMutex.RLock() + defer fake.onNegotiationStateChangedMutex.RUnlock() + fake.onOfferMutex.RLock() + defer fake.onOfferMutex.RUnlock() + fake.onStreamStateChangeMutex.RLock() + defer fake.onStreamStateChangeMutex.RUnlock() + fake.onTrackMutex.RLock() + defer fake.onTrackMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHandler) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ transport.Handler = new(FakeHandler) diff --git a/pkg/rtc/transport_test.go b/pkg/rtc/transport_test.go index eb59df779..7bd1067b3 100644 --- a/pkg/rtc/transport_test.go +++ b/pkg/rtc/transport_test.go @@ -17,6 +17,7 @@ package rtc import ( "fmt" "strings" + "sync" "testing" "time" @@ -25,6 +26,8 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/atomic" + "github.com/livekit/livekit-server/pkg/rtc/transport" + "github.com/livekit/livekit-server/pkg/rtc/transport/transportfakes" "github.com/livekit/livekit-server/pkg/testutils" "github.com/livekit/protocol/livekit" ) @@ -36,33 +39,39 @@ func TestMissingAnswerDuringICERestart(t *testing.T) { Config: &WebRTCConfig{}, IsOfferer: true, } - transportA, err := NewPCTransport(params) + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) require.NoError(t, err) - _, err = transportA.pc.CreateDataChannel("test", nil) + _, err = transportA.pc.CreateDataChannel(ReliableDataChannel, nil) require.NoError(t, err) paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB paramsB.IsOfferer = false transportB, err := NewPCTransport(paramsB) require.NoError(t, err) // exchange ICE - handleICEExchange(t, transportA, transportB) + handleICEExchange(t, transportA, transportB, handlerA, handlerB) - connectTransports(t, transportA, transportB, false, 1, 1) + connectTransports(t, transportA, transportB, handlerA, handlerB, false, 1, 1) require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState()) require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState()) var negotiationState atomic.Value - transportA.OnNegotiationStateChanged(func(state NegotiationState) { + transportA.OnNegotiationStateChanged(func(state transport.NegotiationState) { negotiationState.Store(state) }) // offer again, but missed var offerReceived atomic.Bool - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { require.Equal(t, webrtc.SignalingStateHaveLocalOffer, transportA.pc.SignalingState()) - require.Equal(t, NegotiationStateRemote, negotiationState.Load().(NegotiationState)) + require.Equal(t, transport.NegotiationStateRemote, negotiationState.Load().(transport.NegotiationState)) offerReceived.Store(true) return nil }) @@ -71,7 +80,7 @@ func TestMissingAnswerDuringICERestart(t *testing.T) { return offerReceived.Load() }, 10*time.Second, time.Millisecond*10, "transportA offer not received") - connectTransports(t, transportA, transportB, true, 1, 1) + connectTransports(t, transportA, transportB, handlerA, handlerB, true, 1, 1) require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState()) require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState()) @@ -86,70 +95,76 @@ func TestNegotiationTiming(t *testing.T) { Config: &WebRTCConfig{}, IsOfferer: true, } - transportA, err := NewPCTransport(params) + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) require.NoError(t, err) - _, err = transportA.pc.CreateDataChannel("test", nil) + _, err = transportA.pc.CreateDataChannel(LossyDataChannel, nil) require.NoError(t, err) paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB paramsB.IsOfferer = false - transportB, err := NewPCTransport(params) + transportB, err := NewPCTransport(paramsB) require.NoError(t, err) require.False(t, transportA.IsEstablished()) require.False(t, transportB.IsEstablished()) - handleICEExchange(t, transportA, transportB) + handleICEExchange(t, transportA, transportB, handlerA, handlerB) offer := atomic.Value{} - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { offer.Store(&sd) return nil }) var negotiationState atomic.Value - transportA.OnNegotiationStateChanged(func(state NegotiationState) { + transportA.OnNegotiationStateChanged(func(state transport.NegotiationState) { negotiationState.Store(state) }) // initial offer transportA.Negotiate(true) require.Eventually(t, func() bool { - state, ok := negotiationState.Load().(NegotiationState) + state, ok := negotiationState.Load().(transport.NegotiationState) if !ok { return false } - return state == NegotiationStateRemote + return state == transport.NegotiationStateRemote }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRemote") // second try, should've flipped transport status to retry transportA.Negotiate(true) require.Eventually(t, func() bool { - state, ok := negotiationState.Load().(NegotiationState) + state, ok := negotiationState.Load().(transport.NegotiationState) if !ok { return false } - return state == NegotiationStateRetry + return state == transport.NegotiationStateRetry }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRetry") // third try, should've stayed at retry transportA.Negotiate(true) time.Sleep(100 * time.Millisecond) // some time to process the negotiate event require.Eventually(t, func() bool { - state, ok := negotiationState.Load().(NegotiationState) + state, ok := negotiationState.Load().(transport.NegotiationState) if !ok { return false } - return state == NegotiationStateRetry + return state == transport.NegotiationStateRetry }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRetry") time.Sleep(5 * time.Millisecond) actualOffer, ok := offer.Load().(*webrtc.SessionDescription) require.True(t, ok) - transportB.OnAnswer(func(answer webrtc.SessionDescription) error { + handlerB.OnAnswerCalls(func(answer webrtc.SessionDescription) error { transportA.HandleRemoteDescription(answer) return nil }) @@ -163,7 +178,7 @@ func TestNegotiationTiming(t *testing.T) { }, 10*time.Second, time.Millisecond*10, "transportB is not established") // it should still be negotiating again - require.Equal(t, NegotiationStateRemote, negotiationState.Load().(NegotiationState)) + require.Equal(t, transport.NegotiationStateRemote, negotiationState.Load().(transport.NegotiationState)) offer2, ok := offer.Load().(*webrtc.SessionDescription) require.True(t, ok) require.False(t, offer2 == actualOffer) @@ -179,22 +194,28 @@ func TestFirstOfferMissedDuringICERestart(t *testing.T) { Config: &WebRTCConfig{}, IsOfferer: true, } - transportA, err := NewPCTransport(params) + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) require.NoError(t, err) - _, err = transportA.pc.CreateDataChannel("test", nil) + _, err = transportA.pc.CreateDataChannel(ReliableDataChannel, nil) require.NoError(t, err) paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB paramsB.IsOfferer = false transportB, err := NewPCTransport(paramsB) require.NoError(t, err) // exchange ICE - handleICEExchange(t, transportA, transportB) + handleICEExchange(t, transportA, transportB, handlerA, handlerB) // first offer missed var firstOfferReceived atomic.Bool - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { firstOfferReceived.Store(true) return nil }) @@ -206,13 +227,13 @@ func TestFirstOfferMissedDuringICERestart(t *testing.T) { // set offer/answer with restart ICE, will negotiate twice, // first one is recover from missed offer // second one is restartICE - transportB.OnAnswer(func(answer webrtc.SessionDescription) error { + handlerB.OnAnswerCalls(func(answer webrtc.SessionDescription) error { transportA.HandleRemoteDescription(answer) return nil }) var offerCount atomic.Int32 - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { offerCount.Inc() // the second offer is a ice restart offer, so we wait transportB complete the ice gathering @@ -247,22 +268,28 @@ func TestFirstAnswerMissedDuringICERestart(t *testing.T) { Config: &WebRTCConfig{}, IsOfferer: true, } - transportA, err := NewPCTransport(params) + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) require.NoError(t, err) - _, err = transportA.pc.CreateDataChannel("test", nil) + _, err = transportA.pc.CreateDataChannel(LossyDataChannel, nil) require.NoError(t, err) paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB paramsB.IsOfferer = false transportB, err := NewPCTransport(paramsB) require.NoError(t, err) // exchange ICE - handleICEExchange(t, transportA, transportB) + handleICEExchange(t, transportA, transportB, handlerA, handlerB) // first answer missed var firstAnswerReceived atomic.Bool - transportB.OnAnswer(func(sd webrtc.SessionDescription) error { + handlerB.OnAnswerCalls(func(sd webrtc.SessionDescription) error { if firstAnswerReceived.Load() { transportA.HandleRemoteDescription(sd) } else { @@ -271,7 +298,7 @@ func TestFirstAnswerMissedDuringICERestart(t *testing.T) { } return nil }) - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { transportB.HandleRemoteDescription(sd) return nil }) @@ -285,7 +312,7 @@ func TestFirstAnswerMissedDuringICERestart(t *testing.T) { // first one is recover from missed offer // second one is restartICE var offerCount atomic.Int32 - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { offerCount.Inc() // the second offer is a ice restart offer, so we wait transportB complete the ice gathering @@ -320,20 +347,32 @@ func TestNegotiationFailed(t *testing.T) { Config: &WebRTCConfig{}, IsOfferer: true, } - transportA, err := NewPCTransport(params) + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) + require.NoError(t, err) + _, err = transportA.pc.CreateDataChannel(ReliableDataChannel, nil) require.NoError(t, err) - transportA.OnICECandidate(func(candidate *webrtc.ICECandidate) error { - if candidate == nil { - return nil - } - t.Logf("got ICE candidate from A: %v", candidate) - return nil - }) + paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB + paramsB.IsOfferer = false + transportB, err := NewPCTransport(paramsB) + require.NoError(t, err) - transportA.OnOffer(func(sd webrtc.SessionDescription) error { return nil }) + // exchange ICE + handleICEExchange(t, transportA, transportB, handlerA, handlerB) + + // wait for transport to be connected before maiming the signalling channel + connectTransports(t, transportA, transportB, handlerA, handlerB, false, 1, 1) + + // reset OnOffer to force a negotiation failure + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { return nil }) var failed atomic.Int32 - transportA.OnNegotiationFailed(func() { + handlerA.OnNegotiationFailedCalls(func() { failed.Inc() }) transportA.Negotiate(true) @@ -354,11 +393,12 @@ func TestFilteringCandidates(t *testing.T) { {Mime: webrtc.MimeTypeVP8}, {Mime: webrtc.MimeTypeH264}, }, + Handler: &transportfakes.FakeHandler{}, } transport, err := NewPCTransport(params) require.NoError(t, err) - _, err = transport.pc.CreateDataChannel("test", nil) + _, err = transport.pc.CreateDataChannel(ReliableDataChannel, nil) require.NoError(t, err) _, err = transport.pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio) @@ -464,8 +504,8 @@ func TestFilteringCandidates(t *testing.T) { transport.Close() } -func handleICEExchange(t *testing.T, a, b *PCTransport) { - a.OnICECandidate(func(candidate *webrtc.ICECandidate) error { +func handleICEExchange(t *testing.T, a, b *PCTransport, ah, bh *transportfakes.FakeHandler) { + ah.OnICECandidateCalls(func(candidate *webrtc.ICECandidate, target livekit.SignalTarget) error { if candidate == nil { return nil } @@ -473,7 +513,7 @@ func handleICEExchange(t *testing.T, a, b *PCTransport) { b.AddICECandidate(candidate.ToJSON()) return nil }) - b.OnICECandidate(func(candidate *webrtc.ICECandidate) error { + bh.OnICECandidateCalls(func(candidate *webrtc.ICECandidate, target livekit.SignalTarget) error { if candidate == nil { return nil } @@ -483,16 +523,16 @@ func handleICEExchange(t *testing.T, a, b *PCTransport) { }) } -func connectTransports(t *testing.T, offerer, answerer *PCTransport, isICERestart bool, expectedOfferCount int32, expectedAnswerCount int32) { +func connectTransports(t *testing.T, offerer, answerer *PCTransport, offererHandler, answererHandler *transportfakes.FakeHandler, isICERestart bool, expectedOfferCount int32, expectedAnswerCount int32) { var offerCount atomic.Int32 var answerCount atomic.Int32 - answerer.OnAnswer(func(answer webrtc.SessionDescription) error { + answererHandler.OnAnswerCalls(func(answer webrtc.SessionDescription) error { answerCount.Inc() offerer.HandleRemoteDescription(answer) return nil }) - offerer.OnOffer(func(offer webrtc.SessionDescription) error { + offererHandler.OnOfferCalls(func(offer webrtc.SessionDescription) error { offerCount.Inc() answerer.HandleRemoteDescription(offer) return nil @@ -519,6 +559,31 @@ func connectTransports(t *testing.T, offerer, answerer *PCTransport, isICERestar require.Eventually(t, func() bool { return answerer.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected }, 10*time.Second, time.Millisecond*10, "answerer did not become connected") + + transportsConnected := untilTransportsConnected(offererHandler, answererHandler) + transportsConnected.Wait() +} + +func untilTransportsConnected(transports ...*transportfakes.FakeHandler) *sync.WaitGroup { + var triggered sync.WaitGroup + triggered.Add(len(transports)) + + for _, t := range transports { + var done atomic.Value + done.Store(false) + hdlr := func() { + if val, ok := done.Load().(bool); ok && !val { + done.Store(true) + triggered.Done() + } + } + + if t.OnInitialConnectedCallCount() != 0 { + hdlr() + } + t.OnInitialConnectedCalls(hdlr) + } + return &triggered } func TestConfigureAudioTransceiver(t *testing.T) { diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index 6805bd430..401fea8e0 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -29,11 +29,10 @@ import ( "google.golang.org/protobuf/proto" "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/pacer" - "github.com/livekit/livekit-server/pkg/sfu/streamallocator" - "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) @@ -48,13 +47,31 @@ const ( udpLossUnstableCountThreshold = 20 ) +type TransportManagerTransportHandler struct { + transport.Handler + t *TransportManager +} + +func (h TransportManagerTransportHandler) OnFailed(isShortLived bool) { + h.t.handleConnectionFailed(isShortLived) + h.Handler.OnFailed(isShortLived) +} + +type TransportManagerPublisherTransportHandler struct { + TransportManagerTransportHandler +} + +func (h TransportManagerPublisherTransportHandler) OnAnswer(sd webrtc.SessionDescription) error { + h.t.lastPublisherAnswer.Store(sd) + return h.Handler.OnAnswer(sd) +} + type TransportManagerParams struct { Identity livekit.ParticipantIdentity SID livekit.ParticipantID SubscriberAsPrimary bool Config *WebRTCConfig ProtocolVersion types.ProtocolVersion - Telemetry telemetry.TelemetryService CongestionControlConfig config.CongestionControlConfig EnabledSubscribeCodecs []*livekit.Codec EnabledPublishCodecs []*livekit.Codec @@ -68,6 +85,8 @@ type TransportManagerParams struct { AllowPlayoutDelay bool DataChannelMaxBufferedAmount uint64 Logger logger.Logger + PublisherHandler transport.Handler + SubscriberHandler transport.Handler } type TransportManager struct { @@ -93,11 +112,6 @@ type TransportManager struct { udpLossUnstableCount uint32 signalingRTT, udpRTT uint32 - onPublisherInitialConnected func() - onSubscriberInitialConnected func() - onPrimaryTransportInitialConnected func() - onAnyTransportFailed func() - onICEConfigChanged func(iceConfig *livekit.ICEConfig) } @@ -119,30 +133,17 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro Config: params.Config, DirectionConfig: params.Config.Publisher, CongestionControlConfig: params.CongestionControlConfig, - Telemetry: params.Telemetry, EnabledCodecs: params.EnabledPublishCodecs, Logger: LoggerWithPCTarget(params.Logger, livekit.SignalTarget_PUBLISHER), SimTracks: params.SimTracks, ClientInfo: params.ClientInfo, + Transport: livekit.SignalTarget_PUBLISHER, + Handler: TransportManagerPublisherTransportHandler{TransportManagerTransportHandler{params.PublisherHandler, t}}, }) if err != nil { return nil, err } t.publisher = publisher - t.publisher.OnInitialConnected(func() { - if t.onPublisherInitialConnected != nil { - t.onPublisherInitialConnected() - } - if !t.params.SubscriberAsPrimary && t.onPrimaryTransportInitialConnected != nil { - t.onPrimaryTransportInitialConnected() - } - }) - t.publisher.OnFailed(func(isShortLived bool) { - t.handleConnectionFailed(isShortLived) - if t.onAnyTransportFailed != nil { - t.onAnyTransportFailed() - } - }) subscriber, err := NewPCTransport(TransportParams{ ParticipantID: params.SID, @@ -151,7 +152,6 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro Config: params.Config, DirectionConfig: params.Config.Subscriber, CongestionControlConfig: params.CongestionControlConfig, - Telemetry: params.Telemetry, EnabledCodecs: params.EnabledSubscribeCodecs, Logger: LoggerWithPCTarget(params.Logger, livekit.SignalTarget_SUBSCRIBER), ClientInfo: params.ClientInfo, @@ -159,25 +159,13 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro IsSendSide: true, AllowPlayoutDelay: params.AllowPlayoutDelay, DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, + Transport: livekit.SignalTarget_SUBSCRIBER, + Handler: TransportManagerTransportHandler{params.SubscriberHandler, t}, }) if err != nil { return nil, err } t.subscriber = subscriber - t.subscriber.OnInitialConnected(func() { - if t.onSubscriberInitialConnected != nil { - t.onSubscriberInitialConnected() - } - if t.params.SubscriberAsPrimary && t.onPrimaryTransportInitialConnected != nil { - t.onPrimaryTransportInitialConnected() - } - }) - t.subscriber.OnFailed(func(isShortLived bool) { - t.handleConnectionFailed(isShortLived) - if t.onAnyTransportFailed != nil { - t.onAnyTransportFailed() - } - }) if !t.params.Migration { if err := t.createDataChannelsForSubscriber(nil); err != nil { return nil, err @@ -185,7 +173,6 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro } t.signalSourceValid.Store(true) - return t, nil } @@ -198,25 +185,6 @@ func (t *TransportManager) SubscriberClose() { t.subscriber.Close() } -func (t *TransportManager) OnPublisherICECandidate(f func(c *webrtc.ICECandidate) error) { - t.publisher.OnICECandidate(f) -} - -func (t *TransportManager) OnPublisherAnswer(f func(answer webrtc.SessionDescription) error) { - t.publisher.OnAnswer(func(sd webrtc.SessionDescription) error { - t.lastPublisherAnswer.Store(sd) - return f(sd) - }) -} - -func (t *TransportManager) OnPublisherInitialConnected(f func()) { - t.onPublisherInitialConnected = f -} - -func (t *TransportManager) OnPublisherTrack(f func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver)) { - t.publisher.OnTrack(f) -} - func (t *TransportManager) HasPublisherEverConnected() bool { return t.publisher.HasEverConnected() } @@ -237,22 +205,6 @@ func (t *TransportManager) WritePublisherRTCP(pkts []rtcp.Packet) error { return t.publisher.WriteRTCP(pkts) } -func (t *TransportManager) OnSubscriberICECandidate(f func(c *webrtc.ICECandidate) error) { - t.subscriber.OnICECandidate(f) -} - -func (t *TransportManager) OnSubscriberOffer(f func(offer webrtc.SessionDescription) error) { - t.subscriber.OnOffer(f) -} - -func (t *TransportManager) OnSubscriberInitialConnected(f func()) { - t.onSubscriberInitialConnected = f -} - -func (t *TransportManager) OnSubscriberStreamStateChange(f func(update *streamallocator.StreamStateUpdate) error) { - t.subscriber.OnStreamStateChange(f) -} - func (t *TransportManager) HasSubscriberEverConnected() bool { return t.subscriber.HasEverConnected() } @@ -277,23 +229,6 @@ func (t *TransportManager) GetSubscriberPacer() pacer.Pacer { return t.subscriber.GetPacer() } -func (t *TransportManager) OnPrimaryTransportInitialConnected(f func()) { - t.onPrimaryTransportInitialConnected = f -} - -func (t *TransportManager) OnPrimaryTransportFullyEstablished(f func()) { - t.getTransport(true).OnFullyEstablished(f) -} - -func (t *TransportManager) OnAnyTransportFailed(f func()) { - t.onAnyTransportFailed = f -} - -func (t *TransportManager) OnAnyTransportNegotiationFailed(f func()) { - t.publisher.OnNegotiationFailed(f) - t.subscriber.OnNegotiationFailed(f) -} - func (t *TransportManager) AddSubscribedTrack(subTrack types.SubscribedTrack) { t.subscriber.AddTrackToStreamAllocator(subTrack) } @@ -302,11 +237,6 @@ func (t *TransportManager) RemoveSubscribedTrack(subTrack types.SubscribedTrack) t.subscriber.RemoveTrackFromStreamAllocator(subTrack) } -func (t *TransportManager) OnDataMessage(f func(kind livekit.DataPacket_Kind, data []byte)) { - // upstream data is always comes in via publisher peer connection irrespective of which is primary - t.publisher.OnDataPacket(f) -} - func (t *TransportManager) SendDataPacket(dp *livekit.DataPacket, data []byte) error { // downstream data is sent via primary peer connection return t.getTransport(true).SendDataPacket(dp, data) @@ -554,8 +484,15 @@ func (t *TransportManager) SubscriberAsPrimary() bool { return t.params.SubscriberAsPrimary } -func (t *TransportManager) GetICEConnectionType() types.ICEConnectionType { - return t.getTransport(true).GetICEConnectionType() +func (t *TransportManager) GetICEConnectionDetails() []*types.ICEConnectionDetails { + details := make([]*types.ICEConnectionDetails, 0, 2) + for _, pc := range []*PCTransport{t.publisher, t.subscriber} { + cd := pc.GetICEConnectionDetails() + if cd.HasCandidates() { + details = append(details, cd.Clone()) + } + } + return details } func (t *TransportManager) getTransport(isPrimary bool) *PCTransport { @@ -714,7 +651,7 @@ func (t *TransportManager) ProcessPendingPublisherDataChannels() { } } -func (t *TransportManager) OnReceiverReport(dt *sfu.DownTrack, report *rtcp.ReceiverReport) { +func (t *TransportManager) HandleReceiverReport(dt *sfu.DownTrack, report *rtcp.ReceiverReport) { t.mediaLossProxy.HandleMaxLossFeedback(dt, report) } @@ -732,10 +669,7 @@ func (t *TransportManager) onMediaLossUpdate(loss uint8) { t.lock.Unlock() t.params.Logger.Infow("udp connection unstable, switch to tcp", "signalingRTT", t.signalingRTT) - t.handleConnectionFailed(true) - if t.onAnyTransportFailed != nil { - t.onAnyTransportFailed() - } + t.params.SubscriberHandler.OnFailed(true) return } } diff --git a/pkg/rtc/types/ice.go b/pkg/rtc/types/ice.go new file mode 100644 index 000000000..f628c4a30 --- /dev/null +++ b/pkg/rtc/types/ice.go @@ -0,0 +1,272 @@ +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +import ( + "strings" + "sync" + + "github.com/pion/ice/v2" + "github.com/pion/webrtc/v3" + "golang.org/x/exp/slices" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +type ICEConnectionType string + +const ( + ICEConnectionTypeUDP ICEConnectionType = "udp" + ICEConnectionTypeTCP ICEConnectionType = "tcp" + ICEConnectionTypeTURN ICEConnectionType = "turn" + ICEConnectionTypeUnknown ICEConnectionType = "unknown" +) + +type ICECandidateExtended struct { + // only one of local or remote is set. This is due to type foo in Pion + Local *webrtc.ICECandidate + Remote ice.Candidate + Selected bool + Filtered bool +} + +type ICEConnectionDetails struct { + Local []*ICECandidateExtended + Remote []*ICECandidateExtended + Transport livekit.SignalTarget + Type ICEConnectionType + lock sync.Mutex + logger logger.Logger +} + +func NewICEConnectionDetails(transport livekit.SignalTarget, l logger.Logger) *ICEConnectionDetails { + d := &ICEConnectionDetails{ + Transport: transport, + Type: ICEConnectionTypeUnknown, + logger: l, + } + return d +} + +func (d *ICEConnectionDetails) HasCandidates() bool { + d.lock.Lock() + defer d.lock.Unlock() + return len(d.Local) > 0 || len(d.Remote) > 0 +} + +// Clone returns a copy of the ICEConnectionDetails, where fields can be read without locking +func (d *ICEConnectionDetails) Clone() *ICEConnectionDetails { + d.lock.Lock() + defer d.lock.Unlock() + clone := &ICEConnectionDetails{ + Transport: d.Transport, + Type: d.Type, + logger: d.logger, + Local: make([]*ICECandidateExtended, 0, len(d.Local)), + Remote: make([]*ICECandidateExtended, 0, len(d.Remote)), + } + for _, c := range d.Local { + clone.Local = append(clone.Local, &ICECandidateExtended{ + Local: c.Local, + Filtered: c.Filtered, + Selected: c.Selected, + }) + } + for _, c := range d.Remote { + clone.Remote = append(clone.Remote, &ICECandidateExtended{ + Remote: c.Remote, + Filtered: c.Filtered, + Selected: c.Selected, + }) + } + return clone +} + +func (d *ICEConnectionDetails) AddLocalCandidate(c *webrtc.ICECandidate, filtered bool) { + d.lock.Lock() + defer d.lock.Unlock() + compFn := func(e *ICECandidateExtended) bool { + return isCandidateEqualTo(e.Local, c) + } + if slices.ContainsFunc[[]*ICECandidateExtended, *ICECandidateExtended](d.Local, compFn) { + return + } + d.Local = append(d.Local, &ICECandidateExtended{ + Local: c, + Filtered: filtered, + }) +} + +func (d *ICEConnectionDetails) AddRemoteCandidate(c webrtc.ICECandidateInit, filtered bool) { + candidate, err := unmarshalICECandidate(c) + if err != nil { + d.logger.Errorw("could not unmarshal candidate", err, "candidate", c) + return + } + if candidate == nil { + // end-of-candidates candidate + return + } + + d.lock.Lock() + defer d.lock.Unlock() + compFn := func(e *ICECandidateExtended) bool { + return isICECandidateEqualTo(e.Remote, *candidate) + } + if slices.ContainsFunc[[]*ICECandidateExtended, *ICECandidateExtended](d.Remote, compFn) { + return + } + d.Remote = append(d.Remote, &ICECandidateExtended{ + Remote: *candidate, + Filtered: filtered, + }) +} + +func (d *ICEConnectionDetails) Clear() { + d.lock.Lock() + defer d.lock.Unlock() + d.Local = nil + d.Remote = nil + d.Type = ICEConnectionTypeUnknown +} + +func (d *ICEConnectionDetails) SetSelectedPair(pair *webrtc.ICECandidatePair) { + d.lock.Lock() + defer d.lock.Unlock() + remoteIdx := slices.IndexFunc[[]*ICECandidateExtended, *ICECandidateExtended](d.Remote, func(e *ICECandidateExtended) bool { + return isICECandidateEqualToCandidate(e.Remote, pair.Remote) + }) + if remoteIdx < 0 { + // it's possible for prflx candidates to be generated by Pion, we'll add them + candidate, err := unmarshalICECandidate(pair.Remote.ToJSON()) + if err != nil { + d.logger.Errorw("could not unmarshal remote candidate", err, "candidate", pair.Remote) + return + } + if candidate == nil { + return + } + d.Remote = append(d.Remote, &ICECandidateExtended{ + Remote: *candidate, + Filtered: false, + }) + remoteIdx = len(d.Remote) - 1 + } + remote := d.Remote[remoteIdx] + remote.Selected = true + + localIdx := slices.IndexFunc[[]*ICECandidateExtended, *ICECandidateExtended](d.Local, func(e *ICECandidateExtended) bool { + return isCandidateEqualTo(e.Local, pair.Local) + }) + if localIdx < 0 { + d.logger.Errorw("could not match local candidate", nil, "local", pair.Local) + // should not happen + return + } + local := d.Local[localIdx] + local.Selected = true + + d.Type = ICEConnectionTypeUDP + if pair.Remote.Protocol == webrtc.ICEProtocolTCP { + d.Type = ICEConnectionTypeTCP + } + if pair.Remote.Typ == webrtc.ICECandidateTypeRelay { + d.Type = ICEConnectionTypeTURN + } else if pair.Remote.Typ == webrtc.ICECandidateTypePrflx { + // if the remote relay candidate pings us *before* we get a relay candidate, + // Pion would have created a prflx candidate with the same address as the relay candidate. + // to report an accurate connection type, we'll compare to see if existing relay candidates match + for _, other := range d.Remote { + or := other.Remote + if or.Type() == ice.CandidateTypeRelay && + pair.Remote.Address == or.Address() && + pair.Remote.Port == uint16(or.Port()) && + pair.Remote.Protocol.String() == or.NetworkType().NetworkShort() { + d.Type = ICEConnectionTypeTURN + } + } + } +} + +func isCandidateEqualTo(c1, c2 *webrtc.ICECandidate) bool { + if c1 == nil && c2 == nil { + return true + } + if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) { + return false + } + return c1.Typ == c2.Typ && + c1.Protocol == c2.Protocol && + c1.Address == c2.Address && + c1.Port == c2.Port && + c1.Component == c2.Component && + c1.Foundation == c2.Foundation && + c1.Priority == c2.Priority && + c1.RelatedAddress == c2.RelatedAddress && + c1.RelatedPort == c2.RelatedPort && + c1.TCPType == c2.TCPType +} + +func isICECandidateEqualTo(c1, c2 ice.Candidate) bool { + if c1 == nil && c2 == nil { + return true + } + if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) { + return false + } + return c1.Type() == c2.Type() && + c1.NetworkType() == c2.NetworkType() && + c1.Address() == c2.Address() && + c1.Port() == c2.Port() && + c1.Component() == c2.Component() && + c1.Foundation() == c2.Foundation() && + c1.Priority() == c2.Priority() && + c1.RelatedAddress().Equal(c2.RelatedAddress()) && + c1.TCPType() == c2.TCPType() +} + +func isICECandidateEqualToCandidate(c1 ice.Candidate, c2 *webrtc.ICECandidate) bool { + if c1 == nil && c2 == nil { + return true + } + if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) { + return false + } + return c1.Type().String() == c2.Typ.String() && + c1.NetworkType().NetworkShort() == c2.Protocol.String() && + c1.Address() == c2.Address && + c1.Port() == int(c2.Port) && + c1.Component() == c2.Component && + c1.Foundation() == c2.Foundation && + c1.Priority() == c2.Priority && + c1.TCPType().String() == c2.TCPType +} + +func unmarshalICECandidate(c webrtc.ICECandidateInit) (*ice.Candidate, error) { + candidateValue := strings.TrimPrefix(c.Candidate, "candidate:") + if candidateValue == "" { + return nil, nil + } + + candidate, err := ice.UnmarshalCandidate(candidateValue) + if err != nil { + return nil, err + } + + return &candidate, nil +} diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 266e24b8b..32eb1e35c 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -81,14 +81,13 @@ type SubscribedCodecQuality struct { type ParticipantCloseReason int const ( - ParticipantCloseReasonClientRequestLeave ParticipantCloseReason = iota + ParticipantCloseReasonNone ParticipantCloseReason = iota + ParticipantCloseReasonClientRequestLeave ParticipantCloseReasonRoomManagerStop - ParticipantCloseReasonRoomClose ParticipantCloseReasonVerifyFailed ParticipantCloseReasonJoinFailed ParticipantCloseReasonJoinTimeout ParticipantCloseReasonMessageBusFailed - ParticipantCloseReasonStateDisconnected ParticipantCloseReasonPeerConnectionDisconnected ParticipantCloseReasonDuplicateIdentity ParticipantCloseReasonMigrationComplete @@ -100,21 +99,21 @@ const ( ParticipantCloseReasonSimulateServerLeave ParticipantCloseReasonNegotiateFailed ParticipantCloseReasonMigrationRequested - ParticipantCloseReasonOvercommitted ParticipantCloseReasonPublicationError ParticipantCloseReasonSubscriptionError ParticipantCloseReasonDataChannelError ParticipantCloseReasonMigrateCodecMismatch + ParticipantCloseReasonSignalSourceClose ) func (p ParticipantCloseReason) String() string { switch p { + case ParticipantCloseReasonNone: + return "NONE" case ParticipantCloseReasonClientRequestLeave: return "CLIENT_REQUEST_LEAVE" case ParticipantCloseReasonRoomManagerStop: return "ROOM_MANAGER_STOP" - case ParticipantCloseReasonRoomClose: - return "ROOM_CLOSE" case ParticipantCloseReasonVerifyFailed: return "VERIFY_FAILED" case ParticipantCloseReasonJoinFailed: @@ -123,8 +122,6 @@ func (p ParticipantCloseReason) String() string { return "JOIN_TIMEOUT" case ParticipantCloseReasonMessageBusFailed: return "MESSAGE_BUS_FAILED" - case ParticipantCloseReasonStateDisconnected: - return "STATE_DISCONNECTED" case ParticipantCloseReasonPeerConnectionDisconnected: return "PEER_CONNECTION_DISCONNECTED" case ParticipantCloseReasonDuplicateIdentity: @@ -147,8 +144,6 @@ func (p ParticipantCloseReason) String() string { return "NEGOTIATE_FAILED" case ParticipantCloseReasonMigrationRequested: return "MIGRATION_REQUESTED" - case ParticipantCloseReasonOvercommitted: - return "OVERCOMMITTED" case ParticipantCloseReasonPublicationError: return "PUBLICATION_ERROR" case ParticipantCloseReasonSubscriptionError: @@ -157,6 +152,8 @@ func (p ParticipantCloseReason) String() string { return "DATA_CHANNEL_ERROR" case ParticipantCloseReasonMigrateCodecMismatch: return "MIGRATE_CODEC_MISMATCH" + case ParticipantCloseReasonSignalSourceClose: + return "SIGNAL_SOURCE_CLOSE" default: return fmt.Sprintf("%d", int(p)) } @@ -173,22 +170,20 @@ func (p ParticipantCloseReason) ToDisconnectReason() livekit.DisconnectReason { return livekit.DisconnectReason_JOIN_FAILURE case ParticipantCloseReasonPeerConnectionDisconnected: return livekit.DisconnectReason_STATE_MISMATCH - case ParticipantCloseReasonDuplicateIdentity, ParticipantCloseReasonMigrationComplete, ParticipantCloseReasonStale: + case ParticipantCloseReasonDuplicateIdentity, ParticipantCloseReasonStale: return livekit.DisconnectReason_DUPLICATE_IDENTITY + case ParticipantCloseReasonMigrationRequested, ParticipantCloseReasonMigrationComplete, ParticipantCloseReasonSimulateMigration: + return livekit.DisconnectReason_MIGRATION case ParticipantCloseReasonServiceRequestRemoveParticipant: return livekit.DisconnectReason_PARTICIPANT_REMOVED case ParticipantCloseReasonServiceRequestDeleteRoom: return livekit.DisconnectReason_ROOM_DELETED - case ParticipantCloseReasonSimulateMigration: - return livekit.DisconnectReason_DUPLICATE_IDENTITY - case ParticipantCloseReasonSimulateNodeFailure: - return livekit.DisconnectReason_SERVER_SHUTDOWN - case ParticipantCloseReasonSimulateServerLeave: - return livekit.DisconnectReason_SERVER_SHUTDOWN - case ParticipantCloseReasonOvercommitted: + case ParticipantCloseReasonSimulateNodeFailure, ParticipantCloseReasonSimulateServerLeave: return livekit.DisconnectReason_SERVER_SHUTDOWN case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError, ParticipantCloseReasonMigrateCodecMismatch: return livekit.DisconnectReason_STATE_MISMATCH + case ParticipantCloseReasonSignalSourceClose: + return livekit.DisconnectReason_SIGNAL_CLOSE default: // the other types will map to unknown reason return livekit.DisconnectReason_UNKNOWN_REASON @@ -209,6 +204,8 @@ const ( SignallingCloseReasonFullReconnectDataChannelError SignallingCloseReasonFullReconnectNegotiateFailed SignallingCloseReasonParticipantClose + SignallingCloseReasonDisconnectOnResume + SignallingCloseReasonDisconnectOnResumeNoMessages ) func (s SignallingCloseReason) String() string { @@ -231,6 +228,10 @@ func (s SignallingCloseReason) String() string { return "FULL_RECONNECT_NEGOTIATE_FAILED" case SignallingCloseReasonParticipantClose: return "PARTICIPANT_CLOSE" + case SignallingCloseReasonDisconnectOnResume: + return "DISCONNECT_ON_RESUME" + case SignallingCloseReasonDisconnectOnResumeNoMessages: + return "DISCONNECT_ON_RESUME_NO_MESSAGES" default: return fmt.Sprintf("%d", int(s)) } @@ -243,6 +244,7 @@ type Participant interface { ID() livekit.ParticipantID Identity() livekit.ParticipantIdentity State() livekit.ParticipantInfo_State + CloseReason() ParticipantCloseReason CanSkipBroadcast() bool ToProto() *livekit.ParticipantInfo @@ -266,7 +268,6 @@ type Participant interface { IsRecorder() bool IsAgent() bool - Start() Close(sendLeave bool, reason ParticipantCloseReason, isExpectedToResume bool) error SubscriptionPermission() (*livekit.SubscriptionPermission, utils.TimedVersion) @@ -285,15 +286,6 @@ type Participant interface { // ------------------------------------------------------- -type ICEConnectionType string - -const ( - ICEConnectionTypeUDP ICEConnectionType = "udp" - ICEConnectionTypeTCP ICEConnectionType = "tcp" - ICEConnectionTypeTURN ICEConnectionType = "turn" - ICEConnectionTypeUnknown ICEConnectionType = "unknown" -) - type AddTrackParams struct { Stereo bool Red bool @@ -320,10 +312,11 @@ type LocalParticipant interface { SubscriberAsPrimary() bool GetClientInfo() *livekit.ClientInfo GetClientConfiguration() *livekit.ClientConfiguration - GetICEConnectionType() ICEConnectionType GetBufferFactory() *buffer.Factory GetPlayoutDelayConfig() *livekit.PlayoutDelay GetPendingTrack(trackID livekit.TrackID) *livekit.TrackInfo + GetICEConnectionDetails() []*ICEConnectionDetails + HasConnected() bool SetResponseSink(sink routing.MessageSink) CloseSignalConnection(reason SignallingCloseReason) @@ -380,7 +373,7 @@ type LocalParticipant interface { IssueFullReconnect(reason ParticipantCloseReason) // callbacks - OnStateChange(func(p LocalParticipant, oldState livekit.ParticipantInfo_State)) + OnStateChange(func(p LocalParticipant, state livekit.ParticipantInfo_State)) OnMigrateStateChange(func(p LocalParticipant, migrateState MigrateState)) // OnTrackPublished - remote added a track OnTrackPublished(func(LocalParticipant, MediaTrack)) @@ -394,9 +387,10 @@ type LocalParticipant interface { OnSubscribeStatusChanged(fn func(publisherID livekit.ParticipantID, subscribed bool)) OnClose(callback func(LocalParticipant)) OnClaimsChanged(callback func(LocalParticipant)) - OnReceiverReport(dt *sfu.DownTrack, report *rtcp.ReceiverReport) OnTrafficLoad(callback func(trafficLoad *TrafficLoad)) + HandleReceiverReport(dt *sfu.DownTrack, report *rtcp.ReceiverReport) + // session migration MaybeStartMigration(force bool, onStart func()) bool SetMigrateState(s MigrateState) @@ -452,6 +446,7 @@ type MediaTrack interface { Source() livekit.TrackSource Stream() string + UpdateTrackInfo(ti *livekit.TrackInfo) ToProto() *livekit.TrackInfo PublisherID() livekit.ParticipantID diff --git a/pkg/rtc/types/protocol_version.go b/pkg/rtc/types/protocol_version.go index 9c4e47cc3..ae2282085 100644 --- a/pkg/rtc/types/protocol_version.go +++ b/pkg/rtc/types/protocol_version.go @@ -16,7 +16,7 @@ package types type ProtocolVersion int -const CurrentProtocol = 10 +const CurrentProtocol = 13 func (v ProtocolVersion) SupportsPackedStreamId() bool { return v > 0 @@ -75,3 +75,19 @@ func (v ProtocolVersion) SupportHandlesDisconnectedUpdate() bool { func (v ProtocolVersion) SupportSyncStreamID() bool { return v > 9 } + +func (v ProtocolVersion) SupportsConnectionQualityLost() bool { + return v > 10 +} + +func (v ProtocolVersion) SupportsAsyncRoomID() bool { + return v > 11 +} + +func (v ProtocolVersion) SupportsIdentityBasedReconnection() bool { + return v > 11 +} + +func (v ProtocolVersion) SupportsRegionsInLeaveRequest() bool { + return v > 12 +} diff --git a/pkg/rtc/types/trafficstats.go b/pkg/rtc/types/trafficstats.go index 70b269db8..f5c0b386d 100644 --- a/pkg/rtc/types/trafficstats.go +++ b/pkg/rtc/types/trafficstats.go @@ -21,10 +21,13 @@ import ( ) type TrafficStats struct { - StartTime time.Time - EndTime time.Time - Packets uint32 - Bytes uint64 + StartTime time.Time + EndTime time.Time + Packets uint32 + PacketsLost uint32 + PacketsPadding uint32 + PacketsOutOfOrder uint32 + Bytes uint64 } type TrafficTypeStats struct { @@ -47,24 +50,47 @@ func RTPStatsDiffToTrafficStats(before, after *livekit.RTPStats) *TrafficStats { startTime = before.EndTime } - if before == nil { + getAfter := func() *TrafficStats { return &TrafficStats{ - StartTime: startTime.AsTime(), - EndTime: after.EndTime.AsTime(), - Packets: after.Packets, - Bytes: after.Bytes + after.BytesDuplicate + after.BytesPadding, + StartTime: startTime.AsTime(), + EndTime: after.EndTime.AsTime(), + Packets: after.Packets, + PacketsLost: after.PacketsLost, + PacketsPadding: after.PacketsPadding, + PacketsOutOfOrder: after.PacketsOutOfOrder, + Bytes: after.Bytes + after.BytesDuplicate + after.BytesPadding, } } + if before == nil { + return getAfter() + } + + if (after.Packets - before.Packets) > (1 << 31) { + // after packets < before packets, probably got reset, just return after + return getAfter() + } + if ((after.Bytes + after.BytesDuplicate + after.BytesPadding) - (before.Bytes + before.BytesDuplicate + before.BytesPadding)) > (1 << 63) { + // after bytes < before bytes, probably got reset, just return after + return getAfter() + } + + packetsLost := uint32(0) + if after.PacketsLost >= before.PacketsLost { + packetsLost = after.PacketsLost - before.PacketsLost + } return &TrafficStats{ - StartTime: startTime.AsTime(), - EndTime: after.EndTime.AsTime(), - Packets: after.Packets - before.Packets, - Bytes: (after.Bytes + after.BytesDuplicate + after.BytesPadding) - (before.Bytes + before.BytesDuplicate + before.BytesPadding), + StartTime: startTime.AsTime(), + EndTime: after.EndTime.AsTime(), + Packets: after.Packets - before.Packets, + PacketsLost: packetsLost, + PacketsPadding: after.PacketsPadding - before.PacketsPadding, + PacketsOutOfOrder: after.PacketsOutOfOrder - before.PacketsOutOfOrder, + Bytes: (after.Bytes + after.BytesDuplicate + after.BytesPadding) - (before.Bytes + before.BytesDuplicate + before.BytesPadding), } } -func AggregateTrafficStats(statsList []*TrafficStats) *TrafficStats { +func AggregateTrafficStats(statsList ...*TrafficStats) *TrafficStats { if len(statsList) == 0 { return nil } @@ -73,6 +99,9 @@ func AggregateTrafficStats(statsList []*TrafficStats) *TrafficStats { endTime := time.Time{} packets := uint32(0) + packetsLost := uint32(0) + packetsPadding := uint32(0) + packetsOutOfOrder := uint32(0) bytes := uint64(0) for _, stats := range statsList { @@ -85,6 +114,9 @@ func AggregateTrafficStats(statsList []*TrafficStats) *TrafficStats { } packets += stats.Packets + packetsLost += stats.PacketsLost + packetsPadding += stats.PacketsPadding + packetsOutOfOrder += stats.PacketsOutOfOrder bytes += stats.Bytes } @@ -92,10 +124,13 @@ func AggregateTrafficStats(statsList []*TrafficStats) *TrafficStats { endTime = time.Now() } return &TrafficStats{ - StartTime: startTime, - EndTime: endTime, - Packets: packets, - Bytes: bytes, + StartTime: startTime, + EndTime: endTime, + Packets: packets, + PacketsLost: packetsLost, + PacketsPadding: packetsPadding, + PacketsOutOfOrder: packetsOutOfOrder, + Bytes: bytes, } } diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index a606f8ef1..be86e74c4 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -332,6 +332,11 @@ type FakeLocalMediaTrack struct { toProtoReturnsOnCall map[int]struct { result1 *livekit.TrackInfo } + UpdateTrackInfoStub func(*livekit.TrackInfo) + updateTrackInfoMutex sync.RWMutex + updateTrackInfoArgsForCall []struct { + arg1 *livekit.TrackInfo + } UpdateVideoLayersStub func([]*livekit.VideoLayer) updateVideoLayersMutex sync.RWMutex updateVideoLayersArgsForCall []struct { @@ -2072,6 +2077,38 @@ func (fake *FakeLocalMediaTrack) ToProtoReturnsOnCall(i int, result1 *livekit.Tr }{result1} } +func (fake *FakeLocalMediaTrack) UpdateTrackInfo(arg1 *livekit.TrackInfo) { + fake.updateTrackInfoMutex.Lock() + fake.updateTrackInfoArgsForCall = append(fake.updateTrackInfoArgsForCall, struct { + arg1 *livekit.TrackInfo + }{arg1}) + stub := fake.UpdateTrackInfoStub + fake.recordInvocation("UpdateTrackInfo", []interface{}{arg1}) + fake.updateTrackInfoMutex.Unlock() + if stub != nil { + fake.UpdateTrackInfoStub(arg1) + } +} + +func (fake *FakeLocalMediaTrack) UpdateTrackInfoCallCount() int { + fake.updateTrackInfoMutex.RLock() + defer fake.updateTrackInfoMutex.RUnlock() + return len(fake.updateTrackInfoArgsForCall) +} + +func (fake *FakeLocalMediaTrack) UpdateTrackInfoCalls(stub func(*livekit.TrackInfo)) { + fake.updateTrackInfoMutex.Lock() + defer fake.updateTrackInfoMutex.Unlock() + fake.UpdateTrackInfoStub = stub +} + +func (fake *FakeLocalMediaTrack) UpdateTrackInfoArgsForCall(i int) *livekit.TrackInfo { + fake.updateTrackInfoMutex.RLock() + defer fake.updateTrackInfoMutex.RUnlock() + argsForCall := fake.updateTrackInfoArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeLocalMediaTrack) UpdateVideoLayers(arg1 []*livekit.VideoLayer) { var arg1Copy []*livekit.VideoLayer if arg1 != nil { @@ -2182,6 +2219,8 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} { defer fake.streamMutex.RUnlock() fake.toProtoMutex.RLock() defer fake.toProtoMutex.RUnlock() + fake.updateTrackInfoMutex.RLock() + defer fake.updateTrackInfoMutex.RUnlock() fake.updateVideoLayersMutex.RLock() defer fake.updateVideoLayersMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index a01a3834b..c20a1a1d7 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -133,6 +133,16 @@ type FakeLocalParticipant struct { closeReturnsOnCall map[int]struct { result1 error } + CloseReasonStub func() types.ParticipantCloseReason + closeReasonMutex sync.RWMutex + closeReasonArgsForCall []struct { + } + closeReasonReturns struct { + result1 types.ParticipantCloseReason + } + closeReasonReturnsOnCall map[int]struct { + result1 types.ParticipantCloseReason + } CloseSignalConnectionStub func(types.SignallingCloseReason) closeSignalConnectionMutex sync.RWMutex closeSignalConnectionArgsForCall []struct { @@ -233,15 +243,15 @@ type FakeLocalParticipant struct { getConnectionQualityReturnsOnCall map[int]struct { result1 *livekit.ConnectionQualityInfo } - GetICEConnectionTypeStub func() types.ICEConnectionType - getICEConnectionTypeMutex sync.RWMutex - getICEConnectionTypeArgsForCall []struct { + GetICEConnectionDetailsStub func() []*types.ICEConnectionDetails + getICEConnectionDetailsMutex sync.RWMutex + getICEConnectionDetailsArgsForCall []struct { } - getICEConnectionTypeReturns struct { - result1 types.ICEConnectionType + getICEConnectionDetailsReturns struct { + result1 []*types.ICEConnectionDetails } - getICEConnectionTypeReturnsOnCall map[int]struct { - result1 types.ICEConnectionType + getICEConnectionDetailsReturnsOnCall map[int]struct { + result1 []*types.ICEConnectionDetails } GetLoggerStub func() logger.Logger getLoggerMutex sync.RWMutex @@ -355,6 +365,12 @@ type FakeLocalParticipant struct { handleOfferArgsForCall []struct { arg1 webrtc.SessionDescription } + HandleReceiverReportStub func(*sfu.DownTrack, *rtcp.ReceiverReport) + handleReceiverReportMutex sync.RWMutex + handleReceiverReportArgsForCall []struct { + arg1 *sfu.DownTrack + arg2 *rtcp.ReceiverReport + } HandleReconnectAndSendResponseStub func(livekit.ReconnectReason, *livekit.ReconnectResponse) error handleReconnectAndSendResponseMutex sync.RWMutex handleReconnectAndSendResponseArgsForCall []struct { @@ -371,6 +387,16 @@ type FakeLocalParticipant struct { handleSignalSourceCloseMutex sync.RWMutex handleSignalSourceCloseArgsForCall []struct { } + HasConnectedStub func() bool + hasConnectedMutex sync.RWMutex + hasConnectedArgsForCall []struct { + } + hasConnectedReturns struct { + result1 bool + } + hasConnectedReturnsOnCall map[int]struct { + result1 bool + } HasPermissionStub func(livekit.TrackID, livekit.ParticipantIdentity) bool hasPermissionMutex sync.RWMutex hasPermissionArgsForCall []struct { @@ -561,16 +587,10 @@ type FakeLocalParticipant struct { onParticipantUpdateArgsForCall []struct { arg1 func(types.LocalParticipant) } - OnReceiverReportStub func(*sfu.DownTrack, *rtcp.ReceiverReport) - onReceiverReportMutex sync.RWMutex - onReceiverReportArgsForCall []struct { - arg1 *sfu.DownTrack - arg2 *rtcp.ReceiverReport - } - OnStateChangeStub func(func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State)) + OnStateChangeStub func(func(p types.LocalParticipant, state livekit.ParticipantInfo_State)) onStateChangeMutex sync.RWMutex onStateChangeArgsForCall []struct { - arg1 func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State) + arg1 func(p types.LocalParticipant, state livekit.ParticipantInfo_State) } OnSubscribeStatusChangedStub func(func(publisherID livekit.ParticipantID, subscribed bool)) onSubscribeStatusChangedMutex sync.RWMutex @@ -776,10 +796,6 @@ type FakeLocalParticipant struct { setTrackMutedReturnsOnCall map[int]struct { result1 *livekit.TrackInfo } - StartStub func() - startMutex sync.RWMutex - startArgsForCall []struct { - } StateStub func() livekit.ParticipantInfo_State stateMutex sync.RWMutex stateArgsForCall []struct { @@ -1539,6 +1555,59 @@ func (fake *FakeLocalParticipant) CloseReturnsOnCall(i int, result1 error) { }{result1} } +func (fake *FakeLocalParticipant) CloseReason() types.ParticipantCloseReason { + fake.closeReasonMutex.Lock() + ret, specificReturn := fake.closeReasonReturnsOnCall[len(fake.closeReasonArgsForCall)] + fake.closeReasonArgsForCall = append(fake.closeReasonArgsForCall, struct { + }{}) + stub := fake.CloseReasonStub + fakeReturns := fake.closeReasonReturns + fake.recordInvocation("CloseReason", []interface{}{}) + fake.closeReasonMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) CloseReasonCallCount() int { + fake.closeReasonMutex.RLock() + defer fake.closeReasonMutex.RUnlock() + return len(fake.closeReasonArgsForCall) +} + +func (fake *FakeLocalParticipant) CloseReasonCalls(stub func() types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = stub +} + +func (fake *FakeLocalParticipant) CloseReasonReturns(result1 types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = nil + fake.closeReasonReturns = struct { + result1 types.ParticipantCloseReason + }{result1} +} + +func (fake *FakeLocalParticipant) CloseReasonReturnsOnCall(i int, result1 types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = nil + if fake.closeReasonReturnsOnCall == nil { + fake.closeReasonReturnsOnCall = make(map[int]struct { + result1 types.ParticipantCloseReason + }) + } + fake.closeReasonReturnsOnCall[i] = struct { + result1 types.ParticipantCloseReason + }{result1} +} + func (fake *FakeLocalParticipant) CloseSignalConnection(arg1 types.SignallingCloseReason) { fake.closeSignalConnectionMutex.Lock() fake.closeSignalConnectionArgsForCall = append(fake.closeSignalConnectionArgsForCall, struct { @@ -2062,15 +2131,15 @@ func (fake *FakeLocalParticipant) GetConnectionQualityReturnsOnCall(i int, resul }{result1} } -func (fake *FakeLocalParticipant) GetICEConnectionType() types.ICEConnectionType { - fake.getICEConnectionTypeMutex.Lock() - ret, specificReturn := fake.getICEConnectionTypeReturnsOnCall[len(fake.getICEConnectionTypeArgsForCall)] - fake.getICEConnectionTypeArgsForCall = append(fake.getICEConnectionTypeArgsForCall, struct { +func (fake *FakeLocalParticipant) GetICEConnectionDetails() []*types.ICEConnectionDetails { + fake.getICEConnectionDetailsMutex.Lock() + ret, specificReturn := fake.getICEConnectionDetailsReturnsOnCall[len(fake.getICEConnectionDetailsArgsForCall)] + fake.getICEConnectionDetailsArgsForCall = append(fake.getICEConnectionDetailsArgsForCall, struct { }{}) - stub := fake.GetICEConnectionTypeStub - fakeReturns := fake.getICEConnectionTypeReturns - fake.recordInvocation("GetICEConnectionType", []interface{}{}) - fake.getICEConnectionTypeMutex.Unlock() + stub := fake.GetICEConnectionDetailsStub + fakeReturns := fake.getICEConnectionDetailsReturns + fake.recordInvocation("GetICEConnectionDetails", []interface{}{}) + fake.getICEConnectionDetailsMutex.Unlock() if stub != nil { return stub() } @@ -2080,38 +2149,38 @@ func (fake *FakeLocalParticipant) GetICEConnectionType() types.ICEConnectionType return fakeReturns.result1 } -func (fake *FakeLocalParticipant) GetICEConnectionTypeCallCount() int { - fake.getICEConnectionTypeMutex.RLock() - defer fake.getICEConnectionTypeMutex.RUnlock() - return len(fake.getICEConnectionTypeArgsForCall) +func (fake *FakeLocalParticipant) GetICEConnectionDetailsCallCount() int { + fake.getICEConnectionDetailsMutex.RLock() + defer fake.getICEConnectionDetailsMutex.RUnlock() + return len(fake.getICEConnectionDetailsArgsForCall) } -func (fake *FakeLocalParticipant) GetICEConnectionTypeCalls(stub func() types.ICEConnectionType) { - fake.getICEConnectionTypeMutex.Lock() - defer fake.getICEConnectionTypeMutex.Unlock() - fake.GetICEConnectionTypeStub = stub +func (fake *FakeLocalParticipant) GetICEConnectionDetailsCalls(stub func() []*types.ICEConnectionDetails) { + fake.getICEConnectionDetailsMutex.Lock() + defer fake.getICEConnectionDetailsMutex.Unlock() + fake.GetICEConnectionDetailsStub = stub } -func (fake *FakeLocalParticipant) GetICEConnectionTypeReturns(result1 types.ICEConnectionType) { - fake.getICEConnectionTypeMutex.Lock() - defer fake.getICEConnectionTypeMutex.Unlock() - fake.GetICEConnectionTypeStub = nil - fake.getICEConnectionTypeReturns = struct { - result1 types.ICEConnectionType +func (fake *FakeLocalParticipant) GetICEConnectionDetailsReturns(result1 []*types.ICEConnectionDetails) { + fake.getICEConnectionDetailsMutex.Lock() + defer fake.getICEConnectionDetailsMutex.Unlock() + fake.GetICEConnectionDetailsStub = nil + fake.getICEConnectionDetailsReturns = struct { + result1 []*types.ICEConnectionDetails }{result1} } -func (fake *FakeLocalParticipant) GetICEConnectionTypeReturnsOnCall(i int, result1 types.ICEConnectionType) { - fake.getICEConnectionTypeMutex.Lock() - defer fake.getICEConnectionTypeMutex.Unlock() - fake.GetICEConnectionTypeStub = nil - if fake.getICEConnectionTypeReturnsOnCall == nil { - fake.getICEConnectionTypeReturnsOnCall = make(map[int]struct { - result1 types.ICEConnectionType +func (fake *FakeLocalParticipant) GetICEConnectionDetailsReturnsOnCall(i int, result1 []*types.ICEConnectionDetails) { + fake.getICEConnectionDetailsMutex.Lock() + defer fake.getICEConnectionDetailsMutex.Unlock() + fake.GetICEConnectionDetailsStub = nil + if fake.getICEConnectionDetailsReturnsOnCall == nil { + fake.getICEConnectionDetailsReturnsOnCall = make(map[int]struct { + result1 []*types.ICEConnectionDetails }) } - fake.getICEConnectionTypeReturnsOnCall[i] = struct { - result1 types.ICEConnectionType + fake.getICEConnectionDetailsReturnsOnCall[i] = struct { + result1 []*types.ICEConnectionDetails }{result1} } @@ -2725,6 +2794,39 @@ func (fake *FakeLocalParticipant) HandleOfferArgsForCall(i int) webrtc.SessionDe return argsForCall.arg1 } +func (fake *FakeLocalParticipant) HandleReceiverReport(arg1 *sfu.DownTrack, arg2 *rtcp.ReceiverReport) { + fake.handleReceiverReportMutex.Lock() + fake.handleReceiverReportArgsForCall = append(fake.handleReceiverReportArgsForCall, struct { + arg1 *sfu.DownTrack + arg2 *rtcp.ReceiverReport + }{arg1, arg2}) + stub := fake.HandleReceiverReportStub + fake.recordInvocation("HandleReceiverReport", []interface{}{arg1, arg2}) + fake.handleReceiverReportMutex.Unlock() + if stub != nil { + fake.HandleReceiverReportStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) HandleReceiverReportCallCount() int { + fake.handleReceiverReportMutex.RLock() + defer fake.handleReceiverReportMutex.RUnlock() + return len(fake.handleReceiverReportArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleReceiverReportCalls(stub func(*sfu.DownTrack, *rtcp.ReceiverReport)) { + fake.handleReceiverReportMutex.Lock() + defer fake.handleReceiverReportMutex.Unlock() + fake.HandleReceiverReportStub = stub +} + +func (fake *FakeLocalParticipant) HandleReceiverReportArgsForCall(i int) (*sfu.DownTrack, *rtcp.ReceiverReport) { + fake.handleReceiverReportMutex.RLock() + defer fake.handleReceiverReportMutex.RUnlock() + argsForCall := fake.handleReceiverReportArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + func (fake *FakeLocalParticipant) HandleReconnectAndSendResponse(arg1 livekit.ReconnectReason, arg2 *livekit.ReconnectResponse) error { fake.handleReconnectAndSendResponseMutex.Lock() ret, specificReturn := fake.handleReconnectAndSendResponseReturnsOnCall[len(fake.handleReconnectAndSendResponseArgsForCall)] @@ -2811,6 +2913,59 @@ func (fake *FakeLocalParticipant) HandleSignalSourceCloseCalls(stub func()) { fake.HandleSignalSourceCloseStub = stub } +func (fake *FakeLocalParticipant) HasConnected() bool { + fake.hasConnectedMutex.Lock() + ret, specificReturn := fake.hasConnectedReturnsOnCall[len(fake.hasConnectedArgsForCall)] + fake.hasConnectedArgsForCall = append(fake.hasConnectedArgsForCall, struct { + }{}) + stub := fake.HasConnectedStub + fakeReturns := fake.hasConnectedReturns + fake.recordInvocation("HasConnected", []interface{}{}) + fake.hasConnectedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HasConnectedCallCount() int { + fake.hasConnectedMutex.RLock() + defer fake.hasConnectedMutex.RUnlock() + return len(fake.hasConnectedArgsForCall) +} + +func (fake *FakeLocalParticipant) HasConnectedCalls(stub func() bool) { + fake.hasConnectedMutex.Lock() + defer fake.hasConnectedMutex.Unlock() + fake.HasConnectedStub = stub +} + +func (fake *FakeLocalParticipant) HasConnectedReturns(result1 bool) { + fake.hasConnectedMutex.Lock() + defer fake.hasConnectedMutex.Unlock() + fake.HasConnectedStub = nil + fake.hasConnectedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) HasConnectedReturnsOnCall(i int, result1 bool) { + fake.hasConnectedMutex.Lock() + defer fake.hasConnectedMutex.Unlock() + fake.HasConnectedStub = nil + if fake.hasConnectedReturnsOnCall == nil { + fake.hasConnectedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hasConnectedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeLocalParticipant) HasPermission(arg1 livekit.TrackID, arg2 livekit.ParticipantIdentity) bool { fake.hasPermissionMutex.Lock() ret, specificReturn := fake.hasPermissionReturnsOnCall[len(fake.hasPermissionArgsForCall)] @@ -3867,43 +4022,10 @@ func (fake *FakeLocalParticipant) OnParticipantUpdateArgsForCall(i int) func(typ return argsForCall.arg1 } -func (fake *FakeLocalParticipant) OnReceiverReport(arg1 *sfu.DownTrack, arg2 *rtcp.ReceiverReport) { - fake.onReceiverReportMutex.Lock() - fake.onReceiverReportArgsForCall = append(fake.onReceiverReportArgsForCall, struct { - arg1 *sfu.DownTrack - arg2 *rtcp.ReceiverReport - }{arg1, arg2}) - stub := fake.OnReceiverReportStub - fake.recordInvocation("OnReceiverReport", []interface{}{arg1, arg2}) - fake.onReceiverReportMutex.Unlock() - if stub != nil { - fake.OnReceiverReportStub(arg1, arg2) - } -} - -func (fake *FakeLocalParticipant) OnReceiverReportCallCount() int { - fake.onReceiverReportMutex.RLock() - defer fake.onReceiverReportMutex.RUnlock() - return len(fake.onReceiverReportArgsForCall) -} - -func (fake *FakeLocalParticipant) OnReceiverReportCalls(stub func(*sfu.DownTrack, *rtcp.ReceiverReport)) { - fake.onReceiverReportMutex.Lock() - defer fake.onReceiverReportMutex.Unlock() - fake.OnReceiverReportStub = stub -} - -func (fake *FakeLocalParticipant) OnReceiverReportArgsForCall(i int) (*sfu.DownTrack, *rtcp.ReceiverReport) { - fake.onReceiverReportMutex.RLock() - defer fake.onReceiverReportMutex.RUnlock() - argsForCall := fake.onReceiverReportArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2 -} - -func (fake *FakeLocalParticipant) OnStateChange(arg1 func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State)) { +func (fake *FakeLocalParticipant) OnStateChange(arg1 func(p types.LocalParticipant, state livekit.ParticipantInfo_State)) { fake.onStateChangeMutex.Lock() fake.onStateChangeArgsForCall = append(fake.onStateChangeArgsForCall, struct { - arg1 func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State) + arg1 func(p types.LocalParticipant, state livekit.ParticipantInfo_State) }{arg1}) stub := fake.OnStateChangeStub fake.recordInvocation("OnStateChange", []interface{}{arg1}) @@ -3919,13 +4041,13 @@ func (fake *FakeLocalParticipant) OnStateChangeCallCount() int { return len(fake.onStateChangeArgsForCall) } -func (fake *FakeLocalParticipant) OnStateChangeCalls(stub func(func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State))) { +func (fake *FakeLocalParticipant) OnStateChangeCalls(stub func(func(p types.LocalParticipant, state livekit.ParticipantInfo_State))) { fake.onStateChangeMutex.Lock() defer fake.onStateChangeMutex.Unlock() fake.OnStateChangeStub = stub } -func (fake *FakeLocalParticipant) OnStateChangeArgsForCall(i int) func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State) { +func (fake *FakeLocalParticipant) OnStateChangeArgsForCall(i int) func(p types.LocalParticipant, state livekit.ParticipantInfo_State) { fake.onStateChangeMutex.RLock() defer fake.onStateChangeMutex.RUnlock() argsForCall := fake.onStateChangeArgsForCall[i] @@ -5109,30 +5231,6 @@ func (fake *FakeLocalParticipant) SetTrackMutedReturnsOnCall(i int, result1 *liv }{result1} } -func (fake *FakeLocalParticipant) Start() { - fake.startMutex.Lock() - fake.startArgsForCall = append(fake.startArgsForCall, struct { - }{}) - stub := fake.StartStub - fake.recordInvocation("Start", []interface{}{}) - fake.startMutex.Unlock() - if stub != nil { - fake.StartStub() - } -} - -func (fake *FakeLocalParticipant) StartCallCount() int { - fake.startMutex.RLock() - defer fake.startMutex.RUnlock() - return len(fake.startArgsForCall) -} - -func (fake *FakeLocalParticipant) StartCalls(stub func()) { - fake.startMutex.Lock() - defer fake.startMutex.Unlock() - fake.StartStub = stub -} - func (fake *FakeLocalParticipant) State() livekit.ParticipantInfo_State { fake.stateMutex.Lock() ret, specificReturn := fake.stateReturnsOnCall[len(fake.stateArgsForCall)] @@ -6136,6 +6234,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.claimGrantsMutex.RUnlock() fake.closeMutex.RLock() defer fake.closeMutex.RUnlock() + fake.closeReasonMutex.RLock() + defer fake.closeReasonMutex.RUnlock() fake.closeSignalConnectionMutex.RLock() defer fake.closeSignalConnectionMutex.RUnlock() fake.connectedAtMutex.RLock() @@ -6156,8 +6256,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.getClientInfoMutex.RUnlock() fake.getConnectionQualityMutex.RLock() defer fake.getConnectionQualityMutex.RUnlock() - fake.getICEConnectionTypeMutex.RLock() - defer fake.getICEConnectionTypeMutex.RUnlock() + fake.getICEConnectionDetailsMutex.RLock() + defer fake.getICEConnectionDetailsMutex.RUnlock() fake.getLoggerMutex.RLock() defer fake.getLoggerMutex.RUnlock() fake.getPacerMutex.RLock() @@ -6182,10 +6282,14 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.handleAnswerMutex.RUnlock() fake.handleOfferMutex.RLock() defer fake.handleOfferMutex.RUnlock() + fake.handleReceiverReportMutex.RLock() + defer fake.handleReceiverReportMutex.RUnlock() fake.handleReconnectAndSendResponseMutex.RLock() defer fake.handleReconnectAndSendResponseMutex.RUnlock() fake.handleSignalSourceCloseMutex.RLock() defer fake.handleSignalSourceCloseMutex.RUnlock() + fake.hasConnectedMutex.RLock() + defer fake.hasConnectedMutex.RUnlock() fake.hasPermissionMutex.RLock() defer fake.hasPermissionMutex.RUnlock() fake.hiddenMutex.RLock() @@ -6232,8 +6336,6 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.onMigrateStateChangeMutex.RUnlock() fake.onParticipantUpdateMutex.RLock() defer fake.onParticipantUpdateMutex.RUnlock() - fake.onReceiverReportMutex.RLock() - defer fake.onReceiverReportMutex.RUnlock() fake.onStateChangeMutex.RLock() defer fake.onStateChangeMutex.RUnlock() fake.onSubscribeStatusChangedMutex.RLock() @@ -6288,8 +6390,6 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.setSubscriberChannelCapacityMutex.RUnlock() fake.setTrackMutedMutex.RLock() defer fake.setTrackMutedMutex.RUnlock() - fake.startMutex.RLock() - defer fake.startMutex.RUnlock() fake.stateMutex.RLock() defer fake.stateMutex.RUnlock() fake.subscribeToTrackMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index fd472c52e..d4bdfc17e 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -268,6 +268,11 @@ type FakeMediaTrack struct { toProtoReturnsOnCall map[int]struct { result1 *livekit.TrackInfo } + UpdateTrackInfoStub func(*livekit.TrackInfo) + updateTrackInfoMutex sync.RWMutex + updateTrackInfoArgsForCall []struct { + arg1 *livekit.TrackInfo + } UpdateVideoLayersStub func([]*livekit.VideoLayer) updateVideoLayersMutex sync.RWMutex updateVideoLayersArgsForCall []struct { @@ -1658,6 +1663,38 @@ func (fake *FakeMediaTrack) ToProtoReturnsOnCall(i int, result1 *livekit.TrackIn }{result1} } +func (fake *FakeMediaTrack) UpdateTrackInfo(arg1 *livekit.TrackInfo) { + fake.updateTrackInfoMutex.Lock() + fake.updateTrackInfoArgsForCall = append(fake.updateTrackInfoArgsForCall, struct { + arg1 *livekit.TrackInfo + }{arg1}) + stub := fake.UpdateTrackInfoStub + fake.recordInvocation("UpdateTrackInfo", []interface{}{arg1}) + fake.updateTrackInfoMutex.Unlock() + if stub != nil { + fake.UpdateTrackInfoStub(arg1) + } +} + +func (fake *FakeMediaTrack) UpdateTrackInfoCallCount() int { + fake.updateTrackInfoMutex.RLock() + defer fake.updateTrackInfoMutex.RUnlock() + return len(fake.updateTrackInfoArgsForCall) +} + +func (fake *FakeMediaTrack) UpdateTrackInfoCalls(stub func(*livekit.TrackInfo)) { + fake.updateTrackInfoMutex.Lock() + defer fake.updateTrackInfoMutex.Unlock() + fake.UpdateTrackInfoStub = stub +} + +func (fake *FakeMediaTrack) UpdateTrackInfoArgsForCall(i int) *livekit.TrackInfo { + fake.updateTrackInfoMutex.RLock() + defer fake.updateTrackInfoMutex.RUnlock() + argsForCall := fake.updateTrackInfoArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeMediaTrack) UpdateVideoLayers(arg1 []*livekit.VideoLayer) { var arg1Copy []*livekit.VideoLayer if arg1 != nil { @@ -1752,6 +1789,8 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.streamMutex.RUnlock() fake.toProtoMutex.RLock() defer fake.toProtoMutex.RUnlock() + fake.updateTrackInfoMutex.RLock() + defer fake.updateTrackInfoMutex.RUnlock() fake.updateVideoLayersMutex.RLock() defer fake.updateVideoLayersMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} diff --git a/pkg/rtc/types/typesfakes/fake_participant.go b/pkg/rtc/types/typesfakes/fake_participant.go index 3f46bdbba..98fff9c80 100644 --- a/pkg/rtc/types/typesfakes/fake_participant.go +++ b/pkg/rtc/types/typesfakes/fake_participant.go @@ -33,6 +33,16 @@ type FakeParticipant struct { closeReturnsOnCall map[int]struct { result1 error } + CloseReasonStub func() types.ParticipantCloseReason + closeReasonMutex sync.RWMutex + closeReasonArgsForCall []struct { + } + closeReasonReturns struct { + result1 types.ParticipantCloseReason + } + closeReasonReturnsOnCall map[int]struct { + result1 types.ParticipantCloseReason + } DebugInfoStub func() map[string]interface{} debugInfoMutex sync.RWMutex debugInfoArgsForCall []struct { @@ -165,10 +175,6 @@ type FakeParticipant struct { setNameArgsForCall []struct { arg1 string } - StartStub func() - startMutex sync.RWMutex - startArgsForCall []struct { - } StateStub func() livekit.ParticipantInfo_State stateMutex sync.RWMutex stateArgsForCall []struct { @@ -346,6 +352,59 @@ func (fake *FakeParticipant) CloseReturnsOnCall(i int, result1 error) { }{result1} } +func (fake *FakeParticipant) CloseReason() types.ParticipantCloseReason { + fake.closeReasonMutex.Lock() + ret, specificReturn := fake.closeReasonReturnsOnCall[len(fake.closeReasonArgsForCall)] + fake.closeReasonArgsForCall = append(fake.closeReasonArgsForCall, struct { + }{}) + stub := fake.CloseReasonStub + fakeReturns := fake.closeReasonReturns + fake.recordInvocation("CloseReason", []interface{}{}) + fake.closeReasonMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) CloseReasonCallCount() int { + fake.closeReasonMutex.RLock() + defer fake.closeReasonMutex.RUnlock() + return len(fake.closeReasonArgsForCall) +} + +func (fake *FakeParticipant) CloseReasonCalls(stub func() types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = stub +} + +func (fake *FakeParticipant) CloseReasonReturns(result1 types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = nil + fake.closeReasonReturns = struct { + result1 types.ParticipantCloseReason + }{result1} +} + +func (fake *FakeParticipant) CloseReasonReturnsOnCall(i int, result1 types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = nil + if fake.closeReasonReturnsOnCall == nil { + fake.closeReasonReturnsOnCall = make(map[int]struct { + result1 types.ParticipantCloseReason + }) + } + fake.closeReasonReturnsOnCall[i] = struct { + result1 types.ParticipantCloseReason + }{result1} +} + func (fake *FakeParticipant) DebugInfo() map[string]interface{} { fake.debugInfoMutex.Lock() ret, specificReturn := fake.debugInfoReturnsOnCall[len(fake.debugInfoArgsForCall)] @@ -1047,30 +1106,6 @@ func (fake *FakeParticipant) SetNameArgsForCall(i int) string { return argsForCall.arg1 } -func (fake *FakeParticipant) Start() { - fake.startMutex.Lock() - fake.startArgsForCall = append(fake.startArgsForCall, struct { - }{}) - stub := fake.StartStub - fake.recordInvocation("Start", []interface{}{}) - fake.startMutex.Unlock() - if stub != nil { - fake.StartStub() - } -} - -func (fake *FakeParticipant) StartCallCount() int { - fake.startMutex.RLock() - defer fake.startMutex.RUnlock() - return len(fake.startArgsForCall) -} - -func (fake *FakeParticipant) StartCalls(stub func()) { - fake.startMutex.Lock() - defer fake.startMutex.Unlock() - fake.StartStub = stub -} - func (fake *FakeParticipant) State() livekit.ParticipantInfo_State { fake.stateMutex.Lock() ret, specificReturn := fake.stateReturnsOnCall[len(fake.stateArgsForCall)] @@ -1365,6 +1400,8 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.canSkipBroadcastMutex.RUnlock() fake.closeMutex.RLock() defer fake.closeMutex.RUnlock() + fake.closeReasonMutex.RLock() + defer fake.closeReasonMutex.RUnlock() fake.debugInfoMutex.RLock() defer fake.debugInfoMutex.RUnlock() fake.getAudioLevelMutex.RLock() @@ -1393,8 +1430,6 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.setMetadataMutex.RUnlock() fake.setNameMutex.RLock() defer fake.setNameMutex.RUnlock() - fake.startMutex.RLock() - defer fake.startMutex.RUnlock() fake.stateMutex.RLock() defer fake.stateMutex.RUnlock() fake.subscriptionPermissionMutex.RLock() diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index 1c079dc34..ed01e1f79 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -62,9 +62,6 @@ func NewUpTrackManager(params UpTrackManagerParams) *UpTrackManager { } } -func (u *UpTrackManager) Start() { -} - func (u *UpTrackManager) Close(willBeResumed bool) { u.lock.Lock() u.closed = true @@ -111,7 +108,7 @@ func (u *UpTrackManager) SetPublishedTrackMuted(trackID livekit.TrackID, muted b track.SetMuted(muted) if currentMuted != track.IsMuted() { - u.params.Logger.Infow("publisher mute status changed", "trackID", trackID, "muted", track.IsMuted()) + u.params.Logger.Debugw("publisher mute status changed", "trackID", trackID, "muted", track.IsMuted()) if u.onTrackUpdated != nil { u.onTrackUpdated(track) } @@ -142,7 +139,7 @@ func (u *UpTrackManager) GetPublishedTracks() []types.MediaTrack { func (u *UpTrackManager) UpdateSubscriptionPermission( subscriptionPermission *livekit.SubscriptionPermission, timedVersion utils.TimedVersion, - resolverByIdentity func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, + _ func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, // TODO: separate PR to remove this argument resolverBySid func(participantID livekit.ParticipantID) types.LocalParticipant, ) error { u.lock.Lock() @@ -203,7 +200,7 @@ func (u *UpTrackManager) UpdateSubscriptionPermission( } u.lock.Unlock() - u.maybeRevokeSubscriptions(resolverByIdentity) + u.maybeRevokeSubscriptions() return nil } @@ -381,7 +378,7 @@ func (u *UpTrackManager) getAllowedSubscribersLocked(trackID livekit.TrackID) [] return allowed } -func (u *UpTrackManager) maybeRevokeSubscriptions(resolver func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant) { +func (u *UpTrackManager) maybeRevokeSubscriptions() { u.lock.Lock() defer u.lock.Unlock() diff --git a/pkg/rtc/wrappedreceiver.go b/pkg/rtc/wrappedreceiver.go index 8df479ce7..a887ef832 100644 --- a/pkg/rtc/wrappedreceiver.go +++ b/pkg/rtc/wrappedreceiver.go @@ -294,6 +294,12 @@ func (d *DummyReceiver) TrackInfo() *livekit.TrackInfo { return nil } +func (d *DummyReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + r.UpdateTrackInfo(ti) + } +} + func (d *DummyReceiver) IsClosed() bool { if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { return r.IsClosed() diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go index f2ba672b2..20636bc60 100644 --- a/pkg/service/agentservice.go +++ b/pkg/service/agentservice.go @@ -28,6 +28,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" @@ -54,6 +55,7 @@ type AgentHandler struct { roomWorkers map[string]*worker publisherRegistered bool publisherWorkers map[string]*worker + onWorkerRegistered func(handler *AgentHandler) } type worker struct { @@ -65,6 +67,7 @@ type worker struct { jobType livekit.JobType status livekit.WorkerStatus activeJobs int + logger logger.Logger } type availability struct { @@ -102,18 +105,18 @@ func (s *AgentService) ServeHTTP(writer http.ResponseWriter, r *http.Request) { // require a claim claims := GetGrants(r.Context()) if claims == nil || claims.Video == nil || !claims.Video.Agent { - handleError(writer, http.StatusUnauthorized, rtc.ErrPermissionDenied) + handleError(writer, r, http.StatusUnauthorized, rtc.ErrPermissionDenied) return } // upgrade conn, err := s.upgrader.Upgrade(writer, r, nil) if err != nil { - handleError(writer, http.StatusInternalServerError, err) + handleError(writer, r, http.StatusInternalServerError, err) return } - s.HandleConnection(conn) + s.HandleConnection(r.Context(), conn) } func NewAgentHandler(agentServer rpc.AgentInternalServer, roomTopic, publisherTopic string) *AgentHandler { @@ -128,11 +131,19 @@ func NewAgentHandler(agentServer rpc.AgentInternalServer, roomTopic, publisherTo } } -func (s *AgentHandler) HandleConnection(conn *websocket.Conn) { +// OnWorkerRegistered registers a callback to be called when the first worker of each type is registered +func (s *AgentHandler) OnWorkerRegistered(handler func(handler *AgentHandler)) { + s.mu.Lock() + defer s.mu.Unlock() + s.onWorkerRegistered = handler +} + +func (s *AgentHandler) HandleConnection(ctx context.Context, conn *websocket.Conn) { sigConn := NewWSSignalConnection(conn) w := &worker{ conn: conn, sigConn: sigConn, + logger: utils.GetLogger(ctx), } s.mu.Lock() @@ -177,9 +188,9 @@ func (s *AgentHandler) HandleConnection(conn *websocket.Conn) { websocket.CloseNormalClosure, websocket.CloseNoStatusReceived, ) { - logger.Infow("exit ws read loop for closed connection", "wsError", err) + w.logger.Infow("Agent worker closed WS connection", "wsError", err) } else { - logger.Errorw("error reading from websocket", err) + w.logger.Errorw("error reading from websocket", err) } return } @@ -199,7 +210,7 @@ func (s *AgentHandler) HandleConnection(conn *websocket.Conn) { func (s *AgentHandler) handleRegister(worker *worker, msg *livekit.RegisterWorkerRequest) { if err := s.doHandleRegister(worker, msg); err != nil { - logger.Errorw("failed to register worker", err, "workerID", msg.WorkerId, "jobType", msg.Type) + worker.logger.Errorw("failed to register worker", err, "workerID", msg.WorkerId, "jobType", msg.Type) worker.conn.Close() } } @@ -215,6 +226,9 @@ func (s *AgentHandler) doHandleRegister(worker *worker, msg *livekit.RegisterWor return errors.New("worker already registered") } + onRegistered := s.onWorkerRegistered + firstWorker := false + switch msg.Type { case livekit.JobType_JT_ROOM: worker.id = msg.WorkerId @@ -225,9 +239,10 @@ func (s *AgentHandler) doHandleRegister(worker *worker, msg *livekit.RegisterWor if !s.roomRegistered { err := s.agentServer.RegisterJobRequestTopic(s.roomTopic) if err != nil { - logger.Errorw("failed to register room agents", err) + worker.logger.Errorw("failed to register room agents", err) } else { s.roomRegistered = true + firstWorker = true } } @@ -240,9 +255,10 @@ func (s *AgentHandler) doHandleRegister(worker *worker, msg *livekit.RegisterWor if !s.publisherRegistered { err := s.agentServer.RegisterJobRequestTopic(s.publisherTopic) if err != nil { - logger.Errorw("failed to register publisher agents", err) + worker.logger.Errorw("failed to register publisher agents", err) } else { s.publisherRegistered = true + firstWorker = true } } default: @@ -260,9 +276,12 @@ func (s *AgentHandler) doHandleRegister(worker *worker, msg *livekit.RegisterWor }, }) if err != nil { - logger.Errorw("failed to write server message", err) + worker.logger.Errorw("failed to write server message", err) } + if firstWorker && onRegistered != nil { + onRegistered(s) + } return nil } @@ -282,9 +301,9 @@ func (s *AgentHandler) handleAvailability(w *worker, msg *livekit.AvailabilityRe func (s *AgentHandler) handleJobUpdate(w *worker, msg *livekit.JobStatusUpdate) { switch msg.Status { case livekit.JobStatus_JS_SUCCESS: - logger.Debugw("job complete", "jobID", msg.JobId) + w.logger.Debugw("job complete", "jobID", msg.JobId) case livekit.JobStatus_JS_FAILED: - logger.Warnw("job failed", errors.New(msg.Error), "jobID", msg.JobId) + w.logger.Warnw("job failed", errors.New(msg.Error), "jobID", msg.JobId) } w.mu.Lock() @@ -307,7 +326,7 @@ func (s *AgentHandler) handleStatus(w *worker, msg *livekit.UpdateWorkerStatus) s.agentServer.DeregisterJobRequestTopic(s.roomTopic) } else if !s.roomRegistered && s.roomAvailableLocked() { if err := s.agentServer.RegisterJobRequestTopic(s.roomTopic); err != nil { - logger.Errorw("failed to register room agents", err) + w.logger.Errorw("failed to register room agents", err) } else { s.roomRegistered = true } @@ -318,7 +337,7 @@ func (s *AgentHandler) handleStatus(w *worker, msg *livekit.UpdateWorkerStatus) s.agentServer.DeregisterJobRequestTopic(s.publisherTopic) } else if !s.publisherRegistered && s.publisherAvailableLocked() { if err := s.agentServer.RegisterJobRequestTopic(s.publisherTopic); err != nil { - logger.Errorw("failed to register publisher agents", err) + w.logger.Errorw("failed to register publisher agents", err) } else { s.publisherRegistered = true } @@ -388,7 +407,7 @@ func (s *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty Availability: &livekit.AvailabilityRequest{Job: job}, }}) if err != nil { - logger.Errorw("failed to send availability request", err, "workerID", selected.id) + selected.logger.Errorw("failed to send availability request", err, "workerID", selected.id) } select { @@ -400,7 +419,7 @@ func (s *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty Assignment: &livekit.JobAssignment{Job: job}, }}) if err != nil { - logger.Errorw("failed to assign job", err, "workerID", selected.id) + selected.logger.Errorw("failed to assign job", err, "workerID", selected.id) } else { selected.mu.Lock() selected.activeJobs++ diff --git a/pkg/service/auth.go b/pkg/service/auth.go index 1b5829e30..5633b2889 100644 --- a/pkg/service/auth.go +++ b/pkg/service/auth.go @@ -62,7 +62,7 @@ func (m *APIKeyAuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, if authHeader != "" { if !strings.HasPrefix(authHeader, bearerPrefix) { - handleError(w, http.StatusUnauthorized, ErrMissingAuthorization) + handleError(w, r, http.StatusUnauthorized, ErrMissingAuthorization) return } @@ -75,19 +75,19 @@ func (m *APIKeyAuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, if authToken != "" { v, err := auth.ParseAPIToken(authToken) if err != nil { - handleError(w, http.StatusUnauthorized, ErrInvalidAuthorizationToken) + handleError(w, r, http.StatusUnauthorized, ErrInvalidAuthorizationToken) return } secret := m.provider.GetSecret(v.APIKey()) if secret == "" { - handleError(w, http.StatusUnauthorized, errors.New("invalid API key: "+v.APIKey())) + handleError(w, r, http.StatusUnauthorized, errors.New("invalid API key: "+v.APIKey())) return } grants, err := v.Verify(secret) if err != nil { - handleError(w, http.StatusUnauthorized, errors.New("invalid token: "+authToken+", error: "+err.Error())) + handleError(w, r, http.StatusUnauthorized, errors.New("invalid token: "+authToken+", error: "+err.Error())) return } diff --git a/pkg/service/ingress.go b/pkg/service/ingress.go index a8fd81d6c..2acda0064 100644 --- a/pkg/service/ingress.go +++ b/pkg/service/ingress.go @@ -132,7 +132,7 @@ func (s *IngressService) CreateIngressWithUrl(ctx context.Context, urlStr string if err != nil { return nil, psrpc.NewError(psrpc.InvalidArgument, err) } - if urlObj.Scheme != "http" && urlObj.Scheme != "https" { + if urlObj.Scheme != "http" && urlObj.Scheme != "https" && urlObj.Scheme != "srt" { return nil, ingress.ErrInvalidIngress(fmt.Sprintf("invalid url scheme %s", urlObj.Scheme)) } // Marshall the URL again for sanitization diff --git a/pkg/service/interfaces.go b/pkg/service/interfaces.go index 5dd0cb20b..73c7d4a67 100644 --- a/pkg/service/interfaces.go +++ b/pkg/service/interfaces.go @@ -35,7 +35,6 @@ type ObjectStore interface { UnlockRoom(ctx context.Context, roomName livekit.RoomName, uid string) error StoreRoom(ctx context.Context, room *livekit.Room, internal *livekit.RoomInternal) error - DeleteRoom(ctx context.Context, roomName livekit.RoomName) error StoreParticipant(ctx context.Context, roomName livekit.RoomName, participant *livekit.ParticipantInfo) error DeleteParticipant(ctx context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity) error @@ -44,6 +43,7 @@ type ObjectStore interface { //counterfeiter:generate . ServiceStore type ServiceStore interface { LoadRoom(ctx context.Context, roomName livekit.RoomName, includeInternal bool) (*livekit.Room, *livekit.RoomInternal, error) + DeleteRoom(ctx context.Context, roomName livekit.RoomName) error // ListRooms returns currently active rooms. if names is not nil, it'll filter and return // only rooms that match @@ -88,9 +88,4 @@ type SIPStore interface { LoadSIPDispatchRule(ctx context.Context, sipDispatchRuleID string) (*livekit.SIPDispatchRuleInfo, error) ListSIPDispatchRule(ctx context.Context) ([]*livekit.SIPDispatchRuleInfo, error) DeleteSIPDispatchRule(ctx context.Context, info *livekit.SIPDispatchRuleInfo) error - - StoreSIPParticipant(ctx context.Context, info *livekit.SIPParticipantInfo) error - LoadSIPParticipant(ctx context.Context, sipParticipantID string) (*livekit.SIPParticipantInfo, error) - ListSIPParticipant(ctx context.Context) ([]*livekit.SIPParticipantInfo, error) - DeleteSIPParticipant(ctx context.Context, info *livekit.SIPParticipantInfo) error } diff --git a/pkg/service/ioservice.go b/pkg/service/ioservice.go index 70158d5cc..67f995033 100644 --- a/pkg/service/ioservice.go +++ b/pkg/service/ioservice.go @@ -148,6 +148,15 @@ func (s *IOInfoService) ListEgress(ctx context.Context, req *livekit.ListEgressR return &livekit.ListEgressResponse{Items: items}, nil } +func (s *IOInfoService) UpdateMetrics(ctx context.Context, req *rpc.UpdateMetricsRequest) (*emptypb.Empty, error) { + logger.Infow("received egress metrics", + "egressID", req.Info.EgressId, + "avgCpu", req.AvgCpuUsage, + "maxCpu", req.MaxCpuUsage, + ) + return &emptypb.Empty{}, nil +} + func (s *IOInfoService) GetIngressInfo(ctx context.Context, req *rpc.GetIngressInfoRequest) (*rpc.GetIngressInfoResponse, error) { info, err := s.loadIngressFromInfoRequest(req) if err != nil { @@ -184,7 +193,8 @@ func (s *IOInfoService) UpdateIngressState(ctx context.Context, req *rpc.UpdateI switch req.State.Status { case livekit.IngressState_ENDPOINT_ERROR, - livekit.IngressState_ENDPOINT_INACTIVE: + livekit.IngressState_ENDPOINT_INACTIVE, + livekit.IngressState_ENDPOINT_COMPLETE: s.telemetry.IngressEnded(ctx, info) if req.State.Error != "" { @@ -209,7 +219,7 @@ func (s *IOInfoService) UpdateIngressState(ctx context.Context, req *rpc.UpdateI s.telemetry.IngressUpdated(ctx, info) - logger.Infow("ingress updated", "ingressID", req.IngressId) + logger.Infow("ingress updated", "ingressID", req.IngressId, "status", info.State.Status) } return &emptypb.Empty{}, nil diff --git a/pkg/service/ioservice_sip.go b/pkg/service/ioservice_sip.go index 3f0b34042..540b906b5 100644 --- a/pkg/service/ioservice_sip.go +++ b/pkg/service/ioservice_sip.go @@ -1,252 +1,28 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package service import ( "context" - "fmt" - "math" - "regexp" - "sort" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/sip" ) -// sipRulePriority returns sorting priority for dispatch rules. Lower value means higher priority. -func sipRulePriority(info *livekit.SIPDispatchRuleInfo) int32 { - // In all these cases, prefer pin-protected rules. - // Thus, the order will be the following: - // - 0: Direct or Pin (both pin-protected) - // - 1: Individual (pin-protected) - // - 100: Direct (open) - // - 101: Individual (open) - const ( - last = math.MaxInt32 - ) - // TODO: Maybe allow setting specific priorities for dispatch rules? - switch rule := info.GetRule().GetRule().(type) { - default: - return last - case *livekit.SIPDispatchRule_DispatchRuleDirect: - if rule.DispatchRuleDirect.GetPin() != "" { - return 0 - } - return 100 - case *livekit.SIPDispatchRule_DispatchRuleIndividual: - if rule.DispatchRuleIndividual.GetPin() != "" { - return 1 - } - return 101 - } -} - -// sipSortRules predictably sorts dispatch rules by priority (first one is highest). -func sipSortRules(rules []*livekit.SIPDispatchRuleInfo) { - sort.Slice(rules, func(i, j int) bool { - p1, p2 := sipRulePriority(rules[i]), sipRulePriority(rules[j]) - if p1 < p2 { - return true - } else if p1 > p2 { - return false - } - // For predictable sorting order. - room1, _, _ := sipGetPinAndRoom(rules[i]) - room2, _, _ := sipGetPinAndRoom(rules[j]) - return room1 < room2 - }) -} - -// sipSelectDispatch takes a list of dispatch rules, and takes the decision which one should be selected. -// It returns an error if there are conflicting rules. Returns nil if no rules match. -func sipSelectDispatch(rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) { - if len(rules) == 0 { - return nil, nil - } - // Sorting will do the selection for us. We already filtered out irrelevant ones in matchSIPDispatchRule. - sipSortRules(rules) - byPin := make(map[string]*livekit.SIPDispatchRuleInfo) - var ( - pinRule *livekit.SIPDispatchRuleInfo - openRule *livekit.SIPDispatchRuleInfo - ) - openCnt := 0 - for _, r := range rules { - _, pin, err := sipGetPinAndRoom(r) - if err != nil { - return nil, err - } - if pin == "" { - openRule = r // last one - openCnt++ - } else if r2 := byPin[pin]; r2 != nil { - return nil, fmt.Errorf("Conflicting SIP Dispatch Rules: Same PIN for %q and %q", - r.SipDispatchRuleId, r2.SipDispatchRuleId) - } else { - byPin[pin] = r - // Pick the first one with a Pin. If Pin was provided in the request, we already filtered the right rules. - // If not, this rule will just be used to send RequestPin=true flag. - if pinRule == nil { - pinRule = r - } - } - } - if req.GetPin() != "" { - // If it's still nil that's fine. We will report "no rules matched" later. - return pinRule, nil - } - if pinRule != nil { - return pinRule, nil - } - if openCnt > 1 { - return nil, fmt.Errorf("Conflicting SIP Dispatch Rules: Matched %d open rules for %q", openCnt, req.CallingNumber) - } - return openRule, nil -} - -// sipGetPinAndRoom returns a room name/prefix and the pin for a dispatch rule. Just a convenience wrapper. -func sipGetPinAndRoom(info *livekit.SIPDispatchRuleInfo) (room, pin string, err error) { - // TODO: Could probably add methods on SIPDispatchRuleInfo struct instead. - switch rule := info.GetRule().GetRule().(type) { - default: - return "", "", fmt.Errorf("Unsupported SIP Dispatch Rule: %T", rule) - case *livekit.SIPDispatchRule_DispatchRuleDirect: - pin = rule.DispatchRuleDirect.GetPin() - room = rule.DispatchRuleDirect.GetRoomName() - case *livekit.SIPDispatchRule_DispatchRuleIndividual: - pin = rule.DispatchRuleIndividual.GetPin() - room = rule.DispatchRuleIndividual.GetRoomPrefix() - } - return room, pin, nil -} - -// sipMatchTrunk finds a SIP Trunk definition matching the request. -// Returns nil if no rules matched or an error if there are conflicting definitions. -func sipMatchTrunk(trunks []*livekit.SIPTrunkInfo, calling, called string) (*livekit.SIPTrunkInfo, error) { - var ( - selectedTrunk *livekit.SIPTrunkInfo - defaultTrunk *livekit.SIPTrunkInfo - defaultTrunkCnt int // to error in case there are multiple ones - ) - for _, tr := range trunks { - // Do not consider it if regexp doesn't match. - matches := len(tr.InboundNumbersRegex) == 0 - for _, reStr := range tr.InboundNumbersRegex { - // TODO: we should cache it - re, err := regexp.Compile(reStr) - if err != nil { - logger.Errorw("cannot parse SIP trunk regexp", err, "trunkID", tr.SipTrunkId) - continue - } - if re.MatchString(calling) { - matches = true - break - } - } - if !matches { - continue - } - if tr.OutboundNumber == "" { - // Default/wildcard trunk. - defaultTrunk = tr - defaultTrunkCnt++ - } else if tr.OutboundNumber == called { - // Trunk specific to the number. - if selectedTrunk != nil { - return nil, fmt.Errorf("Multiple SIP Trunks matched for %q", called) - } - selectedTrunk = tr - // Keep searching! We want to know if there are any conflicting Trunk definitions. - } - } - if selectedTrunk != nil { - return selectedTrunk, nil - } - if defaultTrunkCnt > 1 { - return nil, fmt.Errorf("Multiple default SIP Trunks matched for %q", called) - } - // Could still be nil here. - return defaultTrunk, nil -} - -// sipMatchDispatchRule finds the best dispatch rule matching the request parameters. Returns an error if no rule matched. -// Trunk parameter can be nil, in which case only wildcard dispatch rules will be effective (ones without Trunk IDs). -func sipMatchDispatchRule(trunk *livekit.SIPTrunkInfo, rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) { - // Trunk can still be nil here in case none matched or were defined. - // This is still fine, but only in case we'll match exactly one wildcard dispatch rule. - if len(rules) == 0 { - return nil, fmt.Errorf("No SIP Dispatch Rules defined") - } - // We split the matched dispatch rules into two sets: specific and default (aka wildcard). - // First, attempt to match any of the specific rules, where we did match the Trunk ID. - // If nothing matches there - fallback to default/wildcard rules, where no Trunk IDs were mentioned. - var ( - specificRules []*livekit.SIPDispatchRuleInfo - defaultRules []*livekit.SIPDispatchRuleInfo - ) - noPin := req.NoPin - sentPin := req.GetPin() - for _, info := range rules { - _, rulePin, err := sipGetPinAndRoom(info) - if err != nil { - logger.Errorw("Invalid SIP Dispatch Rule", err, "dispatchRuleID", info.SipDispatchRuleId) - continue - } - // Filter heavily on the Pin, so that only relevant rules remain. - if noPin { - if rulePin != "" { - // Skip pin-protected rules if no pin mode requested. - continue - } - } else if sentPin != "" { - if rulePin == "" { - // Pin already sent, skip non-pin-protected rules. - continue - } - if sentPin != rulePin { - // Pin doesn't match. Don't return an error here, just wait for other rule to match (or none at all). - // Note that we will NOT match non-pin-protected rules, thus it will not fallback to open rules. - continue - } - } - if len(info.TrunkIds) == 0 { - // Default/wildcard dispatch rule. - defaultRules = append(defaultRules, info) - continue - } - // Specific dispatch rules. Require a Trunk associated with the number. - if trunk == nil { - continue - } - matches := false - for _, id := range info.TrunkIds { - if id == trunk.SipTrunkId { - matches = true - break - } - } - if !matches { - continue - } - specificRules = append(specificRules, info) - } - best, err := sipSelectDispatch(specificRules, req) - if err != nil { - return nil, err - } else if best != nil { - return best, nil - } - best, err = sipSelectDispatch(defaultRules, req) - if err != nil { - return nil, err - } else if best != nil { - return best, nil - } - if trunk == nil { - return nil, fmt.Errorf("No SIP Trunk or Dispatch Rules matched for %q", req.CalledNumber) - } - return nil, fmt.Errorf("No SIP Dispatch Rules matched for %q", req.CalledNumber) -} - // matchSIPTrunk finds a SIP Trunk definition matching the request. // Returns nil if no rules matched or an error if there are conflicting definitions. func (s *IOInfoService) matchSIPTrunk(ctx context.Context, calling, called string) (*livekit.SIPTrunkInfo, error) { @@ -254,7 +30,7 @@ func (s *IOInfoService) matchSIPTrunk(ctx context.Context, calling, called strin if err != nil { return nil, err } - return sipMatchTrunk(trunks, calling, called) + return sip.MatchTrunk(trunks, calling, called) } // matchSIPDispatchRule finds the best dispatch rule matching the request parameters. Returns an error if no rule matched. @@ -266,7 +42,7 @@ func (s *IOInfoService) matchSIPDispatchRule(ctx context.Context, trunk *livekit if err != nil { return nil, err } - return sipMatchDispatchRule(trunk, rules, req) + return sip.MatchDispatchRule(trunk, rules, req) } func (s *IOInfoService) EvaluateSIPDispatchRules(ctx context.Context, req *rpc.EvaluateSIPDispatchRulesRequest) (*rpc.EvaluateSIPDispatchRulesResponse, error) { @@ -274,46 +50,17 @@ func (s *IOInfoService) EvaluateSIPDispatchRules(ctx context.Context, req *rpc.E if err != nil { return nil, err } + if trunk != nil { + logger.Debugw("SIP trunk matched", "trunkID", trunk.SipTrunkId, "called", req.CalledNumber, "calling", req.CallingNumber) + } else { + logger.Debugw("No SIP trunk matched", "trunkID", "", "called", req.CalledNumber, "calling", req.CallingNumber) + } best, err := s.matchSIPDispatchRule(ctx, trunk, req) if err != nil { return nil, err } - sentPin := req.GetPin() - - from := req.CallingNumber - if best.HidePhoneNumber { - // TODO: Decide on the phone masking format. - // Maybe keep regional code, but mask all but 4 last digits? - from = from[len(from)-4:] - } - fromName := "Phone " + from - - room, rulePin, err := sipGetPinAndRoom(best) - if err != nil { - return nil, err - } - if rulePin != "" { - if sentPin == "" { - return &rpc.EvaluateSIPDispatchRulesResponse{ - RequestPin: true, - }, nil - } - if rulePin != sentPin { - // This should never happen in practice, because matchSIPDispatchRule should remove rules with the wrong pin. - return nil, fmt.Errorf("Incorrect PIN for SIP room") - } - } else { - // Pin was sent, but room doesn't require one. Assume user accidentally pressed phone button. - } - switch rule := best.GetRule().GetRule().(type) { - case *livekit.SIPDispatchRule_DispatchRuleIndividual: - // TODO: Decide on the suffix. Do we need to escape specific characters? - room = rule.DispatchRuleIndividual.GetRoomPrefix() + from - } - return &rpc.EvaluateSIPDispatchRulesResponse{ - RoomName: room, - ParticipantIdentity: fromName, - }, nil + logger.Debugw("SIP dispatch rule matched", "dispatchRule", best.SipDispatchRuleId, "called", req.CalledNumber, "calling", req.CallingNumber) + return sip.EvaluateDispatchRule(best, req) } func (s *IOInfoService) GetSIPTrunkAuthentication(ctx context.Context, req *rpc.GetSIPTrunkAuthenticationRequest) (*rpc.GetSIPTrunkAuthenticationResponse, error) { @@ -321,6 +68,11 @@ func (s *IOInfoService) GetSIPTrunkAuthentication(ctx context.Context, req *rpc. if err != nil { return nil, err } + if trunk == nil { + logger.Debugw("No SIP trunk matched for auth", "trunkID", "", "called", req.To, "calling", req.From) + return &rpc.GetSIPTrunkAuthenticationResponse{}, nil + } + logger.Debugw("SIP trunk matched for auth", "trunkID", trunk.SipTrunkId, "called", req.To, "calling", req.From) return &rpc.GetSIPTrunkAuthenticationResponse{ Username: trunk.InboundUsername, Password: trunk.InboundPassword, diff --git a/pkg/service/ioservice_sip_test.go b/pkg/service/ioservice_sip_test.go deleted file mode 100644 index b8b57fd07..000000000 --- a/pkg/service/ioservice_sip_test.go +++ /dev/null @@ -1,390 +0,0 @@ -package service - -import ( - "fmt" - "testing" - - "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/rpc" - "github.com/stretchr/testify/require" -) - -const ( - sipNumber1 = "1111 1111" - sipNumber2 = "2222 2222" - sipNumber3 = "3333 3333" - sipTrunkID1 = "aaa" - sipTrunkID2 = "bbb" -) - -func TestSIPMatchTrunk(t *testing.T) { - cases := []struct { - name string - trunks []*livekit.SIPTrunkInfo - exp int - expErr bool - }{ - { - name: "empty", - trunks: nil, - exp: -1, // no error; nil result - }, - { - name: "one wildcard", - trunks: []*livekit.SIPTrunkInfo{ - {SipTrunkId: "aaa"}, - }, - exp: 0, - }, - { - name: "matching", - trunks: []*livekit.SIPTrunkInfo{ - {SipTrunkId: "aaa", OutboundNumber: sipNumber2}, - }, - exp: 0, - }, - { - name: "matching regexp", - trunks: []*livekit.SIPTrunkInfo{ - {SipTrunkId: "aaa", OutboundNumber: sipNumber2, InboundNumbersRegex: []string{`^\d+ \d+$`}}, - }, - exp: 0, - }, - { - name: "not matching", - trunks: []*livekit.SIPTrunkInfo{ - {SipTrunkId: "aaa", OutboundNumber: sipNumber3}, - }, - exp: -1, - }, - { - name: "not matching regexp", - trunks: []*livekit.SIPTrunkInfo{ - {SipTrunkId: "aaa", OutboundNumber: sipNumber2, InboundNumbersRegex: []string{`^\d+$`}}, - }, - exp: -1, - }, - { - name: "one match", - trunks: []*livekit.SIPTrunkInfo{ - {SipTrunkId: "aaa", OutboundNumber: sipNumber3}, - {SipTrunkId: "bbb", OutboundNumber: sipNumber2}, - }, - exp: 1, - }, - { - name: "many matches", - trunks: []*livekit.SIPTrunkInfo{ - {SipTrunkId: "aaa", OutboundNumber: sipNumber3}, - {SipTrunkId: "bbb", OutboundNumber: sipNumber2}, - {SipTrunkId: "ccc", OutboundNumber: sipNumber2}, - }, - expErr: true, - }, - { - name: "many matches default", - trunks: []*livekit.SIPTrunkInfo{ - {SipTrunkId: "aaa", OutboundNumber: sipNumber3}, - {SipTrunkId: "bbb"}, - {SipTrunkId: "ccc", OutboundNumber: sipNumber2}, - {SipTrunkId: "ddd"}, - }, - exp: 2, - }, - { - name: "regexp", - trunks: []*livekit.SIPTrunkInfo{ - {SipTrunkId: "aaa", OutboundNumber: sipNumber3}, - {SipTrunkId: "bbb", OutboundNumber: sipNumber2}, - {SipTrunkId: "ccc", OutboundNumber: sipNumber2, InboundNumbersRegex: []string{`^\d+$`}}, - }, - exp: 1, - }, - { - name: "multiple defaults", - trunks: []*livekit.SIPTrunkInfo{ - {SipTrunkId: "aaa", OutboundNumber: sipNumber3}, - {SipTrunkId: "bbb"}, - {SipTrunkId: "ccc"}, - }, - expErr: true, - }, - } - for _, c := range cases { - c := c - t.Run(c.name, func(t *testing.T) { - got, err := sipMatchTrunk(c.trunks, sipNumber1, sipNumber2) - if c.expErr { - require.Error(t, err) - require.Nil(t, got) - t.Log(err) - } else { - var exp *livekit.SIPTrunkInfo - if c.exp >= 0 { - exp = c.trunks[c.exp] - } - require.NoError(t, err) - require.Equal(t, exp, got) - } - }) - } -} - -func newSIPTrunkDispatch() *livekit.SIPTrunkInfo { - return &livekit.SIPTrunkInfo{ - SipTrunkId: sipTrunkID1, - OutboundNumber: sipNumber2, - } -} - -func newSIPReqDispatch(pin string, noPin bool) *rpc.EvaluateSIPDispatchRulesRequest { - return &rpc.EvaluateSIPDispatchRulesRequest{ - CallingNumber: sipNumber1, - CalledNumber: sipNumber2, - Pin: pin, - //NoPin: noPin, // TODO - } -} - -func newDirectDispatch(room, pin string) *livekit.SIPDispatchRule { - return &livekit.SIPDispatchRule{ - Rule: &livekit.SIPDispatchRule_DispatchRuleDirect{ - DispatchRuleDirect: &livekit.SIPDispatchRuleDirect{ - RoomName: room, Pin: pin, - }, - }, - } -} - -func newIndividualDispatch(roomPref, pin string) *livekit.SIPDispatchRule { - return &livekit.SIPDispatchRule{ - Rule: &livekit.SIPDispatchRule_DispatchRuleIndividual{ - DispatchRuleIndividual: &livekit.SIPDispatchRuleIndividual{ - RoomPrefix: roomPref, Pin: pin, - }, - }, - } -} - -func TestSIPMatchDispatchRule(t *testing.T) { - cases := []struct { - name string - trunk *livekit.SIPTrunkInfo - rules []*livekit.SIPDispatchRuleInfo - reqPin string - noPin bool - exp int - expErr bool - }{ - // These cases just validate that no rules produce an error. - { - name: "empty", - trunk: nil, - rules: nil, - expErr: true, - }, - { - name: "only trunk", - trunk: newSIPTrunkDispatch(), - rules: nil, - expErr: true, - }, - // Default rules should work even if no trunk is defined. - { - name: "one rule/no trunk", - trunk: nil, - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: nil, Rule: newDirectDispatch("sip", "")}, - }, - exp: 0, - }, - // Default rule should work with a trunk too. - { - name: "one rule/default trunk", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: nil, Rule: newDirectDispatch("sip", "")}, - }, - exp: 0, - }, - // Rule matching the trunk should be selected. - { - name: "one rule/specific trunk", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: []string{sipTrunkID1, sipTrunkID2}, Rule: newDirectDispatch("sip", "")}, - }, - exp: 0, - }, - // Rule NOT matching the trunk should NOT be selected. - { - name: "one rule/wrong trunk", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: []string{"zzz"}, Rule: newDirectDispatch("sip", "")}, - }, - expErr: true, - }, - // Direct rule with a pin should be selected, even if no pin is provided. - { - name: "direct pin/correct", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: []string{sipTrunkID1}, Rule: newDirectDispatch("sip", "123")}, - {TrunkIds: []string{sipTrunkID2}, Rule: newDirectDispatch("sip", "456")}, - }, - reqPin: "123", - exp: 0, - }, - // Direct rule with a pin should reject wrong pin. - { - name: "direct pin/wrong", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: []string{sipTrunkID1}, Rule: newDirectDispatch("sip", "123")}, - {TrunkIds: []string{sipTrunkID2}, Rule: newDirectDispatch("sip", "456")}, - }, - reqPin: "zzz", - expErr: true, - }, - // Multiple direct rules with the same pin should result in an error. - { - name: "direct pin/conflict", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: []string{sipTrunkID1}, Rule: newDirectDispatch("sip1", "123")}, - {TrunkIds: []string{sipTrunkID1, sipTrunkID2}, Rule: newDirectDispatch("sip2", "123")}, - }, - reqPin: "123", - expErr: true, - }, - // Multiple direct rules with the same pin on different trunks are ok. - { - name: "direct pin/no conflict on different trunk", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: []string{sipTrunkID1}, Rule: newDirectDispatch("sip1", "123")}, - {TrunkIds: []string{sipTrunkID2}, Rule: newDirectDispatch("sip2", "123")}, - }, - reqPin: "123", - exp: 0, - }, - // Specific direct rules should take priority over default direct rules. - { - name: "direct pin/default and specific", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: nil, Rule: newDirectDispatch("sip1", "123")}, - {TrunkIds: []string{sipTrunkID1}, Rule: newDirectDispatch("sip2", "123")}, - }, - reqPin: "123", - exp: 1, - }, - // Specific direct rules should take priority over default direct rules. No pin. - { - name: "direct/default and specific", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: nil, Rule: newDirectDispatch("sip1", "")}, - {TrunkIds: []string{sipTrunkID1}, Rule: newDirectDispatch("sip2", "")}, - }, - exp: 1, - }, - // Specific direct rules should take priority over default direct rules. One with pin, other without. - { - name: "direct/default and specific/mixed 1", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: nil, Rule: newDirectDispatch("sip1", "123")}, - {TrunkIds: []string{sipTrunkID1}, Rule: newDirectDispatch("sip2", "")}, - }, - exp: 1, - }, - { - name: "direct/default and specific/mixed 2", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: nil, Rule: newDirectDispatch("sip1", "")}, - {TrunkIds: []string{sipTrunkID1}, Rule: newDirectDispatch("sip2", "123")}, - }, - exp: 1, - }, - // Multiple default direct rules are not allowed. - { - name: "direct/multiple defaults", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: nil, Rule: newDirectDispatch("sip1", "")}, - {TrunkIds: nil, Rule: newDirectDispatch("sip2", "")}, - }, - expErr: true, - }, - // Cannot use both direct and individual rules with the same pin setup. - { - name: "direct vs individual/private", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: nil, Rule: newIndividualDispatch("pref_", "123")}, - {TrunkIds: nil, Rule: newDirectDispatch("sip", "123")}, - }, - expErr: true, - }, - { - name: "direct vs individual/open", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: nil, Rule: newIndividualDispatch("pref_", "")}, - {TrunkIds: nil, Rule: newDirectDispatch("sip", "")}, - }, - expErr: true, - }, - // Direct rules take priority over individual rules. - { - name: "direct vs individual/priority", - trunk: newSIPTrunkDispatch(), - rules: []*livekit.SIPDispatchRuleInfo{ - {TrunkIds: nil, Rule: newIndividualDispatch("pref_", "123")}, - {TrunkIds: nil, Rule: newDirectDispatch("sip", "456")}, - }, - reqPin: "456", - exp: 1, - }, - } - for _, c := range cases { - c := c - t.Run(c.name, func(t *testing.T) { - pins := []string{c.reqPin} - if !c.expErr && c.reqPin != "" { - // Should match the same rule, even if no pin is set (so that it can be requested). - pins = append(pins, "") - } - for i, r := range c.rules { - if r.SipDispatchRuleId == "" { - r.SipDispatchRuleId = fmt.Sprintf("rule_%d", i) - } - } - for _, pin := range pins { - pin := pin - name := pin - if name == "" { - name = "no pin" - } - t.Run(name, func(t *testing.T) { - got, err := sipMatchDispatchRule(c.trunk, c.rules, newSIPReqDispatch(pin, c.noPin)) - if c.expErr { - require.Error(t, err) - require.Nil(t, got) - t.Log(err) - } else { - var exp *livekit.SIPDispatchRuleInfo - if c.exp >= 0 { - exp = c.rules[c.exp] - } - require.NoError(t, err) - require.Equal(t, exp, got) - } - }) - } - }) - } -} diff --git a/pkg/service/redisstore.go b/pkg/service/redisstore.go index dfb4c3626..91b437f40 100644 --- a/pkg/service/redisstore.go +++ b/pkg/service/redisstore.go @@ -26,11 +26,12 @@ import ( "github.com/redis/go-redis/v9" "google.golang.org/protobuf/proto" - "github.com/livekit/livekit-server/version" "github.com/livekit/protocol/ingress" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/utils" + + "github.com/livekit/livekit-server/version" ) const ( @@ -53,7 +54,6 @@ const ( SIPTrunkKey = "sip_trunk" SIPDispatchRuleKey = "sip_dispatch_rule" - SIPParticipantKey = "sip_participant" // RoomParticipantsPrefix is hash of participant_name => ParticipantInfo RoomParticipantsPrefix = "room_participants:" @@ -908,37 +908,3 @@ func (s *RedisStore) ListSIPDispatchRule(ctx context.Context) (infos []*livekit. return infos, err } - -func (s *RedisStore) StoreSIPParticipant(ctx context.Context, info *livekit.SIPParticipantInfo) error { - data, err := proto.Marshal(info) - if err != nil { - return err - } - - return s.rc.HSet(s.ctx, SIPParticipantKey, info.SipParticipantId, data).Err() -} -func (s *RedisStore) LoadSIPParticipant(ctx context.Context, sipParticipantId string) (*livekit.SIPParticipantInfo, error) { - info := &livekit.SIPParticipantInfo{} - if err := s.loadOne(ctx, SIPParticipantKey, sipParticipantId, info, ErrSIPParticipantNotFound); err != nil { - return nil, err - } - - return info, nil -} - -func (s *RedisStore) DeleteSIPParticipant(ctx context.Context, info *livekit.SIPParticipantInfo) error { - return s.rc.HDel(s.ctx, SIPParticipantKey, info.SipParticipantId).Err() -} - -func (s *RedisStore) ListSIPParticipant(ctx context.Context) (infos []*livekit.SIPParticipantInfo, err error) { - err = s.loadMany(ctx, SIPParticipantKey, func() proto.Message { - infos = append(infos, &livekit.SIPParticipantInfo{}) - return infos[len(infos)-1] - }) - - return infos, err -} - -func (s *RedisStore) SendSIPParticipantDTMF(ctx context.Context, info *livekit.SendSIPParticipantDTMFRequest) (*livekit.SIPParticipantDTMFInfo, error) { - return nil, fmt.Errorf("TODO") -} diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index a6fa679f3..b09ec48c2 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -101,7 +101,7 @@ func NewLocalRoomManager( return nil, err } - r := &RoomManager{ + return &RoomManager{ config: conf, rtcConfig: rtcConf, currentNode: currentNode, @@ -126,12 +126,7 @@ func NewLocalRoomManager( Region: conf.Region, NodeId: currentNode.Id, }, - } - - // hook up to router - router.OnNewParticipantRTC(r.StartSession) - router.OnRTCMessage(r.handleRTCMessage) - return r, nil + }, nil } func (r *RoomManager) GetRoom(_ context.Context, roomName livekit.RoomName) *rtc.Room { @@ -218,10 +213,7 @@ func (r *RoomManager) Stop() { r.lock.RUnlock() for _, room := range rooms { - for _, p := range room.GetParticipants() { - _ = p.Close(true, types.ParticipantCloseReasonRoomManagerStop, false) - } - room.Close() + room.Close(types.ParticipantCloseReasonRoomManagerStop) } r.roomServers.Kill() @@ -245,6 +237,8 @@ func (r *RoomManager) StartSession( requestSource routing.MessageSource, responseSink routing.MessageSink, ) error { + sessionStartTime := time.Now() + room, err := r.getOrCreateRoom(ctx, roomName) if err != nil { return err @@ -279,12 +273,23 @@ func (r *RoomManager) StartSession( "participant", pi.Identity, "reason", pi.ReconnectReason, ) + + var leave *livekit.LeaveRequest + pv := types.ProtocolVersion(pi.Client.Protocol) + if pv.SupportsRegionsInLeaveRequest() { + leave = &livekit.LeaveRequest{ + Reason: livekit.DisconnectReason_STATE_MISMATCH, + Action: livekit.LeaveRequest_RECONNECT, + } + } else { + leave = &livekit.LeaveRequest{ + CanReconnect: true, + Reason: livekit.DisconnectReason_STATE_MISMATCH, + } + } _ = responseSink.WriteMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Leave{ - Leave: &livekit.LeaveRequest{ - CanReconnect: true, - Reason: livekit.DisconnectReason_STATE_MISMATCH, - }, + Leave: leave, }, }) return errors.New("could not restart closed participant") @@ -324,12 +329,22 @@ func (r *RoomManager) StartSession( } else if pi.Reconnect { // send leave request if participant is trying to reconnect without keep subscribe state // but missing from the room + var leave *livekit.LeaveRequest + pv := types.ProtocolVersion(pi.Client.Protocol) + if pv.SupportsRegionsInLeaveRequest() { + leave = &livekit.LeaveRequest{ + Reason: livekit.DisconnectReason_STATE_MISMATCH, + Action: livekit.LeaveRequest_RECONNECT, + } + } else { + leave = &livekit.LeaveRequest{ + CanReconnect: true, + Reason: livekit.DisconnectReason_STATE_MISMATCH, + } + } _ = responseSink.WriteMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Leave{ - Leave: &livekit.LeaveRequest{ - CanReconnect: true, - Reason: livekit.DisconnectReason_STATE_MISMATCH, - }, + Leave: leave, }, }) return errors.New("could not restart participant") @@ -390,6 +405,7 @@ func (r *RoomManager) StartSession( AudioConfig: r.config.Audio, VideoConfig: r.config.Video, ProtocolVersion: pv, + SessionStartTime: sessionStartTime, Telemetry: r.telemetry, Trailer: room.Trailer(), PLIThrottleConfig: r.config.RTC.PLIThrottle, @@ -635,26 +651,6 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.LocalPa } } -// handles RTC messages resulted from Room API calls -func (r *RoomManager) handleRTCMessage(ctx context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity, msg *livekit.RTCNodeMessage) { - switch rm := msg.Message.(type) { - case *livekit.RTCNodeMessage_RemoveParticipant: - r.RemoveParticipant(ctx, rm.RemoveParticipant) - case *livekit.RTCNodeMessage_MuteTrack: - r.MutePublishedTrack(ctx, rm.MuteTrack) - case *livekit.RTCNodeMessage_UpdateParticipant: - r.UpdateParticipant(ctx, rm.UpdateParticipant) - case *livekit.RTCNodeMessage_DeleteRoom: - r.DeleteRoom(ctx, rm.DeleteRoom) - case *livekit.RTCNodeMessage_UpdateSubscriptions: - r.UpdateSubscriptions(ctx, rm.UpdateSubscriptions) - case *livekit.RTCNodeMessage_SendData: - r.SendData(ctx, rm.SendData) - case *livekit.RTCNodeMessage_UpdateRoomMetadata: - r.UpdateRoomMetadata(ctx, rm.UpdateRoomMetadata) - } -} - type participantReq interface { GetRoom() string GetIdentity() string @@ -728,10 +724,7 @@ func (r *RoomManager) DeleteRoom(ctx context.Context, req *livekit.DeleteRoomReq } } else { room.Logger.Infow("deleting room") - for _, p := range room.GetParticipants() { - _ = p.Close(true, types.ParticipantCloseReasonServiceRequestDeleteRoom, false) - } - room.Close() + room.Close(types.ParticipantCloseReasonServiceRequestDeleteRoom) } return &livekit.DeleteRoomResponse{}, nil } @@ -776,7 +769,9 @@ func (r *RoomManager) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat } room.Logger.Debugw("updating room") - room.SetMetadata(req.Metadata) + done := room.SetMetadata(req.Metadata) + // wait till the update is applied + <-done return room.ToProto(), nil } diff --git a/pkg/service/roomservice.go b/pkg/service/roomservice.go index 174ea3391..d9bac7065 100644 --- a/pkg/service/roomservice.go +++ b/pkg/service/roomservice.go @@ -16,22 +16,19 @@ package service import ( "context" - "fmt" "strconv" "time" "github.com/pkg/errors" - "github.com/thoas/go-funk" "github.com/twitchtv/twirp" - "google.golang.org/protobuf/proto" "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/utils" + "github.com/livekit/psrpc" ) // A rooms service that supports a single node @@ -93,15 +90,15 @@ func (s *RoomService) CreateRoom(ctx context.Context, req *livekit.CreateRoomReq } // actually start the room on an RTC node, to ensure metadata & empty timeout functionality - _, sink, source, err := s.router.StartParticipantSignal(ctx, + res, err := s.router.StartParticipantSignal(ctx, livekit.RoomName(req.Name), routing.ParticipantInit{}, ) if err != nil { return nil, err } - defer sink.Close() - defer source.Close() + defer res.RequestSink.Close() + defer res.ResponseSource.Close() // ensure it's created correctly err = s.confirmExecution(func() error { @@ -170,39 +167,13 @@ func (s *RoomService) DeleteRoom(ctx context.Context, req *livekit.DeleteRoomReq return nil, twirpAuthError(err) } - if s.psrpcConf.Enabled { - return s.roomClient.DeleteRoom(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) + _, err := s.roomClient.DeleteRoom(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) + if !errors.Is(err, psrpc.ErrNoResponse) { + return &livekit.DeleteRoomResponse{}, err } - if _, _, err := s.roomStore.LoadRoom(ctx, livekit.RoomName(req.Room), false); err == ErrRoomNotFound { - return nil, twirp.NotFoundError("room not found") - } - - err := s.router.WriteRoomRTC(ctx, livekit.RoomName(req.Room), &livekit.RTCNodeMessage{ - Message: &livekit.RTCNodeMessage_DeleteRoom{ - DeleteRoom: req, - }, - }) - if err != nil { - return nil, err - } - - // we should not return until when the room is confirmed deleted - err = s.confirmExecution(func() error { - _, _, err := s.roomStore.LoadRoom(ctx, livekit.RoomName(req.Room), false) - if err == nil { - return ErrOperationFailed - } else if err != ErrRoomNotFound { - return err - } else { - return nil - } - }) - if err != nil { - return nil, err - } - - return &livekit.DeleteRoomResponse{}, nil + err = s.roomStore.DeleteRoom(ctx, livekit.RoomName(req.Room)) + return &livekit.DeleteRoomResponse{}, err } func (s *RoomService) ListParticipants(ctx context.Context, req *livekit.ListParticipantsRequest) (*livekit.ListParticipantsResponse, error) { @@ -247,34 +218,7 @@ func (s *RoomService) RemoveParticipant(ctx context.Context, req *livekit.RoomPa return nil, twirp.NotFoundError("participant not found") } - if s.psrpcConf.Enabled { - return s.participantClient.RemoveParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) - } - - err := s.writeParticipantMessage(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity), &livekit.RTCNodeMessage{ - Message: &livekit.RTCNodeMessage_RemoveParticipant{ - RemoveParticipant: req, - }, - }) - if err != nil { - return nil, err - } - - err = s.confirmExecution(func() error { - _, err := s.roomStore.LoadParticipant(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)) - if err == ErrParticipantNotFound { - return nil - } else if err != nil { - return err - } else { - return ErrOperationFailed - } - }) - if err != nil { - return nil, err - } - - return &livekit.RemoveParticipantResponse{}, nil + return s.participantClient.RemoveParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) } func (s *RoomService) MutePublishedTrack(ctx context.Context, req *livekit.MuteRoomTrackRequest) (*livekit.MuteRoomTrackResponse, error) { @@ -283,47 +227,7 @@ func (s *RoomService) MutePublishedTrack(ctx context.Context, req *livekit.MuteR return nil, twirpAuthError(err) } - if s.psrpcConf.Enabled { - return s.participantClient.MutePublishedTrack(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) - } - - err := s.writeParticipantMessage(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity), &livekit.RTCNodeMessage{ - Message: &livekit.RTCNodeMessage_MuteTrack{ - MuteTrack: req, - }, - }) - if err != nil { - return nil, err - } - - var track *livekit.TrackInfo - err = s.confirmExecution(func() error { - p, err := s.roomStore.LoadParticipant(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)) - if err != nil { - return err - } - // ensure track is muted - t := funk.Find(p.Tracks, func(t *livekit.TrackInfo) bool { - return t.Sid == req.TrackSid - }) - var ok bool - track, ok = t.(*livekit.TrackInfo) - if !ok { - return ErrTrackNotFound - } - if track.Muted != req.Muted { - return ErrOperationFailed - } - return nil - }) - if err != nil { - return nil, err - } - - res := &livekit.MuteRoomTrackResponse{ - Track: track, - } - return res, nil + return s.participantClient.MutePublishedTrack(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) } func (s *RoomService) UpdateParticipant(ctx context.Context, req *livekit.UpdateParticipantRequest) (*livekit.ParticipantInfo, error) { @@ -337,42 +241,7 @@ func (s *RoomService) UpdateParticipant(ctx context.Context, req *livekit.Update return nil, twirpAuthError(err) } - if s.psrpcConf.Enabled { - return s.participantClient.UpdateParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) - } - - err := s.writeParticipantMessage(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity), &livekit.RTCNodeMessage{ - Message: &livekit.RTCNodeMessage_UpdateParticipant{ - UpdateParticipant: req, - }, - }) - if err != nil { - return nil, err - } - - var participant *livekit.ParticipantInfo - var detailedError error - err = s.confirmExecution(func() error { - participant, err = s.roomStore.LoadParticipant(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)) - if err != nil { - return err - } - if req.Metadata != "" && participant.Metadata != req.Metadata { - detailedError = fmt.Errorf("metadata does not match") - return ErrOperationFailed - } - if req.Permission != nil && !proto.Equal(req.Permission, participant.Permission) { - detailedError = fmt.Errorf("permissions do not match, expected: %v, actual: %v", req.Permission, participant.Permission) - return ErrOperationFailed - } - return nil - }) - if err != nil { - logger.Warnw("could not confirm participant update", detailedError) - return nil, err - } - - return participant, nil + return s.participantClient.UpdateParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) } func (s *RoomService) UpdateSubscriptions(ctx context.Context, req *livekit.UpdateSubscriptionsRequest) (*livekit.UpdateSubscriptionsResponse, error) { @@ -386,20 +255,7 @@ func (s *RoomService) UpdateSubscriptions(ctx context.Context, req *livekit.Upda return nil, twirpAuthError(err) } - if s.psrpcConf.Enabled { - return s.participantClient.UpdateSubscriptions(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) - } - - err := s.writeParticipantMessage(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity), &livekit.RTCNodeMessage{ - Message: &livekit.RTCNodeMessage_UpdateSubscriptions{ - UpdateSubscriptions: req, - }, - }) - if err != nil { - return nil, err - } - - return &livekit.UpdateSubscriptionsResponse{}, nil + return s.participantClient.UpdateSubscriptions(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) } func (s *RoomService) SendData(ctx context.Context, req *livekit.SendDataRequest) (*livekit.SendDataResponse, error) { @@ -409,20 +265,7 @@ func (s *RoomService) SendData(ctx context.Context, req *livekit.SendDataRequest return nil, twirpAuthError(err) } - if s.psrpcConf.Enabled { - return s.roomClient.SendData(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) - } - - err := s.router.WriteRoomRTC(ctx, roomName, &livekit.RTCNodeMessage{ - Message: &livekit.RTCNodeMessage_SendData{ - SendData: req, - }, - }) - if err != nil { - return nil, err - } - - return &livekit.SendDataResponse{}, nil + return s.roomClient.SendData(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) } func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.UpdateRoomMetadataRequest) (*livekit.Room, error) { @@ -451,20 +294,9 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat return nil, err } - if s.psrpcConf.Enabled { - _, err := s.roomClient.UpdateRoomMetadata(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) - if err != nil { - return nil, err - } - } else { - err = s.router.WriteRoomRTC(ctx, livekit.RoomName(req.Room), &livekit.RTCNodeMessage{ - Message: &livekit.RTCNodeMessage_UpdateRoomMetadata{ - UpdateRoomMetadata: req, - }, - }) - if err != nil { - return nil, err - } + _, err = s.roomClient.UpdateRoomMetadata(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) + if err != nil { + return nil, err } err = s.confirmExecution(func() error { @@ -494,14 +326,6 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat return room, nil } -func (s *RoomService) writeParticipantMessage(ctx context.Context, room livekit.RoomName, identity livekit.ParticipantIdentity, msg *livekit.RTCNodeMessage) error { - if err := EnsureAdminPermission(ctx, room); err != nil { - return twirpAuthError(err) - } - - return s.router.WriteParticipantRTC(ctx, room, identity, msg) -} - func (s *RoomService) confirmExecution(f func() error) error { expired := time.After(s.apiConf.ExecutionTimeout) var err error diff --git a/pkg/service/roomservice_test.go b/pkg/service/roomservice_test.go index f9de24a51..0237117f4 100644 --- a/pkg/service/roomservice_test.go +++ b/pkg/service/roomservice_test.go @@ -33,26 +33,6 @@ import ( ) func TestDeleteRoom(t *testing.T) { - t.Run("delete non-existent", func(t *testing.T) { - svc := newTestRoomService(config.RoomConfig{}) - grant := &auth.ClaimGrants{ - Video: &auth.VideoGrant{ - RoomCreate: true, - }, - } - ctx := service.WithGrants(context.Background(), grant) - svc.store.LoadRoomReturns(nil, nil, service.ErrRoomNotFound) - _, err := svc.DeleteRoom(ctx, &livekit.DeleteRoomRequest{ - Room: "testroom", - }) - require.Error(t, err) - if terr, ok := err.(twirp.Error); ok { - require.Equal(t, twirp.NotFound, terr.Code()) - } else { - require.Fail(t, "should be twirp error") - } - }) - t.Run("missing permissions", func(t *testing.T) { svc := newTestRoomService(config.RoomConfig{}) grant := &auth.ClaimGrants{ diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index 86ff5f630..baef4aa30 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -29,6 +29,7 @@ import ( "github.com/gorilla/websocket" "github.com/ua-parser/uap-go/uaparser" + "go.uber.org/atomic" "golang.org/x/exp/maps" "github.com/livekit/livekit-server/pkg/config" @@ -39,7 +40,6 @@ import ( "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/logger" putil "github.com/livekit/protocol/utils" "github.com/livekit/psrpc" ) @@ -97,7 +97,7 @@ func NewRTCService( func (s *RTCService) Validate(w http.ResponseWriter, r *http.Request) { _, _, code, err := s.validate(r) if err != nil { - handleError(w, code, err) + handleError(w, r, code, err) return } _, _ = w.Write([]byte("success")) @@ -202,7 +202,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { roomName, pi, code, err := s.validate(r) if err != nil { - handleError(w, code, err) + handleError(w, r, code, err) return } @@ -213,6 +213,8 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { "remote", false, } + l := utils.GetLogger(r.Context()) + // give it a few attempts to start session var cr connectionResult var initialResponse *livekit.SignalResponse @@ -229,13 +231,13 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if i < 2 { fieldsWithAttempt := append(loggerFields, "attempt", i) - logger.Warnw("failed to start connection, retrying", err, fieldsWithAttempt...) + l.Warnw("failed to start connection, retrying", err, fieldsWithAttempt...) } } if err != nil { prometheus.IncrementParticipantJoinFail(1) - handleError(w, http.StatusInternalServerError, err, loggerFields...) + handleError(w, r, http.StatusInternalServerError, err, loggerFields...) return } @@ -254,16 +256,20 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { } pLogger := rtc.LoggerWithParticipant( - rtc.LoggerWithRoom(logger.GetLogger(), roomName, livekit.RoomID(cr.Room.Sid)), + rtc.LoggerWithRoom(l, roomName, livekit.RoomID(cr.Room.Sid)), pi.Identity, pi.ID, false, ) + closedByClient := atomic.NewBool(false) done := make(chan struct{}) // function exits when websocket terminates, it'll close the event reading off of request sink and response source as well defer func() { - pLogger.Infow("finishing WS connection", "connID", cr.ConnectionID) + pLogger.Debugw("finishing WS connection", + "connID", cr.ConnectionID, + "closedByClient", closedByClient.Load(), + ) cr.ResponseSource.Close() cr.RequestSink.Close() close(done) @@ -276,7 +282,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { // upgrade only once the basics are good to go conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { - handleError(w, http.StatusInternalServerError, err, loggerFields...) + handleError(w, r, http.StatusInternalServerError, err, loggerFields...) return } @@ -300,11 +306,13 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { if signalStats != nil { signalStats.AddBytes(uint64(count), true) } - pLogger.Infow("new client WS connected", + pLogger.Debugw("new client WS connected", "connID", cr.ConnectionID, "reconnect", pi.Reconnect, "reconnectReason", pi.ReconnectReason, "adaptiveStream", pi.AdaptiveStream, + "selectedNodeID", cr.NodeID, + "nodeSelectionReason", cr.NodeSelectionReason, ) // handle responses @@ -366,7 +374,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { req, count, err := sigConn.ReadRequest() if err != nil { // normal/expected closure - if err == io.EOF || + if errors.Is(err, io.EOF) || strings.HasSuffix(err.Error(), "use of closed network connection") || strings.HasSuffix(err.Error(), "connection reset by peer") || websocket.IsCloseError( @@ -376,7 +384,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { websocket.CloseNormalClosure, websocket.CloseNoStatusReceived, ) { - pLogger.Infow("exit ws read loop for closed connection", "connID", cr.ConnectionID, "wsError", err) + closedByClient.Store(true) } else { pLogger.Errorw("error reading from websocket", err, "connID", cr.ConnectionID) } @@ -512,10 +520,8 @@ func (s *RTCService) DrainConnections(interval time.Duration) { } type connectionResult struct { - Room *livekit.Room - ConnectionID livekit.ConnectionID - RequestSink routing.MessageSink - ResponseSource routing.MessageSource + routing.StartParticipantSignalResults + Room *livekit.Room } func (s *RTCService) startConnection( @@ -533,7 +539,7 @@ func (s *RTCService) startConnection( } // this needs to be started first *before* using router functions on this node - cr.ConnectionID, cr.RequestSink, cr.ResponseSource, err = s.router.StartParticipantSignal(ctx, roomName, pi) + cr.StartParticipantSignalResults, err = s.router.StartParticipantSignal(ctx, roomName, pi) if err != nil { return cr, nil, err } diff --git a/pkg/service/server.go b/pkg/service/server.go index f66d848b5..fcb09a648 100644 --- a/pkg/service/server.go +++ b/pkg/service/server.go @@ -37,7 +37,6 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" - sutils "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/livekit-server/version" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" @@ -107,7 +106,7 @@ func NewLivekitServer(conf *config.Config, middlewares = append(middlewares, NewAPIKeyAuthMiddleware(keyProvider)) } - twirpLoggingHook := TwirpLogger(logger.GetLogger().WithComponent(sutils.ComponentAPI)) + twirpLoggingHook := TwirpLogger() twirpRequestStatusHook := TwirpRequestStatusReporter() roomServer := livekit.NewRoomServiceServer(roomService, twirpLoggingHook) egressServer := livekit.NewEgressServer(egressService, twirp.WithServerHooks( diff --git a/pkg/service/servicefakes/fake_service_store.go b/pkg/service/servicefakes/fake_service_store.go index 0ac442161..7ecf3e213 100644 --- a/pkg/service/servicefakes/fake_service_store.go +++ b/pkg/service/servicefakes/fake_service_store.go @@ -10,6 +10,18 @@ import ( ) type FakeServiceStore struct { + DeleteRoomStub func(context.Context, livekit.RoomName) error + deleteRoomMutex sync.RWMutex + deleteRoomArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + deleteRoomReturns struct { + result1 error + } + deleteRoomReturnsOnCall map[int]struct { + result1 error + } ListParticipantsStub func(context.Context, livekit.RoomName) ([]*livekit.ParticipantInfo, error) listParticipantsMutex sync.RWMutex listParticipantsArgsForCall []struct { @@ -74,6 +86,68 @@ type FakeServiceStore struct { invocationsMutex sync.RWMutex } +func (fake *FakeServiceStore) DeleteRoom(arg1 context.Context, arg2 livekit.RoomName) error { + fake.deleteRoomMutex.Lock() + ret, specificReturn := fake.deleteRoomReturnsOnCall[len(fake.deleteRoomArgsForCall)] + fake.deleteRoomArgsForCall = append(fake.deleteRoomArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.DeleteRoomStub + fakeReturns := fake.deleteRoomReturns + fake.recordInvocation("DeleteRoom", []interface{}{arg1, arg2}) + fake.deleteRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeServiceStore) DeleteRoomCallCount() int { + fake.deleteRoomMutex.RLock() + defer fake.deleteRoomMutex.RUnlock() + return len(fake.deleteRoomArgsForCall) +} + +func (fake *FakeServiceStore) DeleteRoomCalls(stub func(context.Context, livekit.RoomName) error) { + fake.deleteRoomMutex.Lock() + defer fake.deleteRoomMutex.Unlock() + fake.DeleteRoomStub = stub +} + +func (fake *FakeServiceStore) DeleteRoomArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.deleteRoomMutex.RLock() + defer fake.deleteRoomMutex.RUnlock() + argsForCall := fake.deleteRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeServiceStore) DeleteRoomReturns(result1 error) { + fake.deleteRoomMutex.Lock() + defer fake.deleteRoomMutex.Unlock() + fake.DeleteRoomStub = nil + fake.deleteRoomReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeServiceStore) DeleteRoomReturnsOnCall(i int, result1 error) { + fake.deleteRoomMutex.Lock() + defer fake.deleteRoomMutex.Unlock() + fake.DeleteRoomStub = nil + if fake.deleteRoomReturnsOnCall == nil { + fake.deleteRoomReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteRoomReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeServiceStore) ListParticipants(arg1 context.Context, arg2 livekit.RoomName) ([]*livekit.ParticipantInfo, error) { fake.listParticipantsMutex.Lock() ret, specificReturn := fake.listParticipantsReturnsOnCall[len(fake.listParticipantsArgsForCall)] @@ -347,6 +421,8 @@ func (fake *FakeServiceStore) LoadRoomReturnsOnCall(i int, result1 *livekit.Room func (fake *FakeServiceStore) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() + fake.deleteRoomMutex.RLock() + defer fake.deleteRoomMutex.RUnlock() fake.listParticipantsMutex.RLock() defer fake.listParticipantsMutex.RUnlock() fake.listRoomsMutex.RLock() diff --git a/pkg/service/servicefakes/fake_session_handler.go b/pkg/service/servicefakes/fake_session_handler.go new file mode 100644 index 000000000..552386918 --- /dev/null +++ b/pkg/service/servicefakes/fake_session_handler.go @@ -0,0 +1,199 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +type FakeSessionHandler struct { + HandleSessionStub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) error + handleSessionMutex sync.RWMutex + handleSessionArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 routing.ParticipantInit + arg4 livekit.ConnectionID + arg5 routing.MessageSource + arg6 routing.MessageSink + } + handleSessionReturns struct { + result1 error + } + handleSessionReturnsOnCall map[int]struct { + result1 error + } + LoggerStub func(context.Context) logger.Logger + loggerMutex sync.RWMutex + loggerArgsForCall []struct { + arg1 context.Context + } + loggerReturns struct { + result1 logger.Logger + } + loggerReturnsOnCall map[int]struct { + result1 logger.Logger + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeSessionHandler) HandleSession(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit, arg4 livekit.ConnectionID, arg5 routing.MessageSource, arg6 routing.MessageSink) error { + fake.handleSessionMutex.Lock() + ret, specificReturn := fake.handleSessionReturnsOnCall[len(fake.handleSessionArgsForCall)] + fake.handleSessionArgsForCall = append(fake.handleSessionArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 routing.ParticipantInit + arg4 livekit.ConnectionID + arg5 routing.MessageSource + arg6 routing.MessageSink + }{arg1, arg2, arg3, arg4, arg5, arg6}) + stub := fake.HandleSessionStub + fakeReturns := fake.handleSessionReturns + fake.recordInvocation("HandleSession", []interface{}{arg1, arg2, arg3, arg4, arg5, arg6}) + fake.handleSessionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4, arg5, arg6) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSessionHandler) HandleSessionCallCount() int { + fake.handleSessionMutex.RLock() + defer fake.handleSessionMutex.RUnlock() + return len(fake.handleSessionArgsForCall) +} + +func (fake *FakeSessionHandler) HandleSessionCalls(stub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) error) { + fake.handleSessionMutex.Lock() + defer fake.handleSessionMutex.Unlock() + fake.HandleSessionStub = stub +} + +func (fake *FakeSessionHandler) HandleSessionArgsForCall(i int) (context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) { + fake.handleSessionMutex.RLock() + defer fake.handleSessionMutex.RUnlock() + argsForCall := fake.handleSessionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5, argsForCall.arg6 +} + +func (fake *FakeSessionHandler) HandleSessionReturns(result1 error) { + fake.handleSessionMutex.Lock() + defer fake.handleSessionMutex.Unlock() + fake.HandleSessionStub = nil + fake.handleSessionReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSessionHandler) HandleSessionReturnsOnCall(i int, result1 error) { + fake.handleSessionMutex.Lock() + defer fake.handleSessionMutex.Unlock() + fake.HandleSessionStub = nil + if fake.handleSessionReturnsOnCall == nil { + fake.handleSessionReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleSessionReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSessionHandler) Logger(arg1 context.Context) logger.Logger { + fake.loggerMutex.Lock() + ret, specificReturn := fake.loggerReturnsOnCall[len(fake.loggerArgsForCall)] + fake.loggerArgsForCall = append(fake.loggerArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.LoggerStub + fakeReturns := fake.loggerReturns + fake.recordInvocation("Logger", []interface{}{arg1}) + fake.loggerMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSessionHandler) LoggerCallCount() int { + fake.loggerMutex.RLock() + defer fake.loggerMutex.RUnlock() + return len(fake.loggerArgsForCall) +} + +func (fake *FakeSessionHandler) LoggerCalls(stub func(context.Context) logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = stub +} + +func (fake *FakeSessionHandler) LoggerArgsForCall(i int) context.Context { + fake.loggerMutex.RLock() + defer fake.loggerMutex.RUnlock() + argsForCall := fake.loggerArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSessionHandler) LoggerReturns(result1 logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = nil + fake.loggerReturns = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeSessionHandler) LoggerReturnsOnCall(i int, result1 logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = nil + if fake.loggerReturnsOnCall == nil { + fake.loggerReturnsOnCall = make(map[int]struct { + result1 logger.Logger + }) + } + fake.loggerReturnsOnCall[i] = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeSessionHandler) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.handleSessionMutex.RLock() + defer fake.handleSessionMutex.RUnlock() + fake.loggerMutex.RLock() + defer fake.loggerMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeSessionHandler) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.SessionHandler = new(FakeSessionHandler) diff --git a/pkg/service/servicefakes/fake_sipstore.go b/pkg/service/servicefakes/fake_sipstore.go new file mode 100644 index 000000000..fe88a5359 --- /dev/null +++ b/pkg/service/servicefakes/fake_sipstore.go @@ -0,0 +1,663 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" +) + +type FakeSIPStore struct { + DeleteSIPDispatchRuleStub func(context.Context, *livekit.SIPDispatchRuleInfo) error + deleteSIPDispatchRuleMutex sync.RWMutex + deleteSIPDispatchRuleArgsForCall []struct { + arg1 context.Context + arg2 *livekit.SIPDispatchRuleInfo + } + deleteSIPDispatchRuleReturns struct { + result1 error + } + deleteSIPDispatchRuleReturnsOnCall map[int]struct { + result1 error + } + DeleteSIPTrunkStub func(context.Context, *livekit.SIPTrunkInfo) error + deleteSIPTrunkMutex sync.RWMutex + deleteSIPTrunkArgsForCall []struct { + arg1 context.Context + arg2 *livekit.SIPTrunkInfo + } + deleteSIPTrunkReturns struct { + result1 error + } + deleteSIPTrunkReturnsOnCall map[int]struct { + result1 error + } + ListSIPDispatchRuleStub func(context.Context) ([]*livekit.SIPDispatchRuleInfo, error) + listSIPDispatchRuleMutex sync.RWMutex + listSIPDispatchRuleArgsForCall []struct { + arg1 context.Context + } + listSIPDispatchRuleReturns struct { + result1 []*livekit.SIPDispatchRuleInfo + result2 error + } + listSIPDispatchRuleReturnsOnCall map[int]struct { + result1 []*livekit.SIPDispatchRuleInfo + result2 error + } + ListSIPTrunkStub func(context.Context) ([]*livekit.SIPTrunkInfo, error) + listSIPTrunkMutex sync.RWMutex + listSIPTrunkArgsForCall []struct { + arg1 context.Context + } + listSIPTrunkReturns struct { + result1 []*livekit.SIPTrunkInfo + result2 error + } + listSIPTrunkReturnsOnCall map[int]struct { + result1 []*livekit.SIPTrunkInfo + result2 error + } + LoadSIPDispatchRuleStub func(context.Context, string) (*livekit.SIPDispatchRuleInfo, error) + loadSIPDispatchRuleMutex sync.RWMutex + loadSIPDispatchRuleArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadSIPDispatchRuleReturns struct { + result1 *livekit.SIPDispatchRuleInfo + result2 error + } + loadSIPDispatchRuleReturnsOnCall map[int]struct { + result1 *livekit.SIPDispatchRuleInfo + result2 error + } + LoadSIPTrunkStub func(context.Context, string) (*livekit.SIPTrunkInfo, error) + loadSIPTrunkMutex sync.RWMutex + loadSIPTrunkArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadSIPTrunkReturns struct { + result1 *livekit.SIPTrunkInfo + result2 error + } + loadSIPTrunkReturnsOnCall map[int]struct { + result1 *livekit.SIPTrunkInfo + result2 error + } + StoreSIPDispatchRuleStub func(context.Context, *livekit.SIPDispatchRuleInfo) error + storeSIPDispatchRuleMutex sync.RWMutex + storeSIPDispatchRuleArgsForCall []struct { + arg1 context.Context + arg2 *livekit.SIPDispatchRuleInfo + } + storeSIPDispatchRuleReturns struct { + result1 error + } + storeSIPDispatchRuleReturnsOnCall map[int]struct { + result1 error + } + StoreSIPTrunkStub func(context.Context, *livekit.SIPTrunkInfo) error + storeSIPTrunkMutex sync.RWMutex + storeSIPTrunkArgsForCall []struct { + arg1 context.Context + arg2 *livekit.SIPTrunkInfo + } + storeSIPTrunkReturns struct { + result1 error + } + storeSIPTrunkReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRule(arg1 context.Context, arg2 *livekit.SIPDispatchRuleInfo) error { + fake.deleteSIPDispatchRuleMutex.Lock() + ret, specificReturn := fake.deleteSIPDispatchRuleReturnsOnCall[len(fake.deleteSIPDispatchRuleArgsForCall)] + fake.deleteSIPDispatchRuleArgsForCall = append(fake.deleteSIPDispatchRuleArgsForCall, struct { + arg1 context.Context + arg2 *livekit.SIPDispatchRuleInfo + }{arg1, arg2}) + stub := fake.DeleteSIPDispatchRuleStub + fakeReturns := fake.deleteSIPDispatchRuleReturns + fake.recordInvocation("DeleteSIPDispatchRule", []interface{}{arg1, arg2}) + fake.deleteSIPDispatchRuleMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRuleCallCount() int { + fake.deleteSIPDispatchRuleMutex.RLock() + defer fake.deleteSIPDispatchRuleMutex.RUnlock() + return len(fake.deleteSIPDispatchRuleArgsForCall) +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRuleCalls(stub func(context.Context, *livekit.SIPDispatchRuleInfo) error) { + fake.deleteSIPDispatchRuleMutex.Lock() + defer fake.deleteSIPDispatchRuleMutex.Unlock() + fake.DeleteSIPDispatchRuleStub = stub +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRuleArgsForCall(i int) (context.Context, *livekit.SIPDispatchRuleInfo) { + fake.deleteSIPDispatchRuleMutex.RLock() + defer fake.deleteSIPDispatchRuleMutex.RUnlock() + argsForCall := fake.deleteSIPDispatchRuleArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRuleReturns(result1 error) { + fake.deleteSIPDispatchRuleMutex.Lock() + defer fake.deleteSIPDispatchRuleMutex.Unlock() + fake.DeleteSIPDispatchRuleStub = nil + fake.deleteSIPDispatchRuleReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRuleReturnsOnCall(i int, result1 error) { + fake.deleteSIPDispatchRuleMutex.Lock() + defer fake.deleteSIPDispatchRuleMutex.Unlock() + fake.DeleteSIPDispatchRuleStub = nil + if fake.deleteSIPDispatchRuleReturnsOnCall == nil { + fake.deleteSIPDispatchRuleReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteSIPDispatchRuleReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) DeleteSIPTrunk(arg1 context.Context, arg2 *livekit.SIPTrunkInfo) error { + fake.deleteSIPTrunkMutex.Lock() + ret, specificReturn := fake.deleteSIPTrunkReturnsOnCall[len(fake.deleteSIPTrunkArgsForCall)] + fake.deleteSIPTrunkArgsForCall = append(fake.deleteSIPTrunkArgsForCall, struct { + arg1 context.Context + arg2 *livekit.SIPTrunkInfo + }{arg1, arg2}) + stub := fake.DeleteSIPTrunkStub + fakeReturns := fake.deleteSIPTrunkReturns + fake.recordInvocation("DeleteSIPTrunk", []interface{}{arg1, arg2}) + fake.deleteSIPTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSIPStore) DeleteSIPTrunkCallCount() int { + fake.deleteSIPTrunkMutex.RLock() + defer fake.deleteSIPTrunkMutex.RUnlock() + return len(fake.deleteSIPTrunkArgsForCall) +} + +func (fake *FakeSIPStore) DeleteSIPTrunkCalls(stub func(context.Context, *livekit.SIPTrunkInfo) error) { + fake.deleteSIPTrunkMutex.Lock() + defer fake.deleteSIPTrunkMutex.Unlock() + fake.DeleteSIPTrunkStub = stub +} + +func (fake *FakeSIPStore) DeleteSIPTrunkArgsForCall(i int) (context.Context, *livekit.SIPTrunkInfo) { + fake.deleteSIPTrunkMutex.RLock() + defer fake.deleteSIPTrunkMutex.RUnlock() + argsForCall := fake.deleteSIPTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) DeleteSIPTrunkReturns(result1 error) { + fake.deleteSIPTrunkMutex.Lock() + defer fake.deleteSIPTrunkMutex.Unlock() + fake.DeleteSIPTrunkStub = nil + fake.deleteSIPTrunkReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) DeleteSIPTrunkReturnsOnCall(i int, result1 error) { + fake.deleteSIPTrunkMutex.Lock() + defer fake.deleteSIPTrunkMutex.Unlock() + fake.DeleteSIPTrunkStub = nil + if fake.deleteSIPTrunkReturnsOnCall == nil { + fake.deleteSIPTrunkReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteSIPTrunkReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) ListSIPDispatchRule(arg1 context.Context) ([]*livekit.SIPDispatchRuleInfo, error) { + fake.listSIPDispatchRuleMutex.Lock() + ret, specificReturn := fake.listSIPDispatchRuleReturnsOnCall[len(fake.listSIPDispatchRuleArgsForCall)] + fake.listSIPDispatchRuleArgsForCall = append(fake.listSIPDispatchRuleArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.ListSIPDispatchRuleStub + fakeReturns := fake.listSIPDispatchRuleReturns + fake.recordInvocation("ListSIPDispatchRule", []interface{}{arg1}) + fake.listSIPDispatchRuleMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) ListSIPDispatchRuleCallCount() int { + fake.listSIPDispatchRuleMutex.RLock() + defer fake.listSIPDispatchRuleMutex.RUnlock() + return len(fake.listSIPDispatchRuleArgsForCall) +} + +func (fake *FakeSIPStore) ListSIPDispatchRuleCalls(stub func(context.Context) ([]*livekit.SIPDispatchRuleInfo, error)) { + fake.listSIPDispatchRuleMutex.Lock() + defer fake.listSIPDispatchRuleMutex.Unlock() + fake.ListSIPDispatchRuleStub = stub +} + +func (fake *FakeSIPStore) ListSIPDispatchRuleArgsForCall(i int) context.Context { + fake.listSIPDispatchRuleMutex.RLock() + defer fake.listSIPDispatchRuleMutex.RUnlock() + argsForCall := fake.listSIPDispatchRuleArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSIPStore) ListSIPDispatchRuleReturns(result1 []*livekit.SIPDispatchRuleInfo, result2 error) { + fake.listSIPDispatchRuleMutex.Lock() + defer fake.listSIPDispatchRuleMutex.Unlock() + fake.ListSIPDispatchRuleStub = nil + fake.listSIPDispatchRuleReturns = struct { + result1 []*livekit.SIPDispatchRuleInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) ListSIPDispatchRuleReturnsOnCall(i int, result1 []*livekit.SIPDispatchRuleInfo, result2 error) { + fake.listSIPDispatchRuleMutex.Lock() + defer fake.listSIPDispatchRuleMutex.Unlock() + fake.ListSIPDispatchRuleStub = nil + if fake.listSIPDispatchRuleReturnsOnCall == nil { + fake.listSIPDispatchRuleReturnsOnCall = make(map[int]struct { + result1 []*livekit.SIPDispatchRuleInfo + result2 error + }) + } + fake.listSIPDispatchRuleReturnsOnCall[i] = struct { + result1 []*livekit.SIPDispatchRuleInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) ListSIPTrunk(arg1 context.Context) ([]*livekit.SIPTrunkInfo, error) { + fake.listSIPTrunkMutex.Lock() + ret, specificReturn := fake.listSIPTrunkReturnsOnCall[len(fake.listSIPTrunkArgsForCall)] + fake.listSIPTrunkArgsForCall = append(fake.listSIPTrunkArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.ListSIPTrunkStub + fakeReturns := fake.listSIPTrunkReturns + fake.recordInvocation("ListSIPTrunk", []interface{}{arg1}) + fake.listSIPTrunkMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) ListSIPTrunkCallCount() int { + fake.listSIPTrunkMutex.RLock() + defer fake.listSIPTrunkMutex.RUnlock() + return len(fake.listSIPTrunkArgsForCall) +} + +func (fake *FakeSIPStore) ListSIPTrunkCalls(stub func(context.Context) ([]*livekit.SIPTrunkInfo, error)) { + fake.listSIPTrunkMutex.Lock() + defer fake.listSIPTrunkMutex.Unlock() + fake.ListSIPTrunkStub = stub +} + +func (fake *FakeSIPStore) ListSIPTrunkArgsForCall(i int) context.Context { + fake.listSIPTrunkMutex.RLock() + defer fake.listSIPTrunkMutex.RUnlock() + argsForCall := fake.listSIPTrunkArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSIPStore) ListSIPTrunkReturns(result1 []*livekit.SIPTrunkInfo, result2 error) { + fake.listSIPTrunkMutex.Lock() + defer fake.listSIPTrunkMutex.Unlock() + fake.ListSIPTrunkStub = nil + fake.listSIPTrunkReturns = struct { + result1 []*livekit.SIPTrunkInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) ListSIPTrunkReturnsOnCall(i int, result1 []*livekit.SIPTrunkInfo, result2 error) { + fake.listSIPTrunkMutex.Lock() + defer fake.listSIPTrunkMutex.Unlock() + fake.ListSIPTrunkStub = nil + if fake.listSIPTrunkReturnsOnCall == nil { + fake.listSIPTrunkReturnsOnCall = make(map[int]struct { + result1 []*livekit.SIPTrunkInfo + result2 error + }) + } + fake.listSIPTrunkReturnsOnCall[i] = struct { + result1 []*livekit.SIPTrunkInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPDispatchRule(arg1 context.Context, arg2 string) (*livekit.SIPDispatchRuleInfo, error) { + fake.loadSIPDispatchRuleMutex.Lock() + ret, specificReturn := fake.loadSIPDispatchRuleReturnsOnCall[len(fake.loadSIPDispatchRuleArgsForCall)] + fake.loadSIPDispatchRuleArgsForCall = append(fake.loadSIPDispatchRuleArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadSIPDispatchRuleStub + fakeReturns := fake.loadSIPDispatchRuleReturns + fake.recordInvocation("LoadSIPDispatchRule", []interface{}{arg1, arg2}) + fake.loadSIPDispatchRuleMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) LoadSIPDispatchRuleCallCount() int { + fake.loadSIPDispatchRuleMutex.RLock() + defer fake.loadSIPDispatchRuleMutex.RUnlock() + return len(fake.loadSIPDispatchRuleArgsForCall) +} + +func (fake *FakeSIPStore) LoadSIPDispatchRuleCalls(stub func(context.Context, string) (*livekit.SIPDispatchRuleInfo, error)) { + fake.loadSIPDispatchRuleMutex.Lock() + defer fake.loadSIPDispatchRuleMutex.Unlock() + fake.LoadSIPDispatchRuleStub = stub +} + +func (fake *FakeSIPStore) LoadSIPDispatchRuleArgsForCall(i int) (context.Context, string) { + fake.loadSIPDispatchRuleMutex.RLock() + defer fake.loadSIPDispatchRuleMutex.RUnlock() + argsForCall := fake.loadSIPDispatchRuleArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) LoadSIPDispatchRuleReturns(result1 *livekit.SIPDispatchRuleInfo, result2 error) { + fake.loadSIPDispatchRuleMutex.Lock() + defer fake.loadSIPDispatchRuleMutex.Unlock() + fake.LoadSIPDispatchRuleStub = nil + fake.loadSIPDispatchRuleReturns = struct { + result1 *livekit.SIPDispatchRuleInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPDispatchRuleReturnsOnCall(i int, result1 *livekit.SIPDispatchRuleInfo, result2 error) { + fake.loadSIPDispatchRuleMutex.Lock() + defer fake.loadSIPDispatchRuleMutex.Unlock() + fake.LoadSIPDispatchRuleStub = nil + if fake.loadSIPDispatchRuleReturnsOnCall == nil { + fake.loadSIPDispatchRuleReturnsOnCall = make(map[int]struct { + result1 *livekit.SIPDispatchRuleInfo + result2 error + }) + } + fake.loadSIPDispatchRuleReturnsOnCall[i] = struct { + result1 *livekit.SIPDispatchRuleInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPTrunk(arg1 context.Context, arg2 string) (*livekit.SIPTrunkInfo, error) { + fake.loadSIPTrunkMutex.Lock() + ret, specificReturn := fake.loadSIPTrunkReturnsOnCall[len(fake.loadSIPTrunkArgsForCall)] + fake.loadSIPTrunkArgsForCall = append(fake.loadSIPTrunkArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadSIPTrunkStub + fakeReturns := fake.loadSIPTrunkReturns + fake.recordInvocation("LoadSIPTrunk", []interface{}{arg1, arg2}) + fake.loadSIPTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) LoadSIPTrunkCallCount() int { + fake.loadSIPTrunkMutex.RLock() + defer fake.loadSIPTrunkMutex.RUnlock() + return len(fake.loadSIPTrunkArgsForCall) +} + +func (fake *FakeSIPStore) LoadSIPTrunkCalls(stub func(context.Context, string) (*livekit.SIPTrunkInfo, error)) { + fake.loadSIPTrunkMutex.Lock() + defer fake.loadSIPTrunkMutex.Unlock() + fake.LoadSIPTrunkStub = stub +} + +func (fake *FakeSIPStore) LoadSIPTrunkArgsForCall(i int) (context.Context, string) { + fake.loadSIPTrunkMutex.RLock() + defer fake.loadSIPTrunkMutex.RUnlock() + argsForCall := fake.loadSIPTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) LoadSIPTrunkReturns(result1 *livekit.SIPTrunkInfo, result2 error) { + fake.loadSIPTrunkMutex.Lock() + defer fake.loadSIPTrunkMutex.Unlock() + fake.LoadSIPTrunkStub = nil + fake.loadSIPTrunkReturns = struct { + result1 *livekit.SIPTrunkInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPTrunkReturnsOnCall(i int, result1 *livekit.SIPTrunkInfo, result2 error) { + fake.loadSIPTrunkMutex.Lock() + defer fake.loadSIPTrunkMutex.Unlock() + fake.LoadSIPTrunkStub = nil + if fake.loadSIPTrunkReturnsOnCall == nil { + fake.loadSIPTrunkReturnsOnCall = make(map[int]struct { + result1 *livekit.SIPTrunkInfo + result2 error + }) + } + fake.loadSIPTrunkReturnsOnCall[i] = struct { + result1 *livekit.SIPTrunkInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) StoreSIPDispatchRule(arg1 context.Context, arg2 *livekit.SIPDispatchRuleInfo) error { + fake.storeSIPDispatchRuleMutex.Lock() + ret, specificReturn := fake.storeSIPDispatchRuleReturnsOnCall[len(fake.storeSIPDispatchRuleArgsForCall)] + fake.storeSIPDispatchRuleArgsForCall = append(fake.storeSIPDispatchRuleArgsForCall, struct { + arg1 context.Context + arg2 *livekit.SIPDispatchRuleInfo + }{arg1, arg2}) + stub := fake.StoreSIPDispatchRuleStub + fakeReturns := fake.storeSIPDispatchRuleReturns + fake.recordInvocation("StoreSIPDispatchRule", []interface{}{arg1, arg2}) + fake.storeSIPDispatchRuleMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSIPStore) StoreSIPDispatchRuleCallCount() int { + fake.storeSIPDispatchRuleMutex.RLock() + defer fake.storeSIPDispatchRuleMutex.RUnlock() + return len(fake.storeSIPDispatchRuleArgsForCall) +} + +func (fake *FakeSIPStore) StoreSIPDispatchRuleCalls(stub func(context.Context, *livekit.SIPDispatchRuleInfo) error) { + fake.storeSIPDispatchRuleMutex.Lock() + defer fake.storeSIPDispatchRuleMutex.Unlock() + fake.StoreSIPDispatchRuleStub = stub +} + +func (fake *FakeSIPStore) StoreSIPDispatchRuleArgsForCall(i int) (context.Context, *livekit.SIPDispatchRuleInfo) { + fake.storeSIPDispatchRuleMutex.RLock() + defer fake.storeSIPDispatchRuleMutex.RUnlock() + argsForCall := fake.storeSIPDispatchRuleArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) StoreSIPDispatchRuleReturns(result1 error) { + fake.storeSIPDispatchRuleMutex.Lock() + defer fake.storeSIPDispatchRuleMutex.Unlock() + fake.StoreSIPDispatchRuleStub = nil + fake.storeSIPDispatchRuleReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) StoreSIPDispatchRuleReturnsOnCall(i int, result1 error) { + fake.storeSIPDispatchRuleMutex.Lock() + defer fake.storeSIPDispatchRuleMutex.Unlock() + fake.StoreSIPDispatchRuleStub = nil + if fake.storeSIPDispatchRuleReturnsOnCall == nil { + fake.storeSIPDispatchRuleReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeSIPDispatchRuleReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) StoreSIPTrunk(arg1 context.Context, arg2 *livekit.SIPTrunkInfo) error { + fake.storeSIPTrunkMutex.Lock() + ret, specificReturn := fake.storeSIPTrunkReturnsOnCall[len(fake.storeSIPTrunkArgsForCall)] + fake.storeSIPTrunkArgsForCall = append(fake.storeSIPTrunkArgsForCall, struct { + arg1 context.Context + arg2 *livekit.SIPTrunkInfo + }{arg1, arg2}) + stub := fake.StoreSIPTrunkStub + fakeReturns := fake.storeSIPTrunkReturns + fake.recordInvocation("StoreSIPTrunk", []interface{}{arg1, arg2}) + fake.storeSIPTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSIPStore) StoreSIPTrunkCallCount() int { + fake.storeSIPTrunkMutex.RLock() + defer fake.storeSIPTrunkMutex.RUnlock() + return len(fake.storeSIPTrunkArgsForCall) +} + +func (fake *FakeSIPStore) StoreSIPTrunkCalls(stub func(context.Context, *livekit.SIPTrunkInfo) error) { + fake.storeSIPTrunkMutex.Lock() + defer fake.storeSIPTrunkMutex.Unlock() + fake.StoreSIPTrunkStub = stub +} + +func (fake *FakeSIPStore) StoreSIPTrunkArgsForCall(i int) (context.Context, *livekit.SIPTrunkInfo) { + fake.storeSIPTrunkMutex.RLock() + defer fake.storeSIPTrunkMutex.RUnlock() + argsForCall := fake.storeSIPTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) StoreSIPTrunkReturns(result1 error) { + fake.storeSIPTrunkMutex.Lock() + defer fake.storeSIPTrunkMutex.Unlock() + fake.StoreSIPTrunkStub = nil + fake.storeSIPTrunkReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) StoreSIPTrunkReturnsOnCall(i int, result1 error) { + fake.storeSIPTrunkMutex.Lock() + defer fake.storeSIPTrunkMutex.Unlock() + fake.StoreSIPTrunkStub = nil + if fake.storeSIPTrunkReturnsOnCall == nil { + fake.storeSIPTrunkReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeSIPTrunkReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.deleteSIPDispatchRuleMutex.RLock() + defer fake.deleteSIPDispatchRuleMutex.RUnlock() + fake.deleteSIPTrunkMutex.RLock() + defer fake.deleteSIPTrunkMutex.RUnlock() + fake.listSIPDispatchRuleMutex.RLock() + defer fake.listSIPDispatchRuleMutex.RUnlock() + fake.listSIPTrunkMutex.RLock() + defer fake.listSIPTrunkMutex.RUnlock() + fake.loadSIPDispatchRuleMutex.RLock() + defer fake.loadSIPDispatchRuleMutex.RUnlock() + fake.loadSIPTrunkMutex.RLock() + defer fake.loadSIPTrunkMutex.RUnlock() + fake.storeSIPDispatchRuleMutex.RLock() + defer fake.storeSIPDispatchRuleMutex.RUnlock() + fake.storeSIPTrunkMutex.RLock() + defer fake.storeSIPTrunkMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeSIPStore) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.SIPStore = new(FakeSIPStore) diff --git a/pkg/service/signal.go b/pkg/service/signal.go index 4c9e085be..6c4ff820e 100644 --- a/pkg/service/signal.go +++ b/pkg/service/signal.go @@ -31,14 +31,21 @@ import ( "github.com/livekit/psrpc/pkg/middleware" ) -type SessionHandler func( - ctx context.Context, - roomName livekit.RoomName, - pi routing.ParticipantInit, - connectionID livekit.ConnectionID, - requestSource routing.MessageSource, - responseSink routing.MessageSink, -) error +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +//counterfeiter:generate . SessionHandler +type SessionHandler interface { + Logger(ctx context.Context) logger.Logger + + HandleSession( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error +} type SignalServer struct { server rpc.TypedSignalServer @@ -72,43 +79,43 @@ func NewDefaultSignalServer( router routing.Router, roomManager *RoomManager, ) (r *SignalServer, err error) { - sessionHandler := func( - ctx context.Context, - roomName livekit.RoomName, - pi routing.ParticipantInit, - connectionID livekit.ConnectionID, - requestSource routing.MessageSource, - responseSink routing.MessageSink, - ) error { - prometheus.IncrementParticipantRtcInit(1) + return NewSignalServer(livekit.NodeID(currentNode.Id), currentNode.Region, bus, config, &defaultSessionHandler{currentNode, router, roomManager}) +} - if rr, ok := router.(*routing.RedisRouter); ok { - rtcNode, err := router.GetNodeForRoom(ctx, roomName) - if err != nil { - return err - } +type defaultSessionHandler struct { + currentNode routing.LocalNode + router routing.Router + roomManager *RoomManager +} - if rtcNode.Id != currentNode.Id { - err = routing.ErrIncorrectRTCNode - logger.Errorw("called participant on incorrect node", err, - "rtcNode", rtcNode, - ) - return err - } +func (s *defaultSessionHandler) Logger(ctx context.Context) logger.Logger { + return logger.GetLogger() +} - pKey := routing.ParticipantKeyLegacy(roomName, pi.Identity) - pKeyB62 := routing.ParticipantKey(roomName, pi.Identity) +func (s *defaultSessionHandler) HandleSession( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, +) error { + prometheus.IncrementParticipantRtcInit(1) - // RTC session should start on this node - if err := rr.SetParticipantRTCNode(pKey, pKeyB62, currentNode.Id); err != nil { - return err - } - } - - return roomManager.StartSession(ctx, roomName, pi, requestSource, responseSink) + rtcNode, err := s.router.GetNodeForRoom(ctx, roomName) + if err != nil { + return err } - return NewSignalServer(livekit.NodeID(currentNode.Id), currentNode.Region, bus, config, sessionHandler) + if rtcNode.Id != s.currentNode.Id { + err = routing.ErrIncorrectRTCNode + logger.Errorw("called participant on incorrect node", err, + "rtcNode", rtcNode, + ) + return err + } + + return s.roomManager.StartSession(ctx, roomName, pi, requestSource, responseSink) } func (s *SignalServer) Start() error { @@ -142,12 +149,13 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe return errors.Wrap(err, "failed to read participant from session") } - l := logger.GetLogger().WithValues( + l := r.sessionHandler.Logger(stream.Context()).WithValues( "room", ss.RoomName, "participant", ss.Identity, "connID", ss.ConnectionId, ) + stream.Hijack() sink := routing.NewSignalMessageSink(routing.SignalSinkParams[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest]{ Logger: l, Stream: stream, @@ -174,13 +182,11 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe // copy the incoming rpc headers to avoid dropping any session vars. ctx := metadata.NewContextWithIncomingHeader(context.Background(), metadata.IncomingHeader(stream.Context())) - err = r.sessionHandler(ctx, livekit.RoomName(ss.RoomName), *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink) + err = r.sessionHandler.HandleSession(ctx, livekit.RoomName(ss.RoomName), *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink) if err != nil { + sink.Close() l.Errorw("could not handle new participant", err) - return } - - stream.Hijack() return } diff --git a/pkg/service/signal_test.go b/pkg/service/signal_test.go index 3e202636e..951837176 100644 --- a/pkg/service/signal_test.go +++ b/pkg/service/signal_test.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package service +package service_test import ( "context" + "errors" "testing" "time" @@ -25,8 +26,11 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/livekit-server/pkg/service/servicefakes" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" "github.com/livekit/psrpc" ) @@ -35,65 +39,120 @@ func init() { } func TestSignal(t *testing.T) { - bus := psrpc.NewLocalMessageBus() cfg := config.SignalRelayConfig{ - Enabled: false, RetryTimeout: 30 * time.Second, MinRetryInterval: 500 * time.Millisecond, MaxRetryInterval: 5 * time.Second, StreamBufferSize: 1000, } - reqMessageIn := &livekit.SignalRequest{ - Message: &livekit.SignalRequest_Ping{Ping: 123}, - } - resMessageIn := &livekit.SignalResponse{ - Message: &livekit.SignalResponse_Pong{Pong: 321}, - } + t.Run("messages are delivered", func(t *testing.T) { + bus := psrpc.NewLocalMessageBus() - var reqMessageOut proto.Message - var resErr error - done := make(chan struct{}) + reqMessageIn := &livekit.SignalRequest{ + Message: &livekit.SignalRequest_Ping{Ping: 123}, + } + resMessageIn := &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Pong{Pong: 321}, + } - client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) - require.NoError(t, err) + var reqMessageOut proto.Message + var resErr error + done := make(chan struct{}) - server, err := NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func( - ctx context.Context, - roomName livekit.RoomName, - pi routing.ParticipantInit, - connectionID livekit.ConnectionID, - requestSource routing.MessageSource, - responseSink routing.MessageSink, - ) error { - go func() { - reqMessageOut = <-requestSource.ReadChan() - resErr = responseSink.WriteMessage(resMessageIn) - responseSink.Close() - close(done) - }() - return nil + client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) + require.NoError(t, err) + + handler := &servicefakes.FakeSessionHandler{ + LoggerStub: func(context.Context) logger.Logger { return logger.GetLogger() }, + HandleSessionStub: func( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error { + go func() { + reqMessageOut = <-requestSource.ReadChan() + resErr = responseSink.WriteMessage(resMessageIn) + responseSink.Close() + close(done) + }() + return nil + }, + } + server, err := service.NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, handler) + require.NoError(t, err) + + err = server.Start() + require.NoError(t, err) + + _, reqSink, resSource, err := client.StartParticipantSignal( + context.Background(), + livekit.RoomName("room1"), + routing.ParticipantInit{}, + livekit.NodeID("node1"), + ) + require.NoError(t, err) + + err = reqSink.WriteMessage(reqMessageIn) + require.NoError(t, err) + + <-done + require.True(t, proto.Equal(reqMessageIn, reqMessageOut), "req message should match %s %s", protojson.Format(reqMessageIn), protojson.Format(reqMessageOut)) + require.NoError(t, resErr) + + resMessageOut := <-resSource.ReadChan() + require.True(t, proto.Equal(resMessageIn, resMessageOut), "res message should match %s %s", protojson.Format(resMessageIn), protojson.Format(resMessageOut)) }) - require.NoError(t, err) - err = server.Start() - require.NoError(t, err) + t.Run("messages are delivered when session handler fails", func(t *testing.T) { + bus := psrpc.NewLocalMessageBus() - _, reqSink, resSource, err := client.StartParticipantSignal( - context.Background(), - livekit.RoomName("room1"), - routing.ParticipantInit{}, - livekit.NodeID("node1"), - ) - require.NoError(t, err) + resMessageIn := &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Pong{Pong: 321}, + } - err = reqSink.WriteMessage(reqMessageIn) - require.NoError(t, err) + var resErr error + done := make(chan struct{}) - <-done - require.True(t, proto.Equal(reqMessageIn, reqMessageOut), "req message should match %s %s", protojson.Format(reqMessageIn), protojson.Format(reqMessageOut)) - require.NoError(t, resErr) + client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) + require.NoError(t, err) - resMessageOut := <-resSource.ReadChan() - require.True(t, proto.Equal(resMessageIn, resMessageOut), "res message should match %s %s", protojson.Format(resMessageIn), protojson.Format(resMessageOut)) + handler := &servicefakes.FakeSessionHandler{ + LoggerStub: func(context.Context) logger.Logger { return logger.GetLogger() }, + HandleSessionStub: func( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error { + defer close(done) + resErr = responseSink.WriteMessage(resMessageIn) + return errors.New("start session failed") + }, + } + server, err := service.NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, handler) + require.NoError(t, err) + + err = server.Start() + require.NoError(t, err) + + _, _, resSource, err := client.StartParticipantSignal( + context.Background(), + livekit.RoomName("room1"), + routing.ParticipantInit{}, + livekit.NodeID("node1"), + ) + require.NoError(t, err) + + <-done + require.NoError(t, resErr) + + resMessageOut := <-resSource.ReadChan() + require.True(t, proto.Equal(resMessageIn, resMessageOut), "res message should match %s %s", protojson.Format(resMessageIn), protojson.Format(resMessageOut)) + }) } diff --git a/pkg/service/sip.go b/pkg/service/sip.go index c69540eed..a95dbad09 100644 --- a/pkg/service/sip.go +++ b/pkg/service/sip.go @@ -16,6 +16,7 @@ package service import ( "context" + "time" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -161,89 +162,43 @@ func (s *SIPService) CreateSIPParticipant(ctx context.Context, req *livekit.Crea return nil, ErrSIPNotConnected } - info := &livekit.SIPParticipantInfo{ - SipParticipantId: utils.NewGuid(utils.SIPParticipantPrefix), - SipTrunkId: req.SipTrunkId, - SipCallTo: req.SipCallTo, + AppendLogFields(ctx, "room", req.RoomName, "trunk", req.SipTrunkId, "to", req.SipCallTo) + ireq := &rpc.InternalCreateSIPParticipantRequest{ + CallTo: req.SipCallTo, RoomName: req.RoomName, ParticipantIdentity: req.ParticipantIdentity, } - - if err := s.store.StoreSIPParticipant(ctx, info); err != nil { - return nil, err - } - s.updateParticipant(ctx, info) - return info, nil -} - -func (s *SIPService) updateParticipant(ctx context.Context, info *livekit.SIPParticipantInfo) { - AppendLogFields(ctx, "participantId", info.SipParticipantId, "room", info.RoomName, "trunk", info.SipTrunkId, "to", info.SipCallTo) - req := &rpc.InternalUpdateSIPParticipantRequest{ - ParticipantId: info.SipParticipantId, - CallTo: info.SipCallTo, - RoomName: info.RoomName, - ParticipantIdentity: info.ParticipantIdentity, - } - if info.SipTrunkId != "" { - trunk, err := s.store.LoadSIPTrunk(ctx, info.SipTrunkId) + if req.SipTrunkId != "" { + trunk, err := s.store.LoadSIPTrunk(ctx, req.SipTrunkId) if err != nil { logger.Errorw("cannot get trunk to update sip participant", err) - return + return nil, err } - req.Address = trunk.OutboundAddress - req.Number = trunk.OutboundNumber - req.Username = trunk.OutboundUsername - req.Password = trunk.OutboundPassword + ireq.Address = trunk.OutboundAddress + ireq.Number = trunk.OutboundNumber + ireq.Username = trunk.OutboundUsername + ireq.Password = trunk.OutboundPassword } - if _, err := s.psrpcClient.UpdateSIPParticipant(ctx, req); err != nil { + + // CreateSIPParticipant will wait for LiveKit Participant to be created and that can take some time. + // Thus, we must set a higher deadline for it, if it's not set already. + // TODO: support context timeouts in psrpc + timeout := 30 * time.Second + if deadline, ok := ctx.Deadline(); ok { + timeout = time.Until(deadline) + } else { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + resp, err := s.psrpcClient.CreateSIPParticipant(ctx, "", ireq, psrpc.WithRequestTimeout(timeout)) + if err != nil { logger.Errorw("cannot update sip participant", err) - } -} - -func (s *SIPService) ListSIPParticipant(ctx context.Context, req *livekit.ListSIPParticipantRequest) (*livekit.ListSIPParticipantResponse, error) { - if s.store == nil { - return nil, ErrSIPNotConnected - } - - participants, err := s.store.ListSIPParticipant(ctx) - if err != nil { return nil, err } - - return &livekit.ListSIPParticipantResponse{Items: participants}, nil -} - -func (s *SIPService) DeleteSIPParticipant(ctx context.Context, req *livekit.DeleteSIPParticipantRequest) (*livekit.SIPParticipantInfo, error) { - if s.store == nil { - return nil, ErrSIPNotConnected - } - - info, err := s.store.LoadSIPParticipant(ctx, req.SipParticipantId) - if err != nil { - return nil, err - } - - if err = s.store.DeleteSIPParticipant(ctx, info); err != nil { - return nil, err - } - // These indicate that the call should be disconnected - info.SipTrunkId = "" - info.SipCallTo = "" - s.updateParticipant(ctx, info) - return info, nil -} - -func (s *SIPService) SendSIPParticipantDTMF(ctx context.Context, req *livekit.SendSIPParticipantDTMFRequest) (*livekit.SIPParticipantDTMFInfo, error) { - if s.store == nil { - return nil, ErrSIPNotConnected - } - AppendLogFields(ctx, "participantId", req.SipParticipantId) - _, err := s.psrpcClient.SendSIPParticipantDTMF(ctx, &rpc.InternalSendSIPParticipantDTMFRequest{ - ParticipantId: req.SipParticipantId, - Digits: req.Digits, - }) - if err != nil { - logger.Errorw("cannot send dtmf to sip participant", err) - } - return nil, err + return &livekit.SIPParticipantInfo{ + ParticipantId: resp.ParticipantId, + ParticipantIdentity: resp.ParticipantIdentity, + RoomName: req.RoomName, + }, nil } diff --git a/pkg/service/twirp.go b/pkg/service/twirp.go index 50f3f2dc0..b7b31d764 100644 --- a/pkg/service/twirp.go +++ b/pkg/service/twirp.go @@ -25,13 +25,11 @@ import ( "github.com/twitchtv/twirp" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" - "github.com/livekit/protocol/logger" + "github.com/livekit/livekit-server/pkg/utils" ) -var ( - loggerKey = struct{}{} - statusReporterKey = struct{ a int }{42} -) +type twirpLoggerContext struct{} +type statusReporterKey struct{} type twirpRequestFields struct { service string @@ -41,11 +39,10 @@ type twirpRequestFields struct { // logging handling inspired by https://github.com/bakins/twirpzap // License: Apache-2.0 -func TwirpLogger(logger logger.Logger) *twirp.ServerHooks { +func TwirpLogger() *twirp.ServerHooks { loggerPool := &sync.Pool{ New: func() interface{} { return &requestLogger{ - logger: logger, fieldsOrig: make([]interface{}, 0, 30), } }, @@ -65,14 +62,13 @@ func TwirpLogger(logger logger.Logger) *twirp.ServerHooks { type requestLogger struct { twirpRequestFields - logger logger.Logger fieldsOrig []interface{} fields []interface{} startedAt time.Time } func AppendLogFields(ctx context.Context, fields ...interface{}) { - r, ok := ctx.Value(loggerKey).(*requestLogger) + r, ok := ctx.Value(twirpLoggerContext{}).(*requestLogger) if !ok || r == nil { return } @@ -91,13 +87,13 @@ func requestReceived(ctx context.Context, requestLoggerPool *sync.Pool) (context r.fields = append(r.fields, "service", svc) } - ctx = context.WithValue(ctx, loggerKey, r) + ctx = context.WithValue(ctx, twirpLoggerContext{}, r) return ctx, nil } func responseRouted(ctx context.Context) (context.Context, error) { if meth, ok := twirp.MethodName(ctx); ok { - l, ok := ctx.Value(loggerKey).(*requestLogger) + l, ok := ctx.Value(twirpLoggerContext{}).(*requestLogger) if !ok || l == nil { return ctx, nil } @@ -109,7 +105,7 @@ func responseRouted(ctx context.Context) (context.Context, error) { } func responseSent(ctx context.Context, requestLoggerPool *sync.Pool) { - r, ok := ctx.Value(loggerKey).(*requestLogger) + r, ok := ctx.Value(twirpLoggerContext{}).(*requestLogger) if !ok || r == nil { return } @@ -125,7 +121,7 @@ func responseSent(ctx context.Context, requestLoggerPool *sync.Pool) { } serviceMethod := "API " + r.service + "." + r.method - r.logger.Infow(serviceMethod, r.fields...) + utils.GetLogger(ctx).WithComponent(utils.ComponentAPI).Infow(serviceMethod, r.fields...) r.fields = r.fieldsOrig r.error = nil @@ -134,7 +130,7 @@ func responseSent(ctx context.Context, requestLoggerPool *sync.Pool) { } func errorReceived(ctx context.Context, e twirp.Error) context.Context { - r, ok := ctx.Value(loggerKey).(*requestLogger) + r, ok := ctx.Value(twirpLoggerContext{}).(*requestLogger) if !ok || r == nil { return ctx } @@ -160,13 +156,13 @@ func statusReporterRequestReceived(ctx context.Context) (context.Context, error) r.service = svc } - ctx = context.WithValue(ctx, statusReporterKey, r) + ctx = context.WithValue(ctx, statusReporterKey{}, r) return ctx, nil } func statusReporterResponseRouted(ctx context.Context) (context.Context, error) { if meth, ok := twirp.MethodName(ctx); ok { - l, ok := ctx.Value(statusReporterKey).(*twirpRequestFields) + l, ok := ctx.Value(statusReporterKey{}).(*twirpRequestFields) if !ok || l == nil { return ctx, nil } @@ -177,7 +173,7 @@ func statusReporterResponseRouted(ctx context.Context) (context.Context, error) } func statusReporterResponseSent(ctx context.Context) { - r, ok := ctx.Value(statusReporterKey).(*twirpRequestFields) + r, ok := ctx.Value(statusReporterKey{}).(*twirpRequestFields) if !ok || r == nil { return } @@ -205,7 +201,7 @@ func statusReporterResponseSent(ctx context.Context) { } func statusReporterErrorReceived(ctx context.Context, e twirp.Error) context.Context { - r, ok := ctx.Value(statusReporterKey).(*twirpRequestFields) + r, ok := ctx.Value(statusReporterKey{}).(*twirpRequestFields) if !ok || r == nil { return ctx } diff --git a/pkg/service/utils.go b/pkg/service/utils.go index 32076455f..2d1bafd4e 100644 --- a/pkg/service/utils.go +++ b/pkg/service/utils.go @@ -22,8 +22,11 @@ import ( "github.com/livekit/protocol/logger" ) -func handleError(w http.ResponseWriter, status int, err error, keysAndValues ...interface{}) { +func handleError(w http.ResponseWriter, r *http.Request, status int, err error, keysAndValues ...interface{}) { keysAndValues = append(keysAndValues, "status", status) + if r != nil && r.URL != nil { + keysAndValues = append(keysAndValues, "method", r.Method, "path", r.URL.Path) + } logger.GetLogger().WithCallDepth(1).Warnw("error handling request", err, keysAndValues...) w.WriteHeader(status) _, _ = w.Write([]byte(err.Error())) diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 93e3784b0..0a2501a1d 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -82,6 +82,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live getSignalRelayConfig, NewDefaultSignalServer, routing.NewSignalClient, + rpc.NewKeepalivePubSub, getPSRPCConfig, getPSRPCClientParams, rpc.NewTopicFormatter, @@ -103,7 +104,10 @@ func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routi getNodeID, getMessageBus, getSignalRelayConfig, + getPSRPCConfig, + getPSRPCClientParams, routing.NewSignalClient, + rpc.NewKeepalivePubSub, routing.CreateRouter, ) diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 9c9a741e5..e76fa56c7 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -50,7 +50,12 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - router := routing.CreateRouter(conf, universalClient, currentNode, signalClient) + clientParams := getPSRPCClientParams(psrpcConfig, messageBus) + keepalivePubSub, err := rpc.NewKeepalivePubSub(clientParams) + if err != nil { + return nil, err + } + router := routing.CreateRouter(universalClient, currentNode, signalClient, keepalivePubSub) objectStore := createStore(universalClient) roomAllocator, err := NewRoomAllocator(conf, router, objectStore) if err != nil { @@ -60,7 +65,6 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - clientParams := getPSRPCClientParams(psrpcConfig, messageBus) egressClient, err := rpc.NewEgressClient(clientParams) if err != nil { return nil, err @@ -149,7 +153,13 @@ func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routi if err != nil { return nil, err } - router := routing.CreateRouter(conf, universalClient, currentNode, signalClient) + psrpcConfig := getPSRPCConfig(conf) + clientParams := getPSRPCClientParams(psrpcConfig, messageBus) + keepalivePubSub, err := rpc.NewKeepalivePubSub(clientParams) + if err != nil { + return nil, err + } + router := routing.CreateRouter(universalClient, currentNode, signalClient, keepalivePubSub) return router, nil } diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 573b85649..f27946064 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -350,7 +350,10 @@ func (b *Buffer) Close() error { if b.rtpStats != nil { b.rtpStats.Stop() - b.logger.Infow("rtp stats", "direction", "upstream", "stats", b.rtpStats.ToString()) + b.logger.Debugw("rtp stats", + "direction", "upstream", + "stats", b.rtpStats, + ) if b.onFinalRtpStats != nil { b.onFinalRtpStats(b.rtpStats.ToProto()) } @@ -454,6 +457,10 @@ func (b *Buffer) calc(pkt []byte, arrivalTime time.Time) { b.logger.Errorw("could not exclude range", err, "sn", rtpPacket.SequenceNumber, "esn", flowState.ExtSequenceNumber) } } + // TODO-VP9-DEBUG-REMOVE-START + snAdjustment, err := b.snRangeMap.GetValue(flowState.ExtSequenceNumber) + b.logger.Debugw("dropping padding packet", "sn", rtpPacket.SequenceNumber, "osn", flowState.ExtSequenceNumber, "msn", flowState.ExtSequenceNumber-snAdjustment, "error", err) + // TODO-VP9-DEBUG-REMOVE-END return } @@ -469,7 +476,7 @@ func (b *Buffer) calc(pkt []byte, arrivalTime time.Time) { if err != nil { if errors.Is(err, bucket.ErrPacketTooOld) { packetTooOldCount := b.packetTooOldCount.Inc() - if packetTooOldCount%20 == 0 { + if (packetTooOldCount-1)%100 == 0 { b.logger.Warnw("could not add packet to bucket", err, "count", packetTooOldCount) } } else if err != bucket.ErrRTXPacket { diff --git a/pkg/sfu/buffer/dependencydescriptorparser.go b/pkg/sfu/buffer/dependencydescriptorparser.go index 5b2744716..da79c301b 100644 --- a/pkg/sfu/buffer/dependencydescriptorparser.go +++ b/pkg/sfu/buffer/dependencydescriptorparser.go @@ -19,6 +19,7 @@ import ( "sort" "github.com/pion/rtp" + "go.uber.org/atomic" dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor" "github.com/livekit/livekit-server/pkg/sfu/utils" @@ -44,6 +45,8 @@ type DependencyDescriptorParser struct { activeDecodeTargetsExtSeq uint64 activeDecodeTargetsMask uint32 frameChecker *FrameIntegrityChecker + + ddNotFoundCount atomic.Uint32 } func NewDependencyDescriptorParser(ddExtID uint8, logger logger.Logger, onMaxLayerChanged func(int32, int32)) *DependencyDescriptorParser { @@ -73,7 +76,10 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr var videoLayer VideoLayer ddBuf := pkt.GetExtension(r.ddExtID) if ddBuf == nil { - r.logger.Warnw("dependency descriptor extension is not present", nil, "seq", pkt.SequenceNumber) + ddNotFoundCount := r.ddNotFoundCount.Inc() + if ddNotFoundCount%100 == 0 { + r.logger.Warnw("dependency descriptor extension is not present", nil, "seq", pkt.SequenceNumber, "count", ddNotFoundCount) + } return nil, videoLayer, nil } diff --git a/pkg/sfu/buffer/rtpstats_receiver.go b/pkg/sfu/buffer/rtpstats_receiver.go index b621f42b1..53a963fcb 100644 --- a/pkg/sfu/buffer/rtpstats_receiver.go +++ b/pkg/sfu/buffer/rtpstats_receiver.go @@ -402,7 +402,7 @@ func (r *RTPStatsReceiver) DeltaInfo(snapshotID uint32) *RTPDeltaInfo { return r.deltaInfo(snapshotID, r.sequenceNumber.GetExtendedStart(), r.sequenceNumber.GetExtendedHighest()) } -func (r *RTPStatsReceiver) ToString() string { +func (r *RTPStatsReceiver) String() string { r.lock.RLock() defer r.lock.RUnlock() diff --git a/pkg/sfu/buffer/rtpstats_receiver_test.go b/pkg/sfu/buffer/rtpstats_receiver_test.go index 7a39c1b54..1816b6689 100644 --- a/pkg/sfu/buffer/rtpstats_receiver_test.go +++ b/pkg/sfu/buffer/rtpstats_receiver_test.go @@ -20,9 +20,10 @@ import ( "testing" "time" - "github.com/livekit/protocol/logger" "github.com/pion/rtp" "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/logger" ) func getPacket(sn uint16, ts uint32, payloadSize int) *rtp.Packet { @@ -82,7 +83,7 @@ func Test_RTPStatsReceiver(t *testing.T) { } r.Stop() - fmt.Printf("%s\n", r.ToString()) + fmt.Printf("%s\n", r.String()) } func Test_RTPStatsReceiver_Update(t *testing.T) { diff --git a/pkg/sfu/buffer/rtpstats_sender.go b/pkg/sfu/buffer/rtpstats_sender.go index 1cb0ff1bd..618341e75 100644 --- a/pkg/sfu/buffer/rtpstats_sender.go +++ b/pkg/sfu/buffer/rtpstats_sender.go @@ -629,7 +629,7 @@ func (r *RTPStatsSender) GetRtcpSenderReport(ssrc uint32, calculatedClockRate ui RTPTimestampExt: nowRTPExt, At: now, } - if r.srNewest != nil { + if r.srNewest != nil && nowRTPExt >= r.srNewest.RTPTimestampExt { timeSinceLastReport := nowNTP.Time().Sub(r.srNewest.NTPTimestamp.Time()) rtpDiffSinceLastReport := nowRTPExt - r.srNewest.RTPTimestampExt windowClockRate := float64(rtpDiffSinceLastReport) / timeSinceLastReport.Seconds() @@ -807,7 +807,7 @@ func (r *RTPStatsSender) DeltaInfoSender(senderSnapshotID uint32) *RTPDeltaInfo } } -func (r *RTPStatsSender) ToString() string { +func (r *RTPStatsSender) String() string { r.lock.RLock() defer r.lock.RUnlock() diff --git a/pkg/sfu/buffer/videolayerutils.go b/pkg/sfu/buffer/videolayerutils.go index 9028283fc..f099c2e7c 100644 --- a/pkg/sfu/buffer/videolayerutils.go +++ b/pkg/sfu/buffer/videolayerutils.go @@ -25,6 +25,7 @@ const ( FullResolution = "f" ) +// SIMULCAST-CODEC-TODO: these need to be codec mime aware if and when each codec suppports different layers func LayerPresenceFromTrackInfo(trackInfo *livekit.TrackInfo) *[livekit.VideoQuality_HIGH + 1]bool { if trackInfo == nil || len(trackInfo.Layers) == 0 { return nil @@ -36,7 +37,7 @@ func LayerPresenceFromTrackInfo(trackInfo *livekit.TrackInfo) *[livekit.VideoQua if layer.Quality <= livekit.VideoQuality_HIGH { layerPresence[layer.Quality] = true } else { - logger.Warnw("unexpected quality in track info", nil, "trackInfo", logger.Proto(trackInfo)) + logger.Warnw("unexpected quality in track info", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) } } @@ -97,13 +98,13 @@ func RidToSpatialLayer(rid string, trackInfo *livekit.TrackInfo) int32 { return 2 case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: - logger.Warnw("unexpected rid f with only two qualities, low and medium", nil) + logger.Warnw("unexpected rid f with only two qualities, low and medium", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) return 1 case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: - logger.Warnw("unexpected rid f with only two qualities, low and high", nil) + logger.Warnw("unexpected rid f with only two qualities, low and high", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) return 1 case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: - logger.Warnw("unexpected rid f with only two qualities, medium and high", nil) + logger.Warnw("unexpected rid f with only two qualities, medium and high", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) return 1 default: @@ -169,13 +170,13 @@ func SpatialLayerToRid(layer int32, trackInfo *livekit.TrackInfo) string { return FullResolution case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: - logger.Warnw("unexpected layer 2 with only two qualities, low and medium", nil) + logger.Warnw("unexpected layer 2 with only two qualities, low and medium", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) return HalfResolution case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: - logger.Warnw("unexpected layer 2 with only two qualities, low and high", nil) + logger.Warnw("unexpected layer 2 with only two qualities, low and high", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) return HalfResolution case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: - logger.Warnw("unexpected layer 2 with only two qualities, medium and high", nil) + logger.Warnw("unexpected layer 2 with only two qualities, medium and high", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) return HalfResolution default: @@ -240,7 +241,7 @@ func SpatialLayerToVideoQuality(layer int32, trackInfo *livekit.TrackInfo) livek return livekit.VideoQuality_HIGH default: - logger.Errorw("invalid layer", nil, "layer", layer, "trackInfo", trackInfo) + logger.Errorw("invalid layer", nil, "trackID", trackInfo.Sid, "layer", layer, "trackInfo", logger.Proto(trackInfo)) return livekit.VideoQuality_HIGH } @@ -250,7 +251,7 @@ func SpatialLayerToVideoQuality(layer int32, trackInfo *livekit.TrackInfo) livek return livekit.VideoQuality_HIGH default: - logger.Errorw("invalid layer", nil, "layer", layer, "trackInfo", trackInfo) + logger.Errorw("invalid layer", nil, "trackID", trackInfo.Sid, "layer", layer, "trackInfo", logger.Proto(trackInfo)) return livekit.VideoQuality_HIGH } } diff --git a/pkg/sfu/connectionquality/scorer.go b/pkg/sfu/connectionquality/scorer.go index ffc3961da..d5c3b019c 100644 --- a/pkg/sfu/connectionquality/scorer.go +++ b/pkg/sfu/connectionquality/scorer.go @@ -29,9 +29,8 @@ const ( MaxMOS = float32(4.5) MinMOS = float32(1.0) - cMaxScore = float64(100.0) - cMinScore = float64(20.0) - cPausedPoorScore = float64(30.0) + cMaxScore = float64(100.0) + cMinScore = float64(30.0) increaseFactor = float64(0.4) // slower increase, i. e. when score is recovering move up slower -> conservative decreaseFactor = float64(0.7) // faster decrease, i. e. when score is dropping move down faster -> aggressive to be responsive to quality drops @@ -305,7 +304,7 @@ func (q *qualityScorer) updatePauseAtLocked(isPaused bool, at time.Time) { q.layerDistance.Reset() q.pausedAt = at - q.score = cPausedPoorScore + q.score = cMinScore } } else { if q.isPaused() { @@ -364,7 +363,7 @@ func (q *qualityScorer) updateAtLocked(stat *windowStat, at time.Time) { // considered (as long as enough time has passed since unmute). // // Similarly, when paused (possibly due to congestion), score is immediately - // set to cPausedPoorScore for responsiveness. The layer transision is reest. + // set to cMinScore for responsiveness. The layer transision is reest. // On a resume, quality climbs back up using normal operation. if q.isMuted() || !q.isUnmutedEnough(at) || q.isLayerMuted() || q.isPaused() { q.lastUpdateAt = at @@ -404,12 +403,12 @@ func (q *qualityScorer) updateAtLocked(stat *windowStat, at time.Time) { factor = decreaseFactor } score = factor*score + (1.0-factor)*q.score - } - if score < cMinScore { - // lower bound to prevent score from becoming very small values due to extreme conditions. - // Without a lower bound, it can get so low that it takes a long time to climb back to - // better quality even under excellent conditions. - score = cMinScore + if score < cMinScore { + // lower bound to prevent score from becoming very small values due to extreme conditions. + // Without a lower bound, it can get so low that it takes a long time to climb back to + // better quality even under excellent conditions. + score = cMinScore + } } prevCQ := scoreToConnectionQuality(q.score) currCQ := scoreToConnectionQuality(score) diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index c15e68dde..f12859142 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -133,7 +133,7 @@ type DownTrackState struct { func (d DownTrackState) String() string { return fmt.Sprintf("DownTrackState{rtpStats: %s, deltaSender: %d, forwarder: %s}", - d.RTPStats.ToString(), d.DeltaStatsSenderSnapshotId, d.ForwarderState.String()) + d.RTPStats, d.DeltaStatsSenderSnapshotId, d.ForwarderState.String()) } // ------------------------------------------------------------------- @@ -262,8 +262,7 @@ type DownTrack struct { bytesSent atomic.Uint32 bytesRetransmitted atomic.Uint32 - playoutDelayBytes atomic.Value //bytes of marshalled playout delay - playoudDelayAcked atomic.Bool + playoutDelay *PlayoutDelayController pacer pacer.Pacer @@ -318,6 +317,14 @@ func NewDownTrack(params DowntrackParams) (*DownTrack, error) { }) d.deltaStatsSenderSnapshotId = d.rtpStats.NewSenderSnapshotId() + if delay := params.PlayoutDelayLimit; delay.GetEnabled() { + var err error + d.playoutDelay, err = NewPlayoutDelayController(delay.GetMin(), delay.GetMax(), params.Logger, d.rtpStats) + if err != nil { + return nil, err + } + } + d.connectionStats = connectionquality.NewConnectionStats(connectionquality.ConnectionStatsParams{ MimeType: codecs[0].MimeType, // LK-TODO have to notify on codec change IsFECEnabled: strings.EqualFold(codecs[0].MimeType, webrtc.MimeTypeOpus) && strings.Contains(strings.ToLower(codecs[0].SDPFmtpLine), "fec"), @@ -330,27 +337,11 @@ func NewDownTrack(params DowntrackParams) (*DownTrack, error) { } }) - // set initial playout delay to minimum value - if d.params.PlayoutDelayLimit.GetEnabled() { - maxDelay := uint32(rtpextension.PlayoutDelayDefaultMax) - if d.params.PlayoutDelayLimit.GetMax() > 0 { - maxDelay = d.params.PlayoutDelayLimit.GetMax() - } - delay := rtpextension.PlayoutDelayFromValue( - uint16(d.params.PlayoutDelayLimit.GetMin()), - uint16(maxDelay), - ) - b, err := delay.Marshal() - if err == nil { - d.playoutDelayBytes.Store(b) - } else { - d.params.Logger.Errorw("failed to marshal playout delay", err, "playoutDelay", d.params.PlayoutDelayLimit) - } - } if d.kind == webrtc.RTPCodecTypeVideo { go d.maxLayerNotifierWorker() go d.keyFrameRequester() } + d.params.Logger.Debugw("downtrack created") return d, nil } @@ -410,6 +401,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.onBinding(nil) } d.bound.Store(true) + d.onBindAndConnectedChange() d.bindLock.Unlock() // Bind is called under RTPSender.mu lock, call the RTPSender.GetParameters in goroutine to avoid deadlock @@ -424,9 +416,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, }() d.forwarder.DetermineCodec(d.codec, d.params.Receiver.HeaderExtensions()) - d.params.Logger.Debugw("downtrack bound") - d.onBindAndConnectedChange() return codec, nil } @@ -434,8 +424,10 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, // Unbind implements the teardown logic when the track is no longer needed. This happens // because a track has been stopped. func (d *DownTrack) Unbind(_ webrtc.TrackLocalContext) error { + d.bindLock.Lock() d.bound.Store(false) d.onBindAndConnectedChange() + d.bindLock.Unlock() return nil } @@ -458,7 +450,7 @@ func (d *DownTrack) SetStreamAllocatorListener(listener DownTrackStreamAllocator d.transportWideExtID = 0 } - // kick of a gratuitous allocation + // kick off a gratuitous allocation listener.OnSubscriptionChanged(d) } } @@ -714,9 +706,9 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { if tp.ddBytes != nil { extensions = []pacer.ExtensionData{{ID: uint8(d.dependencyDescriptorExtID), Payload: tp.ddBytes}} } - if d.playoutDelayExtID != 0 && !d.playoudDelayAcked.Load() { - if val := d.playoutDelayBytes.Load(); val != nil { - extensions = append(extensions, pacer.ExtensionData{ID: uint8(d.playoutDelayExtID), Payload: val.([]byte)}) + if d.playoutDelayExtID != 0 && d.playoutDelay != nil { + if val := d.playoutDelay.GetDelayExtension(hdr.SequenceNumber); val != nil { + extensions = append(extensions, pacer.ExtensionData{ID: uint8(d.playoutDelayExtID), Payload: val}) } } if d.sequencer != nil { @@ -984,10 +976,13 @@ func (d *DownTrack) CloseWithFlush(flush bool) { d.bindLock.Unlock() d.connectionStats.Close() d.rtpStats.Stop() - rtpStats := d.rtpStats.ToString() - if rtpStats != "" { - d.params.Logger.Infow("rtp stats", "direction", "downstream", "mime", d.mime, "ssrc", d.ssrc, "stats", rtpStats) - } + d.params.Logger.Debugw("rtp stats", + "direction", "downstream", + "mime", d.mime, + "ssrc", d.ssrc, + // evaluate only if log level matches + "stats", d.rtpStats, + ) d.maxLayerNotifierChMu.Lock() d.maxLayerNotifierChClosed = true @@ -1266,22 +1261,23 @@ func (d *DownTrack) Resync() { } func (d *DownTrack) CreateSourceDescriptionChunks() []rtcp.SourceDescriptionChunk { - if !d.bound.Load() || d.transceiver.Load() == nil { + transceiver := d.transceiver.Load() + if !d.bound.Load() || transceiver == nil { return nil } return []rtcp.SourceDescriptionChunk{ { Source: d.ssrc, - Items: []rtcp.SourceDescriptionItem{{ - Type: rtcp.SDESCNAME, - Text: d.params.StreamID, - }}, - }, { - Source: d.ssrc, - Items: []rtcp.SourceDescriptionItem{{ - Type: rtcp.SDESType(15), - Text: d.transceiver.Load().Mid(), - }}, + Items: []rtcp.SourceDescriptionItem{ + { + Type: rtcp.SDESCNAME, + Text: d.params.StreamID, + }, + { + Type: rtcp.SDESType(15), + Text: transceiver.Mid(), + }, + }, }, } } @@ -1475,7 +1471,6 @@ func (d *DownTrack) handleRTCP(bytes []byte) { pliOnce := true sendPliOnce := func() { _, layer := d.forwarder.CheckSync() - d.params.Logger.Debugw("received PLI/FIR RTCP", "layer", layer) if pliOnce { if layer != buffer.InvalidLayerSpatial { d.params.Logger.Debugw("sending PLI RTCP", "layer", layer) @@ -1528,7 +1523,11 @@ func (d *DownTrack) handleRTCP(bytes []byte) { sal.OnRTCPReceiverReport(d, r) } - d.playoudDelayAcked.Store(true) + if d.playoutDelay != nil { + jitterMs := uint64(r.Jitter*1e3) / uint64(d.codec.ClockRate) + d.playoutDelay.OnSeqAcked(uint16(r.LastSequenceNumber)) + d.playoutDelay.SetJitter(uint32(jitterMs)) + } } if len(rr.Reports) > 0 { d.listenerLock.RLock() @@ -1572,9 +1571,12 @@ func (d *DownTrack) handleRTCP(bytes []byte) { } func (d *DownTrack) SetConnected() { + d.bindLock.Lock() if !d.connected.Swap(true) { d.onBindAndConnectedChange() } + d.params.Logger.Debugw("downtrack connected") + d.bindLock.Unlock() } // SetActivePaddingOnMuteUpTrack will enable padding on the track when its uptrack is muted. @@ -1623,6 +1625,17 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { if err == io.EOF { break } + // TODO-VP9-DEBUG-REMOVE-START + d.params.Logger.Debugw( + "NACK miss", + "isn", epm.sourceSeqNo, + "osn", epm.targetSeqNo, + "ots", epm.timestamp, + "eosn", epm.extSequenceNumber, + "eots", epm.extTimestamp, + "sid", epm.layer, + ) + // TODO-VP9-DEBUG-REMOVE-END nackMisses++ continue } diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index e0ac55742..44de9ba71 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -1695,8 +1695,36 @@ func (f *Forwarder) getTranslationParamsVideo(extPkt *buffer.ExtPacket, layer in tp.shouldDrop = true if f.started && result.IsRelevant { // call to update highest incoming sequence number and other internal structures - if tpRTP, err := f.rtpMunger.UpdateAndGetSnTs(extPkt, result.RTPMarker); err == nil && tpRTP.snOrdering == SequenceNumberOrderingContiguous { - f.rtpMunger.PacketDropped(extPkt) + if tpRTP, err := f.rtpMunger.UpdateAndGetSnTs(extPkt, result.RTPMarker); err == nil { + if tpRTP.snOrdering == SequenceNumberOrderingContiguous { + // TODO-VP9-DEBUG-REMOVE-START + f.logger.Debugw( + "dropping packet", + "isn", extPkt.ExtSequenceNumber, + "its", extPkt.ExtTimestamp, + "osn", tpRTP.extSequenceNumber, + "ots", tpRTP.extTimestamp, + "payloadLen", len(extPkt.Packet.Payload), + "sid", extPkt.Spatial, + "tid", extPkt.Temporal, + ) + // TODO-VP9-DEBUG-REMOVE-END + f.rtpMunger.PacketDropped(extPkt) + } else { + // TODO-VP9-DEBUG-REMOVE-START + f.logger.Debugw( + "dropping packet skipped as not contiguous", + "isn", extPkt.ExtSequenceNumber, + "its", extPkt.ExtTimestamp, + "osn", tpRTP.extSequenceNumber, + "ots", tpRTP.extTimestamp, + "payloadLen", len(extPkt.Packet.Payload), + "sid", extPkt.Spatial, + "tid", extPkt.Temporal, + "snOrdering", tpRTP.snOrdering, + ) + // TODO-VP9-DEBUG-REMOVE-END + } } } return tp, nil diff --git a/pkg/sfu/playoutdelay.go b/pkg/sfu/playoutdelay.go new file mode 100644 index 000000000..2ac26a85b --- /dev/null +++ b/pkg/sfu/playoutdelay.go @@ -0,0 +1,141 @@ +package sfu + +import ( + "sync" + "sync/atomic" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/rtpextension" + "github.com/livekit/protocol/logger" +) + +type PlayoutDelayState int32 + +const ( + PlayoutDelayStateChanged PlayoutDelayState = iota + PlayoutDelaySending + PlayoutDelayAcked + + jitterLowMultiToDelay = 10 + jitterHighMultiToDelay = 15 + jitterHighThreshold = 15 +) + +func (s PlayoutDelayState) String() string { + switch s { + case PlayoutDelayStateChanged: + return "StateChanged" + case PlayoutDelaySending: + return "Sending" + case PlayoutDelayAcked: + return "Acked" + } + return "Unknown" +} + +type PlayoutDelayController struct { + lock sync.Mutex + state atomic.Int32 + minDelay, maxDelay uint32 + currentDelay uint32 + extBytes atomic.Value //[]byte + sendingAtSeq uint16 + logger logger.Logger + rtpStats *buffer.RTPStatsSender + snapshotID uint32 +} + +func NewPlayoutDelayController(minDelay, maxDelay uint32, logger logger.Logger, rtpStats *buffer.RTPStatsSender) (*PlayoutDelayController, error) { + if maxDelay == 0 && minDelay > 0 { + maxDelay = rtpextension.MaxPlayoutDelayDefault + } + if maxDelay > rtpextension.PlayoutDelayMaxValue { + maxDelay = rtpextension.PlayoutDelayMaxValue + } + c := &PlayoutDelayController{ + currentDelay: minDelay, + minDelay: minDelay, + maxDelay: maxDelay, + logger: logger, + rtpStats: rtpStats, + snapshotID: rtpStats.NewSenderSnapshotId(), + } + return c, c.createExtData() +} + +func (c *PlayoutDelayController) SetJitter(jitter uint32) { + deltaInfo := c.rtpStats.DeltaInfoSender(c.snapshotID) + var nackPercent uint32 + if deltaInfo != nil && deltaInfo.Packets > 0 { + nackPercent = deltaInfo.Nacks * 100 / deltaInfo.Packets + } + + c.lock.Lock() + multi := jitterLowMultiToDelay + if jitter >= jitterHighThreshold { + multi = jitterHighMultiToDelay + } + targetDelay := jitter * uint32(multi) + if nackPercent > 60 { + targetDelay += (nackPercent - 60) * 2 + } + + // increase delay quickly, decrease slowly to make fps more stable + if targetDelay > c.currentDelay { + targetDelay = (targetDelay-c.currentDelay)*3/4 + c.currentDelay + } else { + targetDelay = c.currentDelay - (c.currentDelay-targetDelay)/5 + } + if targetDelay < c.minDelay { + targetDelay = c.minDelay + } + if targetDelay > c.maxDelay { + targetDelay = c.maxDelay + } + if c.currentDelay == targetDelay { + c.lock.Unlock() + return + } + c.currentDelay = targetDelay + c.lock.Unlock() + c.createExtData() +} + +func (c *PlayoutDelayController) OnSeqAcked(seq uint16) { + c.lock.Lock() + defer c.lock.Unlock() + if PlayoutDelayState(c.state.Load()) == PlayoutDelaySending && (seq-c.sendingAtSeq) < 0x8000 { + c.state.Store(int32(PlayoutDelayAcked)) + } +} + +func (c *PlayoutDelayController) GetDelayExtension(seq uint16) []byte { + switch PlayoutDelayState(c.state.Load()) { + case PlayoutDelayStateChanged: + c.lock.Lock() + c.state.Store(int32(PlayoutDelaySending)) + c.sendingAtSeq = seq + c.lock.Unlock() + return c.extBytes.Load().([]byte) + case PlayoutDelaySending: + return c.extBytes.Load().([]byte) + case PlayoutDelayAcked: + return nil + } + return nil +} + +func (c *PlayoutDelayController) createExtData() error { + delay := rtpextension.PlayoutDelayFromValue( + uint16(c.currentDelay), + uint16(c.maxDelay), + ) + b, err := delay.Marshal() + if err == nil { + c.extBytes.Store(b) + c.state.Store(int32(PlayoutDelayStateChanged)) + } else { + c.logger.Errorw("failed to marshal playout delay", err, "playoutDelay", delay) + } + return err +} diff --git a/pkg/sfu/playoutdelay_test.go b/pkg/sfu/playoutdelay_test.go new file mode 100644 index 000000000..190891034 --- /dev/null +++ b/pkg/sfu/playoutdelay_test.go @@ -0,0 +1,61 @@ +package sfu + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/rtpextension" + "github.com/livekit/protocol/logger" +) + +func TestPlayoutDelay(t *testing.T) { + stats := buffer.NewRTPStatsSender(buffer.RTPStatsParams{ClockRate: 900000, Logger: logger.GetLogger()}) + c, err := NewPlayoutDelayController(100, 1000, logger.GetLogger(), stats) + require.NoError(t, err) + + ext := c.GetDelayExtension(100) + playoutDelayEqual(t, ext, 100, 1000) + + ext = c.GetDelayExtension(105) + playoutDelayEqual(t, ext, 100, 1000) + + // seq acked before delay changed + c.OnSeqAcked(65534) + ext = c.GetDelayExtension(105) + playoutDelayEqual(t, ext, 100, 1000) + + c.OnSeqAcked(90) + ext = c.GetDelayExtension(105) + playoutDelayEqual(t, ext, 100, 1000) + + // seq acked, no extension sent for new packet + c.OnSeqAcked(103) + ext = c.GetDelayExtension(106) + require.Nil(t, ext) + + // delay on change(can't go below min), no extension sent + c.SetJitter(0) + ext = c.GetDelayExtension(107) + require.Nil(t, ext) + + // delay changed, generate new extension to send + c.SetJitter(50) + ext = c.GetDelayExtension(108) + var delay rtpextension.PlayOutDelay + require.NoError(t, delay.Unmarshal(ext)) + require.Greater(t, delay.Min, uint16(100)) + + // can't go above max + c.SetJitter(10000) + ext = c.GetDelayExtension(109) + playoutDelayEqual(t, ext, 1000, 1000) +} + +func playoutDelayEqual(t *testing.T, data []byte, min, max uint16) { + var delay rtpextension.PlayOutDelay + require.NoError(t, delay.Unmarshal(data)) + require.Equal(t, min, delay.Min) + require.Equal(t, max, delay.Max) +} diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 9afff8e98..ae368b42f 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -24,6 +24,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "go.uber.org/atomic" + "google.golang.org/protobuf/proto" "github.com/livekit/mediatransportutil/pkg/bucket" "github.com/livekit/mediatransportutil/pkg/twcc" @@ -71,6 +72,7 @@ type TrackReceiver interface { DebugInfo() map[string]interface{} TrackInfo() *livekit.TrackInfo + UpdateTrackInfo(ti *livekit.TrackInfo) // Get primary receiver if this receiver represents a RED codec; otherwise it will return itself GetPrimaryReceiverForRed() TrackReceiver @@ -104,7 +106,7 @@ type WebRTCReceiver struct { closeOnce sync.Once closed atomic.Bool useTrackers bool - trackInfo *livekit.TrackInfo + trackInfo atomic.Pointer[livekit.TrackInfo] rtcpCh chan []rtcp.Packet @@ -200,21 +202,21 @@ func NewWebRTCReceiver( opts ...ReceiverOpts, ) *WebRTCReceiver { w := &WebRTCReceiver{ - logger: logger, - receiver: receiver, - trackID: livekit.TrackID(track.ID()), - streamID: track.StreamID(), - codec: track.Codec(), - kind: track.Kind(), - twcc: twcc, - trackInfo: trackInfo, - isSVC: IsSvcCodec(track.Codec().MimeType), - isRED: IsRedCodec(track.Codec().MimeType), + logger: logger, + receiver: receiver, + trackID: livekit.TrackID(track.ID()), + streamID: track.StreamID(), + codec: track.Codec(), + kind: track.Kind(), + twcc: twcc, + isSVC: IsSvcCodec(track.Codec().MimeType), + isRED: IsRedCodec(track.Codec().MimeType), } for _, opt := range opts { w = opt(w) } + w.trackInfo.Store(proto.Clone(trackInfo).(*livekit.TrackInfo)) w.downTrackSpreader = NewDownTrackSpreader(DownTrackSpreaderParams{ Threshold: w.lbThreshold, @@ -232,7 +234,7 @@ func NewWebRTCReceiver( w.onStatsUpdate(w, stat) } }) - w.connectionStats.Start(w.trackInfo) + w.connectionStats.Start(trackInfo) w.streamTrackerManager = NewStreamTrackerManager(logger, trackInfo, w.isSVC, w.codec.ClockRate, trackersConfig) w.streamTrackerManager.SetListener(w) @@ -250,7 +252,12 @@ func NewWebRTCReceiver( } func (w *WebRTCReceiver) TrackInfo() *livekit.TrackInfo { - return w.trackInfo + return w.trackInfo.Load() +} + +func (w *WebRTCReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { + w.trackInfo.Store(proto.Clone(ti).(*livekit.TrackInfo)) + w.streamTrackerManager.UpdateTrackInfo(ti) } func (w *WebRTCReceiver) OnStatsUpdate(fn func(w *WebRTCReceiver, stat *livekit.AnalyticsStat)) { @@ -328,7 +335,7 @@ func (w *WebRTCReceiver) AddUpTrack(track *webrtc.TrackRemote, buff *buffer.Buff layer := int32(0) if w.Kind() == webrtc.RTPCodecTypeVideo && !w.isSVC { - layer = buffer.RidToSpatialLayer(track.RID(), w.trackInfo) + layer = buffer.RidToSpatialLayer(track.RID(), w.trackInfo.Load()) } buff.SetLogger(w.logger.WithValues("layer", layer)) buff.SetTWCC(w.twcc) @@ -414,6 +421,7 @@ func (w *WebRTCReceiver) AddDownTrack(track TrackSender) error { track.UpTrackMaxTemporalLayerSeenChange(w.streamTrackerManager.GetMaxTemporalLayerSeen()) w.downTrackSpreader.Store(track) + w.logger.Debugw("downtrack added", "subscriberID", track.SubscriberID()) return nil } @@ -498,6 +506,7 @@ func (w *WebRTCReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) { } w.downTrackSpreader.Free(subscriberID) + w.logger.Debugw("downtrack deleted", "subscriberID", subscriberID) } func (w *WebRTCReceiver) sendRTCP(packets []rtcp.Packet) { @@ -700,9 +709,13 @@ func (w *WebRTCReceiver) closeTracks() { } func (w *WebRTCReceiver) DebugInfo() map[string]interface{} { + isSimulcast := !w.isSVC + if ti := w.trackInfo.Load(); ti != nil { + isSimulcast = isSimulcast && len(ti.Layers) > 1 + } info := map[string]interface{}{ "SVC": w.isSVC, - "Simulcast": !w.isSVC && len(w.trackInfo.Layers) > 1, + "Simulcast": isSimulcast, } w.upTrackMu.RLock() diff --git a/pkg/sfu/redprimaryreceiver.go b/pkg/sfu/redprimaryreceiver.go index e7d33099f..e9e2642f3 100644 --- a/pkg/sfu/redprimaryreceiver.go +++ b/pkg/sfu/redprimaryreceiver.go @@ -96,7 +96,10 @@ func (r *RedPrimaryReceiver) AddDownTrack(track TrackSender) error { r.logger.Infow("subscriberID already exists, replacing downtrack", "subscriberID", track.SubscriberID()) } + track.TrackInfoAvailable() + r.downTrackSpreader.Store(track) + r.logger.Debugw("red primary receiver downtrack added", "subscriberID", track.SubscriberID()) return nil } @@ -106,6 +109,7 @@ func (r *RedPrimaryReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) } r.downTrackSpreader.Free(subscriberID) + r.logger.Debugw("red primary receiver downtrack deleted", "subscriberID", subscriberID) } func (r *RedPrimaryReceiver) IsClosed() bool { diff --git a/pkg/sfu/redreceiver.go b/pkg/sfu/redreceiver.go index 9254be438..fe254a972 100644 --- a/pkg/sfu/redreceiver.go +++ b/pkg/sfu/redreceiver.go @@ -87,7 +87,10 @@ func (r *RedReceiver) AddDownTrack(track TrackSender) error { r.logger.Infow("subscriberID already exists, replacing downtrack", "subscriberID", track.SubscriberID()) } + track.TrackInfoAvailable() + r.downTrackSpreader.Store(track) + r.logger.Debugw("red receiver downtrack added", "subscriberID", track.SubscriberID()) return nil } @@ -97,6 +100,7 @@ func (r *RedReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) { } r.downTrackSpreader.Free(subscriberID) + r.logger.Debugw("red receiver downtrack deleted", "subscriberID", subscriberID) } func (r *RedReceiver) CanClose() bool { diff --git a/pkg/sfu/redreceiver_test.go b/pkg/sfu/redreceiver_test.go index 2aae30182..72261f7db 100644 --- a/pkg/sfu/redreceiver_test.go +++ b/pkg/sfu/redreceiver_test.go @@ -22,6 +22,7 @@ import ( "github.com/stretchr/testify/require" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/protocol/logger" ) const tsStep = uint32(48000 / 1000 * 10) @@ -38,11 +39,17 @@ func (dt *dummyDowntrack) WriteRTP(p *buffer.ExtPacket, _ int32) error { return nil } +func (dt *dummyDowntrack) TrackInfoAvailable() {} + func TestRedReceiver(t *testing.T) { dt := &dummyDowntrack{TrackSender: &DownTrack{}} t.Run("normal", func(t *testing.T) { - w := &WebRTCReceiver{isRED: true, kind: webrtc.RTPCodecTypeAudio} + w := &WebRTCReceiver{ + isRED: true, + kind: webrtc.RTPCodecTypeAudio, + logger: logger.GetLogger(), + } require.Equal(t, w.GetRedReceiver(), w) w.isRED = false red := w.GetRedReceiver().(*RedReceiver) @@ -64,7 +71,10 @@ func TestRedReceiver(t *testing.T) { }) t.Run("packet lost and jump", func(t *testing.T) { - w := &WebRTCReceiver{kind: webrtc.RTPCodecTypeAudio} + w := &WebRTCReceiver{ + kind: webrtc.RTPCodecTypeAudio, + logger: logger.GetLogger(), + } red := w.GetRedReceiver().(*RedReceiver) require.NoError(t, red.AddDownTrack(dt)) @@ -112,7 +122,10 @@ func TestRedReceiver(t *testing.T) { }) t.Run("unorder and repeat", func(t *testing.T) { - w := &WebRTCReceiver{kind: webrtc.RTPCodecTypeAudio} + w := &WebRTCReceiver{ + kind: webrtc.RTPCodecTypeAudio, + logger: logger.GetLogger(), + } red := w.GetRedReceiver().(*RedReceiver) require.NoError(t, red.AddDownTrack(dt)) @@ -141,7 +154,11 @@ func TestRedReceiver(t *testing.T) { }) t.Run("encoding exceed space", func(t *testing.T) { - w := &WebRTCReceiver{isRED: true, kind: webrtc.RTPCodecTypeAudio} + w := &WebRTCReceiver{ + isRED: true, + kind: webrtc.RTPCodecTypeAudio, + logger: logger.GetLogger(), + } require.Equal(t, w.GetRedReceiver(), w) w.isRED = false red := w.GetRedReceiver().(*RedReceiver) @@ -162,7 +179,11 @@ func TestRedReceiver(t *testing.T) { }) t.Run("large timestamp gap", func(t *testing.T) { - w := &WebRTCReceiver{isRED: true, kind: webrtc.RTPCodecTypeAudio} + w := &WebRTCReceiver{ + isRED: true, + kind: webrtc.RTPCodecTypeAudio, + logger: logger.GetLogger(), + } require.Equal(t, w.GetRedReceiver(), w) w.isRED = false red := w.GetRedReceiver().(*RedReceiver) @@ -257,7 +278,10 @@ func generateRedPkts(t *testing.T, pkts []*rtp.Packet, redCount int) []*rtp.Pack func testRedRedPrimaryReceiver(t *testing.T, maxPktCount, redCount int, sendPktIdx, expectPktIdx []int) { dt := &dummyDowntrack{TrackSender: &DownTrack{}} - w := &WebRTCReceiver{kind: webrtc.RTPCodecTypeAudio} + w := &WebRTCReceiver{ + kind: webrtc.RTPCodecTypeAudio, + logger: logger.GetLogger(), + } require.Equal(t, w.GetPrimaryReceiverForRed(), w) w.isRED = true red := w.GetPrimaryReceiverForRed().(*RedPrimaryReceiver) @@ -283,7 +307,10 @@ func testRedRedPrimaryReceiver(t *testing.T, maxPktCount, redCount int, sendPktI } func TestRedPrimaryReceiver(t *testing.T) { - w := &WebRTCReceiver{kind: webrtc.RTPCodecTypeAudio} + w := &WebRTCReceiver{ + kind: webrtc.RTPCodecTypeAudio, + logger: logger.GetLogger(), + } require.Equal(t, w.GetPrimaryReceiverForRed(), w) w.isRED = true red := w.GetPrimaryReceiverForRed().(*RedPrimaryReceiver) diff --git a/pkg/sfu/rtpextension/playoutdelay.go b/pkg/sfu/rtpextension/playoutdelay.go index d42322be7..500bb67ea 100644 --- a/pkg/sfu/rtpextension/playoutdelay.go +++ b/pkg/sfu/rtpextension/playoutdelay.go @@ -7,7 +7,8 @@ import ( const ( PlayoutDelayURI = "http://www.webrtc.org/experiments/rtp-hdrext/playout-delay" - PlayoutDelayDefaultMax = 4000 // 4s + MaxPlayoutDelayDefault = 10000 // 10s, equal to chrome's default max playout delay + PlayoutDelayMaxValue = 10 * (1<<12 - 1) // max value for playout delay can be represented playoutDelayExtensionSize = 3 ) @@ -28,11 +29,11 @@ type PlayOutDelay struct { } func PlayoutDelayFromValue(min, max uint16) PlayOutDelay { - if min >= (1<<12)*10 { - min = (1<<12 - 1) * 10 + if min > PlayoutDelayMaxValue { + min = PlayoutDelayMaxValue } - if max >= (1<<12)*10 { - max = (1<<12 - 1) * 10 + if max > PlayoutDelayMaxValue { + max = PlayoutDelayMaxValue } return PlayOutDelay{Min: min, Max: max} } diff --git a/pkg/sfu/rtpextension/playoutdelay_test.go b/pkg/sfu/rtpextension/playoutdelay_test.go index a92f4cae5..7262a6475 100644 --- a/pkg/sfu/rtpextension/playoutdelay_test.go +++ b/pkg/sfu/rtpextension/playoutdelay_test.go @@ -32,4 +32,12 @@ func TestPlayoutDelay(t *testing.T) { require.NoError(t, err) require.Equal(t, uint16((1<<12)-1)*10, p5.Min) require.Equal(t, uint16((1<<12)-1)*10, p5.Max) + + p6 := PlayOutDelay{Min: 100, Max: PlayoutDelayMaxValue} + bytes, err := p6.Marshal() + require.NoError(t, err) + p6Unmarshal := PlayOutDelay{} + err = p6Unmarshal.Unmarshal(bytes) + require.NoError(t, err) + require.Equal(t, p6, p6Unmarshal) } diff --git a/pkg/sfu/streamallocator/streamallocator.go b/pkg/sfu/streamallocator/streamallocator.go index 0c2d7d52b..6370692b0 100644 --- a/pkg/sfu/streamallocator/streamallocator.go +++ b/pkg/sfu/streamallocator/streamallocator.go @@ -31,6 +31,7 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/utils" ) const ( @@ -165,8 +166,7 @@ type StreamAllocator struct { state streamAllocatorState - eventChMu sync.RWMutex - eventCh chan Event + eventsQueue *utils.OpsQueue isStopped atomic.Bool } @@ -180,7 +180,7 @@ func NewStreamAllocator(params StreamAllocatorParams) *StreamAllocator { }), rateMonitor: NewRateMonitor(), videoTracks: make(map[livekit.TrackID]*Track), - eventCh: make(chan Event, 1000), + eventsQueue: utils.NewOpsQueue("stream-allocator", 64, true), } s.probeController = NewProbeController(ProbeControllerParams{ @@ -197,19 +197,18 @@ func NewStreamAllocator(params StreamAllocatorParams) *StreamAllocator { } func (s *StreamAllocator) Start() { - go s.processEvents() + s.eventsQueue.Start() go s.ping() } func (s *StreamAllocator) Stop() { - s.eventChMu.Lock() if s.isStopped.Swap(true) { - s.eventChMu.Unlock() return } - close(s.eventCh) - s.eventChMu.Unlock() + // wait for eventsQueue to be done + <-s.eventsQueue.Stop() + s.probeController.StopProbe() } func (s *StreamAllocator) OnStreamStateChange(f func(update *StreamStateUpdate) error) { @@ -546,30 +545,9 @@ func (s *StreamAllocator) maybePostEventAllocateTrack(downTrack *sfu.DownTrack) } func (s *StreamAllocator) postEvent(event Event) { - s.eventChMu.RLock() - if s.isStopped.Load() { - s.eventChMu.RUnlock() - return - } - - select { - case s.eventCh <- event: - default: - s.params.Logger.Warnw("stream allocator: event queue full", nil) - } - s.eventChMu.RUnlock() -} - -func (s *StreamAllocator) processEvents() { - for event := range s.eventCh { - if s.isStopped.Load() { - break - } - + s.eventsQueue.Enqueue(func() { s.handleEvent(&event) - } - - s.probeController.StopProbe() + }) } func (s *StreamAllocator) ping() { @@ -835,7 +813,7 @@ func (s *StreamAllocator) handleNewEstimateInNonProbe() { "commitThreshold(bps)", commitThreshold, "channel", s.channelObserver.ToString(), ) - s.params.Logger.Infow( + s.params.Logger.Debugw( fmt.Sprintf("stream allocator: channel congestion detected, %s channel capacity: experimental", action), "rateHistory", s.rateMonitor.GetHistory(), "expectedQueuing", s.rateMonitor.GetQueuingGuess(), @@ -1297,7 +1275,7 @@ func (s *StreamAllocator) initProbe(probeGoalDeltaBps int64) { s.channelObserver = s.newChannelObserverProbe() s.channelObserver.SeedEstimate(s.lastReceivedEstimate) - s.params.Logger.Infow( + s.params.Logger.Debugw( "stream allocator: starting probe", "probeClusterId", probeClusterId, "current usage", expectedBandwidthUsage, diff --git a/pkg/sfu/streamtrackermanager.go b/pkg/sfu/streamtrackermanager.go index c9a22bff9..a934b9e93 100644 --- a/pkg/sfu/streamtrackermanager.go +++ b/pkg/sfu/streamtrackermanager.go @@ -22,6 +22,8 @@ import ( "time" "github.com/frostbyte73/core" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/sfu/buffer" @@ -57,7 +59,7 @@ type endsSenderReport struct { type StreamTrackerManager struct { logger logger.Logger - trackInfo *livekit.TrackInfo + trackInfo atomic.Pointer[livekit.TrackInfo] isSVC bool clockRate uint32 @@ -92,15 +94,15 @@ func NewStreamTrackerManager( ) *StreamTrackerManager { s := &StreamTrackerManager{ logger: logger, - trackInfo: trackInfo, isSVC: isSVC, maxPublishedLayer: buffer.InvalidLayerSpatial, maxTemporalLayerSeen: buffer.InvalidLayerTemporal, clockRate: clockRate, closed: core.NewFuse(), } + s.trackInfo.Store(proto.Clone(trackInfo).(*livekit.TrackInfo)) - switch s.trackInfo.Source { + switch trackInfo.Source { case livekit.TrackSource_SCREEN_SHARE: s.trackerConfig = trackersConfig.Screenshare case livekit.TrackSource_CAMERA: @@ -111,7 +113,7 @@ func NewStreamTrackerManager( s.maxExpectedLayerFromTrackInfo() - if s.trackInfo.Type == livekit.TrackType_VIDEO { + if trackInfo.Type == livekit.TrackType_VIDEO { go s.bitrateReporter() } return s @@ -316,6 +318,11 @@ func (s *StreamTrackerManager) IsPaused() bool { return s.paused } +func (s *StreamTrackerManager) UpdateTrackInfo(ti *livekit.TrackInfo) { + s.trackInfo.Store(proto.Clone(ti).(*livekit.TrackInfo)) + s.maxExpectedLayerFromTrackInfo() +} + func (s *StreamTrackerManager) SetMaxExpectedSpatialLayer(layer int32) int32 { s.lock.Lock() prev := s.maxExpectedLayer @@ -540,10 +547,13 @@ func (s *StreamTrackerManager) removeAvailableLayer(layer int32) { func (s *StreamTrackerManager) maxExpectedLayerFromTrackInfo() { s.maxExpectedLayer = buffer.InvalidLayerSpatial - for _, layer := range s.trackInfo.Layers { - spatialLayer := buffer.VideoQualityToSpatialLayer(layer.Quality, s.trackInfo) - if spatialLayer > s.maxExpectedLayer { - s.maxExpectedLayer = spatialLayer + ti := s.trackInfo.Load() + if ti != nil { + for _, layer := range ti.Layers { + spatialLayer := buffer.VideoQualityToSpatialLayer(layer.Quality, ti) + if spatialLayer > s.maxExpectedLayer { + s.maxExpectedLayer = spatialLayer + } } } } diff --git a/pkg/sfu/videolayerselector/dependencydescriptor.go b/pkg/sfu/videolayerselector/dependencydescriptor.go index 9adc54b87..15d9df1eb 100644 --- a/pkg/sfu/videolayerselector/dependencydescriptor.go +++ b/pkg/sfu/videolayerselector/dependencydescriptor.go @@ -39,12 +39,14 @@ type DependencyDescriptor struct { decodeTargetsLock sync.RWMutex decodeTargets []*DecodeTarget + fnWrapper FrameNumberWrapper } func NewDependencyDescriptor(logger logger.Logger) *DependencyDescriptor { return &DependencyDescriptor{ Base: NewBase(logger), decisions: NewSelectorDecisionCache(256, 80), + fnWrapper: FrameNumberWrapper{logger: logger}, } } @@ -52,6 +54,7 @@ func NewDependencyDescriptorFromNull(vls VideoLayerSelector) *DependencyDescript return &DependencyDescriptor{ Base: vls.(*Null).Base, decisions: NewSelectorDecisionCache(256, 80), + fnWrapper: FrameNumberWrapper{logger: vls.(*Null).logger}, } } @@ -68,7 +71,7 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r ddwdt := extPkt.DependencyDescriptor if ddwdt == nil { // packet doesn't have dependency descriptor - d.logger.Debugw(fmt.Sprintf("drop packet, no DD, incoming %v, sn: %d, isKeyFrame: %v", extPkt.VideoLayer, extPkt.Packet.SequenceNumber, extPkt.KeyFrame)) + // d.logger.Debugw(fmt.Sprintf("drop packet, no DD, incoming %v, sn: %d, isKeyFrame: %v", extPkt.VideoLayer, extPkt.Packet.SequenceNumber, extPkt.KeyFrame)) return } @@ -83,6 +86,7 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r } if !d.keyFrameValid && dd.AttachedStructure == nil { + // d.logger.Debugw(fmt.Sprintf("drop packet, no attached structure, incoming %v, sn: %d, isKeyFrame: %v", extPkt.VideoLayer, extPkt.Packet.SequenceNumber, extPkt.KeyFrame)) return } @@ -91,10 +95,10 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r if err != nil { // do not mark as dropped as only error is an old frame // d.logger.Debugw(fmt.Sprintf("drop packet on decision error, incoming %v, fn: %d/%d, sn: %d", - // incomingLayer, - // dd.FrameNumber, - // extFrameNum, - // extPkt.Packet.SequenceNumber, + // incomingLayer, + // dd.FrameNumber, + // extFrameNum, + // extPkt.Packet.SequenceNumber, // ), "err", err) return } @@ -102,24 +106,24 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r case selectorDecisionDropped: // a packet of an alreadty dropped frame, maintain decision // d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d", - // incomingLayer, - // dd.FrameNumber, - // extFrameNum, - // extPkt.Packet.SequenceNumber, + // incomingLayer, + // dd.FrameNumber, + // extFrameNum, + // extPkt.Packet.SequenceNumber, // )) return } if ddwdt.StructureUpdated { - d.logger.Debugw("update dependency structure", - "structureID", dd.AttachedStructure.StructureId, - "structure", dd.AttachedStructure, - "decodeTargets", ddwdt.DecodeTargets, - "efn", extFrameNum, - "sn", extPkt.Packet.SequenceNumber, - "isKeyFrame", extPkt.KeyFrame, - "currentKeyframe", d.extKeyFrameNum, - ) + // d.logger.Debugw("update dependency structure", + // "structureID", dd.AttachedStructure.StructureId, + // "structure", dd.AttachedStructure, + // "decodeTargets", ddwdt.DecodeTargets, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // "currentKeyframe", d.extKeyFrameNum, + // ) d.updateDependencyStructure(dd.AttachedStructure, ddwdt.DecodeTargets, extFrameNum) } @@ -144,7 +148,8 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r "chainDiffs", fd.ChainDiffs, "chains", len(d.chains), "requiredKeyFrame", ddwdt.ExtKeyFrameNum, - "structureKeyFrame", d.extKeyFrameNum) + "structureKeyFrame", d.extKeyFrameNum, + ) d.decisions.AddDropped(extFrameNum) return } @@ -187,15 +192,15 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r if highestDecodeTarget.Target < 0 { // no active decode target, do not select // d.logger.Debugw( - // "drop packet for no target found", - // "highestDecodeTarget", highestDecodeTarget, - // "decodeTargets", d.decodeTargets, - // "tagetLayer", d.targetLayer, - // "incoming", incomingLayer, - // "fn", dd.FrameNumber, - // "efn", extFrameNum, - // "sn", extPkt.Packet.SequenceNumber, - // "isKeyFrame", extPkt.KeyFrame, + // "drop packet for no target found", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, // ) d.decisions.AddDropped(extFrameNum) return @@ -204,15 +209,15 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r // DD-TODO : if bandwidth in congest, could drop the 'Discardable' frame if dti == dede.DecodeTargetNotPresent { // d.logger.Debugw( - // "drop packet for decode target not present", - // "highestDecodeTarget", highestDecodeTarget, - // "decodeTargets", d.decodeTargets, - // "tagetLayer", d.targetLayer, - // "incoming", incomingLayer, - // "fn", dd.FrameNumber, - // "efn", extFrameNum, - // "sn", extPkt.Packet.SequenceNumber, - // "isKeyFrame", extPkt.KeyFrame, + // "drop packet for decode target not present", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, // ) d.decisions.AddDropped(extFrameNum) return @@ -234,15 +239,15 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r } if !isDecodable { // d.logger.Debugw( - // "drop packet for not decodable", - // "highestDecodeTarget", highestDecodeTarget, - // "decodeTargets", d.decodeTargets, - // "tagetLayer", d.targetLayer, - // "incoming", incomingLayer, - // "fn", dd.FrameNumber, - // "efn", extFrameNum, - // "sn", extPkt.Packet.SequenceNumber, - // "isKeyFrame", extPkt.KeyFrame, + // "drop packet for not decodable", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, // ) d.decisions.AddDropped(extFrameNum) return @@ -291,13 +296,26 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r Descriptor: dd, Structure: d.structure, } + + unWrapFn := uint16(d.fnWrapper.UpdateAndGet(extFrameNum, ddwdt.StructureUpdated)) + var ddClone *dede.DependencyDescriptor + if unWrapFn != dd.FrameNumber { + clone := *dd + ddClone = &clone + ddClone.FrameNumber = unWrapFn + ddExtension.Descriptor = ddClone + } + if dd.AttachedStructure == nil { if d.activeDecodeTargetsBitmask != nil { - // clone and override activebitmask - // DD-TODO: if the packet that contains the bitmask is acknowledged by RR, then we don't need it until it changed. - ddClone := *ddExtension.Descriptor + if ddClone == nil { + // clone and override activebitmask + // DD-TODO: if the packet that contains the bitmask is acknowledged by RR, then we don't need it until it changed. + clone := *dd + ddClone = &clone + ddExtension.Descriptor = ddClone + } ddClone.ActiveDecodeTargetsBitmask = d.activeDecodeTargetsBitmask - ddExtension.Descriptor = &ddClone // d.logger.Debugw("set active decode targets bitmask", "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask) } } @@ -410,7 +428,6 @@ func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) { defer d.decodeTargetsLock.RUnlock() for _, dt := range d.decodeTargets { if dt.Active() && dt.Layer.Spatial == layer && dt.Valid() { - d.logger.Debugw(fmt.Sprintf("checking sync, matching decode target, layer: %d, dt: %s, dts: %+v", layer, dt, d.decodeTargets)) return true, layer } } diff --git a/pkg/sfu/videolayerselector/framechain.go b/pkg/sfu/videolayerselector/framechain.go index 0777fb0fa..7b226ed0c 100644 --- a/pkg/sfu/videolayerselector/framechain.go +++ b/pkg/sfu/videolayerselector/framechain.go @@ -54,7 +54,7 @@ func (fc *FrameChain) OnFrame(extFrameNum uint64, fd *dd.FrameDependencyTemplate if fd.ChainDiffs[fc.chainIdx] == 0 { if fc.broken { fc.broken = false - fc.logger.Debugw("frame chain intact", "chanIdx", fc.chainIdx, "frame", extFrameNum) + // fc.logger.Debugw("frame chain intact", "chanIdx", fc.chainIdx, "frame", extFrameNum) } fc.expectFrames = fc.expectFrames[:0] return true @@ -86,7 +86,7 @@ func (fc *FrameChain) OnFrame(extFrameNum uint64, fd *dd.FrameDependencyTemplate if !intact { fc.broken = true - fc.logger.Debugw("frame chain broken", "chanIdx", fc.chainIdx, "sd", sd, "frame", extFrameNum, "prevFrame", prevFrameInChain) + // fc.logger.Debugw("frame chain broken", "chanIdx", fc.chainIdx, "sd", sd, "frame", extFrameNum, "prevFrame", prevFrameInChain) } return intact } @@ -100,7 +100,7 @@ func (fc *FrameChain) OnExpectFrameChanged(frameNum uint64, decision selectorDec if f == frameNum { if decision != selectorDecisionForwarded { fc.broken = true - fc.logger.Debugw("frame chain broken", "chanIdx", fc.chainIdx, "sd", decision, "frame", frameNum) + // fc.logger.Debugw("frame chain broken", "chanIdx", fc.chainIdx, "sd", decision, "frame", frameNum) } fc.expectFrames[i] = fc.expectFrames[len(fc.expectFrames)-1] fc.expectFrames = fc.expectFrames[:len(fc.expectFrames)-1] diff --git a/pkg/sfu/videolayerselector/framenumberwrapper.go b/pkg/sfu/videolayerselector/framenumberwrapper.go new file mode 100644 index 000000000..d2821db80 --- /dev/null +++ b/pkg/sfu/videolayerselector/framenumberwrapper.go @@ -0,0 +1,40 @@ +package videolayerselector + +import "github.com/livekit/protocol/logger" + +type FrameNumberWrapper struct { + offset uint64 + last uint64 + inited bool + logger logger.Logger +} + +// UpdateAndGet returns the wrapped frame number from the given frame number, and updates the offset to +// make sure the returned frame number is always inorder. Should only updateOffset if the new frame is a keyframe +// because frame dependencies uses on the frame number diff so frames inside a GOP should have the same offset. +func (f *FrameNumberWrapper) UpdateAndGet(new uint64, updateOffset bool) uint64 { + if !f.inited { + f.last = new + f.inited = true + return new + } + + if new <= f.last { + return new + f.offset + } + + if updateOffset { + new16 := uint16(new + f.offset) + last16 := uint16(f.last + f.offset) + // if new frame number wraps around and is considered as earlier by client, increase offset to make it later + if diff := new16 - last16; diff > 0x8000 || (diff == 0x8000 && new16 <= last16) { + // increase offset by 6000, nearly 10 seconds for 30fps video with 3 spatial layers + prevOffset := f.offset + f.offset += uint64(65535 - diff + 6000) + + f.logger.Debugw("wrap around frame number seen, update offset", "new", new, "last", f.last, "offset", f.offset, "prevOffset", prevOffset, "lastWrapFn", last16, "newWrapFn", new16) + } + } + f.last = new + return new + f.offset +} diff --git a/pkg/sfu/videolayerselector/framenumberwrapper_test.go b/pkg/sfu/videolayerselector/framenumberwrapper_test.go new file mode 100644 index 000000000..361a7de4c --- /dev/null +++ b/pkg/sfu/videolayerselector/framenumberwrapper_test.go @@ -0,0 +1,82 @@ +package videolayerselector + +import ( + "testing" + + "math/rand" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/sfu/utils" + "github.com/livekit/protocol/logger" +) + +func TestFrameNumberWrapper(t *testing.T) { + + logger.InitFromConfig(&logger.Config{Level: "debug"}, t.Name()) + + fnWrap := &FrameNumberWrapper{logger: logger.GetLogger()} + + fnWrapAround := utils.NewWrapAround[uint16, uint64](utils.WrapAroundParams{IsRestartAllowed: false}) + + firstF := uint16(1000) + + testFrameOrder := func(frame uint16, isKeyFrame bool, frame2 uint16, isKeyFrame2, expectInorder bool) { + frameUnwrap := fnWrapAround.Update(frame).ExtendedVal + wrappedFrame := uint16(fnWrap.UpdateAndGet(frameUnwrap, isKeyFrame)) + + // make sure wrap around always get in order frame number + fnWrapAround.Update(frame + (frame2-frame)/2) + + frame2Unwrap := fnWrapAround.Update(frame2).ExtendedVal + wrappedFrame2 := uint16(fnWrap.UpdateAndGet(frame2Unwrap, isKeyFrame2)) + // keeps order + require.Equal(t, expectInorder, inOrder(wrappedFrame2, wrappedFrame), "frame %d, frame2 %d, wrappedFrame %d, wrapped Frame2 %d, frameUnwrap %d, frame2Unwrap %d", frame, frame2, wrappedFrame, wrappedFrame2, frameUnwrap, frame2Unwrap) + // frame number diff should be the same if frame2 is not a key frame + if !isKeyFrame2 { + require.Equal(t, frame2-frame, wrappedFrame2-wrappedFrame) + } + } + + secondF := getFrame(firstF, true) + testFrameOrder(firstF, true, secondF, false, true) + + // non key frame keeps diff and order + for i := 0; i < 100; i++ { + // frame in order + firstF = secondF + secondF = getFrame(firstF, true) + testFrameOrder(firstF, false, secondF, false, true) + + // frame out of order + firstF = secondF + secondF = getFrame(firstF, false) + testFrameOrder(firstF, false, secondF, false, false) + + // key frame in order + firstF = secondF + secondF = getFrame(firstF, true) + testFrameOrder(firstF, false, secondF, true, true) + + // frame in order + firstF = secondF + secondF = getFrame(firstF, true) + testFrameOrder(firstF, false, secondF, false, true) + + // key frame out of order but should be in order after wrap around + firstF = secondF + secondF = getFrame(firstF, false) + testFrameOrder(firstF, false, secondF, true, true) + } +} + +func inOrder(a, b uint16) bool { + return a-b < 0x8000 || (a-b == 0x8000 && a > b) +} + +func getFrame(base uint16, inorder bool) uint16 { + if inorder { + return base + uint16(rand.Intn(0x8000)) + } + return base + uint16(rand.Intn(0x8000)) + 0x8000 +} diff --git a/pkg/telemetry/analyticsservice.go b/pkg/telemetry/analyticsservice.go index 8611337f1..3f1955873 100644 --- a/pkg/telemetry/analyticsservice.go +++ b/pkg/telemetry/analyticsservice.go @@ -17,6 +17,9 @@ package telemetry import ( "context" + "go.uber.org/atomic" + "google.golang.org/protobuf/types/known/timestamppb" + "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -28,14 +31,17 @@ import ( type AnalyticsService interface { SendStats(ctx context.Context, stats []*livekit.AnalyticsStat) SendEvent(ctx context.Context, events *livekit.AnalyticsEvent) + SendNodeRoomStates(ctx context.Context, nodeRooms *livekit.AnalyticsNodeRooms) } type analyticsService struct { - analyticsKey string - nodeID string + analyticsKey string + nodeID string + sequenceNumber atomic.Uint64 - events livekit.AnalyticsRecorderService_IngestEventsClient - stats livekit.AnalyticsRecorderService_IngestStatsClient + events livekit.AnalyticsRecorderService_IngestEventsClient + stats livekit.AnalyticsRecorderService_IngestStatsClient + nodeRooms livekit.AnalyticsRecorderService_IngestNodeRoomStatesClient } func NewAnalyticsService(_ *config.Config, currentNode routing.LocalNode) AnalyticsService { @@ -71,3 +77,16 @@ func (a *analyticsService) SendEvent(_ context.Context, event *livekit.Analytics logger.Errorw("failed to send event", err, "eventType", event.Type.String()) } } + +func (a *analyticsService) SendNodeRoomStates(_ context.Context, nodeRooms *livekit.AnalyticsNodeRooms) { + if a.nodeRooms == nil { + return + } + + nodeRooms.NodeId = a.nodeID + nodeRooms.SequenceNumber = a.sequenceNumber.Add(1) + nodeRooms.Timestamp = timestamppb.Now() + if err := a.nodeRooms.Send(nodeRooms); err != nil { + logger.Errorw("failed to send node room states", err) + } +} diff --git a/pkg/telemetry/prometheus/rooms.go b/pkg/telemetry/prometheus/rooms.go index ef320ad73..91c7cc2ba 100644 --- a/pkg/telemetry/prometheus/rooms.go +++ b/pkg/telemetry/prometheus/rooms.go @@ -15,6 +15,7 @@ package prometheus import ( + "strconv" "time" "github.com/prometheus/client_golang/prometheus" @@ -43,6 +44,7 @@ var ( promTrackSubscribedCurrent *prometheus.GaugeVec promTrackPublishCounter *prometheus.CounterVec promTrackSubscribeCounter *prometheus.CounterVec + promSessionStartTime *prometheus.HistogramVec ) func initRoomStats(nodeID string, nodeType livekit.NodeType, env string) { @@ -91,6 +93,13 @@ func initRoomStats(nodeID string, nodeType livekit.NodeType, env string) { Name: "subscribe_counter", ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String(), "env": env}, }, []string{"state", "error"}) + promSessionStartTime = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "session", + Name: "start_time_ms", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String(), "env": env}, + Buckets: prometheus.ExponentialBucketsRange(100, 10000, 15), + }, []string{"protocol_version"}) prometheus.MustRegister(promRoomCurrent) prometheus.MustRegister(promRoomDuration) @@ -99,6 +108,7 @@ func initRoomStats(nodeID string, nodeType livekit.NodeType, env string) { prometheus.MustRegister(promTrackSubscribedCurrent) prometheus.MustRegister(promTrackPublishCounter) prometheus.MustRegister(promTrackSubscribeCounter) + prometheus.MustRegister(promSessionStartTime) } func RoomStarted() { @@ -172,3 +182,7 @@ func RecordTrackSubscribeFailure(err error, isUserError bool) { trackSubscribeUserError.Inc() } } + +func RecordSessionStartTime(protocolVersion int, d time.Duration) { + promSessionStartTime.WithLabelValues(strconv.Itoa(protocolVersion)).Observe(float64(d.Milliseconds())) +} diff --git a/pkg/telemetry/signalanddatastats.go b/pkg/telemetry/signalanddatastats.go index 5ed99ebb0..5ab5855ab 100644 --- a/pkg/telemetry/signalanddatastats.go +++ b/pkg/telemetry/signalanddatastats.go @@ -49,6 +49,7 @@ type BytesTrackStats struct { trackID livekit.TrackID pID livekit.ParticipantID send, recv atomic.Uint64 + sendMessages, recvMessages atomic.Uint32 totalSendBytes, totalRecvBytes atomic.Uint64 totalSendMessages, totalRecvMessages atomic.Uint32 telemetry TelemetryService @@ -68,10 +69,12 @@ func NewBytesTrackStats(trackID livekit.TrackID, pID livekit.ParticipantID, tele func (s *BytesTrackStats) AddBytes(bytes uint64, isSend bool) { if isSend { s.send.Add(bytes) + s.sendMessages.Inc() s.totalSendBytes.Add(bytes) s.totalSendMessages.Inc() } else { s.recv.Add(bytes) + s.recvMessages.Inc() s.totalRecvBytes.Add(bytes) s.totalRecvMessages.Inc() } @@ -95,7 +98,10 @@ func (s *BytesTrackStats) report() { if recv := s.recv.Swap(0); recv > 0 { s.telemetry.TrackStats(StatsKeyForData(livekit.StreamType_UPSTREAM, s.pID, s.trackID), &livekit.AnalyticsStat{ Streams: []*livekit.AnalyticsStream{ - {PrimaryBytes: recv}, + { + PrimaryBytes: recv, + PrimaryPackets: s.recvMessages.Swap(0), + }, }, }) } @@ -103,7 +109,10 @@ func (s *BytesTrackStats) report() { if send := s.send.Swap(0); send > 0 { s.telemetry.TrackStats(StatsKeyForData(livekit.StreamType_DOWNSTREAM, s.pID, s.trackID), &livekit.AnalyticsStat{ Streams: []*livekit.AnalyticsStream{ - {PrimaryBytes: send}, + { + PrimaryBytes: send, + PrimaryPackets: s.sendMessages.Swap(0), + }, }, }) } diff --git a/pkg/telemetry/telemetryfakes/fake_analytics_service.go b/pkg/telemetry/telemetryfakes/fake_analytics_service.go index 5af3a3513..21b2bb6a1 100644 --- a/pkg/telemetry/telemetryfakes/fake_analytics_service.go +++ b/pkg/telemetry/telemetryfakes/fake_analytics_service.go @@ -16,6 +16,12 @@ type FakeAnalyticsService struct { arg1 context.Context arg2 *livekit.AnalyticsEvent } + SendNodeRoomStatesStub func(context.Context, *livekit.AnalyticsNodeRooms) + sendNodeRoomStatesMutex sync.RWMutex + sendNodeRoomStatesArgsForCall []struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + } SendStatsStub func(context.Context, []*livekit.AnalyticsStat) sendStatsMutex sync.RWMutex sendStatsArgsForCall []struct { @@ -59,6 +65,39 @@ func (fake *FakeAnalyticsService) SendEventArgsForCall(i int) (context.Context, return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeAnalyticsService) SendNodeRoomStates(arg1 context.Context, arg2 *livekit.AnalyticsNodeRooms) { + fake.sendNodeRoomStatesMutex.Lock() + fake.sendNodeRoomStatesArgsForCall = append(fake.sendNodeRoomStatesArgsForCall, struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + }{arg1, arg2}) + stub := fake.SendNodeRoomStatesStub + fake.recordInvocation("SendNodeRoomStates", []interface{}{arg1, arg2}) + fake.sendNodeRoomStatesMutex.Unlock() + if stub != nil { + fake.SendNodeRoomStatesStub(arg1, arg2) + } +} + +func (fake *FakeAnalyticsService) SendNodeRoomStatesCallCount() int { + fake.sendNodeRoomStatesMutex.RLock() + defer fake.sendNodeRoomStatesMutex.RUnlock() + return len(fake.sendNodeRoomStatesArgsForCall) +} + +func (fake *FakeAnalyticsService) SendNodeRoomStatesCalls(stub func(context.Context, *livekit.AnalyticsNodeRooms)) { + fake.sendNodeRoomStatesMutex.Lock() + defer fake.sendNodeRoomStatesMutex.Unlock() + fake.SendNodeRoomStatesStub = stub +} + +func (fake *FakeAnalyticsService) SendNodeRoomStatesArgsForCall(i int) (context.Context, *livekit.AnalyticsNodeRooms) { + fake.sendNodeRoomStatesMutex.RLock() + defer fake.sendNodeRoomStatesMutex.RUnlock() + argsForCall := fake.sendNodeRoomStatesArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + func (fake *FakeAnalyticsService) SendStats(arg1 context.Context, arg2 []*livekit.AnalyticsStat) { var arg2Copy []*livekit.AnalyticsStat if arg2 != nil { @@ -102,6 +141,8 @@ func (fake *FakeAnalyticsService) Invocations() map[string][][]interface{} { defer fake.invocationsMutex.RUnlock() fake.sendEventMutex.RLock() defer fake.sendEventMutex.RUnlock() + fake.sendNodeRoomStatesMutex.RLock() + defer fake.sendNodeRoomStatesMutex.RUnlock() fake.sendStatsMutex.RLock() defer fake.sendStatsMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} diff --git a/pkg/telemetry/telemetryfakes/fake_telemetry_service.go b/pkg/telemetry/telemetryfakes/fake_telemetry_service.go index fd3ae6ff5..3be71676b 100644 --- a/pkg/telemetry/telemetryfakes/fake_telemetry_service.go +++ b/pkg/telemetry/telemetryfakes/fake_telemetry_service.go @@ -62,6 +62,12 @@ type FakeTelemetryService struct { arg1 context.Context arg2 *livekit.IngressInfo } + LocalRoomStateStub func(context.Context, *livekit.AnalyticsNodeRooms) + localRoomStateMutex sync.RWMutex + localRoomStateArgsForCall []struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + } NotifyEventStub func(context.Context, *livekit.WebhookEvent) notifyEventMutex sync.RWMutex notifyEventArgsForCall []struct { @@ -122,6 +128,12 @@ type FakeTelemetryService struct { arg1 context.Context arg2 *livekit.AnalyticsEvent } + SendNodeRoomStatesStub func(context.Context, *livekit.AnalyticsNodeRooms) + sendNodeRoomStatesMutex sync.RWMutex + sendNodeRoomStatesArgsForCall []struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + } SendStatsStub func(context.Context, []*livekit.AnalyticsStat) sendStatsMutex sync.RWMutex sendStatsArgsForCall []struct { @@ -533,6 +545,39 @@ func (fake *FakeTelemetryService) IngressUpdatedArgsForCall(i int) (context.Cont return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeTelemetryService) LocalRoomState(arg1 context.Context, arg2 *livekit.AnalyticsNodeRooms) { + fake.localRoomStateMutex.Lock() + fake.localRoomStateArgsForCall = append(fake.localRoomStateArgsForCall, struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + }{arg1, arg2}) + stub := fake.LocalRoomStateStub + fake.recordInvocation("LocalRoomState", []interface{}{arg1, arg2}) + fake.localRoomStateMutex.Unlock() + if stub != nil { + fake.LocalRoomStateStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) LocalRoomStateCallCount() int { + fake.localRoomStateMutex.RLock() + defer fake.localRoomStateMutex.RUnlock() + return len(fake.localRoomStateArgsForCall) +} + +func (fake *FakeTelemetryService) LocalRoomStateCalls(stub func(context.Context, *livekit.AnalyticsNodeRooms)) { + fake.localRoomStateMutex.Lock() + defer fake.localRoomStateMutex.Unlock() + fake.LocalRoomStateStub = stub +} + +func (fake *FakeTelemetryService) LocalRoomStateArgsForCall(i int) (context.Context, *livekit.AnalyticsNodeRooms) { + fake.localRoomStateMutex.RLock() + defer fake.localRoomStateMutex.RUnlock() + argsForCall := fake.localRoomStateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + func (fake *FakeTelemetryService) NotifyEvent(arg1 context.Context, arg2 *livekit.WebhookEvent) { fake.notifyEventMutex.Lock() fake.notifyEventArgsForCall = append(fake.notifyEventArgsForCall, struct { @@ -809,6 +854,39 @@ func (fake *FakeTelemetryService) SendEventArgsForCall(i int) (context.Context, return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeTelemetryService) SendNodeRoomStates(arg1 context.Context, arg2 *livekit.AnalyticsNodeRooms) { + fake.sendNodeRoomStatesMutex.Lock() + fake.sendNodeRoomStatesArgsForCall = append(fake.sendNodeRoomStatesArgsForCall, struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + }{arg1, arg2}) + stub := fake.SendNodeRoomStatesStub + fake.recordInvocation("SendNodeRoomStates", []interface{}{arg1, arg2}) + fake.sendNodeRoomStatesMutex.Unlock() + if stub != nil { + fake.SendNodeRoomStatesStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) SendNodeRoomStatesCallCount() int { + fake.sendNodeRoomStatesMutex.RLock() + defer fake.sendNodeRoomStatesMutex.RUnlock() + return len(fake.sendNodeRoomStatesArgsForCall) +} + +func (fake *FakeTelemetryService) SendNodeRoomStatesCalls(stub func(context.Context, *livekit.AnalyticsNodeRooms)) { + fake.sendNodeRoomStatesMutex.Lock() + defer fake.sendNodeRoomStatesMutex.Unlock() + fake.SendNodeRoomStatesStub = stub +} + +func (fake *FakeTelemetryService) SendNodeRoomStatesArgsForCall(i int) (context.Context, *livekit.AnalyticsNodeRooms) { + fake.sendNodeRoomStatesMutex.RLock() + defer fake.sendNodeRoomStatesMutex.RUnlock() + argsForCall := fake.sendNodeRoomStatesArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + func (fake *FakeTelemetryService) SendStats(arg1 context.Context, arg2 []*livekit.AnalyticsStat) { var arg2Copy []*livekit.AnalyticsStat if arg2 != nil { @@ -1359,6 +1437,8 @@ func (fake *FakeTelemetryService) Invocations() map[string][][]interface{} { defer fake.ingressStartedMutex.RUnlock() fake.ingressUpdatedMutex.RLock() defer fake.ingressUpdatedMutex.RUnlock() + fake.localRoomStateMutex.RLock() + defer fake.localRoomStateMutex.RUnlock() fake.notifyEventMutex.RLock() defer fake.notifyEventMutex.RUnlock() fake.participantActiveMutex.RLock() @@ -1375,6 +1455,8 @@ func (fake *FakeTelemetryService) Invocations() map[string][][]interface{} { defer fake.roomStartedMutex.RUnlock() fake.sendEventMutex.RLock() defer fake.sendEventMutex.RUnlock() + fake.sendNodeRoomStatesMutex.RLock() + defer fake.sendNodeRoomStatesMutex.RUnlock() fake.sendStatsMutex.RLock() defer fake.sendStatsMutex.RUnlock() fake.trackMaxSubscribedVideoQualityMutex.RLock() diff --git a/pkg/telemetry/telemetryservice.go b/pkg/telemetry/telemetryservice.go index c74619eda..fe5e329c2 100644 --- a/pkg/telemetry/telemetryservice.go +++ b/pkg/telemetry/telemetryservice.go @@ -20,6 +20,7 @@ import ( "time" "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/webhook" @@ -73,6 +74,7 @@ type TelemetryService interface { IngressStarted(ctx context.Context, info *livekit.IngressInfo) IngressUpdated(ctx context.Context, info *livekit.IngressInfo) IngressEnded(ctx context.Context, info *livekit.IngressInfo) + LocalRoomState(ctx context.Context, info *livekit.AnalyticsNodeRooms) // helpers AnalyticsService @@ -81,15 +83,15 @@ type TelemetryService interface { } const ( - workerCleanupWait = 3 * time.Minute - jobQueueBufferSize = 10000 + workerCleanupWait = 3 * time.Minute + jobsQueueMinSize = 2048 ) type telemetryService struct { AnalyticsService - notifier webhook.QueuedNotifier - jobsChan chan func() + notifier webhook.QueuedNotifier + jobsQueue *utils.OpsQueue lock sync.RWMutex workers map[livekit.ParticipantID]*StatsWorker @@ -99,11 +101,12 @@ func NewTelemetryService(notifier webhook.QueuedNotifier, analytics AnalyticsSer t := &telemetryService{ AnalyticsService: analytics, - notifier: notifier, - jobsChan: make(chan func(), jobQueueBufferSize), - workers: make(map[livekit.ParticipantID]*StatsWorker), + notifier: notifier, + jobsQueue: utils.NewOpsQueue("telemetry", jobsQueueMinSize, true), + workers: make(map[livekit.ParticipantID]*StatsWorker), } + t.jobsQueue.Start() go t.run() return t @@ -131,19 +134,12 @@ func (t *telemetryService) run() { t.FlushStats() case <-cleanupTicker.C: t.cleanupWorkers() - case op := <-t.jobsChan: - op() } } } func (t *telemetryService) enqueue(op func()) { - select { - case t.jobsChan <- op: - // success - default: - logger.Warnw("telemetry queue full", nil) - } + t.jobsQueue.Enqueue(op) } func (t *telemetryService) getWorker(participantID livekit.ParticipantID) (worker *StatsWorker, ok bool) { @@ -187,3 +183,9 @@ func (t *telemetryService) cleanupWorkers() { } } } + +func (t *telemetryService) LocalRoomState(ctx context.Context, info *livekit.AnalyticsNodeRooms) { + t.enqueue(func() { + t.SendNodeRoomStates(ctx, info) + }) +} diff --git a/pkg/utils/context.go b/pkg/utils/context.go index f5262e178..4ad9086d2 100644 --- a/pkg/utils/context.go +++ b/pkg/utils/context.go @@ -16,17 +16,33 @@ package utils -import "context" +import ( + "context" -var attemptKey = struct{}{} + "github.com/livekit/protocol/logger" +) + +type attemptKey struct{} +type loggerKey = struct{} func ContextWithAttempt(ctx context.Context, attempt int) context.Context { - return context.WithValue(ctx, attemptKey, attempt) + return context.WithValue(ctx, attemptKey{}, attempt) } func GetAttempt(ctx context.Context) int { - if attempt, ok := ctx.Value(attemptKey).(int); ok { + if attempt, ok := ctx.Value(attemptKey{}).(int); ok { return attempt } return 0 } + +func ContextWithLogger(ctx context.Context, logger logger.Logger) context.Context { + return context.WithValue(ctx, loggerKey{}, logger) +} + +func GetLogger(ctx context.Context) logger.Logger { + if l, ok := ctx.Value(loggerKey{}).(logger.Logger); ok { + return l + } + return logger.GetLogger() +} diff --git a/pkg/utils/opsqueue.go b/pkg/utils/opsqueue.go index 3e992461c..01f9a12ff 100644 --- a/pkg/utils/opsqueue.go +++ b/pkg/utils/opsqueue.go @@ -15,33 +15,34 @@ package utils import ( + "math/bits" "sync" - "github.com/livekit/protocol/logger" + "github.com/gammazero/deque" + "github.com/livekit/protocol/utils" ) type OpsQueue struct { - logger logger.Logger - name string - size int + name string + flushOnStop bool - lock sync.RWMutex - ops chan func() + lock sync.Mutex + ops deque.Deque[func()] + wake chan struct{} isStarted bool + doneChan chan struct{} isStopped bool } -func NewOpsQueue(logger logger.Logger, name string, size int) *OpsQueue { - return &OpsQueue{ - logger: logger, - name: name, - size: size, - ops: make(chan func(), size), +func NewOpsQueue(name string, minSize uint, flushOnStop bool) *OpsQueue { + oq := &OpsQueue{ + name: name, + flushOnStop: flushOnStop, + wake: make(chan struct{}, 1), + doneChan: make(chan struct{}), } -} - -func (oq *OpsQueue) SetLogger(logger logger.Logger) { - oq.logger = logger + oq.ops.SetMinCapacity(uint(utils.Min(bits.Len64(uint64(minSize-1)), 16))) + return oq } func (oq *OpsQueue) Start() { @@ -57,42 +58,52 @@ func (oq *OpsQueue) Start() { go oq.process() } -func (oq *OpsQueue) Stop() { +func (oq *OpsQueue) Stop() <-chan struct{} { oq.lock.Lock() if oq.isStopped { oq.lock.Unlock() - return + return oq.doneChan } oq.isStopped = true - close(oq.ops) + close(oq.wake) oq.lock.Unlock() -} - -func (oq *OpsQueue) IsStarted() bool { - oq.lock.RLock() - defer oq.lock.RUnlock() - - return oq.isStarted + return oq.doneChan } func (oq *OpsQueue) Enqueue(op func()) { - oq.lock.RLock() - if oq.isStopped { - oq.lock.RUnlock() - return - } + oq.lock.Lock() + defer oq.lock.Unlock() - select { - case oq.ops <- op: - default: - oq.logger.Errorw("ops queue full", nil, "name", oq.name, "size", oq.size) + oq.ops.PushBack(op) + if oq.ops.Len() == 1 && !oq.isStopped { + select { + case oq.wake <- struct{}{}: + default: + } } - oq.lock.RUnlock() } func (oq *OpsQueue) process() { - for op := range oq.ops { - op() + defer close(oq.doneChan) + + for { + <-oq.wake + for { + oq.lock.Lock() + if oq.isStopped && (!oq.flushOnStop || oq.ops.Len() == 0) { + oq.lock.Unlock() + return + } + + if oq.ops.Len() == 0 { + oq.lock.Unlock() + break + } + op := oq.ops.PopFront() + oq.lock.Unlock() + + op() + } } } diff --git a/pkg/utils/protocol.go b/pkg/utils/protocol.go new file mode 100644 index 000000000..23844742f --- /dev/null +++ b/pkg/utils/protocol.go @@ -0,0 +1,32 @@ +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" +) + +func ClientInfoWithoutAddress(c *livekit.ClientInfo) *livekit.ClientInfo { + if c == nil { + return nil + } + clone := proto.Clone(c).(*livekit.ClientInfo) + clone.Address = "" + return clone +} diff --git a/test/client/client.go b/test/client/client.go index 70b135a6d..e0be22e47 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -39,6 +39,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/rtc/transport/transportfakes" "github.com/livekit/livekit-server/pkg/rtc/types" ) @@ -114,7 +115,7 @@ type Options struct { } func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error) { - u, err := url.Parse(host + "/rtc?protocol=7") + u, err := url.Parse(host + fmt.Sprintf("/rtc?protocol=%d", types.CurrentProtocol)) if err != nil { return nil, err } @@ -199,33 +200,37 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { // i. e. the publisher transport on client side has SUBSCRIBER signal target (i. e. publisher is offerer). // Same applies for subscriber transport also // + publisherHandler := &transportfakes.FakeHandler{} c.publisher, err = rtc.NewPCTransport(rtc.TransportParams{ Config: &conf, DirectionConfig: conf.Subscriber, EnabledCodecs: codecs, IsOfferer: true, IsSendSide: true, + Handler: publisherHandler, }) if err != nil { return nil, err } + subscriberHandler := &transportfakes.FakeHandler{} c.subscriber, err = rtc.NewPCTransport(rtc.TransportParams{ Config: &conf, DirectionConfig: conf.Publisher, EnabledCodecs: codecs, + Handler: subscriberHandler, }) if err != nil { return nil, err } - c.publisher.OnICECandidate(func(ic *webrtc.ICECandidate) error { + publisherHandler.OnICECandidateCalls(func(ic *webrtc.ICECandidate, t livekit.SignalTarget) error { if ic == nil { return nil } return c.SendIceCandidate(ic, livekit.SignalTarget_PUBLISHER) }) - c.publisher.OnOffer(c.onOffer) - c.publisher.OnFullyEstablished(func() { + publisherHandler.OnOfferCalls(c.onOffer) + publisherHandler.OnFullyEstablishedCalls(func() { logger.Debugw("publisher fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) c.publisherFullyEstablished.Store(true) }) @@ -245,17 +250,17 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { return nil, err } - c.subscriber.OnICECandidate(func(ic *webrtc.ICECandidate) error { + subscriberHandler.OnICECandidateCalls(func(ic *webrtc.ICECandidate, t livekit.SignalTarget) error { if ic == nil { return nil } return c.SendIceCandidate(ic, livekit.SignalTarget_SUBSCRIBER) }) - c.subscriber.OnTrack(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + subscriberHandler.OnTrackCalls(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { go c.processTrack(track) }) - c.subscriber.OnDataPacket(c.handleDataMessage) - c.subscriber.OnInitialConnected(func() { + subscriberHandler.OnDataPacketCalls(c.handleDataMessage) + subscriberHandler.OnInitialConnectedCalls(func() { logger.Debugw("subscriber initial connected", "participant", c.localParticipant.Identity) c.lock.Lock() @@ -272,11 +277,11 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { go c.OnConnected() } }) - c.subscriber.OnFullyEstablished(func() { + subscriberHandler.OnFullyEstablishedCalls(func() { logger.Debugw("subscriber fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) c.subscriberFullyEstablished.Store(true) }) - c.subscriber.OnAnswer(func(answer webrtc.SessionDescription) error { + subscriberHandler.OnAnswerCalls(func(answer webrtc.SessionDescription) error { // send remote an answer logger.Infow("sending subscriber answer", "participant", c.localParticipant.Identity, @@ -493,7 +498,10 @@ func (c *RTCClient) Stop() { logger.Infow("stopping client", "ID", c.ID()) _ = c.SendRequest(&livekit.SignalRequest{ Message: &livekit.SignalRequest_Leave{ - Leave: &livekit.LeaveRequest{}, + Leave: &livekit.LeaveRequest{ + Reason: livekit.DisconnectReason_CLIENT_INITIATED, + Action: livekit.LeaveRequest_DISCONNECT, + }, }, }) c.publisherFullyEstablished.Store(false) diff --git a/test/integration_helpers.go b/test/integration_helpers.go index 7a11b353e..0ce079c23 100644 --- a/test/integration_helpers.go +++ b/test/integration_helpers.go @@ -183,7 +183,6 @@ func createMultiNodeServer(nodeID string, port uint32) *service.LivekitServer { conf.RTC.TCPPort = port + 2 conf.Redis.Address = "localhost:6379" conf.Keys = map[string]string{testApiKey: testApiSecret} - conf.SignalRelay.Enabled = true currentNode, err := routing.NewLocalNode(conf) if err != nil { diff --git a/test/webhook_test.go b/test/webhook_test.go index 678c48c80..0463eea1a 100644 --- a/test/webhook_test.go +++ b/test/webhook_test.go @@ -35,6 +35,7 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/service" "github.com/livekit/livekit-server/pkg/testutils" ) @@ -108,7 +109,7 @@ func TestWebhooks(t *testing.T) { // room closed rm := server.RoomManager().GetRoom(context.Background(), testRoom) - rm.Close() + rm.Close(types.ParticipantCloseReasonNone) testutils.WithTimeout(t, func() string { if ts.GetEvent(webhook.EventRoomFinished) == nil { return "did not receive RoomFinished" diff --git a/version/version.go b/version/version.go index cbd85195e..e821eae6b 100644 --- a/version/version.go +++ b/version/version.go @@ -14,4 +14,4 @@ package version -const Version = "1.5.1" +const Version = "1.5.2"