From a393d64ccca32257c224bb46ae5c21d9ff7942bd Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sun, 31 Jul 2022 10:50:55 +0530 Subject: [PATCH] Do not re-use transceiver when negotiation is pending. (#862) --- pkg/rtc/mediatracksubscriptions.go | 7 +- pkg/rtc/participant.go | 8 ++ pkg/rtc/transport.go | 15 +++ pkg/rtc/types/interfaces.go | 2 + .../typesfakes/fake_local_participant.go | 113 ++++++++++++++++++ pkg/service/wire_gen.go | 3 +- 6 files changed, 143 insertions(+), 5 deletions(-) diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index de0d3b294..fb0aaf444 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -249,7 +249,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * // if cannot replace, find an unused transceiver or add new one if transceiver == nil { - if sub.ProtocolVersion().SupportsTransceiverReuse() { + if sub.ProtocolVersion().SupportsTransceiverReuse() && !sub.IsNegotiationPending(subTrack.PublisherID()) { // // AddTrack will create a new transceiver or re-use an unused one // if the attributes match. This prevents SDP from bloating @@ -268,9 +268,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * } } } else { - transceiver, err = sub.SubscriberPC().AddTransceiverFromTrack(downTrack, webrtc.RTPTransceiverInit{ - Direction: webrtc.RTPTransceiverDirectionSendonly, - }) + transceiver, err = sub.SubscriberPC().AddTransceiverFromTrack(downTrack) if err != nil { return err } @@ -778,6 +776,7 @@ func (t *MediaTrackSubscriptions) downTrackClosed( sub.RemoveSubscribedTrack(subTrack) if !willBeResumed { + sub.AddNegotiationPending(subTrack.PublisherID()) sub.Negotiate(false) } } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 4e24eb8e5..2af215e13 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -781,6 +781,14 @@ func (p *ParticipantImpl) Negotiate(force bool) { } } +func (p *ParticipantImpl) AddNegotiationPending(publisherID livekit.ParticipantID) { + p.subscriber.AddNegotiationPending(publisherID) +} + +func (p *ParticipantImpl) IsNegotiationPending(publisherID livekit.ParticipantID) bool { + return p.subscriber.IsNegotiationPending(publisherID) +} + func (p *ParticipantImpl) SetMigrateState(s types.MigrateState) { p.lock.Lock() preState := p.MigrateState() diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index ab3786172..2a008c814 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -63,6 +63,7 @@ type PCTransport struct { lock sync.RWMutex pendingCandidates []webrtc.ICECandidateInit debouncedNegotiate func(func()) + negotiationPending map[livekit.ParticipantID]bool onOffer func(offer webrtc.SessionDescription) restartAfterGathering bool restartAtNextOffer bool @@ -194,6 +195,7 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { params: params, debouncedNegotiate: debounce.New(negotiationFrequency), negotiationState: negotiationStateNone, + negotiationPending: make(map[livekit.ParticipantID]bool), } if params.Target == livekit.SignalTarget_SUBSCRIBER { t.streamAllocator = sfu.NewStreamAllocator(sfu.StreamAllocatorParams{ @@ -324,6 +326,12 @@ func (t *PCTransport) OnNegotiationFailed(f func()) { t.onNegotiationFailed = f } +func (t *PCTransport) AddNegotiationPending(publisherID livekit.ParticipantID) { + t.lock.Lock() + t.negotiationPending[publisherID] = true + t.lock.Unlock() +} + func (t *PCTransport) Negotiate(force bool) { if force { t.debouncedNegotiate(func() { @@ -341,6 +349,12 @@ func (t *PCTransport) Negotiate(force bool) { } } +func (t *PCTransport) IsNegotiationPending(publisherID livekit.ParticipantID) bool { + t.lock.RLock() + defer t.lock.RUnlock() + return t.negotiationPending[publisherID] +} + func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error { t.lock.Lock() defer t.lock.Unlock() @@ -438,6 +452,7 @@ func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { // indicate waiting for client t.negotiationState = negotiationStateClient t.restartAfterGathering = false + t.negotiationPending = make(map[livekit.ParticipantID]bool) negotiateVersion := t.negotiateCounter.Inc() if t.signalStateCheckTimer != nil { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 410bc5b43..43035ce50 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -238,6 +238,8 @@ type LocalParticipant interface { SubscriberPC() *webrtc.PeerConnection HandleAnswer(sdp webrtc.SessionDescription) error Negotiate(force bool) + AddNegotiationPending(publisherID livekit.ParticipantID) + IsNegotiationPending(publisherID livekit.ParticipantID) bool ICERestart(iceConfig *IceConfig) error AddSubscribedTrack(st SubscribedTrack) RemoveSubscribedTrack(st SubscribedTrack) diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index e9c631137..4b1fe9182 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -27,6 +27,11 @@ type FakeLocalParticipant struct { addICECandidateReturnsOnCall map[int]struct { result1 error } + AddNegotiationPendingStub func(livekit.ParticipantID) + addNegotiationPendingMutex sync.RWMutex + addNegotiationPendingArgsForCall []struct { + arg1 livekit.ParticipantID + } AddSubscribedTrackStub func(types.SubscribedTrack) addSubscribedTrackMutex sync.RWMutex addSubscribedTrackArgsForCall []struct { @@ -319,6 +324,17 @@ type FakeLocalParticipant struct { identityReturnsOnCall map[int]struct { result1 livekit.ParticipantIdentity } + IsNegotiationPendingStub func(livekit.ParticipantID) bool + isNegotiationPendingMutex sync.RWMutex + isNegotiationPendingArgsForCall []struct { + arg1 livekit.ParticipantID + } + isNegotiationPendingReturns struct { + result1 bool + } + isNegotiationPendingReturnsOnCall map[int]struct { + result1 bool + } IsPublisherStub func() bool isPublisherMutex sync.RWMutex isPublisherArgsForCall []struct { @@ -774,6 +790,38 @@ func (fake *FakeLocalParticipant) AddICECandidateReturnsOnCall(i int, result1 er }{result1} } +func (fake *FakeLocalParticipant) AddNegotiationPending(arg1 livekit.ParticipantID) { + fake.addNegotiationPendingMutex.Lock() + fake.addNegotiationPendingArgsForCall = append(fake.addNegotiationPendingArgsForCall, struct { + arg1 livekit.ParticipantID + }{arg1}) + stub := fake.AddNegotiationPendingStub + fake.recordInvocation("AddNegotiationPending", []interface{}{arg1}) + fake.addNegotiationPendingMutex.Unlock() + if stub != nil { + fake.AddNegotiationPendingStub(arg1) + } +} + +func (fake *FakeLocalParticipant) AddNegotiationPendingCallCount() int { + fake.addNegotiationPendingMutex.RLock() + defer fake.addNegotiationPendingMutex.RUnlock() + return len(fake.addNegotiationPendingArgsForCall) +} + +func (fake *FakeLocalParticipant) AddNegotiationPendingCalls(stub func(livekit.ParticipantID)) { + fake.addNegotiationPendingMutex.Lock() + defer fake.addNegotiationPendingMutex.Unlock() + fake.AddNegotiationPendingStub = stub +} + +func (fake *FakeLocalParticipant) AddNegotiationPendingArgsForCall(i int) livekit.ParticipantID { + fake.addNegotiationPendingMutex.RLock() + defer fake.addNegotiationPendingMutex.RUnlock() + argsForCall := fake.addNegotiationPendingArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeLocalParticipant) AddSubscribedTrack(arg1 types.SubscribedTrack) { fake.addSubscribedTrackMutex.Lock() fake.addSubscribedTrackArgsForCall = append(fake.addSubscribedTrackArgsForCall, struct { @@ -2313,6 +2361,67 @@ func (fake *FakeLocalParticipant) IdentityReturnsOnCall(i int, result1 livekit.P }{result1} } +func (fake *FakeLocalParticipant) IsNegotiationPending(arg1 livekit.ParticipantID) bool { + fake.isNegotiationPendingMutex.Lock() + ret, specificReturn := fake.isNegotiationPendingReturnsOnCall[len(fake.isNegotiationPendingArgsForCall)] + fake.isNegotiationPendingArgsForCall = append(fake.isNegotiationPendingArgsForCall, struct { + arg1 livekit.ParticipantID + }{arg1}) + stub := fake.IsNegotiationPendingStub + fakeReturns := fake.isNegotiationPendingReturns + fake.recordInvocation("IsNegotiationPending", []interface{}{arg1}) + fake.isNegotiationPendingMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsNegotiationPendingCallCount() int { + fake.isNegotiationPendingMutex.RLock() + defer fake.isNegotiationPendingMutex.RUnlock() + return len(fake.isNegotiationPendingArgsForCall) +} + +func (fake *FakeLocalParticipant) IsNegotiationPendingCalls(stub func(livekit.ParticipantID) bool) { + fake.isNegotiationPendingMutex.Lock() + defer fake.isNegotiationPendingMutex.Unlock() + fake.IsNegotiationPendingStub = stub +} + +func (fake *FakeLocalParticipant) IsNegotiationPendingArgsForCall(i int) livekit.ParticipantID { + fake.isNegotiationPendingMutex.RLock() + defer fake.isNegotiationPendingMutex.RUnlock() + argsForCall := fake.isNegotiationPendingArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) IsNegotiationPendingReturns(result1 bool) { + fake.isNegotiationPendingMutex.Lock() + defer fake.isNegotiationPendingMutex.Unlock() + fake.IsNegotiationPendingStub = nil + fake.isNegotiationPendingReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsNegotiationPendingReturnsOnCall(i int, result1 bool) { + fake.isNegotiationPendingMutex.Lock() + defer fake.isNegotiationPendingMutex.Unlock() + fake.IsNegotiationPendingStub = nil + if fake.isNegotiationPendingReturnsOnCall == nil { + fake.isNegotiationPendingReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isNegotiationPendingReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeLocalParticipant) IsPublisher() bool { fake.isPublisherMutex.Lock() ret, specificReturn := fake.isPublisherReturnsOnCall[len(fake.isPublisherArgsForCall)] @@ -4476,6 +4585,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.invocationsMutex.RUnlock() fake.addICECandidateMutex.RLock() defer fake.addICECandidateMutex.RUnlock() + fake.addNegotiationPendingMutex.RLock() + defer fake.addNegotiationPendingMutex.RUnlock() fake.addSubscribedTrackMutex.RLock() defer fake.addSubscribedTrackMutex.RUnlock() fake.addSubscriberMutex.RLock() @@ -4536,6 +4647,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.iDMutex.RUnlock() fake.identityMutex.RLock() defer fake.identityMutex.RUnlock() + fake.isNegotiationPendingMutex.RLock() + defer fake.isNegotiationPendingMutex.RUnlock() fake.isPublisherMutex.RLock() defer fake.isPublisherMutex.RUnlock() fake.isReadyMutex.RLock() diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 6d9060b60..d15c07306 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -1,7 +1,8 @@ // Code generated by Wire. DO NOT EDIT. //go:generate go run github.com/google/wire/cmd/wire -//+build !wireinject +//go:build !wireinject +// +build !wireinject package service