From 11360cd257c7a009a5ba798f49e49ba4afbc744d Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Mon, 29 Jul 2024 16:20:52 +0800 Subject: [PATCH] WHEP support --- go.mod | 8 +- go.sum | 29 +-- pkg/service/roommanager.go | 67 ++++--- pkg/service/server.go | 4 +- pkg/service/wire.go | 3 + pkg/service/wire_gen.go | 10 +- pkg/whep/handler.go | 1 - pkg/whep/server.go | 101 +++++----- pkg/whep/session.go | 381 ++++++++++++++++++++++++++++++++++++- pkg/whep/sessionmanager.go | 123 ++++++++++++ 10 files changed, 617 insertions(+), 110 deletions(-) delete mode 100644 pkg/whep/handler.go create mode 100644 pkg/whep/sessionmanager.go diff --git a/go.mod b/go.mod index add999cac..c52d290c6 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/livekit/mediatransportutil v0.0.0-20240625074155-301bb4a816b7 github.com/livekit/protocol v1.19.2-0.20240719172332-0df8e893874b github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a + github.com/livekit/server-sdk-go/v2 v2.2.0 github.com/mackerelio/go-osstat v0.2.5 github.com/magefile/mage v1.15.0 github.com/maxbrunsfeld/counterfeiter/v6 v6.8.1 @@ -83,6 +84,7 @@ require ( github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-jose/go-jose/v3 v3.0.3 // indirect github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/cel-go v0.20.1 // indirect github.com/google/go-cmp v0.6.0 // indirect @@ -94,7 +96,7 @@ require ( github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/josharian/native v1.1.0 // indirect github.com/klauspost/compress v1.17.9 // indirect - github.com/klauspost/cpuid/v2 v2.2.6 // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/lithammer/shortuuid/v4 v4.0.0 // indirect github.com/mattn/go-runewidth v0.0.9 // indirect github.com/mdlayher/netlink v1.7.1 // indirect @@ -108,7 +110,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opencontainers/runc v1.1.13 // indirect - github.com/pion/datachannel v1.5.5 // indirect + github.com/pion/datachannel v1.5.6 // indirect github.com/pion/logging v0.2.2 // indirect github.com/pion/mdns v0.0.12 // indirect github.com/pion/randutil v0.1.0 // indirect @@ -139,3 +141,5 @@ require ( google.golang.org/grpc v1.65.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) + +replace github.com/livekit/protocol => ../protocol diff --git a/go.sum b/go.sum index 6b708f506..f5ffce429 100644 --- a/go.sum +++ b/go.sum @@ -82,8 +82,11 @@ github.com/gammazero/workerpool v1.1.3 h1:WixN4xzukFoN0XSeXF6puqEqFTl2mECI9S6W44 github.com/gammazero/workerpool v1.1.3/go.mod h1:wPjyBLDbyKnUn2XwwyD3EEwo9dHutia9/fwNmSHWACc= github.com/go-jose/go-jose/v3 v3.0.3 h1:fFKWeig/irsp7XD2zBxvnmA/XaRWp5V3CBsZXJF7G7k= github.com/go-jose/go-jose/v3 v3.0.3/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= @@ -148,8 +151,8 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= -github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= @@ -167,10 +170,10 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20240625074155-301bb4a816b7 h1:F1L8inJoynwIAYpZENNYS+1xHJMF5RFRorsnAlcxfSY= github.com/livekit/mediatransportutil v0.0.0-20240625074155-301bb4a816b7/go.mod h1:jwKUCmObuiEDH0iiuJHaGMXwRs3RjrB4G6qqgkr/5oE= -github.com/livekit/protocol v1.19.2-0.20240719172332-0df8e893874b h1:Wn6D+B5YbMe1tH7WCazLJz+msBQzR69dK2wTdgJsF5k= -github.com/livekit/protocol v1.19.2-0.20240719172332-0df8e893874b/go.mod h1:bNjJi+8frdvC84xG0CJ/7VfVvqerLg2MzjOks0ucyC4= github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a h1:EQAHmcYEGlc6V517cQ3Iy0+jHgP6+tM/B4l2vGuLpQo= github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a/go.mod h1:CQUBSPfYYAaevg1TNCc6/aYsa8DJH4jSRFdCeSZk5u0= +github.com/livekit/server-sdk-go/v2 v2.2.0 h1:E0Yp45v6Yjhzt0ixGltuQQuBk7ToJkyxIe0931Y7aU4= +github.com/livekit/server-sdk-go/v2 v2.2.0/go.mod h1:nYjTi34qkgUvvS9T83KtkQEHTXPEsKoNZ0MQIskVD48= github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o= github.com/mackerelio/go-osstat v0.2.5/go.mod h1:atxwWF+POUZcdtR1wnsUcQxTytoHG4uhl2AKKzrOajY= github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= @@ -228,8 +231,8 @@ github.com/opencontainers/runc v1.1.13 h1:98S2srgG9vw0zWcDpFMn5TRrh8kLxa/5OFUstu github.com/opencontainers/runc v1.1.13/go.mod h1:R016aXacfp/gwQBYw2FDGa9m+n6atbLWrYY8hNMT/sA= github.com/ory/dockertest/v3 v3.10.0 h1:4K3z2VMe8Woe++invjaTB7VRyQXQy5UY+loujO4aNE4= github.com/ory/dockertest/v3 v3.10.0/go.mod h1:nr57ZbRWMqfsdGdFNLHz5jjNdDb7VVFnzAeW1n5N1Lg= -github.com/pion/datachannel v1.5.5 h1:10ef4kwdjije+M9d7Xm9im2Y3O6A6ccQb0zcqZcJew8= -github.com/pion/datachannel v1.5.5/go.mod h1:iMz+lECmfdCMqFRhXhcA/219B0SQlbpoR2V118yimL0= +github.com/pion/datachannel v1.5.6 h1:1IxKJntfSlYkpUj8LlYRSWpYiTTC02nUrOE8T3DqGeg= +github.com/pion/datachannel v1.5.6/go.mod h1:1eKT6Q85pRnr2mHiWHxJwO50SfZRtWHTsNIVb/NfGW4= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.11 h1:9U/dpCYl1ySttROPWJgqWKEylUdT0fXp/xst6JwY5Ks= github.com/pion/dtls/v2 v2.2.11/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= @@ -249,7 +252,7 @@ github.com/pion/rtcp v1.2.14/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9 github.com/pion/rtp v1.8.3/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= github.com/pion/rtp v1.8.7 h1:qslKkG8qxvQ7hqaxkmL7Pl0XcUm+/Er7nMnu6Vq+ZxM= github.com/pion/rtp v1.8.7/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= -github.com/pion/sctp v1.8.5/go.mod h1:SUFFfDpViyKejTAdwD1d/HQsCu+V/40cCs2nZIvC3s0= +github.com/pion/sctp v1.8.13/go.mod h1:YKSgO/bO/6aOMP9LCie1DuD7m+GamiK2yIiPM6vH+GA= github.com/pion/sctp v1.8.19 h1:2CYuw+SQ5vkQ9t0HdOPccsCz1GQMDuVy5PglLgKVBW8= github.com/pion/sctp v1.8.19/go.mod h1:P6PbDVA++OJMrVNg2AL3XtYHV4uD6dvfyOovCgMs0PE= github.com/pion/sdp/v3 v3.0.9 h1:pX++dCHoHUwq43kuwf3PyJfHlwIj4hXA7Vrifiq0IJY= @@ -258,8 +261,6 @@ github.com/pion/srtp/v2 v2.0.18 h1:vKpAXfawO9RtTRKZJbG4y0v1b11NZxQnxRl85kGuUlo= github.com/pion/srtp/v2 v2.0.18/go.mod h1:0KJQjA99A6/a0DOVTu1PhDSw0CXF2jTkqOoMg3ODqdA= github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4= github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8= -github.com/pion/transport v0.14.1 h1:XSM6olwW+o8J4SCmOBb/BpwZypkHeyM0PGFCxNQBr40= -github.com/pion/transport v0.14.1/go.mod h1:4tGmbk00NeYA3rUa9+n+dzCCoKkcy3YlYb99Jn2fNnI= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.2/go.mod h1:OJg3ojoBJopjEeECq2yJdXH9YVrUJ1uQ++NjXLOUorc= github.com/pion/transport/v2 v2.2.3/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= @@ -363,6 +364,7 @@ golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98y golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/exp v0.0.0-20240716160929-1d5bc16f04a8 h1:Z+vTUQyBb738QmIhbJx3z4htsxDeI+rd0EHvNm8jHkg= @@ -393,7 +395,6 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= @@ -401,6 +402,8 @@ golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -441,8 +444,6 @@ golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -452,11 +453,11 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= @@ -465,11 +466,11 @@ golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 06f783a43..ca867d854 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -27,6 +27,7 @@ import ( "github.com/livekit/livekit-server/pkg/agent" "github.com/livekit/livekit-server/pkg/sfu" sutils "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/livekit-server/pkg/whep" "github.com/livekit/mediatransportutil/pkg/rtcconfig" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" @@ -65,20 +66,21 @@ type iceConfigCacheKey struct { type RoomManager struct { lock sync.RWMutex - config *config.Config - rtcConfig *rtc.WebRTCConfig - serverInfo *livekit.ServerInfo - currentNode routing.LocalNode - router routing.Router - roomStore ObjectStore - telemetry telemetry.TelemetryService - clientConfManager clientconfiguration.ClientConfigurationManager - agentClient agent.Client - agentStore AgentStore - egressLauncher rtc.EgressLauncher - versionGenerator utils.TimedVersionGenerator - turnAuthHandler *TURNAuthHandler - bus psrpc.MessageBus + config *config.Config + rtcConfig *rtc.WebRTCConfig + serverInfo *livekit.ServerInfo + currentNode routing.LocalNode + router routing.Router + roomStore ObjectStore + telemetry telemetry.TelemetryService + clientConfManager clientconfiguration.ClientConfigurationManager + agentClient agent.Client + agentStore AgentStore + egressLauncher rtc.EgressLauncher + versionGenerator utils.TimedVersionGenerator + turnAuthHandler *TURNAuthHandler + bus psrpc.MessageBus + whepSessionManagers *whep.SessionManager rooms map[livekit.RoomName]*rtc.Room @@ -110,21 +112,27 @@ func NewLocalRoomManager( return nil, err } + whepSessionManager, err := whep.NewSessionManager(bus, rtcConf.WebRTCConfig) + if err != nil { + return nil, err + } + return &RoomManager{ - config: conf, - rtcConfig: rtcConf, - currentNode: currentNode, - router: router, - roomStore: roomStore, - telemetry: telemetry, - clientConfManager: clientConfManager, - egressLauncher: egressLauncher, - agentClient: agentClient, - agentStore: agentStore, - versionGenerator: versionGenerator, - turnAuthHandler: turnAuthHandler, - bus: bus, - forwardStats: forwardStats, + config: conf, + rtcConfig: rtcConf, + currentNode: currentNode, + router: router, + roomStore: roomStore, + telemetry: telemetry, + clientConfManager: clientConfManager, + egressLauncher: egressLauncher, + agentClient: agentClient, + agentStore: agentStore, + versionGenerator: versionGenerator, + turnAuthHandler: turnAuthHandler, + bus: bus, + whepSessionManagers: whepSessionManager, + forwardStats: forwardStats, rooms: make(map[livekit.RoomName]*rtc.Room), @@ -245,6 +253,9 @@ func (r *RoomManager) Stop() { if r.forwardStats != nil { r.forwardStats.Stop() } + if r.whepSessionManagers != nil { + r.whepSessionManagers.Stop() + } } // StartSession starts WebRTC session when a new participant is connected, takes place on RTC node diff --git a/pkg/service/server.go b/pkg/service/server.go index 3bbbd712a..b86c6927c 100644 --- a/pkg/service/server.go +++ b/pkg/service/server.go @@ -37,6 +37,7 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/whep" "github.com/livekit/livekit-server/version" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" @@ -66,6 +67,7 @@ func NewLivekitServer(conf *config.Config, ingressService *IngressService, sipService *SIPService, ioService *IOInfoService, + whepService *whep.Server, rtcService *RTCService, agentService *AgentService, keyProvider auth.KeyProvider, @@ -132,7 +134,7 @@ func NewLivekitServer(conf *config.Config, mux.Handle(sipServer.PathPrefix(), sipServer) mux.Handle("/rtc", rtcService) mux.Handle("/agent", agentService) - mux.Handle("/whep/", http.StripPrefix("/whep", whepServer)) + mux.Handle("/whep/", http.StripPrefix("/whep", whepService)) mux.HandleFunc("/rtc/validate", rtcService.Validate) mux.HandleFunc("/", s.defaultHandler) diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 75b3509a1..d4d822ac1 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -33,6 +33,7 @@ import ( "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/whep" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -65,6 +66,8 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live wire.Bind(new(IOClient), new(*IOInfoService)), rpc.NewEgressClient, rpc.NewIngressClient, + rpc.NewWHEPClient, + whep.NewServer, getEgressStore, NewEgressLauncher, NewEgressService, diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 15f81061f..c6bc67e71 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -14,6 +14,7 @@ import ( "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/whep" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -113,6 +114,11 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live return nil, err } sipService := NewSIPService(sipConfig, nodeID, messageBus, sipClient, sipStore, roomService, telemetryService) + whepClient, err := rpc.NewWHEPClient(clientParams) + if err != nil { + return nil, err + } + server := whep.NewServer(whepClient) rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, client, telemetryService) agentService, err := NewAgentService(conf, currentNode, messageBus, keyProvider) if err != nil { @@ -132,11 +138,11 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live return nil, err } authHandler := getTURNAuthHandlerFunc(turnAuthHandler) - server, err := newInProcessTurnServer(conf, authHandler) + turnServer, err := newInProcessTurnServer(conf, authHandler) if err != nil { return nil, err } - livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, sipService, ioInfoService, rtcService, agentService, keyProvider, router, roomManager, signalServer, server, currentNode) + livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, sipService, ioInfoService, server, rtcService, agentService, keyProvider, router, roomManager, signalServer, turnServer, currentNode) if err != nil { return nil, err } diff --git a/pkg/whep/handler.go b/pkg/whep/handler.go deleted file mode 100644 index 5712989af..000000000 --- a/pkg/whep/handler.go +++ /dev/null @@ -1 +0,0 @@ -package whep diff --git a/pkg/whep/server.go b/pkg/whep/server.go index 70a53248c..f60061709 100644 --- a/pkg/whep/server.go +++ b/pkg/whep/server.go @@ -3,6 +3,7 @@ package whep import ( "bytes" "context" + "errors" "fmt" "hash/crc32" "io" @@ -10,65 +11,37 @@ import ( "strings" "time" - "github.com/livekit/mediatransportutil/pkg/rtcconfig" "github.com/livekit/protocol/logger" - "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/rpc" "github.com/livekit/psrpc" ) const ( authorizationHeader = "Authorization" bearerPrefix = "Bearer " + accessTokenParam = "access_token" ) type Server struct { *http.ServeMux - sessions map[string]*Session + psrpcClient rpc.WHEPClient + logger logger.Logger } -func (s *Server) Start() error { - s.ctx, s.cancel = context.WithCancel(context.Background()) - - logger.Infow("starting WHEP server") - - if onPublish == nil { - return psrpc.NewErrorf(psrpc.Internal, "no onPublish callback provided") +func NewServer(psrpcClient rpc.WHEPClient) *Server { + s := &Server{ + psrpcClient: psrpcClient, + logger: logger.GetLogger(), } - - s.onPublish = onPublish - - var err error - s.webRTCConfig, err = rtcconfig.NewWebRTCConfig(&conf.RTCConfig, conf.Development) - if err != nil { - return err - } - r := http.NewServeMux() r.HandleFunc("POST /{room}/{participant}", s.handleNewSession) r.HandleFunc("PATCH /{room}/{participant}/{resource_id}", s.handleICE) r.HandleFunc("DELETE /{room}/{participant}/{resource_id}", s.handleDeleteSession) + s.ServeMux = r - hs := &http.Server{ - Addr: fmt.Sprintf(":%d", conf.WHIPPort), - Handler: r, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - } - - go func() { - err := hs.ListenAndServe() - if err != http.ErrServerClosed { - logger.Errorw("WHIP server start failed", err) - } - }() - - return nil -} - -func (s *Server) Stop() { - + return s } func (s *Server) handleNewSession(w http.ResponseWriter, r *http.Request) { @@ -76,12 +49,14 @@ func (s *Server) handleNewSession(w http.ResponseWriter, r *http.Request) { var authToken string if authHeader != "" { + // the request has already passed authorize middleware so this should not happen if !strings.HasPrefix(authHeader, bearerPrefix) { - handleError(w, r, http.StatusUnauthorized, ErrMissingAuthorization) + w.WriteHeader(http.StatusUnauthorized) return } - authToken = authHeader[len(bearerPrefix):] + } else { + authToken = r.FormValue(accessTokenParam) } room := r.PathValue("room") @@ -91,25 +66,29 @@ func (s *Server) handleNewSession(w http.ResponseWriter, r *http.Request) { _, err := io.Copy(&offer, r.Body) if err != nil { - return err + s.handleError(w, r, err) + return } - resourceID := utils.NewGuid(utils.WHIPResourcePrefix) - session := NewSession(room, participant, authToken, offer) - - answer := session.CreateAnswer() - - s.sessions[resourceID] = session - - go session.Run() + wsUrl := r.FormValue("livekit_url") + resp, err := s.psrpcClient.StartWHEP(context.Background(), &rpc.StartWHEPRequest{ + Token: authToken, + WsUrl: wsUrl, + Participant: participant, + Offer: offer.String(), + }, psrpc.WithRequestTimeout(5*time.Second)) + if err != nil { + s.handleError(w, r, err) + return + } w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Expose-Headers", "Location") w.Header().Set("Content-Type", "application/sdp") - w.Header().Set("Location", fmt.Sprintf("/%s/%s/%s", room, participant, resourceID)) + w.Header().Set("Location", fmt.Sprintf("/%s/%s/%s", room, participant, resp.ResourceId)) w.Header().Set("ETag", fmt.Sprintf("%08x", crc32.ChecksumIEEE(offer.Bytes()))) w.WriteHeader(http.StatusCreated) - _, _ = w.Write([]byte(answer)) + _, _ = w.Write([]byte(resp.Answer)) } func (s *Server) handleICE(w http.ResponseWriter, r *http.Request) { @@ -119,7 +98,23 @@ func (s *Server) handleICE(w http.ResponseWriter, r *http.Request) { func (s *Server) handleDeleteSession(w http.ResponseWriter, r *http.Request) { resourceID := r.PathValue("resource_id") - if session, ok := s.sessions[resourceID]; ok { - session.Close() + _, err := s.psrpcClient.DeleteWHEP(context.Background(), resourceID, &rpc.DeleteWHEPRequest{ + ResourceId: resourceID, + }, psrpc.WithRequestTimeout(5*time.Second)) + if err != nil { + w.WriteHeader(http.StatusNotFound) + return + } +} + +func (s *Server) handleError(w http.ResponseWriter, r *http.Request, err error) { + var psrpcErr psrpc.Error + switch { + case errors.As(err, &psrpcErr): + w.WriteHeader(psrpcErr.ToHttp()) + _, _ = w.Write([]byte(psrpcErr.Error())) + default: + s.logger.Infow("whip request failed", "error", err, "method", r.Method, "path", r.URL.Path) + w.WriteHeader(http.StatusInternalServerError) } } diff --git a/pkg/whep/session.go b/pkg/whep/session.go index 735e50525..5e88ecedd 100644 --- a/pkg/whep/session.go +++ b/pkg/whep/session.go @@ -1,24 +1,387 @@ package whep +import ( + "bufio" + "context" + "errors" + "io" + "strings" + "sync" + + google_protobuf "google.golang.org/protobuf/types/known/emptypb" + + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/webrtc/v3" + + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + sdputil "github.com/livekit/protocol/sdp" + "github.com/livekit/protocol/tracer" + "github.com/livekit/protocol/utils" + "github.com/livekit/psrpc" + + lksdk "github.com/livekit/server-sdk-go/v2" +) + +var ( + ErrResourceNotFound = psrpc.NewErrorf(psrpc.NotFound, "WHEP resource not found") +) + +type SessionParams struct { + Logger logger.Logger + Token string + Participant string + Offer string + WsUrl string + RtcAPI *webrtc.API +} + type Session struct { + params SessionParams + + room *lksdk.Room + pc *webrtc.PeerConnection + lock sync.Mutex + forwardingSenders map[string]*webrtc.RTPSender // RemoteTrack.SID -> Local RTPSender + availableSenders []*webrtc.RTPSender + closed bool + + onClose func() } -func NewSession() *Session { - +func NewSession(params SessionParams) *Session { + return &Session{ + params: params, + forwardingSenders: make(map[string]*webrtc.RTPSender), + } } -func (s *Session) CreateAnswer() { +func (s *Session) CreateAnswer() (string, error) { + pc, err := s.params.RtcAPI.NewPeerConnection(webrtc.Configuration{SDPSemantics: webrtc.SDPSemanticsUnifiedPlan}) + if err != nil { + return "", err + } + s.pc = pc + offerSdp := webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: s.params.Offer, + } + + // create track placeholders for forwarding sender + parsedOffer, err := offerSdp.Unmarshal() + if err != nil { + return "", err + } + for _, media := range parsedOffer.MediaDescriptions { + switch media.MediaName.Media { + case "audio": + track, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: 48000, + Channels: 2, + SDPFmtpLine: "minptime=10;useinbandfec=1", + }, utils.NewGuid(utils.TrackPrefix), s.params.Participant) + if err != nil { + return "", err + } + if t, err := pc.AddTransceiverFromTrack(track, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendonly}); err != nil { + return "", err + } else { + s.availableSenders = append(s.availableSenders, t.Sender()) + } + + case "video": + track, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeVP8, + ClockRate: 90000, + }, utils.NewGuid(utils.TrackPrefix), s.params.Participant) + if err != nil { + return "", err + } + if t, err := pc.AddTransceiverFromTrack(track, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendonly}); err != nil { + return "", err + } else { + s.availableSenders = append(s.availableSenders, t.Sender()) + } + + default: + continue + } + } + + if err = pc.SetRemoteDescription(offerSdp); err != nil { + return "", err + } + gatherCompelte := webrtc.GatheringCompletePromise(pc) + + answer, err := pc.CreateAnswer(nil) + if err != nil { + return "", err + } + if err = pc.SetLocalDescription(answer); err != nil { + return "", err + } + + room, err := lksdk.ConnectToRoomWithToken(s.params.WsUrl, s.params.Token, &lksdk.RoomCallback{ + ParticipantCallback: lksdk.ParticipantCallback{ + OnTrackPublished: s.onTrackPublished, + OnTrackSubscribed: s.onTrackSubscribed, + OnTrackUnsubscribed: s.onTrackUnsubscribed, + }, + OnDisconnectedWithReason: func(reason lksdk.DisconnectionReason) { + s.params.Logger.Infow("sdk client disconnected", "reason", reason) + go s.Close() + }, + }, lksdk.WithAutoSubscribe(false)) + if err != nil { + return "", err + } + s.room = room + + <-gatherCompelte + answerSdp := *pc.CurrentLocalDescription() + + pc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { + if pcs == webrtc.PeerConnectionStateFailed { + s.params.Logger.Infow("pc connection failed") + s.Close() + } + + }) + return answerSdp.SDP, err } -func (s *Session) Run() { - -} - -func (s *Session) HandleICERequest() { - +func (s *Session) OnClose(f func()) { + s.lock.Lock() + s.onClose = f + s.lock.Unlock() } func (s *Session) Close() { + s.lock.Lock() + if s.closed { + s.lock.Unlock() + return + } + s.closed = true + pc := s.pc + room := s.room + onClose := s.onClose + s.pc = nil + s.room = nil + s.forwardingSenders = make(map[string]*webrtc.RTPSender) + s.availableSenders = nil + s.lock.Unlock() + if pc != nil { + pc.Close() + } + + if room != nil { + room.Disconnect() + } + + if onClose != nil { + onClose() + } +} + +func (s *Session) onTrackPublished(publication *lksdk.RemoteTrackPublication, rp *lksdk.RemoteParticipant) { + if rp.Identity() != s.params.Participant { + return + } + + publication.SetSubscribed(true) +} + +func (s *Session) onTrackSubscribed(track *webrtc.TrackRemote, publication *lksdk.RemoteTrackPublication, rp *lksdk.RemoteParticipant) { + if rp.Identity() != s.params.Participant { + return + } + + trackSID := publication.SID() + s.params.Logger.Debugw("onTrackSubscribed", "trackID", trackSID) + s.lock.Lock() + defer s.lock.Unlock() + + for i, sender := range s.availableSenders { + if sender.Track().Kind() != track.Kind() { + continue + } + + localTrack, err := webrtc.NewTrackLocalStaticRTP(track.Codec().RTPCodecCapability, sender.Track().ID(), sender.Track().StreamID()) + if err != nil { + s.params.Logger.Errorw("failed to create local track", err, "trackID", trackSID) + return + } + if err = sender.ReplaceTrack(localTrack); err != nil { + s.params.Logger.Errorw("failed to replace track", err, "trackID", trackSID) + return + } + s.forwardingSenders[trackSID] = sender + s.availableSenders[i] = s.availableSenders[len(s.availableSenders)-1] + s.availableSenders[len(s.availableSenders)-1] = nil + s.availableSenders = s.availableSenders[:len(s.availableSenders)-1] + + s.params.Logger.Infow("start forwarding", "trackID", trackSID) + go s.forward(track, trackSID, rp, localTrack, sender) + return + } + + s.params.Logger.Infow("no available sender for track", "trackID", trackSID) +} + +func (s *Session) onTrackUnsubscribed(track *webrtc.TrackRemote, publication *lksdk.RemoteTrackPublication, rp *lksdk.RemoteParticipant) { + if rp.Identity() != s.params.Participant { + return + } + + trackSID := publication.SID() + s.params.Logger.Debugw("onTrackUnSubscribed", "trackID", trackSID) + s.lock.Lock() + defer s.lock.Unlock() + _, ok := s.forwardingSenders[trackSID] + if !ok { + return + } + + delete(s.forwardingSenders, trackSID) + // TODO: support sender reuse, this requires munged continuous timestamp/sequence + // s.availableSenders = append(s.availableSenders, sender) +} + +func (s *Session) forward(remoteTrack *webrtc.TrackRemote, trakcSID string, rp *lksdk.RemoteParticipant, localTrack *webrtc.TrackLocalStaticRTP, sender *webrtc.RTPSender) { + done := make(chan struct{}) + defer close(done) + var rtpPkt rtp.Packet + buf := make([]byte, 1500) + for { + n, _, err := remoteTrack.Read(buf) + if err != nil { + if !errors.Is(err, io.EOF) { + s.params.Logger.Warnw("read remote track failed", err, "trackID", trakcSID) + } + break + } + err = rtpPkt.Unmarshal(buf[:n]) + + if err != nil { + continue + } + rtpPkt.Header.Extension = false + rtpPkt.Header.ExtensionProfile = 0 + rtpPkt.Header.Extensions = []rtp.Extension{} + if err = localTrack.WriteRTP(&rtpPkt); err != nil { + break + } + } + + go func() { + for { + select { + case <-done: + return + default: + rtcpPkts, _, err := sender.ReadRTCP() + if err != nil { + if !errors.Is(err, io.EOF) { + s.params.Logger.Warnw("read forwarding track rtcp failed", err, "trackID", trakcSID) + return + } + } + for _, pkt := range rtcpPkts { + if _, ok := pkt.(*rtcp.PictureLossIndication); ok { + rp.WritePLI(remoteTrack.SSRC()) + break + } + } + } + } + + }() +} + +// TODO: support trickle ICE +func (s *Session) ICETrickle(ctx context.Context, req *rpc.ICETrickleRequest) (*google_protobuf.Empty, error) { + _, span := tracer.Start(ctx, "WHEPSession.ICETrickle") + defer span.End() + + return nil, psrpc.Unimplemented +} + +func (s *Session) ICERestart(ctx context.Context, req *rpc.ICERestartWHIPResourceRequest) (*rpc.ICERestartWHIPResourceResponse, error) { + _, span := tracer.Start(ctx, "WHEPSession.ICERestart") + defer span.End() + + s.lock.Lock() + if s.closed { + s.lock.Unlock() + return nil, ErrResourceNotFound + } + pc := s.pc + s.lock.Unlock() + + if pc == nil { + return nil, ErrResourceNotFound + } + + remoteDescription := pc.CurrentRemoteDescription() + if remoteDescription == nil { + return nil, ErrResourceNotFound + } + + // Replace the current remote description with the values from remote + newRemoteDescription, err := sdputil.ReplaceICEDetails(remoteDescription.SDP, req.UserFragment, req.Password) + if err != nil { + s.params.Logger.Errorw("RepalceICEDetails failed", err, "ufrag", req.UserFragment, "candidates", req.Candidates) + return nil, psrpc.Internal + } + remoteDescription.SDP = newRemoteDescription + + if err := pc.SetRemoteDescription(*remoteDescription); err != nil { + s.params.Logger.Errorw("SetRemoteDescription failed", err, "sdp", remoteDescription.SDP) + return nil, psrpc.Internal + } + + answer, err := pc.CreateAnswer(nil) + if err != nil { + s.params.Logger.Errorw("CreateAnswer failed", err) + return nil, psrpc.Internal + } + + gatherComplete := webrtc.GatheringCompletePromise(pc) + if err = pc.SetLocalDescription(answer); err != nil { + s.params.Logger.Errorw("SetLocalDescription failed", err) + return nil, psrpc.Internal + } + <-gatherComplete + + // Discard all `a=` lines that aren't ICE related + // "WHIP does not support renegotiation of non-ICE related SDP information" + // + // https://www.ietf.org/archive/id/draft-ietf-wish-whip-14.html#name-ice-restarts + var trickleIceSdpfrag strings.Builder + scanner := bufio.NewScanner(strings.NewReader(pc.LocalDescription().SDP)) + for scanner.Scan() { + l := scanner.Text() + if strings.HasPrefix(l, "a=") && !strings.HasPrefix(l, "a=ice-pwd") && !strings.HasPrefix(l, "a=ice-ufrag") && !strings.HasPrefix(l, "a=candidate") { + continue + } + + trickleIceSdpfrag.WriteString(l + "\n") + } + + return &rpc.ICERestartWHIPResourceResponse{TrickleIceSdpfrag: trickleIceSdpfrag.String()}, nil +} + +func (s *Session) DeleteWHEP(ctx context.Context, req *rpc.DeleteWHEPRequest) (*google_protobuf.Empty, error) { + _, span := tracer.Start(ctx, "WHEPSession.DeleteWHEP") + defer span.End() + + s.params.Logger.Debugw("DeleteWHEP") + s.Close() + + return nil, nil } diff --git a/pkg/whep/sessionmanager.go b/pkg/whep/sessionmanager.go new file mode 100644 index 000000000..1f4d005d6 --- /dev/null +++ b/pkg/whep/sessionmanager.go @@ -0,0 +1,123 @@ +package whep + +import ( + "context" + "sync" + + "github.com/pion/interceptor" + "github.com/pion/webrtc/v3" + + "github.com/livekit/mediatransportutil/pkg/rtcconfig" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/tracer" + "github.com/livekit/protocol/utils" + "github.com/livekit/psrpc" +) + +type SessionManager struct { + bus psrpc.MessageBus + sessionLock sync.Mutex + sessions map[string]*Session + rtcAPI *webrtc.API + logger logger.Logger + psrpcServer rpc.WHEPInternalServer +} + +func NewSessionManager(bus psrpc.MessageBus, rtccfg rtcconfig.WebRTCConfig) (*SessionManager, error) { + var me webrtc.MediaEngine + var ir interceptor.Registry + me.RegisterDefaultCodecs() + if err := webrtc.ConfigureNack(&me, &ir); err != nil { + return nil, err + } + + if err := webrtc.ConfigureRTCPReports(&ir); err != nil { + return nil, err + } + + s := &SessionManager{ + bus: bus, + logger: logger.GetLogger(), + sessions: make(map[string]*Session), + rtcAPI: webrtc.NewAPI( + webrtc.WithSettingEngine(rtccfg.SettingEngine), + webrtc.WithMediaEngine(&me), + webrtc.WithInterceptorRegistry(&ir), + ), + } + psrpcServer, err := rpc.NewWHEPInternalServer(s, bus) + if err != nil { + return nil, err + } + s.psrpcServer = psrpcServer + return s, nil +} + +func (sm *SessionManager) StartWHEP(ctx context.Context, req *rpc.StartWHEPRequest) (*rpc.StartWHEPResponse, error) { + _, span := tracer.Start(ctx, "WHEPSessionManager.StartWHEP") + defer span.End() + + sm.logger.Debugw("handle StartWhEP", "participant", req.Participant, "offer", req.Offer) + resourceID := utils.NewGuid(utils.WHIPResourcePrefix) + session := NewSession(SessionParams{ + Participant: req.Participant, + Token: req.Token, + Offer: req.Offer, + WsUrl: req.WsUrl, + RtcAPI: sm.rtcAPI, + Logger: sm.logger.WithValues("resourceID", resourceID), + }) + + answer, err := session.CreateAnswer() + if err != nil { + session.Close() + session.params.Logger.Warnw("failed to create answer", err) + return nil, err + } + + rpcServer, err := rpc.NewWHEPHandlerServer(session, sm.bus) + if err != nil { + session.Close() + return nil, err + } + if err = rpcServer.RegisterAllResuourceTopics(resourceID); err != nil { + rpcServer.Shutdown() + session.Close() + return nil, err + } + sm.sessionLock.Lock() + sm.sessions[resourceID] = session + sm.sessionLock.Unlock() + + session.OnClose(func() { + rpcServer.DeregisterAllResuourceTopics(resourceID) + sm.sessionLock.Lock() + delete(sm.sessions, resourceID) + sm.sessionLock.Unlock() + }) + + sm.logger.Debugw("WHEP session started", "resourceID", resourceID, "answer", answer) + + return &rpc.StartWHEPResponse{ + ResourceId: resourceID, + Answer: answer, + }, nil +} + +func (sm *SessionManager) StartWHEPAffinity(context.Context, *rpc.StartWHEPRequest) float32 { + // TODO: if NoCapacity return -1; if hasRoom return 1; else return 0 + return 1 + +} + +func (sm *SessionManager) Stop() { + sm.psrpcServer.Shutdown() + sm.sessionLock.Lock() + sessions := sm.sessions + sm.sessions = make(map[string]*Session) + sm.sessionLock.Unlock() + for _, s := range sessions { + s.Close() + } +}