From 044f6cec408fa35f02d000df01b573455fe9991f Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Fri, 15 Sep 2023 21:39:03 +0530 Subject: [PATCH 1/8] Reduce packet meta data cache - part 1 (#2073) * Reduce packet meta data cache - part 1 Packet meta data cache takes a good amount of space. That cache is 8K entries deep and each entry is 8 bytes. So, that takes 64KB per RTP stream. It is mostly needed for down stream to line up with receiver reports. So, removing cache from up stream (RTPStatsReceiver) as part 1. Will look at optimising the down stream in part 2. * Remove caching from RTPStatsReceiver * clean up a bit more * maintain history and fix test --- pkg/sfu/buffer/rtpstats_base.go | 143 ++++++++++++++--------- pkg/sfu/buffer/rtpstats_receiver.go | 83 +++++++++++-- pkg/sfu/buffer/rtpstats_receiver_test.go | 65 ++--------- pkg/sfu/buffer/rtpstats_sender.go | 61 ++++------ 4 files changed, 196 insertions(+), 156 deletions(-) diff --git a/pkg/sfu/buffer/rtpstats_base.go b/pkg/sfu/buffer/rtpstats_base.go index f02e9241b..3fad1907c 100644 --- a/pkg/sfu/buffer/rtpstats_base.go +++ b/pkg/sfu/buffer/rtpstats_base.go @@ -89,17 +89,34 @@ type RTPDeltaInfo struct { } type snapshot struct { - startTime time.Time - extStartSN uint64 - packetsDuplicate uint64 - bytesDuplicate uint64 - headerBytesDuplicate uint64 - packetsLostOverridden uint64 - nacks uint32 - plis uint32 - firs uint32 - maxRtt uint32 - maxJitter float64 + isValid bool + + startTime time.Time + + extStartSN uint64 + bytes uint64 + headerBytes uint64 + + packetsPadding uint64 + bytesPadding uint64 + headerBytesPadding uint64 + + packetsDuplicate uint64 + bytesDuplicate uint64 + headerBytesDuplicate uint64 + + packetsOutOfOrder uint64 + + packetsLost uint64 + + frames uint32 + + nacks uint32 + plis uint32 + firs uint32 + + maxRtt uint32 + maxJitter float64 } type snInfo struct { @@ -153,8 +170,7 @@ type rtpStatsBase struct { packetsOutOfOrder uint64 - packetsLost uint64 - packetsLostOverridden uint64 + packetsLost uint64 frames uint32 @@ -189,7 +205,7 @@ type rtpStatsBase struct { srNewest *RTCPSenderReportData nextSnapshotID uint32 - snapshots map[uint32]*snapshot + snapshots []snapshot } func newRTPStatsBase(params RTPStatsParams) *rtpStatsBase { @@ -197,7 +213,7 @@ func newRTPStatsBase(params RTPStatsParams) *rtpStatsBase { params: params, logger: params.Logger, nextSnapshotID: cFirstSnapshotID, - snapshots: make(map[uint32]*snapshot), + snapshots: make([]snapshot, 2), } } @@ -273,10 +289,8 @@ func (r *rtpStatsBase) seed(from *rtpStatsBase) bool { } r.nextSnapshotID = from.nextSnapshotID - for id, ss := range from.snapshots { - ssCopy := *ss - r.snapshots[id] = &ssCopy - } + r.snapshots = make([]snapshot, cap(from.snapshots)) + copy(r.snapshots, from.snapshots) return true } @@ -295,11 +309,14 @@ func (r *rtpStatsBase) newSnapshotID(extStartSN uint64) uint32 { id := r.nextSnapshotID r.nextSnapshotID++ + if cap(r.snapshots) < int(r.nextSnapshotID) { + snapshots := make([]snapshot, r.nextSnapshotID) + copy(snapshots, r.snapshots) + r.snapshots = snapshots + } + if r.initialized { - r.snapshots[id] = &snapshot{ - startTime: time.Now(), - extStartSN: extStartSN, - } + r.snapshots[id] = r.initSnapshot(time.Now(), extStartSN) } return id } @@ -551,21 +568,25 @@ func (r *rtpStatsBase) deltaInfo(snapshotID uint32, extStartSN uint64, extHighes } } - intervalStats := r.getIntervalStats(then.extStartSN, now.extStartSN, extHighestSN) + packetsLost := uint32(now.packetsLost - then.packetsLost) + if int32(packetsLost) < 0 { + packetsLost = 0 + } return &RTPDeltaInfo{ StartTime: startTime, Duration: endTime.Sub(startTime), - Packets: uint32(packetsExpected - intervalStats.packetsPadding), - Bytes: intervalStats.bytes, - HeaderBytes: intervalStats.headerBytes, + Packets: uint32(packetsExpected - (now.packetsPadding - then.packetsPadding)), + Bytes: now.bytes - then.bytes, + HeaderBytes: now.headerBytes - then.headerBytes, PacketsDuplicate: uint32(now.packetsDuplicate - then.packetsDuplicate), BytesDuplicate: now.bytesDuplicate - then.bytesDuplicate, HeaderBytesDuplicate: now.headerBytesDuplicate - then.headerBytesDuplicate, - PacketsPadding: uint32(intervalStats.packetsPadding), - BytesPadding: intervalStats.bytesPadding, - HeaderBytesPadding: intervalStats.headerBytesPadding, - PacketsLost: uint32(intervalStats.packetsLost), - Frames: intervalStats.frames, + PacketsPadding: uint32(now.packetsPadding - then.packetsPadding), + BytesPadding: now.bytesPadding - then.bytesPadding, + HeaderBytesPadding: now.headerBytesPadding - then.headerBytesPadding, + PacketsLost: packetsLost, + PacketsOutOfOrder: uint32(now.packetsOutOfOrder - then.packetsOutOfOrder), + Frames: now.frames - then.frames, RttMax: then.maxRtt, JitterMax: then.maxJitter / float64(r.params.ClockRate) * 1e6, Nacks: now.nacks - then.nacks, @@ -894,31 +915,15 @@ func (r *rtpStatsBase) getAndResetSnapshot(snapshotID uint32, extStartSN uint64, } then := r.snapshots[snapshotID] - if then == nil { - then = &snapshot{ - startTime: r.startTime, - extStartSN: extStartSN, - } + if !then.isValid { + then = r.initSnapshot(r.startTime, extStartSN) r.snapshots[snapshotID] = then } // snapshot now - r.snapshots[snapshotID] = &snapshot{ - startTime: time.Now(), - extStartSN: extHighestSN + 1, - packetsDuplicate: r.packetsDuplicate, - bytesDuplicate: r.bytesDuplicate, - headerBytesDuplicate: r.headerBytesDuplicate, - nacks: r.nacks, - plis: r.plis, - firs: r.firs, - maxJitter: r.jitter, - maxRtt: r.rtt, - } - // make a copy so that it can be used independently - now := *r.snapshots[snapshotID] - - return then, &now + now := r.getSnapshot(time.Now(), extHighestSN+1) + r.snapshots[snapshotID] = now + return &then, &now } func (r *rtpStatsBase) getDrift(extStartTS, extHighestTS uint64) (packetDrift *livekit.RTPDrift, reportDrift *livekit.RTPDrift) { @@ -975,6 +980,38 @@ func (r *rtpStatsBase) updateGapHistogram(gap int) { } } +func (r *rtpStatsBase) initSnapshot(startTime time.Time, extStartSN uint64) snapshot { + return snapshot{ + isValid: true, + startTime: time.Now(), + extStartSN: extStartSN, + } +} + +func (r *rtpStatsBase) getSnapshot(startTime time.Time, extStartSN uint64) snapshot { + return snapshot{ + isValid: true, + startTime: time.Now(), + extStartSN: extStartSN, + bytes: r.bytes, + headerBytes: r.headerBytes, + packetsPadding: r.packetsPadding, + bytesPadding: r.bytesPadding, + headerBytesPadding: r.headerBytesPadding, + packetsDuplicate: r.packetsDuplicate, + bytesDuplicate: r.bytesDuplicate, + headerBytesDuplicate: r.headerBytesDuplicate, + packetsLost: r.packetsLost, + packetsOutOfOrder: r.packetsOutOfOrder, + frames: r.frames, + nacks: r.nacks, + plis: r.plis, + firs: r.firs, + maxRtt: r.rtt, + maxJitter: r.jitter, + } +} + // ---------------------------------- func AggregateRTPStats(statsList []*livekit.RTPStats) *livekit.RTPStats { diff --git a/pkg/sfu/buffer/rtpstats_receiver.go b/pkg/sfu/buffer/rtpstats_receiver.go index eeb8e5ef5..0d41e73d7 100644 --- a/pkg/sfu/buffer/rtpstats_receiver.go +++ b/pkg/sfu/buffer/rtpstats_receiver.go @@ -24,6 +24,10 @@ import ( "github.com/livekit/protocol/livekit" ) +const ( + cHistorySize = 2048 +) + type RTPFlowState struct { IsNotHandled bool @@ -47,6 +51,8 @@ type RTPStatsReceiver struct { sequenceNumber *utils.WrapAround[uint16, uint64] timestamp *utils.WrapAround[uint32, uint64] + + history [cHistorySize / 64]uint64 } func NewRTPStatsReceiver(params RTPStatsParams) *RTPStatsReceiver { @@ -107,10 +113,7 @@ func (r *RTPStatsReceiver) Update( // initialize snapshots if any for i := uint32(cFirstSnapshotID); i < r.nextSnapshotID; i++ { - r.snapshots[i] = &snapshot{ - startTime: r.startTime, - extStartSN: r.sequenceNumber.GetExtendedStart(), - } + r.snapshots[i] = r.initSnapshot(r.startTime, r.sequenceNumber.GetExtendedStart()) } r.logger.Debugw( @@ -170,14 +173,14 @@ func (r *RTPStatsReceiver) Update( ) } - if !r.isSnInfoLost(resSN.ExtendedVal, resSN.PreExtendedHighest) { + if !r.isLost(resSN.ExtendedVal, resSN.PreExtendedHighest) { r.bytesDuplicate += pktSize r.headerBytesDuplicate += uint64(hdrSize) r.packetsDuplicate++ flowState.IsDuplicate = true } else { r.packetsLost-- - r.setSnInfo(resSN.ExtendedVal, resSN.PreExtendedHighest, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), marker, true) + r.setHistory(resSN.ExtendedVal, resSN.PreExtendedHighest) } flowState.IsOutOfOrder = true @@ -188,10 +191,10 @@ func (r *RTPStatsReceiver) Update( r.updateGapHistogram(int(gapSN)) // update missing sequence numbers - r.clearSnInfos(resSN.PreExtendedHighest+1, resSN.ExtendedVal) + r.clearHistory(resSN.PreExtendedHighest+1, resSN.ExtendedVal, resSN.PreExtendedHighest) r.packetsLost += uint64(gapSN - 1) - r.setSnInfo(resSN.ExtendedVal, resSN.PreExtendedHighest, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), marker, false) + r.setHistory(resSN.ExtendedVal, resSN.PreExtendedHighest) if timestamp != uint32(resTS.PreExtendedHighest) { // update only on first packet as same timestamp could be in multiple packets. @@ -409,8 +412,10 @@ func (r *RTPStatsReceiver) GetRtcpReceptionReport(ssrc uint32, proxyFracLost uin return nil } - intervalStats := r.getIntervalStats(then.extStartSN, now.extStartSN, extHighestSN) - packetsLost := intervalStats.packetsLost + packetsLost := uint32(now.packetsLost - then.packetsLost) + if int32(packetsLost) < 0 { + packetsLost = 0 + } lossRate := float32(packetsLost) / float32(packetsExpected) fracLost := uint8(lossRate * 256.0) if proxyFracLost > fracLost { @@ -468,4 +473,62 @@ func (r *RTPStatsReceiver) ToProto() *livekit.RTPStats { ) } +func (r *RTPStatsReceiver) getOutOfOrderHistorySlot(esn uint64, ehsn uint64) (int, int) { + diff := int64(ehsn - esn) + if diff >= cHistorySize || diff < 0 { + // too old OR too new (i. e. ahead of highest) + return -1, -1 + } + + return int(esn) % len(r.history), int(esn & 63) +} + +func (r *RTPStatsReceiver) getHistorySlot(esn uint64, ehsn uint64) (int, int) { + if int64(esn-ehsn) < 0 { + return r.getOutOfOrderHistorySlot(esn, ehsn) + } + + return int(esn) % len(r.history), int(esn & 63) +} + +func (r *RTPStatsReceiver) setHistory(esn uint64, ehsn uint64) { + slot, offset := r.getHistorySlot(esn, ehsn) + if slot < 0 { + return + } + + r.history[slot] |= (1 << offset) +} + +func (r *RTPStatsReceiver) clearHistory(extStartInclusive uint64, extEndExclusive uint64, ehsn uint64) { + if extEndExclusive <= extStartInclusive { + return + } + + slot, offset := r.getHistorySlot(extStartInclusive, ehsn) + if slot < 0 { + return + } + for esn := extStartInclusive; esn != extEndExclusive; esn++ { + r.history[slot] &= ^(1 << offset) + offset++ + if offset > 63 { + offset -= 64 + slot++ + if slot >= len(r.history) { + slot -= len(r.history) + } + } + } +} + +func (r *RTPStatsReceiver) isLost(esn uint64, ehsn uint64) bool { + slot, offset := r.getHistorySlot(esn, ehsn) + if slot < 0 { + return false + } + + return r.history[slot]&(1< Date: Fri, 15 Sep 2023 12:55:54 -0700 Subject: [PATCH 2/8] Skip SendDataPacket logging on transport failure (#2074) That's a sign of peer connection failure, we do not need to log these --- pkg/rtc/room.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index d04eb5cf8..1589ada96 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -1331,7 +1331,8 @@ func BroadcastDataPacketForRoom(r types.Room, source types.LocalParticipant, dp utils.ParallelExec(destParticipants, dataForwardLoadBalanceThreshold, 1, func(op types.LocalParticipant) { err := op.SendDataPacket(dp, dpData) - if err != nil && !errors.Is(err, io.ErrClosedPipe) && !errors.Is(err, sctp.ErrStreamClosed) { + if err != nil && !errors.Is(err, io.ErrClosedPipe) && !errors.Is(err, sctp.ErrStreamClosed) && + !errors.Is(err, ErrTransportFailure) { op.GetLogger().Infow("send data packet error", "error", err) } }) From 9c2ad54146598e80315841e6707e8e01e0d0c347 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sat, 16 Sep 2023 01:57:34 +0530 Subject: [PATCH 3/8] Clean up debug logs (#2076) --- pkg/sfu/rtpmunger.go | 27 ++++++--------------------- pkg/sfu/sequencer.go | 5 ++--- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/pkg/sfu/rtpmunger.go b/pkg/sfu/rtpmunger.go index d8b672a9d..c30b8a53b 100644 --- a/pkg/sfu/rtpmunger.go +++ b/pkg/sfu/rtpmunger.go @@ -121,7 +121,7 @@ func (r *RTPMunger) SetLastSnTs(extPkt *buffer.ExtPacket) { r.extLastSN = extPkt.ExtSequenceNumber r.extSecondLastSN = r.extLastSN - 1 - r.updateSnOffset("init") + r.updateSnOffset() r.extLastTS = extPkt.ExtTimestamp } @@ -130,7 +130,7 @@ func (r *RTPMunger) UpdateSnTsOffsets(extPkt *buffer.ExtPacket, snAdjust uint64, r.extHighestIncomingSN = extPkt.ExtSequenceNumber - 1 r.snRangeMap.ClearAndResetValue(extPkt.ExtSequenceNumber - r.extLastSN - snAdjust) - r.updateSnOffset("switch") + r.updateSnOffset() r.tsOffset = extPkt.ExtTimestamp - r.extLastTS - tsAdjust } @@ -156,7 +156,7 @@ func (r *RTPMunger) PacketDropped(extPkt *buffer.ExtPacket) { } r.extLastSN = r.extSecondLastSN - r.updateSnOffset("drop") + r.updateSnOffset() } func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationParamsRTP, error) { @@ -197,15 +197,6 @@ func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationPara if diff < 0 { // out-of-order, look up sequence number offset cache snOffset, err := r.snRangeMap.GetValue(extPkt.ExtSequenceNumber) - r.logger.Debugw( - "out-of-order packet", - "extHighestIncomingSN", r.extHighestIncomingSN, - "extLastSN", r.extLastSN, - "extSequenceNumber", extPkt.ExtSequenceNumber, - "snOffset", snOffset, - "error", err, - "outgoingSN", extPkt.ExtSequenceNumber-snOffset, - ) if err != nil { return &TranslationParamsRTP{ snOrdering: SequenceNumberOrderingOutOfOrder, @@ -227,7 +218,7 @@ func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationPara r.logger.Errorw("could not exclude range", err, "sn", r.extHighestIncomingSN) } - r.updateSnOffset("pad-drop") + r.updateSnOffset() return &TranslationParamsRTP{ snOrdering: SequenceNumberOrderingContiguous, @@ -298,7 +289,7 @@ func (r *RTPMunger) UpdateAndGetPaddingSnTs(num int, clockRate uint32, frameRate r.extSecondLastSN = extLastSN - 1 r.extLastSN = extLastSN r.snRangeMap.DecValue(r.extHighestIncomingSN, uint64(num)) - r.updateSnOffset("pad") + r.updateSnOffset() r.tsOffset -= extLastTS - r.extLastTS r.extLastTS = extLastTS @@ -314,16 +305,10 @@ func (r *RTPMunger) IsOnFrameBoundary() bool { return r.lastMarker } -func (r *RTPMunger) updateSnOffset(cause string) { +func (r *RTPMunger) updateSnOffset() { snOffset, err := r.snRangeMap.GetValue(r.extHighestIncomingSN + 1) if err != nil { r.logger.Errorw("could not get sequence number offset", err) } r.snOffset = snOffset - r.logger.Debugw( - "updating sequence number offset", - "cause", cause, - "extHighestIncomingSN", r.extHighestIncomingSN, - "snOffset", r.snOffset, - ) } diff --git a/pkg/sfu/sequencer.go b/pkg/sfu/sequencer.go index 14280d0a0..9bf53de74 100644 --- a/pkg/sfu/sequencer.go +++ b/pkg/sfu/sequencer.go @@ -144,8 +144,8 @@ func (s *sequencer) push( s.extHighestSN = extModifiedSN } else { if diff < -int64(s.size) { - s.logger.Debugw( - "old packet, cannot be sequenced", + s.logger.Warnw( + "old packet, cannot be sequenced", nil, "extHighestSN", s.extHighestSN, "extIncomingSN", extIncomingSN, "extModifiedSN", extModifiedSN, @@ -189,7 +189,6 @@ func (s *sequencer) pushPadding(extStartSNInclusive uint64, extEndSNInclusive ui s.Lock() defer s.Unlock() - s.logger.Debugw("sequencer padding", "extHighestSN", s.extHighestSN, "startSN", extStartSNInclusive, "endSN", extEndSNInclusive) if s.snRangeMap == nil { return } From f29887dcd071683a67584e1e3ba2633a865fa76c Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sat, 16 Sep 2023 02:03:50 +0530 Subject: [PATCH 4/8] Use bit map. (#2075) * Use bit map. Also, duplicate packet detection is impoetant for dropping padding only packets at the publisher side itself. In the last PR, mentioned that it is only for stats. * clean up * Update deps --- go.mod | 2 +- go.sum | 4 +- pkg/sfu/buffer/rtpstats_receiver.go | 83 +++++------------------- pkg/sfu/buffer/rtpstats_receiver_test.go | 10 +-- 4 files changed, 25 insertions(+), 74 deletions(-) diff --git a/go.mod b/go.mod index 0df7de2d2..95aec396b 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20230906055425-e81fd5f6fb3f - github.com/livekit/protocol v1.7.3-0.20230911160509-47d330eafb32 + github.com/livekit/protocol v1.7.3-0.20230915202328-cf9f95141e0e github.com/livekit/psrpc v0.3.3 github.com/mackerelio/go-osstat v0.2.4 github.com/magefile/mage v1.15.0 diff --git a/go.sum b/go.sum index 6119003f1..ae9768dd2 100644 --- a/go.sum +++ b/go.sum @@ -127,8 +127,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20230906055425-e81fd5f6fb3f h1:b4ri7hQESRSzJWzXXcmANG2hJ4HTj5LM01Ekm8lnQmg= github.com/livekit/mediatransportutil v0.0.0-20230906055425-e81fd5f6fb3f/go.mod h1:+WIOYwiBMive5T81V8B2wdAc2zQNRjNQiJIcPxMTILY= -github.com/livekit/protocol v1.7.3-0.20230911160509-47d330eafb32 h1:5PdmCpGGXA2hz1pKGgKSJYTjmk3Kkm+kNiW5NOFARCI= -github.com/livekit/protocol v1.7.3-0.20230911160509-47d330eafb32/go.mod h1:zbh0QPUcLGOeZeIO/VeigwWWbudz4Lv+Px94FnVfQH0= +github.com/livekit/protocol v1.7.3-0.20230915202328-cf9f95141e0e h1:WEet0iH/JazBFNhhH+YuZHtXpKefb7mnbCC2al3peyA= +github.com/livekit/protocol v1.7.3-0.20230915202328-cf9f95141e0e/go.mod h1:zbh0QPUcLGOeZeIO/VeigwWWbudz4Lv+Px94FnVfQH0= github.com/livekit/psrpc v0.3.3 h1:+lltbuN39IdaynXhLLxRShgYqYsRMWeeXKzv60oqyWo= github.com/livekit/psrpc v0.3.3/go.mod h1:n6JntEg+zT6Ji8InoyTpV7wusPNwGqqtxmHlkNhDN0U= github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs= diff --git a/pkg/sfu/buffer/rtpstats_receiver.go b/pkg/sfu/buffer/rtpstats_receiver.go index 0d41e73d7..2b409c990 100644 --- a/pkg/sfu/buffer/rtpstats_receiver.go +++ b/pkg/sfu/buffer/rtpstats_receiver.go @@ -22,6 +22,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/utils" "github.com/livekit/protocol/livekit" + protoutils "github.com/livekit/protocol/utils" ) const ( @@ -52,7 +53,7 @@ type RTPStatsReceiver struct { timestamp *utils.WrapAround[uint32, uint64] - history [cHistorySize / 64]uint64 + history *protoutils.Bitmap[uint64] } func NewRTPStatsReceiver(params RTPStatsParams) *RTPStatsReceiver { @@ -60,6 +61,7 @@ func NewRTPStatsReceiver(params RTPStatsParams) *RTPStatsReceiver { rtpStatsBase: newRTPStatsBase(params), sequenceNumber: utils.NewWrapAround[uint16, uint64](), timestamp: utils.NewWrapAround[uint32, uint64](), + history: protoutils.NewBitmap[uint64](cHistorySize), } } @@ -173,14 +175,16 @@ func (r *RTPStatsReceiver) Update( ) } - if !r.isLost(resSN.ExtendedVal, resSN.PreExtendedHighest) { - r.bytesDuplicate += pktSize - r.headerBytesDuplicate += uint64(hdrSize) - r.packetsDuplicate++ - flowState.IsDuplicate = true - } else { - r.packetsLost-- - r.setHistory(resSN.ExtendedVal, resSN.PreExtendedHighest) + if r.isInRange(resSN.ExtendedVal, resSN.PreExtendedHighest) { + if r.history.IsSet(resSN.ExtendedVal) { + r.bytesDuplicate += pktSize + r.headerBytesDuplicate += uint64(hdrSize) + r.packetsDuplicate++ + flowState.IsDuplicate = true + } else { + r.packetsLost-- + r.history.Set(resSN.ExtendedVal) + } } flowState.IsOutOfOrder = true @@ -191,10 +195,10 @@ func (r *RTPStatsReceiver) Update( r.updateGapHistogram(int(gapSN)) // update missing sequence numbers - r.clearHistory(resSN.PreExtendedHighest+1, resSN.ExtendedVal, resSN.PreExtendedHighest) + r.history.ClearRange(resSN.PreExtendedHighest+1, resSN.ExtendedVal-1) r.packetsLost += uint64(gapSN - 1) - r.setHistory(resSN.ExtendedVal, resSN.PreExtendedHighest) + r.history.Set(resSN.ExtendedVal) if timestamp != uint32(resTS.PreExtendedHighest) { // update only on first packet as same timestamp could be in multiple packets. @@ -473,62 +477,9 @@ func (r *RTPStatsReceiver) ToProto() *livekit.RTPStats { ) } -func (r *RTPStatsReceiver) getOutOfOrderHistorySlot(esn uint64, ehsn uint64) (int, int) { +func (r *RTPStatsReceiver) isInRange(esn uint64, ehsn uint64) bool { diff := int64(ehsn - esn) - if diff >= cHistorySize || diff < 0 { - // too old OR too new (i. e. ahead of highest) - return -1, -1 - } - - return int(esn) % len(r.history), int(esn & 63) -} - -func (r *RTPStatsReceiver) getHistorySlot(esn uint64, ehsn uint64) (int, int) { - if int64(esn-ehsn) < 0 { - return r.getOutOfOrderHistorySlot(esn, ehsn) - } - - return int(esn) % len(r.history), int(esn & 63) -} - -func (r *RTPStatsReceiver) setHistory(esn uint64, ehsn uint64) { - slot, offset := r.getHistorySlot(esn, ehsn) - if slot < 0 { - return - } - - r.history[slot] |= (1 << offset) -} - -func (r *RTPStatsReceiver) clearHistory(extStartInclusive uint64, extEndExclusive uint64, ehsn uint64) { - if extEndExclusive <= extStartInclusive { - return - } - - slot, offset := r.getHistorySlot(extStartInclusive, ehsn) - if slot < 0 { - return - } - for esn := extStartInclusive; esn != extEndExclusive; esn++ { - r.history[slot] &= ^(1 << offset) - offset++ - if offset > 63 { - offset -= 64 - slot++ - if slot >= len(r.history) { - slot -= len(r.history) - } - } - } -} - -func (r *RTPStatsReceiver) isLost(esn uint64, ehsn uint64) bool { - slot, offset := r.getHistorySlot(esn, ehsn) - if slot < 0 { - return false - } - - return r.history[slot]&(1<= 0 && diff < cHistorySize } // ---------------------------------- diff --git a/pkg/sfu/buffer/rtpstats_receiver_test.go b/pkg/sfu/buffer/rtpstats_receiver_test.go index 3fb648ba0..34fea0f4b 100644 --- a/pkg/sfu/buffer/rtpstats_receiver_test.go +++ b/pkg/sfu/buffer/rtpstats_receiver_test.go @@ -224,7 +224,7 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.Equal(t, uint64(sequenceNumber-1), flowState.LossStartInclusive) require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive) require.Equal(t, uint64(17), r.packetsLost) - require.True(t, r.isLost(uint64(sequenceNumber)-1, r.sequenceNumber.GetExtendedHighest())) + require.False(t, r.history.IsSet(uint64(sequenceNumber)-1)) // out-of-order sequenceNumber-- @@ -242,7 +242,7 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.False(t, flowState.HasLoss) require.Equal(t, uint64(16), r.packetsLost) require.Equal(t, uint64(4), r.packetsOutOfOrder) - require.False(t, r.isLost(uint64(sequenceNumber), r.sequenceNumber.GetExtendedHighest())) + require.True(t, r.history.IsSet(uint64(sequenceNumber))) // padding only sequenceNumber += 2 @@ -259,9 +259,9 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.False(t, flowState.HasLoss) require.Equal(t, uint64(16), r.packetsLost) require.Equal(t, uint64(4), r.packetsOutOfOrder) - require.False(t, r.isLost(uint64(sequenceNumber), r.sequenceNumber.GetExtendedHighest())) - require.False(t, r.isLost(uint64(sequenceNumber)-1, r.sequenceNumber.GetExtendedHighest())) - require.False(t, r.isLost(uint64(sequenceNumber)-2, r.sequenceNumber.GetExtendedHighest())) + require.True(t, r.history.IsSet(uint64(sequenceNumber))) + require.True(t, r.history.IsSet(uint64(sequenceNumber)-1)) + require.True(t, r.history.IsSet(uint64(sequenceNumber)-2)) r.Stop() } From 340906267f63adb56bb2865aebee17b3b5b89a23 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Sat, 16 Sep 2023 00:15:04 -0700 Subject: [PATCH 5/8] Reduce ghost participant disconnect timeout (#2077) It's been reported that "ghost" participants, those that did not terminate cleanly, hang around the room for too long after they disappear. Evaluating our timeouts a bit, it seems that we are really conservative in waiting for participants to disconnect. This PR cuts down the disconnect timeout from 50s to 20s, a 30s reduction. --- pkg/rtc/participant.go | 10 +++++++--- pkg/rtc/transport.go | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 334a37086..39cce9073 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -56,7 +56,7 @@ const ( sdBatchSize = 30 rttUpdateInterval = 5 * time.Second - disconnectCleanupDuration = 15 * time.Second + disconnectCleanupDuration = 5 * time.Second migrationWaitDuration = 3 * time.Second ) @@ -561,7 +561,9 @@ func (p *ParticipantImpl) HandleSignalSourceClose() { p.TransportManager.SetSignalSourceValid(false) if !p.TransportManager.HasPublisherEverConnected() && !p.TransportManager.HasSubscriberEverConnected() { - p.params.Logger.Infow("closing disconnected participant") + p.params.Logger.Infow("closing disconnected participant", + "reason", types.ParticipantCloseReasonJoinFailed, + ) _ = p.Close(false, types.ParticipantCloseReasonJoinFailed, false) } } @@ -1402,7 +1404,9 @@ func (p *ParticipantImpl) setupDisconnectTimer() { if p.IsClosed() || p.IsDisconnected() { return } - p.params.Logger.Infow("closing disconnected participant") + p.params.Logger.Infow("closing disconnected participant", + "reason", types.ParticipantCloseReasonPeerConnectionDisconnected, + ) _ = p.Close(true, types.ParticipantCloseReasonPeerConnectionDisconnected, false) }) p.lock.Unlock() diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index b671649cd..33ba7ade2 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -60,7 +60,7 @@ const ( dtlsRetransmissionInterval = 100 * time.Millisecond iceDisconnectedTimeout = 10 * time.Second // compatible for ice-lite with firefox client - iceFailedTimeout = 25 * time.Second // pion's default + iceFailedTimeout = 5 * time.Second // time between disconnected and failed iceKeepaliveInterval = 2 * time.Second // pion's default minTcpICEConnectTimeout = 5 * time.Second From 97048a923c2278dcd9ca46fad605ed852ac4c6f8 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sat, 16 Sep 2023 18:54:18 +0530 Subject: [PATCH 6/8] Reducing rtp stats memory consumption - part 2 (#2078) * WIP commit * move a struct to sender only * Snapshot intervals * make receiver history 4K too --- pkg/sfu/buffer/rtpstats_base.go | 151 +--------- pkg/sfu/buffer/rtpstats_receiver.go | 9 +- pkg/sfu/buffer/rtpstats_sender.go | 441 +++++++++++++++++++++------- 3 files changed, 352 insertions(+), 249 deletions(-) diff --git a/pkg/sfu/buffer/rtpstats_base.go b/pkg/sfu/buffer/rtpstats_base.go index 3fad1907c..d25341c8b 100644 --- a/pkg/sfu/buffer/rtpstats_base.go +++ b/pkg/sfu/buffer/rtpstats_base.go @@ -30,8 +30,6 @@ const ( cGapHistogramNumBins = 101 cNumSequenceNumbers = 65536 cFirstSnapshotID = 1 - cSnInfoSize = 8192 - cSnInfoMask = cSnInfoSize - 1 cFirstPacketTimeAdjustWindow = 2 * time.Minute cFirstPacketTimeAdjustThreshold = 5 * time.Second @@ -53,18 +51,6 @@ func RTPDriftToString(r *livekit.RTPDrift) string { // ------------------------------------------------------- -type intervalStats struct { - packets uint64 - bytes uint64 - headerBytes uint64 - packetsPadding uint64 - bytesPadding uint64 - headerBytesPadding uint64 - packetsLost uint64 - packetsOutOfOrder uint64 - frames uint32 -} - type RTPDeltaInfo struct { StartTime time.Time Duration time.Duration @@ -119,14 +105,6 @@ type snapshot struct { maxJitter float64 } -type snInfo struct { - hdrSize uint16 - pktSize uint16 - isPaddingOnly bool - marker bool - isOutOfOrder bool -} - type RTCPSenderReportData struct { RTPTimestamp uint32 RTPTimestampExt uint64 @@ -177,8 +155,6 @@ type rtpStatsBase struct { jitter float64 maxJitter float64 - snInfos [cSnInfoSize]snInfo - gapHistogram [cGapHistogramNumBins]uint32 nacks uint32 @@ -251,8 +227,6 @@ func (r *rtpStatsBase) seed(from *rtpStatsBase) bool { r.jitter = from.jitter r.maxJitter = from.maxJitter - r.snInfos = from.snInfos - r.gapHistogram = from.gapHistogram r.nacks = from.nacks @@ -309,14 +283,14 @@ func (r *rtpStatsBase) newSnapshotID(extStartSN uint64) uint32 { id := r.nextSnapshotID r.nextSnapshotID++ - if cap(r.snapshots) < int(r.nextSnapshotID) { - snapshots := make([]snapshot, r.nextSnapshotID) + if cap(r.snapshots) < int(r.nextSnapshotID-cFirstSnapshotID) { + snapshots := make([]snapshot, r.nextSnapshotID-cFirstSnapshotID) copy(snapshots, r.snapshots) r.snapshots = snapshots } if r.initialized { - r.snapshots[id] = r.initSnapshot(time.Now(), extStartSN) + r.snapshots[id-cFirstSnapshotID] = r.initSnapshot(time.Now(), extStartSN) } return id } @@ -467,7 +441,8 @@ func (r *rtpStatsBase) UpdateRtt(rtt uint32) { r.maxRtt = rtt } - for _, s := range r.snapshots { + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + s := &r.snapshots[i] if rtt > s.maxRtt { s.maxRtt = rtt } @@ -545,7 +520,6 @@ func (r *rtpStatsBase) getTotalPacketsPrimary(extStartSN, extHighestSN uint64) u func (r *rtpStatsBase) deltaInfo(snapshotID uint32, extStartSN uint64, extHighestSN uint64) *RTPDeltaInfo { then, now := r.getAndResetSnapshot(snapshotID, extStartSN, extHighestSN) - if now == nil || then == nil { return nil } @@ -772,107 +746,6 @@ func (r *rtpStatsBase) toProto( return p } -func (r *rtpStatsBase) getSnInfoOutOfOrderSlot(esn uint64, ehsn uint64) int { - offset := int64(ehsn - esn) - if offset >= cSnInfoSize || offset < 0 { - // too old OR too new (i. e. ahead of highest) - return -1 - } - - return int(esn & cSnInfoMask) -} - -func (r *rtpStatsBase) setSnInfo(esn uint64, ehsn uint64, pktSize uint16, hdrSize uint16, payloadSize uint16, marker bool, isOutOfOrder bool) { - var slot int - if int64(esn-ehsn) < 0 { - slot = r.getSnInfoOutOfOrderSlot(esn, ehsn) - if slot < 0 { - return - } - } else { - slot = int(esn & cSnInfoMask) - } - - snInfo := &r.snInfos[slot] - snInfo.pktSize = pktSize - snInfo.hdrSize = hdrSize - snInfo.isPaddingOnly = payloadSize == 0 - snInfo.marker = marker - snInfo.isOutOfOrder = isOutOfOrder -} - -func (r *rtpStatsBase) clearSnInfos(extStartInclusive uint64, extEndExclusive uint64) { - if extEndExclusive <= extStartInclusive { - return - } - - for esn := extStartInclusive; esn != extEndExclusive; esn++ { - snInfo := &r.snInfos[esn&cSnInfoMask] - snInfo.pktSize = 0 - snInfo.hdrSize = 0 - snInfo.isPaddingOnly = false - snInfo.marker = false - } -} - -func (r *rtpStatsBase) isSnInfoLost(esn uint64, ehsn uint64) bool { - slot := r.getSnInfoOutOfOrderSlot(esn, ehsn) - if slot < 0 { - return false - } - - return r.snInfos[slot].pktSize == 0 -} - -func (r *rtpStatsBase) getIntervalStats(extStartInclusive uint64, extEndExclusive uint64, ehsn uint64) (intervalStats intervalStats) { - packetsNotFound := uint32(0) - processESN := func(esn uint64, ehsn uint64) { - slot := r.getSnInfoOutOfOrderSlot(esn, ehsn) - if slot < 0 { - packetsNotFound++ - return - } - - snInfo := &r.snInfos[slot] - switch { - case snInfo.pktSize == 0: - intervalStats.packetsLost++ - - case snInfo.isPaddingOnly: - intervalStats.packetsPadding++ - intervalStats.bytesPadding += uint64(snInfo.pktSize) - intervalStats.headerBytesPadding += uint64(snInfo.hdrSize) - - default: - intervalStats.packets++ - intervalStats.bytes += uint64(snInfo.pktSize) - intervalStats.headerBytes += uint64(snInfo.hdrSize) - if snInfo.isOutOfOrder { - intervalStats.packetsOutOfOrder++ - } - } - - if snInfo.marker { - intervalStats.frames++ - } - } - - for esn := extStartInclusive; esn != extEndExclusive; esn++ { - processESN(esn, ehsn) - } - - if packetsNotFound != 0 { - r.logger.Errorw( - "could not find some packets", nil, - "start", extStartInclusive, - "end", extEndExclusive, - "count", packetsNotFound, - "highestSN", ehsn, - ) - } - return -} - func (r *rtpStatsBase) updateJitter(ets uint64, packetTime time.Time) float64 { // Do not update jitter on multiple packets of same frame. // All packets of a frame have the same time stamp. @@ -896,7 +769,8 @@ func (r *rtpStatsBase) updateJitter(ets uint64, packetTime time.Time) float64 { r.maxJitter = r.jitter } - for _, s := range r.snapshots { + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + s := &r.snapshots[i] if r.jitter > s.maxJitter { s.maxJitter = r.jitter } @@ -914,15 +788,16 @@ func (r *rtpStatsBase) getAndResetSnapshot(snapshotID uint32, extStartSN uint64, return nil, nil } - then := r.snapshots[snapshotID] + idx := snapshotID - cFirstSnapshotID + then := r.snapshots[idx] if !then.isValid { then = r.initSnapshot(r.startTime, extStartSN) - r.snapshots[snapshotID] = then + r.snapshots[idx] = then } // snapshot now now := r.getSnapshot(time.Now(), extHighestSN+1) - r.snapshots[snapshotID] = now + r.snapshots[idx] = now return &then, &now } @@ -983,7 +858,7 @@ func (r *rtpStatsBase) updateGapHistogram(gap int) { func (r *rtpStatsBase) initSnapshot(startTime time.Time, extStartSN uint64) snapshot { return snapshot{ isValid: true, - startTime: time.Now(), + startTime: startTime, extStartSN: extStartSN, } } @@ -991,7 +866,7 @@ func (r *rtpStatsBase) initSnapshot(startTime time.Time, extStartSN uint64) snap func (r *rtpStatsBase) getSnapshot(startTime time.Time, extStartSN uint64) snapshot { return snapshot{ isValid: true, - startTime: time.Now(), + startTime: startTime, extStartSN: extStartSN, bytes: r.bytes, headerBytes: r.headerBytes, diff --git a/pkg/sfu/buffer/rtpstats_receiver.go b/pkg/sfu/buffer/rtpstats_receiver.go index 2b409c990..fff35b2f5 100644 --- a/pkg/sfu/buffer/rtpstats_receiver.go +++ b/pkg/sfu/buffer/rtpstats_receiver.go @@ -26,7 +26,7 @@ import ( ) const ( - cHistorySize = 2048 + cHistorySize = 4096 ) type RTPFlowState struct { @@ -69,7 +69,7 @@ func (r *RTPStatsReceiver) NewSnapshotId() uint32 { r.lock.Lock() defer r.lock.Unlock() - return r.newSnapshotID(r.sequenceNumber.GetExtendedStart()) + return r.newSnapshotID(r.sequenceNumber.GetExtendedHighest()) } func (r *RTPStatsReceiver) Update( @@ -114,7 +114,7 @@ func (r *RTPStatsReceiver) Update( resTS = r.timestamp.Update(timestamp) // initialize snapshots if any - for i := uint32(cFirstSnapshotID); i < r.nextSnapshotID; i++ { + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { r.snapshots[i] = r.initSnapshot(r.startTime, r.sequenceNumber.GetExtendedStart()) } @@ -154,7 +154,8 @@ func (r *RTPStatsReceiver) Update( r.packetsLost += resSN.PreExtendedStart - resSN.ExtendedVal extStartSN := r.sequenceNumber.GetExtendedStart() - for _, s := range r.snapshots { + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + s := &r.snapshots[i] if s.extStartSN == resSN.PreExtendedStart { s.extStartSN = extStartSN } diff --git a/pkg/sfu/buffer/rtpstats_sender.go b/pkg/sfu/buffer/rtpstats_sender.go index 670032e83..5a66c65d4 100644 --- a/pkg/sfu/buffer/rtpstats_sender.go +++ b/pkg/sfu/buffer/rtpstats_sender.go @@ -25,11 +25,91 @@ import ( "github.com/livekit/protocol/livekit" ) +const ( + cSnInfoSize = 4096 + cSnInfoMask = cSnInfoSize - 1 +) + +type snInfoFlag byte + +const ( + snInfoFlagMarker snInfoFlag = 1 << iota + snInfoFlagPadding + snInfoFlagOutOfOrder +) + +type snInfo struct { + pktSize uint16 + hdrSize uint8 + flags snInfoFlag +} + +// ------------------------------------------------------------------- + +type intervalStats struct { + packets uint64 + bytes uint64 + headerBytes uint64 + packetsPadding uint64 + bytesPadding uint64 + headerBytesPadding uint64 + packetsLost uint64 + packetsOutOfOrder uint64 + frames uint32 +} + +func (is *intervalStats) aggregate(other *intervalStats) { + if is == nil || other == nil { + return + } + + is.packets += other.packets + is.bytes += other.bytes + is.headerBytes += other.headerBytes + is.packetsPadding += other.packetsPadding + is.bytesPadding += other.bytesPadding + is.headerBytesPadding += other.headerBytesPadding + is.packetsLost += other.packetsLost + is.packetsOutOfOrder += other.packetsOutOfOrder + is.frames += other.frames +} + +// ------------------------------------------------------------------- + type senderSnapshot struct { - snapshot - extStartSNFromRR uint64 - packetsLostFromRR uint64 - maxJitterFromRR float64 + isValid bool + + startTime time.Time + + extStartSN uint64 + bytes uint64 + headerBytes uint64 + + packetsPadding uint64 + bytesPadding uint64 + headerBytesPadding uint64 + + packetsDuplicate uint64 + bytesDuplicate uint64 + headerBytesDuplicate uint64 + + packetsOutOfOrder uint64 + + packetsLostFeed uint64 + packetsLost uint64 + + frames uint32 + + nacks uint32 + plis uint32 + firs uint32 + + maxRtt uint32 + maxJitterFeed float64 + maxJitter float64 + + extLastRRSN uint64 + intervalStats intervalStats } type RTPStatsSender struct { @@ -50,6 +130,8 @@ type RTPStatsSender struct { jitterFromRR float64 maxJitterFromRR float64 + snInfos [cSnInfoSize]snInfo + nextSenderSnapshotID uint32 senderSnapshots []senderSnapshot } @@ -85,6 +167,8 @@ func (r *RTPStatsSender) Seed(from *RTPStatsSender) { r.jitterFromRR = from.jitterFromRR r.maxJitterFromRR = from.maxJitterFromRR + r.snInfos = from.snInfos + r.nextSenderSnapshotID = from.nextSenderSnapshotID r.senderSnapshots = make([]senderSnapshot, cap(from.senderSnapshots)) copy(r.senderSnapshots, from.senderSnapshots) @@ -94,7 +178,7 @@ func (r *RTPStatsSender) NewSnapshotId() uint32 { r.lock.Lock() defer r.lock.Unlock() - return r.newSnapshotID(r.extStartSN) + return r.newSnapshotID(r.extHighestSN) } func (r *RTPStatsSender) NewSenderSnapshotId() uint32 { @@ -104,17 +188,14 @@ func (r *RTPStatsSender) NewSenderSnapshotId() uint32 { id := r.nextSenderSnapshotID r.nextSenderSnapshotID++ - if cap(r.senderSnapshots) < int(r.nextSenderSnapshotID) { - senderSnapshots := make([]senderSnapshot, r.nextSenderSnapshotID) + if cap(r.senderSnapshots) < int(r.nextSenderSnapshotID-cFirstSnapshotID) { + senderSnapshots := make([]senderSnapshot, r.nextSenderSnapshotID-cFirstSnapshotID) copy(senderSnapshots, r.senderSnapshots) r.senderSnapshots = senderSnapshots } if r.initialized { - r.senderSnapshots[id] = senderSnapshot{ - snapshot: r.initSnapshot(time.Now(), r.extStartSN), - extStartSNFromRR: r.extStartSN, - } + r.senderSnapshots[id-cFirstSnapshotID] = r.initSenderSnapshot(time.Now(), r.extHighestSN) } return id } @@ -155,22 +236,11 @@ func (r *RTPStatsSender) Update( r.extHighestTS = extTimestamp // initialize snapshots if any - for i := uint32(cFirstSnapshotID); i < r.nextSnapshotID; i++ { - r.snapshots[i] = snapshot{ - isValid: true, - startTime: r.startTime, - extStartSN: r.extStartSN, - } + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + r.snapshots[i] = r.initSnapshot(r.startTime, r.extStartSN) } - for i := uint32(cFirstSnapshotID); i < r.nextSenderSnapshotID; i++ { - r.senderSnapshots[i] = senderSnapshot{ - snapshot: snapshot{ - isValid: true, - startTime: r.startTime, - extStartSN: r.extStartSN, - }, - extStartSNFromRR: r.extStartSN, - } + for i := uint32(0); i < r.nextSenderSnapshotID-cFirstSnapshotID; i++ { + r.senderSnapshots[i] = r.initSenderSnapshot(r.startTime, r.extStartSN) } r.logger.Debugw( @@ -195,14 +265,19 @@ func (r *RTPStatsSender) Update( r.packetsLost += r.extStartSN - extSequenceNumber // adjust start of snapshots - for _, s := range r.snapshots { + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + s := &r.snapshots[i] if s.extStartSN == r.extStartSN { s.extStartSN = extSequenceNumber } } - for _, s := range r.senderSnapshots { + for i := uint32(0); i < r.nextSenderSnapshotID-cFirstSnapshotID; i++ { + s := &r.senderSnapshots[i] if s.extStartSN == r.extStartSN { s.extStartSN = extSequenceNumber + if s.extLastRRSN == (r.extStartSN - 1) { + s.extLastRRSN = extSequenceNumber - 1 + } } } @@ -224,7 +299,7 @@ func (r *RTPStatsSender) Update( isDuplicate = true } else { r.packetsLost-- - r.setSnInfo(extSequenceNumber, r.extHighestSN, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), marker, true) + r.setSnInfo(extSequenceNumber, r.extHighestSN, uint16(pktSize), uint8(hdrSize), uint16(payloadSize), marker, true) } } else { // in-order // update gap histogram @@ -234,7 +309,7 @@ func (r *RTPStatsSender) Update( r.clearSnInfos(r.extHighestSN+1, extSequenceNumber) r.packetsLost += uint64(gapSN - 1) - r.setSnInfo(extSequenceNumber, r.extHighestSN, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), marker, false) + r.setSnInfo(extSequenceNumber, r.extHighestSN, uint16(pktSize), uint8(hdrSize), uint16(payloadSize), marker, false) if extTimestamp != r.extHighestTS { // update only on first packet as same timestamp could be in multiple packets. @@ -259,9 +334,10 @@ func (r *RTPStatsSender) Update( } jitter := r.updateJitter(extTimestamp, packetTime) - for _, s := range r.senderSnapshots { - if jitter > s.maxJitter { - s.maxJitter = jitter + for i := uint32(0); i < r.nextSenderSnapshotID-cFirstSnapshotID; i++ { + s := &r.senderSnapshots[i] + if jitter > s.maxJitterFeed { + s.maxJitterFeed = jitter } } } @@ -307,46 +383,7 @@ func (r *RTPStatsSender) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt } } - if r.lastRRTime.IsZero() || r.extHighestSNFromRR <= extHighestSNFromRR { - r.extHighestSNFromRR = extHighestSNFromRR - - packetsLostFromRR := r.packetsLostFromRR&0xFFFF_FFFF_0000_0000 + uint64(rr.TotalLost) - if (rr.TotalLost-r.lastRR.TotalLost) < (1<<31) && rr.TotalLost < r.lastRR.TotalLost { - packetsLostFromRR += (1 << 32) - } - r.packetsLostFromRR = packetsLostFromRR - - if isRttChanged { - r.rtt = rtt - if rtt > r.maxRtt { - r.maxRtt = rtt - } - } - - r.jitterFromRR = float64(rr.Jitter) - if r.jitterFromRR > r.maxJitterFromRR { - r.maxJitterFromRR = r.jitterFromRR - } - - // update snapshots - for _, s := range r.snapshots { - if isRttChanged && rtt > s.maxRtt { - s.maxRtt = rtt - } - } - for _, s := range r.senderSnapshots { - if isRttChanged && rtt > s.maxRtt { - s.maxRtt = rtt - } - - if r.jitterFromRR > s.maxJitterFromRR { - s.maxJitterFromRR = r.jitterFromRR - } - } - - r.lastRRTime = time.Now() - r.lastRR = rr - } else { + if !r.lastRRTime.IsZero() && r.extHighestSNFromRR > extHighestSNFromRR { r.logger.Debugw( fmt.Sprintf("receiver report potentially out of order, highestSN: existing: %d, received: %d", r.extHighestSNFromRR, extHighestSNFromRR), "lastRRTime", r.lastRRTime, @@ -354,7 +391,57 @@ func (r *RTPStatsSender) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt "sinceLastRR", time.Since(r.lastRRTime), "receivedRR", rr, ) + return } + + r.extHighestSNFromRR = extHighestSNFromRR + + packetsLostFromRR := r.packetsLostFromRR&0xFFFF_FFFF_0000_0000 + uint64(rr.TotalLost) + if (rr.TotalLost-r.lastRR.TotalLost) < (1<<31) && rr.TotalLost < r.lastRR.TotalLost { + packetsLostFromRR += (1 << 32) + } + r.packetsLostFromRR = packetsLostFromRR + + if isRttChanged { + r.rtt = rtt + if rtt > r.maxRtt { + r.maxRtt = rtt + } + } + + r.jitterFromRR = float64(rr.Jitter) + if r.jitterFromRR > r.maxJitterFromRR { + r.maxJitterFromRR = r.jitterFromRR + } + + // update snapshots + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + s := &r.snapshots[i] + if isRttChanged && rtt > s.maxRtt { + s.maxRtt = rtt + } + } + + extLastRRSN := r.extHighestSNFromRR + (r.extStartSN & 0xFFFF_FFFF_FFFF_0000) + for i := uint32(0); i < r.nextSenderSnapshotID-cFirstSnapshotID; i++ { + s := &r.senderSnapshots[i] + if isRttChanged && rtt > s.maxRtt { + s.maxRtt = rtt + } + + if r.jitterFromRR > s.maxJitter { + s.maxJitter = r.jitterFromRR + } + + // on every RR, calculate delta since last RR using packet metadata cache + is := r.getIntervalStats(s.extLastRRSN+1, extLastRRSN+1, r.extHighestSN) + eis := &s.intervalStats + eis.aggregate(&is) + s.extLastRRSN = extLastRRSN + } + + r.lastRRTime = time.Now() + r.lastRR = rr return } @@ -492,11 +579,11 @@ func (r *RTPStatsSender) DeltaInfoSender(senderSnapshotID uint32) *RTPDeltaInfo startTime := then.startTime endTime := now.startTime - packetsExpected := now.extStartSNFromRR - then.extStartSNFromRR + packetsExpected := uint32(now.extStartSN - then.extStartSN) if packetsExpected > cNumSequenceNumbers { r.logger.Warnw( "too many packets expected in delta (sender)", - fmt.Errorf("start: %d, end: %d, expected: %d", then.extStartSNFromRR, now.extStartSNFromRR, packetsExpected), + fmt.Errorf("start: %d, end: %d, expected: %d", then.extStartSN, now.extStartSN, packetsExpected), ) return nil } @@ -505,29 +592,31 @@ func (r *RTPStatsSender) DeltaInfoSender(senderSnapshotID uint32) *RTPDeltaInfo return nil } - intervalStats := r.getIntervalStats(then.extStartSNFromRR, now.extStartSNFromRR, r.extHighestSN) - packetsLost := now.packetsLostFromRR - then.packetsLostFromRR + packetsLost := uint32(now.packetsLost - then.packetsLost) if int32(packetsLost) < 0 { packetsLost = 0 } - + packetsLostFeed := uint32(now.packetsLostFeed - then.packetsLostFeed) + if int32(packetsLostFeed) < 0 { + packetsLostFeed = 0 + } if packetsLost > packetsExpected { r.logger.Warnw( "unexpected number of packets lost", fmt.Errorf( - "start: %d, end: %d, expected: %d, lost: report: %d, interval: %d", - then.extStartSNFromRR, - now.extStartSNFromRR, + "start: %d, end: %d, expected: %d, lost: report: %d, feed: %d", + then.extStartSN, + now.extStartSN, packetsExpected, - now.packetsLostFromRR-then.packetsLostFromRR, - intervalStats.packetsLost, + packetsLost, + packetsLostFeed, ), ) packetsLost = packetsExpected } // discount jitter from publisher side + internal processing - maxJitter := then.maxJitterFromRR - then.maxJitter + maxJitter := then.maxJitter - then.maxJitterFeed if maxJitter < 0.0 { maxJitter = 0.0 } @@ -536,19 +625,19 @@ func (r *RTPStatsSender) DeltaInfoSender(senderSnapshotID uint32) *RTPDeltaInfo return &RTPDeltaInfo{ StartTime: startTime, Duration: endTime.Sub(startTime), - Packets: uint32(packetsExpected - intervalStats.packetsPadding), - Bytes: intervalStats.bytes, - HeaderBytes: intervalStats.headerBytes, + Packets: packetsExpected - uint32(now.packetsPadding-then.packetsPadding), + Bytes: now.bytes - then.bytes, + HeaderBytes: now.headerBytes - then.headerBytes, PacketsDuplicate: uint32(now.packetsDuplicate - then.packetsDuplicate), BytesDuplicate: now.bytesDuplicate - then.bytesDuplicate, HeaderBytesDuplicate: now.headerBytesDuplicate - then.headerBytesDuplicate, - PacketsPadding: uint32(intervalStats.packetsPadding), - BytesPadding: intervalStats.bytesPadding, - HeaderBytesPadding: intervalStats.headerBytesPadding, - PacketsLost: uint32(packetsLost), - PacketsMissing: uint32(intervalStats.packetsLost), - PacketsOutOfOrder: uint32(intervalStats.packetsOutOfOrder), - Frames: intervalStats.frames, + PacketsPadding: uint32(now.packetsPadding - then.packetsPadding), + BytesPadding: now.bytesPadding - then.bytesPadding, + HeaderBytesPadding: now.headerBytesPadding - then.headerBytesPadding, + PacketsLost: packetsLost, + PacketsMissing: packetsLostFeed, + PacketsOutOfOrder: uint32(now.packetsOutOfOrder - then.packetsOutOfOrder), + Frames: now.frames - then.frames, RttMax: then.maxRtt, JitterMax: maxJitterTime, Nacks: now.nacks - then.nacks, @@ -584,24 +673,162 @@ func (r *RTPStatsSender) getAndResetSenderSnapshot(senderSnapshotID uint32) (*se return nil, nil } - then := r.senderSnapshots[senderSnapshotID] + idx := senderSnapshotID - cFirstSnapshotID + then := r.senderSnapshots[idx] if !then.isValid { - then = senderSnapshot{ - snapshot: r.initSnapshot(r.startTime, r.extStartSN), - extStartSNFromRR: r.extStartSN, - } - r.senderSnapshots[senderSnapshotID] = then + then = r.initSenderSnapshot(r.startTime, r.extStartSN) + r.senderSnapshots[idx] = then } // snapshot now - now := senderSnapshot{ - snapshot: r.getSnapshot(r.lastRRTime, r.extHighestSN+1), - extStartSNFromRR: r.extHighestSNFromRR + (r.extStartSN & 0xFFFF_FFFF_FFFF_0000) + 1, - packetsLostFromRR: r.packetsLostFromRR, - maxJitterFromRR: r.jitterFromRR, - } - r.senderSnapshots[senderSnapshotID] = now + now := r.getSenderSnapshot(r.lastRRTime, &then) + r.senderSnapshots[idx] = now return &then, &now } +func (r *RTPStatsSender) initSenderSnapshot(startTime time.Time, extStartSN uint64) senderSnapshot { + return senderSnapshot{ + isValid: true, + startTime: startTime, + extStartSN: extStartSN, + extLastRRSN: extStartSN - 1, + } +} + +func (r *RTPStatsSender) getSenderSnapshot(startTime time.Time, s *senderSnapshot) senderSnapshot { + if s == nil { + return senderSnapshot{} + } + + return senderSnapshot{ + isValid: true, + startTime: startTime, + extStartSN: s.extLastRRSN + 1, + bytes: s.bytes + s.intervalStats.bytes, + headerBytes: s.headerBytes + s.intervalStats.headerBytes, + packetsPadding: s.packetsPadding + s.intervalStats.packetsPadding, + bytesPadding: s.bytesPadding + s.intervalStats.bytesPadding, + headerBytesPadding: s.headerBytesPadding + s.intervalStats.headerBytesPadding, + packetsDuplicate: r.packetsDuplicate, + bytesDuplicate: r.bytesDuplicate, + headerBytesDuplicate: r.headerBytesDuplicate, + packetsLostFeed: r.packetsLost, + packetsOutOfOrder: s.packetsOutOfOrder + s.intervalStats.packetsOutOfOrder, + frames: s.frames + s.intervalStats.frames, + nacks: r.nacks, + plis: r.plis, + firs: r.firs, + maxRtt: r.rtt, + maxJitterFeed: r.jitter, + maxJitter: r.jitterFromRR, + extLastRRSN: s.extLastRRSN, + } +} + +func (r *RTPStatsSender) getSnInfoOutOfOrderSlot(esn uint64, ehsn uint64) int { + offset := int64(ehsn - esn) + if offset >= cSnInfoSize || offset < 0 { + // too old OR too new (i. e. ahead of highest) + return -1 + } + + return int(esn & cSnInfoMask) +} + +func (r *RTPStatsSender) setSnInfo(esn uint64, ehsn uint64, pktSize uint16, hdrSize uint8, payloadSize uint16, marker bool, isOutOfOrder bool) { + var slot int + if int64(esn-ehsn) < 0 { + slot = r.getSnInfoOutOfOrderSlot(esn, ehsn) + if slot < 0 { + return + } + } else { + slot = int(esn & cSnInfoMask) + } + + snInfo := &r.snInfos[slot] + snInfo.pktSize = pktSize + snInfo.hdrSize = hdrSize + if marker { + snInfo.flags |= snInfoFlagMarker + } + if payloadSize == 0 { + snInfo.flags |= snInfoFlagPadding + } + if isOutOfOrder { + snInfo.flags |= snInfoFlagOutOfOrder + } +} + +func (r *RTPStatsSender) clearSnInfos(extStartInclusive uint64, extEndExclusive uint64) { + if extEndExclusive <= extStartInclusive { + return + } + + for esn := extStartInclusive; esn != extEndExclusive; esn++ { + snInfo := &r.snInfos[esn&cSnInfoMask] + snInfo.pktSize = 0 + snInfo.hdrSize = 0 + snInfo.flags = 0 + } +} + +func (r *RTPStatsSender) isSnInfoLost(esn uint64, ehsn uint64) bool { + slot := r.getSnInfoOutOfOrderSlot(esn, ehsn) + if slot < 0 { + return false + } + + return r.snInfos[slot].pktSize == 0 +} + +func (r *RTPStatsSender) getIntervalStats(extStartInclusive uint64, extEndExclusive uint64, ehsn uint64) (intervalStats intervalStats) { + packetsNotFound := uint32(0) + processESN := func(esn uint64, ehsn uint64) { + slot := r.getSnInfoOutOfOrderSlot(esn, ehsn) + if slot < 0 { + packetsNotFound++ + return + } + + snInfo := &r.snInfos[slot] + switch { + case snInfo.pktSize == 0: + intervalStats.packetsLost++ + + case snInfo.flags&snInfoFlagPadding != 0: + intervalStats.packetsPadding++ + intervalStats.bytesPadding += uint64(snInfo.pktSize) + intervalStats.headerBytesPadding += uint64(snInfo.hdrSize) + + default: + intervalStats.packets++ + intervalStats.bytes += uint64(snInfo.pktSize) + intervalStats.headerBytes += uint64(snInfo.hdrSize) + if (snInfo.flags & snInfoFlagOutOfOrder) != 0 { + intervalStats.packetsOutOfOrder++ + } + } + + if (snInfo.flags & snInfoFlagMarker) != 0 { + intervalStats.frames++ + } + } + + for esn := extStartInclusive; esn != extEndExclusive; esn++ { + processESN(esn, ehsn) + } + + if packetsNotFound != 0 { + r.logger.Errorw( + "could not find some packets", nil, + "start", extStartInclusive, + "end", extEndExclusive, + "count", packetsNotFound, + "highestSN", ehsn, + ) + } + return +} + // ------------------------------------------------------------------- From 019ad88b08c2038498763e0bb55e5d417474145e Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sun, 17 Sep 2023 14:00:09 +0530 Subject: [PATCH 7/8] Do not force reconnect on resume if there is a pending track (#2081) * Do not force reconnect on resume if there is a pending track * move GetPendingTrack -> LocalParticipant --- pkg/rtc/participant.go | 13 ++++ pkg/rtc/room.go | 4 + pkg/rtc/types/interfaces.go | 3 +- .../typesfakes/fake_local_participant.go | 74 +++++++++++++++++++ 4 files changed, 93 insertions(+), 1 deletion(-) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 39cce9073..07ae3036b 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -1637,6 +1637,19 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l return ti } +func (p *ParticipantImpl) GetPendingTrack(trackID livekit.TrackID) *livekit.TrackInfo { + p.pendingTracksLock.RLock() + defer p.pendingTracksLock.RUnlock() + + for _, t := range p.pendingTracks { + if livekit.TrackID(t.trackInfos[0].Sid) == trackID { + return t.trackInfos[0] + } + } + + return nil +} + func (p *ParticipantImpl) sendTrackPublished(cid string, ti *livekit.TrackInfo) { p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", ti.String()) _ = p.writeMessage(&livekit.SignalResponse{ diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 1589ada96..92f50ca01 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -581,6 +581,10 @@ func (r *Room) SyncState(participant types.LocalParticipant, state *livekit.Sync break } } + if !found { + // is there a pending track? + found = participant.GetPendingTrack(livekit.TrackID(ti.Sid)) != nil + } if !found { pLogger.Warnw("unknown track during resume", nil, "trackID", ti.Sid) shouldReconnect = true diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index e5e44d992..0b0e003ad 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -245,7 +245,7 @@ type Participant interface { SetMetadata(metadata string) IsPublisher() bool - GetPublishedTrack(sid livekit.TrackID) MediaTrack + GetPublishedTrack(trackID livekit.TrackID) MediaTrack GetPublishedTracks() []MediaTrack RemovePublishedTrack(track MediaTrack, willBeResumed bool, shouldClose bool) @@ -315,6 +315,7 @@ type LocalParticipant interface { GetICEConnectionType() ICEConnectionType GetBufferFactory() *buffer.Factory GetPlayoutDelayConfig() *livekit.PlayoutDelay + GetPendingTrack(trackID livekit.TrackID) *livekit.TrackInfo SetResponseSink(sink routing.MessageSink) CloseSignalConnection(reason SignallingCloseReason) diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index b5338c33a..fedebacf5 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -263,6 +263,17 @@ type FakeLocalParticipant struct { getPacerReturnsOnCall map[int]struct { result1 pacer.Pacer } + GetPendingTrackStub func(livekit.TrackID) *livekit.TrackInfo + getPendingTrackMutex sync.RWMutex + getPendingTrackArgsForCall []struct { + arg1 livekit.TrackID + } + getPendingTrackReturns struct { + result1 *livekit.TrackInfo + } + getPendingTrackReturnsOnCall map[int]struct { + result1 *livekit.TrackInfo + } GetPlayoutDelayConfigStub func() *livekit.PlayoutDelay getPlayoutDelayConfigMutex sync.RWMutex getPlayoutDelayConfigArgsForCall []struct { @@ -2169,6 +2180,67 @@ func (fake *FakeLocalParticipant) GetPacerReturnsOnCall(i int, result1 pacer.Pac }{result1} } +func (fake *FakeLocalParticipant) GetPendingTrack(arg1 livekit.TrackID) *livekit.TrackInfo { + fake.getPendingTrackMutex.Lock() + ret, specificReturn := fake.getPendingTrackReturnsOnCall[len(fake.getPendingTrackArgsForCall)] + fake.getPendingTrackArgsForCall = append(fake.getPendingTrackArgsForCall, struct { + arg1 livekit.TrackID + }{arg1}) + stub := fake.GetPendingTrackStub + fakeReturns := fake.getPendingTrackReturns + fake.recordInvocation("GetPendingTrack", []interface{}{arg1}) + fake.getPendingTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetPendingTrackCallCount() int { + fake.getPendingTrackMutex.RLock() + defer fake.getPendingTrackMutex.RUnlock() + return len(fake.getPendingTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) GetPendingTrackCalls(stub func(livekit.TrackID) *livekit.TrackInfo) { + fake.getPendingTrackMutex.Lock() + defer fake.getPendingTrackMutex.Unlock() + fake.GetPendingTrackStub = stub +} + +func (fake *FakeLocalParticipant) GetPendingTrackArgsForCall(i int) livekit.TrackID { + fake.getPendingTrackMutex.RLock() + defer fake.getPendingTrackMutex.RUnlock() + argsForCall := fake.getPendingTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) GetPendingTrackReturns(result1 *livekit.TrackInfo) { + fake.getPendingTrackMutex.Lock() + defer fake.getPendingTrackMutex.Unlock() + fake.GetPendingTrackStub = nil + fake.getPendingTrackReturns = struct { + result1 *livekit.TrackInfo + }{result1} +} + +func (fake *FakeLocalParticipant) GetPendingTrackReturnsOnCall(i int, result1 *livekit.TrackInfo) { + fake.getPendingTrackMutex.Lock() + defer fake.getPendingTrackMutex.Unlock() + fake.GetPendingTrackStub = nil + if fake.getPendingTrackReturnsOnCall == nil { + fake.getPendingTrackReturnsOnCall = make(map[int]struct { + result1 *livekit.TrackInfo + }) + } + fake.getPendingTrackReturnsOnCall[i] = struct { + result1 *livekit.TrackInfo + }{result1} +} + func (fake *FakeLocalParticipant) GetPlayoutDelayConfig() *livekit.PlayoutDelay { fake.getPlayoutDelayConfigMutex.Lock() ret, specificReturn := fake.getPlayoutDelayConfigReturnsOnCall[len(fake.getPlayoutDelayConfigArgsForCall)] @@ -5829,6 +5901,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.getLoggerMutex.RUnlock() fake.getPacerMutex.RLock() defer fake.getPacerMutex.RUnlock() + fake.getPendingTrackMutex.RLock() + defer fake.getPendingTrackMutex.RUnlock() fake.getPlayoutDelayConfigMutex.RLock() defer fake.getPlayoutDelayConfigMutex.RUnlock() fake.getPublishedTrackMutex.RLock() From 0b0431b765bcd8ba776aacf0b136db8a9b7a68d2 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Sun, 17 Sep 2023 10:08:35 -0700 Subject: [PATCH 8/8] Per-session TURN credentials (#2080) Switching to using session specific TURN credentials instead of shared credentials per Room. Also eliminates need to load Room from Redis during TURN authentication --- pkg/service/auth.go | 1 + pkg/service/roommanager.go | 77 +++++++++++++++++++++++++------------- pkg/service/turn.go | 55 ++++++++++++++++++++++----- pkg/service/wire.go | 3 +- pkg/service/wire_gen.go | 5 ++- 5 files changed, 104 insertions(+), 37 deletions(-) diff --git a/pkg/service/auth.go b/pkg/service/auth.go index af940e83a..1b5829e30 100644 --- a/pkg/service/auth.go +++ b/pkg/service/auth.go @@ -38,6 +38,7 @@ var ( ErrPermissionDenied = errors.New("permissions denied") ErrMissingAuthorization = errors.New("invalid authorization header. Must start with " + bearerPrefix) ErrInvalidAuthorizationToken = errors.New("invalid authorization token") + ErrInvalidAPIKey = errors.New("invalid API key") ) // authentication middleware diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index ec5f21516..2d7bdda81 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -66,6 +66,7 @@ type RoomManager struct { clientConfManager clientconfiguration.ClientConfigurationManager egressLauncher rtc.EgressLauncher versionGenerator utils.TimedVersionGenerator + turnAuthHandler *TURNAuthHandler rooms map[livekit.RoomName]*rtc.Room @@ -81,6 +82,7 @@ func NewLocalRoomManager( clientConfManager clientconfiguration.ClientConfigurationManager, egressLauncher rtc.EgressLauncher, versionGenerator utils.TimedVersionGenerator, + turnAuthHandler *TURNAuthHandler, ) (*RoomManager, error) { rtcConf, err := rtc.NewWebRTCConfig(conf) if err != nil { @@ -97,6 +99,7 @@ func NewLocalRoomManager( clientConfManager: clientConfManager, egressLauncher: egressLauncher, versionGenerator: versionGenerator, + turnAuthHandler: turnAuthHandler, rooms: make(map[livekit.RoomName]*rtc.Room), @@ -244,6 +247,10 @@ func (r *RoomManager) StartSession( return nil } + // should not error out, error is logged in iceServersForParticipant even if it fails + // since this is used for TURN server credentials, we don't want to fail the request even if there's no TURN for the session + apiKey, _, _ := r.getFirstKeyPair() + participant := room.GetParticipant(pi.Identity) if participant != nil { // When reconnecting, it means WS has interrupted but underlying peer connection is still ok in this state, @@ -286,8 +293,9 @@ func (r *RoomManager) StartSession( participant, requestSource, responseSink, - r.iceServersForRoom( - protoRoom, + r.iceServersForParticipant( + apiKey, + participant, iceConfig.PreferenceSubscriber == livekit.ICECandidateType_ICT_TLS, ), pi.ReconnectReason, @@ -411,7 +419,8 @@ func (r *RoomManager) StartSession( opts := rtc.ParticipantOptions{ AutoSubscribe: pi.AutoSubscribe, } - if err = room.Join(participant, requestSource, &opts, r.iceServersForRoom(protoRoom, iceConfig.PreferenceSubscriber == livekit.ICECandidateType_ICT_TLS)); err != nil { + iceServers := r.iceServersForParticipant(apiKey, participant, iceConfig.PreferenceSubscriber == livekit.ICECandidateType_ICT_TLS) + if err = room.Join(participant, requestSource, &opts, iceServers); err != nil { pLogger.Errorw("could not join room", err) _ = participant.Close(true, types.ParticipantCloseReasonJoinFailed, false) return err @@ -684,7 +693,7 @@ func (r *RoomManager) handleRTCMessage(ctx context.Context, roomName livekit.Roo } } -func (r *RoomManager) iceServersForRoom(ri *livekit.Room, tlsOnly bool) []*livekit.ICEServer { +func (r *RoomManager) iceServersForParticipant(apiKey string, participant types.LocalParticipant, tlsOnly bool) []*livekit.ICEServer { var iceServers []*livekit.ICEServer rtcConf := r.config.RTC @@ -705,11 +714,19 @@ func (r *RoomManager) iceServersForRoom(ri *livekit.Room, tlsOnly bool) []*livek urls = append(urls, fmt.Sprintf("turns:%s:443?transport=tcp", r.config.TURN.Domain)) } if len(urls) > 0 { - iceServers = append(iceServers, &livekit.ICEServer{ - Urls: urls, - Username: ri.Name, - Credential: ri.TurnPassword, - }) + username := r.turnAuthHandler.CreateUsername(apiKey, participant.ID()) + password, err := r.turnAuthHandler.CreatePassword(apiKey, participant.ID()) + if err != nil { + participant.GetLogger().Warnw("could not create turn password", err) + hasSTUN = false + } else { + logger.Infow("created TURN password", "username", username, "password", password) + iceServers = append(iceServers, &livekit.ICEServer{ + Urls: urls, + Username: username, + Credential: password, + }) + } } } @@ -746,23 +763,26 @@ func (r *RoomManager) iceServersForRoom(ri *livekit.Room, tlsOnly bool) []*livek } func (r *RoomManager) refreshToken(participant types.LocalParticipant) error { - for key, secret := range r.config.Keys { - grants := participant.ClaimGrants() - token := auth.NewAccessToken(key, secret) - token.SetName(grants.Name). - SetIdentity(string(participant.Identity())). - SetValidFor(tokenDefaultTTL). - SetMetadata(grants.Metadata). - AddGrant(grants.Video) - jwt, err := token.ToJWT() - if err == nil { - err = participant.SendRefreshToken(jwt) - } - if err != nil { - return err - } - break + key, secret, err := r.getFirstKeyPair() + if err != nil { + return err } + + grants := participant.ClaimGrants() + token := auth.NewAccessToken(key, secret) + token.SetName(grants.Name). + SetIdentity(string(participant.Identity())). + SetValidFor(tokenDefaultTTL). + SetMetadata(grants.Metadata). + AddGrant(grants.Video) + jwt, err := token.ToJWT() + if err == nil { + err = participant.SendRefreshToken(jwt) + } + if err != nil { + return err + } + return nil } @@ -786,6 +806,13 @@ func (r *RoomManager) getIceConfig(participant types.LocalParticipant) *livekit. return iceConfigCacheEntry.iceConfig } +func (r *RoomManager) getFirstKeyPair() (string, string, error) { + for key, secret := range r.config.Keys { + return key, secret, nil + } + return "", "", errors.New("no API keys configured") +} + // ------------------------------------ func iceServerForStunServers(servers []string) *livekit.ICEServer { diff --git a/pkg/service/turn.go b/pkg/service/turn.go index 08d971fb4..a4b91e212 100644 --- a/pkg/service/turn.go +++ b/pkg/service/turn.go @@ -15,14 +15,18 @@ package service import ( - "context" + "crypto/sha256" "crypto/tls" + "fmt" "net" "strconv" + "strings" + "github.com/jxskiss/base62" "github.com/pion/turn/v2" "github.com/pkg/errors" + "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/logger/pionlogger" @@ -142,14 +146,47 @@ func NewTurnServer(conf *config.Config, authHandler turn.AuthHandler, standalone return turn.NewServer(serverConfig) } -func newTurnAuthHandler(roomStore ObjectStore) turn.AuthHandler { - return func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { - // room id should be the username, create a hashed room id - rm, _, err := roomStore.LoadRoom(context.Background(), livekit.RoomName(username), false) - if err != nil { - return nil, false - } +func getTURNAuthHandlerFunc(handler *TURNAuthHandler) turn.AuthHandler { + return handler.HandleAuth +} - return turn.GenerateAuthKey(username, LivekitRealm, rm.TurnPassword), true +type TURNAuthHandler struct { + keyProvider auth.KeyProvider +} + +func NewTURNAuthHandler(keyProvider auth.KeyProvider) *TURNAuthHandler { + return &TURNAuthHandler{ + keyProvider: keyProvider, } } + +func (h *TURNAuthHandler) CreateUsername(apiKey string, pID livekit.ParticipantID) string { + return base62.EncodeToString([]byte(fmt.Sprintf("%s|%s", apiKey, pID))) +} + +func (h *TURNAuthHandler) CreatePassword(apiKey string, pID livekit.ParticipantID) (string, error) { + secret := h.keyProvider.GetSecret(apiKey) + if secret == "" { + return "", ErrInvalidAPIKey + } + keyInput := fmt.Sprintf("%s|%s", secret, pID) + sum := sha256.Sum256([]byte(keyInput)) + return base62.EncodeToString(sum[:]), nil +} + +func (h *TURNAuthHandler) HandleAuth(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { + decoded, err := base62.DecodeString(username) + if err != nil { + return nil, false + } + parts := strings.Split(string(decoded), "|") + if len(parts) != 2 { + return nil, false + } + password, err := h.CreatePassword(parts[0], livekit.ParticipantID(parts[1])) + if err != nil { + logger.Warnw("could not create TURN password", err, "username", username) + return nil, false + } + return turn.GenerateAuthKey(username, LivekitRealm, password), true +} diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 955f2ecef..9d8fbf394 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -73,7 +73,8 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live NewDefaultSignalServer, routing.NewSignalClient, NewLocalRoomManager, - newTurnAuthHandler, + NewTURNAuthHandler, + getTURNAuthHandlerFunc, newInProcessTurnServer, utils.NewDefaultTimedVersionGenerator, NewLivekitServer, diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index b051fb954..942d693fc 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -87,7 +87,8 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, telemetryService) clientConfigurationManager := createClientConfiguration() timedVersionGenerator := utils.NewDefaultTimedVersionGenerator() - roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, rtcEgressLauncher, timedVersionGenerator) + turnAuthHandler := NewTURNAuthHandler(keyProvider) + roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler) if err != nil { return nil, err } @@ -95,7 +96,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - authHandler := newTurnAuthHandler(objectStore) + authHandler := getTURNAuthHandlerFunc(turnAuthHandler) server, err := newInProcessTurnServer(conf, authHandler) if err != nil { return nil, err