diff --git a/pkg/config/config.go b/pkg/config/config.go index c059156bf..b01da016f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -193,6 +193,7 @@ func NewConfig(confString string, c *cli.Context) (*Config, error) { Port: 7880, RTC: RTCConfig{ UseExternalIP: false, + UseICELite: true, TCPPort: 7881, UDPPort: 0, ICEPortRangeStart: 0, diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 91c98a629..c9cbff21f 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -245,7 +245,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * // since sub will lock, run it in a goroutine to avoid deadlocks go func() { sub.AddSubscribedTrack(subTrack) - sub.Negotiate() + sub.Negotiate(false) }() t.params.Telemetry.TrackSubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto(), @@ -708,5 +708,5 @@ func (t *MediaTrackSubscriptions) downTrackClosed( } sub.RemoveSubscribedTrack(subTrack) - sub.Negotiate() + sub.Negotiate(false) } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 82571cf2f..3c797d0ce 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -134,7 +134,8 @@ type ParticipantImpl struct { onClose func(types.LocalParticipant, map[livekit.TrackID]livekit.ParticipantID) onClaimsChanged func(participant types.LocalParticipant) - activeCounter atomic.Int32 + activeCounter atomic.Int32 + firstConnected atomic.Bool } func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { @@ -672,9 +673,11 @@ func (p *ParticipantImpl) Close(sendLeave bool) error { return nil } -func (p *ParticipantImpl) Negotiate() { +// Negotiate subscriber SDP with client, if force is true, will cencel pending +// negotiate task and negotiate immediately +func (p *ParticipantImpl) Negotiate(force bool) { if p.MigrateState() != types.MigrateStateInit { - p.subscriber.Negotiate() + p.subscriber.Negotiate(force) } } @@ -857,6 +860,9 @@ func (p *ParticipantImpl) AddSubscribedTrack(subTrack types.SubscribedTrack) { p.lock.Unlock() subTrack.OnBind(func() { + if p.firstConnected.Load() { + subTrack.DownTrack().SetConnected() + } p.subscriber.AddTrack(subTrack) }) @@ -1114,6 +1120,9 @@ func (p *ParticipantImpl) handleDataMessage(kind livekit.DataPacket_Kind, data [ func (p *ParticipantImpl) handlePrimaryStateChange(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateConnected { + if !p.firstConnected.Swap(true) { + p.setDowntracksConnected() + } prometheus.ServiceOperationCounter.WithLabelValues("ice_connection", "success", "").Add(1) if !p.hasPendingMigratedTrack() && p.MigrateState() == types.MigrateStateSync { p.SetMigrateState(types.MigrateStateComplete) @@ -1798,3 +1807,14 @@ func (p *ParticipantImpl) postRtcp(pkts []rtcp.Packet) { p.params.Logger.Warnw("rtcp channel full", nil) } } + +func (p *ParticipantImpl) setDowntracksConnected() { + p.lock.RLock() + defer p.lock.RUnlock() + + for _, t := range p.subscribedTracks { + if dt := t.DownTrack(); dt != nil { + dt.SetConnected() + } + } +} diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 568dbecc0..9e52d7949 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -307,7 +307,14 @@ func (r *Room) Join(participant types.LocalParticipant, opts *ParticipantOptions if participant.SubscriberAsPrimary() { // initiates sub connection as primary - participant.Negotiate() + if participant.ProtocolVersion().SupportFastStart() { + go func() { + r.subscribeToExistingTracks(participant) + participant.Negotiate(true) + }() + } else { + participant.Negotiate(true) + } } prometheus.ServiceOperationCounter.WithLabelValues("participant_join", "success", "").Add(1) @@ -462,7 +469,7 @@ func (r *Room) SetParticipantPermission(participant types.LocalParticipant, perm if r.subscribeToExistingTracks(participant) == 0 { // start negotiating even if there are other media tracks to subscribe // we'll need to set the participant up to receive data - participant.Negotiate() + participant.Negotiate(false) } } } diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 48db5b0b9..05f62498b 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -288,12 +288,21 @@ func (t *PCTransport) OnOffer(f func(sd webrtc.SessionDescription)) { t.onOffer = f } -func (t *PCTransport) Negotiate() { - t.debouncedNegotiate(func() { +func (t *PCTransport) Negotiate(force bool) { + if force { + t.debouncedNegotiate(func() { + // no op to cancel pending negotiation + }) if err := t.CreateAndSendOffer(nil); err != nil { t.logger.Errorw("could not negotiate", err) } - }) + } else { + t.debouncedNegotiate(func() { + if err := t.CreateAndSendOffer(nil); err != nil { + t.logger.Errorw("could not negotiate", err) + } + }) + } } func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 0d2d54850..de4a9ac66 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -126,7 +126,7 @@ type LocalParticipant interface { SubscriberMediaEngine() *webrtc.MediaEngine SubscriberPC() *webrtc.PeerConnection HandleAnswer(sdp webrtc.SessionDescription) error - Negotiate() + Negotiate(force bool) ICERestart() error AddSubscribedTrack(st SubscribedTrack) RemoveSubscribedTrack(st SubscribedTrack) diff --git a/pkg/rtc/types/protocol_version.go b/pkg/rtc/types/protocol_version.go index accfb437c..071ea512e 100644 --- a/pkg/rtc/types/protocol_version.go +++ b/pkg/rtc/types/protocol_version.go @@ -47,3 +47,9 @@ func (v ProtocolVersion) SupportsICELite() bool { func (v ProtocolVersion) SupportsUnpublish() bool { return v > 6 } + +// SupportFastStart - if client supports fast start, server side will send media streams +// in the first offer +func (v ProtocolVersion) SupportFastStart() bool { + return v > 7 +} diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 5acdf813b..edabb8df7 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -308,9 +308,10 @@ type FakeLocalParticipant struct { migrateStateReturnsOnCall map[int]struct { result1 types.MigrateState } - NegotiateStub func() + NegotiateStub func(bool) negotiateMutex sync.RWMutex negotiateArgsForCall []struct { + arg1 bool } OnClaimsChangedStub func(func(types.LocalParticipant)) onClaimsChangedMutex sync.RWMutex @@ -2190,15 +2191,16 @@ func (fake *FakeLocalParticipant) MigrateStateReturnsOnCall(i int, result1 types }{result1} } -func (fake *FakeLocalParticipant) Negotiate() { +func (fake *FakeLocalParticipant) Negotiate(arg1 bool) { fake.negotiateMutex.Lock() fake.negotiateArgsForCall = append(fake.negotiateArgsForCall, struct { - }{}) + arg1 bool + }{arg1}) stub := fake.NegotiateStub - fake.recordInvocation("Negotiate", []interface{}{}) + fake.recordInvocation("Negotiate", []interface{}{arg1}) fake.negotiateMutex.Unlock() if stub != nil { - fake.NegotiateStub() + fake.NegotiateStub(arg1) } } @@ -2208,12 +2210,19 @@ func (fake *FakeLocalParticipant) NegotiateCallCount() int { return len(fake.negotiateArgsForCall) } -func (fake *FakeLocalParticipant) NegotiateCalls(stub func()) { +func (fake *FakeLocalParticipant) NegotiateCalls(stub func(bool)) { fake.negotiateMutex.Lock() defer fake.negotiateMutex.Unlock() fake.NegotiateStub = stub } +func (fake *FakeLocalParticipant) NegotiateArgsForCall(i int) bool { + fake.negotiateMutex.RLock() + defer fake.negotiateMutex.RUnlock() + argsForCall := fake.negotiateArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeLocalParticipant) OnClaimsChanged(arg1 func(types.LocalParticipant)) { fake.onClaimsChangedMutex.Lock() fake.onClaimsChangedArgsForCall = append(fake.onClaimsChangedArgsForCall, struct { diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 2e90ff272..1d5fc0ed0 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -312,9 +312,9 @@ func (b *Buffer) SetPLIThrottle(duration int64) { b.pliThrottle = duration } -func (b *Buffer) SendPLI() { +func (b *Buffer) SendPLI(force bool) { b.RLock() - if b.rtpStats == nil || b.rtpStats.TimeSinceLastPli() < b.pliThrottle { + if (b.rtpStats == nil || b.rtpStats.TimeSinceLastPli() < b.pliThrottle) && !force { b.RUnlock() return } @@ -322,7 +322,7 @@ func (b *Buffer) SendPLI() { b.rtpStats.UpdatePliAndTime(1) b.RUnlock() - b.logger.Debugw("send pli", "ssrc", b.mediaSSRC) + b.logger.Debugw("send pli", "ssrc", b.mediaSSRC, "force", force) pli := []rtcp.Packet{ &rtcp.PictureLossIndication{SenderSSRC: rand.Uint32(), MediaSSRC: b.mediaSSRC}, } diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index c774829b6..4b52c2fdc 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -136,6 +136,7 @@ type DownTrack struct { receiverReportListeners []ReceiverReportListener listenerLock sync.RWMutex isClosed atomic.Bool + connected atomic.Bool rtpStats *buffer.RTPStats @@ -397,9 +398,11 @@ func (d *DownTrack) keyFrameRequester(generation uint32, layer int32) { ticker := time.NewTicker(time.Duration(interval) * time.Millisecond) defer ticker.Stop() for { - d.logger.Debugw("sending PLI for layer lock", "generation", generation, "layer", layer) - d.receiver.SendPLI(layer) - d.rtpStats.UpdateLayerLockPliAndTime(1) + if d.connected.Load() { + d.logger.Debugw("sending PLI for layer lock", "generation", generation, "layer", layer) + d.receiver.SendPLI(layer, false) + d.rtpStats.UpdateLayerLockPliAndTime(1) + } <-ticker.C @@ -480,7 +483,6 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { d.stopKeyFrameRequester() } - // too much log for switching target layer, only log key frame if !tp.switchingToTargetLayer { d.logger.Debugw("forwarding key frame", "layer", layer) } @@ -1065,7 +1067,7 @@ func (d *DownTrack) handleRTCP(bytes []byte) { targetLayers := d.forwarder.TargetLayers() if targetLayers != InvalidLayers { d.logger.Debugw("sending PLI RTCP", "layer", targetLayers.Spatial) - d.receiver.SendPLI(targetLayers.Spatial) + d.receiver.SendPLI(targetLayers.Spatial, false) d.isNACKThrottled.Store(true) d.rtpStats.UpdatePliTime() pliOnce = false @@ -1151,6 +1153,17 @@ func (d *DownTrack) handleRTCP(bytes []byte) { } } +func (d *DownTrack) SetConnected() { + if !d.connected.Swap(true) { + if d.bound.Load() && d.kind == webrtc.RTPCodecTypeVideo { + targetLayers := d.forwarder.TargetLayers() + if targetLayers != InvalidLayers { + d.receiver.SendPLI(targetLayers.Spatial, true) + } + } + } +} + func (d *DownTrack) retransmitPackets(nacks []uint16) { if d.sequencer == nil { return diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index a8f081899..9a4129678 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -43,7 +43,7 @@ type TrackReceiver interface { GetAudioLevel() (float64, bool) - SendPLI(layer int32) + SendPLI(layer int32, force bool) SetUpTrackPaused(paused bool) SetMaxExpectedSpatialLayer(layer int32) @@ -440,14 +440,14 @@ func (w *WebRTCReceiver) sendRTCP(packets []rtcp.Packet) { } } -func (w *WebRTCReceiver) SendPLI(layer int32) { +func (w *WebRTCReceiver) SendPLI(layer int32, force bool) { // TODO : should send LRR (Layer Refresh Request) instead of PLI buff := w.getBuffer(layer) if buff == nil { return } - buff.SendPLI() + buff.SendPLI(force) } func (w *WebRTCReceiver) SetRTCPCh(ch chan []rtcp.Packet) { diff --git a/test/client/client.go b/test/client/client.go index f62365ddf..eebd76b10 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -281,7 +281,7 @@ func (c *RTCClient) Run() error { // if publish only, negotiate if !msg.Join.SubscriberPrimary { c.publisherNegotiated.Store(true) - c.publisher.Negotiate() + c.publisher.Negotiate(false) } logger.Infow("join accepted, awaiting offer", "participant", msg.Join.Participant.Identity) @@ -342,7 +342,7 @@ func (c *RTCClient) Run() error { if err := c.publisher.PeerConnection().RemoveTrack(sender); err != nil { logger.Errorw("Could not unpublish track", err) } - c.publisher.Negotiate() + c.publisher.Negotiate(false) } delete(c.trackSenders, sid) delete(c.localTracks, sid) @@ -502,7 +502,7 @@ func (c *RTCClient) AddTrack(track *webrtc.TrackLocalStaticSample, path string) } c.localTracks[ti.Sid] = track c.trackSenders[ti.Sid] = sender - c.publisher.Negotiate() + c.publisher.Negotiate(false) writer = NewTrackWriter(c.ctx, track, path) // write tracks only after ICE connectivity @@ -604,7 +604,7 @@ func (c *RTCClient) ensurePublisherConnected() error { if c.publisher.PeerConnection().ConnectionState() == webrtc.PeerConnectionStateNew { // start negotiating - c.publisher.Negotiate() + c.publisher.Negotiate(false) } dcOpen := atomic.NewBool(false)