From 52b2e6398bfe3c35b424e80e2e5d0584acf3161d Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Fri, 29 Jul 2022 11:51:36 +0530 Subject: [PATCH] Queue `AddTrack` if a published track is not yet closed (#857) * Queue `AddTrack` if a published track is not yet closed - Adding a queue for pending track by signal cid. Ideally, there should not be more than one pending, but making a queue to be generic. - `TrackPublished` is sent if the queue has entries when a published track is closed. * Fix tests and add more checks for queueing AddTrack --- pkg/rtc/participant.go | 302 ++++++++++++++------------- pkg/rtc/participant_internal_test.go | 38 +++- 2 files changed, 195 insertions(+), 145 deletions(-) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 9e48191de..4e24eb8e5 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -40,8 +40,8 @@ const ( ) type pendingTrackInfo struct { - *livekit.TrackInfo - migrated bool + trackInfos []*livekit.TrackInfo + migrated bool } type downTrackState struct { @@ -598,19 +598,31 @@ func (p *ParticipantImpl) HandleOffer(sdp webrtc.SessionDescription) (answer web func (p *ParticipantImpl) handleMigrateMutedTrack() { // muted track won't send rtp packet, so we add mediatrack manually - var addedTrack []*MediaTrack + var addedTracks []*MediaTrack p.pendingTracksLock.Lock() - for cid, t := range p.pendingTracks { - if t.migrated && t.Muted && t.Type == livekit.TrackType_VIDEO { - addedTrack = append(addedTrack, p.addMigrateMutedTrack(cid, t.TrackInfo)) + for cid, pti := range p.pendingTracks { + if !pti.migrated { + continue + } + + if len(pti.trackInfos) > 1 { + p.params.Logger.Warnw("too many pending migrated tracks", nil, "count", len(pti.trackInfos), "cid", cid) + } + + ti := pti.trackInfos[0] + if ti.Muted && ti.Type == livekit.TrackType_VIDEO { + mt := p.addMigrateMutedTrack(cid, ti) + if mt != nil { + addedTracks = append(addedTracks, mt) + } else { + p.params.Logger.Warnw("could not find migrated muted track", nil, "cid", cid) + } } } p.pendingTracksLock.Unlock() - for _, t := range addedTrack { - if t != nil { - p.handleTrackPublished(t) - } + for _, t := range addedTracks { + p.handleTrackPublished(t) } } @@ -630,21 +642,13 @@ func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) { return } - _ = p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_TrackPublished{ - TrackPublished: &livekit.TrackPublishedResponse{ - Cid: req.Cid, - Track: ti, - }, - }, - }) + p.sendTrackPublished(req.Cid, ti) } func (p *ParticipantImpl) SetMigrateInfo(previousAnswer *webrtc.SessionDescription, mediaTracks []*livekit.TrackPublishedResponse, dataChannels []*livekit.DataChannelInfo) { p.pendingTracksLock.Lock() for _, t := range mediaTracks { - pendingInfo := &pendingTrackInfo{TrackInfo: t.GetTrack(), migrated: true} - p.pendingTracks[t.GetCid()] = pendingInfo + p.pendingTracks[t.GetCid()] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{t.GetTrack()}, migrated: true} } p.pendingDataChannels = dataChannels p.pendingTracksLock.Unlock() @@ -1556,18 +1560,9 @@ func (p *ParticipantImpl) onSubscribedMaxQualityChange(trackID livekit.TrackID, } func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *livekit.TrackInfo { - if p.getPublishedTrackBySignalCid(req.Cid) != nil || p.getPublishedTrackBySdpCid(req.Cid) != nil { - return nil - } - p.pendingTracksLock.Lock() defer p.pendingTracksLock.Unlock() - // if track is already published, reject - if p.pendingTracks[req.Cid] != nil { - return nil - } - if req.Sid != "" { track := p.GetPublishedTrack(livekit.TrackID(req.Sid)) if track == nil { @@ -1590,8 +1585,7 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l Source: req.Source, Layers: req.Layers, } - p.setStableTrackID(ti) - pendingInfo := &pendingTrackInfo{TrackInfo: ti} + p.setStableTrackID(req.Cid, ti) for _, codec := range req.SimulcastCodecs { mime := codec.Codec if req.Type == livekit.TrackType_VIDEO && !strings.HasPrefix(mime, "video/") { @@ -1605,12 +1599,33 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l }) } - p.pendingTracks[req.Cid] = pendingInfo - p.params.Logger.Debugw("pending track added", "track", ti.String(), "request", req.String()) + if p.getPublishedTrackBySignalCid(req.Cid) != nil || p.getPublishedTrackBySdpCid(req.Cid) != nil || p.pendingTracks[req.Cid] != nil { + if p.pendingTracks[req.Cid] == nil { + p.pendingTracks[req.Cid] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}} + } else { + p.pendingTracks[req.Cid].trackInfos = append(p.pendingTracks[req.Cid].trackInfos, ti) + } + p.params.Logger.Debugw("pending track queued", "track", ti.String(), "request", req.String()) + return nil + } + p.pendingTracks[req.Cid] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}} + p.params.Logger.Debugw("pending track added", "track", ti.String(), "request", req.String()) return ti } +func (p *ParticipantImpl) sendTrackPublished(cid string, ti *livekit.TrackInfo) { + p.params.Logger.Debugw("sending track published", "cid", cid, "trackInfo", ti.String()) + _ = p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_TrackPublished{ + TrackPublished: &livekit.TrackPublishedResponse{ + Cid: cid, + Track: ti, + }, + }, + }) +} + func (p *ParticipantImpl) SetTrackMuted(trackID livekit.TrackID, muted bool, fromAdmin bool) { // when request is coming from admin, send message to current participant if fromAdmin { @@ -1622,23 +1637,20 @@ func (p *ParticipantImpl) SetTrackMuted(trackID livekit.TrackID, muted bool, fro func (p *ParticipantImpl) setTrackMuted(trackID livekit.TrackID, muted bool) { track := p.UpTrackManager.SetPublishedTrackMuted(trackID, muted) - if track != nil { - // handled in UpTrackManager for a published track, no need to update state of pending track - return - } isPending := false p.pendingTracksLock.RLock() - for _, ti := range p.pendingTracks { - if livekit.TrackID(ti.Sid) == trackID { - ti.Muted = muted - isPending = true - break + for _, pti := range p.pendingTracks { + for _, ti := range pti.trackInfos { + if livekit.TrackID(ti.Sid) == trackID { + ti.Muted = muted + isPending = true + } } } p.pendingTracksLock.RUnlock() - if !isPending { + if !isPending && track == nil { p.params.Logger.Warnw("could not locate track", nil, "trackID", trackID) } } @@ -1670,11 +1682,10 @@ func (p *ParticipantImpl) getDTX() bool { // Most of the time in practice, there is going to be one // audio kind track and hence this is fine. // - for _, ti := range p.pendingTracks { - if ti.Type == livekit.TrackType_AUDIO { - if !ti.TrackInfo.DisableDtx { - return true - } + for _, pti := range p.pendingTracks { + ti := pti.trackInfos[0] + if ti != nil && ti.Type == livekit.TrackType_AUDIO { + return !ti.DisableDtx } } @@ -1702,38 +1713,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei return nil, false } - mt = NewMediaTrack(MediaTrackParams{ - TrackInfo: ti, - SignalCid: signalCid, - SdpCid: track.ID(), - ParticipantID: p.params.SID, - ParticipantIdentity: p.params.Identity, - ParticipantVersion: p.version.Load(), - RTCPChan: p.rtcpCh, - BufferFactory: p.params.Config.BufferFactory, - ReceiverConfig: p.params.Config.Receiver, - AudioConfig: p.params.AudioConfig, - VideoConfig: p.params.VideoConfig, - Telemetry: p.params.Telemetry, - Logger: LoggerWithTrack(p.params.Logger, livekit.TrackID(ti.Sid)), - SubscriberConfig: p.params.Config.Subscriber, - PLIThrottleConfig: p.params.PLIThrottleConfig, - SimTracks: p.params.SimTracks, - }) - - mt.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) - - // add to published and clean up pending - p.UpTrackManager.AddPublishedTrack(mt) - delete(p.pendingTracks, signalCid) - - mt.AddOnClose(func() { - // re-use track - p.lock.Lock() - p.unpublishedTracks = append(p.unpublishedTracks, ti) - p.lock.Unlock() - }) - + mt = p.addMediaTrack(signalCid, track.ID(), ti) newTrack = true } @@ -1753,54 +1733,25 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei return mt, newTrack } -func (p *ParticipantImpl) addMigrateMutedTrack(cid string, t *livekit.TrackInfo) *MediaTrack { - p.params.Logger.Debugw("add migrate muted track", "cid", cid, "track", t.String()) +func (p *ParticipantImpl) addMigrateMutedTrack(cid string, ti *livekit.TrackInfo) *MediaTrack { + p.params.Logger.Debugw("add migrate muted track", "cid", cid, "track", ti.String()) var rtpReceiver *webrtc.RTPReceiver for _, tr := range p.publisher.pc.GetTransceivers() { - if tr.Mid() == t.Mid { + if tr.Mid() == ti.Mid { rtpReceiver = tr.Receiver() break } } if rtpReceiver == nil { - p.params.Logger.Errorw("could not find receiver for migrated track", nil, "track", t.Sid) + p.params.Logger.Errorw("could not find receiver for migrated track", nil, "track", ti.Sid) return nil } - mt := NewMediaTrack(MediaTrackParams{ - TrackInfo: proto.Clone(t).(*livekit.TrackInfo), - SignalCid: cid, - SdpCid: cid, - ParticipantID: p.params.SID, - ParticipantIdentity: p.params.Identity, - ParticipantVersion: p.version.Load(), - RTCPChan: p.rtcpCh, - BufferFactory: p.params.Config.BufferFactory, - ReceiverConfig: p.params.Config.Receiver, - AudioConfig: p.params.AudioConfig, - VideoConfig: p.params.VideoConfig, - Telemetry: p.params.Telemetry, - Logger: LoggerWithTrack(p.params.Logger, livekit.TrackID(t.Sid)), - SubscriberConfig: p.params.Config.Subscriber, - PLIThrottleConfig: p.params.PLIThrottleConfig, - SimTracks: p.params.SimTracks, - }) + mt := p.addMediaTrack(cid, cid, ti) - mt.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) - // add to published and clean up pending - p.UpTrackManager.AddPublishedTrack(mt) - delete(p.pendingTracks, cid) - - mt.AddOnClose(func() { - // re-use track - p.lock.Lock() - p.unpublishedTracks = append(p.unpublishedTracks, t) - p.lock.Unlock() - }) - - potentialCodecs := make([]webrtc.RTPCodecParameters, 0, len(t.Codecs)) + potentialCodecs := make([]webrtc.RTPCodecParameters, 0, len(ti.Codecs)) parameters := rtpReceiver.GetParameters() - for _, c := range t.Codecs { + for _, c := range ti.Codecs { for _, nc := range parameters.Codecs { if strings.EqualFold(nc.MimeType, c.MimeType) { potentialCodecs = append(potentialCodecs, nc) @@ -1810,19 +1761,62 @@ func (p *ParticipantImpl) addMigrateMutedTrack(cid string, t *livekit.TrackInfo) } mt.SetPotentialCodecs(potentialCodecs, parameters.HeaderExtensions) - for _, codec := range t.Codecs { + for _, codec := range ti.Codecs { for ssrc, info := range p.params.SimTracks { if info.Mid == codec.Mid { mt.MediaTrackReceiver.SetLayerSsrc(codec.MimeType, info.Rid, ssrc) } } } - mt.SetSimulcast(t.Simulcast) + mt.SetSimulcast(ti.Simulcast) mt.SetMuted(true) return mt } +func (p *ParticipantImpl) addMediaTrack(signalCid string, sdpCid string, ti *livekit.TrackInfo) *MediaTrack { + mt := NewMediaTrack(MediaTrackParams{ + TrackInfo: proto.Clone(ti).(*livekit.TrackInfo), + SignalCid: signalCid, + SdpCid: sdpCid, + ParticipantID: p.params.SID, + ParticipantIdentity: p.params.Identity, + ParticipantVersion: p.version.Load(), + RTCPChan: p.rtcpCh, + BufferFactory: p.params.Config.BufferFactory, + ReceiverConfig: p.params.Config.Receiver, + AudioConfig: p.params.AudioConfig, + VideoConfig: p.params.VideoConfig, + Telemetry: p.params.Telemetry, + Logger: LoggerWithTrack(p.params.Logger, livekit.TrackID(ti.Sid)), + SubscriberConfig: p.params.Config.Subscriber, + PLIThrottleConfig: p.params.PLIThrottleConfig, + SimTracks: p.params.SimTracks, + }) + + mt.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) + // add to published and clean up pending + p.UpTrackManager.AddPublishedTrack(mt) + + p.pendingTracks[signalCid].trackInfos = p.pendingTracks[signalCid].trackInfos[1:] + if len(p.pendingTracks[signalCid].trackInfos) == 0 { + delete(p.pendingTracks, signalCid) + } + + mt.AddOnClose(func() { + // re-use track sid + p.pendingTracksLock.Lock() + if pti := p.pendingTracks[signalCid]; pti != nil { + p.sendTrackPublished(signalCid, pti.trackInfos[0]) + } else { + p.unpublishedTracks = append(p.unpublishedTracks, ti) + } + p.pendingTracksLock.Unlock() + }) + + return mt +} + func (p *ParticipantImpl) handleTrackPublished(track types.MediaTrack) { if !p.hasPendingMigratedTrack() { p.SetMigrateState(types.MigrateStateComplete) @@ -1855,34 +1849,36 @@ func (p *ParticipantImpl) onUpTrackManagerClose() { func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) { signalCid := clientId - trackInfo := p.pendingTracks[clientId] - if trackInfo == nil { + pendingInfo := p.pendingTracks[clientId] + if pendingInfo == nil { track_loop: - for cid, ti := range p.pendingTracks { + for cid, pti := range p.pendingTracks { if cid == clientId { - trackInfo = ti + pendingInfo = pti signalCid = cid break } + ti := pti.trackInfos[0] for _, c := range ti.Codecs { if c.Cid == clientId { - trackInfo = ti + pendingInfo = pti signalCid = cid break track_loop } } } - if trackInfo == nil { + if pendingInfo == nil { // // If no match on client id, find first one matching type // as MediaStreamTrack can change client id when transceiver // is added to peer connection. // - for cid, ti := range p.pendingTracks { + for cid, pti := range p.pendingTracks { + ti := pti.trackInfos[0] if ti.Type == kind { - trackInfo = ti + pendingInfo = pti signalCid = cid break } @@ -1891,29 +1887,49 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp } // if still not found, we are done - if trackInfo == nil { + if pendingInfo == nil { p.params.Logger.Errorw("track info not published prior to track", nil, "clientId", clientId) return signalCid, nil } - return signalCid, trackInfo.TrackInfo + return signalCid, pendingInfo.trackInfos[0] } // setStableTrackID either generates a new TrackID or reuses a previously used one // for -func (p *ParticipantImpl) setStableTrackID(info *livekit.TrackInfo) { +func (p *ParticipantImpl) setStableTrackID(cid string, info *livekit.TrackInfo) { var trackID string - for i, ti := range p.unpublishedTracks { - if ti.Type == info.Type && ti.Source == info.Source && ti.Name == info.Name { - trackID = ti.Sid - if i < len(p.unpublishedTracks)-1 { - p.unpublishedTracks = append(p.unpublishedTracks[:i], p.unpublishedTracks[i+1:]...) - } else { - p.unpublishedTracks = p.unpublishedTracks[:i] + // if already pending, use the same SID + // should not happen as this means multiple `AddTrack` requests have been called, but check anyway + if pti := p.pendingTracks[cid]; pti != nil { + trackID = pti.trackInfos[0].Sid + } + + // check against published tracks as re-publish could be happening + if trackID == "" { + if pt := p.getPublishedTrackBySignalCid(cid); pt != nil { + ti := pt.ToProto() + if ti.Type == info.Type && ti.Source == info.Source && ti.Name == info.Name { + trackID = ti.Sid } - break } } + + if trackID == "" { + // check a previously published matching track + for i, ti := range p.unpublishedTracks { + if ti.Type == info.Type && ti.Source == info.Source && ti.Name == info.Name { + trackID = ti.Sid + if i < len(p.unpublishedTracks)-1 { + p.unpublishedTracks = append(p.unpublishedTracks[:i], p.unpublishedTracks[i+1:]...) + } else { + p.unpublishedTracks = p.unpublishedTracks[:i] + } + break + } + } + } + // otherwise generate if trackID == "" { trackPrefix := utils.TrackPrefix @@ -1982,11 +1998,15 @@ func (p *ParticipantImpl) DebugInfo() map[string]interface{} { pendingTrackInfo := make(map[string]interface{}) p.pendingTracksLock.RLock() - for clientID, ti := range p.pendingTracks { + for clientID, pti := range p.pendingTracks { + var trackInfos []string + for _, ti := range pti.trackInfos { + trackInfos = append(trackInfos, ti.String()) + } + pendingTrackInfo[clientID] = map[string]interface{}{ - "Sid": ti.Sid, - "Type": ti.Type.String(), - "Simulcast": ti.Simulcast, + "TrackInfos": trackInfos, + "Migrated": pti.migrated, } } p.pendingTracksLock.RUnlock() diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 4b6abf2c2..1452fec2e 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -116,12 +116,13 @@ func TestTrackPublishing(t *testing.T) { require.Equal(t, 1, sink.WriteMessageCallCount()) }) - t.Run("should not allow adding of duplicate tracks if already published by client id in signalling", func(t *testing.T) { + t.Run("should queue adding of duplicate tracks if already published by client id in signalling", func(t *testing.T) { p := newParticipantForTest("test") sink := p.params.Sink.(*routingfakes.FakeMessageSink) track := &typesfakes.FakeLocalMediaTrack{} track.SignalCidReturns("cid") + track.ToProtoReturns(&livekit.TrackInfo{}) // directly add to publishedTracks without lock - for testing purpose only p.UpTrackManager.publishedTracks["cid"] = track @@ -131,13 +132,27 @@ func TestTrackPublishing(t *testing.T) { Type: livekit.TrackType_VIDEO, }) require.Equal(t, 0, sink.WriteMessageCallCount()) + require.Equal(t, 1, len(p.pendingTracks["cid"].trackInfos)) + + // add again - it should be added to the queue + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Type: livekit.TrackType_VIDEO, + }) + require.Equal(t, 0, sink.WriteMessageCallCount()) + require.Equal(t, 2, len(p.pendingTracks["cid"].trackInfos)) + + // check SID is the same + require.Equal(t, p.pendingTracks["cid"].trackInfos[0].Sid, p.pendingTracks["cid"].trackInfos[1].Sid) }) - t.Run("should not allow adding of duplicate tracks if already published by client id in sdp", func(t *testing.T) { + t.Run("should queue adding of duplicate tracks if already published by client id in sdp", func(t *testing.T) { p := newParticipantForTest("test") sink := p.params.Sink.(*routingfakes.FakeMessageSink) track := &typesfakes.FakeLocalMediaTrack{} + track.ToProtoReturns(&livekit.TrackInfo{}) track.HasSdpCidCalls(func(s string) bool { return s == "cid" }) // directly add to publishedTracks without lock - for testing purpose only p.UpTrackManager.publishedTracks["cid"] = track @@ -148,6 +163,19 @@ func TestTrackPublishing(t *testing.T) { Type: livekit.TrackType_VIDEO, }) require.Equal(t, 0, sink.WriteMessageCallCount()) + require.Equal(t, 1, len(p.pendingTracks["cid"].trackInfos)) + + // add again - it should be added to the queue + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Type: livekit.TrackType_VIDEO, + }) + require.Equal(t, 0, sink.WriteMessageCallCount()) + require.Equal(t, 2, len(p.pendingTracks["cid"].trackInfos)) + + // check SID is the same + require.Equal(t, p.pendingTracks["cid"].trackInfos[0].Sid, p.pendingTracks["cid"].trackInfos[1].Sid) }) } @@ -202,7 +230,7 @@ func TestMuteSetting(t *testing.T) { t.Run("can set mute when track is pending", func(t *testing.T) { p := newParticipantForTest("test") ti := &livekit.TrackInfo{Sid: "testTrack"} - p.pendingTracks["cid"] = &pendingTrackInfo{TrackInfo: ti} + p.pendingTracks["cid"] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}} p.SetTrackMuted(livekit.TrackID(ti.Sid), true, false) require.True(t, ti.Muted) @@ -549,6 +577,7 @@ func TestSetStableTrackID(t *testing.T) { name string trackInfo *livekit.TrackInfo unpublished []*livekit.TrackInfo + cid string prefix string remainingUnpublished int }{ @@ -578,6 +607,7 @@ func TestSetStableTrackID(t *testing.T) { Sid: "TR_VC1235", }, }, + cid: "TR_VC1235", prefix: "TR_VC1235", remainingUnpublished: 1, }, @@ -606,7 +636,7 @@ func TestSetStableTrackID(t *testing.T) { p.unpublishedTracks = tc.unpublished ti := tc.trackInfo - p.setStableTrackID(ti) + p.setStableTrackID(tc.cid, ti) require.Contains(t, ti.Sid, tc.prefix) require.Len(t, p.unpublishedTracks, tc.remainingUnpublished) })