diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index fcecc3c59..702c2537b 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -28,6 +28,7 @@ const ( lossyDataChannel = "_lossy" reliableDataChannel = "_reliable" sdBatchSize = 20 + trackCleanupDelay = 5 * time.Second ) type ParticipantParams struct { @@ -582,12 +583,10 @@ func (p *ParticipantImpl) SetTrackMuted(trackId string, muted bool, fromAdmin bo track := p.publishedTracks[trackId] p.lock.RUnlock() - // already handled - if isPending { - return - } if track == nil { - logger.Warnw("could not locate track", nil, "track", trackId) + if !isPending { + logger.Warnw("could not locate track", nil, "track", trackId) + } return } currentMuted := track.IsMuted() @@ -783,12 +782,13 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w // use existing mediatrack to handle simulcast p.lock.Lock() - ti := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind()), false) + cid, ti := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) if ti == nil { p.lock.Unlock() return } ptrack := p.publishedTracks[ti.Sid] + p.lock.Unlock() var mt *MediaTrack var newTrack bool @@ -822,10 +822,12 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w mt.AddReceiver(rtpReceiver, track, p.twcc) // cleanup pendingTracks - if !mt.simulcasted || mt.NumUpTracks() == 3 { - _ = p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind()), true) - } - p.lock.Unlock() + defer func() { + time.Sleep(trackCleanupDelay) + p.lock.Lock() + delete(p.pendingTracks, cid) + p.lock.Unlock() + }() if newTrack { p.handleTrackPublished(mt) @@ -852,7 +854,7 @@ func (p *ParticipantImpl) onDataChannel(dc *webrtc.DataChannel) { } } -func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType, deleteAfter bool) *livekit.TrackInfo { +func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) { ti := p.pendingTracks[clientId] // then find the first one that matches type. with MediaStreamTrack, it's possible for the client id to @@ -870,10 +872,8 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp // if still not found, we are done if ti == nil { logger.Errorw("track info not published prior to track", nil, "clientId", clientId) - } else if deleteAfter { - delete(p.pendingTracks, clientId) } - return ti + return clientId, ti } func (p *ParticipantImpl) handleDataMessage(kind livekit.DataPacket_Kind, data []byte) { diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 1c8cb46ec..9d2d56e91 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -178,7 +178,7 @@ func TestMuteSetting(t *testing.T) { Muted: true, }) - ti := p.getPendingTrack("cid", livekit.TrackType_AUDIO, false) + _, ti := p.getPendingTrack("cid", livekit.TrackType_AUDIO) require.NotNil(t, ti) require.True(t, ti.Muted) })