From ed5e2f16b2146828c27aa75ecae3ffbd33a571d3 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Mon, 31 Mar 2025 19:25:57 +0530 Subject: [PATCH] Keep simulcast information tied to receiver. (#3563) * Keep simulcast information tied to receiver. `simulcast` flag in `TrackInfo` is at track lavel. With codec simulcast, the primary codec (in most cases) is SVC and the backup codec is simulcast. Back up codec publish changing the track info setting to true meant that the primary receiver was treated as simulcast if a subscriber for primary codec joined after the backup codec was published. Keep track of simulcast flag in receiver. Also, TrackInfo Cids are from signal. So, keep track of SDP cids separately. The `simulcastTrackIds` map uses SDP cid. Clean up by all the SDP cids of a track * clean up * clean up * clean up * clean up * test * Store SdpCid and IsSimulcast in Trackinfo * clean up * mock --- go.mod | 6 +- go.sum | 12 +- pkg/rtc/mediatrack.go | 39 +-- pkg/rtc/mediatrackreceiver.go | 36 ++- pkg/rtc/mediatracksubscriptions.go | 3 +- pkg/rtc/participant.go | 57 ++-- pkg/rtc/participant_internal_test.go | 2 +- pkg/rtc/participant_sdp.go | 7 +- pkg/rtc/transport.go | 2 +- pkg/rtc/types/interfaces.go | 5 +- .../typesfakes/fake_local_media_track.go | 284 +++++++++++------- pkg/rtc/types/typesfakes/fake_media_track.go | 65 ++++ pkg/rtc/wrappedreceiver.go | 8 + pkg/sfu/downtrack.go | 8 +- pkg/sfu/forwarder.go | 13 +- pkg/sfu/forwarder_test.go | 4 +- pkg/sfu/receiver.go | 30 +- 17 files changed, 393 insertions(+), 188 deletions(-) diff --git a/go.mod b/go.mod index b6bdc6357..1b5f2474d 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20250310153736-45596af895b6 - github.com/livekit/protocol v1.36.2-0.20250326174620-fbbb1c3ae28a + github.com/livekit/protocol v1.36.2-0.20250331123911-67af9b92e4ac github.com/livekit/psrpc v0.6.1-0.20250205181828-a0beed2e4126 github.com/mackerelio/go-osstat v0.2.5 github.com/magefile/mage v1.15.0 @@ -32,7 +32,7 @@ require ( github.com/olekukonko/tablewriter v0.0.5 github.com/ory/dockertest/v3 v3.11.0 github.com/pion/datachannel v1.5.10 - github.com/pion/dtls/v3 v3.0.5 + github.com/pion/dtls/v3 v3.0.6 github.com/pion/ice/v4 v4.0.9 github.com/pion/interceptor v0.1.37 github.com/pion/rtcp v1.2.15 @@ -133,7 +133,7 @@ require ( go.uber.org/zap/exp v0.3.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/mod v0.24.0 // indirect - golang.org/x/net v0.37.0 // indirect + golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect golang.org/x/tools v0.31.0 // indirect diff --git a/go.sum b/go.sum index 06b0b297f..b07f337da 100644 --- a/go.sum +++ b/go.sum @@ -170,8 +170,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20250310153736-45596af895b6 h1:6ZhtnY9I9knfm3ieIPpznQSEU2rDECO8yliW/ANLQ7U= github.com/livekit/mediatransportutil v0.0.0-20250310153736-45596af895b6/go.mod h1:36s+wwmU3O40IAhE+MjBWP3W71QRiEE9SfooSBvtBqY= -github.com/livekit/protocol v1.36.2-0.20250326174620-fbbb1c3ae28a h1:3oH/yRx6OTFc0JbUNkfhZXmW5zLZ61ZW02fqYFKjSrM= -github.com/livekit/protocol v1.36.2-0.20250326174620-fbbb1c3ae28a/go.mod h1:WrT/CYRxtMNOVUjnIPm5OjWtEkmreffTeE1PRZwlRg4= +github.com/livekit/protocol v1.36.2-0.20250331123911-67af9b92e4ac h1:mEb60UmuJdilpY9WbrdA7L1OgdfBKX9SlRa7GuKMYL8= +github.com/livekit/protocol v1.36.2-0.20250331123911-67af9b92e4ac/go.mod h1:WrT/CYRxtMNOVUjnIPm5OjWtEkmreffTeE1PRZwlRg4= github.com/livekit/psrpc v0.6.1-0.20250205181828-a0beed2e4126 h1:fzuYpAQbCid7ySPpQWWePfQOWUrs8x6dJ0T3Wl07n+Y= github.com/livekit/psrpc v0.6.1-0.20250205181828-a0beed2e4126/go.mod h1:X5WtEZ7OnEs72Fi5/J+i0on3964F1aynQpCalcgMqRo= github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o= @@ -235,8 +235,8 @@ github.com/ory/dockertest/v3 v3.11.0 h1:OiHcxKAvSDUwsEVh2BjxQQc/5EHz9n0va9awCtNG github.com/ory/dockertest/v3 v3.11.0/go.mod h1:VIPxS1gwT9NpPOrfD3rACs8Y9Z7yhzO4SB194iUDnUI= github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o= github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M= -github.com/pion/dtls/v3 v3.0.5 h1:OGWLu21/Wc5+H8R75F1BWvedH7H+nYUPFzJOew4k1iA= -github.com/pion/dtls/v3 v3.0.5/go.mod h1:JVCnfmbgq45QoU07AaxFbdjF2iomKzYouVNy+W5kqmY= +github.com/pion/dtls/v3 v3.0.6 h1:7Hkd8WhAJNbRgq9RgdNh1aaWlZlGpYTzdqjy9x9sK2E= +github.com/pion/dtls/v3 v3.0.6/go.mod h1:iJxNQ3Uhn1NZWOMWlLxEEHAN5yX7GyPvvKw04v9bzYU= github.com/pion/ice/v4 v4.0.9 h1:VKgU4MwA2LUDVLq+WBkpEHTcAb8c5iCvFMECeuPOZNk= github.com/pion/ice/v4 v4.0.9/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= github.com/pion/interceptor v0.1.37 h1:aRA8Zpab/wE7/c0O3fh1PqY0AJI3fCSEM5lRWJVorwI= @@ -398,8 +398,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 8d8eacd81..39ab47f26 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -42,7 +42,6 @@ import ( // Implements MediaTrack and PublishedTrack interface type MediaTrack struct { params MediaTrackParams - numUpTracks atomic.Uint32 buffer *buffer.Buffer everSubscribed atomic.Bool @@ -61,8 +60,6 @@ type MediaTrack struct { } type MediaTrackParams struct { - SignalCid string - SdpCid string ParticipantID livekit.ParticipantID ParticipantIdentity livekit.ParticipantIdentity ParticipantVersion uint32 @@ -177,22 +174,28 @@ func (t *MediaTrack) NotifySubscriberNodeMaxQuality(nodeID livekit.NodeID, quali } } -func (t *MediaTrack) SignalCid() string { - return t.params.SignalCid +func (t *MediaTrack) HasSignalCid(cid string) bool { + for _, c := range t.MediaTrackReceiver.TrackInfo().Codecs { + if c.SignalCid == cid { + return true + } + } + return false } -func (t *MediaTrack) SdpCid() string { - return t.params.SdpCid +func (t *MediaTrack) SdpCids() []string { + var sdpCids []string + for _, c := range t.MediaTrackReceiver.TrackInfo().Codecs { + if c.SdpCid != "" { + sdpCids = append(sdpCids, c.SdpCid) + } + } + return sdpCids } func (t *MediaTrack) HasSdpCid(cid string) bool { - if t.params.SdpCid == cid { - return true - } - - ti := t.MediaTrackReceiver.TrackInfoClone() - for _, c := range ti.Codecs { - if c.Cid == cid { + for _, c := range t.MediaTrackReceiver.TrackInfo().Codecs { + if c.SdpCid == cid { return true } } @@ -203,12 +206,8 @@ func (t *MediaTrack) ToProto() *livekit.TrackInfo { return t.MediaTrackReceiver.TrackInfoClone() } -func (t *MediaTrack) UpdateCodecCid(codecs []*livekit.SimulcastCodec) { - t.MediaTrackReceiver.UpdateCodecCid(codecs) -} - // AddReceiver adds a new RTP receiver to the track, returns true when receiver represents a new codec -func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRemote, mid string) bool { +func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRemote, mid string, isSimulcast bool) bool { var newCodec bool ssrc := uint32(track.SSRC()) buff, rtcpReader := t.params.BufferFactory.GetBufferPair(ssrc) @@ -255,6 +254,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe } }) + t.MediaTrackReceiver.UpdateCodecInfo(track.Codec().MimeType, track.ID(), isSimulcast) ti := t.MediaTrackReceiver.TrackInfoClone() t.lock.Lock() var regressCodec bool @@ -262,6 +262,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe layer := buffer.RidToSpatialLayer(track.RID(), ti) t.params.Logger.Debugw( "AddReceiver", + "cid", track.ID(), "rid", track.RID(), "layer", layer, "ssrc", track.SSRC(), diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index 88b45abc5..245bfd764 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -505,6 +505,10 @@ func (t *MediaTrackReceiver) SetSimulcast(simulcast bool) { } } +func (t *MediaTrackReceiver) HasMultipleSpatialLayers() bool { + return len(t.TrackInfo().Layers) > 1 +} + func (t *MediaTrackReceiver) Name() string { return t.TrackInfo().Name } @@ -595,7 +599,7 @@ func (t *MediaTrackReceiver) AddSubscriber(sub types.LocalParticipant) (types.Su Logger: tLogger, DisableRed: t.TrackInfo().GetDisableRed() || !t.params.AudioConfig.ActiveREDEncoding, }) - subTrack, err := t.MediaTrackSubscriptions.AddSubscriber(sub, wr, t.IsSimulcast()) + subTrack, err := t.MediaTrackSubscriptions.AddSubscriber(sub, wr) // media track could have been closed while adding subscription remove := false @@ -709,13 +713,13 @@ func (t *MediaTrackReceiver) SetLayerSsrc(mimeType mime.MimeType, rid string, ss t.updateTrackInfoOfReceivers() } -func (t *MediaTrackReceiver) UpdateCodecCid(codecs []*livekit.SimulcastCodec) { +func (t *MediaTrackReceiver) UpdateCodecSignalCid(codecs []*livekit.SimulcastCodec) { t.lock.Lock() trackInfo := t.TrackInfoClone() for _, c := range codecs { for _, origin := range trackInfo.Codecs { if mime.GetMimeTypeCodec(origin.MimeType) == mime.NormalizeMimeTypeCodec(c.Codec) { - origin.Cid = c.Cid + origin.SignalCid = c.Cid break } } @@ -726,6 +730,32 @@ func (t *MediaTrackReceiver) UpdateCodecCid(codecs []*livekit.SimulcastCodec) { t.updateTrackInfoOfReceivers() } +func (t *MediaTrackReceiver) UpdateCodecInfo(mimeType string, cid string, isSimulcast bool) { + t.lock.Lock() + trackInfo := t.TrackInfoClone() + for _, origin := range trackInfo.Codecs { + if mime.IsMimeTypeStringEqual(origin.MimeType, mimeType) { + if origin.SdpCid != "" { + if origin.SdpCid != cid || origin.IsSimulcast != isSimulcast { + t.params.Logger.Warnw( + "uexpected codec info change", nil, + "oldCid", origin.SdpCid, "newCid", cid, + "oldIsSimulcast", origin.IsSimulcast, "newIsSimulcast", isSimulcast, + ) + } + + } + origin.SdpCid = cid + origin.IsSimulcast = isSimulcast + break + } + } + t.trackInfo.Store(trackInfo) + t.lock.Unlock() + + t.updateTrackInfoOfReceivers() +} + func (t *MediaTrackReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { updateMute := false clonedInfo := utils.CloneProto(ti) diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index c7af3d77c..63d0f175c 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -92,7 +92,7 @@ func (t *MediaTrackSubscriptions) IsSubscriber(subID livekit.ParticipantID) bool } // AddSubscriber subscribes sub to current mediaTrack -func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr *WrappedReceiver, isReceiverSimulcast bool) (types.SubscribedTrack, error) { +func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr *WrappedReceiver) (types.SubscribedTrack, error) { trackID := t.params.MediaTrack.ID() subscriberID := sub.ID() @@ -146,7 +146,6 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * RTCPWriter: sub.WriteSubscriberRTCP, DisableSenderReportPassThrough: sub.GetDisableSenderReportPassThrough(), SupportsCodecChange: sub.SupportsCodecChange(), - IsReceiverSimulcast: isReceiverSimulcast, }) if err != nil { return nil, err diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 9cdc3aa03..ebf162ca5 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -1762,7 +1762,9 @@ func (p *ParticipantImpl) removePublishedTrack(track types.MediaTrack) { p.RemovePublishedTrack(track, false, true) if lmt, ok := track.(types.LocalMediaTrack); ok { - p.simulcastTrackIds.Delete(lmt.SdpCid()) + for _, sdpCid := range lmt.SdpCids() { + p.simulcastTrackIds.Delete(sdpCid) + } } if p.ProtocolVersion().SupportsUnpublish() { @@ -1808,7 +1810,13 @@ func (p *ParticipantImpl) onMediaTrack(rtcTrack *webrtc.TrackRemote, rtpReceiver codec = codecs[0] fromSdp = true } - p.params.Logger.Debugw("onMediaTrack", "codec", codec, "payloadType", codec.PayloadType, "fromSdp", fromSdp, "parameters", rtpReceiver.GetParameters()) + p.params.Logger.Debugw( + "onMediaTrack", + "codec", codec, + "payloadType", codec.PayloadType, + "fromSdp", fromSdp, + "parameters", rtpReceiver.GetParameters(), + ) var track sfu.TrackRemote = sfu.NewTrackRemoteFromSdp(rtcTrack, codec) publishedTrack, isNewTrack := p.mediaTrackReceived(track, rtpReceiver) @@ -2239,7 +2247,7 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l return nil } - track.(*MediaTrack).UpdateCodecCid(req.SimulcastCodecs) + track.(*MediaTrack).UpdateCodecSignalCid(req.SimulcastCodecs) ti := track.ToProto() return ti } @@ -2279,8 +2287,8 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l if req.Type == livekit.TrackType_VIDEO { // clients not supporting simulcast codecs, synthesise a codec ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ - Cid: req.Cid, - Layers: req.Layers, + SignalCid: req.Cid, + Layers: req.Layers, }) } } else { @@ -2315,9 +2323,9 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l clonedLayers = append(clonedLayers, utils.CloneProto(l)) } ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ - MimeType: mimeType, - Cid: codec.Cid, - Layers: clonedLayers, + MimeType: mimeType, + SignalCid: codec.Cid, + Layers: clonedLayers, }) } } @@ -2441,10 +2449,20 @@ func (p *ParticipantImpl) mediaTrackReceived(track sfu.TrackRemote, rtpReceiver return nil, false } + receiverIsSimulcast := false + cidIsSimulcast, ok := p.simulcastTrackIds.Load(track.ID()) + if ok { + receiverIsSimulcast = cidIsSimulcast.(bool) + } + // use existing media track to handle simulcast var pubTime time.Duration var isMigrated bool mt, ok := p.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack) + if !ok { + // only works for clients using same cid in signal and SDP, so won't work for clients like Firefox + mt, ok = p.getPublishedTrackBySignalCid(track.ID()).(*MediaTrack) + } if !ok { signalCid, ti, migrated, createdAt := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind()), true) if ti == nil { @@ -2483,6 +2501,9 @@ func (p *ParticipantImpl) mediaTrackReceived(track sfu.TrackRemote, rtpReceiver // only assign version on a fresh publish, i. e. avoid updating version in scenarios like migration ti.Version = p.params.VersionGenerator.Next().ToProto() } + // track level simulcast set up only for the primary codec + // assumption: when a new track is created, it is the primary codec that causes that + ti.Simulcast = receiverIsSimulcast mt = p.addMediaTrack(signalCid, track.ID(), ti) newTrack = true @@ -2496,16 +2517,9 @@ func (p *ParticipantImpl) mediaTrackReceived(track sfu.TrackRemote, rtpReceiver pubTime = time.Since(createdAt) p.dirty.Store(true) } - - isSimulcast, ok := p.simulcastTrackIds.Load(track.ID()) - if !ok { - isSimulcast = false - } - mt.SetSimulcast(isSimulcast.(bool)) - p.pendingTracksLock.Unlock() - mt.AddReceiver(rtpReceiver, track, mid) + mt.AddReceiver(rtpReceiver, track, mid, receiverIsSimulcast) if newTrack { go func() { @@ -2593,8 +2607,6 @@ func (p *ParticipantImpl) addMigratedTrack(cid string, ti *livekit.TrackInfo) *M func (p *ParticipantImpl) addMediaTrack(signalCid string, sdpCid string, ti *livekit.TrackInfo) *MediaTrack { mt := NewMediaTrack(MediaTrackParams{ - SignalCid: signalCid, - SdpCid: sdpCid, ParticipantID: p.params.SID, ParticipantIdentity: p.params.Identity, ParticipantVersion: p.version.Load(), @@ -2656,7 +2668,9 @@ func (p *ParticipantImpl) addMediaTrack(signalCid string, sdpCid string, ti *liv ) } - p.simulcastTrackIds.Delete(mt.SdpCid()) + for _, sdpCid := range mt.SdpCids() { + p.simulcastTrackIds.Delete(sdpCid) + } p.pendingTracksLock.Lock() if pti := p.pendingTracks[signalCid]; pti != nil { @@ -2730,7 +2744,7 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp for cid, pti := range p.pendingTracks { ti := pti.trackInfos[0] for _, c := range ti.Codecs { - if c.Cid == clientId { + if c.SignalCid == clientId { pendingInfo = pti signalCid = cid break track_loop @@ -2797,7 +2811,8 @@ func (p *ParticipantImpl) setTrackID(cid string, info *livekit.TrackInfo) { func (p *ParticipantImpl) getPublishedTrackBySignalCid(clientId string) types.MediaTrack { for _, publishedTrack := range p.GetPublishedTracks() { - if publishedTrack.(types.LocalMediaTrack).SignalCid() == clientId { + if publishedTrack.(types.LocalMediaTrack).HasSignalCid(clientId) { + p.pubLogger.Debugw("found track by signal cid", "signalCid", clientId, "trackID", publishedTrack.ID()) return publishedTrack } } diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 4c4d12962..dd45cb27d 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -137,7 +137,7 @@ func TestTrackPublishing(t *testing.T) { sink := p.params.Sink.(*routingfakes.FakeMessageSink) track := &typesfakes.FakeLocalMediaTrack{} - track.SignalCidReturns("cid") + track.HasSignalCidReturns(true) track.ToProtoReturns(&livekit.TrackInfo{}) // directly add to publishedTracks without lock - for testing purpose only p.UpTrackManager.publishedTracks["cid"] = track diff --git a/pkg/rtc/participant_sdp.go b/pkg/rtc/participant_sdp.go index 59d75ada9..ea01e7a36 100644 --- a/pkg/rtc/participant_sdp.go +++ b/pkg/rtc/participant_sdp.go @@ -141,7 +141,12 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess } var mimeType string for _, c := range info.Codecs { - if c.Cid == streamID { + // this is reading streamID from SDP which is technically SDP cid, + // but it is not set in TrackInfo by the time this is checked, + // hence the check on signal cid. + // As a result, this will fail for clients that use different cid + // while signalling and while doing negotiation (e. g. Firefox) + if c.SignalCid == streamID { mimeType = c.MimeType break } diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index e7d302ae8..b2888ddd9 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -1322,7 +1322,7 @@ func (t *PCTransport) AddTrackToStreamAllocator(subTrack types.SubscribedTrack) t.streamAllocator.AddTrack(subTrack.DownTrack(), streamallocator.AddTrackParams{ Source: subTrack.MediaTrack().Source(), - IsSimulcast: subTrack.MediaTrack().IsSimulcast(), + IsSimulcast: subTrack.MediaTrack().HasMultipleSpatialLayers(), PublisherID: subTrack.MediaTrack().PublisherID(), }) } diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 0c2ead4f2..1d4beecfd 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -506,6 +506,7 @@ type MediaTrack interface { SetMuted(muted bool) IsSimulcast() bool + HasMultipleSpatialLayers() bool GetAudioLevel() (level float64, active bool) @@ -542,8 +543,8 @@ type LocalMediaTrack interface { Restart() - SignalCid() string - SdpCid() string + HasSignalCid(cid string) bool + SdpCids() []string HasSdpCid(cid string) bool GetConnectionScoreAndQuality() (float32, livekit.ConnectionQuality) diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index eee0f33fb..4db6e0923 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -118,6 +118,16 @@ type FakeLocalMediaTrack struct { getTrackStatsReturnsOnCall map[int]struct { result1 *livekit.RTPStats } + HasMultipleSpatialLayersStub func() bool + hasMultipleSpatialLayersMutex sync.RWMutex + hasMultipleSpatialLayersArgsForCall []struct { + } + hasMultipleSpatialLayersReturns struct { + result1 bool + } + hasMultipleSpatialLayersReturnsOnCall map[int]struct { + result1 bool + } HasSdpCidStub func(string) bool hasSdpCidMutex sync.RWMutex hasSdpCidArgsForCall []struct { @@ -129,6 +139,17 @@ type FakeLocalMediaTrack struct { hasSdpCidReturnsOnCall map[int]struct { result1 bool } + HasSignalCidStub func(string) bool + hasSignalCidMutex sync.RWMutex + hasSignalCidArgsForCall []struct { + arg1 string + } + hasSignalCidReturns struct { + result1 bool + } + hasSignalCidReturnsOnCall map[int]struct { + result1 bool + } IDStub func() livekit.TrackID iDMutex sync.RWMutex iDArgsForCall []struct { @@ -287,15 +308,15 @@ type FakeLocalMediaTrack struct { revokeDisallowedSubscribersReturnsOnCall map[int]struct { result1 []livekit.ParticipantIdentity } - SdpCidStub func() string - sdpCidMutex sync.RWMutex - sdpCidArgsForCall []struct { + SdpCidsStub func() []string + sdpCidsMutex sync.RWMutex + sdpCidsArgsForCall []struct { } - sdpCidReturns struct { - result1 string + sdpCidsReturns struct { + result1 []string } - sdpCidReturnsOnCall map[int]struct { - result1 string + sdpCidsReturnsOnCall map[int]struct { + result1 []string } SetMutedStub func(bool) setMutedMutex sync.RWMutex @@ -307,16 +328,6 @@ type FakeLocalMediaTrack struct { setRTTArgsForCall []struct { arg1 uint32 } - SignalCidStub func() string - signalCidMutex sync.RWMutex - signalCidArgsForCall []struct { - } - signalCidReturns struct { - result1 string - } - signalCidReturnsOnCall map[int]struct { - result1 string - } SourceStub func() livekit.TrackSource sourceMutex sync.RWMutex sourceArgsForCall []struct { @@ -922,6 +933,59 @@ func (fake *FakeLocalMediaTrack) GetTrackStatsReturnsOnCall(i int, result1 *live }{result1} } +func (fake *FakeLocalMediaTrack) HasMultipleSpatialLayers() bool { + fake.hasMultipleSpatialLayersMutex.Lock() + ret, specificReturn := fake.hasMultipleSpatialLayersReturnsOnCall[len(fake.hasMultipleSpatialLayersArgsForCall)] + fake.hasMultipleSpatialLayersArgsForCall = append(fake.hasMultipleSpatialLayersArgsForCall, struct { + }{}) + stub := fake.HasMultipleSpatialLayersStub + fakeReturns := fake.hasMultipleSpatialLayersReturns + fake.recordInvocation("HasMultipleSpatialLayers", []interface{}{}) + fake.hasMultipleSpatialLayersMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) HasMultipleSpatialLayersCallCount() int { + fake.hasMultipleSpatialLayersMutex.RLock() + defer fake.hasMultipleSpatialLayersMutex.RUnlock() + return len(fake.hasMultipleSpatialLayersArgsForCall) +} + +func (fake *FakeLocalMediaTrack) HasMultipleSpatialLayersCalls(stub func() bool) { + fake.hasMultipleSpatialLayersMutex.Lock() + defer fake.hasMultipleSpatialLayersMutex.Unlock() + fake.HasMultipleSpatialLayersStub = stub +} + +func (fake *FakeLocalMediaTrack) HasMultipleSpatialLayersReturns(result1 bool) { + fake.hasMultipleSpatialLayersMutex.Lock() + defer fake.hasMultipleSpatialLayersMutex.Unlock() + fake.HasMultipleSpatialLayersStub = nil + fake.hasMultipleSpatialLayersReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) HasMultipleSpatialLayersReturnsOnCall(i int, result1 bool) { + fake.hasMultipleSpatialLayersMutex.Lock() + defer fake.hasMultipleSpatialLayersMutex.Unlock() + fake.HasMultipleSpatialLayersStub = nil + if fake.hasMultipleSpatialLayersReturnsOnCall == nil { + fake.hasMultipleSpatialLayersReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hasMultipleSpatialLayersReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeLocalMediaTrack) HasSdpCid(arg1 string) bool { fake.hasSdpCidMutex.Lock() ret, specificReturn := fake.hasSdpCidReturnsOnCall[len(fake.hasSdpCidArgsForCall)] @@ -983,6 +1047,67 @@ func (fake *FakeLocalMediaTrack) HasSdpCidReturnsOnCall(i int, result1 bool) { }{result1} } +func (fake *FakeLocalMediaTrack) HasSignalCid(arg1 string) bool { + fake.hasSignalCidMutex.Lock() + ret, specificReturn := fake.hasSignalCidReturnsOnCall[len(fake.hasSignalCidArgsForCall)] + fake.hasSignalCidArgsForCall = append(fake.hasSignalCidArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.HasSignalCidStub + fakeReturns := fake.hasSignalCidReturns + fake.recordInvocation("HasSignalCid", []interface{}{arg1}) + fake.hasSignalCidMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) HasSignalCidCallCount() int { + fake.hasSignalCidMutex.RLock() + defer fake.hasSignalCidMutex.RUnlock() + return len(fake.hasSignalCidArgsForCall) +} + +func (fake *FakeLocalMediaTrack) HasSignalCidCalls(stub func(string) bool) { + fake.hasSignalCidMutex.Lock() + defer fake.hasSignalCidMutex.Unlock() + fake.HasSignalCidStub = stub +} + +func (fake *FakeLocalMediaTrack) HasSignalCidArgsForCall(i int) string { + fake.hasSignalCidMutex.RLock() + defer fake.hasSignalCidMutex.RUnlock() + argsForCall := fake.hasSignalCidArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) HasSignalCidReturns(result1 bool) { + fake.hasSignalCidMutex.Lock() + defer fake.hasSignalCidMutex.Unlock() + fake.HasSignalCidStub = nil + fake.hasSignalCidReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) HasSignalCidReturnsOnCall(i int, result1 bool) { + fake.hasSignalCidMutex.Lock() + defer fake.hasSignalCidMutex.Unlock() + fake.HasSignalCidStub = nil + if fake.hasSignalCidReturnsOnCall == nil { + fake.hasSignalCidReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hasSignalCidReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeLocalMediaTrack) ID() livekit.TrackID { fake.iDMutex.Lock() ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] @@ -1845,15 +1970,15 @@ func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersReturnsOnCall(i int, }{result1} } -func (fake *FakeLocalMediaTrack) SdpCid() string { - fake.sdpCidMutex.Lock() - ret, specificReturn := fake.sdpCidReturnsOnCall[len(fake.sdpCidArgsForCall)] - fake.sdpCidArgsForCall = append(fake.sdpCidArgsForCall, struct { +func (fake *FakeLocalMediaTrack) SdpCids() []string { + fake.sdpCidsMutex.Lock() + ret, specificReturn := fake.sdpCidsReturnsOnCall[len(fake.sdpCidsArgsForCall)] + fake.sdpCidsArgsForCall = append(fake.sdpCidsArgsForCall, struct { }{}) - stub := fake.SdpCidStub - fakeReturns := fake.sdpCidReturns - fake.recordInvocation("SdpCid", []interface{}{}) - fake.sdpCidMutex.Unlock() + stub := fake.SdpCidsStub + fakeReturns := fake.sdpCidsReturns + fake.recordInvocation("SdpCids", []interface{}{}) + fake.sdpCidsMutex.Unlock() if stub != nil { return stub() } @@ -1863,38 +1988,38 @@ func (fake *FakeLocalMediaTrack) SdpCid() string { return fakeReturns.result1 } -func (fake *FakeLocalMediaTrack) SdpCidCallCount() int { - fake.sdpCidMutex.RLock() - defer fake.sdpCidMutex.RUnlock() - return len(fake.sdpCidArgsForCall) +func (fake *FakeLocalMediaTrack) SdpCidsCallCount() int { + fake.sdpCidsMutex.RLock() + defer fake.sdpCidsMutex.RUnlock() + return len(fake.sdpCidsArgsForCall) } -func (fake *FakeLocalMediaTrack) SdpCidCalls(stub func() string) { - fake.sdpCidMutex.Lock() - defer fake.sdpCidMutex.Unlock() - fake.SdpCidStub = stub +func (fake *FakeLocalMediaTrack) SdpCidsCalls(stub func() []string) { + fake.sdpCidsMutex.Lock() + defer fake.sdpCidsMutex.Unlock() + fake.SdpCidsStub = stub } -func (fake *FakeLocalMediaTrack) SdpCidReturns(result1 string) { - fake.sdpCidMutex.Lock() - defer fake.sdpCidMutex.Unlock() - fake.SdpCidStub = nil - fake.sdpCidReturns = struct { - result1 string +func (fake *FakeLocalMediaTrack) SdpCidsReturns(result1 []string) { + fake.sdpCidsMutex.Lock() + defer fake.sdpCidsMutex.Unlock() + fake.SdpCidsStub = nil + fake.sdpCidsReturns = struct { + result1 []string }{result1} } -func (fake *FakeLocalMediaTrack) SdpCidReturnsOnCall(i int, result1 string) { - fake.sdpCidMutex.Lock() - defer fake.sdpCidMutex.Unlock() - fake.SdpCidStub = nil - if fake.sdpCidReturnsOnCall == nil { - fake.sdpCidReturnsOnCall = make(map[int]struct { - result1 string +func (fake *FakeLocalMediaTrack) SdpCidsReturnsOnCall(i int, result1 []string) { + fake.sdpCidsMutex.Lock() + defer fake.sdpCidsMutex.Unlock() + fake.SdpCidsStub = nil + if fake.sdpCidsReturnsOnCall == nil { + fake.sdpCidsReturnsOnCall = make(map[int]struct { + result1 []string }) } - fake.sdpCidReturnsOnCall[i] = struct { - result1 string + fake.sdpCidsReturnsOnCall[i] = struct { + result1 []string }{result1} } @@ -1962,59 +2087,6 @@ func (fake *FakeLocalMediaTrack) SetRTTArgsForCall(i int) uint32 { return argsForCall.arg1 } -func (fake *FakeLocalMediaTrack) SignalCid() string { - fake.signalCidMutex.Lock() - ret, specificReturn := fake.signalCidReturnsOnCall[len(fake.signalCidArgsForCall)] - fake.signalCidArgsForCall = append(fake.signalCidArgsForCall, struct { - }{}) - stub := fake.SignalCidStub - fakeReturns := fake.signalCidReturns - fake.recordInvocation("SignalCid", []interface{}{}) - fake.signalCidMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeLocalMediaTrack) SignalCidCallCount() int { - fake.signalCidMutex.RLock() - defer fake.signalCidMutex.RUnlock() - return len(fake.signalCidArgsForCall) -} - -func (fake *FakeLocalMediaTrack) SignalCidCalls(stub func() string) { - fake.signalCidMutex.Lock() - defer fake.signalCidMutex.Unlock() - fake.SignalCidStub = stub -} - -func (fake *FakeLocalMediaTrack) SignalCidReturns(result1 string) { - fake.signalCidMutex.Lock() - defer fake.signalCidMutex.Unlock() - fake.SignalCidStub = nil - fake.signalCidReturns = struct { - result1 string - }{result1} -} - -func (fake *FakeLocalMediaTrack) SignalCidReturnsOnCall(i int, result1 string) { - fake.signalCidMutex.Lock() - defer fake.signalCidMutex.Unlock() - fake.SignalCidStub = nil - if fake.signalCidReturnsOnCall == nil { - fake.signalCidReturnsOnCall = make(map[int]struct { - result1 string - }) - } - fake.signalCidReturnsOnCall[i] = struct { - result1 string - }{result1} -} - func (fake *FakeLocalMediaTrack) Source() livekit.TrackSource { fake.sourceMutex.Lock() ret, specificReturn := fake.sourceReturnsOnCall[len(fake.sourceArgsForCall)] @@ -2295,8 +2367,12 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} { defer fake.getTemporalLayerForSpatialFpsMutex.RUnlock() fake.getTrackStatsMutex.RLock() defer fake.getTrackStatsMutex.RUnlock() + fake.hasMultipleSpatialLayersMutex.RLock() + defer fake.hasMultipleSpatialLayersMutex.RUnlock() fake.hasSdpCidMutex.RLock() defer fake.hasSdpCidMutex.RUnlock() + fake.hasSignalCidMutex.RLock() + defer fake.hasSignalCidMutex.RUnlock() fake.iDMutex.RLock() defer fake.iDMutex.RUnlock() fake.isEncryptedMutex.RLock() @@ -2333,14 +2409,12 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} { defer fake.restartMutex.RUnlock() fake.revokeDisallowedSubscribersMutex.RLock() defer fake.revokeDisallowedSubscribersMutex.RUnlock() - fake.sdpCidMutex.RLock() - defer fake.sdpCidMutex.RUnlock() + fake.sdpCidsMutex.RLock() + defer fake.sdpCidsMutex.RUnlock() fake.setMutedMutex.RLock() defer fake.setMutedMutex.RUnlock() fake.setRTTMutex.RLock() defer fake.setRTTMutex.RUnlock() - fake.signalCidMutex.RLock() - defer fake.signalCidMutex.RUnlock() fake.sourceMutex.RLock() defer fake.sourceMutex.RUnlock() fake.streamMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index 67304714d..9bef455ca 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -96,6 +96,16 @@ type FakeMediaTrack struct { getTemporalLayerForSpatialFpsReturnsOnCall map[int]struct { result1 int32 } + HasMultipleSpatialLayersStub func() bool + hasMultipleSpatialLayersMutex sync.RWMutex + hasMultipleSpatialLayersArgsForCall []struct { + } + hasMultipleSpatialLayersReturns struct { + result1 bool + } + hasMultipleSpatialLayersReturnsOnCall map[int]struct { + result1 bool + } IDStub func() livekit.TrackID iDMutex sync.RWMutex iDArgsForCall []struct { @@ -739,6 +749,59 @@ func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsReturnsOnCall(i int, re }{result1} } +func (fake *FakeMediaTrack) HasMultipleSpatialLayers() bool { + fake.hasMultipleSpatialLayersMutex.Lock() + ret, specificReturn := fake.hasMultipleSpatialLayersReturnsOnCall[len(fake.hasMultipleSpatialLayersArgsForCall)] + fake.hasMultipleSpatialLayersArgsForCall = append(fake.hasMultipleSpatialLayersArgsForCall, struct { + }{}) + stub := fake.HasMultipleSpatialLayersStub + fakeReturns := fake.hasMultipleSpatialLayersReturns + fake.recordInvocation("HasMultipleSpatialLayers", []interface{}{}) + fake.hasMultipleSpatialLayersMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) HasMultipleSpatialLayersCallCount() int { + fake.hasMultipleSpatialLayersMutex.RLock() + defer fake.hasMultipleSpatialLayersMutex.RUnlock() + return len(fake.hasMultipleSpatialLayersArgsForCall) +} + +func (fake *FakeMediaTrack) HasMultipleSpatialLayersCalls(stub func() bool) { + fake.hasMultipleSpatialLayersMutex.Lock() + defer fake.hasMultipleSpatialLayersMutex.Unlock() + fake.HasMultipleSpatialLayersStub = stub +} + +func (fake *FakeMediaTrack) HasMultipleSpatialLayersReturns(result1 bool) { + fake.hasMultipleSpatialLayersMutex.Lock() + defer fake.hasMultipleSpatialLayersMutex.Unlock() + fake.HasMultipleSpatialLayersStub = nil + fake.hasMultipleSpatialLayersReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeMediaTrack) HasMultipleSpatialLayersReturnsOnCall(i int, result1 bool) { + fake.hasMultipleSpatialLayersMutex.Lock() + defer fake.hasMultipleSpatialLayersMutex.Unlock() + fake.HasMultipleSpatialLayersStub = nil + if fake.hasMultipleSpatialLayersReturnsOnCall == nil { + fake.hasMultipleSpatialLayersReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hasMultipleSpatialLayersReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeMediaTrack) ID() livekit.TrackID { fake.iDMutex.Lock() ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] @@ -1814,6 +1877,8 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.getQualityForDimensionMutex.RUnlock() fake.getTemporalLayerForSpatialFpsMutex.RLock() defer fake.getTemporalLayerForSpatialFpsMutex.RUnlock() + fake.hasMultipleSpatialLayersMutex.RLock() + defer fake.hasMultipleSpatialLayersMutex.RUnlock() fake.iDMutex.RLock() defer fake.iDMutex.RUnlock() fake.isEncryptedMutex.RLock() diff --git a/pkg/rtc/wrappedreceiver.go b/pkg/rtc/wrappedreceiver.go index a4d41a715..221975dd1 100644 --- a/pkg/rtc/wrappedreceiver.go +++ b/pkg/rtc/wrappedreceiver.go @@ -391,6 +391,14 @@ func (d *DummyReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { } } +func (d *DummyReceiver) IsSimulcast() bool { + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + return r.IsSimulcast() + } + + return false +} + func (d *DummyReceiver) IsClosed() bool { if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { return r.IsClosed() diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 3c7d2d80e..1cba43b23 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -240,7 +240,6 @@ type DowntrackParams struct { RTCPWriter func([]rtcp.Packet) error DisableSenderReportPassThrough bool SupportsCodecChange bool - IsReceiverSimulcast bool } // DownTrack implements TrackLocal, is the track used to write packets @@ -406,7 +405,6 @@ func NewDownTrack(params DowntrackParams) (*DownTrack, error) { d.params.Logger, false, d.rtpStats, - d.params.IsReceiverSimulcast, ) d.connectionStats = connectionquality.NewConnectionStats(connectionquality.ConnectionStatsParams{ @@ -588,7 +586,8 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.setBindStateLocked(bindStateBound) d.bindLock.Unlock() - d.forwarder.DetermineCodec(codec.RTPCodecCapability, d.Receiver().HeaderExtensions()) + receiver := d.Receiver() + d.forwarder.DetermineCodec(codec.RTPCodecCapability, receiver.HeaderExtensions(), receiver.IsSimulcast()) d.connectionStats.Start(d.Mime(), isFECEnabled) d.params.Logger.Debugw("downtrack bound") } @@ -697,7 +696,8 @@ func (d *DownTrack) handleUpstreamCodecChange(mimeType string) { ) d.forwarder.Restart() - d.forwarder.DetermineCodec(codec.RTPCodecCapability, d.Receiver().HeaderExtensions()) + receiver := d.Receiver() + d.forwarder.DetermineCodec(codec.RTPCodecCapability, receiver.HeaderExtensions(), receiver.IsSimulcast()) d.connectionStats.UpdateCodec(d.Mime(), isFECEnabled) } diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index 1acfd6834..3d70f10b2 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -223,7 +223,6 @@ type Forwarder struct { started bool preStartTime time.Time - isReceiverSimulcast bool extFirstTS uint64 lastSSRC uint32 lastReferencePayloadType int8 @@ -249,7 +248,6 @@ func NewForwarder( logger logger.Logger, skipReferenceTS bool, rtpStats *rtpstats.RTPStatsSender, - isReceiverSimulcast bool, ) *Forwarder { f := &Forwarder{ mime: mime.MimeTypeUnknown, @@ -257,7 +255,6 @@ func NewForwarder( logger: logger, skipReferenceTS: skipReferenceTS, rtpStats: rtpStats, - isReceiverSimulcast: isReceiverSimulcast, referenceLayerSpatial: buffer.InvalidLayerSpatial, lastAllocation: VideoAllocationDefault, lastReferencePayloadType: -1, @@ -300,7 +297,11 @@ func (f *Forwarder) SetMaxTemporalLayerSeen(maxTemporalLayerSeen int32) bool { return true } -func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions []webrtc.RTPHeaderExtensionParameter) { +func (f *Forwarder) DetermineCodec( + codec webrtc.RTPCodecCapability, + extensions []webrtc.RTPHeaderExtensionParameter, + isReceiverSimulcast bool, +) { f.lock.Lock() defer f.lock.Unlock() @@ -343,7 +344,7 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ } case mime.MimeTypeVP9: - if f.isReceiverSimulcast { + if isReceiverSimulcast { f.logger.Debugw("selecting simulcast video layer selector for VP9") if f.vls != nil { f.vls = videolayerselector.NewSimulcastFromOther(f.vls) @@ -372,7 +373,7 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ case mime.MimeTypeAV1: isDDAvailable := ddAvailable(extensions) - if f.isReceiverSimulcast || !isDDAvailable { + if isReceiverSimulcast || !isDDAvailable { // AV1-SIMULCAST-TODO: Add temporal layer selector for AV1 f.logger.Debugw("selecting simulcast video layer selector for AV1") if f.vls != nil { diff --git a/pkg/sfu/forwarder_test.go b/pkg/sfu/forwarder_test.go index 2239ee690..ef7b3e0c2 100644 --- a/pkg/sfu/forwarder_test.go +++ b/pkg/sfu/forwarder_test.go @@ -32,8 +32,8 @@ func disable(f *Forwarder) { } func newForwarder(codec webrtc.RTPCodecCapability, kind webrtc.RTPCodecType) *Forwarder { - f := NewForwarder(kind, logger.GetLogger(), true, nil, false) - f.DetermineCodec(codec, nil) + f := NewForwarder(kind, logger.GetLogger(), true, nil) + f.DetermineCodec(codec, nil, false) return f } diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 4e3de0699..382e6332a 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -123,6 +123,7 @@ type TrackReceiver interface { TrackInfo() *livekit.TrackInfo UpdateTrackInfo(ti *livekit.TrackInfo) + IsSimulcast() bool // Get primary receiver if this receiver represents a RED codec; otherwise it will return itself GetPrimaryReceiverForRed() TrackReceiver @@ -268,18 +269,17 @@ func NewWebRTCReceiver( isRED: mime.IsMimeTypeStringRED(track.Codec().MimeType), } - isSVC := false - isSimulcast := trackInfo.GetSimulcast() - if !isSimulcast { - isSVC = mime.IsMimeTypeStringSVC(track.Codec().MimeType) - } - w.isSVC = isSVC - for _, opt := range opts { w = opt(w) } w.trackInfo.Store(utils.CloneProto(trackInfo)) + isSVC := false + if !w.IsSimulcast() { + isSVC = mime.IsMimeTypeStringSVC(track.Codec().MimeType) + } + w.isSVC = isSVC + w.downTrackSpreader = NewDownTrackSpreader(DownTrackSpreaderParams{ Threshold: w.lbThreshold, Logger: logger, @@ -324,6 +324,16 @@ func (w *WebRTCReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { w.streamTrackerManager.UpdateTrackInfo(ti) } +func (w *WebRTCReceiver) IsSimulcast() bool { + for _, codec := range w.trackInfo.Load().Codecs { + if mime.IsMimeTypeStringEqual(codec.MimeType, w.codec.MimeType) { + return codec.IsSimulcast + } + } + + return false +} + func (w *WebRTCReceiver) OnStatsUpdate(fn func(w *WebRTCReceiver, stat *livekit.AnalyticsStat)) { w.onStatsUpdate = fn } @@ -861,13 +871,9 @@ func (w *WebRTCReceiver) closeTracks() { } func (w *WebRTCReceiver) DebugInfo() map[string]interface{} { - isSimulcast := !w.isSVC - if ti := w.trackInfo.Load(); ti != nil { - isSimulcast = isSimulcast && len(ti.Layers) > 1 - } info := map[string]interface{}{ "SVC": w.isSVC, - "Simulcast": isSimulcast, + "Simulcast": w.IsSimulcast(), } w.bufferMu.RLock()