diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 0c290cf7a..b74d253cb 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -23,7 +23,6 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/connectionquality" "github.com/livekit/livekit-server/pkg/sfu/twcc" "github.com/livekit/livekit-server/pkg/telemetry" - "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -388,7 +387,7 @@ func (p *ParticipantImpl) OnClaimsChanged(callback func(types.LocalParticipant)) } // HandleOffer an offer from remote participant, used when clients make the initial connection -func (p *ParticipantImpl) HandleOffer(offer webrtc.SessionDescription) error { +func (p *ParticipantImpl) HandleOffer(offer webrtc.SessionDescription) { p.params.Logger.Infow("received offer", "transport", livekit.SignalTarget_PUBLISHER) shouldPend := false if p.MigrateState() == types.MigrateStateInit { @@ -397,12 +396,7 @@ func (p *ParticipantImpl) HandleOffer(offer webrtc.SessionDescription) error { offer = p.setCodecPreferencesForPublisher(offer) - if err := p.TransportManager.HandleOffer(offer, shouldPend); err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "remote_description").Add(1) - return err - } - - return nil + p.TransportManager.HandleOffer(offer, shouldPend) } func (p *ParticipantImpl) setCodecPreferencesForPublisher(offer webrtc.SessionDescription) webrtc.SessionDescription { @@ -534,18 +528,15 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess } } -func (p *ParticipantImpl) onPublisherAnswer(answer webrtc.SessionDescription) { +func (p *ParticipantImpl) onPublisherAnswer(answer webrtc.SessionDescription) error { p.params.Logger.Infow("sending answer", "transport", livekit.SignalTarget_PUBLISHER) if err := p.writeMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Answer{ Answer: ToProtoSessionDescription(answer), }, }); err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "write_message").Add(1) - return + return err } - prometheus.ServiceOperationCounter.WithLabelValues("answer", "success", "").Add(1) - p.TransportManager.PublisherLocalDescriptionSent() if p.isPublisher.Load() != p.CanPublish() { p.isPublisher.Store(p.CanPublish()) @@ -564,6 +555,7 @@ func (p *ParticipantImpl) onPublisherAnswer(answer webrtc.SessionDescription) { if p.MigrateState() == types.MigrateStateSync { go p.handleMigrateMutedTrack() } + return nil } func (p *ParticipantImpl) handleMigrateMutedTrack() { @@ -717,10 +709,7 @@ func (p *ParticipantImpl) SetMigrateState(s types.MigrateState) { } if processPendingOffer { - err := p.TransportManager.ProcessPendingPublisherOffer() - if err != nil { - p.params.Logger.Errorw("could not handle pending offer during migration", err) - } + p.TransportManager.ProcessPendingPublisherOffer() } } @@ -729,14 +718,14 @@ func (p *ParticipantImpl) MigrateState() types.MigrateState { } // ICERestart restarts subscriber ICE connections -func (p *ParticipantImpl) ICERestart(iceConfig *types.IceConfig) error { +func (p *ParticipantImpl) ICERestart(iceConfig *types.IceConfig) { p.clearDisconnectTimer() for _, t := range p.GetPublishedTracks() { t.(types.LocalMediaTrack).Restart() } - return p.TransportManager.ICERestart(iceConfig) + p.TransportManager.ICERestart(iceConfig) } func (p *ParticipantImpl) OnICEConfigChanged(f func(participant types.LocalParticipant, iceConfig types.IceConfig)) { @@ -1055,8 +1044,8 @@ func (p *ParticipantImpl) setupTransportManager() error { } }) - tm.OnPublisherICECandidate(func(c *webrtc.ICECandidate) { - p.onICECandidate(c, livekit.SignalTarget_PUBLISHER) + tm.OnPublisherICECandidate(func(c *webrtc.ICECandidate) error { + return p.onICECandidate(c, livekit.SignalTarget_PUBLISHER) }) tm.OnPublisherGetDTX(p.onPublisherGetDTX) tm.OnPublisherAnswer(p.onPublisherAnswer) @@ -1064,16 +1053,16 @@ func (p *ParticipantImpl) setupTransportManager() error { tm.OnPublisherInitialConnected(p.onPublisherInitialConnected) tm.OnSubscriberOffer(p.onSubscriberOffer) - tm.OnSubscriberICECandidate(func(c *webrtc.ICECandidate) { - p.onICECandidate(c, livekit.SignalTarget_SUBSCRIBER) + tm.OnSubscriberICECandidate(func(c *webrtc.ICECandidate) error { + return p.onICECandidate(c, livekit.SignalTarget_SUBSCRIBER) }) tm.OnSubscriberInitialConnected(p.onSubscriberInitialConnected) - tm.OnSubscriberNegotiationFailed(p.handleSubscriberNegotiationFailed) tm.OnSubscriberStreamStateChange(p.onStreamStateChange) tm.OnPrimaryTransportInitialConnected(p.onPrimaryTransportInitialConnected) tm.OnPrimaryTransportFullyEstablished(p.onPrimaryTransportFullyEstablished) tm.OnAnyTransportFailed(p.onAnyTransportFailed) + tm.OnAnyTransportNegotiationFailed(p.onAnyTransportNegotiationFailed) tm.OnDataMessage(p.onDataMessage) p.TransportManager = tm @@ -1121,24 +1110,18 @@ func (p *ParticipantImpl) updateState(state livekit.ParticipantInfo_State) { } // when the server has an offer for participant -func (p *ParticipantImpl) onSubscriberOffer(offer webrtc.SessionDescription) { +func (p *ParticipantImpl) onSubscriberOffer(offer webrtc.SessionDescription) error { if p.State() == livekit.ParticipantInfo_DISCONNECTED { // skip when disconnected - return + return nil } p.params.Logger.Infow("sending offer", "transport", livekit.SignalTarget_SUBSCRIBER) - err := p.writeMessage(&livekit.SignalResponse{ + return p.writeMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Offer{ Offer: ToProtoSessionDescription(offer), }, }) - if err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) - return - } - prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) - p.TransportManager.SubscriberLocalDescriptionSent() } // when a new remoteTrack is created, creates a Track and adds it to room @@ -1207,16 +1190,16 @@ func (p *ParticipantImpl) onDataMessage(kind livekit.DataPacket_Kind, data []byt } } -func (p *ParticipantImpl) onICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) { +func (p *ParticipantImpl) onICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error { if c == nil || p.State() == livekit.ParticipantInfo_DISCONNECTED { - return + return nil } if target == livekit.SignalTarget_SUBSCRIBER && p.MigrateState() == types.MigrateStateInit { - return + return nil } - p.sendICECandidate(c, target) + return p.sendICECandidate(c, target) } func (p *ParticipantImpl) onPublisherInitialConnected() { @@ -1941,8 +1924,8 @@ func (p *ParticipantImpl) GetCachedDownTrack(trackID livekit.TrackID) (*webrtc.R return nil, sfu.ForwarderState{} } -func (p *ParticipantImpl) handleSubscriberNegotiationFailed() { - p.params.Logger.Infow("subscriber negotiation failed, starting full reconnect") +func (p *ParticipantImpl) onAnyTransportNegotiationFailed() { + p.params.Logger.Infow("negotiation failed, starting full reconnect") _ = p.writeMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Leave{ Leave: &livekit.LeaveRequest{ diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 2ea4c68f7..3c1cbfe60 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -407,8 +407,7 @@ func TestDisableCodecs(t *testing.T) { } return nil } - err = participant.HandleOffer(sdp) - require.NoError(t, err) + participant.HandleOffer(sdp) testutils.WithTimeout(t, func() string { if answerReceived.Load() { diff --git a/pkg/rtc/participant_signal.go b/pkg/rtc/participant_signal.go index 6bf632238..18cd48e6a 100644 --- a/pkg/rtc/participant_signal.go +++ b/pkg/rtc/participant_signal.go @@ -138,11 +138,11 @@ func (p *ParticipantImpl) SendRefreshToken(token string) error { }) } -func (p *ParticipantImpl) sendICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) { +func (p *ParticipantImpl) sendICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error { p.params.Logger.Infow("sending ice candidate", "candidate", c.String(), "target", target) trickle := ToProtoTrickle(c.ToJSON()) trickle.Target = target - _ = p.writeMessage(&livekit.SignalResponse{ + return p.writeMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Trickle{ Trickle: trickle, }, diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 672db4436..b27663f52 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -341,9 +341,7 @@ func (r *Room) ResumeParticipant(p types.LocalParticipant, responseSink routing. return err } - if err := p.ICERestart(nil); err != nil { - return err - } + p.ICERestart(nil) return nil } @@ -638,12 +636,10 @@ func (r *Room) SimulateScenario(participant types.LocalParticipant, simulateScen case *livekit.SimulateScenario_SwitchCandidateProtocol: r.Logger.Infow("simulating switch candidate protocol", "participant", participant.Identity()) - if err := participant.ICERestart(&types.IceConfig{ + participant.ICERestart(&types.IceConfig{ PreferSubTcp: scenario.SwitchCandidateProtocol == livekit.CandidateProtocol_TCP, PreferPubTcp: scenario.SwitchCandidateProtocol == livekit.CandidateProtocol_TCP, - }); err != nil { - return err - } + }) } return nil } diff --git a/pkg/rtc/signalhandler.go b/pkg/rtc/signalhandler.go index b6b9403fa..a638f40b0 100644 --- a/pkg/rtc/signalhandler.go +++ b/pkg/rtc/signalhandler.go @@ -10,30 +10,19 @@ import ( func HandleParticipantSignal(room types.Room, participant types.LocalParticipant, req *livekit.SignalRequest, pLogger logger.Logger) error { switch msg := req.Message.(type) { case *livekit.SignalRequest_Offer: - err := participant.HandleOffer(FromProtoSessionDescription(msg.Offer)) - if err != nil { - pLogger.Errorw("could not handle offer", err) - return err - } + participant.HandleOffer(FromProtoSessionDescription(msg.Offer)) case *livekit.SignalRequest_AddTrack: pLogger.Debugw("add track request", "trackID", msg.AddTrack.Cid) participant.AddTrack(msg.AddTrack) case *livekit.SignalRequest_Answer: - sd := FromProtoSessionDescription(msg.Answer) - if err := participant.HandleAnswer(sd); err != nil { - pLogger.Errorw("could not handle answer", err) - // connection cannot be successful if we can't answer - return err - } + participant.HandleAnswer(FromProtoSessionDescription(msg.Answer)) case *livekit.SignalRequest_Trickle: candidateInit, err := FromProtoTrickle(msg.Trickle) if err != nil { pLogger.Warnw("could not decode trickle", err) return nil } - if err := participant.AddICECandidate(candidateInit, msg.Trickle.Target); err != nil { - pLogger.Warnw("could not add ICE candidate", err) - } + participant.AddICECandidate(candidateInit, msg.Trickle.Target) case *livekit.SignalRequest_Mute: participant.SetTrackMuted(livekit.TrackID(msg.Mute.Sid), msg.Mute.Muted, false) case *livekit.SignalRequest_Subscription: diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index d507c7d88..1f0131761 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -1,7 +1,6 @@ package rtc import ( - "errors" "fmt" "strings" "sync" @@ -16,6 +15,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" + "github.com/pkg/errors" "go.uber.org/atomic" "google.golang.org/protobuf/proto" @@ -50,16 +50,84 @@ var ( ErrIceRestartWithoutLocalSDP = errors.New("ICE restart without local SDP settled") ErrNoTransceiver = errors.New("no transceiver") ErrNoSender = errors.New("no sender") + ErrNoICECandidateHandler = errors.New("no ICE candidate handler") + ErrNoOfferHandler = errors.New("no offer handler") + ErrNoAnswerHandler = errors.New("no answer handler") ) +// ------------------------------------------------------------------------- + +type signal int + const ( - negotiationStateNone = iota - // waiting for client answer - negotiationStateClient - // need to Negotiate again - negotiationRetry + signalICEGatheringComplete signal = iota + signalLocalICECandidate + signalRemoteICECandidate + signalLogICECandidates + signalSendOffer + signalRemoteDescriptionReceived + signalICERestart ) +func (s signal) String() string { + switch s { + case signalICEGatheringComplete: + return "ICE_GATHERING_COMPLETE" + case signalLocalICECandidate: + return "LOCAL_ICE_CANDIDATE" + case signalRemoteICECandidate: + return "REMOTE_ICE_CANDIDATE" + case signalLogICECandidates: + return "LOG_ICE_CANDIDATES" + case signalSendOffer: + return "SEND_OFFER" + case signalRemoteDescriptionReceived: + return "REMOTE_DESCRIPTION_RECEIVED" + case signalICERestart: + return "ICE_RESTART" + default: + return fmt.Sprintf("%d", int(s)) + } +} + +// ------------------------------------------------------- + +type event struct { + signal signal + data interface{} +} + +func (e event) String() string { + return fmt.Sprintf("PCTransport:Event{signal: %s, data: %+v}", e.signal, e.data) +} + +// ------------------------------------------------------- + +type NegotiationState int + +const ( + NegotiationStateNone NegotiationState = iota + // waiting for remote description + NegotiationStateRemote + // need to Negotiate again + NegotiationStateRetry +) + +func (n NegotiationState) String() string { + switch n { + case NegotiationStateNone: + return "NONE" + case NegotiationStateRemote: + return "WAITING_FOR_REMOTE" + case NegotiationStateRetry: + return "RETRY" + default: + return fmt.Sprintf("%d", int(n)) + } +} + +// ------------------------------------------------------- + type SimulcastTrackInfo struct { Mid string Rid string @@ -84,31 +152,39 @@ type PCTransport struct { onFullyEstablished func() - localDescriptionSent bool - cachedLocalCandidates []*webrtc.ICECandidate - pendingRemoteCandidates []webrtc.ICECandidateInit - debouncedNegotiate func(func()) - negotiationPending map[livekit.ParticipantID]bool - onICECandidate func(c *webrtc.ICECandidate) - onOffer func(offer webrtc.SessionDescription) - onAnswer func(offer webrtc.SessionDescription) - onRemoteDescriptionSettled func() error - onInitialConnected func() - onFailed func(isShortLived bool) - restartAfterGathering bool - restartAtNextOffer bool - negotiationState int - negotiateCounter atomic.Int32 - signalStateCheckTimer *time.Timer - onNegotiationFailed func() + debouncedNegotiate func(func()) + debouncePending bool + + onICECandidate func(c *webrtc.ICECandidate) error + onOffer func(offer webrtc.SessionDescription) error + onAnswer func(answer webrtc.SessionDescription) error + onInitialConnected func() + onFailed func(isShortLived bool) + onGetDTX func() bool + onNegotiationStateChanged func(state NegotiationState) + onNegotiationFailed func() // stream allocator for subscriber PC streamAllocator *sfu.StreamAllocator previousAnswer *webrtc.SessionDescription - preferTCP bool + preferTCP atomic.Bool + isClosed atomic.Bool + eventChMu sync.RWMutex + eventCh chan event + + // the following should be accessed only in event processing go routine + cacheLocalCandidates bool + cachedLocalCandidates []*webrtc.ICECandidate + pendingRemoteCandidates []*webrtc.ICECandidateInit + negotiationPending map[livekit.ParticipantID]bool + restartAfterGathering bool + restartAtNextOffer bool + negotiationState NegotiationState + negotiateCounter atomic.Int32 + signalStateCheckTimer *time.Timer currentOfferIceCredential string // ice user:pwd, for publish side ice restart checking pendingRestartIceOffer *webrtc.SessionDescription @@ -242,8 +318,9 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { t := &PCTransport{ params: params, debouncedNegotiate: debounce.New(negotiationFrequency), - negotiationState: negotiationStateNone, + negotiationState: NegotiationStateNone, negotiationPending: make(map[livekit.ParticipantID]bool), + eventCh: make(chan event, 50), } if params.Target == livekit.SignalTarget_SUBSCRIBER { t.streamAllocator = sfu.NewStreamAllocator(sfu.StreamAllocatorParams{ @@ -257,9 +334,38 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { return nil, err } + go t.processEvents() + return t, nil } +func (t *PCTransport) createPeerConnection() error { + var bwe cc.BandwidthEstimator + pc, me, err := newPeerConnection(t.params, func(estimator cc.BandwidthEstimator) { + bwe = estimator + }) + if err != nil { + return err + } + + t.pc = pc + t.pc.OnICEGatheringStateChange(t.onICEGatheringStateChange) + t.pc.OnICEConnectionStateChange(t.onICEConnectionStateChange) + t.pc.OnICECandidate(t.onICECandidateTrickle) + + t.pc.OnConnectionStateChange(t.onPeerConnectionStateChange) + + t.pc.OnDataChannel(t.onDataChannel) + + t.me = me + + if bwe != nil && t.streamAllocator != nil { + t.streamAllocator.SetBandwidthEstimator(bwe) + } + + return nil +} + func (t *PCTransport) setICEConnectedAt(at time.Time) { t.lock.Lock() if t.iceConnectedAt.IsZero() { @@ -305,15 +411,9 @@ func (t *PCTransport) getSelectedPair() (*webrtc.ICECandidatePair, error) { } func (t *PCTransport) logICECandidates() { - t.lock.RLock() - t.params.Logger.Infow( - "ice candidates", - "lc", t.allowedLocalCandidates, - "rc", t.allowedRemoteCandidates, - "lc (filtered)", t.filteredLocalCandidates, - "rc (filtered)", t.filteredRemoteCandidates, - ) - t.lock.RUnlock() + t.postEvent(event{ + signal: signalLogICECandidates, + }) } func (t *PCTransport) setConnectedAt(at time.Time) bool { @@ -335,51 +435,16 @@ func (t *PCTransport) onICEGatheringStateChange(state webrtc.ICEGathererState) { return } - go func() { - t.lock.Lock() - if t.restartAfterGathering { - t.params.Logger.Debugw("restarting ICE after ICE gathering") - if err := t.createAndSendOffer(&webrtc.OfferOptions{ICERestart: true}); err != nil { - t.params.Logger.Warnw("could not restart ICE", err) - } - t.lock.Unlock() - } else if t.pendingRestartIceOffer != nil { - t.params.Logger.Debugw("accept remote restart ice offer after ICE gathering") - offer := t.pendingRestartIceOffer - t.pendingRestartIceOffer = nil - t.lock.Unlock() - if err := t.SetRemoteDescription(*offer); err != nil { - t.params.Logger.Warnw("could not accept remote restart ice offer", err) - } - } else { - t.lock.Unlock() - } - }() + t.postEvent(event{ + signal: signalICEGatheringComplete, + }) } func (t *PCTransport) onICECandidateTrickle(c *webrtc.ICECandidate) { - t.lock.Lock() - if t.preferTCP && c != nil && c.Protocol != webrtc.ICEProtocolTCP { - cstr := c.String() - t.params.Logger.Infow("filtering out local candidate", "candidate", cstr) - t.filteredLocalCandidates = append(t.filteredLocalCandidates, cstr) - t.lock.Unlock() - return - } - - if c != nil { - t.allowedLocalCandidates = append(t.allowedLocalCandidates, c.String()) - } - if !t.localDescriptionSent { - t.cachedLocalCandidates = append(t.cachedLocalCandidates, c) - t.lock.Unlock() - return - } - t.lock.Unlock() - - if t.onICECandidate != nil { - t.onICECandidate(c) - } + t.postEvent(event{ + signal: signalLocalICECandidate, + data: c, + }) } func (t *PCTransport) handleConnectionFailed() { @@ -393,8 +458,8 @@ func (t *PCTransport) handleConnectionFailed() { } } - if t.onFailed != nil { - t.onFailed(isShort) + if onFailed := t.getOnFailed(); onFailed != nil { + onFailed(isShort) } } @@ -420,8 +485,8 @@ func (t *PCTransport) onPeerConnectionStateChange(state webrtc.PeerConnectionSta t.logICECandidates() isInitialConnection := t.setConnectedAt(time.Now()) if isInitialConnection { - if t.onInitialConnected != nil { - t.onInitialConnected() + if onInitialConnected := t.getOnInitialConnected(); onInitialConnected != nil { + onInitialConnected() } t.maybeNotifyFullyEstablished() @@ -440,8 +505,8 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { t.reliableDCOpened = true t.lock.Unlock() dc.OnMessage(func(msg webrtc.DataChannelMessage) { - if t.onDataPacket != nil { - t.onDataPacket(livekit.DataPacket_RELIABLE, msg.Data) + if onDataPacket := t.getOnDataPacket(); onDataPacket != nil { + onDataPacket(livekit.DataPacket_RELIABLE, msg.Data) } }) @@ -453,8 +518,8 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { t.lossyDCOpened = true t.lock.Unlock() dc.OnMessage(func(msg webrtc.DataChannelMessage) { - if t.onDataPacket != nil { - t.onDataPacket(livekit.DataPacket_LOSSY, msg.Data) + if onDataPacket := t.getOnDataPacket(); onDataPacket != nil { + onDataPacket(livekit.DataPacket_LOSSY, msg.Data) } }) @@ -469,94 +534,22 @@ func (t *PCTransport) maybeNotifyFullyEstablished() { fullyEstablished := t.reliableDCOpened && t.lossyDCOpened && !t.connectedAt.IsZero() t.lock.RUnlock() - if fullyEstablished && t.onFullyEstablished != nil { - t.onFullyEstablished() - } -} - -func (t *PCTransport) LocalDescriptionSent() { - var cachedLocalCandidates []*webrtc.ICECandidate - t.lock.Lock() - t.localDescriptionSent = true - - cachedLocalCandidates = t.cachedLocalCandidates - t.cachedLocalCandidates = nil - t.lock.Unlock() - - if t.onICECandidate != nil { - for _, c := range cachedLocalCandidates { - t.onICECandidate(c) + if fullyEstablished { + if onFullyEstablished := t.getOnFullyEstablished(); onFullyEstablished != nil { + onFullyEstablished() } } } -func (t *PCTransport) clearLocalDescriptionSentLocked() { - t.localDescriptionSent = false - - t.allowedLocalCandidates = nil - t.allowedRemoteCandidates = nil - t.filteredLocalCandidates = nil - t.filteredRemoteCandidates = nil -} - func (t *PCTransport) SetPreferTCP(preferTCP bool) { - t.lock.Lock() - t.preferTCP = preferTCP - t.lock.Unlock() + t.preferTCP.Store(preferTCP) } -func (t *PCTransport) createPeerConnection() error { - var bwe cc.BandwidthEstimator - pc, me, err := newPeerConnection(t.params, func(estimator cc.BandwidthEstimator) { - bwe = estimator +func (t *PCTransport) AddICECandidate(candidate webrtc.ICECandidateInit) { + t.postEvent(event{ + signal: signalRemoteICECandidate, + data: &candidate, }) - if err != nil { - return err - } - - t.pc = pc - t.pc.OnICEGatheringStateChange(t.onICEGatheringStateChange) - t.pc.OnICEConnectionStateChange(t.onICEConnectionStateChange) - t.pc.OnICECandidate(t.onICECandidateTrickle) - - t.pc.OnConnectionStateChange(t.onPeerConnectionStateChange) - - t.pc.OnDataChannel(t.onDataChannel) - - t.me = me - - if bwe != nil && t.streamAllocator != nil { - t.streamAllocator.SetBandwidthEstimator(bwe) - } - - return nil -} - -func (t *PCTransport) AddICECandidate(candidate webrtc.ICECandidateInit) error { - if t.pc.RemoteDescription() == nil { - t.lock.Lock() - t.pendingRemoteCandidates = append(t.pendingRemoteCandidates, candidate) - t.lock.Unlock() - return nil - } - - t.lock.Lock() - if t.preferTCP && !strings.Contains(candidate.Candidate, "tcp") { - t.params.Logger.Infow("filtering out remote candidate", "candidate", candidate.Candidate) - t.filteredRemoteCandidates = append(t.filteredRemoteCandidates, candidate.Candidate) - t.lock.Unlock() - return nil - } - - t.allowedRemoteCandidates = append(t.allowedRemoteCandidates, candidate.Candidate) - t.lock.Unlock() - - t.params.Logger.Infow("add candidate ", "candidate", candidate.Candidate) - return t.pc.AddICECandidate(candidate) -} - -func (t *PCTransport) PeerConnection() *webrtc.PeerConnection { - return t.pc } func (t *PCTransport) AddTrack(trackLocal webrtc.TrackLocal) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { @@ -725,7 +718,14 @@ func (t *PCTransport) SendDataPacket(dp *livekit.DataPacket) error { } func (t *PCTransport) Close() { - t.clearSignalStateCheckTimer() + t.eventChMu.Lock() + if t.isClosed.Swap(true) { + t.eventChMu.Unlock() + return + } + + close(t.eventCh) + t.eventChMu.Unlock() if t.streamAllocator != nil { t.streamAllocator.Stop() @@ -734,111 +734,76 @@ func (t *PCTransport) Close() { _ = t.pc.Close() } -func (t *PCTransport) SetRemoteDescription(sd webrtc.SessionDescription) error { +func (t *PCTransport) HandleRemoteDescription(sd webrtc.SessionDescription) { + t.postEvent(event{ + signal: signalRemoteDescriptionReceived, + data: &sd, + }) +} + +func (t *PCTransport) OnICECandidate(f func(c *webrtc.ICECandidate) error) { t.lock.Lock() - - var ( - iceCredential string - offerRestartICE bool - ) - if sd.Type == webrtc.SDPTypeOffer { - var err error - iceCredential, offerRestartICE, err = t.isRemoteOfferRestartICE(sd) - if err != nil { - t.params.Logger.Errorw("check remote offer restart ice failed", err) - t.lock.Unlock() - return err - } - } - - if offerRestartICE && t.pendingRestartIceOffer == nil { - t.clearLocalDescriptionSentLocked() - } - - if offerRestartICE && t.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering { - t.params.Logger.Debugw("remote offer restart ice while ice gathering") - t.pendingRestartIceOffer = &sd - t.lock.Unlock() - return nil - } - - // filter before setting remote description so that pion does not see filtered remote candidates - if t.preferTCP { - t.params.Logger.Infow("remote description (unfiltered)", "type", sd.Type, "sdp", sd.SDP) - } - sd = t.filterCandidates(sd) - if t.preferTCP { - t.params.Logger.Infow("remote description (filtered)", "type", sd.Type, "sdp", sd.SDP) - } - if err := t.pc.SetRemoteDescription(sd); err != nil { - t.lock.Unlock() - return err - } - - if t.currentOfferIceCredential == "" || offerRestartICE { - t.currentOfferIceCredential = iceCredential - } - - // negotiated, reset flag - lastState := t.negotiationState - t.negotiationState = negotiationStateNone - - t.clearSignalStateCheckTimerLocked() - - for _, c := range t.pendingRemoteCandidates { - if err := t.pc.AddICECandidate(c); err != nil { - t.lock.Unlock() - return err - } - } - t.pendingRemoteCandidates = nil - - // only initiate when we are the offerer - if lastState == negotiationRetry && sd.Type == webrtc.SDPTypeAnswer { - t.params.Logger.Debugw("re-negotiate after receiving answer") - if err := t.createAndSendOffer(nil); err != nil { - t.params.Logger.Errorw("could not negotiate", err) - } - } - onRemoteDescriptionSettled := t.onRemoteDescriptionSettled - t.lock.Unlock() - - if onRemoteDescriptionSettled != nil { - return onRemoteDescriptionSettled() - } - return nil -} - -func (t *PCTransport) isRemoteOfferRestartICE(sd webrtc.SessionDescription) (string, bool, error) { - parsed, err := sd.Unmarshal() - if err != nil { - return "", false, err - } - user, pwd, err := lksdp.ExtractICECredential(parsed) - if err != nil { - return "", false, err - } - - credential := fmt.Sprintf("%s:%s", user, pwd) - // ice credential changed, remote offer restart ice - restartICE := t.currentOfferIceCredential != "" && t.currentOfferIceCredential != credential - return credential, restartICE, nil -} - -func (t *PCTransport) OnICECandidate(f func(c *webrtc.ICECandidate)) { t.onICECandidate = f + t.lock.Unlock() +} + +func (t *PCTransport) getOnICECandidate() func(c *webrtc.ICECandidate) error { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onICECandidate } func (t *PCTransport) OnInitialConnected(f func()) { + t.lock.Lock() t.onInitialConnected = f + t.lock.Unlock() +} + +func (t *PCTransport) getOnInitialConnected() func() { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onInitialConnected } func (t *PCTransport) OnFullyEstablished(f func()) { + t.lock.Lock() t.onFullyEstablished = f + t.lock.Unlock() +} + +func (t *PCTransport) getOnFullyEstablished() func() { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onFullyEstablished } func (t *PCTransport) OnFailed(f func(isShortLived bool)) { + t.lock.Lock() t.onFailed = f + t.lock.Unlock() +} + +func (t *PCTransport) getOnFailed() func(isShortLived bool) { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onFailed +} + +func (t *PCTransport) OnGetDTX(f func() bool) { + t.lock.Lock() + t.onGetDTX = f + t.lock.Unlock() +} + +func (t *PCTransport) getOnGetDTX() func() bool { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onGetDTX } func (t *PCTransport) OnTrack(f func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver)) { @@ -846,26 +811,69 @@ func (t *PCTransport) OnTrack(f func(track *webrtc.TrackRemote, rtpReceiver *web } func (t *PCTransport) OnDataPacket(f func(kind livekit.DataPacket_Kind, data []byte)) { - t.onDataPacket = f -} - -// OnOffer is called when the PeerConnection starts negotiation and prepares an offer -func (t *PCTransport) OnOffer(f func(sd webrtc.SessionDescription)) { - t.onOffer = f -} - -func (t *PCTransport) OnAnswer(f func(sd webrtc.SessionDescription)) { - t.onAnswer = f -} - -func (t *PCTransport) OnRemoteDescriptionSettled(f func() error) { t.lock.Lock() - t.onRemoteDescriptionSettled = f + t.onDataPacket = f t.lock.Unlock() } +func (t *PCTransport) getOnDataPacket() func(kind livekit.DataPacket_Kind, data []byte) { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onDataPacket +} + +// OnOffer is called when the PeerConnection starts negotiation and prepares an offer +func (t *PCTransport) OnOffer(f func(sd webrtc.SessionDescription) error) { + t.lock.Lock() + t.onOffer = f + t.lock.Unlock() +} + +func (t *PCTransport) getOnOffer() func(sd webrtc.SessionDescription) error { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onOffer +} + +func (t *PCTransport) OnAnswer(f func(sd webrtc.SessionDescription) error) { + t.lock.Lock() + t.onAnswer = f + t.lock.Unlock() +} + +func (t *PCTransport) getOnAnswer() func(sd webrtc.SessionDescription) error { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onAnswer +} + +func (t *PCTransport) OnNegotiationStateChanged(f func(state NegotiationState)) { + t.lock.Lock() + t.onNegotiationStateChanged = f + t.lock.Unlock() +} + +func (t *PCTransport) getOnNegotiationStateChanged() func(state NegotiationState) { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onNegotiationStateChanged +} + func (t *PCTransport) OnNegotiationFailed(f func()) { + t.lock.Lock() t.onNegotiationFailed = f + t.lock.Unlock() +} + +func (t *PCTransport) getOnNegotiationFailed() func() { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onNegotiationFailed } func (t *PCTransport) AddNegotiationPending(publisherID livekit.ParticipantID) { @@ -876,18 +884,31 @@ func (t *PCTransport) AddNegotiationPending(publisherID livekit.ParticipantID) { func (t *PCTransport) Negotiate(force bool) { if force { + t.lock.Lock() t.debouncedNegotiate(func() { // no op to cancel pending negotiation }) - if err := t.CreateAndSendOffer(nil); err != nil { - t.params.Logger.Errorw("could not negotiate", err) - } - } else { - t.debouncedNegotiate(func() { - if err := t.CreateAndSendOffer(nil); err != nil { - t.params.Logger.Errorw("could not negotiate", err) - } + t.debouncePending = false + t.lock.Unlock() + + t.postEvent(event{ + signal: signalSendOffer, }) + } else { + t.lock.Lock() + if !t.debouncePending { + t.debouncedNegotiate(func() { + t.lock.Lock() + t.debouncePending = false + t.lock.Unlock() + + t.postEvent(event{ + signal: signalSendOffer, + }) + }) + t.debouncePending = true + } + t.lock.Unlock() } } @@ -981,195 +1002,38 @@ func (t *PCTransport) configureReceiverDTX(enableDTX bool) { } } -func (t *PCTransport) CreateAndSendAnswer(enableDTX bool) error { - t.lock.RLock() - defer t.lock.RUnlock() - - t.configureReceiverDTX(enableDTX) - - answer, err := t.pc.CreateAnswer(nil) - if err != nil { - return err - } - - if t.preferTCP { - t.params.Logger.Infow("local answer (unfiltered)", "sdp", answer.SDP) - } - if err = t.pc.SetLocalDescription(answer); err != nil { - return err - } - - // - // Filter after setting local description as pion expects the answer - // to match between CreateAnswer and SetLocalDescription. - // Filtered answer is sent to remote so that remote does not - // see filtered candidates. - // - answer = t.filterCandidates(answer) - if t.preferTCP { - t.params.Logger.Infow("local answer (filtered)", "sdp", answer.SDP) - } - - if t.onAnswer != nil { - go t.onAnswer(answer) - } - - return nil -} - -func (t *PCTransport) clearSignalStateCheckTimer() { - t.lock.Lock() - t.clearSignalStateCheckTimerLocked() - t.lock.Unlock() -} - -func (t *PCTransport) clearSignalStateCheckTimerLocked() { - if t.signalStateCheckTimer != nil { - t.signalStateCheckTimer.Stop() - t.signalStateCheckTimer = nil - } -} - -func (t *PCTransport) setupSignalStateCheckTimerLocked() { - negotiateVersion := t.negotiateCounter.Inc() - t.clearSignalStateCheckTimerLocked() - t.signalStateCheckTimer = time.AfterFunc(negotiationFailedTimeout, func() { - t.lock.Lock() - t.clearSignalStateCheckTimerLocked() - - failed := t.negotiationState != negotiationStateNone - t.lock.Unlock() - - if t.negotiateCounter.Load() == negotiateVersion && failed { - if t.onNegotiationFailed != nil { - t.onNegotiationFailed() - } - } +func (t *PCTransport) ICERestart() { + t.postEvent(event{ + signal: signalICERestart, }) } -func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error { - t.lock.Lock() - defer t.lock.Unlock() - - if options != nil && options.ICERestart { - t.clearLocalDescriptionSentLocked() +func (t *PCTransport) OnStreamStateChange(f func(update *sfu.StreamStateUpdate) error) { + if t.streamAllocator == nil { + return } - return t.createAndSendOffer(options) + t.streamAllocator.OnStreamStateChange(f) } -// creates and sends offer assuming lock has been acquired -func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { - if t.onOffer == nil { - return nil - } - if t.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { - return nil +func (t *PCTransport) AddTrackToStreamAllocator(subTrack types.SubscribedTrack) { + if t.streamAllocator == nil { + return } - iceRestart := (options != nil && options.ICERestart) || t.restartAtNextOffer + t.streamAllocator.AddTrack(subTrack.DownTrack(), sfu.AddTrackParams{ + Source: subTrack.MediaTrack().Source(), + IsSimulcast: subTrack.MediaTrack().IsSimulcast(), + PublisherID: subTrack.MediaTrack().PublisherID(), + }) +} - // if restart is requested, and we are not ready, then continue afterwards - if iceRestart { - if t.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering { - t.params.Logger.Debugw("restart ICE after gathering") - t.restartAfterGathering = true - return nil - } - t.params.Logger.Debugw("restarting ICE") +func (t *PCTransport) RemoveTrackFromStreamAllocator(subTrack types.SubscribedTrack) { + if t.streamAllocator == nil { + return } - if iceRestart && t.negotiationState != negotiationStateNone { - currentSD := t.pc.CurrentRemoteDescription() - if currentSD == nil { - // restart without current remote description, send current local description again to try recover - offer := t.pc.LocalDescription() - if offer == nil { - // it should not happen, log just in case - t.params.Logger.Warnw("ice restart without local offer", nil) - return ErrIceRestartWithoutLocalSDP - } else { - t.negotiationState = negotiationRetry - t.restartAtNextOffer = true - go t.onOffer(*offer) - return nil - } - } else { - // recover by re-applying the last answer - t.params.Logger.Infow("recovering from client negotiation state on ICE restart") - if err := t.pc.SetRemoteDescription(*currentSD); err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "remote_description").Add(1) - return err - } - } - } else { - // when there's an ongoing negotiation, let it finish and not disrupt its state - if t.negotiationState == negotiationStateClient { - t.params.Logger.Infow("skipping negotiation, trying again later") - t.negotiationState = negotiationRetry - return nil - } else if t.negotiationState == negotiationRetry { - // already set to retry, we can safely skip this attempt - return nil - } - } - - ensureICERestart := func(options *webrtc.OfferOptions) *webrtc.OfferOptions { - if options == nil { - options = &webrtc.OfferOptions{} - } - options.ICERestart = true - return options - } - - if t.previousAnswer != nil { - t.previousAnswer = nil - options = ensureICERestart(options) - } - - if t.restartAtNextOffer { - t.restartAtNextOffer = false - options = ensureICERestart(options) - } - - offer, err := t.pc.CreateOffer(options) - if err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "create").Add(1) - t.params.Logger.Errorw("could not create offer", err) - return err - } - - if t.preferTCP { - t.params.Logger.Infow("local offer (unfiltered)", "sdp", offer.SDP) - } - err = t.pc.SetLocalDescription(offer) - if err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "local_description").Add(1) - t.params.Logger.Errorw("could not set local description", err) - return err - } - - // - // Filter after setting local description as pion expects the offer - // to match between CreateOffer and SetLocalDescription. - // Filtered offer is sent to remote so that remote does not - // see filtered candidates. - // - offer = t.filterCandidates(offer) - if t.preferTCP { - t.params.Logger.Infow("local offer (filtered)", "sdp", offer.SDP) - } - - // indicate waiting for client - t.negotiationState = negotiationStateClient - t.restartAfterGathering = false - t.negotiationPending = make(map[livekit.ParticipantID]bool) - - t.setupSignalStateCheckTimerLocked() - - go t.onOffer(offer) - return nil + t.streamAllocator.RemoveTrack(subTrack.DownTrack()) } func (t *PCTransport) preparePC(previousAnswer webrtc.SessionDescription) error { @@ -1188,7 +1052,9 @@ func (t *PCTransport) preparePC(previousAnswer webrtc.SessionDescription) error if err != nil { return err } - t.pc.SetLocalDescription(offer) + if err := t.pc.SetLocalDescription(offer); err != nil { + return err + } // // Simulate client side peer connection and set DTLS role from previous answer. @@ -1210,7 +1076,9 @@ func (t *PCTransport) preparePC(previousAnswer webrtc.SessionDescription) error } defer pc2.Close() - pc2.SetRemoteDescription(offer) + if err := pc2.SetRemoteDescription(offer); err != nil { + return err + } ans, err := pc2.CreateAnswer(nil) if err != nil { return err @@ -1283,46 +1151,216 @@ func (t *PCTransport) initPCWithPreviousAnswer(previousAnswer webrtc.SessionDesc return nil } -func (t *PCTransport) OnStreamStateChange(f func(update *sfu.StreamStateUpdate) error) { - if t.streamAllocator == nil { - return - } - - t.streamAllocator.OnStreamStateChange(f) -} - -func (t *PCTransport) AddTrackToStreamAllocator(subTrack types.SubscribedTrack) { - if t.streamAllocator == nil { - return - } - - t.streamAllocator.AddTrack(subTrack.DownTrack(), sfu.AddTrackParams{ - Source: subTrack.MediaTrack().Source(), - IsSimulcast: subTrack.MediaTrack().IsSimulcast(), - PublisherID: subTrack.MediaTrack().PublisherID(), - }) -} - -func (t *PCTransport) RemoveTrackFromStreamAllocator(subTrack types.SubscribedTrack) { - if t.streamAllocator == nil { - return - } - - t.streamAllocator.RemoveTrack(subTrack.DownTrack()) -} - func (t *PCTransport) SetPreviousAnswer(answer *webrtc.SessionDescription) { t.lock.Lock() - defer t.lock.Unlock() if t.pc.RemoteDescription() == nil && t.previousAnswer == nil { t.previousAnswer = answer if err := t.initPCWithPreviousAnswer(*t.previousAnswer); err != nil { t.params.Logger.Errorw("initPCWithPreviousAnswer failed", err) + t.lock.Unlock() + + if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { + onNegotiationFailed() + } + return } } + t.lock.Unlock() +} + +func (t *PCTransport) postEvent(event event) { + t.eventChMu.RLock() + if t.isClosed.Load() { + t.eventChMu.RUnlock() + return + } + + select { + case t.eventCh <- event: + default: + t.params.Logger.Warnw("event queue full", nil, "event", event.String()) + } + t.eventChMu.RUnlock() +} + +func (t *PCTransport) processEvents() { + for event := range t.eventCh { + err := t.handleEvent(&event) + if err != nil { + t.params.Logger.Errorw("error handling event", err, "event", event.String()) + if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { + onNegotiationFailed() + } + break + } + } + + t.clearSignalStateCheckTimer() + t.params.Logger.Infow("leaving events processor") +} + +func (t *PCTransport) handleEvent(e *event) error { + switch e.signal { + case signalICEGatheringComplete: + return t.handleICEGatheringComplete(e) + case signalLocalICECandidate: + return t.handleLocalICECandidate(e) + case signalRemoteICECandidate: + return t.handleRemoteICECandidate(e) + case signalLogICECandidates: + return t.handleLogICECandidates(e) + case signalSendOffer: + return t.handleSendOffer(e) + case signalRemoteDescriptionReceived: + return t.handleRemoteDescriptionReceived(e) + case signalICERestart: + return t.handleICERestart(e) + } + + return nil +} + +func (t *PCTransport) handleICEGatheringComplete(e *event) error { + if t.params.Target == livekit.SignalTarget_SUBSCRIBER { + return t.handleICEGatheringCompleteOfferer() + } else { + return t.handleICEGatheringCompleteAnswerer() + } } -func (t *PCTransport) filterCandidates(sd webrtc.SessionDescription) webrtc.SessionDescription { +func (t *PCTransport) handleICEGatheringCompleteOfferer() error { + if !t.restartAfterGathering { + return nil + } + + t.params.Logger.Debugw("restarting ICE after ICE gathering") + t.restartAfterGathering = false + return t.createAndSendOffer(&webrtc.OfferOptions{ICERestart: true}) +} + +func (t *PCTransport) handleICEGatheringCompleteAnswerer() error { + if t.pendingRestartIceOffer == nil { + return nil + } + + t.params.Logger.Debugw("accept remote restart ice offer after ICE gathering") + err := t.setRemoteDescription(*t.pendingRestartIceOffer) + t.pendingRestartIceOffer = nil + return err +} + +func (t *PCTransport) localDescriptionSent() error { + if !t.cacheLocalCandidates { + return nil + } + + t.cacheLocalCandidates = false + + cachedLocalCandidates := t.cachedLocalCandidates + t.cachedLocalCandidates = nil + + if onICECandidate := t.getOnICECandidate(); onICECandidate != nil { + for _, c := range cachedLocalCandidates { + if err := onICECandidate(c); err != nil { + return err + } + } + + return nil + } + + return ErrNoICECandidateHandler +} + +func (t *PCTransport) clearLocalDescriptionSent() { + t.cacheLocalCandidates = true + t.cachedLocalCandidates = nil + + t.allowedLocalCandidates = nil + t.allowedRemoteCandidates = nil + t.filteredLocalCandidates = nil + t.filteredRemoteCandidates = nil +} + +func (t *PCTransport) handleLocalICECandidate(e *event) error { + c := e.data.(*webrtc.ICECandidate) + + filtered := false + if t.preferTCP.Load() && c != nil && c.Protocol != webrtc.ICEProtocolTCP { + cstr := c.String() + t.params.Logger.Infow("filtering out local candidate", "candidate", cstr) + t.filteredLocalCandidates = append(t.filteredLocalCandidates, cstr) + filtered = true + } + + if filtered { + return nil + } + + if c != nil { + t.allowedLocalCandidates = append(t.allowedLocalCandidates, c.String()) + } + if t.cacheLocalCandidates { + t.cachedLocalCandidates = append(t.cachedLocalCandidates, c) + return nil + } + + if onICECandidate := t.getOnICECandidate(); onICECandidate != nil { + return onICECandidate(c) + } + + return ErrNoICECandidateHandler +} + +func (t *PCTransport) handleRemoteICECandidate(e *event) error { + c := e.data.(*webrtc.ICECandidateInit) + + filtered := false + if t.preferTCP.Load() && !strings.Contains(c.Candidate, "tcp") { + t.params.Logger.Infow("filtering out remote candidate", "candidate", c.Candidate) + t.filteredRemoteCandidates = append(t.filteredRemoteCandidates, c.Candidate) + filtered = true + } + + if filtered { + return nil + } + + if t.pc.RemoteDescription() == nil { + t.pendingRemoteCandidates = append(t.pendingRemoteCandidates, c) + return nil + } + + t.allowedRemoteCandidates = append(t.allowedRemoteCandidates, c.Candidate) + + t.params.Logger.Infow("add candidate ", "candidate", c.Candidate) + if err := t.pc.AddICECandidate(*c); err != nil { + return errors.Wrap(err, "add ice candidate failed") + } + + return nil +} + +func (t *PCTransport) handleLogICECandidates(e *event) error { + t.params.Logger.Infow( + "ice candidates", + "lc", t.allowedLocalCandidates, + "rc", t.allowedRemoteCandidates, + "lc (filtered)", t.filteredLocalCandidates, + "rc (filtered)", t.filteredRemoteCandidates, + ) + + return nil +} + +func (t *PCTransport) setNegotiationState(state NegotiationState) { + t.negotiationState = state + if onNegotiationStateChanged := t.getOnNegotiationStateChanged(); onNegotiationStateChanged != nil { + onNegotiationStateChanged(t.negotiationState) + } +} + +func (t *PCTransport) filterCandidates(sd webrtc.SessionDescription, preferTCP bool) webrtc.SessionDescription { parsed, err := sd.Unmarshal() if err != nil { t.params.Logger.Errorw("could not unmarshal SDP to filter candidates", err) @@ -1333,7 +1371,7 @@ func (t *PCTransport) filterCandidates(sd webrtc.SessionDescription) webrtc.Sess filteredAttrs := make([]sdp.Attribute, 0, len(attrs)) for _, a := range attrs { if a.Key == sdp.AttrKeyCandidate { - if t.preferTCP { + if preferTCP { if strings.Contains(a.Value, "tcp") { filteredAttrs = append(filteredAttrs, a) } @@ -1361,3 +1399,321 @@ func (t *PCTransport) filterCandidates(sd webrtc.SessionDescription) webrtc.Sess sd.SDP = string(bytes) return sd } + +func (t *PCTransport) clearSignalStateCheckTimer() { + if t.signalStateCheckTimer != nil { + t.signalStateCheckTimer.Stop() + t.signalStateCheckTimer = nil + } +} + +func (t *PCTransport) setupSignalStateCheckTimer() { + t.clearSignalStateCheckTimer() + + negotiateVersion := t.negotiateCounter.Inc() + t.signalStateCheckTimer = time.AfterFunc(negotiationFailedTimeout, func() { + t.clearSignalStateCheckTimer() + + failed := t.negotiationState != NegotiationStateNone + + if t.negotiateCounter.Load() == negotiateVersion && failed { + if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { + onNegotiationFailed() + } + } + }) +} + +func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { + if t.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { + t.params.Logger.Warnw("trying to send offer on closed peer connection", nil) + return nil + } + + // when there's an ongoing negotiation, let it finish and not disrupt its state + if t.negotiationState == NegotiationStateRemote { + t.params.Logger.Infow("skipping negotiation, trying again later") + t.setNegotiationState(NegotiationStateRetry) + return nil + } else if t.negotiationState == NegotiationStateRetry { + // already set to retry, we can safely skip this attempt + return nil + } + + ensureICERestart := func(options *webrtc.OfferOptions) *webrtc.OfferOptions { + if options == nil { + options = &webrtc.OfferOptions{} + } + options.ICERestart = true + return options + } + + t.lock.Lock() + if t.previousAnswer != nil { + t.previousAnswer = nil + options = ensureICERestart(options) + t.params.Logger.Infow("ice restart due to previous answer") + } + t.lock.Unlock() + + if t.restartAtNextOffer { + t.restartAtNextOffer = false + options = ensureICERestart(options) + t.params.Logger.Infow("ice restart at next offer") + } + + if options != nil && options.ICERestart { + t.clearLocalDescriptionSent() + } + + offer, err := t.pc.CreateOffer(options) + if err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "create").Add(1) + return errors.Wrap(err, "create offer failed") + } + + preferTCP := t.preferTCP.Load() + if preferTCP { + t.params.Logger.Infow("local offer (unfiltered)", "sdp", offer.SDP) + } + + err = t.pc.SetLocalDescription(offer) + if err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "local_description").Add(1) + return errors.Wrap(err, "setting local description failed") + } + + // + // Filter after setting local description as pion expects the offer + // to match between CreateOffer and SetLocalDescription. + // Filtered offer is sent to remote so that remote does not + // see filtered candidates. + // + offer = t.filterCandidates(offer, preferTCP) + if preferTCP { + t.params.Logger.Infow("local offer (filtered)", "sdp", offer.SDP) + } + + // indicate waiting for remote + t.setNegotiationState(NegotiationStateRemote) + + t.negotiationPending = make(map[livekit.ParticipantID]bool) + + t.setupSignalStateCheckTimer() + + if onOffer := t.getOnOffer(); onOffer != nil { + if err := onOffer(offer); err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) + return errors.Wrap(err, "could not send offer") + } + + prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) + return t.localDescriptionSent() + } + + return ErrNoOfferHandler +} + +func (t *PCTransport) handleSendOffer(e *event) error { + return t.createAndSendOffer(nil) +} + +func (t *PCTransport) handleRemoteDescriptionReceived(e *event) error { + sd := e.data.(*webrtc.SessionDescription) + if sd.Type == webrtc.SDPTypeOffer { + return t.handleRemoteOfferReceived(sd) + } else { + return t.handleRemoteAnswerReceived(sd) + } +} + +func (t *PCTransport) isRemoteOfferRestartICE(sd *webrtc.SessionDescription) (string, bool, error) { + parsed, err := sd.Unmarshal() + if err != nil { + return "", false, err + } + user, pwd, err := lksdp.ExtractICECredential(parsed) + if err != nil { + return "", false, err + } + + credential := fmt.Sprintf("%s:%s", user, pwd) + // ice credential changed, remote offer restart ice + restartICE := t.currentOfferIceCredential != "" && t.currentOfferIceCredential != credential + return credential, restartICE, nil +} + +func (t *PCTransport) setRemoteDescription(sd webrtc.SessionDescription) error { + // filter before setting remote description so that pion does not see filtered remote candidates + preferTCP := t.preferTCP.Load() + if preferTCP { + t.params.Logger.Infow("remote description (unfiltered)", "type", sd.Type, "sdp", sd.SDP) + } + sd = t.filterCandidates(sd, preferTCP) + if preferTCP { + t.params.Logger.Infow("remote description (filtered)", "type", sd.Type, "sdp", sd.SDP) + } + + if err := t.pc.SetRemoteDescription(sd); err != nil { + sdpType := "offer" + if sd.Type == webrtc.SDPTypeAnswer { + sdpType = "answer" + } + prometheus.ServiceOperationCounter.WithLabelValues(sdpType, "error", "remote_description").Add(1) + return errors.Wrap(err, "setting remote description failed") + } + + for _, c := range t.pendingRemoteCandidates { + if err := t.pc.AddICECandidate(*c); err != nil { + return errors.Wrap(err, "add ice candidate failed") + } + } + t.pendingRemoteCandidates = nil + + return nil +} + +func (t *PCTransport) createAndSendAnswer() error { + enableDTX := false + if onGetDTX := t.getOnGetDTX(); onGetDTX != nil { + enableDTX = onGetDTX() + } + t.configureReceiverDTX(enableDTX) + + answer, err := t.pc.CreateAnswer(nil) + if err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "create").Add(1) + return errors.Wrap(err, "create answer failed") + } + + preferTCP := t.preferTCP.Load() + if preferTCP { + t.params.Logger.Infow("local answer (unfiltered)", "sdp", answer.SDP) + } + + if err = t.pc.SetLocalDescription(answer); err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "local_description").Add(1) + return errors.Wrap(err, "setting local description failed") + } + + // + // Filter after setting local description as pion expects the answer + // to match between CreateAnswer and SetLocalDescription. + // Filtered answer is sent to remote so that remote does not + // see filtered candidates. + // + answer = t.filterCandidates(answer, preferTCP) + if preferTCP { + t.params.Logger.Infow("local answer (filtered)", "sdp", answer.SDP) + } + + if onAnswer := t.getOnAnswer(); onAnswer != nil { + if err := onAnswer(answer); err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "write_message").Add(1) + return errors.Wrap(err, "could not send answer") + } + + prometheus.ServiceOperationCounter.WithLabelValues("answer", "success", "").Add(1) + return t.localDescriptionSent() + } + + return ErrNoAnswerHandler +} + +func (t *PCTransport) handleRemoteOfferReceived(sd *webrtc.SessionDescription) error { + iceCredential, offerRestartICE, err := t.isRemoteOfferRestartICE(sd) + if err != nil { + return errors.Wrap(err, "check remote offer restart ice failed") + } + + if offerRestartICE && t.pendingRestartIceOffer == nil { + t.clearLocalDescriptionSent() + } + + if offerRestartICE && t.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering { + t.params.Logger.Debugw("remote offer restart ice while ice gathering") + t.pendingRestartIceOffer = sd + return nil + } + + if err := t.setRemoteDescription(*sd); err != nil { + return err + } + + if t.currentOfferIceCredential == "" || offerRestartICE { + t.currentOfferIceCredential = iceCredential + } + + return t.createAndSendAnswer() +} + +func (t *PCTransport) handleRemoteAnswerReceived(sd *webrtc.SessionDescription) error { + if err := t.setRemoteDescription(*sd); err != nil { + return err + } + + t.clearSignalStateCheckTimer() + + if t.negotiationState == NegotiationStateRetry { + t.setNegotiationState(NegotiationStateNone) + + t.params.Logger.Debugw("re-negotiate after receiving answer") + return t.createAndSendOffer(nil) + } + + t.setNegotiationState(NegotiationStateNone) + return nil +} + +func (t *PCTransport) handleICERestart(e *event) error { + if t.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { + t.params.Logger.Warnw("trying to restart ICE on closed peer connection", nil) + return nil + } + + // if restart is requested, and we are not ready, then continue afterwards + if t.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering { + t.params.Logger.Debugw("restart ICE after gathering") + t.restartAfterGathering = true + return nil + } + + if t.negotiationState == NegotiationStateNone { + return t.createAndSendOffer(&webrtc.OfferOptions{ICERestart: true}) + } + + currentSD := t.pc.CurrentRemoteDescription() + if currentSD == nil { + // restart without current remote description, send current local description again to try recover + offer := t.pc.LocalDescription() + if offer == nil { + // it should not happen, log just in case + t.params.Logger.Warnw("ice restart without local offer", nil) + return ErrIceRestartWithoutLocalSDP + } else { + t.params.Logger.Infow("deferring ice restart to next offer") + t.setNegotiationState(NegotiationStateRetry) + t.restartAtNextOffer = true + if onOffer := t.getOnOffer(); onOffer != nil { + err := onOffer(*offer) + if err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) + } else { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) + } + return err + } + return ErrNoOfferHandler + } + } else { + // recover by re-applying the last answer + t.params.Logger.Infow("recovering from client negotiation state on ICE restart") + if err := t.pc.SetRemoteDescription(*currentSD); err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "remote_description").Add(1) + return errors.Wrap(err, "set remote description failed") + } else { + t.setNegotiationState(NegotiationStateNone) + return t.createAndSendOffer(&webrtc.OfferOptions{ICERestart: true}) + } + } +} diff --git a/pkg/rtc/transport_test.go b/pkg/rtc/transport_test.go index 9d5071294..e604c294a 100644 --- a/pkg/rtc/transport_test.go +++ b/pkg/rtc/transport_test.go @@ -1,18 +1,18 @@ package rtc import ( + "fmt" "strings" - "sync/atomic" "testing" "time" "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" - - "github.com/livekit/protocol/livekit" + "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/testutils" + "github.com/livekit/protocol/livekit" ) func TestMissingAnswerDuringICERestart(t *testing.T) { @@ -26,55 +26,40 @@ func TestMissingAnswerDuringICERestart(t *testing.T) { require.NoError(t, err) _, err = transportA.pc.CreateDataChannel("test", nil) require.NoError(t, err) - transportB, err := NewPCTransport(params) + + paramsB := params + paramsB.Target = livekit.SignalTarget_SUBSCRIBER + transportB, err := NewPCTransport(paramsB) require.NoError(t, err) // exchange ICE handleICEExchange(t, transportA, transportB) - // set offer/answer - handleOffer := handleOfferFunc(t, transportA, transportB) - transportA.OnOffer(handleOffer) - - // first establish connection - require.NoError(t, transportA.CreateAndSendOffer(nil)) - - // ensure we are connected the first time - testutils.WithTimeout(t, func() string { - if transportA.pc.ICEConnectionState() != webrtc.ICEConnectionStateConnected { - return "transportA did not become connected" - } - - if transportB.pc.ICEConnectionState() != webrtc.ICEConnectionStateConnected { - return "transportB did not become connected" - } - return "" - }) + connectTransports(t, transportA, transportB, false, 1, 1) require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState()) require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState()) - // offer again, but missed - transportA.OnOffer(func(sd webrtc.SessionDescription) {}) - require.NoError(t, transportA.CreateAndSendOffer(nil)) - require.Equal(t, webrtc.SignalingStateHaveLocalOffer, transportA.pc.SignalingState()) - require.Equal(t, negotiationStateClient, transportA.negotiationState) - - // now restart ICE - t.Logf("creating offer with ICE restart") - transportA.OnOffer(handleOffer) - require.NoError(t, transportA.CreateAndSendOffer(&webrtc.OfferOptions{ - ICERestart: true, - })) - - testutils.WithTimeout(t, func() string { - if transportA.pc.ICEConnectionState() != webrtc.ICEConnectionStateConnected { - return "transportA did not reconnect after ICE restart" - } - if transportB.pc.ICEConnectionState() != webrtc.ICEConnectionStateConnected { - return "transportB did not reconnect after ICE restart" - } - return "" + var negotiationState atomic.Value + transportA.OnNegotiationStateChanged(func(state NegotiationState) { + negotiationState.Store(state) }) + + // offer again, but missed + var offerReceived atomic.Bool + transportA.OnOffer(func(sd webrtc.SessionDescription) error { + require.Equal(t, webrtc.SignalingStateHaveLocalOffer, transportA.pc.SignalingState()) + require.Equal(t, NegotiationStateRemote, negotiationState.Load().(NegotiationState)) + offerReceived.Store(true) + return nil + }) + transportA.Negotiate(true) + require.Eventually(t, func() bool { + return offerReceived.Load() + }, 10*time.Second, time.Millisecond*10, "transportA offer not received") + + connectTransports(t, transportA, transportB, true, 1, 1) + require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState()) + require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState()) } func TestNegotiationTiming(t *testing.T) { @@ -88,51 +73,63 @@ func TestNegotiationTiming(t *testing.T) { require.NoError(t, err) _, err = transportA.pc.CreateDataChannel("test", nil) require.NoError(t, err) + transportB, err := NewPCTransport(params) require.NoError(t, err) + require.False(t, transportA.IsEstablished()) require.False(t, transportB.IsEstablished()) handleICEExchange(t, transportA, transportB) offer := atomic.Value{} - transportA.OnOffer(func(sd webrtc.SessionDescription) { + transportA.OnOffer(func(sd webrtc.SessionDescription) error { offer.Store(&sd) + return nil + }) + + var negotiationState atomic.Value + transportA.OnNegotiationStateChanged(func(state NegotiationState) { + negotiationState.Store(state) }) // initial offer - require.NoError(t, transportA.CreateAndSendOffer(nil)) - require.Equal(t, negotiationStateClient, transportA.negotiationState) + transportA.Negotiate(true) + require.Eventually(t, func() bool { + return negotiationState.Load().(NegotiationState) == NegotiationStateRemote + }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRemote") // second try, should've flipped transport status to retry - require.NoError(t, transportA.CreateAndSendOffer(nil)) - require.Equal(t, negotiationRetry, transportA.negotiationState) + transportA.Negotiate(true) + require.Eventually(t, func() bool { + return negotiationState.Load().(NegotiationState) == NegotiationStateRetry + }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRetry") // third try, should've stayed at retry - require.NoError(t, transportA.CreateAndSendOffer(nil)) - require.Equal(t, negotiationRetry, transportA.negotiationState) + transportA.Negotiate(true) + time.Sleep(100 * time.Millisecond) // some time to process the negotiate event + require.Eventually(t, func() bool { + return negotiationState.Load().(NegotiationState) == NegotiationStateRetry + }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRetry") time.Sleep(5 * time.Millisecond) actualOffer, ok := offer.Load().(*webrtc.SessionDescription) - require.True(t, ok) - require.NoError(t, transportB.SetRemoteDescription(*actualOffer)) - answer, err := transportB.pc.CreateAnswer(nil) - require.NoError(t, err) - require.NoError(t, transportB.pc.SetLocalDescription(answer)) - require.NoError(t, transportA.SetRemoteDescription(answer)) - testutils.WithTimeout(t, func() string { - if !transportA.IsEstablished() { - return "transportA is not established" - } - if !transportB.IsEstablished() { - return "transportB is not established" - } - return "" + transportB.OnAnswer(func(answer webrtc.SessionDescription) error { + transportA.HandleRemoteDescription(answer) + return nil }) + transportB.HandleRemoteDescription(*actualOffer) + + require.Eventually(t, func() bool { + return transportA.IsEstablished() + }, 10*time.Second, time.Millisecond*10, "transportA is not established") + require.Eventually(t, func() bool { + return transportB.IsEstablished() + }, 10*time.Second, time.Millisecond*10, "transportB is not established") // it should still be negotiating again - require.Equal(t, negotiationStateClient, transportA.negotiationState) + require.Equal(t, NegotiationStateRemote, negotiationState.Load().(NegotiationState)) offer2, ok := offer.Load().(*webrtc.SessionDescription) require.True(t, ok) require.False(t, offer2 == actualOffer) @@ -149,42 +146,57 @@ func TestFirstOfferMissedDuringICERestart(t *testing.T) { require.NoError(t, err) _, err = transportA.pc.CreateDataChannel("test", nil) require.NoError(t, err) - transportB, err := NewPCTransport(params) - require.NoError(t, err) - //first offer missed - transportA.OnOffer(func(sd webrtc.SessionDescription) {}) - require.NoError(t, transportA.CreateAndSendOffer(nil)) + paramsB := params + paramsB.Target = livekit.SignalTarget_SUBSCRIBER + transportB, err := NewPCTransport(paramsB) + require.NoError(t, err) // exchange ICE handleICEExchange(t, transportA, transportB) + // first offer missed + var firstOfferReceived atomic.Bool + transportA.OnOffer(func(sd webrtc.SessionDescription) error { + firstOfferReceived.Store(true) + return nil + }) + transportA.Negotiate(true) + require.Eventually(t, func() bool { + return firstOfferReceived.Load() + }, 10*time.Second, 10*time.Millisecond, "first offer not received") + // set offer/answer with restart ICE, will negotiate twice, // first one is recover from missed offer // second one is restartICE - handleOffer := handleOfferFunc(t, transportA, transportB) - var offerCount int32 - transportA.OnOffer(func(sd webrtc.SessionDescription) { - atomic.AddInt32(&offerCount, 1) + transportB.OnAnswer(func(answer webrtc.SessionDescription) error { + transportA.HandleRemoteDescription(answer) + return nil + }) + + var offerCount atomic.Int32 + transportA.OnOffer(func(sd webrtc.SessionDescription) error { + offerCount.Inc() + // the second offer is a ice restart offer, so we wait transportB complete the ice gathering if transportB.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering { require.Eventually(t, func() bool { return transportB.pc.ICEGatheringState() == webrtc.ICEGatheringStateComplete }, 10*time.Second, time.Millisecond*10) } - handleOffer(sd) + + transportB.HandleRemoteDescription(sd) + return nil }) // first establish connection - require.NoError(t, transportA.CreateAndSendOffer(&webrtc.OfferOptions{ - ICERestart: true, - })) + transportA.ICERestart() // ensure we are connected require.Eventually(t, func() bool { return transportA.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected && transportB.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected && - atomic.LoadInt32(&offerCount) == 2 + offerCount.Load() == 2 }, testutils.ConnectTimeout, 10*time.Millisecond, "transport did not connect") transportA.Close() transportB.Close() @@ -201,50 +213,62 @@ func TestFirstAnwserMissedDuringICERestart(t *testing.T) { require.NoError(t, err) _, err = transportA.pc.CreateDataChannel("test", nil) require.NoError(t, err) - transportB, err := NewPCTransport(params) + + paramsB := params + paramsB.Target = livekit.SignalTarget_SUBSCRIBER + transportB, err := NewPCTransport(paramsB) require.NoError(t, err) - //first anwser missed - transportA.OnOffer(func(sd webrtc.SessionDescription) { - require.NoError(t, transportB.SetRemoteDescription(sd)) - answer, err := transportB.pc.CreateAnswer(nil) - require.NoError(t, err) - require.NoError(t, transportB.pc.SetLocalDescription(answer)) - }) // exchange ICE handleICEExchange(t, transportA, transportB) - require.NoError(t, transportA.CreateAndSendOffer(nil)) + // first anwser missed + var firstAnswerReceived atomic.Bool + transportB.OnAnswer(func(sd webrtc.SessionDescription) error { + firstAnswerReceived.Store(true) + return nil + }) + transportA.OnOffer(func(sd webrtc.SessionDescription) error { + transportB.HandleRemoteDescription(sd) + return nil + }) + + transportA.Negotiate(true) require.Eventually(t, func() bool { return transportB.pc.SignalingState() == webrtc.SignalingStateStable - }, time.Second, 10*time.Millisecond) + }, time.Second, 10*time.Millisecond, "transportB signaling state did not go to stable") // set offer/answer with restart ICE, will negotiate twice, // first one is recover from missed offer // second one is restartICE - handleOffer := handleOfferFunc(t, transportA, transportB) - var offerCount int32 - transportA.OnOffer(func(sd webrtc.SessionDescription) { - atomic.AddInt32(&offerCount, 1) + transportB.OnAnswer(func(answer webrtc.SessionDescription) error { + transportA.HandleRemoteDescription(answer) + return nil + }) + + var offerCount atomic.Int32 + transportA.OnOffer(func(sd webrtc.SessionDescription) error { + offerCount.Inc() + // the second offer is a ice restart offer, so we wait transportB complete the ice gathering if transportB.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering { require.Eventually(t, func() bool { return transportB.pc.ICEGatheringState() == webrtc.ICEGatheringStateComplete }, 10*time.Second, time.Millisecond*10) } - handleOffer(sd) + + transportB.HandleRemoteDescription(sd) + return nil }) // first establish connection - require.NoError(t, transportA.CreateAndSendOffer(&webrtc.OfferOptions{ - ICERestart: true, - })) + transportA.ICERestart() // ensure we are connected require.Eventually(t, func() bool { return transportA.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected && transportB.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected && - atomic.LoadInt32(&offerCount) == 2 + offerCount.Load() == 2 }, testutils.ConnectTimeout, 10*time.Millisecond, "transport did not connect") transportA.Close() transportB.Close() @@ -260,14 +284,22 @@ func TestNegotiationFailed(t *testing.T) { transportA, err := NewPCTransport(params) require.NoError(t, err) - transportA.OnOffer(func(sd webrtc.SessionDescription) {}) - var failed int32 - transportA.OnNegotiationFailed(func() { - atomic.AddInt32(&failed, 1) + transportA.OnICECandidate(func(candidate *webrtc.ICECandidate) error { + if candidate == nil { + return nil + } + t.Logf("got ICE candidate from A: %v", candidate) + return nil }) - transportA.CreateAndSendOffer(nil) + + transportA.OnOffer(func(sd webrtc.SessionDescription) error { return nil }) + var failed atomic.Int32 + transportA.OnNegotiationFailed(func() { + failed.Inc() + }) + transportA.Negotiate(true) require.Eventually(t, func() bool { - return atomic.LoadInt32(&failed) == 1 + return failed.Load() == 1 }, negotiationFailedTimeout+time.Second, 10*time.Millisecond, "negotiation failed") } @@ -304,7 +336,7 @@ func TestFilteringCandidates(t *testing.T) { // should not filter out UDP candidates if TCP is not preferred offer = *transport.pc.LocalDescription() - filteredOffer := transport.filterCandidates(offer) + filteredOffer := transport.filterCandidates(offer, false) require.EqualValues(t, offer.SDP, filteredOffer.SDP) parsed, err := offer.Unmarshal() @@ -382,7 +414,7 @@ func TestFilteringCandidates(t *testing.T) { require.Equal(t, 2, tcp) transport.SetPreferTCP(true) - filteredOffer = transport.filterCandidates(offer) + filteredOffer = transport.filterCandidates(offer, true) parsed, err = filteredOffer.Unmarshal() require.NoError(t, err) udp, tcp = getNumTransportTypeCandidates(parsed) @@ -390,33 +422,59 @@ func TestFilteringCandidates(t *testing.T) { require.Equal(t, 2, tcp) } -func handleOfferFunc(t *testing.T, current, other *PCTransport) func(sd webrtc.SessionDescription) { - return func(sd webrtc.SessionDescription) { - t.Logf("handling offer") - t.Logf("setting other remote description") - require.NoError(t, other.SetRemoteDescription(sd)) - answer, err := other.pc.CreateAnswer(nil) - require.NoError(t, err) - require.NoError(t, other.pc.SetLocalDescription(answer)) - - t.Logf("setting answer on current") - require.NoError(t, current.SetRemoteDescription(answer)) - } -} - func handleICEExchange(t *testing.T, a, b *PCTransport) { - a.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + a.OnICECandidate(func(candidate *webrtc.ICECandidate) error { if candidate == nil { - return + return nil } t.Logf("got ICE candidate from A: %v", candidate) - require.NoError(t, b.AddICECandidate(candidate.ToJSON())) + b.AddICECandidate(candidate.ToJSON()) + return nil }) - b.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + b.OnICECandidate(func(candidate *webrtc.ICECandidate) error { if candidate == nil { - return + return nil } t.Logf("got ICE candidate from B: %v", candidate) - require.NoError(t, a.AddICECandidate(candidate.ToJSON())) + a.AddICECandidate(candidate.ToJSON()) + return nil }) } + +func connectTransports(t *testing.T, offerer, answerer *PCTransport, isICERestart bool, expectedOfferCount int32, expectedAnswerCount int32) { + var offerCount atomic.Int32 + var answerCount atomic.Int32 + answerer.OnAnswer(func(answer webrtc.SessionDescription) error { + answerCount.Inc() + offerer.HandleRemoteDescription(answer) + return nil + }) + + offerer.OnOffer(func(offer webrtc.SessionDescription) error { + offerCount.Inc() + answerer.HandleRemoteDescription(offer) + return nil + }) + + if isICERestart { + offerer.ICERestart() + } else { + offerer.Negotiate(true) + } + + require.Eventually(t, func() bool { + return offerCount.Load() == expectedOfferCount + }, 10*time.Second, time.Millisecond*10, fmt.Sprintf("offer count mismatch, expected: %d, actual: %d", expectedOfferCount, offerCount.Load())) + + require.Eventually(t, func() bool { + return offerer.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected + }, 10*time.Second, time.Millisecond*10, "offerer did not become connected") + + require.Eventually(t, func() bool { + return answerCount.Load() == expectedAnswerCount + }, 10*time.Second, time.Millisecond*10, fmt.Sprintf("answer count mismatch, expected: %d, actual: %d", expectedAnswerCount, answerCount.Load())) + + require.Eventually(t, func() bool { + return answerer.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected + }, 10*time.Second, time.Millisecond*10, "answerer did not become connected") +} diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index 428bc204b..0b018f67f 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -3,18 +3,17 @@ package rtc import ( "strings" "sync" - "sync/atomic" "github.com/pion/rtcp" "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" "github.com/pkg/errors" + "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/telemetry" - "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) @@ -48,8 +47,6 @@ type TransportManager struct { pendingDataChannelsPublisher []*livekit.DataChannelInfo lastPublisherAnswer atomic.Value - onPublisherGetDTX func() bool - onPublisherInitialConnected func() onSubscriberInitialConnected func() onPrimaryTransportInitialConnected func() @@ -95,7 +92,6 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro return nil, err } t.publisher = publisher - t.publisher.OnRemoteDescriptionSettled(t.createPublisherAnswerAndSend) t.publisher.OnInitialConnected(func() { if t.onPublisherInitialConnected != nil { t.onPublisherInitialConnected() @@ -155,18 +151,18 @@ func (t *TransportManager) Close() { t.subscriber.Close() } -func (t *TransportManager) OnPublisherICECandidate(f func(c *webrtc.ICECandidate)) { +func (t *TransportManager) OnPublisherICECandidate(f func(c *webrtc.ICECandidate) error) { t.publisher.OnICECandidate(f) } func (t *TransportManager) OnPublisherGetDTX(f func() bool) { - t.onPublisherGetDTX = f + t.publisher.OnGetDTX(f) } -func (t *TransportManager) OnPublisherAnswer(f func(answer webrtc.SessionDescription)) { - t.publisher.OnAnswer(func(sd webrtc.SessionDescription) { +func (t *TransportManager) OnPublisherAnswer(f func(answer webrtc.SessionDescription) error) { + t.publisher.OnAnswer(func(sd webrtc.SessionDescription) error { t.lastPublisherAnswer.Store(sd) - f(sd) + return f(sd) }) } @@ -190,19 +186,15 @@ func (t *TransportManager) GetPublisherRTPReceiver(mid string) *webrtc.RTPReceiv return t.publisher.GetRTPReceiver(mid) } -func (t *TransportManager) PublisherLocalDescriptionSent() { - t.publisher.LocalDescriptionSent() -} - func (t *TransportManager) WritePublisherRTCP(pkts []rtcp.Packet) error { return t.publisher.WriteRTCP(pkts) } -func (t *TransportManager) OnSubscriberICECandidate(f func(c *webrtc.ICECandidate)) { +func (t *TransportManager) OnSubscriberICECandidate(f func(c *webrtc.ICECandidate) error) { t.subscriber.OnICECandidate(f) } -func (t *TransportManager) OnSubscriberOffer(f func(offer webrtc.SessionDescription)) { +func (t *TransportManager) OnSubscriberOffer(f func(offer webrtc.SessionDescription) error) { t.subscriber.OnOffer(f) } @@ -210,18 +202,10 @@ func (t *TransportManager) OnSubscriberInitialConnected(f func()) { t.onSubscriberInitialConnected = f } -func (t *TransportManager) OnSubscriberNegotiationFailed(f func()) { - t.subscriber.OnNegotiationFailed(f) -} - func (t *TransportManager) OnSubscriberStreamStateChange(f func(update *sfu.StreamStateUpdate) error) { t.subscriber.OnStreamStateChange(f) } -func (t *TransportManager) SubscriberLocalDescriptionSent() { - t.subscriber.LocalDescriptionSent() -} - func (t *TransportManager) HasSubscriberEverConnected() bool { return t.subscriber.HasEverConnected() } @@ -254,6 +238,11 @@ func (t *TransportManager) OnAnyTransportFailed(f func()) { t.onAnyTransportFailed = f } +func (t *TransportManager) OnAnyTransportNegotiationFailed(f func()) { + t.publisher.OnNegotiationFailed(f) + t.subscriber.OnNegotiationFailed(f) +} + func (t *TransportManager) AddSubscribedTrack(subTrack types.SubscribedTrack) { t.subscriber.AddTrackToStreamAllocator(subTrack) } @@ -362,71 +351,46 @@ func (t *TransportManager) GetLastUnmatchedMediaForOffer(offer webrtc.SessionDes return } -func (t *TransportManager) HandleOffer(offer webrtc.SessionDescription, shouldPend bool) error { +func (t *TransportManager) HandleOffer(offer webrtc.SessionDescription, shouldPend bool) { t.lock.Lock() if shouldPend { t.pendingOfferPublisher = &offer t.lock.Unlock() - return nil + return } t.lock.Unlock() - return t.publisher.SetRemoteDescription(offer) + t.publisher.HandleRemoteDescription(offer) } -func (t *TransportManager) ProcessPendingPublisherOffer() error { +func (t *TransportManager) ProcessPendingPublisherOffer() { t.lock.Lock() pendingOffer := t.pendingOfferPublisher t.pendingOfferPublisher = nil t.lock.Unlock() if pendingOffer != nil { - return t.HandleOffer(*pendingOffer, false) + t.HandleOffer(*pendingOffer, false) } - - return nil -} - -func (t *TransportManager) createPublisherAnswerAndSend() error { - enableDTX := false - if t.onPublisherGetDTX != nil { - enableDTX = t.onPublisherGetDTX() - } - err := t.publisher.CreateAndSendAnswer(enableDTX) - if err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "create").Add(1) - return errors.Wrap(err, "could not create answer") - } - - return nil } // HandleAnswer handles a client answer response, with subscriber PC, server initiates the // offer and client answers -func (t *TransportManager) HandleAnswer(answer webrtc.SessionDescription) error { - if answer.Type != webrtc.SDPTypeAnswer { - return ErrUnexpectedOffer - } - +func (t *TransportManager) HandleAnswer(answer webrtc.SessionDescription) { t.params.Logger.Infow("received answer", "transport", livekit.SignalTarget_SUBSCRIBER) - if err := t.subscriber.SetRemoteDescription(answer); err != nil { - return errors.Wrap(err, "could not set answer") - } - - return nil + t.subscriber.HandleRemoteDescription(answer) } // AddICECandidate adds candidates for remote peer -func (t *TransportManager) AddICECandidate(candidate webrtc.ICECandidateInit, target livekit.SignalTarget) error { +func (t *TransportManager) AddICECandidate(candidate webrtc.ICECandidateInit, target livekit.SignalTarget) { switch target { case livekit.SignalTarget_PUBLISHER: - return t.publisher.AddICECandidate(candidate) + t.publisher.AddICECandidate(candidate) case livekit.SignalTarget_SUBSCRIBER: - return t.subscriber.AddICECandidate(candidate) + t.subscriber.AddICECandidate(candidate) default: err := errors.New("unknown signal target") t.params.Logger.Errorw("ice candidate for unknown signal target", err, "target", target) - return err } } @@ -442,14 +406,12 @@ func (t *TransportManager) IsNegotiationPending(publisherID livekit.ParticipantI return t.subscriber.IsNegotiationPending(publisherID) } -func (t *TransportManager) ICERestart(iceConfig *types.IceConfig) error { +func (t *TransportManager) ICERestart(iceConfig *types.IceConfig) { if iceConfig != nil { t.SetICEConfig(*iceConfig) } - return t.subscriber.CreateAndSendOffer(&webrtc.OfferOptions{ - ICERestart: true, - }) + t.subscriber.ICERestart() } func (t *TransportManager) OnICEConfigChanged(f func(iceConfig types.IceConfig)) { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 9ec0d78a4..33ec9004c 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -225,16 +225,16 @@ type LocalParticipant interface { CanPublishData() bool // PeerConnection - AddICECandidate(candidate webrtc.ICECandidateInit, target livekit.SignalTarget) error - HandleOffer(sdp webrtc.SessionDescription) error + AddICECandidate(candidate webrtc.ICECandidateInit, target livekit.SignalTarget) + HandleOffer(sdp webrtc.SessionDescription) AddTrack(req *livekit.AddTrackRequest) SetTrackMuted(trackID livekit.TrackID, muted bool, fromAdmin bool) - HandleAnswer(sdp webrtc.SessionDescription) error + HandleAnswer(sdp webrtc.SessionDescription) Negotiate(force bool) AddNegotiationPending(publisherID livekit.ParticipantID) IsNegotiationPending(publisherID livekit.ParticipantID) bool - ICERestart(iceConfig *IceConfig) error + ICERestart(iceConfig *IceConfig) AddTrackToSubscriber(trackLocal webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) AddTransceiverFromTrackToSubscriber(trackLocal webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) RemoveTrackFromSubscriber(sender *webrtc.RTPSender) error diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 02a7b71f7..a3b04be25 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -15,18 +15,12 @@ import ( ) type FakeLocalParticipant struct { - AddICECandidateStub func(webrtc.ICECandidateInit, livekit.SignalTarget) error + AddICECandidateStub func(webrtc.ICECandidateInit, livekit.SignalTarget) addICECandidateMutex sync.RWMutex addICECandidateArgsForCall []struct { arg1 webrtc.ICECandidateInit arg2 livekit.SignalTarget } - addICECandidateReturns struct { - result1 error - } - addICECandidateReturnsOnCall map[int]struct { - result1 error - } AddNegotiationPendingStub func(livekit.ParticipantID) addNegotiationPendingMutex sync.RWMutex addNegotiationPendingArgsForCall []struct { @@ -293,28 +287,16 @@ type FakeLocalParticipant struct { getSubscribedTracksReturnsOnCall map[int]struct { result1 []types.SubscribedTrack } - HandleAnswerStub func(webrtc.SessionDescription) error + HandleAnswerStub func(webrtc.SessionDescription) handleAnswerMutex sync.RWMutex handleAnswerArgsForCall []struct { arg1 webrtc.SessionDescription } - handleAnswerReturns struct { - result1 error - } - handleAnswerReturnsOnCall map[int]struct { - result1 error - } - HandleOfferStub func(webrtc.SessionDescription) error + HandleOfferStub func(webrtc.SessionDescription) handleOfferMutex sync.RWMutex handleOfferArgsForCall []struct { arg1 webrtc.SessionDescription } - handleOfferReturns struct { - result1 error - } - handleOfferReturnsOnCall map[int]struct { - result1 error - } HiddenStub func() bool hiddenMutex sync.RWMutex hiddenArgsForCall []struct { @@ -325,17 +307,11 @@ type FakeLocalParticipant struct { hiddenReturnsOnCall map[int]struct { result1 bool } - ICERestartStub func(*types.IceConfig) error + ICERestartStub func(*types.IceConfig) iCERestartMutex sync.RWMutex iCERestartArgsForCall []struct { arg1 *types.IceConfig } - iCERestartReturns struct { - result1 error - } - iCERestartReturnsOnCall map[int]struct { - result1 error - } IDStub func() livekit.ParticipantID iDMutex sync.RWMutex iDArgsForCall []struct { @@ -758,24 +734,18 @@ type FakeLocalParticipant struct { invocationsMutex sync.RWMutex } -func (fake *FakeLocalParticipant) AddICECandidate(arg1 webrtc.ICECandidateInit, arg2 livekit.SignalTarget) error { +func (fake *FakeLocalParticipant) AddICECandidate(arg1 webrtc.ICECandidateInit, arg2 livekit.SignalTarget) { fake.addICECandidateMutex.Lock() - ret, specificReturn := fake.addICECandidateReturnsOnCall[len(fake.addICECandidateArgsForCall)] fake.addICECandidateArgsForCall = append(fake.addICECandidateArgsForCall, struct { arg1 webrtc.ICECandidateInit arg2 livekit.SignalTarget }{arg1, arg2}) stub := fake.AddICECandidateStub - fakeReturns := fake.addICECandidateReturns fake.recordInvocation("AddICECandidate", []interface{}{arg1, arg2}) fake.addICECandidateMutex.Unlock() if stub != nil { - return stub(arg1, arg2) + fake.AddICECandidateStub(arg1, arg2) } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 } func (fake *FakeLocalParticipant) AddICECandidateCallCount() int { @@ -784,7 +754,7 @@ func (fake *FakeLocalParticipant) AddICECandidateCallCount() int { return len(fake.addICECandidateArgsForCall) } -func (fake *FakeLocalParticipant) AddICECandidateCalls(stub func(webrtc.ICECandidateInit, livekit.SignalTarget) error) { +func (fake *FakeLocalParticipant) AddICECandidateCalls(stub func(webrtc.ICECandidateInit, livekit.SignalTarget)) { fake.addICECandidateMutex.Lock() defer fake.addICECandidateMutex.Unlock() fake.AddICECandidateStub = stub @@ -797,29 +767,6 @@ func (fake *FakeLocalParticipant) AddICECandidateArgsForCall(i int) (webrtc.ICEC return argsForCall.arg1, argsForCall.arg2 } -func (fake *FakeLocalParticipant) AddICECandidateReturns(result1 error) { - fake.addICECandidateMutex.Lock() - defer fake.addICECandidateMutex.Unlock() - fake.AddICECandidateStub = nil - fake.addICECandidateReturns = struct { - result1 error - }{result1} -} - -func (fake *FakeLocalParticipant) AddICECandidateReturnsOnCall(i int, result1 error) { - fake.addICECandidateMutex.Lock() - defer fake.addICECandidateMutex.Unlock() - fake.AddICECandidateStub = nil - if fake.addICECandidateReturnsOnCall == nil { - fake.addICECandidateReturnsOnCall = make(map[int]struct { - result1 error - }) - } - fake.addICECandidateReturnsOnCall[i] = struct { - result1 error - }{result1} -} - func (fake *FakeLocalParticipant) AddNegotiationPending(arg1 livekit.ParticipantID) { fake.addNegotiationPendingMutex.Lock() fake.addNegotiationPendingArgsForCall = append(fake.addNegotiationPendingArgsForCall, struct { @@ -2204,23 +2151,17 @@ func (fake *FakeLocalParticipant) GetSubscribedTracksReturnsOnCall(i int, result }{result1} } -func (fake *FakeLocalParticipant) HandleAnswer(arg1 webrtc.SessionDescription) error { +func (fake *FakeLocalParticipant) HandleAnswer(arg1 webrtc.SessionDescription) { fake.handleAnswerMutex.Lock() - ret, specificReturn := fake.handleAnswerReturnsOnCall[len(fake.handleAnswerArgsForCall)] fake.handleAnswerArgsForCall = append(fake.handleAnswerArgsForCall, struct { arg1 webrtc.SessionDescription }{arg1}) stub := fake.HandleAnswerStub - fakeReturns := fake.handleAnswerReturns fake.recordInvocation("HandleAnswer", []interface{}{arg1}) fake.handleAnswerMutex.Unlock() if stub != nil { - return stub(arg1) + fake.HandleAnswerStub(arg1) } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 } func (fake *FakeLocalParticipant) HandleAnswerCallCount() int { @@ -2229,7 +2170,7 @@ func (fake *FakeLocalParticipant) HandleAnswerCallCount() int { return len(fake.handleAnswerArgsForCall) } -func (fake *FakeLocalParticipant) HandleAnswerCalls(stub func(webrtc.SessionDescription) error) { +func (fake *FakeLocalParticipant) HandleAnswerCalls(stub func(webrtc.SessionDescription)) { fake.handleAnswerMutex.Lock() defer fake.handleAnswerMutex.Unlock() fake.HandleAnswerStub = stub @@ -2242,46 +2183,17 @@ func (fake *FakeLocalParticipant) HandleAnswerArgsForCall(i int) webrtc.SessionD return argsForCall.arg1 } -func (fake *FakeLocalParticipant) HandleAnswerReturns(result1 error) { - fake.handleAnswerMutex.Lock() - defer fake.handleAnswerMutex.Unlock() - fake.HandleAnswerStub = nil - fake.handleAnswerReturns = struct { - result1 error - }{result1} -} - -func (fake *FakeLocalParticipant) HandleAnswerReturnsOnCall(i int, result1 error) { - fake.handleAnswerMutex.Lock() - defer fake.handleAnswerMutex.Unlock() - fake.HandleAnswerStub = nil - if fake.handleAnswerReturnsOnCall == nil { - fake.handleAnswerReturnsOnCall = make(map[int]struct { - result1 error - }) - } - fake.handleAnswerReturnsOnCall[i] = struct { - result1 error - }{result1} -} - -func (fake *FakeLocalParticipant) HandleOffer(arg1 webrtc.SessionDescription) error { +func (fake *FakeLocalParticipant) HandleOffer(arg1 webrtc.SessionDescription) { fake.handleOfferMutex.Lock() - ret, specificReturn := fake.handleOfferReturnsOnCall[len(fake.handleOfferArgsForCall)] fake.handleOfferArgsForCall = append(fake.handleOfferArgsForCall, struct { arg1 webrtc.SessionDescription }{arg1}) stub := fake.HandleOfferStub - fakeReturns := fake.handleOfferReturns fake.recordInvocation("HandleOffer", []interface{}{arg1}) fake.handleOfferMutex.Unlock() if stub != nil { - return stub(arg1) + fake.HandleOfferStub(arg1) } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 } func (fake *FakeLocalParticipant) HandleOfferCallCount() int { @@ -2290,7 +2202,7 @@ func (fake *FakeLocalParticipant) HandleOfferCallCount() int { return len(fake.handleOfferArgsForCall) } -func (fake *FakeLocalParticipant) HandleOfferCalls(stub func(webrtc.SessionDescription) error) { +func (fake *FakeLocalParticipant) HandleOfferCalls(stub func(webrtc.SessionDescription)) { fake.handleOfferMutex.Lock() defer fake.handleOfferMutex.Unlock() fake.HandleOfferStub = stub @@ -2303,29 +2215,6 @@ func (fake *FakeLocalParticipant) HandleOfferArgsForCall(i int) webrtc.SessionDe return argsForCall.arg1 } -func (fake *FakeLocalParticipant) HandleOfferReturns(result1 error) { - fake.handleOfferMutex.Lock() - defer fake.handleOfferMutex.Unlock() - fake.HandleOfferStub = nil - fake.handleOfferReturns = struct { - result1 error - }{result1} -} - -func (fake *FakeLocalParticipant) HandleOfferReturnsOnCall(i int, result1 error) { - fake.handleOfferMutex.Lock() - defer fake.handleOfferMutex.Unlock() - fake.HandleOfferStub = nil - if fake.handleOfferReturnsOnCall == nil { - fake.handleOfferReturnsOnCall = make(map[int]struct { - result1 error - }) - } - fake.handleOfferReturnsOnCall[i] = struct { - result1 error - }{result1} -} - func (fake *FakeLocalParticipant) Hidden() bool { fake.hiddenMutex.Lock() ret, specificReturn := fake.hiddenReturnsOnCall[len(fake.hiddenArgsForCall)] @@ -2379,23 +2268,17 @@ func (fake *FakeLocalParticipant) HiddenReturnsOnCall(i int, result1 bool) { }{result1} } -func (fake *FakeLocalParticipant) ICERestart(arg1 *types.IceConfig) error { +func (fake *FakeLocalParticipant) ICERestart(arg1 *types.IceConfig) { fake.iCERestartMutex.Lock() - ret, specificReturn := fake.iCERestartReturnsOnCall[len(fake.iCERestartArgsForCall)] fake.iCERestartArgsForCall = append(fake.iCERestartArgsForCall, struct { arg1 *types.IceConfig }{arg1}) stub := fake.ICERestartStub - fakeReturns := fake.iCERestartReturns fake.recordInvocation("ICERestart", []interface{}{arg1}) fake.iCERestartMutex.Unlock() if stub != nil { - return stub(arg1) + fake.ICERestartStub(arg1) } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 } func (fake *FakeLocalParticipant) ICERestartCallCount() int { @@ -2404,7 +2287,7 @@ func (fake *FakeLocalParticipant) ICERestartCallCount() int { return len(fake.iCERestartArgsForCall) } -func (fake *FakeLocalParticipant) ICERestartCalls(stub func(*types.IceConfig) error) { +func (fake *FakeLocalParticipant) ICERestartCalls(stub func(*types.IceConfig)) { fake.iCERestartMutex.Lock() defer fake.iCERestartMutex.Unlock() fake.ICERestartStub = stub @@ -2417,29 +2300,6 @@ func (fake *FakeLocalParticipant) ICERestartArgsForCall(i int) *types.IceConfig return argsForCall.arg1 } -func (fake *FakeLocalParticipant) ICERestartReturns(result1 error) { - fake.iCERestartMutex.Lock() - defer fake.iCERestartMutex.Unlock() - fake.ICERestartStub = nil - fake.iCERestartReturns = struct { - result1 error - }{result1} -} - -func (fake *FakeLocalParticipant) ICERestartReturnsOnCall(i int, result1 error) { - fake.iCERestartMutex.Lock() - defer fake.iCERestartMutex.Unlock() - fake.ICERestartStub = nil - if fake.iCERestartReturnsOnCall == nil { - fake.iCERestartReturnsOnCall = make(map[int]struct { - result1 error - }) - } - fake.iCERestartReturnsOnCall[i] = struct { - result1 error - }{result1} -} - func (fake *FakeLocalParticipant) ID() livekit.ParticipantID { fake.iDMutex.Lock() ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] diff --git a/pkg/sfu/streamallocator.go b/pkg/sfu/streamallocator.go index cbf01c263..191bca6f8 100644 --- a/pkg/sfu/streamallocator.go +++ b/pkg/sfu/streamallocator.go @@ -83,7 +83,7 @@ func (s State) String() string { type Signal int const ( - SignalAllocateTrack = iota + SignalAllocateTrack Signal = iota SignalAllocateAllTracks SignalAdjustState SignalEstimate diff --git a/test/client/client.go b/test/client/client.go index 086f90654..52b973493 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -25,11 +25,6 @@ import ( "github.com/livekit/livekit-server/pkg/rtc" ) -const ( - lossyDataChannel = "_lossy" - reliableDataChannel = "_reliable" -) - type RTCClient struct { id livekit.ParticipantID conn *websocket.Conn @@ -42,20 +37,15 @@ type RTCClient struct { wsLock sync.Mutex ctx context.Context cancel context.CancelFunc - connected atomic.Bool - iceConnected atomic.Bool me *webrtc.MediaEngine // optional, populated only when receiving tracks subscribedTracks map[livekit.ParticipantID][]*webrtc.TrackRemote localParticipant *livekit.ParticipantInfo remoteParticipants map[livekit.ParticipantID]*livekit.ParticipantInfo - reliableDC *webrtc.DataChannel - reliableDCSub *webrtc.DataChannel - lossyDC *webrtc.DataChannel - lossyDCSub *webrtc.DataChannel - publisherConnected atomic.Bool - publisherNegotiated atomic.Bool - pongReceivedAt atomic.Int64 + subscriberAsPrimary atomic.Bool + publisherFullyEstablished atomic.Bool + subscriberFullyEstablished atomic.Bool + pongReceivedAt atomic.Int64 // tracks waiting to be acked, cid => trackInfo pendingPublishedTracks map[string]*livekit.TrackInfo @@ -144,15 +134,20 @@ func NewRTCClient(conn *websocket.Conn) (*RTCClient, error) { Mime: "video/h264", }, } + // + // The signal targets are from point of view of server. + // From client side, they are flipped, + // i. e. the publisher transport on client side has SUBSCRIBER signal target (i. e. publisher is offerer). + // Same applies for subscriber transport also + // c.publisher, err = rtc.NewPCTransport(rtc.TransportParams{ - Target: livekit.SignalTarget_PUBLISHER, + Target: livekit.SignalTarget_SUBSCRIBER, Config: &conf, EnabledCodecs: codecs, }) if err != nil { return nil, err } - // intentionally use publisher transport to have codecs pre-registered c.subscriber, err = rtc.NewPCTransport(rtc.TransportParams{ Target: livekit.SignalTarget_PUBLISHER, Config: &conf, @@ -162,87 +157,75 @@ func NewRTCClient(conn *websocket.Conn) (*RTCClient, error) { return nil, err } + c.publisher.OnICECandidate(func(ic *webrtc.ICECandidate) error { + if ic == nil { + return nil + } + return c.SendIceCandidate(ic, livekit.SignalTarget_PUBLISHER) + }) + c.publisher.OnOffer(c.onOffer) + c.publisher.OnFullyEstablished(func() { + logger.Debugw("publisher fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) + c.publisherFullyEstablished.Store(true) + }) + ordered := true - c.reliableDC, err = c.publisher.PeerConnection().CreateDataChannel(reliableDataChannel, - &webrtc.DataChannelInit{Ordered: &ordered}, - ) - if err != nil { + if err := c.publisher.CreateDataChannel(rtc.ReliableDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + }); err != nil { return nil, err } maxRetransmits := uint16(0) - c.lossyDC, err = c.publisher.PeerConnection().CreateDataChannel(lossyDataChannel, - &webrtc.DataChannelInit{Ordered: &ordered, MaxRetransmits: &maxRetransmits}, - ) - if err != nil { + if err := c.publisher.CreateDataChannel(rtc.LossyDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + MaxRetransmits: &maxRetransmits, + }); err != nil { return nil, err } - c.publisher.PeerConnection().OnICECandidate(func(ic *webrtc.ICECandidate) { + c.subscriber.OnICECandidate(func(ic *webrtc.ICECandidate) error { if ic == nil { - return + return nil } - _ = c.SendIceCandidate(ic, livekit.SignalTarget_PUBLISHER) + return c.SendIceCandidate(ic, livekit.SignalTarget_SUBSCRIBER) }) - c.subscriber.PeerConnection().OnICECandidate(func(ic *webrtc.ICECandidate) { - if ic == nil { - return - } - _ = c.SendIceCandidate(ic, livekit.SignalTarget_SUBSCRIBER) - }) - - c.subscriber.PeerConnection().OnTrack(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + c.subscriber.OnTrack(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { go c.processTrack(track) }) - c.subscriber.PeerConnection().OnDataChannel(func(channel *webrtc.DataChannel) { - if channel.Label() == reliableDataChannel { - c.reliableDCSub = channel - } else if channel.Label() == lossyDataChannel { - c.lossyDCSub = channel - } else { - return + c.subscriber.OnDataPacket(c.handleDataMessage) + c.subscriber.OnInitialConnected(func() { + logger.Debugw("subscriber initial connected", "participant", c.localParticipant.Identity) + + c.lock.Lock() + defer c.lock.Unlock() + for _, tw := range c.pendingTrackWriters { + if err := tw.Start(); err != nil { + logger.Errorw("track writer error", err) + } } - channel.OnMessage(c.handleDataMessage) - }) - c.publisher.OnOffer(c.onOffer) + c.pendingTrackWriters = nil - c.subscriber.PeerConnection().OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { - logger.Debugw("subscriber ICE state has changed", "state", connectionState.String(), - "participant", c.localParticipant.Identity) - if connectionState == webrtc.ICEConnectionStateConnected { - // flush peers - c.lock.Lock() - defer c.lock.Unlock() - for _, tw := range c.pendingTrackWriters { - if err := tw.Start(); err != nil { - logger.Errorw("track writer error", err) - } - } - - initialConnect := !c.iceConnected.Load() - c.pendingTrackWriters = nil - c.iceConnected.Store(true) - - if initialConnect && c.OnConnected != nil { - go c.OnConnected() - } + if c.OnConnected != nil { + go c.OnConnected() } }) - - c.publisher.PeerConnection().OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { - logger.Infow("publisher ICE state changed", "state", state.String(), - "participant", c.localParticipant.Identity) - - if state == webrtc.ICEConnectionStateConnected { - c.publisherConnected.Store(true) - // check if publisher triggered negotiate (!subscriberPrimary) - if c.publisherNegotiated.Load() { - c.iceConnected.Store(true) - } - } else { - c.publisherConnected.Store(false) - } + c.subscriber.OnFullyEstablished(func() { + logger.Debugw("subscriber fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) + c.subscriberFullyEstablished.Store(true) + }) + c.subscriber.OnAnswer(func(answer webrtc.SessionDescription) error { + // send remote an answer + logger.Infow("sending subscriber answer", + "participant", c.localParticipant.Identity, + // "sdp", answer, + ) + return c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_Answer{ + Answer: rtc.ToProtoSessionDescription(answer), + }, + }) }) return c, nil @@ -281,8 +264,10 @@ func (c *RTCClient) Run() error { c.lock.Unlock() // if publish only, negotiate if !msg.Join.SubscriberPrimary { - c.publisherNegotiated.Store(true) + c.subscriberAsPrimary.Store(false) c.publisher.Negotiate(false) + } else { + c.subscriberAsPrimary.Store(true) } logger.Infow("join accepted, awaiting offer", "participant", msg.Join.Participant.Identity) @@ -290,27 +275,22 @@ func (c *RTCClient) Run() error { // logger.Debugw("received server answer", // "participant", c.localParticipant.Identity, // "answer", msg.Answer.Sdp) - _ = c.handleAnswer(rtc.FromProtoSessionDescription(msg.Answer)) + c.handleAnswer(rtc.FromProtoSessionDescription(msg.Answer)) case *livekit.SignalResponse_Offer: logger.Infow("received server offer", "participant", c.localParticipant.Identity, ) desc := rtc.FromProtoSessionDescription(msg.Offer) - if err := c.handleOffer(desc); err != nil { - return err - } + c.handleOffer(desc) case *livekit.SignalResponse_Trickle: candidateInit, err := rtc.FromProtoTrickle(msg.Trickle) if err != nil { return err } if msg.Trickle.Target == livekit.SignalTarget_PUBLISHER { - err = c.publisher.AddICECandidate(candidateInit) + c.publisher.AddICECandidate(candidateInit) } else { - err = c.subscriber.AddICECandidate(candidateInit) - } - if err != nil { - return err + c.subscriber.AddICECandidate(candidateInit) } case *livekit.SignalResponse_Update: c.lock.Lock() @@ -340,7 +320,7 @@ func (c *RTCClient) Run() error { c.lock.Lock() sender := c.trackSenders[sid] if sender != nil { - if err := c.publisher.PeerConnection().RemoveTrack(sender); err != nil { + if err := c.publisher.RemoveTrack(sender); err != nil { logger.Errorw("Could not unpublish track", err) } c.publisher.Negotiate(false) @@ -366,8 +346,14 @@ func (c *RTCClient) WaitUntilConnected() error { } return fmt.Errorf("%s could not connect after timeout", id) case <-time.After(10 * time.Millisecond): - if c.iceConnected.Load() { - return nil + if c.subscriberAsPrimary.Load() { + if c.subscriberFullyEstablished.Load() { + return nil + } + } else { + if c.publisherFullyEstablished.Load() { + return nil + } } } } @@ -430,8 +416,8 @@ func (c *RTCClient) Stop() { Leave: &livekit.LeaveRequest{}, }, }) - c.connected.Store(false) - c.iceConnected.Store(false) + c.publisherFullyEstablished.Store(false) + c.subscriberFullyEstablished.Store(false) _ = c.conn.Close() c.publisher.Close() c.subscriber.Close() @@ -477,6 +463,14 @@ func (c *RTCClient) SendIceCandidate(ic *webrtc.ICECandidate, target livekit.Sig }) } +func (c *RTCClient) hasPrimaryEverConnected() bool { + if c.subscriberAsPrimary.Load() { + return c.subscriber.HasEverConnected() + } else { + return c.publisher.HasEverConnected() + } +} + func (c *RTCClient) AddTrack(track *webrtc.TrackLocalStaticSample, path string) (writer *TrackWriter, err error) { trackType := livekit.TrackType_AUDIO if track.Kind() == webrtc.RTPCodecTypeVideo { @@ -511,8 +505,9 @@ func (c *RTCClient) AddTrack(track *webrtc.TrackLocalStaticSample, path string) c.lock.Lock() defer c.lock.Unlock() - sender, err := c.publisher.PeerConnection().AddTrack(track) + sender, _, err := c.publisher.AddTrack(track) if err != nil { + logger.Errorw("add track failed", err, "trackID", ti.Sid, "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) return } c.localTracks[ti.Sid] = track @@ -520,8 +515,8 @@ func (c *RTCClient) AddTrack(track *webrtc.TrackLocalStaticSample, path string) c.publisher.Negotiate(false) writer = NewTrackWriter(c.ctx, track, path) - // write tracks only after ICE connectivity - if c.iceConnected.Load() { + // write tracks only after connection established + if c.hasPrimaryEverConnected() { err = writer.Start() } else { c.pendingTrackWriters = append(c.pendingTrackWriters, writer) @@ -590,16 +585,8 @@ func (c *RTCClient) PublishData(data []byte, kind livekit.DataPacket_Kind) error User: &livekit.UserPacket{Payload: data}, }, } - payload, err := proto.Marshal(dp) - if err != nil { - return err - } - if kind == livekit.DataPacket_RELIABLE { - return c.reliableDC.Send(payload) - } else { - return c.lossyDC.Send(payload) - } + return c.publisher.SendDataPacket(dp) } func (c *RTCClient) GetPublishedTrackIDs() []string { @@ -613,22 +600,12 @@ func (c *RTCClient) GetPublishedTrackIDs() []string { } func (c *RTCClient) ensurePublisherConnected() error { - if c.publisherConnected.Load() { + if c.publisher.HasEverConnected() { return nil } - if c.publisher.PeerConnection().ConnectionState() == webrtc.PeerConnectionStateNew { - // start negotiating - c.publisher.Negotiate(false) - } - - dcOpen := atomic.NewBool(false) - c.reliableDC.OnOpen(func() { - dcOpen.Store(true) - }) - if c.reliableDC.ReadyState() == webrtc.DataChannelStateOpen { - dcOpen.Store(true) - } + // start negotiating + c.publisher.Negotiate(false) // wait until connected, increase wait time since it takes more than 10s sometimes on GH ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -638,16 +615,16 @@ func (c *RTCClient) ensurePublisherConnected() error { case <-ctx.Done(): return fmt.Errorf("could not connect publisher after timeout") case <-time.After(10 * time.Millisecond): - if c.publisherConnected.Load() && dcOpen.Load() { + if c.publisher.HasEverConnected() { return nil } } } } -func (c *RTCClient) handleDataMessage(msg webrtc.DataChannelMessage) { +func (c *RTCClient) handleDataMessage(kind livekit.DataPacket_Kind, data []byte) { dp := &livekit.DataPacket{} - err := proto.Unmarshal(msg.Data, dp) + err := proto.Unmarshal(data, dp) if err != nil { return } @@ -659,54 +636,22 @@ func (c *RTCClient) handleDataMessage(msg webrtc.DataChannelMessage) { } // handles a server initiated offer, handle on subscriber PC -func (c *RTCClient) handleOffer(desc webrtc.SessionDescription) error { - if err := c.subscriber.SetRemoteDescription(desc); err != nil { - return err - } - - // if we received an offer, we'd have to answer - answer, err := c.subscriber.PeerConnection().CreateAnswer(nil) - if err != nil { - return err - } - - if err := c.subscriber.PeerConnection().SetLocalDescription(answer); err != nil { - return err - } - - // send remote an answer - logger.Infow("sending subscriber answer", - "participant", c.localParticipant.Identity, - // "sdp", answer, - ) - return c.SendRequest(&livekit.SignalRequest{ - Message: &livekit.SignalRequest_Answer{ - Answer: rtc.ToProtoSessionDescription(answer), - }, - }) +func (c *RTCClient) handleOffer(desc webrtc.SessionDescription) { + c.subscriber.HandleRemoteDescription(desc) } // the client handles answer on the publisher PC -func (c *RTCClient) handleAnswer(desc webrtc.SessionDescription) error { +func (c *RTCClient) handleAnswer(desc webrtc.SessionDescription) { logger.Infow("handling server answer", "participant", c.localParticipant.Identity) // remote answered the offer, establish connection - err := c.publisher.SetRemoteDescription(desc) - if err != nil { - return err - } - - if c.connected.Swap(true) { - // already connected - return nil - } - return nil + c.publisher.HandleRemoteDescription(desc) } -func (c *RTCClient) onOffer(offer webrtc.SessionDescription) { +func (c *RTCClient) onOffer(offer webrtc.SessionDescription) error { if c.localParticipant != nil { logger.Infow("starting negotiation", "participant", c.localParticipant.Identity) } - _ = c.SendRequest(&livekit.SignalRequest{ + return c.SendRequest(&livekit.SignalRequest{ Message: &livekit.SignalRequest_Offer{ Offer: rtc.ToProtoSessionDescription(offer), }, @@ -786,5 +731,5 @@ func (c *RTCClient) SendNacks(count int) { } c.lock.Unlock() - _ = c.subscriber.PeerConnection().WriteRTCP(packets) + _ = c.subscriber.WriteRTCP(packets) } diff --git a/test/multinode_roomservice_test.go b/test/multinode_roomservice_test.go index dded4cb25..f822ebdb1 100644 --- a/test/multinode_roomservice_test.go +++ b/test/multinode_roomservice_test.go @@ -126,7 +126,6 @@ func TestMultiNodeMutePublishedTrack(t *testing.T) { defer c1.Stop() waitUntilConnected(t, c1) - // c1 and c2 publishing, c3 just receiving writers := publishTracksForClients(t, c1) defer stopWriters(writers...)