From 890fd94249c51ff82ea5ab521198b435f727f5d7 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Thu, 28 Aug 2025 12:16:18 +0530 Subject: [PATCH] Single peer connection mode (#3873) * WIP * check using protocol version * revert * clean up * sdp cid argument * WIP * WIP * test * clean up * clean up * fixes * clean up * clean up * clean up * conditional checks * tests for both dual and single peer connection * test * test * test * type check * test * todo * munges * combined config * populate mid * limit to receive only * clean up * clean up * clean up * older test * clean up * alternative audio codec * dtx * don't need to copy * Anunay feedback * use the available peer connection * publisher check * WIP * WIP * WIP * no mid * media sections requirement * mage generate * WIP * WIP * set data channel receive size for test * handle early media better * WIP * do not do ICERestart if no subscriber * WIP * WIP * WIP * WIP * WIP * WIP * start up subscriber RTCP worker * WIP * WIP * clean up * clean up * flag to indicate use of single peer connection * remove unused interface method * clean up * clean up * Jie feedback #1 * deps * do not access subscriber in one shot mode * more places for one shot mode * more one shot fixes * deps * deps * test --- go.mod | 22 +- go.sum | 54 +- pkg/config/config.go | 2 +- pkg/routing/interfaces.go | 98 +- pkg/rtc/config.go | 90 +- pkg/rtc/mediaengine.go | 361 +++-- pkg/rtc/mediatracksubscriptions.go | 3 +- pkg/rtc/participant.go | 108 +- pkg/rtc/participant_internal_test.go | 5 +- pkg/rtc/participant_sdp.go | 45 +- pkg/rtc/participant_signal.go | 16 + pkg/rtc/room.go | 13 +- pkg/rtc/signalling/interfaces.go | 1 + pkg/rtc/signalling/signalling.go | 8 + pkg/rtc/signalling/signallingunimplemented.go | 4 + pkg/rtc/transport.go | 289 +++- pkg/rtc/transport/handler.go | 4 + .../transport/transportfakes/fake_handler.go | 74 ++ pkg/rtc/transport_test.go | 47 +- pkg/rtc/transportmanager.go | 259 ++-- pkg/rtc/types/interfaces.go | 4 +- .../typesfakes/fake_local_participant.go | 63 + pkg/rtc/wrappedreceiver.go | 10 +- pkg/service/roommanager.go | 8 + pkg/service/rtcservice.go | 13 +- test/agent_test.go | 312 ++--- test/client/client.go | 384 ++++-- test/integration_helpers.go | 27 +- test/multinode_roomservice_test.go | 172 +-- test/multinode_test.go | 458 ++++--- test/scenarios.go | 296 +++-- test/singlenode_test.go | 1172 +++++++++-------- test/webhook_test.go | 138 +- 33 files changed, 2768 insertions(+), 1792 deletions(-) diff --git a/go.mod b/go.mod index 13fbb5b9e..04d05ed99 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.24.6 require ( github.com/bep/debounce v1.2.1 github.com/d5/tengo/v2 v2.17.0 - github.com/dennwc/iters v1.1.0 + github.com/dennwc/iters v1.2.2 github.com/dustin/go-humanize v1.0.1 github.com/elliotchance/orderedmap/v2 v2.7.0 github.com/florianl/go-tc v0.4.5 @@ -23,7 +23,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731 github.com/livekit/mediatransportutil v0.0.0-20250825135402-7bc31f107ade - github.com/livekit/protocol v1.40.0 + github.com/livekit/protocol v1.40.1-0.20250826073447-c714707269e5 github.com/livekit/psrpc v0.6.1-0.20250726180611-3915e005e741 github.com/mackerelio/go-osstat v0.2.6 github.com/magefile/mage v1.15.0 @@ -41,12 +41,12 @@ require ( github.com/pion/sdp/v3 v3.0.15 github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v4 v4.1.1 - github.com/pion/webrtc/v4 v4.1.3 + github.com/pion/webrtc/v4 v4.1.5-0.20250828044558-c376d0edf977 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.23.0 github.com/redis/go-redis/v9 v9.12.1 github.com/rs/cors v1.11.1 - github.com/stretchr/testify v1.10.0 + github.com/stretchr/testify v1.11.1 github.com/thoas/go-funk v0.9.3 github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 github.com/twitchtv/twirp v8.1.3+incompatible @@ -55,10 +55,10 @@ require ( go.uber.org/atomic v1.11.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 - golang.org/x/exp v0.0.0-20250811191247-51f88131bc50 + golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b golang.org/x/mod v0.27.0 golang.org/x/sync v0.16.0 - google.golang.org/protobuf v1.36.7 + google.golang.org/protobuf v1.36.8 gopkg.in/yaml.v3 v3.0.1 ) @@ -68,7 +68,7 @@ require ( ) require ( - buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.7-20250717185734-6c6e0d3c608e.1 // indirect + buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.8-20250717185734-6c6e0d3c608e.1 // indirect buf.build/go/protovalidate v0.14.0 // indirect buf.build/go/protoyaml v0.6.0 // indirect cel.dev/expr v0.24.0 // indirect @@ -109,7 +109,7 @@ require ( github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/term v0.5.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/nats-io/nats.go v1.44.0 // indirect + github.com/nats-io/nats.go v1.45.0 // indirect github.com/nats-io/nkeys v0.4.11 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect @@ -139,8 +139,8 @@ require ( golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect golang.org/x/tools v0.36.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250811230008-5f3141c8851a // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect - google.golang.org/grpc v1.74.2 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect + google.golang.org/grpc v1.75.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index 0685362d5..06d9e517f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.7-20250717185734-6c6e0d3c608e.1 h1:/AZH8sVB6LHv8G+hZlAMCP31NevnesHwYgnlgS5Vt14= -buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.7-20250717185734-6c6e0d3c608e.1/go.mod h1:eva/VCrd8X7xuJw+JtwCEyrCKiRRASukFqmirnWBvFU= +buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.8-20250717185734-6c6e0d3c608e.1 h1:sjY1k5uszbIZfv11HO2keV4SLhNA47SabPO886v7Rvo= +buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.8-20250717185734-6c6e0d3c608e.1/go.mod h1:8EQ5GzyGJQ5tEIwMSxCl8RKJYsjCpAwkdcENoioXT6g= buf.build/go/protovalidate v0.14.0 h1:kr/rC/no+DtRyYX+8KXLDxNnI1rINz0imk5K44ZpZ3A= buf.build/go/protovalidate v0.14.0/go.mod h1:+F/oISho9MO7gJQNYC2VWLzcO1fTPmaTA08SDYJZncA= buf.build/go/protoyaml v0.6.0 h1:Nzz1lvcXF8YgNZXk+voPPwdU8FjDPTUV4ndNTXN0n2w= @@ -47,8 +47,8 @@ github.com/d5/tengo/v2 v2.17.0/go.mod h1:XRGjEs5I9jYIKTxly6HCF8oiiilk5E/RYXOZ5b0 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dennwc/iters v1.1.0 h1:PsS3DbOU7GxSUQO0e7SGmzHkPhtwOlwbqggJ++Bgnr8= -github.com/dennwc/iters v1.1.0/go.mod h1:M9KuuMBeyEXYTmB7EnI9SCyALFCmPWOIxn5W1L0CjGg= +github.com/dennwc/iters v1.2.2 h1:XH2/Etihiy9ZvPOVCR+icQXeYlhbvS7k0qro4x/2qQo= +github.com/dennwc/iters v1.2.2/go.mod h1:M9KuuMBeyEXYTmB7EnI9SCyALFCmPWOIxn5W1L0CjGg= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/docker/cli v27.4.1+incompatible h1:VzPiUlRJ/xh+otB75gva3r05isHMo5wXDfPRi5/b4hI= @@ -170,8 +170,8 @@ github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731 h1:9x+U2HGLrSw5AT github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20250825135402-7bc31f107ade h1:lpxPcglwzUWNB4J0S2qZuyMehzmR7vW9whzSwV4IGoI= github.com/livekit/mediatransportutil v0.0.0-20250825135402-7bc31f107ade/go.mod h1:mSNtYzSf6iY9xM3UX42VEI+STHvMgHmrYzEHPcdhB8A= -github.com/livekit/protocol v1.40.0 h1:FzCv17ivkxI4ySE0uFXSPPVPBVepf736Cgnmi054z2o= -github.com/livekit/protocol v1.40.0/go.mod h1:YlgUxAegtU8jZ0tVXoIV/4fHeHqqLvS+6JnPKDbpFPU= +github.com/livekit/protocol v1.40.1-0.20250826073447-c714707269e5 h1:aBqHlrgCI3qzVUAoOx7n5Kt8GQ8g7UbNp69fUt2gO8I= +github.com/livekit/protocol v1.40.1-0.20250826073447-c714707269e5/go.mod h1:YlgUxAegtU8jZ0tVXoIV/4fHeHqqLvS+6JnPKDbpFPU= github.com/livekit/psrpc v0.6.1-0.20250726180611-3915e005e741 h1:KKL1u94l6dF9u4cBwnnfozk27GH1txWy2SlvkfgmzoY= github.com/livekit/psrpc v0.6.1-0.20250726180611-3915e005e741/go.mod h1:AuDC5uOoEjQJEc69v4Li3t77Ocz0e0NdjQEuFfO+vfk= github.com/mackerelio/go-osstat v0.2.6 h1:gs4U8BZeS1tjrL08tt5VUliVvSWP26Ai2Ob8Lr7f2i0= @@ -215,8 +215,8 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/nats-io/nats.go v1.44.0 h1:ECKVrDLdh/kDPV1g0gAQ+2+m2KprqZK5O/eJAyAnH2M= -github.com/nats-io/nats.go v1.44.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= +github.com/nats-io/nats.go v1.45.0 h1:/wGPbnYXDM0pLKFjZTX+2JOw9TQPoIgTFrUaH97giwA= +github.com/nats-io/nats.go v1.45.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0= github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -263,8 +263,8 @@ github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1 github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8= -github.com/pion/webrtc/v4 v4.1.3 h1:YZ67Boj9X/hk190jJZ8+HFGQ6DqSZ/fYP3sLAZv7c3c= -github.com/pion/webrtc/v4 v4.1.3/go.mod h1:rsq+zQ82ryfR9vbb0L1umPJ6Ogq7zm8mcn9fcGnxomM= +github.com/pion/webrtc/v4 v4.1.5-0.20250828044558-c376d0edf977 h1:skeDG50xusAciSuWumr6OT/PvlyKdIC/E7OYmysehWw= +github.com/pion/webrtc/v4 v4.1.5-0.20250828044558-c376d0edf977/go.mod h1:L+kyaW50BzPT8fQQyllbJCz+a6T3JsFTQfJs8zbWuAI= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -303,8 +303,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw= github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y= @@ -339,10 +339,10 @@ go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= -go.opentelemetry.io/otel/sdk v1.36.0 h1:b6SYIuLRs88ztox4EyrvRti80uXIFy+Sqzoh9kFULbs= -go.opentelemetry.io/otel/sdk v1.36.0/go.mod h1:+lC+mTgD+MUWfjJubi2vvXWcVxyr9rmlshZni72pXeY= -go.opentelemetry.io/otel/sdk/metric v1.36.0 h1:r0ntwwGosWGaa0CrSt8cuNuTcccMXERFwHX4dThiPis= -go.opentelemetry.io/otel/sdk/metric v1.36.0/go.mod h1:qTNOhFDfKRwX0yXOqJYegL5WRaW376QbB7P4Pb0qva4= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= @@ -364,8 +364,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= -golang.org/x/exp v0.0.0-20250811191247-51f88131bc50 h1:3yiSh9fhy5/RhCSntf4Sy0Tnx50DmMpQ4MQdKKk4yg4= -golang.org/x/exp v0.0.0-20250811191247-51f88131bc50/go.mod h1:rT6SFzZ7oxADUDx58pcaKFTcZ+inxAa9fTrYx/uVYwg= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -474,14 +474,16 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20250811230008-5f3141c8851a h1:DMCgtIAIQGZqJXMVzJF4MV8BlWoJh2ZuFiRdAleyr58= -google.golang.org/genproto/googleapis/api v0.0.0-20250811230008-5f3141c8851a/go.mod h1:y2yVLIE/CSMCPXaHnSKXxu1spLPnglFLegmgdY23uuE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= -google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= -google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= -google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= -google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY= +google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc= +google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4= +google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/pkg/config/config.go b/pkg/config/config.go index 9c0576072..0faf481f1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -365,8 +365,8 @@ var DefaultConfig = Config{ {Mime: mime.MimeTypeH264.String()}, {Mime: mime.MimeTypeVP9.String()}, {Mime: mime.MimeTypeAV1.String()}, - {Mime: mime.MimeTypeRTX.String()}, {Mime: mime.MimeTypeH265.String()}, + {Mime: mime.MimeTypeRTX.String()}, }, EmptyTimeout: 5 * 60, DepartureTimeout: 20, diff --git a/pkg/routing/interfaces.go b/pkg/routing/interfaces.go index 714f70b6f..8355eed3d 100644 --- a/pkg/routing/interfaces.go +++ b/pkg/routing/interfaces.go @@ -184,22 +184,23 @@ func CreateRouter( // ------------------------------------------------ type ParticipantInit struct { - Identity livekit.ParticipantIdentity - Name livekit.ParticipantName - Reconnect bool - ReconnectReason livekit.ReconnectReason - AutoSubscribe bool - Client *livekit.ClientInfo - Grants *auth.ClaimGrants - Region string - AdaptiveStream bool - ID livekit.ParticipantID - SubscriberAllowPause *bool - DisableICELite bool - CreateRoom *livekit.CreateRoomRequest - AddTrackRequests []*livekit.AddTrackRequest - PublisherOffer *livekit.SessionDescription - SyncState *livekit.SyncState + Identity livekit.ParticipantIdentity + Name livekit.ParticipantName + Reconnect bool + ReconnectReason livekit.ReconnectReason + AutoSubscribe bool + Client *livekit.ClientInfo + Grants *auth.ClaimGrants + Region string + AdaptiveStream bool + ID livekit.ParticipantID + SubscriberAllowPause *bool + DisableICELite bool + CreateRoom *livekit.CreateRoomRequest + AddTrackRequests []*livekit.AddTrackRequest + PublisherOffer *livekit.SessionDescription + SyncState *livekit.SyncState + UseSinglePeerConnection bool } func (pi *ParticipantInit) MarshalLogObject(e zapcore.ObjectEncoder) error { @@ -230,6 +231,7 @@ func (pi *ParticipantInit) MarshalLogObject(e zapcore.ObjectEncoder) error { e.AddArray("AddTrackRequests", logger.ProtoSlice(pi.AddTrackRequests)) e.AddObject("PublisherOffer", logger.Proto(pi.PublisherOffer)) e.AddObject("SyncState", logger.Proto(pi.SyncState)) + logBoolPtr("UseSinglePeerConnection", &pi.UseSinglePeerConnection) return nil } @@ -240,22 +242,23 @@ func (pi *ParticipantInit) ToStartSession(roomName livekit.RoomName, connectionI } ss := &livekit.StartSession{ - RoomName: string(roomName), - Identity: string(pi.Identity), - Name: string(pi.Name), - ConnectionId: string(connectionID), - Reconnect: pi.Reconnect, - ReconnectReason: pi.ReconnectReason, - AutoSubscribe: pi.AutoSubscribe, - Client: pi.Client, - GrantsJson: string(claims), - AdaptiveStream: pi.AdaptiveStream, - ParticipantId: string(pi.ID), - DisableIceLite: pi.DisableICELite, - CreateRoom: pi.CreateRoom, - AddTrackRequests: pi.AddTrackRequests, - PublisherOffer: pi.PublisherOffer, - SyncState: pi.SyncState, + RoomName: string(roomName), + Identity: string(pi.Identity), + Name: string(pi.Name), + ConnectionId: string(connectionID), + Reconnect: pi.Reconnect, + ReconnectReason: pi.ReconnectReason, + AutoSubscribe: pi.AutoSubscribe, + Client: pi.Client, + GrantsJson: string(claims), + AdaptiveStream: pi.AdaptiveStream, + ParticipantId: string(pi.ID), + DisableIceLite: pi.DisableICELite, + CreateRoom: pi.CreateRoom, + AddTrackRequests: pi.AddTrackRequests, + PublisherOffer: pi.PublisherOffer, + SyncState: pi.SyncState, + UseSinglePeerConnection: pi.UseSinglePeerConnection, } if pi.SubscriberAllowPause != nil { subscriberAllowPause := *pi.SubscriberAllowPause @@ -272,21 +275,22 @@ func ParticipantInitFromStartSession(ss *livekit.StartSession, region string) (* } pi := &ParticipantInit{ - Identity: livekit.ParticipantIdentity(ss.Identity), - Name: livekit.ParticipantName(ss.Name), - Reconnect: ss.Reconnect, - ReconnectReason: ss.ReconnectReason, - Client: ss.Client, - AutoSubscribe: ss.AutoSubscribe, - Grants: claims, - Region: region, - AdaptiveStream: ss.AdaptiveStream, - ID: livekit.ParticipantID(ss.ParticipantId), - DisableICELite: ss.DisableIceLite, - CreateRoom: ss.CreateRoom, - AddTrackRequests: ss.AddTrackRequests, - PublisherOffer: ss.PublisherOffer, - SyncState: ss.SyncState, + Identity: livekit.ParticipantIdentity(ss.Identity), + Name: livekit.ParticipantName(ss.Name), + Reconnect: ss.Reconnect, + ReconnectReason: ss.ReconnectReason, + Client: ss.Client, + AutoSubscribe: ss.AutoSubscribe, + Grants: claims, + Region: region, + AdaptiveStream: ss.AdaptiveStream, + ID: livekit.ParticipantID(ss.ParticipantId), + DisableICELite: ss.DisableIceLite, + CreateRoom: ss.CreateRoom, + AddTrackRequests: ss.AddTrackRequests, + PublisherOffer: ss.PublisherOffer, + SyncState: ss.SyncState, + UseSinglePeerConnection: ss.UseSinglePeerConnection, } if ss.SubscriberAllowPause != nil { subscriberAllowPause := *ss.SubscriberAllowPause diff --git a/pkg/rtc/config.go b/pkg/rtc/config.go index 1a87a0f94..a107a9106 100644 --- a/pkg/rtc/config.go +++ b/pkg/rtc/config.go @@ -25,8 +25,8 @@ import ( ) const ( - frameMarking = "urn:ietf:params:rtp-hdrext:framemarking" - repairedRTPStreamID = "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id" + frameMarkingURI = "urn:ietf:params:rtp-hdrext:framemarking" + repairedRTPStreamIDURI = "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id" ) type WebRTCConfig struct { @@ -79,8 +79,67 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { rtcConf.PacketBufferSizeAudio = rtcConf.PacketBufferSize } - // publisher configuration - publisherConfig := DirectionConfig{ + return &WebRTCConfig{ + WebRTCConfig: *webRTCConfig, + Receiver: ReceiverConfig{ + PacketBufferSizeVideo: rtcConf.PacketBufferSizeVideo, + PacketBufferSizeAudio: rtcConf.PacketBufferSizeAudio, + }, + Publisher: getPublisherConfig(false), + Subscriber: getSubscriberConfig(rtcConf.CongestionControl.UseSendSideBWEInterceptor || rtcConf.CongestionControl.UseSendSideBWE), + }, nil +} + +func (c *WebRTCConfig) UpdatePublisherConfig(consolidated bool) { + c.Publisher = getPublisherConfig(consolidated) +} + +func (c *WebRTCConfig) UpdateSubscriberConfig(ccConf config.CongestionControlConfig) { + c.Subscriber = getSubscriberConfig(ccConf.UseSendSideBWEInterceptor || ccConf.UseSendSideBWE) +} + +func (c *WebRTCConfig) SetBufferFactory(factory *buffer.Factory) { + c.BufferFactory = factory + c.SettingEngine.BufferFactory = factory.GetOrNew +} + +func getPublisherConfig(consolidated bool) DirectionConfig { + if consolidated { + return DirectionConfig{ + RTPHeaderExtension: RTPHeaderExtensionConfig{ + Audio: []string{ + sdp.SDESMidURI, + sdp.SDESRTPStreamIDURI, + sdp.AudioLevelURI, + //act.AbsCaptureTimeURI, + }, + Video: []string{ + sdp.SDESMidURI, + sdp.SDESRTPStreamIDURI, + sdp.TransportCCURI, + sdp.ABSSendTimeURI, + frameMarkingURI, + dd.ExtensionURI, + repairedRTPStreamIDURI, + //act.AbsCaptureTimeURI, + }, + }, + RTCPFeedback: RTCPFeedbackConfig{ + Audio: []webrtc.RTCPFeedback{ + {Type: webrtc.TypeRTCPFBNACK}, + }, + Video: []webrtc.RTCPFeedback{ + {Type: webrtc.TypeRTCPFBTransportCC}, + {Type: webrtc.TypeRTCPFBGoogREMB}, + {Type: webrtc.TypeRTCPFBCCM, Parameter: "fir"}, + {Type: webrtc.TypeRTCPFBNACK}, + {Type: webrtc.TypeRTCPFBNACK, Parameter: "pli"}, + }, + }, + } + } + + return DirectionConfig{ RTPHeaderExtension: RTPHeaderExtensionConfig{ Audio: []string{ sdp.SDESMidURI, @@ -92,9 +151,9 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { sdp.SDESMidURI, sdp.SDESRTPStreamIDURI, sdp.TransportCCURI, - frameMarking, + frameMarkingURI, dd.ExtensionURI, - repairedRTPStreamID, + repairedRTPStreamIDURI, //act.AbsCaptureTimeURI, }, }, @@ -110,25 +169,6 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { }, }, } - - return &WebRTCConfig{ - WebRTCConfig: *webRTCConfig, - Receiver: ReceiverConfig{ - PacketBufferSizeVideo: rtcConf.PacketBufferSizeVideo, - PacketBufferSizeAudio: rtcConf.PacketBufferSizeAudio, - }, - Publisher: publisherConfig, - Subscriber: getSubscriberConfig(rtcConf.CongestionControl.UseSendSideBWEInterceptor || rtcConf.CongestionControl.UseSendSideBWE), - }, nil -} - -func (c *WebRTCConfig) UpdateCongestionControl(ccConf config.CongestionControlConfig) { - c.Subscriber = getSubscriberConfig(ccConf.UseSendSideBWEInterceptor || ccConf.UseSendSideBWE) -} - -func (c *WebRTCConfig) SetBufferFactory(factory *buffer.Factory) { - c.BufferFactory = factory - c.SettingEngine.BufferFactory = factory.GetOrNew } func getSubscriberConfig(enableTWCC bool) DirectionConfig { diff --git a/pkg/rtc/mediaengine.go b/pkg/rtc/mediaengine.go index 579cb0a42..6b24dac6e 100644 --- a/pkg/rtc/mediaengine.go +++ b/pkg/rtc/mediaengine.go @@ -24,169 +24,188 @@ import ( "github.com/livekit/protocol/livekit" ) -var OpusCodecCapability = webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeOpus.String(), - ClockRate: 48000, - Channels: 2, - SDPFmtpLine: "minptime=10;useinbandfec=1", -} -var RedCodecCapability = webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeRED.String(), - ClockRate: 48000, - Channels: 2, - SDPFmtpLine: "111/111", -} -var videoRTX = webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeRTX.String(), - ClockRate: 90000, -} +var ( + OpusCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeOpus.String(), + ClockRate: 48000, + Channels: 2, + SDPFmtpLine: "minptime=10;useinbandfec=1", + }, + PayloadType: 111, + } + + RedCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeRED.String(), + ClockRate: 48000, + Channels: 2, + SDPFmtpLine: "111/111", + }, + PayloadType: 63, + } + + pcmuCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypePCMU.String(), + ClockRate: 8000, + }, + PayloadType: 0, + } + + pcmaCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypePCMA.String(), + ClockRate: 8000, + }, + PayloadType: 8, + } + + videoRTXCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeRTX.String(), + ClockRate: 90000, + }, + } + + vp8CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeVP8.String(), + ClockRate: 90000, + }, + PayloadType: 96, + } + + vp9ProfileId0CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeVP9.String(), + ClockRate: 90000, + SDPFmtpLine: "profile-id=0", + }, + PayloadType: 98, + } + + vp9ProfileId1CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeVP9.String(), + ClockRate: 90000, + SDPFmtpLine: "profile-id=1", + }, + PayloadType: 100, + } + + h264ProfileLevelId42e01fPacketizationMode0CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeH264.String(), + ClockRate: 90000, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + }, + PayloadType: 125, + } + + h264ProfileLevelId42e01fPacketizationMode1CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeH264.String(), + ClockRate: 90000, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f", + }, + PayloadType: 108, + } + + h264HighProfileFmtp = "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=640032" + h264HighProfileCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeH264.String(), + ClockRate: 90000, + SDPFmtpLine: h264HighProfileFmtp, + }, + PayloadType: 123, + } + + av1CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeAV1.String(), + ClockRate: 90000, + }, + PayloadType: 35, + } + + h265CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeH265.String(), + ClockRate: 90000, + }, + PayloadType: 116, + } + + videoCodecsParameters = []webrtc.RTPCodecParameters{ + vp8CodecParameters, + vp9ProfileId0CodecParameters, + vp9ProfileId1CodecParameters, + h264ProfileLevelId42e01fPacketizationMode0CodecParameters, + h264ProfileLevelId42e01fPacketizationMode1CodecParameters, + h264HighProfileCodecParameters, + av1CodecParameters, + h265CodecParameters, + } +) func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedback RTCPFeedbackConfig, filterOutH264HighProfile bool) error { - opusCodec := OpusCodecCapability - opusCodec.RTCPFeedback = rtcpFeedback.Audio - var opusPayload webrtc.PayloadType - if IsCodecEnabled(codecs, opusCodec) { - opusPayload = 111 - if err := me.RegisterCodec(webrtc.RTPCodecParameters{ - RTPCodecCapability: opusCodec, - PayloadType: opusPayload, - }, webrtc.RTPCodecTypeAudio); err != nil { + // audio codecs + if IsCodecEnabled(codecs, OpusCodecParameters.RTPCodecCapability) { + cp := OpusCodecParameters + cp.RTPCodecCapability.RTCPFeedback = rtcpFeedback.Audio + if err := me.RegisterCodec(cp, webrtc.RTPCodecTypeAudio); err != nil { return err } - if IsCodecEnabled(codecs, RedCodecCapability) { - if err := me.RegisterCodec(webrtc.RTPCodecParameters{ - RTPCodecCapability: RedCodecCapability, - PayloadType: 63, - }, webrtc.RTPCodecTypeAudio); err != nil { + if IsCodecEnabled(codecs, RedCodecParameters.RTPCodecCapability) { + if err := me.RegisterCodec(RedCodecParameters, webrtc.RTPCodecTypeAudio); err != nil { return err } } } - for _, codec := range []webrtc.RTPCodecParameters{ - { - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypePCMU.String(), - ClockRate: 8000, - Channels: 1, - RTCPFeedback: rtcpFeedback.Audio, - }, - PayloadType: 0, - }, - { - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypePCMA.String(), - ClockRate: 8000, - Channels: 1, - RTCPFeedback: rtcpFeedback.Audio, - }, - PayloadType: 8, - }, - } { - if IsCodecEnabled(codecs, codec.RTPCodecCapability) { - if err := me.RegisterCodec(codec, webrtc.RTPCodecTypeAudio); err != nil { - return err - } + for _, codec := range []webrtc.RTPCodecParameters{pcmuCodecParameters, pcmaCodecParameters} { + if !IsCodecEnabled(codecs, codec.RTPCodecCapability) { + continue + } + + cp := codec + cp.RTPCodecCapability.RTCPFeedback = rtcpFeedback.Audio + if err := me.RegisterCodec(cp, webrtc.RTPCodecTypeAudio); err != nil { + return err } } - rtxEnabled := IsCodecEnabled(codecs, videoRTX) - - h264HighProfileFmtp := "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=640032" - for _, codec := range []webrtc.RTPCodecParameters{ - { - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeVP8.String(), - ClockRate: 90000, - RTCPFeedback: rtcpFeedback.Video, - }, - PayloadType: 96, - }, - { - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeVP9.String(), - ClockRate: 90000, - SDPFmtpLine: "profile-id=0", - RTCPFeedback: rtcpFeedback.Video, - }, - PayloadType: 98, - }, - { - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeVP9.String(), - ClockRate: 90000, - SDPFmtpLine: "profile-id=1", - RTCPFeedback: rtcpFeedback.Video, - }, - PayloadType: 100, - }, - { - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeH264.String(), - ClockRate: 90000, - SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", - RTCPFeedback: rtcpFeedback.Video, - }, - PayloadType: 125, - }, - { - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeH264.String(), - ClockRate: 90000, - SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f", - RTCPFeedback: rtcpFeedback.Video, - }, - PayloadType: 108, - }, - { - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeH264.String(), - ClockRate: 90000, - SDPFmtpLine: h264HighProfileFmtp, - RTCPFeedback: rtcpFeedback.Video, - }, - PayloadType: 123, - }, - { - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeAV1.String(), - ClockRate: 90000, - RTCPFeedback: rtcpFeedback.Video, - }, - PayloadType: 35, - }, - { - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeH265.String(), - ClockRate: 90000, - RTCPFeedback: rtcpFeedback.Video, - }, - PayloadType: 116, - }, - } { + // video codecs + rtxEnabled := IsCodecEnabled(codecs, videoRTXCodecParameters.RTPCodecCapability) + for _, codec := range videoCodecsParameters { if filterOutH264HighProfile && codec.RTPCodecCapability.SDPFmtpLine == h264HighProfileFmtp { continue } if mime.IsMimeTypeStringRTX(codec.MimeType) { continue } - if IsCodecEnabled(codecs, codec.RTPCodecCapability) { - if err := me.RegisterCodec(codec, webrtc.RTPCodecTypeVideo); err != nil { - return err - } - if rtxEnabled { - if err := me.RegisterCodec(webrtc.RTPCodecParameters{ - RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: mime.MimeTypeRTX.String(), - ClockRate: 90000, - SDPFmtpLine: fmt.Sprintf("apt=%d", codec.PayloadType), - }, - PayloadType: codec.PayloadType + 1, - }, webrtc.RTPCodecTypeVideo); err != nil { - return err - } - } + if !IsCodecEnabled(codecs, codec.RTPCodecCapability) { + continue + } + + cp := codec + cp.RTPCodecCapability.RTCPFeedback = rtcpFeedback.Video + if err := me.RegisterCodec(cp, webrtc.RTPCodecTypeVideo); err != nil { + return err + } + + if !rtxEnabled { + continue + } + + cp = videoRTXCodecParameters + cp.RTPCodecCapability.SDPFmtpLine = fmt.Sprintf("apt=%d", codec.PayloadType) + cp.PayloadType = codec.PayloadType + 1 + if err := me.RegisterCodec(cp, webrtc.RTPCodecTypeVideo); err != nil { + return err } } return nil @@ -242,3 +261,57 @@ func selectAlternativeVideoCodec(enabledCodecs []*livekit.Codec) string { // no viable codec in the list of enabled codecs, fall back to the most widely supported codec return mime.MimeTypeVP8.String() } + +func selectAlternativeAudioCodec(enabledCodecs []*livekit.Codec) string { + for _, c := range enabledCodecs { + if mime.IsMimeTypeStringAudio(c.Mime) { + return c.Mime + } + } + // no viable codec in the list of enabled codecs, fall back to the most widely supported codec + return mime.MimeTypeOpus.String() +} + +func filterCodecs( + codecs []webrtc.RTPCodecParameters, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, + filterOutH264HighProfile bool, +) []webrtc.RTPCodecParameters { + filteredCodecs := make([]webrtc.RTPCodecParameters, 0, len(codecs)) + for _, c := range codecs { + if filterOutH264HighProfile && isH264HighProfile(c.RTPCodecCapability.SDPFmtpLine) { + continue + } + + for _, enabledCodec := range enabledCodecs { + if mime.NormalizeMimeType(enabledCodec.Mime) == mime.NormalizeMimeType(c.RTPCodecCapability.MimeType) { + // SINGLE-PEER-CONNECTION-TOOD: remove `nack` for RED? + if mime.IsMimeTypeStringVideo(c.RTPCodecCapability.MimeType) { + c.RTPCodecCapability.RTCPFeedback = rtcpFeedbackConfig.Video + } else { + c.RTPCodecCapability.RTCPFeedback = rtcpFeedbackConfig.Audio + } + filteredCodecs = append(filteredCodecs, c) + break + } + } + } + return filteredCodecs +} + +func isH264HighProfile(fmtp string) bool { + params := strings.Split(fmtp, ";") + for _, param := range params { + parts := strings.Split(param, "=") + if len(parts) == 2 { + if parts[0] == "profile-level-id" { + // https://datatracker.ietf.org/doc/html/rfc6184#section-8.1 + // hex value 0x64 for profile_idc is high profile + return strings.HasPrefix(parts[1], "64") + } + } + } + + return false +} diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 133573f43..93b90cb8b 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -16,6 +16,7 @@ package rtc import ( "errors" + "slices" "sync" "github.com/pion/rtcp" @@ -309,7 +310,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * if transceiver == nil { info := t.params.MediaTrack.ToProto() addTrackParams := types.AddTrackParams{ - Stereo: info.Stereo, + Stereo: slices.Contains(info.AudioFeatures, livekit.AudioTrackFeature_TF_STEREO), Red: !info.DisableRed, } if addTrackParams.Red && (len(codecs) == 1 && mime.IsMimeTypeStringOpus(codecs[0].MimeType)) { diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 5fd0f4171..1a97db1ae 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -209,6 +209,7 @@ type ParticipantParams struct { LastPubReliableSeq uint32 Country string PreferVideoSizeFromMedia bool + UseSinglePeerConnection bool } type ParticipantImpl struct { @@ -1098,7 +1099,7 @@ func (p *ParticipantImpl) updateRidsFromSDP(parsed *sdp.SessionDescription, unma } p.pendingTracksLock.Lock() - pti := p.getPendingTrackPrimaryBySdpCid(mst, true) + pti := p.getPendingTrackPrimaryBySdpCid(mst) if pti != nil { pti.sdpRids = getRids(pti.sdpRids) p.pubLogger.Debugw( @@ -1201,6 +1202,7 @@ func (p *ParticipantImpl) HandleOffer(offer webrtc.SessionDescription, offerId u } } + p.handlePendingRemoteTracks() return err } @@ -1463,6 +1465,10 @@ func (p *ParticipantImpl) clearMigrationTimer() { } func (p *ParticipantImpl) setupMigrationTimerLocked() { + if p.params.UseSinglePeerConnection { + return + } + // // On subscriber peer connection, remote side will try ICE on both // pre- and post-migration ICE candidates as the migrating out @@ -1841,6 +1847,10 @@ func (h PublisherTransportHandler) OnDataSendError(err error) { h.p.onDataSendError(err) } +func (h PublisherTransportHandler) OnUnmatchedMedia(numAudios uint32, numVideos uint32) error { + return h.p.sendMediaSectionsRequirement(numAudios, numVideos) +} + // ---------------------------------------------------------- type SubscriberTransportHandler struct { @@ -1879,6 +1889,8 @@ func (h PrimaryTransportHandler) OnFullyEstablished() { h.p.onPrimaryTransportFullyEstablished() } +// ---------------------------------------------------------- + func (p *ParticipantImpl) setupSignalling() { p.signalling = signalling.NewSignalling(signalling.SignallingParams{ Logger: p.params.Logger, @@ -1902,7 +1914,7 @@ func (p *ParticipantImpl) setupTransportManager() error { var pth transport.Handler = PublisherTransportHandler{ath} var sth transport.Handler = SubscriberTransportHandler{ath} - subscriberAsPrimary := p.ProtocolVersion().SubscriberAsPrimary() && p.CanSubscribe() && !p.params.UseOneShotSignallingMode + subscriberAsPrimary := !p.params.UseOneShotSignallingMode && (p.ProtocolVersion().SubscriberAsPrimary() && p.CanSubscribe()) && !p.params.UseSinglePeerConnection if subscriberAsPrimary { sth = PrimaryTransportHandler{sth, p} } else { @@ -1913,6 +1925,7 @@ func (p *ParticipantImpl) setupTransportManager() error { // primary connection does not change, canSubscribe can change if permission was updated // after the participant has joined SubscriberAsPrimary: subscriberAsPrimary, + UseSinglePeerConnection: p.params.UseSinglePeerConnection, Config: p.params.Config, Twcc: p.twcc, ProtocolVersion: p.params.ProtocolVersion, @@ -2193,7 +2206,7 @@ func (p *ParticipantImpl) onMediaTrack(rtcTrack *webrtc.TrackRemote, rtpReceiver publishedTrack, isNewTrack, isReceiverAdded, sdpRids := p.mediaTrackReceived(track, rtpReceiver) if publishedTrack == nil { p.pubLogger.Debugw( - "webrtc Track published but can't find MediaTrack in pendingTracks", + "webrtc track published but can't find MediaTrack in pendingTracks", "kind", track.Kind().String(), "webrtcTrackID", track.ID(), "rid", track.RID(), @@ -2456,7 +2469,7 @@ func (p *ParticipantImpl) onPublisherInitialConnected() { p.supervisor.SetPublisherPeerConnectionConnected(true) } - if p.params.UseOneShotSignallingMode { + if p.params.UseOneShotSignallingMode || p.params.UseSinglePeerConnection { go p.subscriberRTCPWorker() p.setDownTracksConnected() @@ -2759,7 +2772,7 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mimeType}) { altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs) p.pubLogger.Infow( - "falling back to alternative codec", + "falling back to alternative video codec", "codec", mimeType, "altCodec", altCodec, "trackID", ti.Sid, @@ -2774,8 +2787,21 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l videoLayerMode = livekit.VideoLayer_ONE_SPATIAL_LAYER_PER_STREAM } } - } else if req.Type == livekit.TrackType_AUDIO && !mime.IsMimeTypeStringAudio(mimeType) { - mimeType = mime.MimeTypePrefixAudio + mimeType + } else if req.Type == livekit.TrackType_AUDIO { + if !mime.IsMimeTypeStringAudio(mimeType) { + mimeType = mime.MimeTypePrefixAudio + mimeType + } + if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mimeType}) { + altCodec := selectAlternativeAudioCodec(p.enabledPublishCodecs) + p.pubLogger.Infow( + "falling back to alternative audio codec", + "codec", mimeType, + "altCodec", altCodec, + "trackID", ti.Sid, + ) + // select an alternative MIME type that's generally supported + mimeType = altCodec + } } if _, ok := seenCodecs[mimeType]; ok || mimeType == "" { @@ -2821,18 +2847,23 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l } p.params.Telemetry.TrackPublishRequested(context.Background(), p.ID(), p.Identity(), utils.CloneProto(ti)) + if p.supervisor != nil { p.supervisor.AddPublication(livekit.TrackID(ti.Sid)) p.supervisor.SetPublicationMute(livekit.TrackID(ti.Sid), ti.Muted) } + if p.getPublishedTrackBySignalCid(req.Cid) != nil || p.getPublishedTrackBySdpCid(req.Cid) != nil || p.pendingTracks[req.Cid] != nil { if p.pendingTracks[req.Cid] == nil { - p.pendingTracks[req.Cid] = &pendingTrackInfo{ + pti := &pendingTrackInfo{ trackInfos: []*livekit.TrackInfo{ti}, - sdpRids: buffer.DefaultVideoLayersRid, // could get updated from SDP createdAt: time.Now(), queued: true, } + if ti.Type == livekit.TrackType_VIDEO { + pti.sdpRids = buffer.DefaultVideoLayersRid // could get updated from SDP + } + p.pendingTracks[req.Cid] = pti } else { p.pendingTracks[req.Cid].trackInfos = append(p.pendingTracks[req.Cid].trackInfos, ti) } @@ -2845,11 +2876,14 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l return nil } - p.pendingTracks[req.Cid] = &pendingTrackInfo{ + pti := &pendingTrackInfo{ trackInfos: []*livekit.TrackInfo{ti}, - sdpRids: buffer.DefaultVideoLayersRid, // could get updated from SDP createdAt: time.Now(), } + if ti.Type == livekit.TrackType_VIDEO { + pti.sdpRids = buffer.DefaultVideoLayersRid // could get updated from SDP + } + p.pendingTracks[req.Cid] = pti p.pubLogger.Debugw( "pending track added", "trackID", ti.Sid, @@ -3295,7 +3329,7 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp return signalCid, utils.CloneProto(pendingInfo.trackInfos[0]), pendingInfo.sdpRids, pendingInfo.migrated, pendingInfo.createdAt } -func (p *ParticipantImpl) getPendingTrackPrimaryBySdpCid(sdpCid string, skipQueued bool) *pendingTrackInfo { +func (p *ParticipantImpl) getPendingTrackPrimaryBySdpCid(sdpCid string) *pendingTrackInfo { for _, pti := range p.pendingTracks { ti := pti.trackInfos[0] if len(ti.Codecs) == 0 { @@ -3674,7 +3708,12 @@ func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Cod subscribeCodecs = append(subscribeCodecs, c) } p.enabledSubscribeCodecs = subscribeCodecs - p.params.Logger.Debugw("setup enabled codecs", "publish", p.enabledPublishCodecs, "subscribe", p.enabledSubscribeCodecs, "disabled", disabledCodecs) + p.params.Logger.Debugw( + "setup enabled codecs", + "publish", logger.ProtoSlice(p.enabledPublishCodecs), + "subscribe", logger.ProtoSlice(p.enabledSubscribeCodecs), + "disabled", logger.Proto(disabledCodecs), + ) } func (p *ParticipantImpl) replayJoiningReliableMessages() { @@ -3718,7 +3757,7 @@ func (p *ParticipantImpl) UpdateAudioTrack(update *livekit.UpdateLocalAudioTrack if ti.Sid == update.TrackSid { isPending = true - ti.AudioFeatures = update.Features + ti.AudioFeatures = sutils.DedupeSlice(update.Features) ti.Stereo = false ti.DisableDtx = false for _, feature := range update.Features { @@ -3895,3 +3934,44 @@ func (p *ParticipantImpl) HandleLeaveRequest(reason types.ParticipantCloseReason func (p *ParticipantImpl) HandleSignalMessage(msg proto.Message) error { return p.signalHandler.HandleMessage(msg) } + +func (p *ParticipantImpl) IsUsingSinglePeerConnection() bool { + return p.params.UseSinglePeerConnection +} + +func (p *ParticipantImpl) AddTrackLocal( + trackLocal webrtc.TrackLocal, + params types.AddTrackParams, +) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { + if p.params.UseSinglePeerConnection { + return p.TransportManager.AddTrackLocal( + trackLocal, + params, + p.enabledSubscribeCodecs, + p.params.Config.Subscriber.RTCPFeedback, + ) + } else { + return p.TransportManager.AddTrackLocal(trackLocal, params, nil, RTCPFeedbackConfig{}) + } +} + +func (p *ParticipantImpl) AddTransceiverFromTrackLocal( + trackLocal webrtc.TrackLocal, + params types.AddTrackParams, +) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { + if p.params.UseSinglePeerConnection { + return p.TransportManager.AddTransceiverFromTrackLocal( + trackLocal, + params, + p.enabledSubscribeCodecs, + p.params.Config.Subscriber.RTCPFeedback, + ) + } else { + return p.TransportManager.AddTransceiverFromTrackLocal( + trackLocal, + params, + nil, + RTCPFeedbackConfig{}, + ) + } +} diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 015e07bb1..b5dc5064b 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -609,10 +609,7 @@ func TestPreferAudioCodecForRed(t *testing.T) { me := webrtc.MediaEngine{} me.RegisterDefaultCodecs() - require.NoError(t, me.RegisterCodec(webrtc.RTPCodecParameters{ - RTPCodecCapability: RedCodecCapability, - PayloadType: 63, - }, webrtc.RTPCodecTypeAudio)) + require.NoError(t, me.RegisterCodec(RedCodecParameters, webrtc.RTPCodecTypeAudio)) api := webrtc.NewAPI(webrtc.WithMediaEngine(&me)) pc, err := api.NewPeerConnection(webrtc.Configuration{}) diff --git a/pkg/rtc/participant_sdp.go b/pkg/rtc/participant_sdp.go index 0229eb7a4..0904b78c1 100644 --- a/pkg/rtc/participant_sdp.go +++ b/pkg/rtc/participant_sdp.go @@ -16,6 +16,7 @@ package rtc import ( "fmt" + "slices" "strconv" "strings" @@ -25,6 +26,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" lksdp "github.com/livekit/protocol/sdp" "github.com/livekit/protocol/utils" ) @@ -91,6 +93,11 @@ func (p *ParticipantImpl) populateSdpCid(parsedOffer *sdp.SessionDescription) ([ ) } unmatchedTrack.(*MediaTrack).UpdateCodecSdpCid(unmatchedSdpMimeType, streamID) + p.pubLogger.Debugw( + "published track SDP cid updated", + "trackID", unmatchedTrack.ID(), + "track", logger.Proto(unmatchedTrack.ToProto()), + ) } continue } @@ -176,7 +183,11 @@ func (p *ParticipantImpl) setCodecPreferencesForPublisher( livekit.TrackType_AUDIO, ) parsedOffer = p.setCodecPreferencesOpusRedForPublisher(parsedOffer, unprocessedUnmatchAudios) - parsedOffer, _ = p.setCodecPreferencesForPublisherMedia(parsedOffer, unmatchVideos, livekit.TrackType_VIDEO) + parsedOffer, _ = p.setCodecPreferencesForPublisherMedia( + parsedOffer, + unmatchVideos, + livekit.TrackType_VIDEO, + ) return parsedOffer } @@ -191,10 +202,11 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher( } p.pendingTracksLock.RLock() - _, info, _, _, _ := p.getPendingTrack(streamID, livekit.TrackType_AUDIO, false) - // if RED is disabled for this track, don't prefer RED codec in offer - disableRed := info != nil && info.DisableRed + _, ti, _, _, _ := p.getPendingTrack(streamID, livekit.TrackType_AUDIO, false) p.pendingTracksLock.RUnlock() + if ti == nil { + continue + } codecs, err := lksdp.CodecsFromMediaDescription(unmatchAudio) if err != nil { @@ -217,10 +229,11 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher( continue } + // if RED is disabled for this track, don't prefer RED codec in offer var preferredCodecs, leftCodecs []string for _, codec := range codecs { // codec contain opus/red - if !disableRed && mime.IsMimeTypeCodecStringRED(codec.Name) && strings.Contains(codec.Fmtp, strconv.FormatInt(int64(opusPayload), 10)) { + if !ti.DisableRed && mime.IsMimeTypeCodecStringRED(codec.Name) && strings.Contains(codec.Fmtp, strconv.FormatInt(int64(opusPayload), 10)) { preferredCodecs = append(preferredCodecs, strconv.FormatInt(int64(codec.PayloadType), 10)) } else { leftCodecs = append(leftCodecs, strconv.FormatInt(int64(codec.PayloadType), 10)) @@ -262,38 +275,38 @@ func (p *ParticipantImpl) setCodecPreferencesForPublisherMedia( unprocessed := make([]*sdp.MediaDescription, 0, len(unmatches)) // unmatched media is pending for publish, set codec preference for _, unmatch := range unmatches { + var ti *livekit.TrackInfo + var mimeType string + streamID, ok := lksdp.ExtractStreamID(unmatch) if !ok { unprocessed = append(unprocessed, unmatch) continue } - var info *livekit.TrackInfo p.pendingTracksLock.RLock() mt := p.getPublishedTrackBySdpCid(streamID) if mt != nil { - info = mt.ToProto() + ti = mt.ToProto() } else { - _, info, _, _, _ = p.getPendingTrack(streamID, trackType, false) + _, ti, _, _, _ = p.getPendingTrack(streamID, trackType, false) } + p.pendingTracksLock.RUnlock() - if info == nil { - p.pendingTracksLock.RUnlock() + if ti == nil { unprocessed = append(unprocessed, unmatch) continue } - var mimeType string - for _, c := range info.Codecs { + for _, c := range ti.Codecs { if c.Cid == streamID || c.SdpCid == streamID { mimeType = c.MimeType break } } - if mimeType == "" && len(info.Codecs) > 0 { - mimeType = info.Codecs[0].MimeType + if mimeType == "" && len(ti.Codecs) > 0 { + mimeType = ti.Codecs[0].MimeType } - p.pendingTracksLock.RUnlock() if mimeType == "" { unprocessed = append(unprocessed, unmatch) @@ -422,7 +435,7 @@ func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescript if !ti.DisableDtx { attr.Value += ";usedtx=1" } - if ti.Stereo { + if slices.Contains(ti.AudioFeatures, livekit.AudioTrackFeature_TF_STEREO) { attr.Value += ";stereo=1;maxaveragebitrate=510000" } m.Attributes[i] = attr diff --git a/pkg/rtc/participant_signal.go b/pkg/rtc/participant_signal.go index 53574ded4..52792e8d9 100644 --- a/pkg/rtc/participant_signal.go +++ b/pkg/rtc/participant_signal.go @@ -330,3 +330,19 @@ func (p *ParticipantImpl) SendSubscriptionPermissionUpdate(publisherID livekit.P } return err } + +func (p *ParticipantImpl) sendMediaSectionsRequirement(numAudios uint32, numVideos uint32) error { + p.pubLogger.Debugw( + "sending media sections requirement", + "numAudios", numAudios, + "numVideos", numVideos, + ) + err := p.signaller.WriteMessage(p.signalling.SignalMediaSectionsRequirement(&livekit.MediaSectionsRequirement{ + NumAudios: numAudios, + NumVideos: numVideos, + })) + if err != nil { + p.subLogger.Errorw("could not send media sections requirement", err) + } + return err +} diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index ae7250953..7b23d5ee5 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -610,6 +610,10 @@ func (r *Room) Join( } else { participant.Negotiate(true) } + } else { + if participant.IsUsingSinglePeerConnection() { + go r.subscribeToExistingTracks(participant, false) + } } prometheus.ServiceOperationCounter.WithLabelValues("participant_join", "success", "").Add(1) @@ -1114,13 +1118,12 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types. continue } - r.logger.Debugw( + existingParticipant.GetLogger().Debugw( "subscribing to new track", - "participant", existingParticipant.Identity(), - "pID", existingParticipant.ID(), "publisher", participant.Identity(), "publisherID", participant.ID(), - "trackID", track.ID()) + "trackID", track.ID(), + ) existingParticipant.SubscribeToTrack(track.ID(), false) } onParticipantChanged := r.onParticipantChanged @@ -1393,7 +1396,7 @@ func (r *Room) subscribeToExistingTracks(p types.LocalParticipant, isSync bool) } } if len(trackIDs) > 0 { - r.logger.Debugw("subscribed participant to existing tracks", "trackID", trackIDs) + p.GetLogger().Debugw("subscribed participant to existing tracks", "trackID", trackIDs) } } diff --git a/pkg/rtc/signalling/interfaces.go b/pkg/rtc/signalling/interfaces.go index 0dd94a7f7..73b35d5b0 100644 --- a/pkg/rtc/signalling/interfaces.go +++ b/pkg/rtc/signalling/interfaces.go @@ -57,4 +57,5 @@ type ParticipantSignalling interface { SignalSubscribedQualityUpdate(subscribedQualityUpdate *livekit.SubscribedQualityUpdate) proto.Message SignalSubscriptionResponse(subscriptionResponse *livekit.SubscriptionResponse) proto.Message SignalSubscriptionPermissionUpdate(subscriptionPermissionUpdate *livekit.SubscriptionPermissionUpdate) proto.Message + SignalMediaSectionsRequirement(mediaSectionsRequirement *livekit.MediaSectionsRequirement) proto.Message } diff --git a/pkg/rtc/signalling/signalling.go b/pkg/rtc/signalling/signalling.go index 81a8e8f22..2e4ca69a7 100644 --- a/pkg/rtc/signalling/signalling.go +++ b/pkg/rtc/signalling/signalling.go @@ -220,3 +220,11 @@ func (s *signalling) SignalSubscriptionPermissionUpdate(subscriptionPermissionUp }, } } + +func (u *signalling) SignalMediaSectionsRequirement(mediaSectionsRequirement *livekit.MediaSectionsRequirement) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_MediaSectionsRequirement{ + MediaSectionsRequirement: mediaSectionsRequirement, + }, + } +} diff --git a/pkg/rtc/signalling/signallingunimplemented.go b/pkg/rtc/signalling/signallingunimplemented.go index d069d5e2d..5aa3a04cf 100644 --- a/pkg/rtc/signalling/signallingunimplemented.go +++ b/pkg/rtc/signalling/signallingunimplemented.go @@ -107,3 +107,7 @@ func (u *signallingUnimplemented) SignalSubscriptionResponse(subscriptionRespons func (u *signallingUnimplemented) SignalSubscriptionPermissionUpdate(subscriptionPermissionUpdate *livekit.SubscriptionPermissionUpdate) proto.Message { return nil } + +func (u *signallingUnimplemented) SignalMediaSectionsRequirement(mediaSectionsRequirement *livekit.MediaSectionsRequirement) proto.Message { + return nil +} diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index e4af28af2..109bba158 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -246,6 +246,16 @@ type PCTransport struct { eventsQueue *utils.TypedOpsQueue[event] + connectionDetails *types.ICEConnectionDetails + selectedPair atomic.Pointer[webrtc.ICECandidatePair] + mayFailedICEStats []iceCandidatePairStats + mayFailedICEStatsTimer *time.Timer + + numOutstandingAudios uint32 + numRequestSentAudios uint32 + numOutstandingVideos uint32 + numRequestSentVideos uint32 + // the following should be accessed only in event processing go routine cacheLocalCandidates bool cachedLocalCandidates []*webrtc.ICECandidate @@ -257,11 +267,6 @@ type PCTransport struct { signalStateCheckTimer *time.Timer currentOfferIceCredential string // ice user:pwd, for publish side ice restart checking pendingRestartIceOffer *webrtc.SessionDescription - - connectionDetails *types.ICEConnectionDetails - selectedPair atomic.Pointer[webrtc.ICECandidatePair] - mayFailedICEStats []iceCandidatePairStats - mayFailedICEStatsTimer *time.Timer } type TransportParams struct { @@ -466,6 +471,7 @@ func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimat params.Logger.Debugw("rtx pair found from extension", "repair", repair, "base", base) params.Config.BufferFactory.SetRTXPair(repair, base) }, params.Logger)) + api := webrtc.NewAPI( webrtc.WithMediaEngine(me), webrtc.WithSettingEngine(se), @@ -904,7 +910,12 @@ func (t *PCTransport) AddICECandidate(candidate webrtc.ICECandidateInit) { }) } -func (t *PCTransport) AddTrack(trackLocal webrtc.TrackLocal, params types.AddTrackParams) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { +func (t *PCTransport) AddTrack( + trackLocal webrtc.TrackLocal, + params types.AddTrackParams, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, +) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { t.lock.Lock() canReuse := t.canReuseTransceiver td, ok := t.previousTrackDescription[trackLocal.ID()] @@ -924,7 +935,7 @@ func (t *PCTransport) AddTrack(trackLocal webrtc.TrackLocal, params types.AddTra // if never negotiated with client, can't reuse transceiver for track not subscribed before migration if !canReuse { - return t.AddTransceiverFromTrack(trackLocal, params) + return t.AddTransceiverFromTrack(trackLocal, params, enabledCodecs, rtcpFeedbackConfig) } sender, err = t.pc.AddTrack(trackLocal) @@ -932,7 +943,6 @@ func (t *PCTransport) AddTrack(trackLocal webrtc.TrackLocal, params types.AddTra return } - // as there is no way to get transceiver from sender, search for _, tr := range t.pc.GetTransceivers() { if tr.Sender() == sender { transceiver = tr @@ -948,10 +958,17 @@ func (t *PCTransport) AddTrack(trackLocal webrtc.TrackLocal, params types.AddTra if trackLocal.Kind() == webrtc.RTPCodecTypeAudio { configureAudioTransceiver(transceiver, params.Stereo, !params.Red || !t.params.ClientInfo.SupportsAudioRED()) } + configureTransceiverCodecs(transceiver, enabledCodecs, rtcpFeedbackConfig, !t.params.IsOfferer) + t.adjustNumOutstandingMedia(transceiver) return } -func (t *PCTransport) AddTransceiverFromTrack(trackLocal webrtc.TrackLocal, params types.AddTrackParams) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { +func (t *PCTransport) AddTransceiverFromTrack( + trackLocal webrtc.TrackLocal, + params types.AddTrackParams, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, +) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { transceiver, err = t.pc.AddTransceiverFromTrack(trackLocal) if err != nil { return @@ -966,14 +983,42 @@ func (t *PCTransport) AddTransceiverFromTrack(trackLocal webrtc.TrackLocal, para if trackLocal.Kind() == webrtc.RTPCodecTypeAudio { configureAudioTransceiver(transceiver, params.Stereo, !params.Red || !t.params.ClientInfo.SupportsAudioRED()) } - + configureTransceiverCodecs(transceiver, enabledCodecs, rtcpFeedbackConfig, !t.params.IsOfferer) + t.adjustNumOutstandingMedia(transceiver) return } +func (t *PCTransport) AddTransceiverFromKind( + kind webrtc.RTPCodecType, + init webrtc.RTPTransceiverInit, +) (*webrtc.RTPTransceiver, error) { + return t.pc.AddTransceiverFromKind(kind, init) +} + func (t *PCTransport) RemoveTrack(sender *webrtc.RTPSender) error { return t.pc.RemoveTrack(sender) } +func (t *PCTransport) CurrentLocalDescription() *webrtc.SessionDescription { + cld := t.pc.CurrentLocalDescription() + if cld == nil { + return nil + } + + ld := *cld + return &ld +} + +func (t *PCTransport) CurrentRemoteDescription() *webrtc.SessionDescription { + crd := t.pc.CurrentRemoteDescription() + if crd == nil { + return nil + } + + rd := *crd + return &rd +} + func (t *PCTransport) GetMid(rtpReceiver *webrtc.RTPReceiver) string { for _, tr := range t.pc.GetTransceivers() { if tr.Receiver() == rtpReceiver { @@ -994,6 +1039,30 @@ func (t *PCTransport) GetRTPReceiver(mid string) *webrtc.RTPReceiver { return nil } +func (t *PCTransport) getNumUnmatchedTransceivers() (uint32, uint32) { + if t.isClosed.Load() || t.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { + return 0, 0 + } + + numAudios := uint32(0) + numVideos := uint32(0) + for _, tr := range t.pc.GetTransceivers() { + if tr.Mid() != "" { + continue + } + + switch tr.Kind() { + case webrtc.RTPCodecTypeAudio: + numAudios++ + + case webrtc.RTPCodecTypeVideo: + numVideos++ + } + } + + return numAudios, numVideos +} + func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelInit) error { dc, err := t.pc.CreateDataChannel(label, dci) if err != nil { @@ -1003,6 +1072,7 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni dcPtr **datachannel.DataChannelWriter[*webrtc.DataChannel] dcReady *bool isUnlabeled bool + kind livekit.DataPacket_Kind ) switch dc.Label() { default: @@ -1011,9 +1081,11 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni case ReliableDataChannel: dcPtr = &t.reliableDC dcReady = &t.reliableDCOpened + kind = livekit.DataPacket_RELIABLE case LossyDataChannel: dcPtr = &t.lossyDC dcReady = &t.lossyDCOpened + kind = livekit.DataPacket_LOSSY } dc.OnOpen(func() { @@ -1044,6 +1116,28 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni t.lock.Unlock() t.params.Logger.Debugw(dc.Label() + " data channel open") + go func() { + defer rawDC.Close() + buffer := make([]byte, dataChannelBufferSize) + for { + n, _, err := rawDC.ReadDataChannel(buffer) + if err != nil { + if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "state=Closed") { + t.params.Logger.Warnw("error reading data channel", err, "label", dc.Label()) + } + return + } + + switch { + case isUnlabeled: + t.params.Handler.OnDataMessageUnlabeled(buffer[:n]) + + default: + t.params.Handler.OnDataMessage(kind, buffer[:n]) + } + } + }() + t.maybeNotifyFullyEstablished() }) @@ -1131,7 +1225,6 @@ func (t *PCTransport) GetRTT() (float64, bool) { return scps.CurrentRoundTripTime, true } -// IsEstablished returns true if the PeerConnection has been established func (t *PCTransport) IsEstablished() bool { return t.pc.ConnectionState() != webrtc.PeerConnectionStateNew } @@ -1801,11 +1894,11 @@ func (t *PCTransport) preparePC(previousAnswer webrtc.SessionDescription) error return err } - // replace client's fingerprint into dump pc's answer, for pion's dtls process, it will + // replace client's fingerprint into dummy pc's answer, for pion's dtls process, it will // keep the fingerprint at first call of SetRemoteDescription, if dummy pc and client pc use // different fingerprint, that will cause pion denied dtls data after handshake with client // complete (can't pass fingerprint change). - // in this step, we don't established connection with dump pc(no candidate swap), just use + // in this step, we don't established connection with dummy pc(no candidate swap), just use // sdp negotiation to sticky data channel and keep client's fingerprint parsedAns, _ := ans.Unmarshal() fpLine := fpHahs + " " + fp @@ -1829,9 +1922,9 @@ func (t *PCTransport) preparePC(previousAnswer webrtc.SessionDescription) error return t.pc.SetRemoteDescription(ans) } -func (t *PCTransport) initPCWithPreviousAnswer(previousAnswer webrtc.SessionDescription) (map[string]*webrtc.RTPSender, error) { +func (t *PCTransport) initPCWithPreviousRemoteDescription(previousRemoteDescription webrtc.SessionDescription) (map[string]*webrtc.RTPSender, error) { senders := make(map[string]*webrtc.RTPSender) - parsed, err := previousAnswer.Unmarshal() + parsed, err := previousRemoteDescription.Unmarshal() if err != nil { return senders, err } @@ -1843,19 +1936,33 @@ func (t *PCTransport) initPCWithPreviousAnswer(previousAnswer webrtc.SessionDesc case "audio": codecType = webrtc.RTPCodecTypeAudio case "application": - // for pion generate unmatched sdp, it always appends data channel to last m-lines, - // that not consistent with our previous answer that data channel might at middle-line - // because sdp can negotiate multi times before migration.(it will sticky to the last m-line atfirst negotiate) - // so use a dummy pc to negotiate sdp to fixed the datachannel's mid at same position with previous answer - if err := t.preparePC(previousAnswer); err != nil { - t.params.Logger.Warnw("prepare pc for migration failed", err) - return senders, err + if t.params.IsOfferer { + // for pion generate unmatched sdp, it always appends data channel to last m-lines, + // that not consistent with our previous answer that data channel might at middle-line + // because sdp can negotiate multi times before migration.(it will sticky to the last m-line at first negotiate) + // so use a dummy pc to negotiate sdp to fixed the datachannel's mid at same position with previous answer + if err := t.preparePC(previousRemoteDescription); err != nil { + t.params.Logger.Warnw("prepare pc for migration failed", err) + return senders, err + } } continue default: continue } - tr, err := t.pc.AddTransceiverFromKind(codecType, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendonly}) + + _, ok1 := m.Attribute(webrtc.RTPTransceiverDirectionSendrecv.String()) + _, ok2 := m.Attribute(webrtc.RTPTransceiverDirectionRecvonly.String()) + if !ok1 && !ok2 { + continue + } + + tr, err := t.pc.AddTransceiverFromKind( + codecType, + webrtc.RTPTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionSendonly, + }, + ) if err != nil { return senders, err } @@ -1870,43 +1977,51 @@ func (t *PCTransport) initPCWithPreviousAnswer(previousAnswer webrtc.SessionDesc senders[mid] = sender // set transceiver to inactive - tr.SetSender(tr.Sender(), nil) + tr.SetSender(sender, nil) } return senders, nil } -func (t *PCTransport) SetPreviousSdp(offer, answer *webrtc.SessionDescription) { - // when there is no previous answer, cannot migrate, force a full reconnect - if answer == nil { - t.onNegotiationFailed(true, "no previous answer for previous sdp") +func (t *PCTransport) SetPreviousSdp(localDescription, remoteDescription *webrtc.SessionDescription) { + // when there is no remote description, cannot migrate, force a full reconnect + if remoteDescription == nil { + t.onNegotiationFailed(true, "no previous remote description") return } t.lock.Lock() if t.pc.RemoteDescription() == nil && t.previousAnswer == nil { - t.previousAnswer = answer - if senders, err := t.initPCWithPreviousAnswer(*t.previousAnswer); err != nil { + if t.params.IsOfferer { + t.previousAnswer = remoteDescription + } + if senders, err := t.initPCWithPreviousRemoteDescription(*remoteDescription); err != nil { t.lock.Unlock() - t.onNegotiationFailed(true, fmt.Sprintf("initPCWithPreviousAnswer failed, error: %s", err)) + t.onNegotiationFailed(true, fmt.Sprintf("initPCWithPreviousRemoteDescription failed, error: %s", err)) return - } else if offer != nil { + } else if localDescription != nil { // in migration case, can't reuse transceiver before negotiated except track subscribed at previous node t.canReuseTransceiver = false - if err := t.parseTrackMid(*offer, senders); err != nil { - t.params.Logger.Warnw("parse previous offer failed", err, "offer", offer.SDP) + if err := t.parseTrackMid(*localDescription, senders); err != nil { + t.params.Logger.Warnw( + "parse previous local description failed", err, + "localDescription", localDescription.SDP, + ) } } } - // disable fast negotiation temporarily after migration to avoid sending offer - // contains part of subscribed tracks before migration, let the subscribed track - // resume at the same time. - t.lastNegotiate = time.Now().Add(iceFailedTimeoutTotal) + + if t.params.IsOfferer { + // disable fast negotiation temporarily after migration to avoid sending offer + // contains part of subscribed tracks before migration, let the subscribed track + // resume at the same time. + t.lastNegotiate = time.Now().Add(iceFailedTimeoutTotal) + } t.lock.Unlock() } -func (t *PCTransport) parseTrackMid(offer webrtc.SessionDescription, senders map[string]*webrtc.RTPSender) error { - parsed, err := offer.Unmarshal() +func (t *PCTransport) parseTrackMid(sd webrtc.SessionDescription, senders map[string]*webrtc.RTPSender) error { + parsed, err := sd.Unmarshal() if err != nil { return err } @@ -1924,9 +2039,8 @@ func (t *PCTransport) parseTrackMid(offer webrtc.SessionDescription, senders map if mid == "" { return ErrMidNotFound } - t.previousTrackDescription[trackID] = &trackDescription{ - mid: mid, - sender: senders[mid], + if sender, ok := senders[mid]; ok { + t.previousTrackDescription[trackID] = &trackDescription{mid, sender} } } } @@ -2164,6 +2278,40 @@ func (t *PCTransport) setupSignalStateCheckTimer() { }) } +func (t *PCTransport) adjustNumOutstandingMedia(transceiver *webrtc.RTPTransceiver) { + if transceiver.Mid() != "" { + return + } + + t.lock.Lock() + if transceiver.Kind() == webrtc.RTPCodecTypeAudio { + t.numOutstandingAudios++ + } else { + t.numOutstandingVideos++ + } + t.lock.Unlock() +} + +func (t *PCTransport) sendUnmatchedMediaRequirement(force bool) error { + // if there are unmatched media sections, notify remote peer to generate offer with + // enough media section in subsequent offers + t.lock.Lock() + numAudios := t.numOutstandingAudios - t.numRequestSentAudios + t.numRequestSentAudios += numAudios + + numVideos := t.numOutstandingVideos - t.numRequestSentVideos + t.numRequestSentVideos += numVideos + t.lock.Unlock() + + if force || (numAudios+numVideos) != 0 { + if err := t.params.Handler.OnUnmatchedMedia(numAudios, numVideos); err != nil { + return errors.Wrap(err, "could not send unmatched media requirements") + } + } + + return nil +} + func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { if t.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { t.params.Logger.Warnw("trying to send offer on closed peer connection", nil) @@ -2262,12 +2410,16 @@ func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) return errors.Wrap(err, "could not send offer") } - prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) + return t.localDescriptionSent() } func (t *PCTransport) handleSendOffer(_ event) error { + if !t.params.IsOfferer { + return t.sendUnmatchedMediaRequirement(true) + } + return t.createAndSendOffer(nil) } @@ -2343,6 +2495,12 @@ func (t *PCTransport) setRemoteDescription(sd webrtc.SessionDescription) error { } func (t *PCTransport) createAndSendAnswer() error { + numOutstandingAudios, numOutstandingVideos := t.getNumUnmatchedTransceivers() + t.lock.Lock() + t.numOutstandingAudios, t.numOutstandingVideos = numOutstandingAudios, numOutstandingVideos + t.numRequestSentAudios, t.numRequestSentVideos = 0, 0 + t.lock.Unlock() + answer, err := t.pc.CreateAnswer(nil) if err != nil { if errors.Is(err, webrtc.ErrConnectionClosed) { @@ -2390,8 +2548,19 @@ func (t *PCTransport) createAndSendAnswer() error { return errors.Wrap(err, "could not send answer") } t.localAnswerId.Store(answerId) - prometheus.ServiceOperationCounter.WithLabelValues("answer", "success", "").Add(1) + + if err := t.sendUnmatchedMediaRequirement(false); err != nil { + return err + } + + t.lock.Lock() + if !t.canReuseTransceiver { + t.canReuseTransceiver = true + t.previousTrackDescription = make(map[string]*trackDescription) + } + t.lock.Unlock() + return t.localDescriptionSent() } @@ -2640,6 +2809,36 @@ func configureAudioTransceiver(tr *webrtc.RTPTransceiver, stereo bool, nack bool tr.SetCodecPreferences(configCodecs) } +// In single peer connection mode, set up enebled codecs, +// the config provides config of direction, for publisher peer connection, it is publish enabled codecs +// and for subscriber peer connection, it is subscribe enabled codecs. +// +// But, in single peer connection mode, if setting up a transceiver where the media is +// flowing in the other direction, the other direction codec config needs to be set. +func configureTransceiverCodecs( + tr *webrtc.RTPTransceiver, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, + filterOutH264HighProfile bool, +) { + if len(enabledCodecs) == 0 { + return + } + + sender := tr.Sender() + if sender == nil { + return + } + + filteredCodecs := filterCodecs( + sender.GetParameters().Codecs, + enabledCodecs, + rtcpFeedbackConfig, + filterOutH264HighProfile, + ) + tr.SetCodecPreferences(filteredCodecs) +} + func nonSimulcastRTXRepairsFromSDP(s *sdp.SessionDescription, logger logger.Logger) map[uint32]uint32 { rtxRepairFlows := map[uint32]uint32{} for _, media := range s.MediaDescriptions { diff --git a/pkg/rtc/transport/handler.go b/pkg/rtc/transport/handler.go index e43e3bd7b..72c4cc9b7 100644 --- a/pkg/rtc/transport/handler.go +++ b/pkg/rtc/transport/handler.go @@ -47,6 +47,7 @@ type Handler interface { OnNegotiationStateChanged(state NegotiationState) OnNegotiationFailed() OnStreamStateChange(update *streamallocator.StreamStateUpdate) error + OnUnmatchedMedia(numAudios uint32, numVideos uint32) error } type UnimplementedHandler struct{} @@ -72,3 +73,6 @@ func (h UnimplementedHandler) OnNegotiationFailed() func (h UnimplementedHandler) OnStreamStateChange(update *streamallocator.StreamStateUpdate) error { return nil } +func (h UnimplementedHandler) OnUnmatchedMedia(numAudios uint32, numVideos uint32) error { + return nil +} diff --git a/pkg/rtc/transport/transportfakes/fake_handler.go b/pkg/rtc/transport/transportfakes/fake_handler.go index 05618f8bd..17151a70e 100644 --- a/pkg/rtc/transport/transportfakes/fake_handler.go +++ b/pkg/rtc/transport/transportfakes/fake_handler.go @@ -104,6 +104,18 @@ type FakeHandler struct { arg1 *webrtc.TrackRemote arg2 *webrtc.RTPReceiver } + OnUnmatchedMediaStub func(uint32, uint32) error + onUnmatchedMediaMutex sync.RWMutex + onUnmatchedMediaArgsForCall []struct { + arg1 uint32 + arg2 uint32 + } + onUnmatchedMediaReturns struct { + result1 error + } + onUnmatchedMediaReturnsOnCall map[int]struct { + result1 error + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } @@ -632,6 +644,68 @@ func (fake *FakeHandler) OnTrackArgsForCall(i int) (*webrtc.TrackRemote, *webrtc return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeHandler) OnUnmatchedMedia(arg1 uint32, arg2 uint32) error { + fake.onUnmatchedMediaMutex.Lock() + ret, specificReturn := fake.onUnmatchedMediaReturnsOnCall[len(fake.onUnmatchedMediaArgsForCall)] + fake.onUnmatchedMediaArgsForCall = append(fake.onUnmatchedMediaArgsForCall, struct { + arg1 uint32 + arg2 uint32 + }{arg1, arg2}) + stub := fake.OnUnmatchedMediaStub + fakeReturns := fake.onUnmatchedMediaReturns + fake.recordInvocation("OnUnmatchedMedia", []interface{}{arg1, arg2}) + fake.onUnmatchedMediaMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnUnmatchedMediaCallCount() int { + fake.onUnmatchedMediaMutex.RLock() + defer fake.onUnmatchedMediaMutex.RUnlock() + return len(fake.onUnmatchedMediaArgsForCall) +} + +func (fake *FakeHandler) OnUnmatchedMediaCalls(stub func(uint32, uint32) error) { + fake.onUnmatchedMediaMutex.Lock() + defer fake.onUnmatchedMediaMutex.Unlock() + fake.OnUnmatchedMediaStub = stub +} + +func (fake *FakeHandler) OnUnmatchedMediaArgsForCall(i int) (uint32, uint32) { + fake.onUnmatchedMediaMutex.RLock() + defer fake.onUnmatchedMediaMutex.RUnlock() + argsForCall := fake.onUnmatchedMediaArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnUnmatchedMediaReturns(result1 error) { + fake.onUnmatchedMediaMutex.Lock() + defer fake.onUnmatchedMediaMutex.Unlock() + fake.OnUnmatchedMediaStub = nil + fake.onUnmatchedMediaReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnUnmatchedMediaReturnsOnCall(i int, result1 error) { + fake.onUnmatchedMediaMutex.Lock() + defer fake.onUnmatchedMediaMutex.Unlock() + fake.OnUnmatchedMediaStub = nil + if fake.onUnmatchedMediaReturnsOnCall == nil { + fake.onUnmatchedMediaReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onUnmatchedMediaReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeHandler) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() diff --git a/pkg/rtc/transport_test.go b/pkg/rtc/transport_test.go index 035058ec5..0cc9ea56a 100644 --- a/pkg/rtc/transport_test.go +++ b/pkg/rtc/transport_test.go @@ -112,9 +112,16 @@ func TestNegotiationTiming(t *testing.T) { require.False(t, transportB.IsEstablished()) handleICEExchange(t, transportA, transportB, handlerA, handlerB) - offer := atomic.Value{} - handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, _offerId uint32) error { - offer.Store(&sd) + firstOffer := atomic.Value{} + firstOfferId := atomic.Uint32{} + secondOffer := atomic.Value{} + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, offerId uint32) error { + if _, ok := firstOffer.Load().(*webrtc.SessionDescription); !ok { + firstOffer.Store(&sd) + firstOfferId.Store(offerId) + } else { + secondOffer.Store(&sd) + } return nil }) @@ -157,15 +164,22 @@ func TestNegotiationTiming(t *testing.T) { return state == transport.NegotiationStateRetry }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRetry") - time.Sleep(5 * time.Millisecond) - actualOffer, ok := offer.Load().(*webrtc.SessionDescription) - require.True(t, ok) + require.Eventually(t, func() bool { + _, ok := firstOffer.Load().(*webrtc.SessionDescription) + if !ok { + return false + } + if firstOfferId.Load() == 0 { + return false + } + return true + }, 10*time.Second, 10*time.Millisecond, "first offer not received yet") handlerB.OnAnswerCalls(func(answer webrtc.SessionDescription, answerId uint32) error { transportA.HandleRemoteDescription(answer, answerId) return nil }) - transportB.HandleRemoteDescription(*actualOffer, 10) + transportB.HandleRemoteDescription(*firstOffer.Load().(*webrtc.SessionDescription), firstOfferId.Load()) require.Eventually(t, func() bool { return transportA.IsEstablished() @@ -174,11 +188,18 @@ func TestNegotiationTiming(t *testing.T) { return transportB.IsEstablished() }, 10*time.Second, time.Millisecond*10, "transportB is not established") - // it should still be negotiating again - require.Equal(t, transport.NegotiationStateRemote, negotiationState.Load().(transport.NegotiationState)) - offer2, ok := offer.Load().(*webrtc.SessionDescription) + // offerer should send another offer after processing the answer + // as there were forced negotiations a couple of time above + require.Eventually(t, func() bool { + state, ok := negotiationState.Load().(transport.NegotiationState) + if !ok { + return false + } + + return state == transport.NegotiationStateRemote + }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRemote") + _, ok := secondOffer.Load().(*webrtc.SessionDescription) require.True(t, ok) - require.False(t, offer2 == actualOffer) transportA.Close() transportB.Close() @@ -361,7 +382,9 @@ func TestNegotiationFailed(t *testing.T) { connectTransports(t, transportA, transportB, handlerA, handlerB, false, 1, 1) // reset OnOffer to force a negotiation failure - handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, offerId uint32) error { return nil }) + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, offerId uint32) error { + return nil + }) var failed atomic.Int32 handlerA.OnNegotiationFailedCalls(func() { failed.Inc() diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index a064fdf4c..68bcfb63f 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -71,19 +71,9 @@ func (h TransportManagerTransportHandler) OnFailed(isShortLived bool, iceConnect // ------------------------------- -type TransportManagerPublisherTransportHandler struct { - TransportManagerTransportHandler -} - -func (h TransportManagerPublisherTransportHandler) OnAnswer(sd webrtc.SessionDescription, answerId uint32) error { - h.t.lastPublisherAnswer.Store(sd) - return h.Handler.OnAnswer(sd, answerId) -} - -// ------------------------------- - type TransportManagerParams struct { SubscriberAsPrimary bool + UseSinglePeerConnection bool Config *WebRTCConfig Twcc *twcc.Responder ProtocolVersion types.ProtocolVersion @@ -124,8 +114,6 @@ type TransportManager struct { pendingOfferPublisher *webrtc.SessionDescription pendingOfferIdPublisher uint32 pendingDataChannelsPublisher []*livekit.DataChannelInfo - lastPublisherAnswer atomic.Value - lastPublisherOffer atomic.Value iceConfig *livekit.ICEConfig mediaLossProxy *MediaLossProxy @@ -159,8 +147,10 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro Logger: lgr, SimTracks: params.SimTracks, ClientInfo: params.ClientInfo, + IsSendSide: params.UseOneShotSignallingMode || params.UseSinglePeerConnection, + AllowPlayoutDelay: params.AllowPlayoutDelay, Transport: livekit.SignalTarget_PUBLISHER, - Handler: TransportManagerPublisherTransportHandler{TransportManagerTransportHandler{params.PublisherHandler, t, lgr}}, + Handler: params.PublisherHandler, UseOneShotSignallingMode: params.UseOneShotSignallingMode, DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, DatachannelSlowThreshold: params.DatachannelSlowThreshold, @@ -171,27 +161,31 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro } t.publisher = publisher - lgr = LoggerWithPCTarget(params.Logger, livekit.SignalTarget_SUBSCRIBER) - subscriber, err := NewPCTransport(TransportParams{ - ProtocolVersion: params.ProtocolVersion, - Config: params.Config, - DirectionConfig: params.Config.Subscriber, - CongestionControlConfig: params.CongestionControlConfig, - EnabledCodecs: params.EnabledSubscribeCodecs, - Logger: lgr, - ClientInfo: params.ClientInfo, - IsOfferer: true, - IsSendSide: true, - AllowPlayoutDelay: params.AllowPlayoutDelay, - DatachannelSlowThreshold: params.DatachannelSlowThreshold, - Transport: livekit.SignalTarget_SUBSCRIBER, - Handler: TransportManagerTransportHandler{params.SubscriberHandler, t, lgr}, - }) - if err != nil { - return nil, err + if !t.params.UseOneShotSignallingMode && !t.params.UseSinglePeerConnection { + lgr := LoggerWithPCTarget(params.Logger, livekit.SignalTarget_SUBSCRIBER) + subscriber, err := NewPCTransport(TransportParams{ + ProtocolVersion: params.ProtocolVersion, + Config: params.Config, + DirectionConfig: params.Config.Subscriber, + CongestionControlConfig: params.CongestionControlConfig, + EnabledCodecs: params.EnabledSubscribeCodecs, + Logger: lgr, + ClientInfo: params.ClientInfo, + IsOfferer: true, + IsSendSide: true, + AllowPlayoutDelay: params.AllowPlayoutDelay, + DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: params.DatachannelSlowThreshold, + Transport: livekit.SignalTarget_SUBSCRIBER, + Handler: TransportManagerTransportHandler{params.SubscriberHandler, t, lgr}, + FireOnTrackBySdp: params.FireOnTrackBySdp, + }) + if err != nil { + return nil, err + } + t.subscriber = subscriber } - t.subscriber = subscriber - if !t.params.Migration { + if !t.params.Migration && t.params.SubscriberAsPrimary { if err := t.createDataChannelsForSubscriber(nil); err != nil { return nil, err } @@ -202,8 +196,12 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro } func (t *TransportManager) Close() { - t.publisher.Close() - t.subscriber.Close() + if t.publisher != nil { + t.publisher.Close() + } + if t.subscriber != nil { + t.subscriber.Close() + } } func (t *TransportManager) SubscriberClose() { @@ -235,37 +233,49 @@ func (t *TransportManager) WritePublisherRTCP(pkts []rtcp.Packet) error { } func (t *TransportManager) GetSubscriberRTT() (float64, bool) { - return t.subscriber.GetRTT() + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.GetRTT() + } else { + return t.subscriber.GetRTT() + } } func (t *TransportManager) HasSubscriberEverConnected() bool { - return t.subscriber.HasEverConnected() + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.HasEverConnected() + } else { + return t.subscriber.HasEverConnected() + } } func (t *TransportManager) AddTrackLocal( trackLocal webrtc.TrackLocal, params types.AddTrackParams, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, ) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { - if t.params.UseOneShotSignallingMode { - return t.publisher.AddTrack(trackLocal, params) + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.AddTrack(trackLocal, params, enabledCodecs, rtcpFeedbackConfig) } else { - return t.subscriber.AddTrack(trackLocal, params) + return t.subscriber.AddTrack(trackLocal, params, enabledCodecs, rtcpFeedbackConfig) } } func (t *TransportManager) AddTransceiverFromTrackLocal( trackLocal webrtc.TrackLocal, params types.AddTrackParams, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, ) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { - if t.params.UseOneShotSignallingMode { - return t.publisher.AddTransceiverFromTrack(trackLocal, params) + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.AddTransceiverFromTrack(trackLocal, params, enabledCodecs, rtcpFeedbackConfig) } else { - return t.subscriber.AddTransceiverFromTrack(trackLocal, params) + return t.subscriber.AddTransceiverFromTrack(trackLocal, params, enabledCodecs, rtcpFeedbackConfig) } } func (t *TransportManager) RemoveTrackLocal(sender *webrtc.RTPSender) error { - if t.params.UseOneShotSignallingMode { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { return t.publisher.RemoveTrack(sender) } else { return t.subscriber.RemoveTrack(sender) @@ -273,7 +283,7 @@ func (t *TransportManager) RemoveTrackLocal(sender *webrtc.RTPSender) error { } func (t *TransportManager) WriteSubscriberRTCP(pkts []rtcp.Packet) error { - if t.params.UseOneShotSignallingMode { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { return t.publisher.WriteRTCP(pkts) } else { return t.subscriber.WriteRTCP(pkts) @@ -281,15 +291,27 @@ func (t *TransportManager) WriteSubscriberRTCP(pkts []rtcp.Packet) error { } func (t *TransportManager) GetSubscriberPacer() pacer.Pacer { - return t.subscriber.GetPacer() + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.GetPacer() + } else { + return t.subscriber.GetPacer() + } } func (t *TransportManager) AddSubscribedTrack(subTrack types.SubscribedTrack) { - t.subscriber.AddTrackToStreamAllocator(subTrack) + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + t.publisher.AddTrackToStreamAllocator(subTrack) + } else { + t.subscriber.AddTrackToStreamAllocator(subTrack) + } } func (t *TransportManager) RemoveSubscribedTrack(subTrack types.SubscribedTrack) { - t.subscriber.RemoveTrackFromStreamAllocator(subTrack) + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + t.publisher.RemoveTrackFromStreamAllocator(subTrack) + } else { + t.subscriber.RemoveTrackFromStreamAllocator(subTrack) + } } func (t *TransportManager) SendDataMessage(kind livekit.DataPacket_Kind, data []byte) error { @@ -393,12 +415,9 @@ func (t *TransportManager) createDataChannelsForSubscriber(pendingDataChannels [ } func (t *TransportManager) GetUnmatchMediaForOffer(parsedOffer *sdp.SessionDescription, mediaType string) (unmatched []*sdp.MediaDescription, err error) { - // prefer codec from offer for clients that don't support setCodecPreferences var lastMatchedMid string - lastAnswer := t.lastPublisherAnswer.Load() - if lastAnswer != nil { - answer := lastAnswer.(webrtc.SessionDescription) - parsedAnswer, err1 := answer.Unmarshal() + if lastAnswer := t.publisher.CurrentLocalDescription(); lastAnswer != nil { + parsedAnswer, err1 := lastAnswer.Unmarshal() if err1 != nil { // should not happen t.params.Logger.Errorw("failed to parse last answer", err1) @@ -428,11 +447,8 @@ func (t *TransportManager) GetUnmatchMediaForOffer(parsedOffer *sdp.SessionDescr return } -func (t *TransportManager) LastPublisherOffer() webrtc.SessionDescription { - if sd := t.lastPublisherOffer.Load(); sd != nil { - return sd.(webrtc.SessionDescription) - } - return webrtc.SessionDescription{} +func (t *TransportManager) LastPublisherOffer() *webrtc.SessionDescription { + return t.publisher.CurrentRemoteDescription() } func (t *TransportManager) HandleOffer(offer webrtc.SessionDescription, offerId uint32, shouldPend bool) error { @@ -444,17 +460,12 @@ func (t *TransportManager) HandleOffer(offer webrtc.SessionDescription, offerId return nil } t.lock.Unlock() - t.lastPublisherOffer.Store(offer) return t.publisher.HandleRemoteDescription(offer, offerId) } func (t *TransportManager) GetAnswer() (webrtc.SessionDescription, uint32, error) { - answer, answerId, err := t.publisher.GetAnswer() - if err == nil { - t.lastPublisherAnswer.Store(answer) - } - return answer, answerId, err + return t.publisher.GetAnswer() } func (t *TransportManager) GetPublisherICESessionUfrag() (string, error) { @@ -501,7 +512,11 @@ func (t *TransportManager) AddICECandidate(candidate webrtc.ICECandidateInit, ta } func (t *TransportManager) NegotiateSubscriber(force bool) { - t.subscriber.Negotiate(force) + if t.subscriber != nil { + t.subscriber.Negotiate(force) + } else { + t.publisher.Negotiate(force) + } } func (t *TransportManager) HandleClientReconnect(reason livekit.ReconnectReason) { @@ -512,12 +527,16 @@ func (t *TransportManager) HandleClientReconnect(reason livekit.ReconnectReason) ) switch reason { case livekit.ReconnectReason_RR_PUBLISHER_FAILED: - resetShortConnection = true - isShort, duration = t.publisher.IsShortConnection(time.Now()) + if t.publisher != nil { + resetShortConnection = true + isShort, duration = t.publisher.IsShortConnection(time.Now()) + } case livekit.ReconnectReason_RR_SUBSCRIBER_FAILED: - resetShortConnection = true - isShort, duration = t.subscriber.IsShortConnection(time.Now()) + if t.subscriber != nil { + resetShortConnection = true + isShort, duration = t.subscriber.IsShortConnection(time.Now()) + } } if isShort { @@ -529,15 +548,23 @@ func (t *TransportManager) HandleClientReconnect(reason livekit.ReconnectReason) } if resetShortConnection { - t.publisher.ResetShortConnOnICERestart() - t.subscriber.ResetShortConnOnICERestart() + if t.publisher != nil { + t.publisher.ResetShortConnOnICERestart() + } + if t.subscriber != nil { + t.subscriber.ResetShortConnOnICERestart() + } } } func (t *TransportManager) ICERestart(iceConfig *livekit.ICEConfig) error { t.SetICEConfig(iceConfig) - return t.subscriber.ICERestart() + if t.subscriber != nil { + return t.subscriber.ICERestart() + } + + return nil } func (t *TransportManager) OnICEConfigChanged(f func(iceConfig *livekit.ICEConfig)) { @@ -589,8 +616,12 @@ func (t *TransportManager) configureICE(iceConfig *livekit.ICEConfig, reset bool t.mediaLossProxy.OnMediaLossUpdate(nil) } - t.publisher.SetPreferTCP(iceConfig.PreferencePublisher == livekit.ICECandidateType_ICT_TCP) - t.subscriber.SetPreferTCP(iceConfig.PreferenceSubscriber == livekit.ICECandidateType_ICT_TCP) + if t.publisher != nil { + t.publisher.SetPreferTCP(iceConfig.PreferencePublisher == livekit.ICECandidateType_ICT_TCP) + } + if t.subscriber != nil { + t.subscriber.SetPreferTCP(iceConfig.PreferenceSubscriber == livekit.ICECandidateType_ICT_TCP) + } if onICEConfigChanged != nil { onICEConfigChanged(iceConfig) @@ -604,6 +635,10 @@ func (t *TransportManager) SubscriberAsPrimary() bool { func (t *TransportManager) GetICEConnectionInfo() []*types.ICEConnectionInfo { infos := make([]*types.ICEConnectionInfo, 0, 2) for _, pc := range []*PCTransport{t.publisher, t.subscriber} { + if pc == nil { + continue + } + info := pc.GetICEConnectionInfo() if info.HasCandidates() { infos = append(infos, info) @@ -613,20 +648,38 @@ func (t *TransportManager) GetICEConnectionInfo() []*types.ICEConnectionInfo { } func (t *TransportManager) getTransport(isPrimary bool) *PCTransport { - pcTransport := t.publisher - if (isPrimary && t.params.SubscriberAsPrimary) || (!isPrimary && !t.params.SubscriberAsPrimary) { - pcTransport = t.subscriber - } + switch { + case t.publisher == nil: + return t.subscriber - return pcTransport + case t.subscriber == nil: + return t.publisher + + default: + pcTransport := t.publisher + if (isPrimary && t.params.SubscriberAsPrimary) || (!isPrimary && !t.params.SubscriberAsPrimary) { + pcTransport = t.subscriber + } + + return pcTransport + } } func (t *TransportManager) getLowestPriorityConnectionType() types.ICEConnectionType { - ctype := t.publisher.GetICEConnectionType() - if stype := t.subscriber.GetICEConnectionType(); stype > ctype { - ctype = stype + switch { + case t.publisher == nil: + return t.subscriber.GetICEConnectionType() + + case t.subscriber == nil: + return t.publisher.GetICEConnectionType() + + default: + ctype := t.publisher.GetICEConnectionType() + if stype := t.subscriber.GetICEConnectionType(); stype > ctype { + ctype = stype + } + return ctype } - return ctype } func (t *TransportManager) handleConnectionFailed(isShortLived bool) { @@ -739,7 +792,11 @@ func (t *TransportManager) handleConnectionFailed(isShortLived bool) { }, false) } -func (t *TransportManager) SetMigrateInfo(previousOffer, previousAnswer *webrtc.SessionDescription, dataChannels []*livekit.DataChannelInfo) { +func (t *TransportManager) SetMigrateInfo( + previousOffer *webrtc.SessionDescription, + previousAnswer *webrtc.SessionDescription, + dataChannels []*livekit.DataChannelInfo, +) { t.lock.Lock() t.pendingDataChannelsPublisher = make([]*livekit.DataChannelInfo, 0, len(dataChannels)) pendingDataChannelsSubscriber := make([]*livekit.DataChannelInfo, 0, len(dataChannels)) @@ -758,7 +815,11 @@ func (t *TransportManager) SetMigrateInfo(previousOffer, previousAnswer *webrtc. } } - t.subscriber.SetPreviousSdp(previousOffer, previousAnswer) + if t.params.UseSinglePeerConnection { + t.publisher.SetPreviousSdp(previousAnswer, previousOffer) + } else { + t.subscriber.SetPreviousSdp(previousOffer, previousAnswer) + } } func (t *TransportManager) ProcessPendingPublisherDataChannels() { @@ -823,7 +884,11 @@ func (t *TransportManager) onMediaLossUpdate(loss uint8) { t.lock.Unlock() t.params.Logger.Infow("udp connection unstable, switch to tcp", "signalingRTT", t.signalingRTT) - t.params.SubscriberHandler.OnFailed(true, t.subscriber.GetICEConnectionInfo()) + if t.params.UseSinglePeerConnection { + t.params.PublisherHandler.OnFailed(true, t.publisher.GetICEConnectionInfo()) + } else { + t.params.SubscriberHandler.OnFailed(true, t.subscriber.GetICEConnectionInfo()) + } return } } @@ -835,8 +900,12 @@ func (t *TransportManager) UpdateSignalingRTT(rtt uint32) { t.lock.Lock() t.signalingRTT = rtt t.lock.Unlock() - t.publisher.SetSignalingRTT(rtt) - t.subscriber.SetSignalingRTT(rtt) + if t.publisher != nil { + t.publisher.SetSignalingRTT(rtt) + } + if t.subscriber != nil { + t.subscriber.SetSignalingRTT(rtt) + } // TODO: considering using tcp rtt to calculate ice connection cost, if ice connection can't be established // within 5 * tcp rtt(at least 5s), means udp traffic might be block/dropped, switch to tcp. @@ -881,11 +950,19 @@ func (t *TransportManager) SetSignalSourceValid(valid bool) { } func (t *TransportManager) SetSubscriberAllowPause(allowPause bool) { - t.subscriber.SetAllowPauseOfStreamAllocator(allowPause) + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + t.publisher.SetAllowPauseOfStreamAllocator(allowPause) + } else { + t.subscriber.SetAllowPauseOfStreamAllocator(allowPause) + } } func (t *TransportManager) SetSubscriberChannelCapacity(channelCapacity int64) { - t.subscriber.SetChannelCapacityOfStreamAllocator(channelCapacity) + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + t.publisher.SetChannelCapacityOfStreamAllocator(channelCapacity) + } else { + t.subscriber.SetChannelCapacityOfStreamAllocator(channelCapacity) + } } func (t *TransportManager) hasRecentSignalLocked() bool { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 020d3cf72..eacf95f15 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -360,6 +360,7 @@ type LocalParticipant interface { ProtocolVersion() ProtocolVersion SupportsSyncStreamID() bool SupportsTransceiverReuse() bool + IsUsingSinglePeerConnection() bool IsClosed() bool IsReady() bool IsDisconnected() bool @@ -490,7 +491,8 @@ type LocalParticipant interface { SetMigrateState(s MigrateState) MigrateState() MigrateState SetMigrateInfo( - previousOffer, previousAnswer *webrtc.SessionDescription, + previousOffer *webrtc.SessionDescription, + previousAnswer *webrtc.SessionDescription, mediaTracks []*livekit.TrackPublishedResponse, dataChannels []*livekit.DataChannelInfo, dataChannelReceiveState []*livekit.DataChannelReceiveState, diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index b6494291c..d4c19d8d3 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -799,6 +799,16 @@ type FakeLocalParticipant struct { isTrackNameSubscribedReturnsOnCall map[int]struct { result1 bool } + IsUsingSinglePeerConnectionStub func() bool + isUsingSinglePeerConnectionMutex sync.RWMutex + isUsingSinglePeerConnectionArgsForCall []struct { + } + isUsingSinglePeerConnectionReturns struct { + result1 bool + } + isUsingSinglePeerConnectionReturnsOnCall map[int]struct { + result1 bool + } IssueFullReconnectStub func(types.ParticipantCloseReason) issueFullReconnectMutex sync.RWMutex issueFullReconnectArgsForCall []struct { @@ -5521,6 +5531,59 @@ func (fake *FakeLocalParticipant) IsTrackNameSubscribedReturnsOnCall(i int, resu }{result1} } +func (fake *FakeLocalParticipant) IsUsingSinglePeerConnection() bool { + fake.isUsingSinglePeerConnectionMutex.Lock() + ret, specificReturn := fake.isUsingSinglePeerConnectionReturnsOnCall[len(fake.isUsingSinglePeerConnectionArgsForCall)] + fake.isUsingSinglePeerConnectionArgsForCall = append(fake.isUsingSinglePeerConnectionArgsForCall, struct { + }{}) + stub := fake.IsUsingSinglePeerConnectionStub + fakeReturns := fake.isUsingSinglePeerConnectionReturns + fake.recordInvocation("IsUsingSinglePeerConnection", []interface{}{}) + fake.isUsingSinglePeerConnectionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsUsingSinglePeerConnectionCallCount() int { + fake.isUsingSinglePeerConnectionMutex.RLock() + defer fake.isUsingSinglePeerConnectionMutex.RUnlock() + return len(fake.isUsingSinglePeerConnectionArgsForCall) +} + +func (fake *FakeLocalParticipant) IsUsingSinglePeerConnectionCalls(stub func() bool) { + fake.isUsingSinglePeerConnectionMutex.Lock() + defer fake.isUsingSinglePeerConnectionMutex.Unlock() + fake.IsUsingSinglePeerConnectionStub = stub +} + +func (fake *FakeLocalParticipant) IsUsingSinglePeerConnectionReturns(result1 bool) { + fake.isUsingSinglePeerConnectionMutex.Lock() + defer fake.isUsingSinglePeerConnectionMutex.Unlock() + fake.IsUsingSinglePeerConnectionStub = nil + fake.isUsingSinglePeerConnectionReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsUsingSinglePeerConnectionReturnsOnCall(i int, result1 bool) { + fake.isUsingSinglePeerConnectionMutex.Lock() + defer fake.isUsingSinglePeerConnectionMutex.Unlock() + fake.IsUsingSinglePeerConnectionStub = nil + if fake.isUsingSinglePeerConnectionReturnsOnCall == nil { + fake.isUsingSinglePeerConnectionReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isUsingSinglePeerConnectionReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeLocalParticipant) IssueFullReconnect(arg1 types.ParticipantCloseReason) { fake.issueFullReconnectMutex.Lock() fake.issueFullReconnectArgsForCall = append(fake.issueFullReconnectArgsForCall, struct { diff --git a/pkg/rtc/wrappedreceiver.go b/pkg/rtc/wrappedreceiver.go index 5195908d2..37c7947b1 100644 --- a/pkg/rtc/wrappedreceiver.go +++ b/pkg/rtc/wrappedreceiver.go @@ -63,16 +63,10 @@ func NewWrappedReceiver(params WrappedReceiverParams) *WrappedReceiver { normalizedMimeType := mime.NormalizeMimeType(codecs[0].MimeType) if normalizedMimeType == mime.MimeTypeRED { // if upstream is opus/red, then add opus to match clients that don't support red - codecs = append(codecs, webrtc.RTPCodecParameters{ - RTPCodecCapability: OpusCodecCapability, - PayloadType: 111, - }) + codecs = append(codecs, OpusCodecParameters) } else if !params.DisableRed && normalizedMimeType == mime.MimeTypeOpus { // if upstream is opus only and red enabled, add red to match clients that support red - codecs = append(codecs, webrtc.RTPCodecParameters{ - RTPCodecCapability: RedCodecCapability, - PayloadType: 63, - }) + codecs = append(codecs, RedCodecParameters) // prefer red codec codecs[0], codecs[1] = codecs[1], codecs[0] } diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 23dc53849..3901d3b78 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -416,30 +416,37 @@ func (r *RoomManager) StartSession( if pi.DisableICELite { rtcConf.SettingEngine.SetLite(false) } + rtcConf.UpdatePublisherConfig(pi.UseSinglePeerConnection) + // default allow forceTCP allowFallback := true if r.config.RTC.AllowTCPFallback != nil { allowFallback = *r.config.RTC.AllowTCPFallback } + // default do not force full reconnect on a publication error reconnectOnPublicationError := false if r.config.RTC.ReconnectOnPublicationError != nil { reconnectOnPublicationError = *r.config.RTC.ReconnectOnPublicationError } + // default do not force full reconnect on a subscription error reconnectOnSubscriptionError := false if r.config.RTC.ReconnectOnSubscriptionError != nil { reconnectOnSubscriptionError = *r.config.RTC.ReconnectOnSubscriptionError } + // default do not force full reconnect on a data channel error reconnectOnDataChannelError := false if r.config.RTC.ReconnectOnDataChannelError != nil { reconnectOnDataChannelError = *r.config.RTC.ReconnectOnDataChannelError } + subscriberAllowPause := r.config.RTC.CongestionControl.AllowPause if pi.SubscriberAllowPause != nil { subscriberAllowPause = *pi.SubscriberAllowPause } + participant, err = rtc.NewParticipant(rtc.ParticipantParams{ Identity: pi.Identity, Name: pi.Name, @@ -486,6 +493,7 @@ func (r *RoomManager) StartSession( DataChannelMaxBufferedAmount: r.config.RTC.DataChannelMaxBufferedAmount, DatachannelSlowThreshold: r.config.RTC.DatachannelSlowThreshold, FireOnTrackBySdp: true, + UseSinglePeerConnection: pi.UseSinglePeerConnection, }) if err != nil { return err diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index 1495c4328..69715892b 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -118,6 +118,7 @@ func decodeAttributes(str string) (map[string]string, error) { func (s *RTCService) validateInternal(lgr logger.Logger, r *http.Request, strict bool) (livekit.RoomName, routing.ParticipantInit, int, error) { var params ValidateConnectRequestParams + useSinglePeerConnection := false joinRequest := &livekit.JoinRequest{} wrappedJoinRequestBase64 := r.FormValue("join_request") @@ -137,6 +138,7 @@ func (s *RTCService) validateInternal(lgr logger.Logger, r *http.Request, strict params.attributes = attrs } } else { + useSinglePeerConnection = true if wrappedProtoBytes, err := base64.URLEncoding.DecodeString(wrappedJoinRequestBase64); err != nil { return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot base64 decode wrapped join request") } else { @@ -186,11 +188,12 @@ func (s *RTCService) validateInternal(lgr logger.Logger, r *http.Request, strict } pi := routing.ParticipantInit{ - Identity: livekit.ParticipantIdentity(res.grants.Identity), - Name: livekit.ParticipantName(res.grants.Name), - Grants: res.grants, - Region: res.region, - CreateRoom: res.createRoomRequest, + Identity: livekit.ParticipantIdentity(res.grants.Identity), + Name: livekit.ParticipantName(res.grants.Name), + Grants: res.grants, + Region: res.region, + CreateRoom: res.createRoomRequest, + UseSinglePeerConnection: useSinglePeerConnection, } if wrappedJoinRequestBase64 == "" { diff --git a/test/agent_test.go b/test/agent_test.go index d6a2571f2..0b1a42f44 100644 --- a/test/agent_test.go +++ b/test/agent_test.go @@ -22,7 +22,6 @@ import ( "github.com/stretchr/testify/require" "github.com/livekit/livekit-server/pkg/testutils" - testclient "github.com/livekit/livekit-server/test/client" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" ) @@ -33,194 +32,205 @@ var ( ) func TestAgents(t *testing.T) { - _, finish := setupSingleNodeTest("TestAgents") - defer finish() + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + _, finish := setupSingleNodeTest("TestAgents") + defer finish() - ac1, err := newAgentClient(agentToken(), defaultServerPort) - require.NoError(t, err) - ac2, err := newAgentClient(agentToken(), defaultServerPort) - require.NoError(t, err) - ac3, err := newAgentClient(agentToken(), defaultServerPort) - require.NoError(t, err) - ac4, err := newAgentClient(agentToken(), defaultServerPort) - require.NoError(t, err) - ac5, err := newAgentClient(agentToken(), defaultServerPort) - require.NoError(t, err) - ac6, err := newAgentClient(agentToken(), defaultServerPort) - require.NoError(t, err) - defer ac1.close() - defer ac2.close() - defer ac3.close() - defer ac4.close() - defer ac5.close() - defer ac6.close() - ac1.Run(livekit.JobType_JT_ROOM, "default") - ac2.Run(livekit.JobType_JT_ROOM, "default") - ac3.Run(livekit.JobType_JT_PUBLISHER, "default") - ac4.Run(livekit.JobType_JT_PUBLISHER, "default") - ac5.Run(livekit.JobType_JT_PARTICIPANT, "default") - ac6.Run(livekit.JobType_JT_PARTICIPANT, "default") + ac1, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac2, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac3, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac4, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac5, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac6, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + defer ac1.close() + defer ac2.close() + defer ac3.close() + defer ac4.close() + defer ac5.close() + defer ac6.close() + ac1.Run(livekit.JobType_JT_ROOM, "default") + ac2.Run(livekit.JobType_JT_ROOM, "default") + ac3.Run(livekit.JobType_JT_PUBLISHER, "default") + ac4.Run(livekit.JobType_JT_PUBLISHER, "default") + ac5.Run(livekit.JobType_JT_PARTICIPANT, "default") + ac6.Run(livekit.JobType_JT_PARTICIPANT, "default") - testutils.WithTimeout(t, func() string { - if ac1.registered.Load() != 1 || ac2.registered.Load() != 1 || ac3.registered.Load() != 1 || ac4.registered.Load() != 1 || ac5.registered.Load() != 1 || ac6.registered.Load() != 1 { - return "worker not registered" - } + testutils.WithTimeout(t, func() string { + if ac1.registered.Load() != 1 || ac2.registered.Load() != 1 || ac3.registered.Load() != 1 || ac4.registered.Load() != 1 || ac5.registered.Load() != 1 || ac6.registered.Load() != 1 { + return "worker not registered" + } - return "" - }, RegisterTimeout) + return "" + }, RegisterTimeout) - c1 := createRTCClient("c1", defaultServerPort, nil) - c2 := createRTCClient("c2", defaultServerPort, &testclient.Options{UseJoinRequestQueryParam: true}) - waitUntilConnected(t, c1, c2) + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("c2", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1, c2) - // publish 2 tracks - t1, err := c1.AddStaticTrack("audio/opus", "audio", "micro") - require.NoError(t, err) - defer t1.Stop() - t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") - require.NoError(t, err) - defer t2.Stop() + // publish 2 tracks + t1, err := c1.AddStaticTrack("audio/opus", "audio", "micro") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t2.Stop() - testutils.WithTimeout(t, func() string { - if ac1.roomJobs.Load()+ac2.roomJobs.Load() != 1 { - return "room job not assigned" - } + testutils.WithTimeout(t, func() string { + if ac1.roomJobs.Load()+ac2.roomJobs.Load() != 1 { + return "room job not assigned" + } - if ac3.publisherJobs.Load()+ac4.publisherJobs.Load() != 1 { - return fmt.Sprintf("publisher jobs not assigned, ac3: %d, ac4: %d", ac3.publisherJobs.Load(), ac4.publisherJobs.Load()) - } + if ac3.publisherJobs.Load()+ac4.publisherJobs.Load() != 1 { + return fmt.Sprintf("publisher jobs not assigned, ac3: %d, ac4: %d", ac3.publisherJobs.Load(), ac4.publisherJobs.Load()) + } - if ac5.participantJobs.Load()+ac6.participantJobs.Load() != 2 { - return fmt.Sprintf("participant jobs not assigned, ac5: %d, ac6: %d", ac5.participantJobs.Load(), ac6.participantJobs.Load()) - } + if ac5.participantJobs.Load()+ac6.participantJobs.Load() != 2 { + return fmt.Sprintf("participant jobs not assigned, ac5: %d, ac6: %d", ac5.participantJobs.Load(), ac6.participantJobs.Load()) + } - return "" - }, 6*time.Second) + return "" + }, 6*time.Second) - // publish 2 tracks - t3, err := c2.AddStaticTrack("audio/opus", "audio", "micro") - require.NoError(t, err) - defer t3.Stop() - t4, err := c2.AddStaticTrack("video/vp8", "video", "webcam") - require.NoError(t, err) - defer t4.Stop() + // publish 2 tracks + t3, err := c2.AddStaticTrack("audio/opus", "audio", "micro") + require.NoError(t, err) + defer t3.Stop() + t4, err := c2.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t4.Stop() - testutils.WithTimeout(t, func() string { - if ac1.roomJobs.Load()+ac2.roomJobs.Load() != 1 { - return "room job must be assigned 1 time" - } + testutils.WithTimeout(t, func() string { + if ac1.roomJobs.Load()+ac2.roomJobs.Load() != 1 { + return "room job must be assigned 1 time" + } - if ac3.publisherJobs.Load()+ac4.publisherJobs.Load() != 2 { - return "2 publisher jobs must assigned" - } + if ac3.publisherJobs.Load()+ac4.publisherJobs.Load() != 2 { + return "2 publisher jobs must assigned" + } - if ac5.participantJobs.Load()+ac6.participantJobs.Load() != 2 { - return "2 participant jobs must assigned" - } + if ac5.participantJobs.Load()+ac6.participantJobs.Load() != 2 { + return "2 participant jobs must assigned" + } - return "" - }, AssignJobTimeout) + return "" + }, AssignJobTimeout) + }) + } } func TestAgentNamespaces(t *testing.T) { - _, finish := setupSingleNodeTest("TestAgentNamespaces") - defer finish() + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + _, finish := setupSingleNodeTest("TestAgentNamespaces") + defer finish() - ac1, err := newAgentClient(agentToken(), defaultServerPort) - require.NoError(t, err) - ac2, err := newAgentClient(agentToken(), defaultServerPort) - require.NoError(t, err) - defer ac1.close() - defer ac2.close() - ac1.Run(livekit.JobType_JT_ROOM, "namespace1") - ac2.Run(livekit.JobType_JT_ROOM, "namespace2") + ac1, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac2, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + defer ac1.close() + defer ac2.close() + ac1.Run(livekit.JobType_JT_ROOM, "namespace1") + ac2.Run(livekit.JobType_JT_ROOM, "namespace2") - _, err = roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ - Name: testRoom, - Agents: []*livekit.RoomAgentDispatch{ - {}, - { - AgentName: "ag", - }, - }, - }) - require.NoError(t, err) + _, err = roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ + Name: testRoom, + Agents: []*livekit.RoomAgentDispatch{ + {}, + { + AgentName: "ag", + }, + }, + }) + require.NoError(t, err) - testutils.WithTimeout(t, func() string { - if ac1.registered.Load() != 1 || ac2.registered.Load() != 1 { - return "worker not registered" - } - return "" - }, RegisterTimeout) + testutils.WithTimeout(t, func() string { + if ac1.registered.Load() != 1 || ac2.registered.Load() != 1 { + return "worker not registered" + } + return "" + }, RegisterTimeout) - c1 := createRTCClient("c1", defaultServerPort, nil) - waitUntilConnected(t, c1) + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1) - testutils.WithTimeout(t, func() string { - if ac1.roomJobs.Load() != 1 || ac2.roomJobs.Load() != 1 { - return "room job not assigned" - } + testutils.WithTimeout(t, func() string { + if ac1.roomJobs.Load() != 1 || ac2.roomJobs.Load() != 1 { + return "room job not assigned" + } - job1 := <-ac1.requestedJobs - job2 := <-ac2.requestedJobs + job1 := <-ac1.requestedJobs + job2 := <-ac2.requestedJobs - if job1.Namespace != "namespace1" { - return "namespace is not 'namespace'" - } + if job1.Namespace != "namespace1" { + return "namespace is not 'namespace'" + } - if job2.Namespace != "namespace2" { - return "namespace is not 'namespace2'" - } + if job2.Namespace != "namespace2" { + return "namespace is not 'namespace2'" + } - if job1.Id == job2.Id { - return "job ids are the same" - } - - return "" - }, AssignJobTimeout) + if job1.Id == job2.Id { + return "job ids are the same" + } + return "" + }, AssignJobTimeout) + }) + } } func TestAgentMultiNode(t *testing.T) { - _, _, finish := setupMultiNodeTest("TestAgentMultiNode") - defer finish() + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestAgentMultiNode") + defer finish() - ac1, err := newAgentClient(agentToken(), defaultServerPort) - require.NoError(t, err) - ac2, err := newAgentClient(agentToken(), defaultServerPort) - require.NoError(t, err) - defer ac1.close() - defer ac2.close() - ac1.Run(livekit.JobType_JT_ROOM, "default") - ac2.Run(livekit.JobType_JT_PUBLISHER, "default") + ac1, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac2, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + defer ac1.close() + defer ac2.close() + ac1.Run(livekit.JobType_JT_ROOM, "default") + ac2.Run(livekit.JobType_JT_PUBLISHER, "default") - testutils.WithTimeout(t, func() string { - if ac1.registered.Load() != 1 || ac2.registered.Load() != 1 { - return "worker not registered" - } - return "" - }, RegisterTimeout) + testutils.WithTimeout(t, func() string { + if ac1.registered.Load() != 1 || ac2.registered.Load() != 1 { + return "worker not registered" + } + return "" + }, RegisterTimeout) - c1 := createRTCClient("c1", secondServerPort, nil) // Create a room on the second node - waitUntilConnected(t, c1) + c1 := createRTCClient("c1", secondServerPort, useSinglePeerConnection, nil) // Create a room on the second node + waitUntilConnected(t, c1) - t1, err := c1.AddStaticTrack("audio/opus", "audio", "micro") - require.NoError(t, err) - defer t1.Stop() + t1, err := c1.AddStaticTrack("audio/opus", "audio", "micro") + require.NoError(t, err) + defer t1.Stop() - time.Sleep(time.Second * 10) + time.Sleep(time.Second * 10) - testutils.WithTimeout(t, func() string { - if ac1.roomJobs.Load() != 1 { - return "room job not assigned" - } + testutils.WithTimeout(t, func() string { + if ac1.roomJobs.Load() != 1 { + return "room job not assigned" + } - if ac2.publisherJobs.Load() != 1 { - return "participant job not assigned" - } + if ac2.publisherJobs.Load() != 1 { + return "participant job not assigned" + } - return "" - }, AssignJobTimeout) + return "" + }, AssignJobTimeout) + }) + } } func agentToken() string { diff --git a/test/client/client.go b/test/client/client.go index c364f3d0c..5bdf98c1d 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -55,10 +55,11 @@ type SignalResponseHandler func(msg *livekit.SignalResponse) error type SignalResponseInterceptor func(msg *livekit.SignalResponse, next SignalResponseHandler) error type RTCClient struct { - id livekit.ParticipantID - conn *websocket.Conn - publisher *rtc.PCTransport - subscriber *rtc.PCTransport + useSinglePeerConnection bool + id livekit.ParticipantID + conn *websocket.Conn + publisher *rtc.PCTransport + subscriber *rtc.PCTransport // sid => track localTracks map[string]webrtc.TrackLocal trackSenders map[string]*webrtc.RTPSender @@ -80,11 +81,13 @@ type RTCClient struct { publisherFullyEstablished atomic.Bool subscriberFullyEstablished atomic.Bool pongReceivedAt atomic.Int64 - lastAnswer atomic.Pointer[webrtc.SessionDescription] // tracks waiting to be acked, cid => trackInfo pendingPublishedTracks map[string]*livekit.TrackInfo + // remote tracks waiting to be processed + pendingRemoteTracks []*webrtc.TrackRemote + pendingTrackWriters []*TrackWriter OnConnected func() OnDataReceived func(data []byte, sid string) @@ -139,7 +142,7 @@ func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error clientInfo := &livekit.ClientInfo{ Os: runtime.GOOS, Sdk: livekit.ClientInfo_GO, - Protocol: types.CurrentProtocol, + Protocol: int32(types.CurrentProtocol), } if opts.ClientInfo != nil { clientInfo = opts.ClientInfo @@ -202,19 +205,20 @@ func SetAuthorizationToken(header http.Header, token string) { header.Set("Authorization", "Bearer "+token) } -func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { +func NewRTCClient(conn *websocket.Conn, useSinglePeerConnection bool, opts *Options) (*RTCClient, error) { var err error c := &RTCClient{ - conn: conn, - localTracks: make(map[string]webrtc.TrackLocal), - trackSenders: make(map[string]*webrtc.RTPSender), - pendingPublishedTracks: make(map[string]*livekit.TrackInfo), - subscribedTracks: make(map[livekit.ParticipantID][]*webrtc.TrackRemote), - remoteParticipants: make(map[livekit.ParticipantID]*livekit.ParticipantInfo), - me: &webrtc.MediaEngine{}, - lastPackets: make(map[livekit.ParticipantID]*rtp.Packet), - bytesReceived: make(map[livekit.ParticipantID]uint64), + useSinglePeerConnection: useSinglePeerConnection, + conn: conn, + localTracks: make(map[string]webrtc.TrackLocal), + trackSenders: make(map[string]*webrtc.RTPSender), + pendingPublishedTracks: make(map[string]*livekit.TrackInfo), + subscribedTracks: make(map[livekit.ParticipantID][]*webrtc.TrackRemote), + remoteParticipants: make(map[livekit.ParticipantID]*livekit.ParticipantInfo), + me: &webrtc.MediaEngine{}, + lastPackets: make(map[livekit.ParticipantID]*rtp.Packet), + bytesReceived: make(map[livekit.ParticipantID]uint64), } c.ctx, c.cancel = context.WithCancel(context.Background()) @@ -261,24 +265,14 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { // publisherHandler := &transportfakes.FakeHandler{} c.publisher, err = rtc.NewPCTransport(rtc.TransportParams{ - Config: &conf, - DirectionConfig: conf.Subscriber, - EnabledCodecs: codecs, - IsOfferer: true, - IsSendSide: true, - Handler: publisherHandler, - DatachannelSlowThreshold: 1024 * 1024 * 1024, - }) - if err != nil { - return nil, err - } - subscriberHandler := &transportfakes.FakeHandler{} - c.subscriber, err = rtc.NewPCTransport(rtc.TransportParams{ Config: &conf, - DirectionConfig: conf.Publisher, + DirectionConfig: conf.Subscriber, EnabledCodecs: codecs, - Handler: subscriberHandler, + IsOfferer: true, + IsSendSide: true, + Handler: publisherHandler, DatachannelMaxReceiverBufferSize: 1500, + DatachannelSlowThreshold: 1024 * 1024 * 1024, FireOnTrackBySdp: true, }) if err != nil { @@ -288,6 +282,28 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { publisherHandler.OnICECandidateCalls(func(ic *webrtc.ICECandidate, t livekit.SignalTarget) error { return c.SendIceCandidate(ic, livekit.SignalTarget_PUBLISHER) }) + publisherHandler.OnTrackCalls(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + go c.processRemoteTrack(track) + }) + publisherHandler.OnDataMessageCalls(c.handleDataMessage) + publisherHandler.OnDataMessageUnlabeledCalls(c.handleDataMessageUnlabeled) + publisherHandler.OnInitialConnectedCalls(func() { + logger.Debugw("publisher initial connected", "participant", c.localParticipant.Identity) + + c.lock.Lock() + defer c.lock.Unlock() + for _, tw := range c.pendingTrackWriters { + if err := tw.Start(); err != nil { + logger.Errorw("track writer error", err) + } + } + + c.pendingTrackWriters = nil + + if c.OnConnected != nil { + go c.OnConnected() + } + }) publisherHandler.OnOfferCalls(c.onOffer) publisherHandler.OnFullyEstablishedCalls(func() { logger.Debugw("publisher fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) @@ -316,56 +332,74 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { return nil, err } - if err := c.subscriber.CreateReadableDataChannel("subraw", &webrtc.DataChannelInit{ - Ordered: &ordered, - }); err != nil { - return nil, err - } - - subscriberHandler.OnICECandidateCalls(func(ic *webrtc.ICECandidate, t livekit.SignalTarget) error { - if ic == nil { - return nil - } - return c.SendIceCandidate(ic, livekit.SignalTarget_SUBSCRIBER) - }) - subscriberHandler.OnTrackCalls(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { - go c.processTrack(track) - }) - subscriberHandler.OnDataMessageCalls(c.handleDataMessage) - subscriberHandler.OnDataMessageUnlabeledCalls(c.handleDataMessageUnlabeled) - subscriberHandler.OnInitialConnectedCalls(func() { - logger.Debugw("subscriber initial connected", "participant", c.localParticipant.Identity) - - c.lock.Lock() - defer c.lock.Unlock() - for _, tw := range c.pendingTrackWriters { - if err := tw.Start(); err != nil { - logger.Errorw("track writer error", err) - } - } - - c.pendingTrackWriters = nil - - if c.OnConnected != nil { - go c.OnConnected() - } - }) - subscriberHandler.OnFullyEstablishedCalls(func() { - logger.Debugw("subscriber fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) - c.subscriberFullyEstablished.Store(true) - }) - subscriberHandler.OnAnswerCalls(func(answer webrtc.SessionDescription, answerId uint32) error { - // send remote an answer - logger.Infow("sending subscriber answer", - "participant", c.localParticipant.Identity, - // "sdp", answer, - ) - return c.SendRequest(&livekit.SignalRequest{ - Message: &livekit.SignalRequest_Answer{ - Answer: signalling.ToProtoSessionDescription(answer, answerId), - }, + if !c.useSinglePeerConnection { + subscriberHandler := &transportfakes.FakeHandler{} + c.subscriber, err = rtc.NewPCTransport(rtc.TransportParams{ + Config: &conf, + DirectionConfig: conf.Publisher, + EnabledCodecs: codecs, + Handler: subscriberHandler, + DatachannelMaxReceiverBufferSize: 1500, + DatachannelSlowThreshold: 1024 * 1024 * 1024, + FireOnTrackBySdp: true, }) - }) + if err != nil { + return nil, err + } + + ordered := false + if err := c.subscriber.CreateReadableDataChannel("subraw", &webrtc.DataChannelInit{ + Ordered: &ordered, + }); err != nil { + return nil, err + } + + subscriberHandler.OnICECandidateCalls(func(ic *webrtc.ICECandidate, t livekit.SignalTarget) error { + if ic == nil { + return nil + } + return c.SendIceCandidate(ic, livekit.SignalTarget_SUBSCRIBER) + }) + subscriberHandler.OnTrackCalls(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + go c.processRemoteTrack(track) + }) + subscriberHandler.OnDataMessageCalls(c.handleDataMessage) + subscriberHandler.OnDataMessageUnlabeledCalls(c.handleDataMessageUnlabeled) + subscriberHandler.OnInitialConnectedCalls(func() { + logger.Debugw("subscriber initial connected", "participant", c.localParticipant.Identity) + + c.lock.Lock() + defer c.lock.Unlock() + for _, tw := range c.pendingTrackWriters { + if err := tw.Start(); err != nil { + logger.Errorw("track writer error", err) + } + } + + c.pendingTrackWriters = nil + + if c.OnConnected != nil { + go c.OnConnected() + } + }) + subscriberHandler.OnFullyEstablishedCalls(func() { + logger.Debugw("subscriber fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) + c.subscriberFullyEstablished.Store(true) + }) + subscriberHandler.OnAnswerCalls(func(answer webrtc.SessionDescription, answerId uint32) error { + // send remote an answer + logger.Infow( + "sending subscriber answer", + "participant", c.localParticipant.Identity, + "sdp", answer, + ) + return c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_Answer{ + Answer: signalling.ToProtoSessionDescription(answer, answerId), + }, + }) + }) + } if opts != nil { c.signalRequestInterceptor = opts.SignalRequestInterceptor @@ -428,15 +462,20 @@ func (c *RTCClient) handleSignalResponse(res *livekit.SignalResponse) error { logger.Infow("join accepted, awaiting offer", "participant", msg.Join.Participant.Identity) case *livekit.SignalResponse_Answer: - // logger.Debugw("received server answer", - // "participant", c.localParticipant.Identity, - // "answer", msg.Answer.Sdp) + logger.Infow( + "received server answer", + "participant", c.localParticipant.Identity, + "answer", msg.Answer.Sdp, + ) c.handleAnswer(signalling.FromProtoSessionDescription(msg.Answer)) case *livekit.SignalResponse_Offer: - logger.Infow("received server offer", - "participant", c.localParticipant.Identity, - ) desc, offerId := signalling.FromProtoSessionDescription(msg.Offer) + logger.Infow( + "received server offer", + "participant", c.localParticipant.Identity, + "sdp", desc, + "offerId", offerId, + ) c.handleOffer(desc, offerId) case *livekit.SignalResponse_Trickle: candidateInit, err := signalling.FromProtoTrickle(msg.Trickle) @@ -474,8 +513,7 @@ func (c *RTCClient) handleSignalResponse(res *livekit.SignalResponse) error { case *livekit.SignalResponse_TrackUnpublished: sid := msg.TrackUnpublished.TrackSid c.lock.Lock() - sender := c.trackSenders[sid] - if sender != nil { + if sender := c.trackSenders[sid]; sender != nil { if err := c.publisher.RemoveTrack(sender); err != nil { logger.Errorw("Could not unpublish track", err) } @@ -488,6 +526,14 @@ func (c *RTCClient) handleSignalResponse(res *livekit.SignalResponse) error { c.pongReceivedAt.Store(msg.Pong) case *livekit.SignalResponse_SubscriptionResponse: c.subscriptionResponse.Store(msg.SubscriptionResponse) + case *livekit.SignalResponse_MediaSectionsRequirement: + logger.Infow( + "received media sections requirement", + "participant", c.localParticipant.Identity, + "numAudios", msg.MediaSectionsRequirement.NumAudios, + "numVideos", msg.MediaSectionsRequirement.NumVideos, + ) + c.handleMediaSectionsRequirement(msg.MediaSectionsRequirement) } return nil } @@ -580,8 +626,12 @@ func (c *RTCClient) Stop() { c.publisherFullyEstablished.Store(false) c.subscriberFullyEstablished.Store(false) _ = c.conn.Close() - c.publisher.Close() - c.subscriber.Close() + if c.publisher != nil { + c.publisher.Close() + } + if c.subscriber != nil { + c.subscriber.Close() + } c.cancel() } @@ -679,7 +729,18 @@ func (c *RTCClient) AddTrack(track *webrtc.TrackLocalStaticSample, path string, trackType = livekit.TrackType_VIDEO } - if err = c.SendAddTrack(track.ID(), track.StreamID(), trackType); err != nil { + sender, _, err := c.publisher.AddTrack(track, types.AddTrackParams{}, nil, rtc.RTCPFeedbackConfig{}) + if err != nil { + logger.Errorw( + "add track failed", err, + "participant", c.localParticipant.Identity, + "pID", c.localParticipant.Sid, + "trackID", track.ID(), + ) + return + } + + if err = c.SendAddTrack(track.ID(), track.Codec().MimeType, track.StreamID(), trackType); err != nil { return } @@ -693,10 +754,12 @@ func (c *RTCClient) AddTrack(track *webrtc.TrackLocalStaticSample, path string, default: c.lock.Lock() ti = c.pendingPublishedTracks[track.ID()] - c.lock.Unlock() if ti != nil { + delete(c.pendingPublishedTracks, track.ID()) + c.lock.Unlock() break } + c.lock.Unlock() time.Sleep(50 * time.Millisecond) } if ti != nil { @@ -707,11 +770,6 @@ func (c *RTCClient) AddTrack(track *webrtc.TrackLocalStaticSample, path string, c.lock.Lock() defer c.lock.Unlock() - sender, _, err := c.publisher.AddTrack(track, types.AddTrackParams{}) - if err != nil { - logger.Errorw("add track failed", err, "trackID", ti.Sid, "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) - return - } c.localTracks[ti.Sid] = track c.trackSenders[ti.Sid] = sender c.publisher.Negotiate(false) @@ -750,9 +808,7 @@ func (c *RTCClient) AddFileTrack(path string, id string, label string) (writer * return nil, fmt.Errorf("%s has an unsupported extension", filepath.Base(path)) } - logger.Debugw("adding file track", - "mime", mime, - ) + logger.Debugw("adding file track", "mime", mime) track, err := webrtc.NewTrackLocalStaticSample( webrtc.RTPCodecCapability{MimeType: mime}, @@ -767,13 +823,19 @@ func (c *RTCClient) AddFileTrack(path string, id string, label string) (writer * } // send AddTrack command to server to initiate server-side negotiation -func (c *RTCClient) SendAddTrack(cid string, name string, trackType livekit.TrackType) error { +func (c *RTCClient) SendAddTrack(cid string, mimeType string, name string, trackType livekit.TrackType) error { return c.SendRequest(&livekit.SignalRequest{ Message: &livekit.SignalRequest_AddTrack{ AddTrack: &livekit.AddTrackRequest{ Cid: cid, Name: name, Type: trackType, + SimulcastCodecs: []*livekit.SimulcastCodec{ + { + Cid: cid, + Codec: mimeType, + }, + }, }, }, }) @@ -816,7 +878,7 @@ func (c *RTCClient) GetPublishedTrackIDs() []string { // LastAnswer return SDP of the last answer for the publisher connection func (c *RTCClient) LastAnswer() *webrtc.SessionDescription { - return c.lastAnswer.Load() + return c.publisher.CurrentRemoteDescription() } func (c *RTCClient) ensurePublisherConnected() error { @@ -864,21 +926,58 @@ func (c *RTCClient) handleDataMessageUnlabeled(data []byte) { // handles a server initiated offer, handle on subscriber PC func (c *RTCClient) handleOffer(desc webrtc.SessionDescription, offerId uint32) { + logger.Infow("handling server offer", "participant", c.localParticipant.Identity) c.subscriber.HandleRemoteDescription(desc, offerId) + c.processPendingRemoteTracks() } // the client handles answer on the publisher PC func (c *RTCClient) handleAnswer(desc webrtc.SessionDescription, answerId uint32) { logger.Infow("handling server answer", "participant", c.localParticipant.Identity) - c.lastAnswer.Store(&desc) // remote answered the offer, establish connection c.publisher.HandleRemoteDescription(desc, answerId) + c.processPendingRemoteTracks() +} + +// the client handles media sections requirement on the publisher PC +func (c *RTCClient) handleMediaSectionsRequirement(mediaSectionsRequirement *livekit.MediaSectionsRequirement) { + addTransceivers := func(kind webrtc.RTPCodecType, count uint32) { + for i := uint32(0); i < count; i++ { + if _, err := c.publisher.AddTransceiverFromKind( + kind, + webrtc.RTPTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionRecvonly, + }, + ); err != nil { + logger.Warnw( + "could not add transceiver", err, + "participant", c.localParticipant.Identity, + "kind", kind, + ) + } else { + logger.Infow( + "added transceiver of kind", + "participant", c.localParticipant.Identity, + "kind", kind, + ) + } + } + } + + addTransceivers(webrtc.RTPCodecTypeAudio, mediaSectionsRequirement.NumAudios) + addTransceivers(webrtc.RTPCodecTypeVideo, mediaSectionsRequirement.NumVideos) + c.publisher.Negotiate(false) } func (c *RTCClient) onOffer(offer webrtc.SessionDescription, offerId uint32) error { if c.localParticipant != nil { logger.Infow("starting negotiation", "participant", c.localParticipant.Identity) + logger.Infow( + "sending publisher offer", + "participant", c.localParticipant.Identity, + "offer", offer, + ) } return c.SendRequest(&livekit.SignalRequest{ Message: &livekit.SignalRequest_Offer{ @@ -887,25 +986,58 @@ func (c *RTCClient) onOffer(offer webrtc.SessionDescription, offerId uint32) err }) } -func (c *RTCClient) processTrack(track *webrtc.TrackRemote) { - lastUpdate := time.Time{} - pId, trackId := rtc.UnpackStreamID(track.StreamID()) - if trackId == "" { - trackId = livekit.TrackID(track.ID()) - } +func (c *RTCClient) processPendingRemoteTracks() { c.lock.Lock() - c.subscribedTracks[pId] = append(c.subscribedTracks[pId], track) + pendingRemoteTracks := c.pendingRemoteTracks + c.pendingRemoteTracks = nil c.lock.Unlock() - logger.Infow("client added track", "participant", c.localParticipant.Identity, - "pID", pId, - "trackID", trackId, + for _, pendingRemoteTrack := range pendingRemoteTracks { + go c.processRemoteTrack(pendingRemoteTrack) + } +} + +func (c *RTCClient) processRemoteTrack(track *webrtc.TrackRemote) { + lastUpdate := time.Time{} + + // because of FireOnTrackBySdp, it is possible get an empty streamID + // if media comes before SDP, cache and try later + streamID := track.StreamID() + if streamID == "" { + logger.Infow( + "client caching track", + "participant", c.localParticipant.Identity, + "pID", c.ID(), + "codec", track.Codec(), + "ssrc", track.SSRC(), + ) + c.lock.Lock() + c.pendingRemoteTracks = append(c.pendingRemoteTracks, track) + c.lock.Unlock() + return + } + + publisherID, trackID := rtc.UnpackStreamID(streamID) + if trackID == "" { + trackID = livekit.TrackID(track.ID()) + } + c.lock.Lock() + c.subscribedTracks[publisherID] = append(c.subscribedTracks[publisherID], track) + c.lock.Unlock() + + logger.Infow( + "client added track", + "participant", c.localParticipant.Identity, + "pID", c.ID(), + "publisherID", publisherID, + "trackID", trackID, "codec", track.Codec(), + "ssrc", track.SSRC(), ) defer func() { c.lock.Lock() - c.subscribedTracks[pId] = funk.Without(c.subscribedTracks[pId], track).([]*webrtc.TrackRemote) + c.subscribedTracks[publisherID] = funk.Without(c.subscribedTracks[publisherID], track).([]*webrtc.TrackRemote) c.lock.Unlock() }() @@ -916,6 +1048,15 @@ func (c *RTCClient) processTrack(track *webrtc.TrackRemote) { break } if rtc.IsEOF(err) { + logger.Infow( + "client track removed", + "participant", c.localParticipant.Identity, + "pID", c.ID(), + "publisherID", publisherID, + "trackID", trackID, + "codec", track.Codec(), + "ssrc", track.SSRC(), + ) break } if err != nil { @@ -923,14 +1064,19 @@ func (c *RTCClient) processTrack(track *webrtc.TrackRemote) { continue } c.lock.Lock() - c.lastPackets[pId] = pkt - c.bytesReceived[pId] += uint64(pkt.MarshalSize()) + c.lastPackets[publisherID] = pkt + c.bytesReceived[publisherID] += uint64(pkt.MarshalSize()) c.lock.Unlock() numBytes += pkt.MarshalSize() if time.Since(lastUpdate) > 30*time.Second { - logger.Infow("consumed from participant", - "trackID", trackId, "pID", pId, - "size", numBytes) + logger.Infow( + "consumed from participant", + "participant", c.localParticipant.Identity, + "pID", c.ID(), + "publisherID", publisherID, + "trackID", trackID, + "size", numBytes, + ) lastUpdate = time.Now() } } diff --git a/test/integration_helpers.go b/test/integration_helpers.go index abb6bc149..6fed71acd 100644 --- a/test/integration_helpers.go +++ b/test/integration_helpers.go @@ -202,35 +202,32 @@ func createMultiNodeServer(nodeID string, port uint32) *service.LivekitServer { } // creates a client and runs against server -func createRTCClient(name string, port int, opts *testclient.Options) *testclient.RTCClient { +func createRTCClient(name string, port int, useSinglePeerConnection bool, opts *testclient.Options) *testclient.RTCClient { var customizer func(token *auth.AccessToken, grants *auth.VideoGrant) if opts != nil { customizer = opts.TokenCustomizer } token := joinToken(testRoom, name, customizer) - ws, err := testclient.NewWebSocketConn(fmt.Sprintf("ws://localhost:%d", port), token, opts) - if err != nil { - panic(err) - } - c, err := testclient.NewRTCClient(ws, opts) - if err != nil { - panic(err) - } - - go c.Run() - - return c + return createRTCClientWithToken(token, port, useSinglePeerConnection, opts) } // creates a client and runs against server -func createRTCClientWithToken(token string, port int, opts *testclient.Options) *testclient.RTCClient { +func createRTCClientWithToken(token string, port int, useSinglePeerConnection bool, opts *testclient.Options) *testclient.RTCClient { + if opts == nil { + opts = &testclient.Options{ + AutoSubscribe: true, + } + } + if useSinglePeerConnection { + opts.UseJoinRequestQueryParam = true + } ws, err := testclient.NewWebSocketConn(fmt.Sprintf("ws://localhost:%d", port), token, opts) if err != nil { panic(err) } - c, err := testclient.NewRTCClient(ws, opts) + c, err := testclient.NewRTCClient(ws, useSinglePeerConnection, opts) if err != nil { panic(err) } diff --git a/test/multinode_roomservice_test.go b/test/multinode_roomservice_test.go index a1ec1cb8b..7bec21b4b 100644 --- a/test/multinode_roomservice_test.go +++ b/test/multinode_roomservice_test.go @@ -61,24 +61,28 @@ func TestMultiNodeUpdateRoomMetadata(t *testing.T) { }) t.Run("when room has a participant", func(t *testing.T) { - _, _, finish := setupMultiNodeTest("TestMultiNodeUpdateRoomMetadata_with_participant") - defer finish() + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeUpdateRoomMetadata_with_participant") + defer finish() - c1 := createRTCClient("c1", defaultServerPort, nil) - waitUntilConnected(t, c1) - defer c1.Stop() + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1) + defer c1.Stop() - _, err := roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ - Name: "emptyRoom", - }) - require.NoError(t, err) + _, err := roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ + Name: "emptyRoom", + }) + require.NoError(t, err) - rm, err := roomClient.UpdateRoomMetadata(contextWithToken(adminRoomToken("emptyRoom")), &livekit.UpdateRoomMetadataRequest{ - Room: "emptyRoom", - Metadata: "updated metadata", - }) - require.NoError(t, err) - require.Equal(t, "updated metadata", rm.Metadata) + rm, err := roomClient.UpdateRoomMetadata(contextWithToken(adminRoomToken("emptyRoom")), &livekit.UpdateRoomMetadataRequest{ + Room: "emptyRoom", + Metadata: "updated metadata", + }) + require.NoError(t, err) + require.Equal(t, "updated metadata", rm.Metadata) + }) + } }) } @@ -89,85 +93,97 @@ func TestMultiNodeRemoveParticipant(t *testing.T) { return } - _, _, finish := setupMultiNodeTest("TestMultiNodeRemoveParticipant") - defer finish() + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeRemoveParticipant") + defer finish() - c1 := createRTCClient("mn_remove_participant", defaultServerPort, nil) - defer c1.Stop() - waitUntilConnected(t, c1) + c1 := createRTCClient("mn_remove_participant", defaultServerPort, useSinglePeerConnection, nil) + defer c1.Stop() + waitUntilConnected(t, c1) - ctx := contextWithToken(adminRoomToken(testRoom)) - _, err := roomClient.RemoveParticipant(ctx, &livekit.RoomParticipantIdentity{ - Room: testRoom, - Identity: "mn_remove_participant", - }) - require.NoError(t, err) + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err := roomClient.RemoveParticipant(ctx, &livekit.RoomParticipantIdentity{ + Room: testRoom, + Identity: "mn_remove_participant", + }) + require.NoError(t, err) - // participant list doesn't show the participant - listRes, err := roomClient.ListParticipants(ctx, &livekit.ListParticipantsRequest{ - Room: testRoom, - }) - require.NoError(t, err) - require.Len(t, listRes.Participants, 0) + // participant list doesn't show the participant + listRes, err := roomClient.ListParticipants(ctx, &livekit.ListParticipantsRequest{ + Room: testRoom, + }) + require.NoError(t, err) + require.Len(t, listRes.Participants, 0) + }) + } } // update participant metadata func TestMultiNodeUpdateParticipantMetadata(t *testing.T) { - _, _, finish := setupMultiNodeTest("TestMultiNodeUpdateParticipantMetadata") - defer finish() + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeUpdateParticipantMetadata") + defer finish() - c1 := createRTCClient("update_participant_metadata", defaultServerPort, nil) - defer c1.Stop() - waitUntilConnected(t, c1) + c1 := createRTCClient("update_participant_metadata", defaultServerPort, useSinglePeerConnection, nil) + defer c1.Stop() + waitUntilConnected(t, c1) - ctx := contextWithToken(adminRoomToken(testRoom)) - res, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ - Room: testRoom, - Identity: "update_participant_metadata", - Metadata: "the new metadata", - }) - require.NoError(t, err) - require.Equal(t, "the new metadata", res.Metadata) + ctx := contextWithToken(adminRoomToken(testRoom)) + res, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "update_participant_metadata", + Metadata: "the new metadata", + }) + require.NoError(t, err) + require.Equal(t, "the new metadata", res.Metadata) + }) + } } // admin mute published track func TestMultiNodeMutePublishedTrack(t *testing.T) { - _, _, finish := setupMultiNodeTest("TestMultiNodeMutePublishedTrack") - defer finish() + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeMutePublishedTrack") + defer finish() - identity := "mute_published_track" - c1 := createRTCClient(identity, defaultServerPort, nil) - defer c1.Stop() - waitUntilConnected(t, c1) + identity := "mute_published_track" + c1 := createRTCClient(identity, defaultServerPort, useSinglePeerConnection, nil) + defer c1.Stop() + waitUntilConnected(t, c1) - writers := publishTracksForClients(t, c1) - defer stopWriters(writers...) + writers := publishTracksForClients(t, c1) + defer stopWriters(writers...) - trackIDs := c1.GetPublishedTrackIDs() - require.NotEmpty(t, trackIDs) + trackIDs := c1.GetPublishedTrackIDs() + require.NotEmpty(t, trackIDs) - ctx := contextWithToken(adminRoomToken(testRoom)) - // wait for it to be published before - testutils.WithTimeout(t, func() string { - res, err := roomClient.GetParticipant(ctx, &livekit.RoomParticipantIdentity{ - Room: testRoom, - Identity: identity, + ctx := contextWithToken(adminRoomToken(testRoom)) + // wait for it to be published before + testutils.WithTimeout(t, func() string { + res, err := roomClient.GetParticipant(ctx, &livekit.RoomParticipantIdentity{ + Room: testRoom, + Identity: identity, + }) + require.NoError(t, err) + if len(res.Tracks) == 2 { + return "" + } else { + return fmt.Sprintf("expected 2 tracks to be published, actual: %d", len(res.Tracks)) + } + }) + + res, err := roomClient.MutePublishedTrack(ctx, &livekit.MuteRoomTrackRequest{ + Room: testRoom, + Identity: identity, + TrackSid: trackIDs[0], + Muted: true, + }) + require.NoError(t, err) + require.Equal(t, trackIDs[0], res.Track.Sid) + require.True(t, res.Track.Muted) }) - require.NoError(t, err) - if len(res.Tracks) == 2 { - return "" - } else { - return fmt.Sprintf("expected 2 tracks to be published, actual: %d", len(res.Tracks)) - } - }) - - res, err := roomClient.MutePublishedTrack(ctx, &livekit.MuteRoomTrackRequest{ - Room: testRoom, - Identity: identity, - TrackSid: trackIDs[0], - Muted: true, - }) - require.NoError(t, err) - require.Equal(t, trackIDs[0], res.Track.Sid) - require.True(t, res.Track.Muted) + } } diff --git a/test/multinode_test.go b/test/multinode_test.go index e79c8316c..25372e3af 100644 --- a/test/multinode_test.go +++ b/test/multinode_test.go @@ -15,6 +15,7 @@ package test import ( + "fmt" "testing" "time" @@ -33,6 +34,7 @@ func TestMultiNodeRouting(t *testing.T) { t.SkipNow() return } + _, _, finish := setupMultiNodeTest("TestMultiNodeRouting") defer finish() @@ -42,35 +44,39 @@ func TestMultiNodeRouting(t *testing.T) { }) require.NoError(t, err) - // one node connecting to node 1, and another connecting to node 2 - c1 := createRTCClient("c1", defaultServerPort, nil) - c2 := createRTCClient("c2", secondServerPort, nil) - waitUntilConnected(t, c1, c2) - defer stopClients(c1, c2) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + // one node connecting to node 1, and another connecting to node 2 + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("c2", secondServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1, c2) + defer stopClients(c1, c2) - // c1 publishing, and c2 receiving - t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") - require.NoError(t, err) - if t1 != nil { - defer t1.Stop() + // c1 publishing, and c2 receiving + t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + if t1 != nil { + defer t1.Stop() + } + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 received no tracks" + } + if len(c2.SubscribedTracks()[c1.ID()]) != 1 { + return "c2 didn't receive track published by c1" + } + tr1 := c2.SubscribedTracks()[c1.ID()][0] + streamID, _ := rtc.UnpackStreamID(tr1.StreamID()) + require.Equal(t, c1.ID(), streamID) + return "" + }) + + remoteC1 := c2.GetRemoteParticipant(c1.ID()) + require.Equal(t, "c1", remoteC1.Name) + require.Equal(t, "metadatac1", remoteC1.Metadata) + }) } - - testutils.WithTimeout(t, func() string { - if len(c2.SubscribedTracks()) == 0 { - return "c2 received no tracks" - } - if len(c2.SubscribedTracks()[c1.ID()]) != 1 { - return "c2 didn't receive track published by c1" - } - tr1 := c2.SubscribedTracks()[c1.ID()][0] - streamID, _ := rtc.UnpackStreamID(tr1.StreamID()) - require.Equal(t, c1.ID(), streamID) - return "" - }) - - remoteC1 := c2.GetRemoteParticipant(c1.ID()) - require.Equal(t, "c1", remoteC1.Name) - require.Equal(t, "metadatac1", remoteC1.Metadata) } func TestConnectWithoutCreation(t *testing.T) { @@ -82,10 +88,14 @@ func TestConnectWithoutCreation(t *testing.T) { _, _, finish := setupMultiNodeTest("TestConnectWithoutCreation") defer finish() - c1 := createRTCClient("c1", defaultServerPort, nil) - waitUntilConnected(t, c1) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1) - c1.Stop() + c1.Stop() + }) + } } // testing multiple scenarios rooms @@ -118,30 +128,34 @@ func TestMultinodeReconnectAfterNodeShutdown(t *testing.T) { return } - _, s2, finish := setupMultiNodeTest("TestMultinodeReconnectAfterNodeShutdown") - defer finish() + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + _, s2, finish := setupMultiNodeTest("TestMultinodeReconnectAfterNodeShutdown") + defer finish() - // creating room on node 1 - _, err := roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ - Name: testRoom, - NodeId: s2.Node().Id, - }) - require.NoError(t, err) + // creating room on node 1 + _, err := roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ + Name: testRoom, + NodeId: s2.Node().Id, + }) + require.NoError(t, err) - // one node connecting to node 1, and another connecting to node 2 - c1 := createRTCClient("c1", defaultServerPort, nil) - c2 := createRTCClient("c2", secondServerPort, nil) + // one node connecting to node 1, and another connecting to node 2 + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("c2", secondServerPort, useSinglePeerConnection, nil) - waitUntilConnected(t, c1, c2) - stopClients(c1, c2) + waitUntilConnected(t, c1, c2) + stopClients(c1, c2) - // stop s2, and connect to room again - s2.Stop(true) + // stop s2, and connect to room again + s2.Stop(true) - time.Sleep(syncDelay) + time.Sleep(syncDelay) - c3 := createRTCClient("c3", defaultServerPort, nil) - waitUntilConnected(t, c3) + c3 := createRTCClient("c3", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c3) + }) + } } func TestMultinodeDataPublishing(t *testing.T) { @@ -186,48 +200,52 @@ func TestMultiNodeRefreshToken(t *testing.T) { _, _, finish := setupMultiNodeTest("TestMultiNodeJoinAfterClose") defer finish() - // a participant joining with full permissions - c1 := createRTCClient("c1", defaultServerPort, nil) - waitUntilConnected(t, c1) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + // a participant joining with full permissions + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1) - // update permissions and metadata - ctx := contextWithToken(adminRoomToken(testRoom)) - _, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ - Room: testRoom, - Identity: "c1", - Permission: &livekit.ParticipantPermission{ - CanPublish: false, - CanSubscribe: true, - }, - Metadata: "metadata", - }) - require.NoError(t, err) + // update permissions and metadata + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "c1", + Permission: &livekit.ParticipantPermission{ + CanPublish: false, + CanSubscribe: true, + }, + Metadata: "metadata", + }) + require.NoError(t, err) - testutils.WithTimeout(t, func() string { - if c1.RefreshToken() == "" { - return "did not receive refresh token" - } - // parse token to ensure it's correct - verifier, err := auth.ParseAPIToken(c1.RefreshToken()) - require.NoError(t, err) + testutils.WithTimeout(t, func() string { + if c1.RefreshToken() == "" { + return "did not receive refresh token" + } + // parse token to ensure it's correct + verifier, err := auth.ParseAPIToken(c1.RefreshToken()) + require.NoError(t, err) - grants, err := verifier.Verify(testApiSecret) - require.NoError(t, err) + grants, err := verifier.Verify(testApiSecret) + require.NoError(t, err) - if grants.Metadata != "metadata" { - return "metadata did not match" - } - if *grants.Video.CanPublish { - return "canPublish should be false" - } - if *grants.Video.CanPublishData { - return "canPublishData should be false" - } - if !*grants.Video.CanSubscribe { - return "canSubscribe should be true" - } - return "" - }) + if grants.Metadata != "metadata" { + return "metadata did not match" + } + if *grants.Video.CanPublish { + return "canPublish should be false" + } + if *grants.Video.CanPublishData { + return "canPublishData should be false" + } + if !*grants.Video.CanSubscribe { + return "canSubscribe should be true" + } + return "" + }) + }) + } } // ensure that token accurately reflects out of band updates @@ -240,158 +258,170 @@ func TestMultiNodeUpdateAttributes(t *testing.T) { _, _, finish := setupMultiNodeTest("TestMultiNodeUpdateAttributes") defer finish() - c1 := createRTCClient("au1", defaultServerPort, &client.Options{ - TokenCustomizer: func(token *auth.AccessToken, grants *auth.VideoGrant) { - token.SetAttributes(map[string]string{ - "mykey": "au1", + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("au1", defaultServerPort, useSinglePeerConnection, &client.Options{ + TokenCustomizer: func(token *auth.AccessToken, grants *auth.VideoGrant) { + token.SetAttributes(map[string]string{ + "mykey": "au1", + }) + }, }) - }, - }) - c2 := createRTCClient("au2", secondServerPort, &client.Options{ - TokenCustomizer: func(token *auth.AccessToken, grants *auth.VideoGrant) { - token.SetAttributes(map[string]string{ - "mykey": "au2", + c2 := createRTCClient("au2", secondServerPort, useSinglePeerConnection, &client.Options{ + TokenCustomizer: func(token *auth.AccessToken, grants *auth.VideoGrant) { + token.SetAttributes(map[string]string{ + "mykey": "au2", + }) + grants.SetCanUpdateOwnMetadata(true) + }, }) - grants.SetCanUpdateOwnMetadata(true) - }, - }) - waitUntilConnected(t, c1, c2) + waitUntilConnected(t, c1, c2) - testutils.WithTimeout(t, func() string { - rc2 := c1.GetRemoteParticipant(c2.ID()) - rc1 := c2.GetRemoteParticipant(c1.ID()) - if rc2 == nil || rc1 == nil { - return "participants could not see each other" - } - if rc1.Attributes == nil || rc1.Attributes["mykey"] != "au1" { - return "rc1's initial attributes are incorrect" - } - if rc2.Attributes == nil || rc2.Attributes["mykey"] != "au2" { - return "rc2's initial attributes are incorrect" - } - return "" - }) + testutils.WithTimeout(t, func() string { + rc2 := c1.GetRemoteParticipant(c2.ID()) + rc1 := c2.GetRemoteParticipant(c1.ID()) + if rc2 == nil || rc1 == nil { + return "participants could not see each other" + } + if rc1.Attributes == nil || rc1.Attributes["mykey"] != "au1" { + return "rc1's initial attributes are incorrect" + } + if rc2.Attributes == nil || rc2.Attributes["mykey"] != "au2" { + return "rc2's initial attributes are incorrect" + } + return "" + }) - // this one should not go through - _ = c1.SetAttributes(map[string]string{"mykey": "shouldnotchange"}) - _ = c2.SetAttributes(map[string]string{"secondkey": "au2"}) + // this one should not go through + _ = c1.SetAttributes(map[string]string{"mykey": "shouldnotchange"}) + _ = c2.SetAttributes(map[string]string{"secondkey": "au2"}) - // updates using room API should succeed - _, err := roomClient.UpdateParticipant(contextWithToken(adminRoomToken(testRoom)), &livekit.UpdateParticipantRequest{ - Room: testRoom, - Identity: "au1", - Attributes: map[string]string{ - "secondkey": "au1", - }, - }) - require.NoError(t, err) + // updates using room API should succeed + _, err := roomClient.UpdateParticipant(contextWithToken(adminRoomToken(testRoom)), &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "au1", + Attributes: map[string]string{ + "secondkey": "au1", + }, + }) + require.NoError(t, err) - testutils.WithTimeout(t, func() string { - rc1 := c2.GetRemoteParticipant(c1.ID()) - rc2 := c1.GetRemoteParticipant(c2.ID()) - if rc1.Attributes["secondkey"] != "au1" { - return "au1's attribute update failed" - } - if rc2.Attributes["secondkey"] != "au2" { - return "au2's attribute update failed" - } - if rc1.Attributes["mykey"] != "au1" { - return "au1's mykey should not change" - } - if rc2.Attributes["mykey"] != "au2" { - return "au2's mykey should not change" - } - return "" - }) + testutils.WithTimeout(t, func() string { + rc1 := c2.GetRemoteParticipant(c1.ID()) + rc2 := c1.GetRemoteParticipant(c2.ID()) + if rc1.Attributes["secondkey"] != "au1" { + return "au1's attribute update failed" + } + if rc2.Attributes["secondkey"] != "au2" { + return "au2's attribute update failed" + } + if rc1.Attributes["mykey"] != "au1" { + return "au1's mykey should not change" + } + if rc2.Attributes["mykey"] != "au2" { + return "au2's mykey should not change" + } + return "" + }) + }) + } } func TestMultiNodeRevokePublishPermission(t *testing.T) { _, _, finish := setupMultiNodeTest("TestMultiNodeRevokePublishPermission") defer finish() - c1 := createRTCClient("c1", defaultServerPort, nil) - c2 := createRTCClient("c2", secondServerPort, nil) - waitUntilConnected(t, c1, c2) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("c2", secondServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1, c2) - // c1 publishes a track for c2 - writers := publishTracksForClients(t, c1) - defer stopWriters(writers...) + // c1 publishes a track for c2 + writers := publishTracksForClients(t, c1) + defer stopWriters(writers...) - testutils.WithTimeout(t, func() string { - if len(c2.SubscribedTracks()[c1.ID()]) != 2 { - return "c2 did not receive c1's tracks" - } - return "" - }) + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 did not receive c1's tracks" + } + return "" + }) - // revoke permission - ctx := contextWithToken(adminRoomToken(testRoom)) - _, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ - Room: testRoom, - Identity: "c1", - Permission: &livekit.ParticipantPermission{ - CanPublish: false, - CanPublishData: true, - CanSubscribe: true, - }, - }) - require.NoError(t, err) + // revoke permission + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "c1", + Permission: &livekit.ParticipantPermission{ + CanPublish: false, + CanPublishData: true, + CanSubscribe: true, + }, + }) + require.NoError(t, err) - // ensure c1 no longer has track published, c2 no longer see track under C1 - testutils.WithTimeout(t, func() string { - if len(c1.GetPublishedTrackIDs()) != 0 { - return "c1 did not unpublish tracks" - } - remoteC1 := c2.GetRemoteParticipant(c1.ID()) - if remoteC1 == nil { - return "c2 doesn't know about c1" - } - if len(remoteC1.Tracks) != 0 { - return "c2 still has c1's tracks" - } - return "" - }) + // ensure c1 no longer has track published, c2 no longer see track under C1 + testutils.WithTimeout(t, func() string { + if len(c1.GetPublishedTrackIDs()) != 0 { + return "c1 did not unpublish tracks" + } + remoteC1 := c2.GetRemoteParticipant(c1.ID()) + if remoteC1 == nil { + return "c2 doesn't know about c1" + } + if len(remoteC1.Tracks) != 0 { + return "c2 still has c1's tracks" + } + return "" + }) + }) + } } func TestCloseDisconnectedParticipantOnSignalClose(t *testing.T) { _, _, finish := setupMultiNodeTest("TestCloseDisconnectedParticipantOnSignalClose") defer finish() - c1 := createRTCClient("c1", secondServerPort, nil) - waitUntilConnected(t, c1) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("c1", secondServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1) - c2 := createRTCClient("c2", defaultServerPort, &client.Options{ - SignalRequestInterceptor: func(msg *livekit.SignalRequest, next client.SignalRequestHandler) error { - switch msg.Message.(type) { - case *livekit.SignalRequest_Offer, *livekit.SignalRequest_Answer, *livekit.SignalRequest_Leave: - return nil - default: - return next(msg) - } - }, - SignalResponseInterceptor: func(msg *livekit.SignalResponse, next client.SignalResponseHandler) error { - switch msg.Message.(type) { - case *livekit.SignalResponse_Offer, *livekit.SignalResponse_Answer: - return nil - default: - return next(msg) - } - }, - }) + c2 := createRTCClient("c2", defaultServerPort, useSinglePeerConnection, &client.Options{ + SignalRequestInterceptor: func(msg *livekit.SignalRequest, next client.SignalRequestHandler) error { + switch msg.Message.(type) { + case *livekit.SignalRequest_Offer, *livekit.SignalRequest_Answer, *livekit.SignalRequest_Leave: + return nil + default: + return next(msg) + } + }, + SignalResponseInterceptor: func(msg *livekit.SignalResponse, next client.SignalResponseHandler) error { + switch msg.Message.(type) { + case *livekit.SignalResponse_Offer, *livekit.SignalResponse_Answer: + return nil + default: + return next(msg) + } + }, + }) - testutils.WithTimeout(t, func() string { - if len(c1.RemoteParticipants()) != 1 { - return "c1 did not see c2 join" - } - return "" - }) + testutils.WithTimeout(t, func() string { + if len(c1.RemoteParticipants()) != 1 { + return "c1 did not see c2 join" + } + return "" + }) - c2.Stop() + c2.Stop() - testutils.WithTimeout(t, func() string { - if len(c1.RemoteParticipants()) != 0 { - return "c1 did not see c2 removed" - } - return "" - }) + testutils.WithTimeout(t, func() string { + if len(c1.RemoteParticipants()) != 0 { + return "c1 did not see c2 removed" + } + return "" + }) + }) + } } diff --git a/test/scenarios.go b/test/scenarios.go index 9a00776fd..ef73aa412 100644 --- a/test/scenarios.go +++ b/test/scenarios.go @@ -31,179 +31,199 @@ import ( // a scenario with lots of clients connecting, publishing, and leaving at random periods func scenarioPublishingUponJoining(t *testing.T) { - c1 := createRTCClient("puj_1", defaultServerPort, nil) - c2 := createRTCClient("puj_2", secondServerPort, &testclient.Options{AutoSubscribe: true}) - c3 := createRTCClient("puj_3", defaultServerPort, &testclient.Options{AutoSubscribe: true}) - defer stopClients(c1, c2, c3) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("puj_1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("puj_2", secondServerPort, useSinglePeerConnection, &testclient.Options{AutoSubscribe: true}) + c3 := createRTCClient("puj_3", defaultServerPort, useSinglePeerConnection, &testclient.Options{AutoSubscribe: true}) + defer stopClients(c1, c2, c3) - waitUntilConnected(t, c1, c2, c3) + waitUntilConnected(t, c1, c2, c3) - // c1 and c2 publishing, c3 just receiving - writers := publishTracksForClients(t, c1, c2) - defer stopWriters(writers...) + // c1 and c2 publishing, c3 just receiving + writers := publishTracksForClients(t, c1, c2) + defer stopWriters(writers...) - logger.Infow("waiting to receive tracks from c1 and c2") - testutils.WithTimeout(t, func() string { - tracks := c3.SubscribedTracks() - if len(tracks[c1.ID()]) != 2 { - return "did not receive tracks from c1" - } - if len(tracks[c2.ID()]) != 2 { - return "did not receive tracks from c2" - } - return "" - }) + logger.Infow("waiting to receive tracks from c1 and c2") + testutils.WithTimeout(t, func() string { + tracks := c3.SubscribedTracks() + if len(tracks[c1.ID()]) != 2 { + return "did not receive tracks from c1" + } + if len(tracks[c2.ID()]) != 2 { + return "did not receive tracks from c2" + } + return "" + }) - // after a delay, c2 reconnects, then publishing - time.Sleep(syncDelay) - c2.Stop() + // after a delay, c2 reconnects, then publishing + time.Sleep(syncDelay) + c2.Stop() - logger.Infow("waiting for c2 tracks to be gone") - testutils.WithTimeout(t, func() string { - tracks := c3.SubscribedTracks() + logger.Infow("waiting for c2 tracks to be gone") + testutils.WithTimeout(t, func() string { + tracks := c3.SubscribedTracks() - if len(tracks[c1.ID()]) != 2 { - return fmt.Sprintf("c3 should be subscribed to 2 tracks from c1, actual: %d", len(tracks[c1.ID()])) - } - if len(tracks[c2.ID()]) != 0 { - return fmt.Sprintf("c3 should be subscribed to 0 tracks from c2, actual: %d", len(tracks[c2.ID()])) - } - if len(c1.SubscribedTracks()[c2.ID()]) != 0 { - return fmt.Sprintf("c3 should be subscribed to 0 tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) - } - return "" - }) + if len(tracks[c1.ID()]) != 2 { + return fmt.Sprintf("c3 should be subscribed to 2 tracks from c1, actual: %d", len(tracks[c1.ID()])) + } + if len(tracks[c2.ID()]) != 0 { + return fmt.Sprintf("c3 should be subscribed to 0 tracks from c2, actual: %d", len(tracks[c2.ID()])) + } + if len(c1.SubscribedTracks()[c2.ID()]) != 0 { + return fmt.Sprintf("c3 should be subscribed to 0 tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) + } + return "" + }) - logger.Infow("c2 reconnecting") - // connect to a diff port - c2 = createRTCClient("puj_2", defaultServerPort, nil) - defer c2.Stop() - waitUntilConnected(t, c2) - writers = publishTracksForClients(t, c2) - defer stopWriters(writers...) + logger.Infow("c2 reconnecting") + // connect to a diff port + c2 = createRTCClient("puj_2", defaultServerPort, useSinglePeerConnection, nil) + defer c2.Stop() + waitUntilConnected(t, c2) + writers = publishTracksForClients(t, c2) + defer stopWriters(writers...) - testutils.WithTimeout(t, func() string { - tracks := c3.SubscribedTracks() - // "new c2 tracks should be published again", - if len(tracks[c2.ID()]) != 2 { - return fmt.Sprintf("c3 should be subscribed to 2 tracks from c2, actual: %d", len(tracks[c2.ID()])) - } - if len(c1.SubscribedTracks()[c2.ID()]) != 2 { - return fmt.Sprintf("c1 should be subscribed to 2 tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) - } - return "" - }) + testutils.WithTimeout(t, func() string { + tracks := c3.SubscribedTracks() + // "new c2 tracks should be published again", + if len(tracks[c2.ID()]) != 2 { + return fmt.Sprintf("c3 should be subscribed to 2 tracks from c2, actual: %d", len(tracks[c2.ID()])) + } + if len(c1.SubscribedTracks()[c2.ID()]) != 2 { + return fmt.Sprintf("c1 should be subscribed to 2 tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) + } + return "" + }) + }) + } } func scenarioReceiveBeforePublish(t *testing.T) { - c1 := createRTCClient("rbp_1", defaultServerPort, nil) - c2 := createRTCClient("rbp_2", defaultServerPort, nil) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("rbp_1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("rbp_2", defaultServerPort, useSinglePeerConnection, nil) - waitUntilConnected(t, c1, c2) - defer stopClients(c1, c2) + waitUntilConnected(t, c1, c2) + defer stopClients(c1, c2) - // c1 publishes - writers := publishTracksForClients(t, c1) - defer stopWriters(writers...) + // c1 publishes + writers := publishTracksForClients(t, c1) + defer stopWriters(writers...) - // c2 should see some bytes flowing through - testutils.WithTimeout(t, func() string { - if c2.BytesReceived() > 20 { - return "" - } else { - return fmt.Sprintf("c2 only received %d bytes", c2.BytesReceived()) - } - }) + // c2 should see some bytes flowing through + testutils.WithTimeout(t, func() string { + if c2.BytesReceived() > 20 { + return "" + } else { + return fmt.Sprintf("c2 only received %d bytes", c2.BytesReceived()) + } + }) - // now publish on C2 - writers = publishTracksForClients(t, c2) - defer stopWriters(writers...) + // now publish on C2 + writers = publishTracksForClients(t, c2) + defer stopWriters(writers...) - testutils.WithTimeout(t, func() string { - if len(c1.SubscribedTracks()[c2.ID()]) == 2 { - return "" - } else { - return fmt.Sprintf("expected c1 to receive 2 tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) - } - }) + testutils.WithTimeout(t, func() string { + if len(c1.SubscribedTracks()[c2.ID()]) == 2 { + return "" + } else { + return fmt.Sprintf("expected c1 to receive 2 tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) + } + }) - // now leave, and ensure that it's immediate - c2.Stop() + // now leave, and ensure that it's immediate + c2.Stop() - testutils.WithTimeout(t, func() string { - if len(c1.RemoteParticipants()) > 0 { - return fmt.Sprintf("expected no remote participants, actual: %v", c1.RemoteParticipants()) - } - return "" - }) + testutils.WithTimeout(t, func() string { + if len(c1.RemoteParticipants()) > 0 { + return fmt.Sprintf("expected no remote participants, actual: %v", c1.RemoteParticipants()) + } + return "" + }) + }) + } } func scenarioDataPublish(t *testing.T) { - c1 := createRTCClient("dp1", defaultServerPort, nil) - c2 := createRTCClient("dp2", secondServerPort, nil) - waitUntilConnected(t, c1, c2) - defer stopClients(c1, c2) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("dp1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("dp2", secondServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1, c2) + defer stopClients(c1, c2) - payload := "test bytes" + payload := "test bytes" - received := atomic.NewBool(false) - c2.OnDataReceived = func(data []byte, sid string) { - if string(data) == payload && livekit.ParticipantID(sid) == c1.ID() { - received.Store(true) - } + received := atomic.NewBool(false) + c2.OnDataReceived = func(data []byte, sid string) { + if string(data) == payload && livekit.ParticipantID(sid) == c1.ID() { + received.Store(true) + } + } + + require.NoError(t, c1.PublishData([]byte(payload), livekit.DataPacket_RELIABLE)) + + testutils.WithTimeout(t, func() string { + if received.Load() { + return "" + } else { + return "c2 did not receive published data" + } + }) + }) } - - require.NoError(t, c1.PublishData([]byte(payload), livekit.DataPacket_RELIABLE)) - - testutils.WithTimeout(t, func() string { - if received.Load() { - return "" - } else { - return "c2 did not receive published data" - } - }) } func scenarioDataUnlabeledPublish(t *testing.T) { - c1 := createRTCClient("dp1", defaultServerPort, nil) - c2 := createRTCClient("dp2", secondServerPort, nil) - waitUntilConnected(t, c1, c2) - defer stopClients(c1, c2) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("dp1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("dp2", secondServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1, c2) + defer stopClients(c1, c2) - payload := "test unlabeled bytes" + payload := "test unlabeled bytes" - received := atomic.NewBool(false) - c2.OnDataReceived = func(data []byte, _sid string) { - if string(data) == payload { - received.Store(true) - } + received := atomic.NewBool(false) + c2.OnDataReceived = func(data []byte, _sid string) { + if string(data) == payload { + received.Store(true) + } + } + + require.NoError(t, c1.PublishDataUnlabeled([]byte(payload))) + + testutils.WithTimeout(t, func() string { + if received.Load() { + return "" + } else { + return "c2 did not receive published data unlabeled" + } + }) + }) } - - require.NoError(t, c1.PublishDataUnlabeled([]byte(payload))) - - testutils.WithTimeout(t, func() string { - if received.Load() { - return "" - } else { - return "c2 did not receive published data unlabeled" - } - }) } func scenarioJoinClosedRoom(t *testing.T) { - c1 := createRTCClient("jcr1", defaultServerPort, nil) - waitUntilConnected(t, c1) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("jcr1", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1) - // close room with room client - _, err := roomClient.DeleteRoom(contextWithToken(createRoomToken()), &livekit.DeleteRoomRequest{ - Room: testRoom, - }) - require.NoError(t, err) + // close room with room client + _, err := roomClient.DeleteRoom(contextWithToken(createRoomToken()), &livekit.DeleteRoomRequest{ + Room: testRoom, + }) + require.NoError(t, err) - // now join again - c2 := createRTCClient("jcr2", defaultServerPort, nil) - waitUntilConnected(t, c2) - stopClients(c2) + // now join again + c2 := createRTCClient("jcr2", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c2) + stopClients(c2) + }) + } } // close a room that has been created, but no participant has joined diff --git a/test/singlenode_test.go b/test/singlenode_test.go index fe76a8f94..510fcf08b 100644 --- a/test/singlenode_test.go +++ b/test/singlenode_test.go @@ -58,20 +58,24 @@ func TestClientCouldConnect(t *testing.T) { _, finish := setupSingleNodeTest("TestClientCouldConnect") defer finish() - c1 := createRTCClient("c1", defaultServerPort, nil) - c2 := createRTCClient("c2", defaultServerPort, nil) - waitUntilConnected(t, c1, c2) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("c2", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1, c2) - // ensure they both see each other - testutils.WithTimeout(t, func() string { - if len(c1.RemoteParticipants()) == 0 { - return "c1 did not see c2" - } - if len(c2.RemoteParticipants()) == 0 { - return "c2 did not see c1" - } - return "" - }) + // ensure they both see each other + testutils.WithTimeout(t, func() string { + if len(c1.RemoteParticipants()) == 0 { + return "c1 did not see c2" + } + if len(c2.RemoteParticipants()) == 0 { + return "c2 did not see c1" + } + return "" + }) + }) + } } func TestClientConnectDuplicate(t *testing.T) { @@ -80,68 +84,71 @@ func TestClientConnectDuplicate(t *testing.T) { return } - _, finish := setupSingleNodeTest("TestClientCouldConnect") + _, finish := setupSingleNodeTest("TestClientConnectDuplicate") defer finish() - grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} - grant.SetCanPublish(true) - grant.SetCanSubscribe(true) - token := joinTokenWithGrant("c1", grant) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} + grant.SetCanPublish(true) + grant.SetCanSubscribe(true) + token := joinTokenWithGrant("c1", grant) + c1 := createRTCClientWithToken(token, defaultServerPort, useSinglePeerConnection, nil) - c1 := createRTCClientWithToken(token, defaultServerPort, nil) + // publish 2 tracks + t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t2.Stop() - // publish 2 tracks - t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") - require.NoError(t, err) - defer t1.Stop() - t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") - require.NoError(t, err) - defer t2.Stop() + c2 := createRTCClient("c2", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1, c2) - c2 := createRTCClient("c2", defaultServerPort, nil) - waitUntilConnected(t, c1, c2) + opts := &testclient.Options{ + Publish: "duplicate_connection", + } + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 didn't subscribe to anything" + } + // should have received two tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 didn't subscribe to both tracks from c1" + } - opts := &testclient.Options{ - Publish: "duplicate_connection", + // participant ID can be appended with '#..' . but should contain orig id as prefix + tr1 := c2.SubscribedTracks()[c1.ID()][0] + participantId1, _ := rtc.UnpackStreamID(tr1.StreamID()) + require.Equal(t, c1.ID(), participantId1) + tr2 := c2.SubscribedTracks()[c1.ID()][1] + participantId2, _ := rtc.UnpackStreamID(tr2.StreamID()) + require.Equal(t, c1.ID(), participantId2) + return "" + }) + + c1Dup := createRTCClientWithToken(token, defaultServerPort, useSinglePeerConnection, opts) + + waitUntilConnected(t, c1Dup) + + t3, err := c1Dup.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t3.Stop() + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()[c1Dup.ID()]) != 1 { + return "c2 was not subscribed to track from duplicated c1" + } + + tr3 := c2.SubscribedTracks()[c1Dup.ID()][0] + participantId3, _ := rtc.UnpackStreamID(tr3.StreamID()) + require.Contains(t, c1Dup.ID(), participantId3) + + return "" + }) + }) } - testutils.WithTimeout(t, func() string { - if len(c2.SubscribedTracks()) == 0 { - return "c2 didn't subscribe to anything" - } - // should have received three tracks - if len(c2.SubscribedTracks()[c1.ID()]) != 2 { - return "c2 didn't subscribe to both tracks from c1" - } - - // participant ID can be appended with '#..' . but should contain orig id as prefix - tr1 := c2.SubscribedTracks()[c1.ID()][0] - participantId1, _ := rtc.UnpackStreamID(tr1.StreamID()) - require.Equal(t, c1.ID(), participantId1) - tr2 := c2.SubscribedTracks()[c1.ID()][1] - participantId2, _ := rtc.UnpackStreamID(tr2.StreamID()) - require.Equal(t, c1.ID(), participantId2) - return "" - }) - - c1Dup := createRTCClientWithToken(token, defaultServerPort, opts) - - waitUntilConnected(t, c1Dup) - - t3, err := c1Dup.AddStaticTrack("video/vp8", "video", "webcam") - require.NoError(t, err) - defer t3.Stop() - - testutils.WithTimeout(t, func() string { - if len(c2.SubscribedTracks()[c1Dup.ID()]) != 1 { - return "c2 was not subscribed to track from duplicated c1" - } - - tr3 := c2.SubscribedTracks()[c1Dup.ID()][0] - participantId3, _ := rtc.UnpackStreamID(tr3.StreamID()) - require.Contains(t, c1Dup.ID(), participantId3) - - return "" - }) } func TestSinglePublisher(t *testing.T) { @@ -153,76 +160,80 @@ func TestSinglePublisher(t *testing.T) { s, finish := setupSingleNodeTest("TestSinglePublisher") defer finish() - c1 := createRTCClient("c1", defaultServerPort, nil) - c2 := createRTCClient("c2", defaultServerPort, nil) - waitUntilConnected(t, c1, c2) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("c2", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1, c2) - // publish a track and ensure clients receive it ok - t1, err := c1.AddStaticTrack("audio/OPUS", "audio", "webcamaudio") - require.NoError(t, err) - defer t1.Stop() - t2, err := c1.AddStaticTrack("video/vp8", "video", "webcamvideo") - require.NoError(t, err) - defer t2.Stop() + // publish a track and ensure clients receive it ok + t1, err := c1.AddStaticTrack("audio/OPUS", "audio", "webcamaudio") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcamvideo") + require.NoError(t, err) + defer t2.Stop() - testutils.WithTimeout(t, func() string { - if len(c2.SubscribedTracks()) == 0 { - return "c2 was not subscribed to anything" - } - // should have received two tracks - if len(c2.SubscribedTracks()[c1.ID()]) != 2 { - return "c2 didn't subscribe to both tracks from c1" - } + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 was not subscribed to anything" + } + // should have received two tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 didn't subscribe to both tracks from c1" + } - tr1 := c2.SubscribedTracks()[c1.ID()][0] - participantId, _ := rtc.UnpackStreamID(tr1.StreamID()) - require.Equal(t, c1.ID(), participantId) - return "" - }) - // ensure mime type is received - remoteC1 := c2.GetRemoteParticipant(c1.ID()) - audioTrack := funk.Find(remoteC1.Tracks, func(ti *livekit.TrackInfo) bool { - return ti.Name == "webcamaudio" - }).(*livekit.TrackInfo) - require.Equal(t, "audio/opus", audioTrack.MimeType) + tr1 := c2.SubscribedTracks()[c1.ID()][0] + participantId, _ := rtc.UnpackStreamID(tr1.StreamID()) + require.Equal(t, c1.ID(), participantId) + return "" + }) + // ensure mime type is received + remoteC1 := c2.GetRemoteParticipant(c1.ID()) + audioTrack := funk.Find(remoteC1.Tracks, func(ti *livekit.TrackInfo) bool { + return ti.Name == "webcamaudio" + }).(*livekit.TrackInfo) + require.Equal(t, "audio/opus", audioTrack.MimeType) - // a new client joins and should get the initial stream - c3 := createRTCClient("c3", defaultServerPort, nil) + // a new client joins and should get the initial stream + c3 := createRTCClient("c3", defaultServerPort, useSinglePeerConnection, nil) - // ensure that new client that has joined also received tracks - waitUntilConnected(t, c3) - testutils.WithTimeout(t, func() string { - if len(c3.SubscribedTracks()) == 0 { - return "c3 didn't subscribe to anything" - } - // should have received two tracks - if len(c3.SubscribedTracks()[c1.ID()]) != 2 { - return "c3 didn't subscribe to tracks from c1" - } - return "" - }) + // ensure that new client that has joined also received tracks + waitUntilConnected(t, c3) + testutils.WithTimeout(t, func() string { + if len(c3.SubscribedTracks()) == 0 { + return "c3 didn't subscribe to anything" + } + // should have received two tracks + if len(c3.SubscribedTracks()[c1.ID()]) != 2 { + return "c3 didn't subscribe to tracks from c1" + } + return "" + }) - // ensure that the track ids are generated by server - tracks := c3.SubscribedTracks()[c1.ID()] - for _, tr := range tracks { - require.True(t, strings.HasPrefix(tr.ID(), "TR_"), "track should begin with TR") - } - - // when c3 disconnects, ensure subscriber is cleaned up correctly - c3.Stop() - - testutils.WithTimeout(t, func() string { - room := s.RoomManager().GetRoom(context.Background(), testRoom) - p := room.GetParticipant("c1") - require.NotNil(t, p) - - for _, t := range p.GetPublishedTracks() { - if t.IsSubscriber(c3.ID()) { - return "c3 was not a subscriber of c1's tracks" + // ensure that the track ids are generated by server + tracks := c3.SubscribedTracks()[c1.ID()] + for _, tr := range tracks { + require.True(t, strings.HasPrefix(tr.ID(), "TR_"), "track should begin with TR") } - } - return "" - }) + + // when c3 disconnects, ensure subscriber is cleaned up correctly + c3.Stop() + + testutils.WithTimeout(t, func() string { + room := s.RoomManager().GetRoom(context.Background(), testRoom) + p := room.GetParticipant("c1") + require.NotNil(t, p) + + for _, t := range p.GetPublishedTracks() { + if t.IsSubscriber(c3.ID()) { + return "c3 was not a subscriber of c1's tracks" + } + } + return "" + }) + }) + } } func Test_WhenAutoSubscriptionDisabled_ClientShouldNotReceiveAnyPublishedTracks(t *testing.T) { @@ -234,20 +245,24 @@ func Test_WhenAutoSubscriptionDisabled_ClientShouldNotReceiveAnyPublishedTracks( _, finish := setupSingleNodeTest("Test_WhenAutoSubscriptionDisabled_ClientShouldNotReceiveAnyPublishedTracks") defer finish() - opts := testclient.Options{AutoSubscribe: false} - publisher := createRTCClient("publisher", defaultServerPort, &opts) - client := createRTCClient("client", defaultServerPort, &opts) - defer publisher.Stop() - defer client.Stop() - waitUntilConnected(t, publisher, client) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + opts := testclient.Options{AutoSubscribe: false} + publisher := createRTCClient("publisher", defaultServerPort, useSinglePeerConnection, &opts) + client := createRTCClient("client", defaultServerPort, useSinglePeerConnection, &opts) + defer publisher.Stop() + defer client.Stop() + waitUntilConnected(t, publisher, client) - track, err := publisher.AddStaticTrack("audio/opus", "audio", "webcam") - require.NoError(t, err) - defer track.Stop() + track, err := publisher.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer track.Stop() - time.Sleep(syncDelay) + time.Sleep(syncDelay) - require.Empty(t, client.SubscribedTracks()[publisher.ID()]) + require.Empty(t, client.SubscribedTracks()[publisher.ID()]) + }) + } } func Test_RenegotiationWithDifferentCodecs(t *testing.T) { @@ -259,71 +274,75 @@ func Test_RenegotiationWithDifferentCodecs(t *testing.T) { _, finish := setupSingleNodeTest("TestRenegotiationWithDifferentCodecs") defer finish() - c1 := createRTCClient("c1", defaultServerPort, nil) - c2 := createRTCClient("c2", defaultServerPort, nil) - waitUntilConnected(t, c1, c2) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + c2 := createRTCClient("c2", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1, c2) - // publish a vp8 video track and ensure clients receive it ok - t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") - require.NoError(t, err) - defer t1.Stop() - t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") - require.NoError(t, err) - defer t2.Stop() + // publish a vp8 video track and ensure clients receive it ok + t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t2.Stop() - testutils.WithTimeout(t, func() string { - if len(c2.SubscribedTracks()) == 0 { - return "c2 was not subscribed to anything" - } - // should have received two tracks - if len(c2.SubscribedTracks()[c1.ID()]) != 2 { - return "c2 was not subscribed to tracks from c1" - } + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 was not subscribed to anything" + } + // should have received two tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 was not subscribed to tracks from c1" + } - tracks := c2.SubscribedTracks()[c1.ID()] - for _, t := range tracks { - if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { + tracks := c2.SubscribedTracks()[c1.ID()] + for _, t := range tracks { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { + return "" + + } + } + return "did not receive track with vp8" + }) + + t3, err := c1.AddStaticTrackWithCodec(webrtc.RTPCodecCapability{ + MimeType: "video/h264", + ClockRate: 90000, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + }, "videoscreen", "screen") + defer t3.Stop() + require.NoError(t, err) + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2's not subscribed to anything" + } + // should have received three tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 3 { + return "c2's not subscribed to 3 tracks from c1" + } + + var vp8Found, h264Found bool + tracks := c2.SubscribedTracks()[c1.ID()] + for _, t := range tracks { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { + vp8Found = true + } else if mime.IsMimeTypeStringH264(t.Codec().MimeType) { + h264Found = true + } + } + if !vp8Found { + return "did not receive track with vp8" + } + if !h264Found { + return "did not receive track with h264" + } return "" - - } - } - return "did not receive track with vp8" - }) - - t3, err := c1.AddStaticTrackWithCodec(webrtc.RTPCodecCapability{ - MimeType: "video/h264", - ClockRate: 90000, - SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", - }, "videoscreen", "screen") - defer t3.Stop() - require.NoError(t, err) - - testutils.WithTimeout(t, func() string { - if len(c2.SubscribedTracks()) == 0 { - return "c2's not subscribed to anything" - } - // should have received three tracks - if len(c2.SubscribedTracks()[c1.ID()]) != 3 { - return "c2's not subscribed to 3 tracks from c1" - } - - var vp8Found, h264Found bool - tracks := c2.SubscribedTracks()[c1.ID()] - for _, t := range tracks { - if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { - vp8Found = true - } else if mime.IsMimeTypeStringH264(t.Codec().MimeType) { - h264Found = true - } - } - if !vp8Found { - return "did not receive track with vp8" - } - if !h264Found { - return "did not receive track with h264" - } - return "" - }) + }) + }) + } } func TestSingleNodeRoomList(t *testing.T) { @@ -401,13 +420,17 @@ func TestPingPong(t *testing.T) { _, finish := setupSingleNodeTest("TestPingPong") defer finish() - c1 := createRTCClient("c1", defaultServerPort, nil) - waitUntilConnected(t, c1) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1) - require.NoError(t, c1.SendPing()) - require.Eventually(t, func() bool { - return c1.PongReceivedAt() > 0 - }, time.Second, 10*time.Millisecond) + require.NoError(t, c1.SendPing()) + require.Eventually(t, func() bool { + return c1.PongReceivedAt() > 0 + }, time.Second, 10*time.Millisecond) + }) + } } func TestSingleNodeJoinAfterClose(t *testing.T) { @@ -453,14 +476,18 @@ func TestAutoCreate(t *testing.T) { waitForServerToStart(s) - token := joinToken(testRoom, "start-before-create", nil) - _, err := testclient.NewWebSocketConn(fmt.Sprintf("ws://localhost:%d", defaultServerPort), token, nil) - require.Error(t, err) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + token := joinToken(testRoom, "start-before-create", nil) + _, err := testclient.NewWebSocketConn(fmt.Sprintf("ws://localhost:%d", defaultServerPort), token, &testclient.Options{UseJoinRequestQueryParam: useSinglePeerConnection}) + require.Error(t, err) - // second join should also fail - token = joinToken(testRoom, "start-before-create-2", nil) - _, err = testclient.NewWebSocketConn(fmt.Sprintf("ws://localhost:%d", defaultServerPort), token, nil) - require.Error(t, err) + // second join should also fail + token = joinToken(testRoom, "start-before-create-2", nil) + _, err = testclient.NewWebSocketConn(fmt.Sprintf("ws://localhost:%d", defaultServerPort), token, &testclient.Options{UseJoinRequestQueryParam: useSinglePeerConnection}) + require.Error(t, err) + }) + } }) t.Run("join with explicit createRoom", func(t *testing.T) { @@ -478,10 +505,14 @@ func TestAutoCreate(t *testing.T) { _, err := roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{Name: testRoom}) require.NoError(t, err) - c1 := createRTCClient("join-after-create", defaultServerPort, nil) - waitUntilConnected(t, c1) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("join-after-create", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1) - c1.Stop() + c1.Stop() + }) + } }) } @@ -494,53 +525,58 @@ func TestSingleNodeUpdateSubscriptionPermissions(t *testing.T) { _, finish := setupSingleNodeTest("TestSingleNodeUpdateSubscriptionPermissions") defer finish() - pub := createRTCClient("pub", defaultServerPort, nil) - grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} - grant.SetCanSubscribe(false) - at := auth.NewAccessToken(testApiKey, testApiSecret). - AddGrant(grant). - SetIdentity("sub") - token, err := at.ToJWT() - require.NoError(t, err) - sub := createRTCClientWithToken(token, defaultServerPort, nil) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + pub := createRTCClient("pub", defaultServerPort, useSinglePeerConnection, nil) - waitUntilConnected(t, pub, sub) + grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} + grant.SetCanSubscribe(false) + at := auth.NewAccessToken(testApiKey, testApiSecret). + AddGrant(grant). + SetIdentity("sub") + token, err := at.ToJWT() + require.NoError(t, err) + sub := createRTCClientWithToken(token, defaultServerPort, useSinglePeerConnection, nil) - writers := publishTracksForClients(t, pub) - defer stopWriters(writers...) + waitUntilConnected(t, pub, sub) - // wait sub receives tracks - testutils.WithTimeout(t, func() string { - pubRemote := sub.GetRemoteParticipant(pub.ID()) - if pubRemote == nil { - return "could not find remote publisher" - } - if len(pubRemote.Tracks) != 2 { - return "did not receive metadata for published tracks" - } - return "" - }) + writers := publishTracksForClients(t, pub) + defer stopWriters(writers...) - // set permissions out of band - ctx := contextWithToken(adminRoomToken(testRoom)) - _, err = roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ - Room: testRoom, - Identity: "sub", - Permission: &livekit.ParticipantPermission{ - CanSubscribe: true, - CanPublish: true, - }, - }) - require.NoError(t, err) + // wait sub receives tracks + testutils.WithTimeout(t, func() string { + pubRemote := sub.GetRemoteParticipant(pub.ID()) + if pubRemote == nil { + return "could not find remote publisher" + } + if len(pubRemote.Tracks) != 2 { + return "did not receive metadata for published tracks" + } + return "" + }) - testutils.WithTimeout(t, func() string { - tracks := sub.SubscribedTracks()[pub.ID()] - if len(tracks) == 2 { - return "" - } else { - return fmt.Sprintf("expected 2 tracks subscribed, actual: %d", len(tracks)) - } - }) + // set permissions out of band + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err = roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "sub", + Permission: &livekit.ParticipantPermission{ + CanSubscribe: true, + CanPublish: true, + }, + }) + require.NoError(t, err) + + testutils.WithTimeout(t, func() string { + tracks := sub.SubscribedTracks()[pub.ID()] + if len(tracks) == 2 { + return "" + } else { + return fmt.Sprintf("expected 2 tracks subscribed, actual: %d", len(tracks)) + } + }) + }) + } } func TestSingleNodeAttributes(t *testing.T) { @@ -551,48 +587,52 @@ func TestSingleNodeAttributes(t *testing.T) { _, finish := setupSingleNodeTest("TestSingleNodeAttributes") defer finish() - pub := createRTCClient("pub", defaultServerPort, &testclient.Options{ - Attributes: map[string]string{ - "b": "2", - "c": "3", - }, - TokenCustomizer: func(token *auth.AccessToken, grants *auth.VideoGrant) { - T := true - grants.CanUpdateOwnMetadata = &T - token.SetAttributes(map[string]string{ - "a": "0", - "b": "1", + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + pub := createRTCClient("pub", defaultServerPort, useSinglePeerConnection, &testclient.Options{ + Attributes: map[string]string{ + "b": "2", + "c": "3", + }, + TokenCustomizer: func(token *auth.AccessToken, grants *auth.VideoGrant) { + T := true + grants.CanUpdateOwnMetadata = &T + token.SetAttributes(map[string]string{ + "a": "0", + "b": "1", + }) + }, }) - }, - UseJoinRequestQueryParam: true, - }) - grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} - grant.SetCanSubscribe(false) - at := auth.NewAccessToken(testApiKey, testApiSecret). - SetVideoGrant(grant). - SetIdentity("sub") - token, err := at.ToJWT() - require.NoError(t, err) - sub := createRTCClientWithToken(token, defaultServerPort, nil) - waitUntilConnected(t, pub, sub) + grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} + grant.SetCanSubscribe(false) + at := auth.NewAccessToken(testApiKey, testApiSecret). + SetVideoGrant(grant). + SetIdentity("sub") + token, err := at.ToJWT() + require.NoError(t, err) + sub := createRTCClientWithToken(token, defaultServerPort, useSinglePeerConnection, nil) - // wait sub receives initial attributes - testutils.WithTimeout(t, func() string { - pubRemote := sub.GetRemoteParticipant(pub.ID()) - if pubRemote == nil { - return "could not find remote publisher" - } - attrs := pubRemote.Attributes - if !reflect.DeepEqual(attrs, map[string]string{ - "a": "0", - "b": "2", - "c": "3", - }) { - return fmt.Sprintf("did not receive expected attributes: %v", attrs) - } - return "" - }) + waitUntilConnected(t, pub, sub) + + // wait sub receives initial attributes + testutils.WithTimeout(t, func() string { + pubRemote := sub.GetRemoteParticipant(pub.ID()) + if pubRemote == nil { + return "could not find remote publisher" + } + attrs := pubRemote.Attributes + if !reflect.DeepEqual(attrs, map[string]string{ + "a": "0", + "b": "2", + "c": "3", + }) { + return fmt.Sprintf("did not receive expected attributes: %v", attrs) + } + return "" + }) + }) + } } // TestDeviceCodecOverride checks that codecs that are incompatible with a device is not @@ -606,53 +646,67 @@ func TestDeviceCodecOverride(t *testing.T) { _, finish := setupSingleNodeTest("TestDeviceCodecOverride") defer finish() - // simulate device that isn't compatible with H.264 - c1 := createRTCClient("c1", defaultServerPort, &testclient.Options{ - ClientInfo: &livekit.ClientInfo{ - Os: "android", - DeviceModel: "Xiaomi 2201117TI", - }, - }) - defer c1.Stop() - waitUntilConnected(t, c1) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + // simulate device that isn't compatible with H.264 + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, &testclient.Options{ + ClientInfo: &livekit.ClientInfo{ + Os: "android", + DeviceModel: "Xiaomi 2201117TI", + }, + }) + defer c1.Stop() + waitUntilConnected(t, c1) - // it doesn't really matter what the codec set here is, uses default Pion MediaEngine codecs - tw, err := c1.AddStaticTrack("video/h264", "video", "webcam") - require.NoError(t, err) - defer stopWriters(tw) + // it doesn't really matter what the codec set here is, uses default Pion MediaEngine codecs + tw, err := c1.AddStaticTrack("video/h264", "video", "webcam") + require.NoError(t, err) + defer stopWriters(tw) - // wait for server to receive track - require.Eventually(t, func() bool { - return c1.LastAnswer() != nil - }, waitTimeout, waitTick, "did not receive answer") + var desc *sdp.MediaDescription + require.Eventually(t, func() bool { + lastAnswer := c1.LastAnswer() + if lastAnswer == nil { + return false + } - sd := webrtc.SessionDescription{ - Type: webrtc.SDPTypeAnswer, - SDP: c1.LastAnswer().SDP, - } - answer, err := sd.Unmarshal() - require.NoError(t, err) + sd := webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: lastAnswer.SDP, + } + answer, err := sd.Unmarshal() + require.NoError(t, err) - // video and data channel - require.Len(t, answer.MediaDescriptions, 2) - var desc *sdp.MediaDescription - for _, md := range answer.MediaDescriptions { - if md.MediaName.Media == "video" { - desc = md - break - } - } - require.NotNil(t, desc) - hasSeenVP8 := false - for _, a := range desc.Attributes { - if a.Key == "rtpmap" { - require.NotContains(t, a.Value, mime.MimeTypeCodecH264.String(), "should not contain H264 codec") - if strings.Contains(a.Value, mime.MimeTypeCodecVP8.String()) { - hasSeenVP8 = true + // video and data channel + if len(answer.MediaDescriptions) < 2 { + return false + } + + for _, md := range answer.MediaDescriptions { + if md.MediaName.Media == "video" { + desc = md + break + } + } + if desc == nil { + return false + } + + return true + }, waitTimeout, waitTick, "did not receive answer") + + hasSeenVP8 := false + for _, a := range desc.Attributes { + if a.Key == "rtpmap" { + require.NotContains(t, a.Value, mime.MimeTypeCodecH264.String(), "should not contain H264 codec") + if strings.Contains(a.Value, mime.MimeTypeCodecVP8.String()) { + hasSeenVP8 = true + } + } } - } + require.True(t, hasSeenVP8, "should have seen VP8 codec in SDP") + }) } - require.True(t, hasSeenVP8, "should have seen VP8 codec in SDP") } func TestSubscribeToCodecUnsupported(t *testing.T) { @@ -664,104 +718,106 @@ func TestSubscribeToCodecUnsupported(t *testing.T) { _, finish := setupSingleNodeTest("TestSubscribeToCodecUnsupported") defer finish() - c1 := createRTCClient("c1", defaultServerPort, &testclient.Options{ - UseJoinRequestQueryParam: true, - }) - // create a client that doesn't support H264 - c2 := createRTCClient("c2", defaultServerPort, &testclient.Options{ - AutoSubscribe: true, - DisabledCodecs: []webrtc.RTPCodecCapability{ - {MimeType: "video/H264"}, - }, - }) - waitUntilConnected(t, c1, c2) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + // create a client that doesn't support H264 + c2 := createRTCClient("c2", defaultServerPort, useSinglePeerConnection, &testclient.Options{ + AutoSubscribe: true, + DisabledCodecs: []webrtc.RTPCodecCapability{ + {MimeType: "video/H264"}, + }, + }) + waitUntilConnected(t, c1, c2) - // publish a vp8 video track and ensure c2 receives it ok - t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") - require.NoError(t, err) - defer t1.Stop() - t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") - require.NoError(t, err) - defer t2.Stop() + // publish a vp8 video track and ensure c2 receives it ok + t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t2.Stop() - testutils.WithTimeout(t, func() string { - if len(c2.SubscribedTracks()) == 0 { - return "c2 was not subscribed to anything" - } - // should have received two tracks - if len(c2.SubscribedTracks()[c1.ID()]) != 2 { - return "c2 was not subscribed to tracks from c1" - } + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 was not subscribed to anything" + } + // should have received two tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 was not subscribed to tracks from c1" + } - tracks := c2.SubscribedTracks()[c1.ID()] - for _, t := range tracks { - if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { - return "" - } - } - return "did not receive track with vp8" - }) - require.Nil(t, c2.GetSubscriptionResponseAndClear()) + tracks := c2.SubscribedTracks()[c1.ID()] + for _, t := range tracks { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { + return "" + } + } + return "did not receive track with vp8" + }) + require.Nil(t, c2.GetSubscriptionResponseAndClear()) - // publish a h264 track and ensure c2 got subscription error - t3, err := c1.AddStaticTrackWithCodec(webrtc.RTPCodecCapability{ - MimeType: "video/h264", - ClockRate: 90000, - SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", - }, "videoscreen", "screen") - defer t3.Stop() - require.NoError(t, err) + // publish a h264 track and ensure c2 got subscription error + t3, err := c1.AddStaticTrackWithCodec(webrtc.RTPCodecCapability{ + MimeType: "video/h264", + ClockRate: 90000, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + }, "videoscreen", "screen") + defer t3.Stop() + require.NoError(t, err) - var h264TrackID string - require.Eventually(t, func() bool { - remoteC1 := c2.GetRemoteParticipant(c1.ID()) - require.NotNil(t, remoteC1) - for _, track := range remoteC1.Tracks { - if mime.IsMimeTypeStringH264(track.MimeType) { - h264TrackID = track.Sid + var h264TrackID string + require.Eventually(t, func() bool { + remoteC1 := c2.GetRemoteParticipant(c1.ID()) + require.NotNil(t, remoteC1) + for _, track := range remoteC1.Tracks { + if mime.IsMimeTypeStringH264(track.MimeType) { + h264TrackID = track.Sid + return true + } + } + return false + }, time.Second, 10*time.Millisecond, "did not receive track info with h264") + + require.Eventually(t, func() bool { + sr := c2.GetSubscriptionResponseAndClear() + if sr == nil { + return false + } + require.Equal(t, h264TrackID, sr.TrackSid) + require.Equal(t, livekit.SubscriptionError_SE_CODEC_UNSUPPORTED, sr.Err) return true - } - } - return false - }, time.Second, 10*time.Millisecond, "did not receive track info with h264") + }, 5*time.Second, 10*time.Millisecond, "did not receive subscription response") - require.Eventually(t, func() bool { - sr := c2.GetSubscriptionResponseAndClear() - if sr == nil { - return false - } - require.Equal(t, h264TrackID, sr.TrackSid) - require.Equal(t, livekit.SubscriptionError_SE_CODEC_UNSUPPORTED, sr.Err) - return true - }, 5*time.Second, 10*time.Millisecond, "did not receive subscription response") + // publish another vp8 track again, ensure the transport recovered by sfu and c2 can receive it + t4, err := c1.AddStaticTrack("video/vp8", "video2", "webcam2") + require.NoError(t, err) + defer t4.Stop() - // publish another vp8 track again, ensure the transport recovered by sfu and c2 can receive it - t4, err := c1.AddStaticTrack("video/vp8", "video2", "webcam2") - require.NoError(t, err) - defer t4.Stop() + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 was not subscribed to anything" + } + // should have received two tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 3 { + return "c2 was not subscribed to tracks from c1" + } - testutils.WithTimeout(t, func() string { - if len(c2.SubscribedTracks()) == 0 { - return "c2 was not subscribed to anything" - } - // should have received two tracks - if len(c2.SubscribedTracks()[c1.ID()]) != 3 { - return "c2 was not subscribed to tracks from c1" - } - - var vp8Count int - tracks := c2.SubscribedTracks()[c1.ID()] - for _, t := range tracks { - if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { - vp8Count++ - } - } - if vp8Count == 2 { - return "" - } - return "did not 2 receive track with vp8" - }) - require.Nil(t, c2.GetSubscriptionResponseAndClear()) + var vp8Count int + tracks := c2.SubscribedTracks()[c1.ID()] + for _, t := range tracks { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { + vp8Count++ + } + } + if vp8Count == 2 { + return "" + } + return "did not 2 receive track with vp8" + }) + require.Nil(t, c2.GetSubscriptionResponseAndClear()) + }) + } } func TestDataPublishSlowSubscriber(t *testing.T) { @@ -789,97 +845,101 @@ func TestDataPublishSlowSubscriber(t *testing.T) { logger.Infow("----------------FINISHING TEST----------------", "test", t.Name()) }() - pub := createRTCClient("pub", defaultServerPort, nil) - fastSub := createRTCClient("fastSub", defaultServerPort, nil) - slowSubNotDrop := createRTCClient("slowSubNotDrop", defaultServerPort, nil) - slowSubDrop := createRTCClient("slowSubDrop", defaultServerPort, nil) - waitUntilConnected(t, pub, fastSub, slowSubDrop, slowSubNotDrop) - defer func() { - pub.Stop() - fastSub.Stop() - slowSubNotDrop.Stop() - slowSubDrop.Stop() - }() + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + pub := createRTCClient("pub", defaultServerPort, useSinglePeerConnection, nil) + fastSub := createRTCClient("fastSub", defaultServerPort, useSinglePeerConnection, nil) + slowSubNotDrop := createRTCClient("slowSubNotDrop", defaultServerPort, useSinglePeerConnection, nil) + slowSubDrop := createRTCClient("slowSubDrop", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, pub, fastSub, slowSubDrop, slowSubNotDrop) + defer func() { + pub.Stop() + fastSub.Stop() + slowSubNotDrop.Stop() + slowSubDrop.Stop() + }() - // no data should be dropped for fast subscriber - var fastDataIndex atomic.Uint64 - fastSub.OnDataReceived = func(data []byte, sid string) { - idx := binary.BigEndian.Uint64(data[len(data)-8:]) - require.Equal(t, fastDataIndex.Load()+1, idx) - fastDataIndex.Store(idx) - } + // no data should be dropped for fast subscriber + var fastDataIndex atomic.Uint64 + fastSub.OnDataReceived = func(data []byte, sid string) { + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + require.Equal(t, fastDataIndex.Load()+1, idx) + fastDataIndex.Store(idx) + } - // no data should be dropped for slow subscriber that is above threshold - var slowNoDropDataIndex atomic.Uint64 - var drainSlowSubNotDrop atomic.Bool - slowNoDropReader := testclient.NewDataChannelReader(dataChannelSlowThreshold * 2) - slowSubNotDrop.OnDataReceived = func(data []byte, sid string) { - idx := binary.BigEndian.Uint64(data[len(data)-8:]) - require.Equal(t, slowNoDropDataIndex.Load()+1, idx) - slowNoDropDataIndex.Store(idx) - if !drainSlowSubNotDrop.Load() { - slowNoDropReader.Read(data, sid) - } - } - - // data should be dropped for slow subscriber that is below threshold - var slowDropDataIndex atomic.Uint64 - dropped := make(chan struct{}) - slowDropReader := testclient.NewDataChannelReader(dataChannelSlowThreshold / 2) - slowSubDrop.OnDataReceived = func(data []byte, sid string) { - select { - case <-dropped: - return - default: - } - idx := binary.BigEndian.Uint64(data[len(data)-8:]) - if idx != slowDropDataIndex.Load()+1 { - close(dropped) - } - slowDropDataIndex.Store(idx) - slowDropReader.Read(data, sid) - } - - // publisher sends data as fast as possible, it will block by the slowest subscriber above the slow threshold - var ( - blocked atomic.Bool - stopWrite atomic.Bool - writeIdx atomic.Uint64 - ) - writeStopped := make(chan struct{}) - go func() { - defer close(writeStopped) - var i int - buf := make([]byte, 100) - for !stopWrite.Load() { - i++ - binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(i)) - if err := pub.PublishData(buf, livekit.DataPacket_RELIABLE); err != nil { - if errors.Is(err, datachannel.ErrDataDroppedBySlowReader) { - blocked.Store(true) - i-- - continue - } else { - t.Log("error writing", err) - break + // no data should be dropped for slow subscriber that is above threshold + var slowNoDropDataIndex atomic.Uint64 + var drainSlowSubNotDrop atomic.Bool + slowNoDropReader := testclient.NewDataChannelReader(dataChannelSlowThreshold * 2) + slowSubNotDrop.OnDataReceived = func(data []byte, sid string) { + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + require.Equal(t, slowNoDropDataIndex.Load()+1, idx) + slowNoDropDataIndex.Store(idx) + if !drainSlowSubNotDrop.Load() { + slowNoDropReader.Read(data, sid) } } - writeIdx.Store(uint64(i)) - } - }() - <-dropped + // data should be dropped for slow subscriber that is below threshold + var slowDropDataIndex atomic.Uint64 + dropped := make(chan struct{}) + slowDropReader := testclient.NewDataChannelReader(dataChannelSlowThreshold / 2) + slowSubDrop.OnDataReceived = func(data []byte, sid string) { + select { + case <-dropped: + return + default: + } + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + if idx != slowDropDataIndex.Load()+1 { + close(dropped) + } + slowDropDataIndex.Store(idx) + slowDropReader.Read(data, sid) + } - time.Sleep(time.Second) - blocked.Store(false) - require.Eventually(t, func() bool { return blocked.Load() }, 30*time.Second, 100*time.Millisecond) - stopWrite.Store(true) - <-writeStopped - drainSlowSubNotDrop.Store(true) - require.Eventually(t, func() bool { - return writeIdx.Load() == fastDataIndex.Load() && - writeIdx.Load() == slowNoDropDataIndex.Load() - }, 10*time.Second, 50*time.Millisecond, "writeIdx %d, fast %d, slowNoDrop %d", writeIdx.Load(), fastDataIndex.Load(), slowNoDropDataIndex.Load()) + // publisher sends data as fast as possible, it will block by the slowest subscriber above the slow threshold + var ( + blocked atomic.Bool + stopWrite atomic.Bool + writeIdx atomic.Uint64 + ) + writeStopped := make(chan struct{}) + go func() { + defer close(writeStopped) + var i int + buf := make([]byte, 100) + for !stopWrite.Load() { + i++ + binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(i)) + if err := pub.PublishData(buf, livekit.DataPacket_RELIABLE); err != nil { + if errors.Is(err, datachannel.ErrDataDroppedBySlowReader) { + blocked.Store(true) + i-- + continue + } else { + t.Log("error writing", err) + break + } + } + writeIdx.Store(uint64(i)) + } + }() + + <-dropped + + time.Sleep(time.Second) + blocked.Store(false) + require.Eventually(t, func() bool { return blocked.Load() }, 30*time.Second, 100*time.Millisecond) + stopWrite.Store(true) + <-writeStopped + drainSlowSubNotDrop.Store(true) + require.Eventually(t, func() bool { + return writeIdx.Load() == fastDataIndex.Load() && + writeIdx.Load() == slowNoDropDataIndex.Load() + }, 10*time.Second, 50*time.Millisecond, "writeIdx %d, fast %d, slowNoDrop %d", writeIdx.Load(), fastDataIndex.Load(), slowNoDropDataIndex.Load()) + }) + } } func TestFireTrackBySdp(t *testing.T) { @@ -894,8 +954,8 @@ func TestFireTrackBySdp(t *testing.T) { { name: "js client could pub a/v tracks", codecs: []webrtc.RTPCodecCapability{ - {MimeType: "video/H264"}, - {MimeType: "audio/opus"}, + {MimeType: mime.MimeTypeH264.String()}, + {MimeType: mime.MimeTypeOpus.String()}, }, pubSDK: livekit.ClientInfo_JS, }, @@ -911,46 +971,50 @@ func TestFireTrackBySdp(t *testing.T) { for _, c := range cases { codecs, sdk := c.codecs, c.pubSDK t.Run(c.name, func(t *testing.T) { - c1 := createRTCClient(c.name+"_c1", defaultServerPort, &testclient.Options{ - ClientInfo: &livekit.ClientInfo{ - Sdk: sdk, - }, - }) - c2 := createRTCClient(c.name+"_c2", defaultServerPort, &testclient.Options{ - AutoSubscribe: true, - ClientInfo: &livekit.ClientInfo{ - Sdk: livekit.ClientInfo_JS, - }, - }) - waitUntilConnected(t, c1, c2) - defer func() { - c1.Stop() - c2.Stop() - }() + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient(c.name+"_c1", defaultServerPort, useSinglePeerConnection, &testclient.Options{ + ClientInfo: &livekit.ClientInfo{ + Sdk: sdk, + }, + }) + c2 := createRTCClient(c.name+"_c2", defaultServerPort, useSinglePeerConnection, &testclient.Options{ + AutoSubscribe: true, + ClientInfo: &livekit.ClientInfo{ + Sdk: livekit.ClientInfo_JS, + }, + }) + waitUntilConnected(t, c1, c2) + defer func() { + c1.Stop() + c2.Stop() + }() - // publish tracks and don't write any packets - for _, codec := range codecs { - _, err := c1.AddStaticTrackWithCodec(codec, codec.MimeType, codec.MimeType, testclient.AddTrackNoWriter()) - require.NoError(t, err) - } - - require.Eventually(t, func() bool { - return len(c2.SubscribedTracks()[c1.ID()]) == len(codecs) - }, 5*time.Second, 10*time.Millisecond) - - var found int - for _, pubTrack := range c1.GetPublishedTrackIDs() { - t.Log("pub track", pubTrack) - tracks := c2.SubscribedTracks()[c1.ID()] - for _, track := range tracks { - t.Log("sub track", track.ID(), track.Codec()) - if track.Codec().PayloadType == 0 && track.ID() == pubTrack { - found++ - break + // publish tracks and don't write any packets + for _, codec := range codecs { + _, err := c1.AddStaticTrackWithCodec(codec, codec.MimeType, codec.MimeType, testclient.AddTrackNoWriter()) + require.NoError(t, err) } - } + + require.Eventually(t, func() bool { + return len(c2.SubscribedTracks()[c1.ID()]) == len(codecs) + }, 5*time.Second, 10*time.Millisecond) + + var found int + for _, pubTrack := range c1.GetPublishedTrackIDs() { + t.Log("pub track", pubTrack) + tracks := c2.SubscribedTracks()[c1.ID()] + for _, track := range tracks { + t.Log("sub track", track.ID(), track.Codec()) + if track.Codec().PayloadType == 0 && track.ID() == pubTrack { + found++ + break + } + } + } + require.Equal(t, len(codecs), found) + }) } - require.Equal(t, len(codecs), found) }) } } diff --git a/test/webhook_test.go b/test/webhook_test.go index e0469a4e0..040130dcc 100644 --- a/test/webhook_test.go +++ b/test/webhook_test.go @@ -45,78 +45,82 @@ func TestWebhooks(t *testing.T) { require.NoError(t, err) defer finish() - c1 := createRTCClient("c1", defaultServerPort, nil) - waitUntilConnected(t, c1) - testutils.WithTimeout(t, func() string { - if ts.GetEvent(webhook.EventRoomStarted) == nil { - return "did not receive RoomStarted" - } - if ts.GetEvent(webhook.EventParticipantJoined) == nil { - return "did not receive ParticipantJoined" - } - return "" - }) + for _, useSinglePeerConnection := range []bool{false, true} { + t.Run(fmt.Sprintf("singlePeerConnection=%+v", useSinglePeerConnection), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c1) + testutils.WithTimeout(t, func() string { + if ts.GetEvent(webhook.EventRoomStarted) == nil { + return "did not receive RoomStarted" + } + if ts.GetEvent(webhook.EventParticipantJoined) == nil { + return "did not receive ParticipantJoined" + } + return "" + }) - // first participant join should have started the room - started := ts.GetEvent(webhook.EventRoomStarted) - require.Equal(t, testRoom, started.Room.Name) - require.NotEmpty(t, started.Id) - require.Greater(t, started.CreatedAt, time.Now().Unix()-100) - require.GreaterOrEqual(t, time.Now().Unix(), started.CreatedAt) - joined := ts.GetEvent(webhook.EventParticipantJoined) - require.Equal(t, "c1", joined.Participant.Identity) - ts.ClearEvents() + // first participant join should have started the room + started := ts.GetEvent(webhook.EventRoomStarted) + require.Equal(t, testRoom, started.Room.Name) + require.NotEmpty(t, started.Id) + require.Greater(t, started.CreatedAt, time.Now().Unix()-100) + require.GreaterOrEqual(t, time.Now().Unix(), started.CreatedAt) + joined := ts.GetEvent(webhook.EventParticipantJoined) + require.Equal(t, "c1", joined.Participant.Identity) + ts.ClearEvents() - // another participant joins - c2 := createRTCClient("c2", defaultServerPort, nil) - waitUntilConnected(t, c2) - defer c2.Stop() - testutils.WithTimeout(t, func() string { - if ts.GetEvent(webhook.EventParticipantJoined) == nil { - return "did not receive ParticipantJoined" - } - return "" - }) - joined = ts.GetEvent(webhook.EventParticipantJoined) - require.Equal(t, "c2", joined.Participant.Identity) - ts.ClearEvents() + // another participant joins + c2 := createRTCClient("c2", defaultServerPort, useSinglePeerConnection, nil) + waitUntilConnected(t, c2) + defer c2.Stop() + testutils.WithTimeout(t, func() string { + if ts.GetEvent(webhook.EventParticipantJoined) == nil { + return "did not receive ParticipantJoined" + } + return "" + }) + joined = ts.GetEvent(webhook.EventParticipantJoined) + require.Equal(t, "c2", joined.Participant.Identity) + ts.ClearEvents() - // track published - writers := publishTracksForClients(t, c1) - defer stopWriters(writers...) - testutils.WithTimeout(t, func() string { - ev := ts.GetEvent(webhook.EventTrackPublished) - if ev == nil { - return "did not receive TrackPublished" - } - require.NotNil(t, ev.Track, "TrackPublished did not include trackInfo") - require.Equal(t, string(c1.ID()), ev.Participant.Sid) - return "" - }) - ts.ClearEvents() + // track published + writers := publishTracksForClients(t, c1) + defer stopWriters(writers...) + testutils.WithTimeout(t, func() string { + ev := ts.GetEvent(webhook.EventTrackPublished) + if ev == nil { + return "did not receive TrackPublished" + } + require.NotNil(t, ev.Track, "TrackPublished did not include trackInfo") + require.Equal(t, string(c1.ID()), ev.Participant.Sid) + return "" + }) + ts.ClearEvents() - // first participant leaves - c1.Stop() - testutils.WithTimeout(t, func() string { - if ts.GetEvent(webhook.EventParticipantLeft) == nil { - return "did not receive ParticipantLeft" - } - return "" - }) - left := ts.GetEvent(webhook.EventParticipantLeft) - require.Equal(t, "c1", left.Participant.Identity) - ts.ClearEvents() + // first participant leaves + c1.Stop() + testutils.WithTimeout(t, func() string { + if ts.GetEvent(webhook.EventParticipantLeft) == nil { + return "did not receive ParticipantLeft" + } + return "" + }) + left := ts.GetEvent(webhook.EventParticipantLeft) + require.Equal(t, "c1", left.Participant.Identity) + ts.ClearEvents() - // room closed - rm := server.RoomManager().GetRoom(context.Background(), testRoom) - rm.Close(types.ParticipantCloseReasonNone) - testutils.WithTimeout(t, func() string { - if ts.GetEvent(webhook.EventRoomFinished) == nil { - return "did not receive RoomFinished" - } - return "" - }) - require.Equal(t, testRoom, ts.GetEvent(webhook.EventRoomFinished).Room.Name) + // room closed + rm := server.RoomManager().GetRoom(context.Background(), testRoom) + rm.Close(types.ParticipantCloseReasonNone) + testutils.WithTimeout(t, func() string { + if ts.GetEvent(webhook.EventRoomFinished) == nil { + return "did not receive RoomFinished" + } + return "" + }) + require.Equal(t, testRoom, ts.GetEvent(webhook.EventRoomFinished).Room.Name) + }) + } } func setupServerWithWebhook() (server *service.LivekitServer, testServer *webhookTestServer, finishFunc func(), err error) {