diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index d6509e89b..38cf31569 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -385,6 +385,7 @@ type DownTrack struct { blankFramesGeneration atomic.Uint32 connectionStats *connectionquality.ConnectionStats + onStatsUpdate atomic.Value // func(d *DownTrack, stat *livekit.AnalyticsStat) isNACKThrottled atomic.Bool @@ -471,6 +472,9 @@ func NewDownTrack(params DownTrackParams) (*DownTrack, error) { }) d.connectionStats.OnStatsUpdate(func(_cs *connectionquality.ConnectionStats, stat *livekit.AnalyticsStat) { d.params.Listener.OnStatsUpdate(stat) + if fn, ok := d.onStatsUpdate.Load().(func(*DownTrack, *livekit.AnalyticsStat)); ok && fn != nil { + fn(d, stat) + } }) if d.kind == webrtc.RTPCodecTypeVideo { @@ -2484,6 +2488,13 @@ func (d *DownTrack) GetConnectionScoreAndQuality() (float32, livekit.ConnectionQ return d.connectionStats.GetScoreAndQuality() } +// OnStatsUpdate registers an additional callback that fires alongside the +// configured DownTrackListener whenever connection-quality stats are produced. +// Intended for tests and observers; the production listener path is unaffected. +func (d *DownTrack) OnStatsUpdate(fn func(d *DownTrack, stat *livekit.AnalyticsStat)) { + d.onStatsUpdate.Store(fn) +} + func (d *DownTrack) GetTrackStats() *livekit.RTPStats { return rtpstats.ReconcileRTPStatsWithRTX(d.rtpStats.ToProto(), d.rtpStatsRTX.ToProto()) } diff --git a/test/singlenode_test.go b/test/singlenode_test.go index d1cbf152c..177dbbaaa 100644 --- a/test/singlenode_test.go +++ b/test/singlenode_test.go @@ -22,6 +22,7 @@ import ( "net/http" "reflect" "strings" + "sync" "testing" "time" @@ -39,6 +40,8 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/datachannel" "github.com/livekit/livekit-server/pkg/testutils" testclient "github.com/livekit/livekit-server/test/client" @@ -236,6 +239,217 @@ func TestSinglePublisher(t *testing.T) { } } +func TestConnectionStats(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + s, finish := setupSingleNodeTest("TestConnectionStats") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("c2", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1, c2) + defer func() { + c1.Stop() + c2.Stop() + }() + + // both clients publish audio + video + t1, err := c1.AddStaticTrack("audio/opus", "audio", "c1audio") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "c1video") + require.NoError(t, err) + defer t2.Stop() + + t3, err := c2.AddStaticTrack("audio/opus", "audio", "c2audio") + require.NoError(t, err) + defer t3.Stop() + t4, err := c2.AddStaticTrack("video/vp8", "video", "c2video") + require.NoError(t, err) + defer t4.Stop() + + // wait for cross-subscriptions: each client should receive 2 tracks from the other + testutils.WithTimeout(t, func() string { + if len(c1.SubscribedTracks()[c2.ID()]) != 2 { + return "c1 did not subscribe to both tracks from c2" + } + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 did not subscribe to both tracks from c1" + } + return "" + }) + + room := s.RoomManager().GetRoom(context.Background(), testRoom) + require.NotNil(t, room) + + // hook the upstream WebRTCReceiver.OnStatsUpdate and downstream DownTrack.OnStatsUpdate + // callbacks so we can verify the AnalyticsStat delivered through each carries valid + // delta data. MediaTrack.Receivers() returns one entry per potential codec; only those + // matching the actually published codec are *sfu.WebRTCReceiver, the rest are + // placeholder *rtc.DummyReceiver instances that we skip. + type statCapture struct { + lock sync.Mutex + stat *livekit.AnalyticsStat + } + receiverCaptures := make(map[livekit.TrackID]*statCapture) + downTrackCaptures := make(map[livekit.ParticipantIdentity]map[livekit.TrackID]*statCapture) + for _, identity := range []livekit.ParticipantIdentity{"c1", "c2"} { + p := room.GetParticipant(identity) + require.NotNil(t, p, "participant %s not found", identity) + for _, mt := range p.GetPublishedTracks() { + rc := &statCapture{} + receiverCaptures[mt.ID()] = rc + var hooked int + for _, r := range mt.Receivers() { + if dr, ok := r.(*rtc.DummyReceiver); ok { + underlying := dr.Receiver() + if underlying == nil { + continue + } + r = underlying + } + wr, ok := r.(*sfu.WebRTCReceiver) + if !ok { + continue + } + wr.OnStatsUpdate(func(_ *sfu.WebRTCReceiver, stat *livekit.AnalyticsStat) { + rc.lock.Lock() + rc.stat = stat + rc.lock.Unlock() + }) + hooked++ + } + require.Greater(t, hooked, 0, "no live WebRTCReceiver found for published track %s", mt.ID()) + } + + dtCaps := make(map[livekit.TrackID]*statCapture) + downTrackCaptures[identity] = dtCaps + for _, st := range p.GetSubscribedTracks() { + dt := st.DownTrack() + require.NotNil(t, dt, "subscribed track %s has no DownTrack", st.ID()) + dc := &statCapture{} + dtCaps[st.ID()] = dc + dt.OnStatsUpdate(func(_ *sfu.DownTrack, stat *livekit.AnalyticsStat) { + dc.lock.Lock() + dc.stat = stat + dc.lock.Unlock() + }) + } + } + + validateAnalyticsStat := func(stat *livekit.AnalyticsStat) string { + if stat == nil { + return "stat nil" + } + if len(stat.Streams) == 0 { + return "stat has no streams" + } + var totalPackets uint32 + var totalBytes uint64 + for _, s := range stat.Streams { + totalPackets += s.PrimaryPackets + totalBytes += s.PrimaryBytes + } + if totalPackets == 0 { + return "stat has no packets across streams" + } + if totalBytes == 0 { + return "stat has no bytes across streams" + } + return "" + } + + // wait for cumulative + delta + OnStatsUpdate-derived stats. the + // connection-quality update interval is 5s, so allow plenty of time for + // the receiver OnStatsUpdate callback to fire at least once and for + // the downstream connection-quality scorer to compute a real score. + testutils.WithTimeout(t, func() string { + for _, identity := range []livekit.ParticipantIdentity{"c1", "c2"} { + p := room.GetParticipant(identity) + if p == nil { + return fmt.Sprintf("participant %s not found", identity) + } + + // upstream (publisher) cumulative stats + published := p.GetPublishedTracks() + if len(published) != 2 { + return fmt.Sprintf("%s expected 2 published tracks, got %d", identity, len(published)) + } + for _, mt := range published { + lmt, ok := mt.(types.LocalMediaTrack) + if !ok { + return fmt.Sprintf("%s published track %s is not a LocalMediaTrack", identity, mt.ID()) + } + stats := lmt.GetTrackStats() + if stats == nil { + return fmt.Sprintf("%s upstream cumulative stats nil for track %s", identity, mt.ID()) + } + if stats.Packets == 0 { + return fmt.Sprintf("%s upstream cumulative stats has no packets for track %s", identity, mt.ID()) + } + if stats.Bytes == 0 { + return fmt.Sprintf("%s upstream cumulative stats has no bytes for track %s", identity, mt.ID()) + } + + // upstream delta stats fed into the receiver OnStatsUpdate path + rc, ok := receiverCaptures[mt.ID()] + if !ok { + return fmt.Sprintf("%s missing receiver capture for track %s", identity, mt.ID()) + } + rc.lock.Lock() + stat := rc.stat + rc.lock.Unlock() + if msg := validateAnalyticsStat(stat); msg != "" { + return fmt.Sprintf("%s upstream OnStatsUpdate %s for track %s", identity, msg, mt.ID()) + } + } + + // downstream (subscriber) cumulative stats and DownTrack OnStatsUpdate + // delta stats captured from the listener path + subscribed := p.GetSubscribedTracks() + if len(subscribed) != 2 { + return fmt.Sprintf("%s expected 2 subscribed tracks, got %d", identity, len(subscribed)) + } + for _, st := range subscribed { + dt := st.DownTrack() + if dt == nil { + return fmt.Sprintf("%s subscribed track %s has no DownTrack", identity, st.ID()) + } + stats := dt.GetTrackStats() + if stats == nil { + return fmt.Sprintf("%s downstream cumulative stats nil for track %s", identity, st.ID()) + } + if stats.Packets == 0 { + return fmt.Sprintf("%s downstream cumulative stats has no packets for track %s", identity, st.ID()) + } + if stats.Bytes == 0 { + return fmt.Sprintf("%s downstream cumulative stats has no bytes for track %s", identity, st.ID()) + } + + // downstream delta stats fed into the DownTrack OnStatsUpdate path + dc, ok := downTrackCaptures[identity][st.ID()] + if !ok { + return fmt.Sprintf("%s missing DownTrack capture for track %s", identity, st.ID()) + } + dc.lock.Lock() + stat := dc.stat + dc.lock.Unlock() + if msg := validateAnalyticsStat(stat); msg != "" { + return fmt.Sprintf("%s downstream OnStatsUpdate %s for track %s", identity, msg, st.ID()) + } + } + } + return "" + }, 15*time.Second) + }) + } +} + func Test_WhenAutoSubscriptionDisabled_ClientShouldNotReceiveAnyPublishedTracks(t *testing.T) { if testing.Short() { t.SkipNow()