From efa85221b3a8105ffe32b0b2af02366db1fdf1eb Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Mon, 2 Sep 2024 06:10:14 +0000 Subject: [PATCH] Negotiate downttrack for subscriber before receiver is ready (#2970) * Negotiate downttrack for subscriber before receiver is ready This change will save 1 round sdp negotiation time for subscribing to simulcast-codec or remote node track * solve comment * Fix simulcast-codec case --- pkg/rtc/mediaengine.go | 10 +- pkg/rtc/mediatracksubscriptions.go | 12 +- pkg/rtc/participant_internal_test.go | 2 +- pkg/rtc/wrappedreceiver.go | 168 +++++++++++++++++-- pkg/sfu/downtrack.go | 231 +++++++++++++++++++-------- pkg/sfu/receiver.go | 9 ++ 6 files changed, 341 insertions(+), 91 deletions(-) diff --git a/pkg/rtc/mediaengine.go b/pkg/rtc/mediaengine.go index 4dedc085a..75b24e9fc 100644 --- a/pkg/rtc/mediaengine.go +++ b/pkg/rtc/mediaengine.go @@ -28,13 +28,13 @@ const ( videoRTXMimeType = "video/rtx" ) -var opusCodecCapability = webrtc.RTPCodecCapability{ +var OpusCodecCapability = webrtc.RTPCodecCapability{ MimeType: webrtc.MimeTypeOpus, ClockRate: 48000, Channels: 2, SDPFmtpLine: "minptime=10;useinbandfec=1", } -var redCodecCapability = webrtc.RTPCodecCapability{ +var RedCodecCapability = webrtc.RTPCodecCapability{ MimeType: sfu.MimeTypeAudioRed, ClockRate: 48000, Channels: 2, @@ -46,7 +46,7 @@ var videoRTX = webrtc.RTPCodecCapability{ } func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedback RTCPFeedbackConfig, filterOutH264HighProfile bool) error { - opusCodec := opusCodecCapability + opusCodec := OpusCodecCapability opusCodec.RTCPFeedback = rtcpFeedback.Audio var opusPayload webrtc.PayloadType if IsCodecEnabled(codecs, opusCodec) { @@ -58,9 +58,9 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac return err } - if IsCodecEnabled(codecs, redCodecCapability) { + if IsCodecEnabled(codecs, RedCodecCapability) { if err := me.RegisterCodec(webrtc.RTPCodecParameters{ - RTPCodecCapability: redCodecCapability, + RTPCodecCapability: RedCodecCapability, PayloadType: 63, }, webrtc.RTPCodecTypeAudio); err != nil { return err diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 20b73db55..3f212b6f1 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -23,6 +23,7 @@ import ( "github.com/pion/webrtc/v3" "go.uber.org/atomic" + "github.com/livekit/livekit-server/pkg/sfu/buffer" sutils "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -170,12 +171,21 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * // Bind callback can happen from replaceTrack, so set it up early var reusingTransceiver atomic.Bool var dtState sfu.DownTrackState + downTrack.OnCodecNegotiated(func(codec webrtc.RTPCodecCapability) { + if !wr.DetermineReceiver(codec) { + if t.onSubscriberMaxQualityChange != nil { + go func() { + spatial := buffer.VideoQualityToSpatialLayer(livekit.VideoQuality_HIGH, t.params.MediaTrack.ToProto()) + t.onSubscriberMaxQualityChange(subscriberID, codec, spatial) + }() + } + } + }) downTrack.OnBinding(func(err error) { if err != nil { go subTrack.Bound(err) return } - wr.DetermineReceiver(downTrack.Codec()) if reusingTransceiver.Load() { downTrack.SeedState(dtState) } diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 0f9825536..74fcc6090 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -621,7 +621,7 @@ func TestPreferAudioCodecForRed(t *testing.T) { me := webrtc.MediaEngine{} me.RegisterDefaultCodecs() require.NoError(t, me.RegisterCodec(webrtc.RTPCodecParameters{ - RTPCodecCapability: redCodecCapability, + RTPCodecCapability: RedCodecCapability, PayloadType: 63, }, webrtc.RTPCodecTypeAudio)) diff --git a/pkg/rtc/wrappedreceiver.go b/pkg/rtc/wrappedreceiver.go index 19cb3a694..cfde44114 100644 --- a/pkg/rtc/wrappedreceiver.go +++ b/pkg/rtc/wrappedreceiver.go @@ -42,10 +42,11 @@ type WrappedReceiverParams struct { type WrappedReceiver struct { sfu.TrackReceiver - params WrappedReceiverParams - receivers []sfu.TrackReceiver - codecs []webrtc.RTPCodecParameters - determinedCodec webrtc.RTPCodecCapability + params WrappedReceiverParams + receivers []sfu.TrackReceiver + codecs []webrtc.RTPCodecParameters + determinedCodec webrtc.RTPCodecCapability + onReadyCallbacks []func() } func NewWrappedReceiver(params WrappedReceiverParams) *WrappedReceiver { @@ -59,13 +60,13 @@ func NewWrappedReceiver(params WrappedReceiverParams) *WrappedReceiver { if strings.EqualFold(codecs[0].MimeType, sfu.MimeTypeAudioRed) { // if upstream is opus/red, then add opus to match clients that don't support red codecs = append(codecs, webrtc.RTPCodecParameters{ - RTPCodecCapability: opusCodecCapability, + RTPCodecCapability: OpusCodecCapability, PayloadType: 111, }) } else if !params.DisableRed && strings.EqualFold(codecs[0].MimeType, webrtc.MimeTypeOpus) { // if upstream is opus only and red enabled, add red to match clients that support red codecs = append(codecs, webrtc.RTPCodecParameters{ - RTPCodecCapability: redCodecCapability, + RTPCodecCapability: RedCodecCapability, PayloadType: 63, }) // prefer red codec @@ -88,7 +89,8 @@ func (r *WrappedReceiver) StreamID() string { return r.params.StreamId } -func (r *WrappedReceiver) DetermineReceiver(codec webrtc.RTPCodecCapability) { +// DetermineReceiver determines the receiver of negotiated codec and return ready state of the receiver +func (r *WrappedReceiver) DetermineReceiver(codec webrtc.RTPCodecCapability) bool { r.determinedCodec = codec for _, receiver := range r.receivers { if c := receiver.Codec(); strings.EqualFold(c.MimeType, codec.MimeType) { @@ -109,6 +111,18 @@ func (r *WrappedReceiver) DetermineReceiver(codec webrtc.RTPCodecCapability) { r.TrackReceiver = r.receivers[0] } } + if r.TrackReceiver != nil { + for _, f := range r.onReadyCallbacks { + r.TrackReceiver.AddOnReady(f) + } + r.onReadyCallbacks = nil + + if d, ok := r.TrackReceiver.(*DummyReceiver); ok { + return d.IsReady() + } + return true + } + return false } func (r *WrappedReceiver) Codecs() []webrtc.RTPCodecParameters { @@ -123,6 +137,14 @@ func (r *WrappedReceiver) DeleteDownTrack(participantID livekit.ParticipantID) { } } +func (r *WrappedReceiver) AddOnReady(f func()) { + if r.TrackReceiver != nil { + r.TrackReceiver.AddOnReady(f) + } else { + r.onReadyCallbacks = append(r.onReadyCallbacks, f) + } +} + // -------------------------------------------- type DummyReceiver struct { @@ -132,8 +154,9 @@ type DummyReceiver struct { codec webrtc.RTPCodecParameters headerExtensions []webrtc.RTPHeaderExtensionParameter - downTrackLock sync.Mutex - downTracks map[livekit.ParticipantID]sfu.TrackSender + downTrackLock sync.Mutex + downTracks map[livekit.ParticipantID]sfu.TrackSender + onReadyCallbacks []func() settingsLock sync.Mutex maxExpectedLayerValid bool @@ -142,6 +165,8 @@ type DummyReceiver struct { paused bool baseTime time.Time + + redReceiver, primaryReceiver *DummyRedReceiver } func NewDummyReceiver(trackID livekit.TrackID, streamId string, codec webrtc.RTPCodecParameters, headerExtensions []webrtc.RTPHeaderExtensionParameter) *DummyReceiver { @@ -161,15 +186,23 @@ func (d *DummyReceiver) Receiver() sfu.TrackReceiver { } func (d *DummyReceiver) Upgrade(receiver sfu.TrackReceiver) { - d.receiver.CompareAndSwap(nil, receiver) + if !d.receiver.CompareAndSwap(nil, receiver) { + return + } d.downTrackLock.Lock() for _, t := range d.downTracks { receiver.AddDownTrack(t) } d.downTracks = make(map[livekit.ParticipantID]sfu.TrackSender) + onReadyCallbacks := d.onReadyCallbacks + d.onReadyCallbacks = nil d.downTrackLock.Unlock() + for _, f := range onReadyCallbacks { + receiver.AddOnReady(f) + } + d.settingsLock.Lock() if d.maxExpectedLayerValid { receiver.SetMaxExpectedSpatialLayer(d.maxExpectedLayer) @@ -180,6 +213,13 @@ func (d *DummyReceiver) Upgrade(receiver sfu.TrackReceiver) { receiver.SetUpTrackPaused(d.paused) } d.pausedValid = false + + if d.primaryReceiver != nil { + d.primaryReceiver.upgrade(receiver) + } + if d.redReceiver != nil { + d.redReceiver.upgrade(receiver) + } d.settingsLock.Unlock() } @@ -314,12 +354,28 @@ func (d *DummyReceiver) IsClosed() bool { } func (d *DummyReceiver) GetPrimaryReceiverForRed() sfu.TrackReceiver { - // DummyReceiver used for video, it should not have RED codec - return d + d.settingsLock.Lock() + defer d.settingsLock.Unlock() + + if d.primaryReceiver == nil { + d.primaryReceiver = NewDummyRedReceiver(d, false) + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + d.primaryReceiver.upgrade(r) + } + } + return d.primaryReceiver } func (d *DummyReceiver) GetRedReceiver() sfu.TrackReceiver { - return d + d.settingsLock.Lock() + defer d.settingsLock.Unlock() + if d.redReceiver == nil { + d.redReceiver = NewDummyRedReceiver(d, true) + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + d.redReceiver.upgrade(r) + } + } + return d.redReceiver } func (d *DummyReceiver) GetTrackStats() *livekit.RTPStats { @@ -335,3 +391,89 @@ func (d *DummyReceiver) GetMonotonicNowUnixNano() int64 { } return d.baseTime.Add(time.Since(d.baseTime)).UnixNano() } + +func (d *DummyReceiver) AddOnReady(f func()) { + var receiver sfu.TrackReceiver + d.downTrackLock.Lock() + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + receiver = r + } else { + d.onReadyCallbacks = append(d.onReadyCallbacks, f) + } + d.downTrackLock.Unlock() + if receiver != nil { + receiver.AddOnReady(f) + } +} + +func (d *DummyReceiver) IsReady() bool { + return d.receiver.Load() != nil +} + +// -------------------------------------------- + +type DummyRedReceiver struct { + *DummyReceiver + redReceiver atomic.Value // sfu.TrackReceiver + // indicates this receiver is for RED encoding receiver of primary codec OR + // primary decoding receiver of RED codec + isRedEncoding bool + + downTrackLock sync.Mutex + downTracks map[livekit.ParticipantID]sfu.TrackSender +} + +func NewDummyRedReceiver(d *DummyReceiver, isRedEncoding bool) *DummyRedReceiver { + return &DummyRedReceiver{ + DummyReceiver: d, + isRedEncoding: isRedEncoding, + downTracks: make(map[livekit.ParticipantID]sfu.TrackSender), + } +} + +func (d *DummyRedReceiver) AddDownTrack(track sfu.TrackSender) error { + d.downTrackLock.Lock() + defer d.downTrackLock.Unlock() + + if r, ok := d.redReceiver.Load().(sfu.TrackReceiver); ok { + r.AddDownTrack(track) + } else { + d.downTracks[track.SubscriberID()] = track + } + return nil +} + +func (d *DummyRedReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) { + d.downTrackLock.Lock() + defer d.downTrackLock.Unlock() + + if r, ok := d.redReceiver.Load().(sfu.TrackReceiver); ok { + r.DeleteDownTrack(subscriberID) + } else { + delete(d.downTracks, subscriberID) + } +} + +func (d *DummyRedReceiver) ReadRTP(buf []byte, layer uint8, esn uint64) (int, error) { + if r, ok := d.redReceiver.Load().(sfu.TrackReceiver); ok { + return r.ReadRTP(buf, layer, esn) + } + return 0, errors.New("no receiver") +} + +func (d *DummyRedReceiver) upgrade(receiver sfu.TrackReceiver) { + var redReceiver sfu.TrackReceiver + if d.isRedEncoding { + redReceiver = receiver.GetRedReceiver() + } else { + redReceiver = receiver.GetPrimaryReceiverForRed() + } + d.redReceiver.Store(redReceiver) + + d.downTrackLock.Lock() + for _, t := range d.downTracks { + redReceiver.AddDownTrack(t) + } + d.downTracks = make(map[livekit.ParticipantID]sfu.TrackSender) + d.downTrackLock.Unlock() +} diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 9b393818c..24b880852 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -256,15 +256,17 @@ type DownTrack struct { listenerLock sync.RWMutex receiverReportListeners []ReceiverReportListener - bindLock sync.Mutex - bound atomic.Bool - onBinding func(error) + bindLock sync.Mutex + bindState atomic.Value + onBinding func(error) + bindOnReceiverReady func() isClosed atomic.Bool connected atomic.Bool bindAndConnectedOnce atomic.Bool writable atomic.Bool writeStopped atomic.Bool + isReceiverReady bool rtpStats *buffer.RTPStatsSender @@ -305,10 +307,33 @@ type DownTrack struct { onMaxSubscribedLayerChanged func(dt *DownTrack, layer int32) onRttUpdate func(dt *DownTrack, rtt uint32) onCloseHandler func(isExpectedToResume bool) + onCodecNegotiated func(webrtc.RTPCodecCapability) createdAt int64 } +type bindState int + +const ( + bindStateUnbound bindState = iota + // downtrack negotiated, but waiting for receiver to be ready to start forwarding + bindStateWaitForReceiverReady + // downtrack is bound and ready to forward + bindStateBound +) + +func (bs bindState) String() string { + switch bs { + case bindStateUnbound: + return "unbound" + case bindStateWaitForReceiverReady: + return "waitForReceiverReady" + case bindStateBound: + return "bound" + } + return "unknown" +} + // NewDownTrack returns a DownTrack. func NewDownTrack(params DowntrackParams) (*DownTrack, error) { codecs := params.Codecs @@ -333,6 +358,7 @@ func NewDownTrack(params DowntrackParams) (*DownTrack, error) { keyFrameRequesterCh: make(chan struct{}, 1), createdAt: time.Now().UnixNano(), } + d.bindState.Store(bindStateUnbound) d.params.Logger = params.Logger.WithValues( "mime", codecs[0].MimeType, "subscriberID", d.SubscriberID(), @@ -373,17 +399,25 @@ func NewDownTrack(params DowntrackParams) (*DownTrack, error) { go d.maxLayerNotifierWorker() go d.keyFrameRequester() } + + d.params.Receiver.AddOnReady(d.handleReceiverReady) d.params.Logger.Debugw("downtrack created", "upstreamCodecs", d.upstreamCodecs) return d, nil } +func (d *DownTrack) OnCodecNegotiated(f func(webrtc.RTPCodecCapability)) { + d.bindLock.Lock() + d.onCodecNegotiated = f + d.bindLock.Unlock() +} + // Bind is called by the PeerConnection after negotiation is complete // This asserts that the code requested is supported by the remote peer. // If so it sets up all the state (SSRC and PayloadType) to have a call func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) { d.bindLock.Lock() - if d.bound.Load() { + if d.bindState.Load() != bindStateUnbound { d.bindLock.Unlock() return webrtc.RTPCodecParameters{}, ErrDownTrackAlreadyBound } @@ -406,24 +440,6 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, onBinding(err) } return webrtc.RTPCodecParameters{}, err - } else if strings.EqualFold(matchedUpstreamCodec.MimeType, "audio/red") { - d.isRED = true - for _, c := range d.upstreamCodecs { - // assume upstream primary codec is opus since we only support it for audio now - if strings.EqualFold(c.MimeType, "audio/opus") { - d.upstreamPrimaryPT = uint8(c.PayloadType) - break - } - } - if d.upstreamPrimaryPT == 0 { - d.params.Logger.Errorw("failed to find upstream primary opus payload type for RED", nil, "matchedCodec", codec, "upstreamCodec", d.upstreamCodecs) - } - - var primaryPT, secondaryPT int - if n, err := fmt.Sscanf(codec.SDPFmtpLine, "%d/%d", &primaryPT, &secondaryPT); err != nil || n != 2 { - d.params.Logger.Errorw("failed to parse primary and secondary payload type for RED", err, "matchedCodec", codec) - } - d.primaryPT = uint8(primaryPT) } // if a downtrack is closed before bind, it already unsubscribed from client, don't do subsequent operation and return here. @@ -433,43 +449,6 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, return codec, nil } - logFields := []interface{}{ - "codecs", d.upstreamCodecs, - "matchCodec", codec, - "ssrc", t.SSRC(), - } - if d.isRED { - logFields = append(logFields, - "isRED", d.isRED, - "upstreamPrimaryPT", d.upstreamPrimaryPT, - "primaryPT", d.primaryPT, - ) - } - d.params.Logger.Debugw("DownTrack.Bind", - logFields..., - ) - - d.ssrc = uint32(t.SSRC()) - d.payloadType = uint8(codec.PayloadType) - d.writeStream = t.WriteStream() - d.mime = strings.ToLower(codec.MimeType) - if rr := d.params.BufferFactory.GetOrNew(packetio.RTCPBufferPacket, uint32(t.SSRC())).(*buffer.RTCPReader); rr != nil { - rr.OnPacket(func(pkt []byte) { - d.handleRTCP(pkt) - }) - d.rtcpReader = rr - } - - d.sequencer = newSequencer(d.params.MaxTrack, d.kind == webrtc.RTPCodecTypeVideo, d.params.Logger) - - d.codec = codec.RTPCodecCapability - if d.onBinding != nil { - d.onBinding(nil) - } - d.bound.Store(true) - d.onBindAndConnectedChange() - d.bindLock.Unlock() - // Bind is called under RTPSender.mu lock, call the RTPSender.GetParameters in goroutine to avoid deadlock go func() { if tr := d.transceiver.Load(); tr != nil { @@ -481,18 +460,129 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, } }() - d.forwarder.DetermineCodec(d.codec, d.params.Receiver.HeaderExtensions()) - d.params.Logger.Debugw("downtrack bound") + doBind := func() { + d.bindLock.Lock() + if d.IsClosed() { + d.bindLock.Unlock() + d.params.Logger.Debugw("DownTrack closed before bind") + return + } + if bs := d.bindState.Load(); bs != bindStateWaitForReceiverReady { + d.bindLock.Unlock() + d.params.Logger.Debugw("DownTrack.Bind: not in wait for receiver state", "state", bs) + return + } + + if strings.EqualFold(matchedUpstreamCodec.MimeType, "audio/red") { + d.isRED = true + for _, c := range d.upstreamCodecs { + // assume upstream primary codec is opus since we only support it for audio now + if strings.EqualFold(c.MimeType, "audio/opus") { + d.upstreamPrimaryPT = uint8(c.PayloadType) + break + } + } + if d.upstreamPrimaryPT == 0 { + d.params.Logger.Errorw("failed to find upstream primary opus payload type for RED", nil, "matchedCodec", codec, "upstreamCodec", d.upstreamCodecs) + } + + var primaryPT, secondaryPT int + if n, err := fmt.Sscanf(codec.SDPFmtpLine, "%d/%d", &primaryPT, &secondaryPT); err != nil || n != 2 { + d.params.Logger.Errorw("failed to parse primary and secondary payload type for RED", err, "matchedCodec", codec) + } + d.primaryPT = uint8(primaryPT) + } + + logFields := []interface{}{ + "codecs", d.upstreamCodecs, + "matchCodec", codec, + "ssrc", t.SSRC(), + } + if d.isRED { + logFields = append(logFields, + "isRED", d.isRED, + "upstreamPrimaryPT", d.upstreamPrimaryPT, + "primaryPT", d.primaryPT, + ) + } + d.params.Logger.Debugw("DownTrack.Bind", + logFields..., + ) + + d.ssrc = uint32(t.SSRC()) + d.payloadType = uint8(codec.PayloadType) + d.writeStream = t.WriteStream() + d.mime = strings.ToLower(codec.MimeType) + if rr := d.params.BufferFactory.GetOrNew(packetio.RTCPBufferPacket, uint32(t.SSRC())).(*buffer.RTCPReader); rr != nil { + rr.OnPacket(func(pkt []byte) { + d.handleRTCP(pkt) + }) + d.rtcpReader = rr + } + + d.sequencer = newSequencer(d.params.MaxTrack, d.kind == webrtc.RTPCodecTypeVideo, d.params.Logger) + + d.codec = codec.RTPCodecCapability + if d.onBinding != nil { + d.onBinding(nil) + } + d.setBindStateLocked(bindStateBound) + d.bindLock.Unlock() + + d.forwarder.DetermineCodec(d.codec, d.params.Receiver.HeaderExtensions()) + d.params.Logger.Debugw("downtrack bound") + } + + isReceiverReady := d.isReceiverReady + if !isReceiverReady { + d.params.Logger.Debugw("downtrack bound: receiver not ready", "codec", codec) + d.bindOnReceiverReady = doBind + d.setBindStateLocked(bindStateWaitForReceiverReady) + } + + onCodecNegotiated := d.onCodecNegotiated + d.bindLock.Unlock() + + if onCodecNegotiated != nil { + onCodecNegotiated(codec.RTPCodecCapability) + } + + if isReceiverReady { + doBind() + } return codec, nil } +func (d *DownTrack) setBindStateLocked(state bindState) { + if d.bindState.Swap(state) == state { + return + } + + if state == bindStateBound || state == bindStateUnbound { + d.bindOnReceiverReady = nil + d.onBindAndConnectedChange() + } +} + +func (d *DownTrack) handleReceiverReady() { + d.params.Logger.Debugw("downtrack receiver ready") + d.bindLock.Lock() + d.isReceiverReady = true + doBind := d.bindOnReceiverReady + d.bindOnReceiverReady = nil + d.bindLock.Unlock() + + if doBind != nil { + doBind() + } +} + // Unbind implements the teardown logic when the track is no longer needed. This happens // because a track has been stopped. func (d *DownTrack) Unbind(_ webrtc.TrackLocalContext) error { d.bindLock.Lock() - d.bound.Store(false) - d.onBindAndConnectedChange() + d.setBindStateLocked(bindStateUnbound) d.bindLock.Unlock() return nil } @@ -1058,7 +1148,7 @@ func (d *DownTrack) CloseWithFlush(flush bool) { d.bindLock.Lock() d.params.Logger.Debugw("close down track", "flushBlankFrame", flush) - if d.bound.Load() { + if d.bindState.Load() == bindStateBound { d.forwarder.Mute(true, true) // write blank frames after disabling so that other frames do not interfere. @@ -1079,10 +1169,9 @@ func (d *DownTrack) CloseWithFlush(flush bool) { } } - d.bound.Store(false) - d.onBindAndConnectedChange() d.params.Logger.Debugw("closing sender", "kind", d.kind) } + d.setBindStateLocked(bindStateUnbound) d.params.Receiver.DeleteDownTrack(d.SubscriberID()) if d.rtcpReader != nil && flush { @@ -1391,7 +1480,7 @@ func (d *DownTrack) Resync() { func (d *DownTrack) CreateSourceDescriptionChunks() []rtcp.SourceDescriptionChunk { transceiver := d.transceiver.Load() - if !d.bound.Load() || transceiver == nil { + if d.bindState.Load() != bindStateBound || transceiver == nil { return nil } return []rtcp.SourceDescriptionChunk{ @@ -1412,7 +1501,7 @@ func (d *DownTrack) CreateSourceDescriptionChunks() []rtcp.SourceDescriptionChun } func (d *DownTrack) CreateSenderReport() *rtcp.SenderReport { - if !d.bound.Load() { + if d.bindState.Load() != bindStateBound { return nil } @@ -1935,7 +2024,7 @@ func (d *DownTrack) DebugInfo() map[string]interface{} { "StreamID": d.params.StreamID, "SSRC": d.ssrc, "MimeType": d.codec.MimeType, - "Bound": d.bound.Load(), + "BindState": d.bindState.Load().(bindState), "Muted": d.forwarder.IsMuted(), "PubMuted": d.forwarder.IsPubMuted(), "CurrentSpatialLayer": d.forwarder.CurrentLayer().Spatial, @@ -1999,8 +2088,8 @@ func (d *DownTrack) onBindAndConnectedChange() { if d.writeStopped.Load() { return } - d.writable.Store(d.connected.Load() && d.bound.Load()) - if d.connected.Load() && d.bound.Load() && !d.bindAndConnectedOnce.Swap(true) { + d.writable.Store(d.connected.Load() && d.bindState.Load() == bindStateBound) + if d.connected.Load() && d.bindState.Load() == bindStateBound && !d.bindAndConnectedOnce.Swap(true) { if d.activePaddingOnMuteUpTrack.Load() { go d.sendPaddingOnMute() } diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index e87471d99..5f8c00ee3 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -85,6 +85,10 @@ type TrackReceiver interface { GetTrackStats() *livekit.RTPStats GetMonotonicNowUnixNano() int64 + + // AddOnReady adds a function to be called when the receiver is ready, the callback + // could be called immediately if the receiver is ready when the callback is added + AddOnReady(func()) } // WebRTCReceiver receives a media track @@ -841,6 +845,11 @@ func (w *WebRTCReceiver) GetMonotonicNowUnixNano() int64 { return w.baseTime.Add(time.Since(w.baseTime)).UnixNano() } +func (w *WebRTCReceiver) AddOnReady(fn func()) { + // webRTCReceiver is always ready after created + fn() +} + // ----------------------------------------------------------- // closes all track senders in parallel, returns when all are closed