From 0be241eed8836f14e8a53e40ef246ec2f9f95609 Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Sun, 28 Jan 2024 21:35:25 -0800 Subject: [PATCH] refactor transport callbacks as interface (#2423) * refactor transport callbacks as interface * test --- pkg/rtc/participant.go | 104 ++- pkg/rtc/transport.go | 301 ++------- pkg/rtc/transport/handler.go | 55 ++ pkg/rtc/transport/negotiationstate.go | 26 + .../transport/transportfakes/fake_handler.go | 593 ++++++++++++++++++ pkg/rtc/transport_test.go | 124 ++-- pkg/rtc/transportmanager.go | 120 +--- test/client/client.go | 23 +- 8 files changed, 928 insertions(+), 418 deletions(-) create mode 100644 pkg/rtc/transport/handler.go create mode 100644 pkg/rtc/transport/negotiationstate.go create mode 100644 pkg/rtc/transport/transportfakes/fake_handler.go diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 0aa98e198..df60f272a 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -36,6 +36,7 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/rtc/supervisor" + "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" @@ -1149,13 +1150,91 @@ func (p *ParticipantImpl) UpdateMediaRTT(rtt uint32) { } } +type AnyTransportHandler struct { + transport.UnimplementedHandler + p *ParticipantImpl +} + +func (h AnyTransportHandler) OnFailed(isShortLived bool) { + h.p.onAnyTransportFailed() +} + +func (h AnyTransportHandler) OnNegotiationFailed() { + h.p.onAnyTransportNegotiationFailed() +} + +func (h AnyTransportHandler) OnICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error { + return h.p.onICECandidate(c, target) +} + +type PublisherTransportHandler struct { + AnyTransportHandler +} + +func (h PublisherTransportHandler) OnAnswer(sd webrtc.SessionDescription) error { + return h.p.onPublisherAnswer(sd) +} + +func (h PublisherTransportHandler) OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + h.p.onMediaTrack(track, rtpReceiver) +} + +func (h PublisherTransportHandler) OnInitialConnected() { + h.p.onPublisherInitialConnected() +} + +func (h PublisherTransportHandler) OnDataPacket(kind livekit.DataPacket_Kind, data []byte) { + h.p.onDataMessage(kind, data) +} + +type SubscriberTransportHandler struct { + AnyTransportHandler +} + +func (h SubscriberTransportHandler) OnOffer(sd webrtc.SessionDescription) error { + return h.p.onSubscriberOffer(sd) +} + +func (h SubscriberTransportHandler) OnStreamStateChange(update *streamallocator.StreamStateUpdate) error { + return h.p.onStreamStateChange(update) +} + +func (h SubscriberTransportHandler) OnInitialConnected() { + h.p.onSubscriberInitialConnected() +} + +type PrimaryTransportHandler struct { + transport.Handler + p *ParticipantImpl +} + +func (h PrimaryTransportHandler) OnInitialConnected() { + h.Handler.OnInitialConnected() + h.p.onPrimaryTransportInitialConnected() +} + +func (h PrimaryTransportHandler) OnFullyEstablished() { + h.p.onPrimaryTransportFullyEstablished() +} + func (p *ParticipantImpl) setupTransportManager() error { + ath := AnyTransportHandler{p: p} + var pth transport.Handler = PublisherTransportHandler{ath} + var sth transport.Handler = SubscriberTransportHandler{ath} + + subscriberAsPrimary := p.ProtocolVersion().SubscriberAsPrimary() && p.CanSubscribe() + if subscriberAsPrimary { + sth = PrimaryTransportHandler{sth, p} + } else { + pth = PrimaryTransportHandler{pth, p} + } + params := TransportManagerParams{ Identity: p.params.Identity, SID: p.params.SID, // primary connection does not change, canSubscribe can change if permission was updated // after the participant has joined - SubscriberAsPrimary: p.ProtocolVersion().SubscriberAsPrimary() && p.CanSubscribe(), + SubscriberAsPrimary: subscriberAsPrimary, Config: p.params.Config, ProtocolVersion: p.params.ProtocolVersion, CongestionControlConfig: p.params.CongestionControlConfig, @@ -1171,6 +1250,8 @@ func (p *ParticipantImpl) setupTransportManager() error { AllowPlayoutDelay: p.params.PlayoutDelay.GetEnabled(), DataChannelMaxBufferedAmount: p.params.DataChannelMaxBufferedAmount, Logger: p.params.Logger.WithComponent(sutils.ComponentTransport), + PublisherHandler: pth, + SubscriberHandler: sth, } if p.params.SyncStreams && p.params.PlayoutDelay.GetEnabled() && p.params.ClientInfo.isFirefox() { // we will disable playout delay for Firefox if the user is expecting @@ -1202,27 +1283,6 @@ func (p *ParticipantImpl) setupTransportManager() error { } }) - tm.OnPublisherICECandidate(func(c *webrtc.ICECandidate) error { - return p.onICECandidate(c, livekit.SignalTarget_PUBLISHER) - }) - tm.OnPublisherAnswer(p.onPublisherAnswer) - tm.OnPublisherTrack(p.onMediaTrack) - tm.OnPublisherInitialConnected(p.onPublisherInitialConnected) - - tm.OnSubscriberOffer(p.onSubscriberOffer) - tm.OnSubscriberICECandidate(func(c *webrtc.ICECandidate) error { - return p.onICECandidate(c, livekit.SignalTarget_SUBSCRIBER) - }) - tm.OnSubscriberInitialConnected(p.onSubscriberInitialConnected) - tm.OnSubscriberStreamStateChange(p.onStreamStateChange) - - tm.OnPrimaryTransportInitialConnected(p.onPrimaryTransportInitialConnected) - tm.OnPrimaryTransportFullyEstablished(p.onPrimaryTransportFullyEstablished) - tm.OnAnyTransportFailed(p.onAnyTransportFailed) - tm.OnAnyTransportNegotiationFailed(p.onAnyTransportNegotiationFailed) - - tm.OnDataMessage(p.onDataMessage) - tm.SetSubscriberAllowPause(p.params.SubscriberAllowPause) p.TransportManager = tm return nil diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index aa416b50c..c0e804d65 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -35,12 +35,12 @@ import ( "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu/pacer" "github.com/livekit/livekit-server/pkg/sfu/rtpextension" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" - "github.com/livekit/livekit-server/pkg/utils" sutils "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -77,9 +77,6 @@ var ( ErrIceRestartOnClosedPeerConnection = errors.New("ICE restart on closed peer connection") ErrNoTransceiver = errors.New("no transceiver") ErrNoSender = errors.New("no sender") - ErrNoICECandidateHandler = errors.New("no ICE candidate handler") - ErrNoOfferHandler = errors.New("no offer handler") - ErrNoAnswerHandler = errors.New("no answer handler") ErrMidNotFound = errors.New("mid not found") ) @@ -128,31 +125,6 @@ func (e event) String() string { // ------------------------------------------------------- -type NegotiationState int - -const ( - NegotiationStateNone NegotiationState = iota - // waiting for remote description - NegotiationStateRemote - // need to Negotiate again - NegotiationStateRetry -) - -func (n NegotiationState) String() string { - switch n { - case NegotiationStateNone: - return "NONE" - case NegotiationStateRemote: - return "WAITING_FOR_REMOTE" - case NegotiationStateRetry: - return "RETRY" - default: - return fmt.Sprintf("%d", int(n)) - } -} - -// ------------------------------------------------------- - type SimulcastTrackInfo struct { Mid string Rid string @@ -175,7 +147,6 @@ type PCTransport struct { reliableDCOpened bool lossyDC *webrtc.DataChannel lossyDCOpened bool - onDataPacket func(kind livekit.DataPacket_Kind, data []byte) iceStartedAt time.Time iceConnectedAt time.Time @@ -186,18 +157,10 @@ type PCTransport struct { resetShortConnOnICERestart atomic.Bool signalingRTT atomic.Uint32 // milliseconds - onFullyEstablished func() - debouncedNegotiate func(func()) debouncePending bool - onICECandidate func(c *webrtc.ICECandidate) error - onOffer func(offer webrtc.SessionDescription) error - onAnswer func(answer webrtc.SessionDescription) error - onInitialConnected func() - onFailed func(isShortLived bool) - onNegotiationStateChanged func(state NegotiationState) - onNegotiationFailed func() + onNegotiationStateChanged func(state transport.NegotiationState) // stream allocator for subscriber PC streamAllocator *streamallocator.StreamAllocator @@ -213,7 +176,7 @@ type PCTransport struct { preferTCP atomic.Bool isClosed atomic.Bool - eventsQueue *utils.OpsQueue + eventsQueue *sutils.OpsQueue // the following should be accessed only in event processing go routine cacheLocalCandidates bool @@ -221,7 +184,7 @@ type PCTransport struct { pendingRemoteCandidates []*webrtc.ICECandidateInit restartAfterGathering bool restartAtNextOffer bool - negotiationState NegotiationState + negotiationState transport.NegotiationState negotiateCounter atomic.Int32 signalStateCheckTimer *time.Timer currentOfferIceCredential string // ice user:pwd, for publish side ice restart checking @@ -231,6 +194,7 @@ type PCTransport struct { } type TransportParams struct { + Handler transport.Handler ParticipantID livekit.ParticipantID ParticipantIdentity livekit.ParticipantIdentity ProtocolVersion types.ProtocolVersion @@ -378,8 +342,8 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { t := &PCTransport{ params: params, debouncedNegotiate: debounce.New(negotiationFrequency), - negotiationState: NegotiationStateNone, - eventsQueue: utils.NewOpsQueue("transport", 64, false), + negotiationState: transport.NegotiationStateNone, + eventsQueue: sutils.NewOpsQueue("transport", 64, false), previousTrackDescription: make(map[string]*trackDescription), canReuseTransceiver: true, connectionDetails: types.NewICEConnectionDetails(params.Transport, params.Logger), @@ -389,6 +353,7 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { Config: params.CongestionControlConfig, Logger: params.Logger.WithComponent(sutils.ComponentCongestionControl), }) + t.streamAllocator.OnStreamStateChange(params.Handler.OnStreamStateChange) t.streamAllocator.Start() t.pacer = pacer.NewPassThrough(params.Logger) } @@ -419,6 +384,7 @@ func (t *PCTransport) createPeerConnection() error { t.pc.OnConnectionStateChange(t.onPeerConnectionStateChange) t.pc.OnDataChannel(t.onDataChannel) + t.pc.OnTrack(t.params.Handler.OnTrack) t.me = me @@ -608,9 +574,7 @@ func (t *PCTransport) handleConnectionFailed(forceShortConn bool) { t.params.Logger.Infow("force short ICE connection") } - if onFailed := t.getOnFailed(); onFailed != nil { - onFailed(isShort) - } + t.params.Handler.OnFailed(isShort) } func (t *PCTransport) onICEConnectionStateChange(state webrtc.ICEConnectionState) { @@ -644,9 +608,7 @@ func (t *PCTransport) onPeerConnectionStateChange(state webrtc.PeerConnectionSta t.clearConnTimer() isInitialConnection := t.setConnectedAt(time.Now()) if isInitialConnection { - if onInitialConnected := t.getOnInitialConnected(); onInitialConnected != nil { - onInitialConnected() - } + t.params.Handler.OnInitialConnected() t.maybeNotifyFullyEstablished() } @@ -666,9 +628,7 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { t.reliableDCOpened = true t.lock.Unlock() dc.OnMessage(func(msg webrtc.DataChannelMessage) { - if onDataPacket := t.getOnDataPacket(); onDataPacket != nil { - onDataPacket(livekit.DataPacket_RELIABLE, msg.Data) - } + t.params.Handler.OnDataPacket(livekit.DataPacket_RELIABLE, msg.Data) }) t.maybeNotifyFullyEstablished() @@ -678,9 +638,7 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { t.lossyDCOpened = true t.lock.Unlock() dc.OnMessage(func(msg webrtc.DataChannelMessage) { - if onDataPacket := t.getOnDataPacket(); onDataPacket != nil { - onDataPacket(livekit.DataPacket_LOSSY, msg.Data) - } + t.params.Handler.OnDataPacket(livekit.DataPacket_LOSSY, msg.Data) }) t.maybeNotifyFullyEstablished() @@ -691,9 +649,7 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { func (t *PCTransport) maybeNotifyFullyEstablished() { if t.isFullyEstablished() { - if onFullyEstablished := t.getOnFullyEstablished(); onFullyEstablished != nil { - onFullyEstablished() - } + t.params.Handler.OnFullyEstablished() } } @@ -975,134 +931,19 @@ func (t *PCTransport) HandleRemoteDescription(sd webrtc.SessionDescription) { }) } -func (t *PCTransport) OnICECandidate(f func(c *webrtc.ICECandidate) error) { - t.lock.Lock() - t.onICECandidate = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnICECandidate() func(c *webrtc.ICECandidate) error { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onICECandidate -} - -func (t *PCTransport) OnInitialConnected(f func()) { - t.lock.Lock() - t.onInitialConnected = f - t.lock.Unlock() - - if f != nil { - if t.pc.ConnectionState() == webrtc.PeerConnectionStateConnected { - go f() - } - } -} - -func (t *PCTransport) getOnInitialConnected() func() { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onInitialConnected -} - -func (t *PCTransport) OnFullyEstablished(f func()) { - t.lock.Lock() - t.onFullyEstablished = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnFullyEstablished() func() { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onFullyEstablished -} - -func (t *PCTransport) OnFailed(f func(isShortLived bool)) { - t.lock.Lock() - t.onFailed = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnFailed() func(isShortLived bool) { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onFailed -} - -func (t *PCTransport) OnTrack(f func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver)) { - t.pc.OnTrack(f) -} - -func (t *PCTransport) OnDataPacket(f func(kind livekit.DataPacket_Kind, data []byte)) { - t.lock.Lock() - t.onDataPacket = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnDataPacket() func(kind livekit.DataPacket_Kind, data []byte) { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onDataPacket -} - -// OnOffer is called when the PeerConnection starts negotiation and prepares an offer -func (t *PCTransport) OnOffer(f func(sd webrtc.SessionDescription) error) { - t.lock.Lock() - t.onOffer = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnOffer() func(sd webrtc.SessionDescription) error { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onOffer -} - -func (t *PCTransport) OnAnswer(f func(sd webrtc.SessionDescription) error) { - t.lock.Lock() - t.onAnswer = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnAnswer() func(sd webrtc.SessionDescription) error { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onAnswer -} - -func (t *PCTransport) OnNegotiationStateChanged(f func(state NegotiationState)) { +func (t *PCTransport) OnNegotiationStateChanged(f func(state transport.NegotiationState)) { t.lock.Lock() t.onNegotiationStateChanged = f t.lock.Unlock() } -func (t *PCTransport) getOnNegotiationStateChanged() func(state NegotiationState) { +func (t *PCTransport) getOnNegotiationStateChanged() func(state transport.NegotiationState) { t.lock.RLock() defer t.lock.RUnlock() return t.onNegotiationStateChanged } -func (t *PCTransport) OnNegotiationFailed(f func()) { - t.lock.Lock() - t.onNegotiationFailed = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnNegotiationFailed() func() { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onNegotiationFailed -} - func (t *PCTransport) Negotiate(force bool) { if t.isClosed.Load() { return @@ -1153,14 +994,6 @@ func (t *PCTransport) ResetShortConnOnICERestart() { t.resetShortConnOnICERestart.Store(true) } -func (t *PCTransport) OnStreamStateChange(f func(update *streamallocator.StreamStateUpdate) error) { - if t.streamAllocator == nil { - return - } - - t.streamAllocator.OnStreamStateChange(f) -} - func (t *PCTransport) AddTrackToStreamAllocator(subTrack types.SubscribedTrack) { if t.streamAllocator == nil { return @@ -1322,9 +1155,7 @@ func (t *PCTransport) initPCWithPreviousAnswer(previousAnswer webrtc.SessionDesc func (t *PCTransport) SetPreviousSdp(offer, answer *webrtc.SessionDescription) { // when there is no previous answer, cannot migrate, force a full reconnect if answer == nil { - if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { - onNegotiationFailed() - } + t.params.Handler.OnNegotiationFailed() return } @@ -1335,9 +1166,7 @@ func (t *PCTransport) SetPreviousSdp(offer, answer *webrtc.SessionDescription) { t.params.Logger.Errorw("initPCWithPreviousAnswer failed", err) t.lock.Unlock() - if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { - onNegotiationFailed() - } + t.params.Handler.OnNegotiationFailed() return } else if offer != nil { // in migration case, can't reuse transceiver before negotiated except track subscribed at previous node @@ -1384,9 +1213,7 @@ func (t *PCTransport) postEvent(event event) { if err != nil { if !t.isClosed.Load() { t.params.Logger.Errorw("error handling event", err, "event", event.String()) - if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { - onNegotiationFailed() - } + t.params.Handler.OnNegotiationFailed() } } }) @@ -1455,17 +1282,12 @@ func (t *PCTransport) localDescriptionSent() error { cachedLocalCandidates := t.cachedLocalCandidates t.cachedLocalCandidates = nil - if onICECandidate := t.getOnICECandidate(); onICECandidate != nil { - for _, c := range cachedLocalCandidates { - if err := onICECandidate(c); err != nil { - return err - } + for _, c := range cachedLocalCandidates { + if err := t.params.Handler.OnICECandidate(c, t.params.Transport); err != nil { + return err } - - return nil } - - return ErrNoICECandidateHandler + return nil } func (t *PCTransport) clearLocalDescriptionSent() { @@ -1498,11 +1320,7 @@ func (t *PCTransport) handleLocalICECandidate(e *event) error { return nil } - if onICECandidate := t.getOnICECandidate(); onICECandidate != nil { - return onICECandidate(c) - } - - return ErrNoICECandidateHandler + return t.params.Handler.OnICECandidate(c, t.params.Transport) } func (t *PCTransport) handleRemoteICECandidate(e *event) error { @@ -1531,7 +1349,7 @@ func (t *PCTransport) handleRemoteICECandidate(e *event) error { return nil } -func (t *PCTransport) setNegotiationState(state NegotiationState) { +func (t *PCTransport) setNegotiationState(state transport.NegotiationState) { t.negotiationState = state if onNegotiationStateChanged := t.getOnNegotiationStateChanged(); onNegotiationStateChanged != nil { onNegotiationStateChanged(t.negotiationState) @@ -1592,7 +1410,7 @@ func (t *PCTransport) setupSignalStateCheckTimer() { t.signalStateCheckTimer = time.AfterFunc(negotiationFailedTimeout, func() { t.clearSignalStateCheckTimer() - failed := t.negotiationState != NegotiationStateNone + failed := t.negotiationState != transport.NegotiationStateNone if t.negotiateCounter.Load() == negotiateVersion && failed && t.pc.ConnectionState() == webrtc.PeerConnectionStateConnected { t.params.Logger.Infow( @@ -1602,9 +1420,7 @@ func (t *PCTransport) setupSignalStateCheckTimer() { "remoteCurrent", t.pc.CurrentRemoteDescription(), "remotePending", t.pc.PendingRemoteDescription(), ) - if onNegotiationFailed := t.getOnNegotiationFailed(); onNegotiationFailed != nil { - onNegotiationFailed() - } + t.params.Handler.OnNegotiationFailed() } }) } @@ -1616,11 +1432,11 @@ func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { } // when there's an ongoing negotiation, let it finish and not disrupt its state - if t.negotiationState == NegotiationStateRemote { + if t.negotiationState == transport.NegotiationStateRemote { t.params.Logger.Debugw("skipping negotiation, trying again later") - t.setNegotiationState(NegotiationStateRetry) + t.setNegotiationState(transport.NegotiationStateRetry) return nil - } else if t.negotiationState == NegotiationStateRetry { + } else if t.negotiationState == transport.NegotiationStateRetry { // already set to retry, we can safely skip this attempt return nil } @@ -1690,21 +1506,17 @@ func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { } // indicate waiting for remote - t.setNegotiationState(NegotiationStateRemote) + t.setNegotiationState(transport.NegotiationStateRemote) t.setupSignalStateCheckTimer() - if onOffer := t.getOnOffer(); onOffer != nil { - if err := onOffer(offer); err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) - return errors.Wrap(err, "could not send offer") - } - - prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) - return t.localDescriptionSent() + if err := t.params.Handler.OnOffer(offer); err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) + return errors.Wrap(err, "could not send offer") } - return ErrNoOfferHandler + prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) + return t.localDescriptionSent() } func (t *PCTransport) handleSendOffer(_ *event) error { @@ -1811,17 +1623,13 @@ func (t *PCTransport) createAndSendAnswer() error { t.params.Logger.Debugw("local answer (filtered)", "sdp", answer.SDP) } - if onAnswer := t.getOnAnswer(); onAnswer != nil { - if err := onAnswer(answer); err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "write_message").Add(1) - return errors.Wrap(err, "could not send answer") - } - - prometheus.ServiceOperationCounter.WithLabelValues("answer", "success", "").Add(1) - return t.localDescriptionSent() + if err := t.params.Handler.OnAnswer(answer); err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "write_message").Add(1) + return errors.Wrap(err, "could not send answer") } - return ErrNoAnswerHandler + prometheus.ServiceOperationCounter.WithLabelValues("answer", "success", "").Add(1) + return t.localDescriptionSent() } func (t *PCTransport) handleRemoteOfferReceived(sd *webrtc.SessionDescription) error { @@ -1868,14 +1676,14 @@ func (t *PCTransport) handleRemoteAnswerReceived(sd *webrtc.SessionDescription) } } - if t.negotiationState == NegotiationStateRetry { - t.setNegotiationState(NegotiationStateNone) + if t.negotiationState == transport.NegotiationStateRetry { + t.setNegotiationState(transport.NegotiationStateNone) t.params.Logger.Debugw("re-negotiate after receiving answer") return t.createAndSendOffer(nil) } - t.setNegotiationState(NegotiationStateNone) + t.setNegotiationState(transport.NegotiationStateNone) return nil } @@ -1896,7 +1704,7 @@ func (t *PCTransport) doICERestart() error { t.resetShortConn() } - if t.negotiationState == NegotiationStateNone { + if t.negotiationState == transport.NegotiationStateNone { return t.createAndSendOffer(&webrtc.OfferOptions{ICERestart: true}) } @@ -1910,18 +1718,15 @@ func (t *PCTransport) doICERestart() error { return ErrIceRestartWithoutLocalSDP } else { t.params.Logger.Infow("deferring ice restart to next offer") - t.setNegotiationState(NegotiationStateRetry) + t.setNegotiationState(transport.NegotiationStateRetry) t.restartAtNextOffer = true - if onOffer := t.getOnOffer(); onOffer != nil { - err := onOffer(*offer) - if err != nil { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) - } else { - prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) - } - return err + err := t.params.Handler.OnOffer(*offer) + if err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "write_message").Add(1) + } else { + prometheus.ServiceOperationCounter.WithLabelValues("offer", "success", "").Add(1) } - return ErrNoOfferHandler + return err } } else { // recover by re-applying the last answer @@ -1930,7 +1735,7 @@ func (t *PCTransport) doICERestart() error { prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "remote_description").Add(1) return errors.Wrap(err, "set remote description failed") } else { - t.setNegotiationState(NegotiationStateNone) + t.setNegotiationState(transport.NegotiationStateNone) return t.createAndSendOffer(&webrtc.OfferOptions{ICERestart: true}) } } diff --git a/pkg/rtc/transport/handler.go b/pkg/rtc/transport/handler.go new file mode 100644 index 000000000..eab3de349 --- /dev/null +++ b/pkg/rtc/transport/handler.go @@ -0,0 +1,55 @@ +package transport + +import ( + "errors" + + "github.com/pion/webrtc/v3" + + "github.com/livekit/livekit-server/pkg/sfu/streamallocator" + "github.com/livekit/protocol/livekit" +) + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +var ( + ErrNoICECandidateHandler = errors.New("no ICE candidate handler") + ErrNoOfferHandler = errors.New("no offer handler") + ErrNoAnswerHandler = errors.New("no answer handler") +) + +//counterfeiter:generate . Handler +type Handler interface { + OnICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error + OnInitialConnected() + OnFullyEstablished() + OnFailed(isShortLived bool) + OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) + OnDataPacket(kind livekit.DataPacket_Kind, data []byte) + OnOffer(sd webrtc.SessionDescription) error + OnAnswer(sd webrtc.SessionDescription) error + OnNegotiationStateChanged(state NegotiationState) + OnNegotiationFailed() + OnStreamStateChange(update *streamallocator.StreamStateUpdate) error +} + +type UnimplementedHandler struct{} + +func (h UnimplementedHandler) OnICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error { + return ErrNoICECandidateHandler +} +func (h UnimplementedHandler) OnInitialConnected() {} +func (h UnimplementedHandler) OnFullyEstablished() {} +func (h UnimplementedHandler) OnFailed(isShortLived bool) {} +func (h UnimplementedHandler) OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) {} +func (h UnimplementedHandler) OnDataPacket(kind livekit.DataPacket_Kind, data []byte) {} +func (h UnimplementedHandler) OnOffer(sd webrtc.SessionDescription) error { + return ErrNoOfferHandler +} +func (h UnimplementedHandler) OnAnswer(sd webrtc.SessionDescription) error { + return ErrNoAnswerHandler +} +func (h UnimplementedHandler) OnNegotiationStateChanged(state NegotiationState) {} +func (h UnimplementedHandler) OnNegotiationFailed() {} +func (h UnimplementedHandler) OnStreamStateChange(update *streamallocator.StreamStateUpdate) error { + return nil +} diff --git a/pkg/rtc/transport/negotiationstate.go b/pkg/rtc/transport/negotiationstate.go new file mode 100644 index 000000000..8f074f83f --- /dev/null +++ b/pkg/rtc/transport/negotiationstate.go @@ -0,0 +1,26 @@ +package transport + +import "fmt" + +type NegotiationState int + +const ( + NegotiationStateNone NegotiationState = iota + // waiting for remote description + NegotiationStateRemote + // need to Negotiate again + NegotiationStateRetry +) + +func (n NegotiationState) String() string { + switch n { + case NegotiationStateNone: + return "NONE" + case NegotiationStateRemote: + return "WAITING_FOR_REMOTE" + case NegotiationStateRetry: + return "RETRY" + default: + return fmt.Sprintf("%d", int(n)) + } +} diff --git a/pkg/rtc/transport/transportfakes/fake_handler.go b/pkg/rtc/transport/transportfakes/fake_handler.go new file mode 100644 index 000000000..dc7ad3c3e --- /dev/null +++ b/pkg/rtc/transport/transportfakes/fake_handler.go @@ -0,0 +1,593 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package transportfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/transport" + "github.com/livekit/livekit-server/pkg/sfu/streamallocator" + "github.com/livekit/protocol/livekit" + webrtc "github.com/pion/webrtc/v3" +) + +type FakeHandler struct { + OnAnswerStub func(webrtc.SessionDescription) error + onAnswerMutex sync.RWMutex + onAnswerArgsForCall []struct { + arg1 webrtc.SessionDescription + } + onAnswerReturns struct { + result1 error + } + onAnswerReturnsOnCall map[int]struct { + result1 error + } + OnDataPacketStub func(livekit.DataPacket_Kind, []byte) + onDataPacketMutex sync.RWMutex + onDataPacketArgsForCall []struct { + arg1 livekit.DataPacket_Kind + arg2 []byte + } + OnFailedStub func(bool) + onFailedMutex sync.RWMutex + onFailedArgsForCall []struct { + arg1 bool + } + OnFullyEstablishedStub func() + onFullyEstablishedMutex sync.RWMutex + onFullyEstablishedArgsForCall []struct { + } + OnICECandidateStub func(*webrtc.ICECandidate, livekit.SignalTarget) error + onICECandidateMutex sync.RWMutex + onICECandidateArgsForCall []struct { + arg1 *webrtc.ICECandidate + arg2 livekit.SignalTarget + } + onICECandidateReturns struct { + result1 error + } + onICECandidateReturnsOnCall map[int]struct { + result1 error + } + OnInitialConnectedStub func() + onInitialConnectedMutex sync.RWMutex + onInitialConnectedArgsForCall []struct { + } + OnNegotiationFailedStub func() + onNegotiationFailedMutex sync.RWMutex + onNegotiationFailedArgsForCall []struct { + } + OnNegotiationStateChangedStub func(transport.NegotiationState) + onNegotiationStateChangedMutex sync.RWMutex + onNegotiationStateChangedArgsForCall []struct { + arg1 transport.NegotiationState + } + OnOfferStub func(webrtc.SessionDescription) error + onOfferMutex sync.RWMutex + onOfferArgsForCall []struct { + arg1 webrtc.SessionDescription + } + onOfferReturns struct { + result1 error + } + onOfferReturnsOnCall map[int]struct { + result1 error + } + OnStreamStateChangeStub func(*streamallocator.StreamStateUpdate) error + onStreamStateChangeMutex sync.RWMutex + onStreamStateChangeArgsForCall []struct { + arg1 *streamallocator.StreamStateUpdate + } + onStreamStateChangeReturns struct { + result1 error + } + onStreamStateChangeReturnsOnCall map[int]struct { + result1 error + } + OnTrackStub func(*webrtc.TrackRemote, *webrtc.RTPReceiver) + onTrackMutex sync.RWMutex + onTrackArgsForCall []struct { + arg1 *webrtc.TrackRemote + arg2 *webrtc.RTPReceiver + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHandler) OnAnswer(arg1 webrtc.SessionDescription) error { + fake.onAnswerMutex.Lock() + ret, specificReturn := fake.onAnswerReturnsOnCall[len(fake.onAnswerArgsForCall)] + fake.onAnswerArgsForCall = append(fake.onAnswerArgsForCall, struct { + arg1 webrtc.SessionDescription + }{arg1}) + stub := fake.OnAnswerStub + fakeReturns := fake.onAnswerReturns + fake.recordInvocation("OnAnswer", []interface{}{arg1}) + fake.onAnswerMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnAnswerCallCount() int { + fake.onAnswerMutex.RLock() + defer fake.onAnswerMutex.RUnlock() + return len(fake.onAnswerArgsForCall) +} + +func (fake *FakeHandler) OnAnswerCalls(stub func(webrtc.SessionDescription) error) { + fake.onAnswerMutex.Lock() + defer fake.onAnswerMutex.Unlock() + fake.OnAnswerStub = stub +} + +func (fake *FakeHandler) OnAnswerArgsForCall(i int) webrtc.SessionDescription { + fake.onAnswerMutex.RLock() + defer fake.onAnswerMutex.RUnlock() + argsForCall := fake.onAnswerArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnAnswerReturns(result1 error) { + fake.onAnswerMutex.Lock() + defer fake.onAnswerMutex.Unlock() + fake.OnAnswerStub = nil + fake.onAnswerReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnAnswerReturnsOnCall(i int, result1 error) { + fake.onAnswerMutex.Lock() + defer fake.onAnswerMutex.Unlock() + fake.OnAnswerStub = nil + if fake.onAnswerReturnsOnCall == nil { + fake.onAnswerReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onAnswerReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnDataPacket(arg1 livekit.DataPacket_Kind, arg2 []byte) { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.onDataPacketMutex.Lock() + fake.onDataPacketArgsForCall = append(fake.onDataPacketArgsForCall, struct { + arg1 livekit.DataPacket_Kind + arg2 []byte + }{arg1, arg2Copy}) + stub := fake.OnDataPacketStub + fake.recordInvocation("OnDataPacket", []interface{}{arg1, arg2Copy}) + fake.onDataPacketMutex.Unlock() + if stub != nil { + fake.OnDataPacketStub(arg1, arg2) + } +} + +func (fake *FakeHandler) OnDataPacketCallCount() int { + fake.onDataPacketMutex.RLock() + defer fake.onDataPacketMutex.RUnlock() + return len(fake.onDataPacketArgsForCall) +} + +func (fake *FakeHandler) OnDataPacketCalls(stub func(livekit.DataPacket_Kind, []byte)) { + fake.onDataPacketMutex.Lock() + defer fake.onDataPacketMutex.Unlock() + fake.OnDataPacketStub = stub +} + +func (fake *FakeHandler) OnDataPacketArgsForCall(i int) (livekit.DataPacket_Kind, []byte) { + fake.onDataPacketMutex.RLock() + defer fake.onDataPacketMutex.RUnlock() + argsForCall := fake.onDataPacketArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnFailed(arg1 bool) { + fake.onFailedMutex.Lock() + fake.onFailedArgsForCall = append(fake.onFailedArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.OnFailedStub + fake.recordInvocation("OnFailed", []interface{}{arg1}) + fake.onFailedMutex.Unlock() + if stub != nil { + fake.OnFailedStub(arg1) + } +} + +func (fake *FakeHandler) OnFailedCallCount() int { + fake.onFailedMutex.RLock() + defer fake.onFailedMutex.RUnlock() + return len(fake.onFailedArgsForCall) +} + +func (fake *FakeHandler) OnFailedCalls(stub func(bool)) { + fake.onFailedMutex.Lock() + defer fake.onFailedMutex.Unlock() + fake.OnFailedStub = stub +} + +func (fake *FakeHandler) OnFailedArgsForCall(i int) bool { + fake.onFailedMutex.RLock() + defer fake.onFailedMutex.RUnlock() + argsForCall := fake.onFailedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnFullyEstablished() { + fake.onFullyEstablishedMutex.Lock() + fake.onFullyEstablishedArgsForCall = append(fake.onFullyEstablishedArgsForCall, struct { + }{}) + stub := fake.OnFullyEstablishedStub + fake.recordInvocation("OnFullyEstablished", []interface{}{}) + fake.onFullyEstablishedMutex.Unlock() + if stub != nil { + fake.OnFullyEstablishedStub() + } +} + +func (fake *FakeHandler) OnFullyEstablishedCallCount() int { + fake.onFullyEstablishedMutex.RLock() + defer fake.onFullyEstablishedMutex.RUnlock() + return len(fake.onFullyEstablishedArgsForCall) +} + +func (fake *FakeHandler) OnFullyEstablishedCalls(stub func()) { + fake.onFullyEstablishedMutex.Lock() + defer fake.onFullyEstablishedMutex.Unlock() + fake.OnFullyEstablishedStub = stub +} + +func (fake *FakeHandler) OnICECandidate(arg1 *webrtc.ICECandidate, arg2 livekit.SignalTarget) error { + fake.onICECandidateMutex.Lock() + ret, specificReturn := fake.onICECandidateReturnsOnCall[len(fake.onICECandidateArgsForCall)] + fake.onICECandidateArgsForCall = append(fake.onICECandidateArgsForCall, struct { + arg1 *webrtc.ICECandidate + arg2 livekit.SignalTarget + }{arg1, arg2}) + stub := fake.OnICECandidateStub + fakeReturns := fake.onICECandidateReturns + fake.recordInvocation("OnICECandidate", []interface{}{arg1, arg2}) + fake.onICECandidateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnICECandidateCallCount() int { + fake.onICECandidateMutex.RLock() + defer fake.onICECandidateMutex.RUnlock() + return len(fake.onICECandidateArgsForCall) +} + +func (fake *FakeHandler) OnICECandidateCalls(stub func(*webrtc.ICECandidate, livekit.SignalTarget) error) { + fake.onICECandidateMutex.Lock() + defer fake.onICECandidateMutex.Unlock() + fake.OnICECandidateStub = stub +} + +func (fake *FakeHandler) OnICECandidateArgsForCall(i int) (*webrtc.ICECandidate, livekit.SignalTarget) { + fake.onICECandidateMutex.RLock() + defer fake.onICECandidateMutex.RUnlock() + argsForCall := fake.onICECandidateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnICECandidateReturns(result1 error) { + fake.onICECandidateMutex.Lock() + defer fake.onICECandidateMutex.Unlock() + fake.OnICECandidateStub = nil + fake.onICECandidateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnICECandidateReturnsOnCall(i int, result1 error) { + fake.onICECandidateMutex.Lock() + defer fake.onICECandidateMutex.Unlock() + fake.OnICECandidateStub = nil + if fake.onICECandidateReturnsOnCall == nil { + fake.onICECandidateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onICECandidateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnInitialConnected() { + fake.onInitialConnectedMutex.Lock() + fake.onInitialConnectedArgsForCall = append(fake.onInitialConnectedArgsForCall, struct { + }{}) + stub := fake.OnInitialConnectedStub + fake.recordInvocation("OnInitialConnected", []interface{}{}) + fake.onInitialConnectedMutex.Unlock() + if stub != nil { + fake.OnInitialConnectedStub() + } +} + +func (fake *FakeHandler) OnInitialConnectedCallCount() int { + fake.onInitialConnectedMutex.RLock() + defer fake.onInitialConnectedMutex.RUnlock() + return len(fake.onInitialConnectedArgsForCall) +} + +func (fake *FakeHandler) OnInitialConnectedCalls(stub func()) { + fake.onInitialConnectedMutex.Lock() + defer fake.onInitialConnectedMutex.Unlock() + fake.OnInitialConnectedStub = stub +} + +func (fake *FakeHandler) OnNegotiationFailed() { + fake.onNegotiationFailedMutex.Lock() + fake.onNegotiationFailedArgsForCall = append(fake.onNegotiationFailedArgsForCall, struct { + }{}) + stub := fake.OnNegotiationFailedStub + fake.recordInvocation("OnNegotiationFailed", []interface{}{}) + fake.onNegotiationFailedMutex.Unlock() + if stub != nil { + fake.OnNegotiationFailedStub() + } +} + +func (fake *FakeHandler) OnNegotiationFailedCallCount() int { + fake.onNegotiationFailedMutex.RLock() + defer fake.onNegotiationFailedMutex.RUnlock() + return len(fake.onNegotiationFailedArgsForCall) +} + +func (fake *FakeHandler) OnNegotiationFailedCalls(stub func()) { + fake.onNegotiationFailedMutex.Lock() + defer fake.onNegotiationFailedMutex.Unlock() + fake.OnNegotiationFailedStub = stub +} + +func (fake *FakeHandler) OnNegotiationStateChanged(arg1 transport.NegotiationState) { + fake.onNegotiationStateChangedMutex.Lock() + fake.onNegotiationStateChangedArgsForCall = append(fake.onNegotiationStateChangedArgsForCall, struct { + arg1 transport.NegotiationState + }{arg1}) + stub := fake.OnNegotiationStateChangedStub + fake.recordInvocation("OnNegotiationStateChanged", []interface{}{arg1}) + fake.onNegotiationStateChangedMutex.Unlock() + if stub != nil { + fake.OnNegotiationStateChangedStub(arg1) + } +} + +func (fake *FakeHandler) OnNegotiationStateChangedCallCount() int { + fake.onNegotiationStateChangedMutex.RLock() + defer fake.onNegotiationStateChangedMutex.RUnlock() + return len(fake.onNegotiationStateChangedArgsForCall) +} + +func (fake *FakeHandler) OnNegotiationStateChangedCalls(stub func(transport.NegotiationState)) { + fake.onNegotiationStateChangedMutex.Lock() + defer fake.onNegotiationStateChangedMutex.Unlock() + fake.OnNegotiationStateChangedStub = stub +} + +func (fake *FakeHandler) OnNegotiationStateChangedArgsForCall(i int) transport.NegotiationState { + fake.onNegotiationStateChangedMutex.RLock() + defer fake.onNegotiationStateChangedMutex.RUnlock() + argsForCall := fake.onNegotiationStateChangedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnOffer(arg1 webrtc.SessionDescription) error { + fake.onOfferMutex.Lock() + ret, specificReturn := fake.onOfferReturnsOnCall[len(fake.onOfferArgsForCall)] + fake.onOfferArgsForCall = append(fake.onOfferArgsForCall, struct { + arg1 webrtc.SessionDescription + }{arg1}) + stub := fake.OnOfferStub + fakeReturns := fake.onOfferReturns + fake.recordInvocation("OnOffer", []interface{}{arg1}) + fake.onOfferMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnOfferCallCount() int { + fake.onOfferMutex.RLock() + defer fake.onOfferMutex.RUnlock() + return len(fake.onOfferArgsForCall) +} + +func (fake *FakeHandler) OnOfferCalls(stub func(webrtc.SessionDescription) error) { + fake.onOfferMutex.Lock() + defer fake.onOfferMutex.Unlock() + fake.OnOfferStub = stub +} + +func (fake *FakeHandler) OnOfferArgsForCall(i int) webrtc.SessionDescription { + fake.onOfferMutex.RLock() + defer fake.onOfferMutex.RUnlock() + argsForCall := fake.onOfferArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnOfferReturns(result1 error) { + fake.onOfferMutex.Lock() + defer fake.onOfferMutex.Unlock() + fake.OnOfferStub = nil + fake.onOfferReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnOfferReturnsOnCall(i int, result1 error) { + fake.onOfferMutex.Lock() + defer fake.onOfferMutex.Unlock() + fake.OnOfferStub = nil + if fake.onOfferReturnsOnCall == nil { + fake.onOfferReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onOfferReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnStreamStateChange(arg1 *streamallocator.StreamStateUpdate) error { + fake.onStreamStateChangeMutex.Lock() + ret, specificReturn := fake.onStreamStateChangeReturnsOnCall[len(fake.onStreamStateChangeArgsForCall)] + fake.onStreamStateChangeArgsForCall = append(fake.onStreamStateChangeArgsForCall, struct { + arg1 *streamallocator.StreamStateUpdate + }{arg1}) + stub := fake.OnStreamStateChangeStub + fakeReturns := fake.onStreamStateChangeReturns + fake.recordInvocation("OnStreamStateChange", []interface{}{arg1}) + fake.onStreamStateChangeMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnStreamStateChangeCallCount() int { + fake.onStreamStateChangeMutex.RLock() + defer fake.onStreamStateChangeMutex.RUnlock() + return len(fake.onStreamStateChangeArgsForCall) +} + +func (fake *FakeHandler) OnStreamStateChangeCalls(stub func(*streamallocator.StreamStateUpdate) error) { + fake.onStreamStateChangeMutex.Lock() + defer fake.onStreamStateChangeMutex.Unlock() + fake.OnStreamStateChangeStub = stub +} + +func (fake *FakeHandler) OnStreamStateChangeArgsForCall(i int) *streamallocator.StreamStateUpdate { + fake.onStreamStateChangeMutex.RLock() + defer fake.onStreamStateChangeMutex.RUnlock() + argsForCall := fake.onStreamStateChangeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnStreamStateChangeReturns(result1 error) { + fake.onStreamStateChangeMutex.Lock() + defer fake.onStreamStateChangeMutex.Unlock() + fake.OnStreamStateChangeStub = nil + fake.onStreamStateChangeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnStreamStateChangeReturnsOnCall(i int, result1 error) { + fake.onStreamStateChangeMutex.Lock() + defer fake.onStreamStateChangeMutex.Unlock() + fake.OnStreamStateChangeStub = nil + if fake.onStreamStateChangeReturnsOnCall == nil { + fake.onStreamStateChangeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onStreamStateChangeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnTrack(arg1 *webrtc.TrackRemote, arg2 *webrtc.RTPReceiver) { + fake.onTrackMutex.Lock() + fake.onTrackArgsForCall = append(fake.onTrackArgsForCall, struct { + arg1 *webrtc.TrackRemote + arg2 *webrtc.RTPReceiver + }{arg1, arg2}) + stub := fake.OnTrackStub + fake.recordInvocation("OnTrack", []interface{}{arg1, arg2}) + fake.onTrackMutex.Unlock() + if stub != nil { + fake.OnTrackStub(arg1, arg2) + } +} + +func (fake *FakeHandler) OnTrackCallCount() int { + fake.onTrackMutex.RLock() + defer fake.onTrackMutex.RUnlock() + return len(fake.onTrackArgsForCall) +} + +func (fake *FakeHandler) OnTrackCalls(stub func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + fake.onTrackMutex.Lock() + defer fake.onTrackMutex.Unlock() + fake.OnTrackStub = stub +} + +func (fake *FakeHandler) OnTrackArgsForCall(i int) (*webrtc.TrackRemote, *webrtc.RTPReceiver) { + fake.onTrackMutex.RLock() + defer fake.onTrackMutex.RUnlock() + argsForCall := fake.onTrackArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.onAnswerMutex.RLock() + defer fake.onAnswerMutex.RUnlock() + fake.onDataPacketMutex.RLock() + defer fake.onDataPacketMutex.RUnlock() + fake.onFailedMutex.RLock() + defer fake.onFailedMutex.RUnlock() + fake.onFullyEstablishedMutex.RLock() + defer fake.onFullyEstablishedMutex.RUnlock() + fake.onICECandidateMutex.RLock() + defer fake.onICECandidateMutex.RUnlock() + fake.onInitialConnectedMutex.RLock() + defer fake.onInitialConnectedMutex.RUnlock() + fake.onNegotiationFailedMutex.RLock() + defer fake.onNegotiationFailedMutex.RUnlock() + fake.onNegotiationStateChangedMutex.RLock() + defer fake.onNegotiationStateChangedMutex.RUnlock() + fake.onOfferMutex.RLock() + defer fake.onOfferMutex.RUnlock() + fake.onStreamStateChangeMutex.RLock() + defer fake.onStreamStateChangeMutex.RUnlock() + fake.onTrackMutex.RLock() + defer fake.onTrackMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHandler) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ transport.Handler = new(FakeHandler) diff --git a/pkg/rtc/transport_test.go b/pkg/rtc/transport_test.go index e66670ea5..7bd1067b3 100644 --- a/pkg/rtc/transport_test.go +++ b/pkg/rtc/transport_test.go @@ -26,6 +26,8 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/atomic" + "github.com/livekit/livekit-server/pkg/rtc/transport" + "github.com/livekit/livekit-server/pkg/rtc/transport/transportfakes" "github.com/livekit/livekit-server/pkg/testutils" "github.com/livekit/protocol/livekit" ) @@ -37,33 +39,39 @@ func TestMissingAnswerDuringICERestart(t *testing.T) { Config: &WebRTCConfig{}, IsOfferer: true, } - transportA, err := NewPCTransport(params) + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) require.NoError(t, err) _, err = transportA.pc.CreateDataChannel(ReliableDataChannel, nil) require.NoError(t, err) paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB paramsB.IsOfferer = false transportB, err := NewPCTransport(paramsB) require.NoError(t, err) // exchange ICE - handleICEExchange(t, transportA, transportB) + handleICEExchange(t, transportA, transportB, handlerA, handlerB) - connectTransports(t, transportA, transportB, false, 1, 1) + connectTransports(t, transportA, transportB, handlerA, handlerB, false, 1, 1) require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState()) require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState()) var negotiationState atomic.Value - transportA.OnNegotiationStateChanged(func(state NegotiationState) { + transportA.OnNegotiationStateChanged(func(state transport.NegotiationState) { negotiationState.Store(state) }) // offer again, but missed var offerReceived atomic.Bool - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { require.Equal(t, webrtc.SignalingStateHaveLocalOffer, transportA.pc.SignalingState()) - require.Equal(t, NegotiationStateRemote, negotiationState.Load().(NegotiationState)) + require.Equal(t, transport.NegotiationStateRemote, negotiationState.Load().(transport.NegotiationState)) offerReceived.Store(true) return nil }) @@ -72,7 +80,7 @@ func TestMissingAnswerDuringICERestart(t *testing.T) { return offerReceived.Load() }, 10*time.Second, time.Millisecond*10, "transportA offer not received") - connectTransports(t, transportA, transportB, true, 1, 1) + connectTransports(t, transportA, transportB, handlerA, handlerB, true, 1, 1) require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState()) require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState()) @@ -87,70 +95,76 @@ func TestNegotiationTiming(t *testing.T) { Config: &WebRTCConfig{}, IsOfferer: true, } - transportA, err := NewPCTransport(params) + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) require.NoError(t, err) _, err = transportA.pc.CreateDataChannel(LossyDataChannel, nil) require.NoError(t, err) paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB paramsB.IsOfferer = false - transportB, err := NewPCTransport(params) + transportB, err := NewPCTransport(paramsB) require.NoError(t, err) require.False(t, transportA.IsEstablished()) require.False(t, transportB.IsEstablished()) - handleICEExchange(t, transportA, transportB) + handleICEExchange(t, transportA, transportB, handlerA, handlerB) offer := atomic.Value{} - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { offer.Store(&sd) return nil }) var negotiationState atomic.Value - transportA.OnNegotiationStateChanged(func(state NegotiationState) { + transportA.OnNegotiationStateChanged(func(state transport.NegotiationState) { negotiationState.Store(state) }) // initial offer transportA.Negotiate(true) require.Eventually(t, func() bool { - state, ok := negotiationState.Load().(NegotiationState) + state, ok := negotiationState.Load().(transport.NegotiationState) if !ok { return false } - return state == NegotiationStateRemote + return state == transport.NegotiationStateRemote }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRemote") // second try, should've flipped transport status to retry transportA.Negotiate(true) require.Eventually(t, func() bool { - state, ok := negotiationState.Load().(NegotiationState) + state, ok := negotiationState.Load().(transport.NegotiationState) if !ok { return false } - return state == NegotiationStateRetry + return state == transport.NegotiationStateRetry }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRetry") // third try, should've stayed at retry transportA.Negotiate(true) time.Sleep(100 * time.Millisecond) // some time to process the negotiate event require.Eventually(t, func() bool { - state, ok := negotiationState.Load().(NegotiationState) + state, ok := negotiationState.Load().(transport.NegotiationState) if !ok { return false } - return state == NegotiationStateRetry + return state == transport.NegotiationStateRetry }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRetry") time.Sleep(5 * time.Millisecond) actualOffer, ok := offer.Load().(*webrtc.SessionDescription) require.True(t, ok) - transportB.OnAnswer(func(answer webrtc.SessionDescription) error { + handlerB.OnAnswerCalls(func(answer webrtc.SessionDescription) error { transportA.HandleRemoteDescription(answer) return nil }) @@ -164,7 +178,7 @@ func TestNegotiationTiming(t *testing.T) { }, 10*time.Second, time.Millisecond*10, "transportB is not established") // it should still be negotiating again - require.Equal(t, NegotiationStateRemote, negotiationState.Load().(NegotiationState)) + require.Equal(t, transport.NegotiationStateRemote, negotiationState.Load().(transport.NegotiationState)) offer2, ok := offer.Load().(*webrtc.SessionDescription) require.True(t, ok) require.False(t, offer2 == actualOffer) @@ -180,22 +194,28 @@ func TestFirstOfferMissedDuringICERestart(t *testing.T) { Config: &WebRTCConfig{}, IsOfferer: true, } - transportA, err := NewPCTransport(params) + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) require.NoError(t, err) _, err = transportA.pc.CreateDataChannel(ReliableDataChannel, nil) require.NoError(t, err) paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB paramsB.IsOfferer = false transportB, err := NewPCTransport(paramsB) require.NoError(t, err) // exchange ICE - handleICEExchange(t, transportA, transportB) + handleICEExchange(t, transportA, transportB, handlerA, handlerB) // first offer missed var firstOfferReceived atomic.Bool - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { firstOfferReceived.Store(true) return nil }) @@ -207,13 +227,13 @@ func TestFirstOfferMissedDuringICERestart(t *testing.T) { // set offer/answer with restart ICE, will negotiate twice, // first one is recover from missed offer // second one is restartICE - transportB.OnAnswer(func(answer webrtc.SessionDescription) error { + handlerB.OnAnswerCalls(func(answer webrtc.SessionDescription) error { transportA.HandleRemoteDescription(answer) return nil }) var offerCount atomic.Int32 - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { offerCount.Inc() // the second offer is a ice restart offer, so we wait transportB complete the ice gathering @@ -248,22 +268,28 @@ func TestFirstAnswerMissedDuringICERestart(t *testing.T) { Config: &WebRTCConfig{}, IsOfferer: true, } - transportA, err := NewPCTransport(params) + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) require.NoError(t, err) _, err = transportA.pc.CreateDataChannel(LossyDataChannel, nil) require.NoError(t, err) paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB paramsB.IsOfferer = false transportB, err := NewPCTransport(paramsB) require.NoError(t, err) // exchange ICE - handleICEExchange(t, transportA, transportB) + handleICEExchange(t, transportA, transportB, handlerA, handlerB) // first answer missed var firstAnswerReceived atomic.Bool - transportB.OnAnswer(func(sd webrtc.SessionDescription) error { + handlerB.OnAnswerCalls(func(sd webrtc.SessionDescription) error { if firstAnswerReceived.Load() { transportA.HandleRemoteDescription(sd) } else { @@ -272,7 +298,7 @@ func TestFirstAnswerMissedDuringICERestart(t *testing.T) { } return nil }) - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { transportB.HandleRemoteDescription(sd) return nil }) @@ -286,7 +312,7 @@ func TestFirstAnswerMissedDuringICERestart(t *testing.T) { // first one is recover from missed offer // second one is restartICE var offerCount atomic.Int32 - transportA.OnOffer(func(sd webrtc.SessionDescription) error { + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { offerCount.Inc() // the second offer is a ice restart offer, so we wait transportB complete the ice gathering @@ -321,26 +347,32 @@ func TestNegotiationFailed(t *testing.T) { Config: &WebRTCConfig{}, IsOfferer: true, } - transportA, err := NewPCTransport(params) + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) require.NoError(t, err) _, err = transportA.pc.CreateDataChannel(ReliableDataChannel, nil) require.NoError(t, err) paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB paramsB.IsOfferer = false transportB, err := NewPCTransport(paramsB) require.NoError(t, err) // exchange ICE - handleICEExchange(t, transportA, transportB) + handleICEExchange(t, transportA, transportB, handlerA, handlerB) // wait for transport to be connected before maiming the signalling channel - connectTransports(t, transportA, transportB, false, 1, 1) + connectTransports(t, transportA, transportB, handlerA, handlerB, false, 1, 1) // reset OnOffer to force a negotiation failure - transportA.OnOffer(func(sd webrtc.SessionDescription) error { return nil }) + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription) error { return nil }) var failed atomic.Int32 - transportA.OnNegotiationFailed(func() { + handlerA.OnNegotiationFailedCalls(func() { failed.Inc() }) transportA.Negotiate(true) @@ -361,6 +393,7 @@ func TestFilteringCandidates(t *testing.T) { {Mime: webrtc.MimeTypeVP8}, {Mime: webrtc.MimeTypeH264}, }, + Handler: &transportfakes.FakeHandler{}, } transport, err := NewPCTransport(params) require.NoError(t, err) @@ -471,8 +504,8 @@ func TestFilteringCandidates(t *testing.T) { transport.Close() } -func handleICEExchange(t *testing.T, a, b *PCTransport) { - a.OnICECandidate(func(candidate *webrtc.ICECandidate) error { +func handleICEExchange(t *testing.T, a, b *PCTransport, ah, bh *transportfakes.FakeHandler) { + ah.OnICECandidateCalls(func(candidate *webrtc.ICECandidate, target livekit.SignalTarget) error { if candidate == nil { return nil } @@ -480,7 +513,7 @@ func handleICEExchange(t *testing.T, a, b *PCTransport) { b.AddICECandidate(candidate.ToJSON()) return nil }) - b.OnICECandidate(func(candidate *webrtc.ICECandidate) error { + bh.OnICECandidateCalls(func(candidate *webrtc.ICECandidate, target livekit.SignalTarget) error { if candidate == nil { return nil } @@ -490,16 +523,16 @@ func handleICEExchange(t *testing.T, a, b *PCTransport) { }) } -func connectTransports(t *testing.T, offerer, answerer *PCTransport, isICERestart bool, expectedOfferCount int32, expectedAnswerCount int32) { +func connectTransports(t *testing.T, offerer, answerer *PCTransport, offererHandler, answererHandler *transportfakes.FakeHandler, isICERestart bool, expectedOfferCount int32, expectedAnswerCount int32) { var offerCount atomic.Int32 var answerCount atomic.Int32 - answerer.OnAnswer(func(answer webrtc.SessionDescription) error { + answererHandler.OnAnswerCalls(func(answer webrtc.SessionDescription) error { answerCount.Inc() offerer.HandleRemoteDescription(answer) return nil }) - offerer.OnOffer(func(offer webrtc.SessionDescription) error { + offererHandler.OnOfferCalls(func(offer webrtc.SessionDescription) error { offerCount.Inc() answerer.HandleRemoteDescription(offer) return nil @@ -527,11 +560,11 @@ func connectTransports(t *testing.T, offerer, answerer *PCTransport, isICERestar return answerer.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected }, 10*time.Second, time.Millisecond*10, "answerer did not become connected") - transportsConnected := untilTransportsConnected(offerer, answerer) + transportsConnected := untilTransportsConnected(offererHandler, answererHandler) transportsConnected.Wait() } -func untilTransportsConnected(transports ...*PCTransport) *sync.WaitGroup { +func untilTransportsConnected(transports ...*transportfakes.FakeHandler) *sync.WaitGroup { var triggered sync.WaitGroup triggered.Add(len(transports)) @@ -545,7 +578,10 @@ func untilTransportsConnected(transports ...*PCTransport) *sync.WaitGroup { } } - t.OnInitialConnected(hdlr) + if t.OnInitialConnectedCallCount() != 0 { + hdlr() + } + t.OnInitialConnectedCalls(hdlr) } return &triggered } diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index 7207ace3e..401fea8e0 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -29,10 +29,10 @@ import ( "google.golang.org/protobuf/proto" "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/pacer" - "github.com/livekit/livekit-server/pkg/sfu/streamallocator" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) @@ -47,6 +47,25 @@ const ( udpLossUnstableCountThreshold = 20 ) +type TransportManagerTransportHandler struct { + transport.Handler + t *TransportManager +} + +func (h TransportManagerTransportHandler) OnFailed(isShortLived bool) { + h.t.handleConnectionFailed(isShortLived) + h.Handler.OnFailed(isShortLived) +} + +type TransportManagerPublisherTransportHandler struct { + TransportManagerTransportHandler +} + +func (h TransportManagerPublisherTransportHandler) OnAnswer(sd webrtc.SessionDescription) error { + h.t.lastPublisherAnswer.Store(sd) + return h.Handler.OnAnswer(sd) +} + type TransportManagerParams struct { Identity livekit.ParticipantIdentity SID livekit.ParticipantID @@ -66,6 +85,8 @@ type TransportManagerParams struct { AllowPlayoutDelay bool DataChannelMaxBufferedAmount uint64 Logger logger.Logger + PublisherHandler transport.Handler + SubscriberHandler transport.Handler } type TransportManager struct { @@ -91,11 +112,6 @@ type TransportManager struct { udpLossUnstableCount uint32 signalingRTT, udpRTT uint32 - onPublisherInitialConnected func() - onSubscriberInitialConnected func() - onPrimaryTransportInitialConnected func() - onAnyTransportFailed func() - onICEConfigChanged func(iceConfig *livekit.ICEConfig) } @@ -122,25 +138,12 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro SimTracks: params.SimTracks, ClientInfo: params.ClientInfo, Transport: livekit.SignalTarget_PUBLISHER, + Handler: TransportManagerPublisherTransportHandler{TransportManagerTransportHandler{params.PublisherHandler, t}}, }) if err != nil { return nil, err } t.publisher = publisher - t.publisher.OnInitialConnected(func() { - if t.onPublisherInitialConnected != nil { - t.onPublisherInitialConnected() - } - if !t.params.SubscriberAsPrimary && t.onPrimaryTransportInitialConnected != nil { - t.onPrimaryTransportInitialConnected() - } - }) - t.publisher.OnFailed(func(isShortLived bool) { - t.handleConnectionFailed(isShortLived) - if t.onAnyTransportFailed != nil { - t.onAnyTransportFailed() - } - }) subscriber, err := NewPCTransport(TransportParams{ ParticipantID: params.SID, @@ -157,25 +160,12 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro AllowPlayoutDelay: params.AllowPlayoutDelay, DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, Transport: livekit.SignalTarget_SUBSCRIBER, + Handler: TransportManagerTransportHandler{params.SubscriberHandler, t}, }) if err != nil { return nil, err } t.subscriber = subscriber - t.subscriber.OnInitialConnected(func() { - if t.onSubscriberInitialConnected != nil { - t.onSubscriberInitialConnected() - } - if t.params.SubscriberAsPrimary && t.onPrimaryTransportInitialConnected != nil { - t.onPrimaryTransportInitialConnected() - } - }) - t.subscriber.OnFailed(func(isShortLived bool) { - t.handleConnectionFailed(isShortLived) - if t.onAnyTransportFailed != nil { - t.onAnyTransportFailed() - } - }) if !t.params.Migration { if err := t.createDataChannelsForSubscriber(nil); err != nil { return nil, err @@ -195,25 +185,6 @@ func (t *TransportManager) SubscriberClose() { t.subscriber.Close() } -func (t *TransportManager) OnPublisherICECandidate(f func(c *webrtc.ICECandidate) error) { - t.publisher.OnICECandidate(f) -} - -func (t *TransportManager) OnPublisherAnswer(f func(answer webrtc.SessionDescription) error) { - t.publisher.OnAnswer(func(sd webrtc.SessionDescription) error { - t.lastPublisherAnswer.Store(sd) - return f(sd) - }) -} - -func (t *TransportManager) OnPublisherInitialConnected(f func()) { - t.onPublisherInitialConnected = f -} - -func (t *TransportManager) OnPublisherTrack(f func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver)) { - t.publisher.OnTrack(f) -} - func (t *TransportManager) HasPublisherEverConnected() bool { return t.publisher.HasEverConnected() } @@ -234,22 +205,6 @@ func (t *TransportManager) WritePublisherRTCP(pkts []rtcp.Packet) error { return t.publisher.WriteRTCP(pkts) } -func (t *TransportManager) OnSubscriberICECandidate(f func(c *webrtc.ICECandidate) error) { - t.subscriber.OnICECandidate(f) -} - -func (t *TransportManager) OnSubscriberOffer(f func(offer webrtc.SessionDescription) error) { - t.subscriber.OnOffer(f) -} - -func (t *TransportManager) OnSubscriberInitialConnected(f func()) { - t.onSubscriberInitialConnected = f -} - -func (t *TransportManager) OnSubscriberStreamStateChange(f func(update *streamallocator.StreamStateUpdate) error) { - t.subscriber.OnStreamStateChange(f) -} - func (t *TransportManager) HasSubscriberEverConnected() bool { return t.subscriber.HasEverConnected() } @@ -274,23 +229,6 @@ func (t *TransportManager) GetSubscriberPacer() pacer.Pacer { return t.subscriber.GetPacer() } -func (t *TransportManager) OnPrimaryTransportInitialConnected(f func()) { - t.onPrimaryTransportInitialConnected = f -} - -func (t *TransportManager) OnPrimaryTransportFullyEstablished(f func()) { - t.getTransport(true).OnFullyEstablished(f) -} - -func (t *TransportManager) OnAnyTransportFailed(f func()) { - t.onAnyTransportFailed = f -} - -func (t *TransportManager) OnAnyTransportNegotiationFailed(f func()) { - t.publisher.OnNegotiationFailed(f) - t.subscriber.OnNegotiationFailed(f) -} - func (t *TransportManager) AddSubscribedTrack(subTrack types.SubscribedTrack) { t.subscriber.AddTrackToStreamAllocator(subTrack) } @@ -299,11 +237,6 @@ func (t *TransportManager) RemoveSubscribedTrack(subTrack types.SubscribedTrack) t.subscriber.RemoveTrackFromStreamAllocator(subTrack) } -func (t *TransportManager) OnDataMessage(f func(kind livekit.DataPacket_Kind, data []byte)) { - // upstream data always comes in via publisher peer connection irrespective of which is primary - t.publisher.OnDataPacket(f) -} - func (t *TransportManager) SendDataPacket(dp *livekit.DataPacket, data []byte) error { // downstream data is sent via primary peer connection return t.getTransport(true).SendDataPacket(dp, data) @@ -736,10 +669,7 @@ func (t *TransportManager) onMediaLossUpdate(loss uint8) { t.lock.Unlock() t.params.Logger.Infow("udp connection unstable, switch to tcp", "signalingRTT", t.signalingRTT) - t.handleConnectionFailed(true) - if t.onAnyTransportFailed != nil { - t.onAnyTransportFailed() - } + t.params.SubscriberHandler.OnFailed(true) return } } diff --git a/test/client/client.go b/test/client/client.go index fbe6954d9..e0be22e47 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -39,6 +39,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/rtc/transport/transportfakes" "github.com/livekit/livekit-server/pkg/rtc/types" ) @@ -199,33 +200,37 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { // i. e. the publisher transport on client side has SUBSCRIBER signal target (i. e. publisher is offerer). // Same applies for subscriber transport also // + publisherHandler := &transportfakes.FakeHandler{} c.publisher, err = rtc.NewPCTransport(rtc.TransportParams{ Config: &conf, DirectionConfig: conf.Subscriber, EnabledCodecs: codecs, IsOfferer: true, IsSendSide: true, + Handler: publisherHandler, }) if err != nil { return nil, err } + subscriberHandler := &transportfakes.FakeHandler{} c.subscriber, err = rtc.NewPCTransport(rtc.TransportParams{ Config: &conf, DirectionConfig: conf.Publisher, EnabledCodecs: codecs, + Handler: subscriberHandler, }) if err != nil { return nil, err } - c.publisher.OnICECandidate(func(ic *webrtc.ICECandidate) error { + publisherHandler.OnICECandidateCalls(func(ic *webrtc.ICECandidate, t livekit.SignalTarget) error { if ic == nil { return nil } return c.SendIceCandidate(ic, livekit.SignalTarget_PUBLISHER) }) - c.publisher.OnOffer(c.onOffer) - c.publisher.OnFullyEstablished(func() { + publisherHandler.OnOfferCalls(c.onOffer) + publisherHandler.OnFullyEstablishedCalls(func() { logger.Debugw("publisher fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) c.publisherFullyEstablished.Store(true) }) @@ -245,17 +250,17 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { return nil, err } - c.subscriber.OnICECandidate(func(ic *webrtc.ICECandidate) error { + subscriberHandler.OnICECandidateCalls(func(ic *webrtc.ICECandidate, t livekit.SignalTarget) error { if ic == nil { return nil } return c.SendIceCandidate(ic, livekit.SignalTarget_SUBSCRIBER) }) - c.subscriber.OnTrack(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + subscriberHandler.OnTrackCalls(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { go c.processTrack(track) }) - c.subscriber.OnDataPacket(c.handleDataMessage) - c.subscriber.OnInitialConnected(func() { + subscriberHandler.OnDataPacketCalls(c.handleDataMessage) + subscriberHandler.OnInitialConnectedCalls(func() { logger.Debugw("subscriber initial connected", "participant", c.localParticipant.Identity) c.lock.Lock() @@ -272,11 +277,11 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { go c.OnConnected() } }) - c.subscriber.OnFullyEstablished(func() { + subscriberHandler.OnFullyEstablishedCalls(func() { logger.Debugw("subscriber fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) c.subscriberFullyEstablished.Store(true) }) - c.subscriber.OnAnswer(func(answer webrtc.SessionDescription) error { + subscriberHandler.OnAnswerCalls(func(answer webrtc.SessionDescription) error { // send remote an answer logger.Infow("sending subscriber answer", "participant", c.localParticipant.Identity,