diff --git a/go.mod b/go.mod index c09a99e8b..f923bfef9 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20240625074155-301bb4a816b7 - github.com/livekit/protocol v1.19.1-0.20240627173058-82786f41fdb6 + github.com/livekit/protocol v1.19.2-0.20240705134535-94a2cfe2f1ee github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a github.com/mackerelio/go-osstat v0.2.5 github.com/magefile/mage v1.15.0 diff --git a/go.sum b/go.sum index 4a54999e8..964384eec 100644 --- a/go.sum +++ b/go.sum @@ -167,8 +167,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20240625074155-301bb4a816b7 h1:F1L8inJoynwIAYpZENNYS+1xHJMF5RFRorsnAlcxfSY= github.com/livekit/mediatransportutil v0.0.0-20240625074155-301bb4a816b7/go.mod h1:jwKUCmObuiEDH0iiuJHaGMXwRs3RjrB4G6qqgkr/5oE= -github.com/livekit/protocol v1.19.1-0.20240627173058-82786f41fdb6 h1:XtyV+MqHqXTuNLXz5TUjYtNg0gvTVw9web/YuXD9+3c= -github.com/livekit/protocol v1.19.1-0.20240627173058-82786f41fdb6/go.mod h1:bNjJi+8frdvC84xG0CJ/7VfVvqerLg2MzjOks0ucyC4= +github.com/livekit/protocol v1.19.2-0.20240705134535-94a2cfe2f1ee h1:J1U5fqAB5wJ4+Dl/DAf43Eiw+syyLTKAJoGuUj3rjQI= +github.com/livekit/protocol v1.19.2-0.20240705134535-94a2cfe2f1ee/go.mod h1:bNjJi+8frdvC84xG0CJ/7VfVvqerLg2MzjOks0ucyC4= github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a h1:EQAHmcYEGlc6V517cQ3Iy0+jHgP6+tM/B4l2vGuLpQo= github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a/go.mod h1:CQUBSPfYYAaevg1TNCc6/aYsa8DJH4jSRFdCeSZk5u0= github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o= diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 85368df95..ec5030ca3 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -40,9 +40,10 @@ import ( // MediaTrack represents a WebRTC track that needs to be forwarded // Implements MediaTrack and PublishedTrack interface type MediaTrack struct { - params MediaTrackParams - numUpTracks atomic.Uint32 - buffer *buffer.Buffer + params MediaTrackParams + numUpTracks atomic.Uint32 + buffer *buffer.Buffer + everSubscribed atomic.Bool *MediaTrackReceiver *MediaLossProxy @@ -55,22 +56,23 @@ type MediaTrack struct { } type MediaTrackParams struct { - SignalCid string - SdpCid string - ParticipantID livekit.ParticipantID - ParticipantIdentity livekit.ParticipantIdentity - ParticipantVersion uint32 - BufferFactory *buffer.Factory - ReceiverConfig ReceiverConfig - SubscriberConfig DirectionConfig - PLIThrottleConfig config.PLIThrottleConfig - AudioConfig config.AudioConfig - VideoConfig config.VideoConfig - Telemetry telemetry.TelemetryService - Logger logger.Logger - SimTracks map[uint32]SimulcastTrackInfo - OnRTCP func([]rtcp.Packet) - ForwardStats *sfu.ForwardStats + SignalCid string + SdpCid string + ParticipantID livekit.ParticipantID + ParticipantIdentity livekit.ParticipantIdentity + ParticipantVersion uint32 + BufferFactory *buffer.Factory + ReceiverConfig ReceiverConfig + SubscriberConfig DirectionConfig + PLIThrottleConfig config.PLIThrottleConfig + AudioConfig config.AudioConfig + VideoConfig config.VideoConfig + Telemetry telemetry.TelemetryService + Logger logger.Logger + SimTracks map[uint32]SimulcastTrackInfo + OnRTCP func([]rtcp.Packet) + ForwardStats *sfu.ForwardStats + OnTrackEverSubscribed func(livekit.TrackID) } func NewMediaTrack(params MediaTrackParams, ti *livekit.TrackInfo) *MediaTrack { @@ -283,6 +285,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra sfu.WithLoadBalanceThreshold(20), sfu.WithStreamTrackers(), sfu.WithForwardStats(t.params.ForwardStats), + sfu.WithEverHasDowntrackAdded(t.handleReceiverEverAddDowntrack), ) newWR.OnCloseHandler(func() { t.MediaTrackReceiver.SetClosing() @@ -430,3 +433,9 @@ func (t *MediaTrack) SetMuted(muted bool) { t.MediaTrackReceiver.SetMuted(muted) } + +func (t *MediaTrack) handleReceiverEverAddDowntrack() { + if !t.everSubscribed.Swap(true) && t.params.OnTrackEverSubscribed != nil { + go t.params.OnTrackEverSubscribed(t.ID()) + } +} diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 1e113bc14..1d3824033 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -2140,22 +2140,23 @@ func (p *ParticipantImpl) addMigratedTrack(cid string, ti *livekit.TrackInfo) *M func (p *ParticipantImpl) addMediaTrack(signalCid string, sdpCid string, ti *livekit.TrackInfo) *MediaTrack { mt := NewMediaTrack(MediaTrackParams{ - SignalCid: signalCid, - SdpCid: sdpCid, - ParticipantID: p.params.SID, - ParticipantIdentity: p.params.Identity, - ParticipantVersion: p.version.Load(), - BufferFactory: p.params.Config.BufferFactory, - ReceiverConfig: p.params.Config.Receiver, - AudioConfig: p.params.AudioConfig, - VideoConfig: p.params.VideoConfig, - Telemetry: p.params.Telemetry, - Logger: LoggerWithTrack(p.pubLogger, livekit.TrackID(ti.Sid), false), - SubscriberConfig: p.params.Config.Subscriber, - PLIThrottleConfig: p.params.PLIThrottleConfig, - SimTracks: p.params.SimTracks, - OnRTCP: p.postRtcp, - ForwardStats: p.params.ForwardStats, + SignalCid: signalCid, + SdpCid: sdpCid, + ParticipantID: p.params.SID, + ParticipantIdentity: p.params.Identity, + ParticipantVersion: p.version.Load(), + BufferFactory: p.params.Config.BufferFactory, + ReceiverConfig: p.params.Config.Receiver, + AudioConfig: p.params.AudioConfig, + VideoConfig: p.params.VideoConfig, + Telemetry: p.params.Telemetry, + Logger: LoggerWithTrack(p.pubLogger, livekit.TrackID(ti.Sid), false), + SubscriberConfig: p.params.Config.Subscriber, + PLIThrottleConfig: p.params.PLIThrottleConfig, + SimTracks: p.params.SimTracks, + OnRTCP: p.postRtcp, + ForwardStats: p.params.ForwardStats, + OnTrackEverSubscribed: p.sendTrackHasBeenSubscribed, }, ti) mt.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) diff --git a/pkg/rtc/participant_signal.go b/pkg/rtc/participant_signal.go index 6f84eadd9..35dc7148d 100644 --- a/pkg/rtc/participant_signal.go +++ b/pkg/rtc/participant_signal.go @@ -286,6 +286,17 @@ func (p *ParticipantImpl) sendTrackUnpublished(trackID livekit.TrackID) { }) } +func (p *ParticipantImpl) sendTrackHasBeenSubscribed(trackID livekit.TrackID) { + _ = p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_TrackSubscribed{ + TrackSubscribed: &livekit.TrackSubscribed{ + TrackSid: string(trackID), + }, + }, + }) + p.params.Logger.Debugw("track has been subscribed", "trackID", trackID) +} + func (p *ParticipantImpl) writeMessage(msg *livekit.SignalResponse) error { if p.IsDisconnected() || (!p.IsReady() && msg.GetJoin() == nil) { return nil diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 1fd0db1c6..b0ccc9431 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -120,8 +120,10 @@ type WebRTCReceiver struct { connectionStats *connectionquality.ConnectionStats - onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat) - onMaxLayerChange func(maxLayer int32) + onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat) + onMaxLayerChange func(maxLayer int32) + downtrackEverAdded atomic.Bool + onDowntrackEverAdded func() primaryReceiver atomic.Pointer[RedPrimaryReceiver] redReceiver atomic.Pointer[RedReceiver] @@ -193,6 +195,13 @@ func WithForwardStats(forwardStats *ForwardStats) ReceiverOpts { } } +func WithEverHasDowntrackAdded(f func()) ReceiverOpts { + return func(w *WebRTCReceiver) *WebRTCReceiver { + w.onDowntrackEverAdded = f + return w + } +} + // NewWebRTCReceiver creates a new webrtc track receiver func NewWebRTCReceiver( receiver *webrtc.RTPReceiver, @@ -429,9 +438,16 @@ func (w *WebRTCReceiver) AddDownTrack(track TrackSender) error { w.downTrackSpreader.Store(track) w.logger.Debugw("downtrack added", "subscriberID", track.SubscriberID()) + w.handleDowntrackAdded() return nil } +func (w *WebRTCReceiver) handleDowntrackAdded() { + if !w.downtrackEverAdded.Swap(true) && w.onDowntrackEverAdded != nil { + w.onDowntrackEverAdded() + } +} + func (w *WebRTCReceiver) notifyMaxExpectedLayer(layer int32) { ti := w.TrackInfo() if ti == nil { @@ -810,6 +826,7 @@ func (w *WebRTCReceiver) GetPrimaryReceiverForRed() TrackReceiver { w.bufferMu.Lock() w.redPktWriter = pr.ForwardRTP w.bufferMu.Unlock() + w.handleDowntrackAdded() } } return w.primaryReceiver.Load() @@ -829,6 +846,7 @@ func (w *WebRTCReceiver) GetRedReceiver() TrackReceiver { w.bufferMu.Lock() w.redPktWriter = pr.ForwardRTP w.bufferMu.Unlock() + w.handleDowntrackAdded() } } return w.redReceiver.Load()