diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 49013d923..cb8e08981 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -517,8 +517,8 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) { return } - if tracker != nil && len(pkt.Packet.Payload) > 0 { - tracker.Observe(pkt.Packet.SequenceNumber, pkt.TemporalLayer, len(pkt.RawPacket)) + if tracker != nil { + tracker.Observe(pkt.Packet.SequenceNumber, pkt.TemporalLayer, len(pkt.RawPacket), len(pkt.Packet.Payload)) } w.downTrackMu.RLock() diff --git a/pkg/sfu/streamtracker.go b/pkg/sfu/streamtracker.go index 875955f46..65a608291 100644 --- a/pkg/sfu/streamtracker.go +++ b/pkg/sfu/streamtracker.go @@ -190,7 +190,7 @@ func (s *StreamTracker) SetPaused(paused bool) { } // Observe a packet that's received -func (s *StreamTracker) Observe(sn uint16, temporalLayer int32, pktSize int) { +func (s *StreamTracker) Observe(sn uint16, temporalLayer int32, pktSize int, payloadSize int) { s.lock.Lock() defer s.lock.Unlock() @@ -206,7 +206,7 @@ func (s *StreamTracker) Observe(sn uint16, temporalLayer int32, pktSize int) { s.countSinceLast = 1 s.lastBitrateReport = time.Now() - if temporalLayer >= 0 { + if temporalLayer >= 0 && payloadSize > 0 { s.bytesForBitrate[temporalLayer] += int64(pktSize) } @@ -223,7 +223,7 @@ func (s *StreamTracker) Observe(sn uint16, temporalLayer int32, pktSize int) { s.lastSN = sn s.countSinceLast++ - if temporalLayer >= 0 { + if temporalLayer >= 0 && payloadSize > 0 { s.bytesForBitrate[temporalLayer] += int64(pktSize) } } diff --git a/pkg/sfu/streamtracker_test.go b/pkg/sfu/streamtracker_test.go index 6b707ac64..b9fa11461 100644 --- a/pkg/sfu/streamtracker_test.go +++ b/pkg/sfu/streamtracker_test.go @@ -32,7 +32,7 @@ func TestStreamTracker(t *testing.T) { require.Equal(t, StreamStatusStopped, tracker.Status()) // observe first packet - tracker.Observe(1, 0, 0) + tracker.Observe(1, 0, 0, 0) testutils.WithTimeout(t, func() string { if callbackCalled.Load() { @@ -53,7 +53,7 @@ func TestStreamTracker(t *testing.T) { tracker.Start() require.Equal(t, StreamStatusStopped, tracker.Status()) - tracker.Observe(1, 0, 0) + tracker.Observe(1, 0, 0, 0) testutils.WithTimeout(t, func() string { if tracker.Status() == StreamStatusActive { return "" @@ -89,7 +89,7 @@ func TestStreamTracker(t *testing.T) { tracker.Start() require.Equal(t, StreamStatusStopped, tracker.Status()) - tracker.Observe(1, 0, 0) + tracker.Observe(1, 0, 0, 0) testutils.WithTimeout(t, func() string { if tracker.Status() == StreamStatusActive { return "" @@ -100,11 +100,11 @@ func TestStreamTracker(t *testing.T) { tracker.maybeSetStatus(StreamStatusStopped) - tracker.Observe(2, 0, 0) + tracker.Observe(2, 0, 0, 0) tracker.detectChanges() require.Equal(t, StreamStatusStopped, tracker.Status()) - tracker.Observe(3, 0, 0) + tracker.Observe(3, 0, 0, 0) tracker.detectChanges() require.Equal(t, StreamStatusActive, tracker.Status()) @@ -114,7 +114,7 @@ func TestStreamTracker(t *testing.T) { t.Run("changes to inactive when paused", func(t *testing.T) { tracker := newStreamTracker(5, 60, 500*time.Millisecond) tracker.Start() - tracker.Observe(1, 0, 0) + tracker.Observe(1, 0, 0, 0) testutils.WithTimeout(t, func() string { if tracker.Status() == StreamStatusActive { return "" @@ -140,7 +140,7 @@ func TestStreamTracker(t *testing.T) { require.Equal(t, StreamStatusStopped, tracker.Status()) // observe first packet - tracker.Observe(1, 0, 0) + tracker.Observe(1, 0, 0, 0) testutils.WithTimeout(t, func() string { if callbackCalled.Load() == 1 { @@ -154,10 +154,10 @@ func TestStreamTracker(t *testing.T) { require.Equal(t, uint32(1), callbackCalled.Load()) // observe a few more - tracker.Observe(2, 0, 0) - tracker.Observe(3, 0, 0) - tracker.Observe(4, 0, 0) - tracker.Observe(5, 0, 0) + tracker.Observe(2, 0, 0, 0) + tracker.Observe(3, 0, 0, 0) + tracker.Observe(4, 0, 0, 0) + tracker.Observe(5, 0, 0, 0) tracker.detectChanges() // should still be active @@ -168,7 +168,7 @@ func TestStreamTracker(t *testing.T) { require.Equal(t, StreamStatusStopped, tracker.Status()) // first packet after reset - tracker.Observe(1, 0, 0) + tracker.Observe(1, 0, 0, 0) testutils.WithTimeout(t, func() string { if callbackCalled.Load() == 2 {