diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 662caf8d6..fb57e78b0 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -70,7 +70,7 @@ type TrackSender interface { ) error Resync() SetReceiver(TrackReceiver) - ReceiverRestart() + ReceiverRestart(TrackReceiver) } // ------------------------------------------------------------------- @@ -1635,14 +1635,18 @@ func (d *DownTrack) Resync() { d.forwarder.Resync() } -func (d *DownTrack) ReceiverRestart() { +func (d *DownTrack) ReceiverRestart(rcvr TrackReceiver) { + if rcvr.Mime() != d.Receiver().Mime() { + d.params.Logger.Infow("upstream receiver restart - skipped", "mime", d.Receiver().Mime().String(), "newMime", rcvr.Mime().String()) + return + } + d.bindLock.Lock() codec := d.codec.Load().(webrtc.RTPCodecCapability) d.bindLock.Unlock() - d.params.Logger.Infow("upstream receiver restart") - receiver := d.Receiver() + d.params.Logger.Infow("upstream receiver restart", "mime", receiver.Mime().String()) d.forwarder.Restart() d.forwarder.DetermineCodec(codec, receiver.HeaderExtensions(), receiver.VideoLayerMode()) } diff --git a/pkg/sfu/receiver_base.go b/pkg/sfu/receiver_base.go index e84478026..ac878aaef 100644 --- a/pkg/sfu/receiver_base.go +++ b/pkg/sfu/receiver_base.go @@ -373,9 +373,12 @@ func (r *ReceiverBase) restartInternal(reason string, isDetected bool) { return } r.restartInProgress = true + + // 2. advance forwarder generation + r.forwardersGeneration.Inc() r.bufferMu.Unlock() - // 2. restart all the buffers + // 3. restart all the buffers // if a stream was detected, skip external restart // // NOTE: The case of external restart and detected restart (which usually comes from one buffer) @@ -392,28 +395,28 @@ func (r *ReceiverBase) restartInternal(reason string, isDetected bool) { } } - // 3. wait for the forwarders to finish - r.stopForwarderGeneration() + // 4. wait for the forwarders to finish + r.waitForForwardersStop() - // 4. reset stream tracker + // 5. reset stream tracker r.streamTrackerManager.RemoveAllTrackers() - // 5. signal attached downtracks to resync so that they can have proper sequencing on a receiver restart + // 6. signal attached downtracks to resync so that they can have proper sequencing on a receiver restart r.downTrackSpreader.Broadcast(func(dt TrackSender) { - dt.ReceiverRestart() + dt.ReceiverRestart(r) }) if rt := r.loadREDTransformer(); rt != nil { rt.OnStreamRestart() } - // 6. move forwarder generation ahead + // 7. move forwarder generation ahead r.startForwarderGeneration() r.bufferMu.Lock() - // 7. release restart hold + // 8. release restart hold r.restartInProgress = false - // 8. restart forwarders + // 9. restart forwarders for layer, buff := range r.buffers { if buff == nil { continue @@ -851,9 +854,8 @@ func (r *ReceiverBase) startForwarderGeneration() { r.forwardersWaitGroup = &sync.WaitGroup{} } -func (r *ReceiverBase) stopForwarderGeneration() { +func (r *ReceiverBase) waitForForwardersStop() { r.bufferMu.Lock() - r.forwardersGeneration.Inc() forwarderWaitGroup := r.forwardersWaitGroup r.bufferMu.Unlock() @@ -869,8 +871,9 @@ func (r *ReceiverBase) startForwarderForBufferLocked(layer int32, buff buffer.Bu r.forwardersWaitGroup.Add(1) - r.params.Logger.Debugw("starting forwarder", "layer", layer) - go r.forwardRTP(layer, buff, r.forwardersGeneration.Load(), r.forwardersWaitGroup) + forwarderGeneration := r.forwardersGeneration.Load() + r.params.Logger.Debugw("starting forwarder", "layer", layer, "forwarderGeneration", forwarderGeneration) + go r.forwardRTP(layer, buff, forwarderGeneration, r.forwardersWaitGroup) } func (r *ReceiverBase) forwardRTP( @@ -917,7 +920,12 @@ func (r *ReceiverBase) forwardRTP( } pktBuf := make([]byte, bucket.RTPMaxPktSize) - r.params.Logger.Debugw("starting forwarding", "layer", layer, "forwarderGeneration", forwarderGeneration) + r.params.Logger.Debugw( + "starting forwarding", + "layer", layer, + "forwarderGeneration", forwarderGeneration, + "forwardersGeneration", r.forwardersGeneration.Load(), + ) for r.forwardersGeneration.Load() == forwarderGeneration { extPkt, err = buff.ReadExtended(pktBuf) diff --git a/pkg/sfu/redprimaryreceiver.go b/pkg/sfu/redprimaryreceiver.go index e05448131..8f362612e 100644 --- a/pkg/sfu/redprimaryreceiver.go +++ b/pkg/sfu/redprimaryreceiver.go @@ -145,7 +145,7 @@ func (r *RedPrimaryReceiver) ResyncDownTracks() { func (r *RedPrimaryReceiver) OnStreamRestart() { r.downTrackSpreader.Broadcast(func(dt TrackSender) { - dt.ReceiverRestart() + dt.ReceiverRestart(r) }) } diff --git a/pkg/sfu/redreceiver.go b/pkg/sfu/redreceiver.go index 9c93a3f26..33af02f41 100644 --- a/pkg/sfu/redreceiver.go +++ b/pkg/sfu/redreceiver.go @@ -142,7 +142,7 @@ func (r *RedReceiver) ResyncDownTracks() { func (r *RedReceiver) OnStreamRestart() { r.downTrackSpreader.Broadcast(func(dt TrackSender) { - dt.ReceiverRestart() + dt.ReceiverRestart(r) }) }