diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index 10d2eeeab..d42012e1e 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -297,10 +297,20 @@ func (t *MediaTrackReceiver) OnVideoLayerUpdate(f func(layers []*livekit.VideoLa func (t *MediaTrackReceiver) IsOpen() bool { t.lock.RLock() defer t.lock.RUnlock() - return t.state == mediaTrackReceiverStateOpen + if t.state != mediaTrackReceiverStateOpen { + return false + } + // If any one of the receivers has entered closed state, we would not consider the track open + for _, receiver := range t.receivers { + if receiver.IsClosed() { + return false + } + } + return true } func (t *MediaTrackReceiver) SetClosing() { + t.params.Logger.Infow("setting track to closing") t.lock.Lock() defer t.lock.Unlock() if t.state == mediaTrackReceiverStateOpen { diff --git a/pkg/rtc/wrappedreceiver.go b/pkg/rtc/wrappedreceiver.go index 839275485..84ac7108a 100644 --- a/pkg/rtc/wrappedreceiver.go +++ b/pkg/rtc/wrappedreceiver.go @@ -280,6 +280,13 @@ func (d *DummyReceiver) TrackInfo() *livekit.TrackInfo { return nil } +func (d *DummyReceiver) IsClosed() bool { + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + return r.IsClosed() + } + return false +} + func (d *DummyReceiver) GetPrimaryReceiverForRed() sfu.TrackReceiver { // DummyReceiver used for video, it should not have RED codec return d diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 2ec94bff2..bd9b417cf 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -39,6 +39,7 @@ type TrackReceiver interface { StreamID() string Codec() webrtc.RTPCodecParameters HeaderExtensions() []webrtc.RTPHeaderExtensionParameter + IsClosed() bool ReadRTP(buf []byte, layer uint8, sn uint16) (int, error) GetLayeredBitrate() ([]int32, Bitrates) @@ -249,6 +250,10 @@ func (w *WebRTCReceiver) GetConnectionScore() float32 { return w.connectionStats.GetScore() } +func (w *WebRTCReceiver) IsClosed() bool { + return w.closed.Load() +} + func (w *WebRTCReceiver) SetRTT(rtt uint32) { w.bufferMu.Lock() if w.rtt == rtt { diff --git a/pkg/sfu/redprimaryreceiver.go b/pkg/sfu/redprimaryreceiver.go index 8b889f88c..9789bb32e 100644 --- a/pkg/sfu/redprimaryreceiver.go +++ b/pkg/sfu/redprimaryreceiver.go @@ -88,6 +88,10 @@ func (r *RedPrimaryReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) r.downTrackSpreader.Free(subscriberID) } +func (r *RedPrimaryReceiver) IsClosed() bool { + return r.closed.Load() +} + func (r *RedPrimaryReceiver) CanClose() bool { return r.closed.Load() || r.downTrackSpreader.DownTrackCount() == 0 } diff --git a/pkg/sfu/redreceiver.go b/pkg/sfu/redreceiver.go index b5880cf66..b814940f4 100644 --- a/pkg/sfu/redreceiver.go +++ b/pkg/sfu/redreceiver.go @@ -89,6 +89,10 @@ func (r *RedReceiver) CanClose() bool { return r.closed.Load() || r.downTrackSpreader.DownTrackCount() == 0 } +func (r *RedReceiver) IsClosed() bool { + return r.closed.Load() +} + func (r *RedReceiver) Close() { r.closed.Store(true) for _, dt := range r.downTrackSpreader.ResetAndGetDownTracks() {