diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 96cef2b8e..ef5e22a98 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -484,7 +484,9 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, return codec, nil } - // Bind is called under RTPSender.mu lock, call the RTPSender.GetParameters in goroutine to avoid deadlock + // Bind is called under RTPSender.mu lock, + // call the RTPSender.GetParameters (which setRTPHeaderExtensions invokes) + // in goroutine to avoid deadlock go d.setRTPHeaderExtensions() doBind := func() { @@ -777,25 +779,24 @@ func (d *DownTrack) SetReceiver(r TrackReceiver) { // Sets RTP header extensions for this track func (d *DownTrack) setRTPHeaderExtensions() { - d.bindLock.Lock() - defer d.bindLock.Unlock() - sal := d.getStreamAllocatorListener() if sal == nil { return } - - var extensions []webrtc.RTPHeaderExtensionParameter - if tr := d.transceiver.Load(); tr != nil { - if sender := tr.Sender(); sender != nil { - extensions = sender.GetParameters().HeaderExtensions - d.params.Logger.Debugw("negotiated downtrack extensions", "extensions", extensions) - } - } - isBWEEnabled := sal.IsBWEEnabled(d) bweType := sal.BWEType() + tr := d.transceiver.Load() + if tr == nil { + return + } + var extensions []webrtc.RTPHeaderExtensionParameter + if sender := tr.Sender(); sender != nil { + extensions = sender.GetParameters().HeaderExtensions + d.params.Logger.Debugw("negotiated downtrack extensions", "extensions", extensions) + } + + d.bindLock.Lock() for _, ext := range extensions { switch ext.URI { case sdp.ABSSendTimeURI: @@ -818,6 +819,7 @@ func (d *DownTrack) setRTPHeaderExtensions() { d.absCaptureTimeExtID = ext.ID } } + d.bindLock.Unlock() } // Kind controls if this TrackLocal is audio or video @@ -840,6 +842,7 @@ func (d *DownTrack) SSRCRTX() uint32 { func (d *DownTrack) SetTransceiver(transceiver *webrtc.RTPTransceiver) { d.transceiver.Store(transceiver) + d.setRTPHeaderExtensions() } func (d *DownTrack) GetTransceiver() *webrtc.RTPTransceiver { @@ -957,7 +960,11 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { copy(payload, tp.codecBytes) n := copy(payload[len(tp.codecBytes):], extPkt.Packet.Payload[tp.incomingHeaderSize:]) if n != len(extPkt.Packet.Payload[tp.incomingHeaderSize:]) { - d.params.Logger.Errorw("payload overflow", nil, "want", len(extPkt.Packet.Payload[tp.incomingHeaderSize:]), "have", n) + d.params.Logger.Errorw( + "payload overflow", nil, + "want", len(extPkt.Packet.Payload[tp.incomingHeaderSize:]), + "have", n, + ) PacketFactory.Put(poolEntity) return ErrPayloadOverflow } diff --git a/pkg/sfu/forwardstats.go b/pkg/sfu/forwardstats.go index fc1403c84..cf55dfb64 100644 --- a/pkg/sfu/forwardstats.go +++ b/pkg/sfu/forwardstats.go @@ -49,7 +49,7 @@ func (s *ForwardStats) Update(arrival, left int64) (int64, bool) { return transit, isHighForwardingLatency } -func (s *ForwardStats) getStats(shortDuration time.Duration) (time.Duration, time.Duration, time.Duration, time.Duration) { +func (s *ForwardStats) GetStats(shortDuration time.Duration) (time.Duration, time.Duration, time.Duration, time.Duration) { s.lock.Lock() wLong := s.latency.Summarize() wShort := s.latency.SummarizeLast(shortDuration) @@ -93,7 +93,7 @@ func (s *ForwardStats) report(reportInterval time.Duration) { return case <-ticker.C: - latencyLong, jitterLong, latencyShort, jitterShort := s.getStats(reportInterval) + latencyLong, jitterLong, latencyShort, jitterShort := s.GetStats(reportInterval) prometheus.RecordForwardJitter(uint32(jitterShort.Microseconds()), uint32(jitterLong.Microseconds())) prometheus.RecordForwardLatency(uint32(latencyShort.Microseconds()), uint32(latencyLong.Microseconds())) }