diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 68fb8d168..08e0b78cb 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -53,6 +53,8 @@ type MediaTrack struct { type MediaTrackParams struct { TrackInfo *livekit.TrackInfo + SignalCid string + SdpCid string ParticipantID string RTCPChan chan []rtcp.Packet BufferFactory *buffer.Factory @@ -83,6 +85,14 @@ func (t *MediaTrack) ID() string { return t.params.TrackInfo.Sid } +func (t *MediaTrack) SignalCid() string { + return t.params.SignalCid +} + +func (t *MediaTrack) SdpCid() string { + return t.params.SdpCid +} + func (t *MediaTrack) Kind() livekit.TrackType { return t.params.TrackInfo.Type } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 237545cda..87ee64c2c 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -29,7 +29,6 @@ const ( lossyDataChannel = "_lossy" reliableDataChannel = "_reliable" sdBatchSize = 20 - trackCleanupDelay = 5 * time.Second ) type ParticipantParams struct { @@ -330,6 +329,10 @@ func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) { return } + if p.getPublishedTrackBySignalCid(req.Cid) != nil || p.getPublishedTrackBySdpCid(req.Cid) != nil { + return + } + if !p.CanPublish() { logger.Warnw("no permission to publish track", nil, "participant", p.Identity(), "pID", p.ID()) @@ -809,7 +812,8 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w "participant", p.Identity(), "pID", p.ID(), "track", track.ID(), - "rid", track.RID()) + "rid", track.RID(), + "SSRC", track.SSRC()) if !p.CanPublish() { logger.Warnw("no permission to publish mediaTrack", nil, @@ -817,23 +821,22 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w return } + var newTrack bool + // use existing mediatrack to handle simulcast p.lock.Lock() - cid, ti := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) - if ti == nil { - p.lock.Unlock() - return - } - ptrack := p.publishedTracks[ti.Sid] - p.lock.Unlock() + mt, ok := p.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack) + if !ok { + signalCid, ti := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) + if ti == nil { + p.lock.Unlock() + return + } - var mt *MediaTrack - var newTrack bool - if trk, ok := ptrack.(*MediaTrack); ok { - mt = trk - } else { mt = NewMediaTrack(track, MediaTrackParams{ TrackInfo: ti, + SignalCid: signalCid, + SdpCid: track.ID(), ParticipantID: p.id, RTCPChan: p.rtcpCh, BufferFactory: p.params.Config.BufferFactory, @@ -841,8 +844,14 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w AudioConfig: p.params.AudioConfig, Stats: p.params.Stats, }) + + // add to published and clean up pending + p.publishedTracks[mt.ID()] = mt + delete(p.pendingTracks, signalCid) + newTrack = true } + p.lock.Unlock() ssrc := uint32(track.SSRC()) p.pliThrottle.addTrack(ssrc, track.RID()) @@ -854,14 +863,6 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w } mt.AddReceiver(rtpReceiver, track, p.twcc) - // cleanup pendingTracks - defer func() { - time.Sleep(trackCleanupDelay) - p.lock.Lock() - delete(p.pendingTracks, cid) - p.lock.Unlock() - }() - if newTrack { p.handleTrackPublished(mt) } @@ -887,7 +888,31 @@ func (p *ParticipantImpl) onDataChannel(dc *webrtc.DataChannel) { } } +// should be called with lock held +func (p *ParticipantImpl) getPublishedTrackBySignalCid(clientId string) types.PublishedTrack { + for _, publishedTrack := range p.publishedTracks { + if publishedTrack.SignalCid() == clientId { + return publishedTrack + } + } + + return nil +} + +// should be called with lock held +func (p *ParticipantImpl) getPublishedTrackBySdpCid(clientId string) types.PublishedTrack { + for _, publishedTrack := range p.publishedTracks { + if publishedTrack.SdpCid() == clientId { + return publishedTrack + } + } + + return nil +} + +// should be called with lock held func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) { + signalCid := clientId ti := p.pendingTracks[clientId] // then find the first one that matches type. with MediaStreamTrack, it's possible for the client id to @@ -896,7 +921,7 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp for cid, info := range p.pendingTracks { if info.Type == kind { ti = info - clientId = cid + signalCid = cid break } } @@ -906,7 +931,7 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp if ti == nil { logger.Errorw("track info not published prior to track", nil, "clientId", clientId) } - return clientId, ti + return signalCid, ti } func (p *ParticipantImpl) handleDataMessage(kind livekit.DataPacket_Kind, data []byte) { @@ -932,9 +957,10 @@ func (p *ParticipantImpl) handleDataMessage(kind livekit.DataPacket_Kind, data [ } func (p *ParticipantImpl) handleTrackPublished(track types.PublishedTrack) { - // fill in p.lock.Lock() - p.publishedTracks[track.ID()] = track + if _, ok := p.publishedTracks[track.ID()]; !ok { + p.publishedTracks[track.ID()] = track + } p.lock.Unlock() track.Start() diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 9d2d56e91..4b2384b1b 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -131,6 +131,40 @@ func TestTrackPublishing(t *testing.T) { require.Equal(t, 1, sink.WriteMessageCallCount()) }) + + t.Run("should not allow adding of duplicate tracks if already published by client id in signalling", func(t *testing.T) { + p := newParticipantForTest("test") + sink := p.params.Sink.(*routingfakes.FakeMessageSink) + + track := &typesfakes.FakePublishedTrack{} + track.SignalCidReturns("cid") + // directly add to publishedTracks without lock - for testing purpose only + p.publishedTracks["cid"] = track + + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Type: livekit.TrackType_VIDEO, + }) + require.Equal(t, 0, sink.WriteMessageCallCount()) + }) + + t.Run("should not allow adding of duplicate tracks if already published by client id in sdp", func(t *testing.T) { + p := newParticipantForTest("test") + sink := p.params.Sink.(*routingfakes.FakeMessageSink) + + track := &typesfakes.FakePublishedTrack{} + track.SdpCidReturns("cid") + // directly add to publishedTracks without lock - for testing purpose only + p.publishedTracks["cid"] = track + + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Type: livekit.TrackType_VIDEO, + }) + require.Equal(t, 0, sink.WriteMessageCallCount()) + }) } // after disconnection, things should continue to function and not panic diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 6f88b7fcc..b5926580b 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -91,6 +91,8 @@ type Participant interface { type PublishedTrack interface { Start() ID() string + SignalCid() string + SdpCid() string Kind() livekit.TrackType Name() string IsMuted() bool diff --git a/pkg/rtc/types/typesfakes/fake_published_track.go b/pkg/rtc/types/typesfakes/fake_published_track.go index 01891ec58..60aeb7c5b 100644 --- a/pkg/rtc/types/typesfakes/fake_published_track.go +++ b/pkg/rtc/types/typesfakes/fake_published_track.go @@ -85,11 +85,31 @@ type FakePublishedTrack struct { removeSubscriberArgsForCall []struct { arg1 string } + SdpCidStub func() string + sdpCidMutex sync.RWMutex + sdpCidArgsForCall []struct { + } + sdpCidReturns struct { + result1 string + } + sdpCidReturnsOnCall map[int]struct { + result1 string + } SetMutedStub func(bool) setMutedMutex sync.RWMutex setMutedArgsForCall []struct { arg1 bool } + SignalCidStub func() string + signalCidMutex sync.RWMutex + signalCidArgsForCall []struct { + } + signalCidReturns struct { + result1 string + } + signalCidReturnsOnCall map[int]struct { + result1 string + } StartStub func() startMutex sync.RWMutex startArgsForCall []struct { @@ -530,6 +550,59 @@ func (fake *FakePublishedTrack) RemoveSubscriberArgsForCall(i int) string { return argsForCall.arg1 } +func (fake *FakePublishedTrack) SdpCid() string { + fake.sdpCidMutex.Lock() + ret, specificReturn := fake.sdpCidReturnsOnCall[len(fake.sdpCidArgsForCall)] + fake.sdpCidArgsForCall = append(fake.sdpCidArgsForCall, struct { + }{}) + stub := fake.SdpCidStub + fakeReturns := fake.sdpCidReturns + fake.recordInvocation("SdpCid", []interface{}{}) + fake.sdpCidMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakePublishedTrack) SdpCidCallCount() int { + fake.sdpCidMutex.RLock() + defer fake.sdpCidMutex.RUnlock() + return len(fake.sdpCidArgsForCall) +} + +func (fake *FakePublishedTrack) SdpCidCalls(stub func() string) { + fake.sdpCidMutex.Lock() + defer fake.sdpCidMutex.Unlock() + fake.SdpCidStub = stub +} + +func (fake *FakePublishedTrack) SdpCidReturns(result1 string) { + fake.sdpCidMutex.Lock() + defer fake.sdpCidMutex.Unlock() + fake.SdpCidStub = nil + fake.sdpCidReturns = struct { + result1 string + }{result1} +} + +func (fake *FakePublishedTrack) SdpCidReturnsOnCall(i int, result1 string) { + fake.sdpCidMutex.Lock() + defer fake.sdpCidMutex.Unlock() + fake.SdpCidStub = nil + if fake.sdpCidReturnsOnCall == nil { + fake.sdpCidReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.sdpCidReturnsOnCall[i] = struct { + result1 string + }{result1} +} + func (fake *FakePublishedTrack) SetMuted(arg1 bool) { fake.setMutedMutex.Lock() fake.setMutedArgsForCall = append(fake.setMutedArgsForCall, struct { @@ -562,6 +635,59 @@ func (fake *FakePublishedTrack) SetMutedArgsForCall(i int) bool { return argsForCall.arg1 } +func (fake *FakePublishedTrack) SignalCid() string { + fake.signalCidMutex.Lock() + ret, specificReturn := fake.signalCidReturnsOnCall[len(fake.signalCidArgsForCall)] + fake.signalCidArgsForCall = append(fake.signalCidArgsForCall, struct { + }{}) + stub := fake.SignalCidStub + fakeReturns := fake.signalCidReturns + fake.recordInvocation("SignalCid", []interface{}{}) + fake.signalCidMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakePublishedTrack) SignalCidCallCount() int { + fake.signalCidMutex.RLock() + defer fake.signalCidMutex.RUnlock() + return len(fake.signalCidArgsForCall) +} + +func (fake *FakePublishedTrack) SignalCidCalls(stub func() string) { + fake.signalCidMutex.Lock() + defer fake.signalCidMutex.Unlock() + fake.SignalCidStub = stub +} + +func (fake *FakePublishedTrack) SignalCidReturns(result1 string) { + fake.signalCidMutex.Lock() + defer fake.signalCidMutex.Unlock() + fake.SignalCidStub = nil + fake.signalCidReturns = struct { + result1 string + }{result1} +} + +func (fake *FakePublishedTrack) SignalCidReturnsOnCall(i int, result1 string) { + fake.signalCidMutex.Lock() + defer fake.signalCidMutex.Unlock() + fake.SignalCidStub = nil + if fake.signalCidReturnsOnCall == nil { + fake.signalCidReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.signalCidReturnsOnCall[i] = struct { + result1 string + }{result1} +} + func (fake *FakePublishedTrack) Start() { fake.startMutex.Lock() fake.startArgsForCall = append(fake.startArgsForCall, struct { @@ -660,8 +786,12 @@ func (fake *FakePublishedTrack) Invocations() map[string][][]interface{} { defer fake.removeAllSubscribersMutex.RUnlock() fake.removeSubscriberMutex.RLock() defer fake.removeSubscriberMutex.RUnlock() + fake.sdpCidMutex.RLock() + defer fake.sdpCidMutex.RUnlock() fake.setMutedMutex.RLock() defer fake.setMutedMutex.RUnlock() + fake.signalCidMutex.RLock() + defer fake.signalCidMutex.RUnlock() fake.startMutex.RLock() defer fake.startMutex.RUnlock() fake.toProtoMutex.RLock()