From aeec75edeb96271b64c0fda84a7e98429bc452c4 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Thu, 6 Feb 2025 11:56:49 +0800 Subject: [PATCH] H265 supoort and codec regression (#3358) * H265 supoort and codec regression Support H265 codec. Add optional codec regression for subscribers don't support advanced codecs like H265, AV1, VP9. * restart forwarder on upstream codec change * tests * Reneogitate new codec if client doesn't support change * Add option to disable codec regression --------- Co-authored-by: boks1971 --- go.mod | 2 +- go.sum | 4 +- pkg/rtc/clientinfo.go | 4 + pkg/rtc/dynacast/dynacastmanager.go | 38 +++- pkg/rtc/dynacast/dynacastmanager_test.go | 106 ++++++++++- pkg/rtc/dynacast/dynacastquality.go | 51 +++++ pkg/rtc/mediaengine.go | 8 + pkg/rtc/mediatrack.go | 54 +++++- pkg/rtc/mediatrackreceiver.go | 159 +++++++++++++--- pkg/rtc/mediatracksubscriptions.go | 1 + pkg/rtc/participant.go | 53 ++++-- pkg/rtc/types/interfaces.go | 1 + .../typesfakes/fake_local_participant.go | 65 +++++++ pkg/rtc/wrappedreceiver.go | 65 ++++++- pkg/sfu/buffer/buffer.go | 153 ++++++++++++--- pkg/sfu/buffer/buffer_test.go | 110 +++++++++++ pkg/sfu/buffer/fps.go | 174 +++++++++++++++++- pkg/sfu/buffer/fps_test.go | 127 ++++++++++++- pkg/sfu/buffer/helpers.go | 39 +++- pkg/sfu/downtrack.go | 170 +++++++++++++---- pkg/sfu/forwarder.go | 92 ++++++--- pkg/sfu/receiver.go | 108 ++++++++--- pkg/sfu/redprimaryreceiver.go | 4 + pkg/sfu/redreceiver.go | 4 + pkg/sfu/track_remote.go | 5 +- pkg/sfu/utils/mimetype.go | 3 + pkg/sfu/videolayerselector/simulcast.go | 26 ++- 27 files changed, 1424 insertions(+), 202 deletions(-) diff --git a/go.mod b/go.mod index cf3d4f17b..5c3ae99fd 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20241220010243-a2bdee945564 - github.com/livekit/protocol v1.32.2-0.20250205043618-3d2a520b8e34 + github.com/livekit/protocol v1.32.2-0.20250206022155-07992dd19e2c github.com/livekit/psrpc v0.6.1-0.20250204212339-6de8b05bfcff github.com/mackerelio/go-osstat v0.2.5 github.com/magefile/mage v1.15.0 diff --git a/go.sum b/go.sum index 0896b24ca..417db69d6 100644 --- a/go.sum +++ b/go.sum @@ -169,8 +169,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20241220010243-a2bdee945564 h1:GX7KF/V9ExmcfT/2Bdia8aROjkxrgx7WpyH7w9MB4J4= github.com/livekit/mediatransportutil v0.0.0-20241220010243-a2bdee945564/go.mod h1:36s+wwmU3O40IAhE+MjBWP3W71QRiEE9SfooSBvtBqY= -github.com/livekit/protocol v1.32.2-0.20250205043618-3d2a520b8e34 h1:yMjtBcMYnZwlc+GSU56OCrzT0fv7TgZ6xhb3QlHyrps= -github.com/livekit/protocol v1.32.2-0.20250205043618-3d2a520b8e34/go.mod h1:9PQOu9w06M+14UDIhbmPeRRti5N4kq6n3R5XHDCzN5k= +github.com/livekit/protocol v1.32.2-0.20250206022155-07992dd19e2c h1:zBpzvlKkqShdd+9LFTl4H2ELV6is+K/XCRIsuphfJ+w= +github.com/livekit/protocol v1.32.2-0.20250206022155-07992dd19e2c/go.mod h1:9PQOu9w06M+14UDIhbmPeRRti5N4kq6n3R5XHDCzN5k= github.com/livekit/psrpc v0.6.1-0.20250204212339-6de8b05bfcff h1:1P84qlSggoKa60H20mAUXUkzjckHGl172ilzg5OJkho= github.com/livekit/psrpc v0.6.1-0.20250204212339-6de8b05bfcff/go.mod h1:X5WtEZ7OnEs72Fi5/J+i0on3964F1aynQpCalcgMqRo= github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o= diff --git a/pkg/rtc/clientinfo.go b/pkg/rtc/clientinfo.go index 2785edd08..463e32db7 100644 --- a/pkg/rtc/clientinfo.go +++ b/pkg/rtc/clientinfo.go @@ -58,6 +58,10 @@ func (c ClientInfo) FireTrackByRTPPacket() bool { return c.isGo() } +func (c ClientInfo) SupportsCodecChange() bool { + return c.ClientInfo != nil && c.ClientInfo.Sdk != livekit.ClientInfo_GO && c.ClientInfo.Sdk != livekit.ClientInfo_UNKNOWN +} + func (c ClientInfo) CanHandleReconnectResponse() bool { if c.Sdk == livekit.ClientInfo_JS { // JS handles Reconnect explicitly in 1.6.3, prior to 1.6.4 it could not handle unknown responses diff --git a/pkg/rtc/dynacast/dynacastmanager.go b/pkg/rtc/dynacast/dynacastmanager.go index 1320c2934..92ba192c5 100644 --- a/pkg/rtc/dynacast/dynacastmanager.go +++ b/pkg/rtc/dynacast/dynacastmanager.go @@ -38,6 +38,7 @@ type DynacastManager struct { params DynacastManagerParams lock sync.RWMutex + regressedCodec map[string]struct{} dynacastQuality map[string]*DynacastQuality // mime type => DynacastQuality maxSubscribedQuality map[string]livekit.VideoQuality committedMaxSubscribedQuality map[string]livekit.VideoQuality @@ -58,6 +59,7 @@ func NewDynacastManager(params DynacastManagerParams) *DynacastManager { } d := &DynacastManager{ params: params, + regressedCodec: make(map[string]struct{}), dynacastQuality: make(map[string]*DynacastQuality), maxSubscribedQuality: make(map[string]livekit.VideoQuality), committedMaxSubscribedQuality: make(map[string]livekit.VideoQuality), @@ -85,6 +87,38 @@ func (d *DynacastManager) AddCodec(mime string) { d.getOrCreateDynacastQuality(mime) } +func (d *DynacastManager) HandleCodecRegression(fromMime, toMime string) { + fromDq := d.getOrCreateDynacastQuality(fromMime) + + d.lock.Lock() + if d.isClosed { + d.lock.Unlock() + return + } + + normalizedFromMime, normalizedToMime := strings.ToLower(fromMime), strings.ToLower(toMime) + if fromDq == nil { + // should not happen as we have added the codec on setup receiver + d.params.Logger.Warnw("regression from codec not found", nil, "mime", normalizedFromMime) + d.lock.Unlock() + return + } + d.regressedCodec[normalizedFromMime] = struct{}{} + d.maxSubscribedQuality[normalizedFromMime] = livekit.VideoQuality_OFF + + // if the new codec is not added, notify the publisher to start publishing + if _, ok := d.maxSubscribedQuality[normalizedToMime]; !ok { + d.maxSubscribedQuality[normalizedToMime] = livekit.VideoQuality_HIGH + } + + d.lock.Unlock() + d.update(false) + + fromDq.Stop() + ToDq := d.getOrCreateDynacastQuality(normalizedToMime) + fromDq.RegressTo(ToDq) +} + func (d *DynacastManager) Restart() { d.lock.Lock() d.committedMaxSubscribedQuality = make(map[string]livekit.VideoQuality) @@ -183,7 +217,9 @@ func (d *DynacastManager) getDynacastQualitiesLocked() []*DynacastQuality { func (d *DynacastManager) updateMaxQualityForMime(mime string, maxQuality livekit.VideoQuality) { d.lock.Lock() - d.maxSubscribedQuality[mime] = maxQuality + if _, ok := d.regressedCodec[mime]; !ok { + d.maxSubscribedQuality[mime] = maxQuality + } d.lock.Unlock() d.update(false) diff --git a/pkg/rtc/dynacast/dynacastmanager_test.go b/pkg/rtc/dynacast/dynacastmanager_test.go index 759d230c5..f01b2edf9 100644 --- a/pkg/rtc/dynacast/dynacastmanager_test.go +++ b/pkg/rtc/dynacast/dynacastmanager_test.go @@ -29,14 +29,7 @@ import ( ) func TestSubscribedMaxQuality(t *testing.T) { - subscribedCodecsAsString := func(c1 []*livekit.SubscribedCodec) string { - sort.Slice(c1, func(i, j int) bool { return c1[i].Codec < c1[j].Codec }) - var s1 string - for _, c := range c1 { - s1 += c.String() - } - return s1 - } + t.Run("subscribers muted", func(t *testing.T) { dm := NewDynacastManager(DynacastManagerParams{}) var lock sync.Mutex @@ -285,3 +278,100 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) }) } + +func TestCodecRegression(t *testing.T) { + dm := NewDynacastManager(DynacastManagerParams{}) + var lock sync.Mutex + actualSubscribedQualities := make([]*livekit.SubscribedCodec, 0) + dm.OnSubscribedMaxQualityChange(func(subscribedQualities []*livekit.SubscribedCodec, _maxSubscribedQualities []types.SubscribedCodecQuality) { + lock.Lock() + actualSubscribedQualities = subscribedQualities + lock.Unlock() + }) + + dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeAV1, livekit.VideoQuality_HIGH) + + expectedSubscribedQualities := []*livekit.SubscribedCodec{ + { + Codec: strings.ToLower(webrtc.MimeTypeAV1), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: true}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + + dm.HandleCodecRegression(webrtc.MimeTypeAV1, webrtc.MimeTypeVP8) + + expectedSubscribedQualities = []*livekit.SubscribedCodec{ + { + Codec: strings.ToLower(webrtc.MimeTypeAV1), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: false}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + { + Codec: strings.ToLower(webrtc.MimeTypeVP8), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: true}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + + // av1 quality change should be forwarded to vp8 + // av1 quality change of node should be ignored + dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeAV1, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberNodeMaxQuality("n1", []types.SubscribedCodecQuality{ + {CodecMime: webrtc.MimeTypeAV1, Quality: livekit.VideoQuality_HIGH}, + }) + expectedSubscribedQualities = []*livekit.SubscribedCodec{ + { + Codec: strings.ToLower(webrtc.MimeTypeAV1), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: false}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + { + Codec: strings.ToLower(webrtc.MimeTypeVP8), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) +} + +func subscribedCodecsAsString(c1 []*livekit.SubscribedCodec) string { + sort.Slice(c1, func(i, j int) bool { return c1[i].Codec < c1[j].Codec }) + var s1 string + for _, c := range c1 { + s1 += c.String() + } + return s1 +} diff --git a/pkg/rtc/dynacast/dynacastquality.go b/pkg/rtc/dynacast/dynacastquality.go index 53e9f979d..957874b36 100644 --- a/pkg/rtc/dynacast/dynacastquality.go +++ b/pkg/rtc/dynacast/dynacastquality.go @@ -42,6 +42,7 @@ type DynacastQuality struct { maxSubscriberNodeQuality map[livekit.NodeID]livekit.VideoQuality maxSubscribedQuality livekit.VideoQuality maxQualityTimer *time.Timer + regressTo *DynacastQuality onSubscribedMaxQualityChange func(mimeType string, maxSubscribedQuality livekit.VideoQuality) } @@ -67,6 +68,8 @@ func (d *DynacastQuality) Stop() { } func (d *DynacastQuality) OnSubscribedMaxQualityChange(f func(mimeType string, maxSubscribedQuality livekit.VideoQuality)) { + d.lock.Lock() + defer d.lock.Unlock() d.onSubscribedMaxQualityChange = f } @@ -79,6 +82,12 @@ func (d *DynacastQuality) NotifySubscriberMaxQuality(subscriberID livekit.Partic ) d.lock.Lock() + if r := d.regressTo; r != nil { + d.lock.Unlock() + r.NotifySubscriberMaxQuality(subscriberID, quality) + return + } + if quality == livekit.VideoQuality_OFF { delete(d.maxSubscriberQuality, subscriberID) } else { @@ -98,6 +107,13 @@ func (d *DynacastQuality) NotifySubscriberNodeMaxQuality(nodeID livekit.NodeID, ) d.lock.Lock() + if r := d.regressTo; r != nil { + // the downstream node will synthesize correct quality notify (its dynacast manager has codec regression), just ignore it + d.lock.Unlock() + r.params.Logger.Debugw("ignoring node quality change, regressed to another dynacast quality", "mime", d.params.MimeType) + return + } + if quality == livekit.VideoQuality_OFF { delete(d.maxSubscriberNodeQuality, nodeID) } else { @@ -108,6 +124,41 @@ func (d *DynacastQuality) NotifySubscriberNodeMaxQuality(nodeID livekit.NodeID, d.updateQualityChange(false) } +func (d *DynacastQuality) RegressTo(other *DynacastQuality) { + d.lock.Lock() + d.regressTo = other + maxSubscriberQuality := d.maxSubscriberQuality + maxSubscriberNodeQuality := d.maxSubscriberNodeQuality + d.maxSubscriberQuality = make(map[livekit.ParticipantID]livekit.VideoQuality) + d.maxSubscriberNodeQuality = make(map[livekit.NodeID]livekit.VideoQuality) + d.lock.Unlock() + + other.lock.Lock() + for subID, quality := range maxSubscriberQuality { + if otherQuality, ok := other.maxSubscriberQuality[subID]; ok { + // no QUALITY_OFF in the map + if quality > otherQuality { + other.maxSubscriberQuality[subID] = quality + } + } else { + other.maxSubscriberQuality[subID] = quality + } + } + + for nodeID, quality := range maxSubscriberNodeQuality { + if otherQuality, ok := other.maxSubscriberNodeQuality[nodeID]; ok { + // no QUALITY_OFF in the map + if quality > otherQuality { + other.maxSubscriberNodeQuality[nodeID] = quality + } + } else { + other.maxSubscriberNodeQuality[nodeID] = quality + } + } + other.lock.Unlock() + other.Restart() +} + func (d *DynacastQuality) reset() { d.lock.Lock() d.initialized = false diff --git a/pkg/rtc/mediaengine.go b/pkg/rtc/mediaengine.go index 88261c5fa..81aab0176 100644 --- a/pkg/rtc/mediaengine.go +++ b/pkg/rtc/mediaengine.go @@ -129,6 +129,14 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, PayloadType: 35, }, + { + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeH265, + ClockRate: 90000, + RTCPFeedback: rtcpFeedback.Video, + }, + PayloadType: 116, + }, } { if filterOutH264HighProfile && codec.RTPCodecCapability.SDPFmtpLine == h264HighProfileFmtp { continue diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 78bc3ee35..bcb4c1457 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -54,6 +54,9 @@ type MediaTrack struct { lock sync.RWMutex rttFromXR atomic.Bool + + enableRegression bool + regressionTargetCodec string } type MediaTrackParams struct { @@ -81,17 +84,25 @@ func NewMediaTrack(params MediaTrackParams, ti *livekit.TrackInfo) *MediaTrack { params: params, } + // TODO: disable codec regression until simulcast-codec clients knows that + if ti.BackupCodecPolicy == livekit.BackupCodecPolicy_REGRESSION && len(ti.Codecs) > 1 { + t.enableRegression = true + t.regressionTargetCodec = ti.Codecs[1].MimeType + t.params.Logger.Debugw("track enabled codec regression", "regressionCodec", t.regressionTargetCodec) + } + t.MediaTrackReceiver = NewMediaTrackReceiver(MediaTrackReceiverParams{ - MediaTrack: t, - IsRelayed: false, - ParticipantID: params.ParticipantID, - ParticipantIdentity: params.ParticipantIdentity, - ParticipantVersion: params.ParticipantVersion, - ReceiverConfig: params.ReceiverConfig, - SubscriberConfig: params.SubscriberConfig, - AudioConfig: params.AudioConfig, - Telemetry: params.Telemetry, - Logger: params.Logger, + MediaTrack: t, + IsRelayed: false, + ParticipantID: params.ParticipantID, + ParticipantIdentity: params.ParticipantIdentity, + ParticipantVersion: params.ParticipantVersion, + ReceiverConfig: params.ReceiverConfig, + SubscriberConfig: params.SubscriberConfig, + AudioConfig: params.AudioConfig, + Telemetry: params.Telemetry, + Logger: params.Logger, + RegressionTargetCodec: t.regressionTargetCodec, }, ti) if ti.Type == livekit.TrackType_AUDIO { @@ -123,6 +134,9 @@ func NewMediaTrack(params MediaTrackParams, ti *livekit.TrackInfo) *MediaTrack { ) }, ) + t.MediaTrackReceiver.OnCodecRegression(func(old, new webrtc.RTPCodecParameters) { + t.dynacastManager.HandleCodecRegression(old.MimeType, new.MimeType) + }) } return t @@ -335,6 +349,10 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe } wr = newWR newCodec = true + + newWR.AddOnCodecStateChange(func(codec webrtc.RTPCodecParameters, state sfu.ReceiverCodecState) { + t.MediaTrackReceiver.HandleReceiverCodecChange(newWR, codec, state) + }) } t.lock.Unlock() @@ -363,6 +381,22 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe t.MediaTrackReceiver.SetLayerSsrc(mime, track.RID(), uint32(track.SSRC())) + if newCodec && t.enableRegression && strings.EqualFold(mime, t.regressionTargetCodec) { + t.params.Logger.Infow("regression target codec received", "codec", mime) + for _, c := range ti.Codecs { + if strings.EqualFold(c.MimeType, mime) { + continue + } + + t.params.Logger.Debugw("suspending codec for codec regression", "codec", c.MimeType) + if r := t.MediaTrackReceiver.Receiver(c.MimeType); r != nil { + if rtcreceiver, ok := r.(*sfu.WebRTCReceiver); ok { + rtcreceiver.SetCodecState(sfu.ReceiverCodecStateSuspended) + } + } + } + } + buff.Bind(receiver.GetParameters(), track.Codec().RTPCodecCapability, bitrates) // if subscriber request fps before fps calculated, update them after fps updated. diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index 4be8c83ca..970dba5a4 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -75,24 +75,59 @@ func (m mediaTrackReceiverState) String() string { type simulcastReceiver struct { sfu.TrackReceiver - priority int + priority int + lock sync.Mutex + regressTo sfu.TrackReceiver } func (r *simulcastReceiver) Priority() int { return r.priority } +func (r *simulcastReceiver) AddDownTrack(track sfu.TrackSender) error { + r.lock.Lock() + if rt := r.regressTo; rt != nil { + r.lock.Unlock() + // AddDownTrack could be called in downtrack.OnBinding callback, use a goroutine to avoid deadlock + go track.SetReceiver(rt) + return nil + } + err := r.TrackReceiver.AddDownTrack(track) + r.lock.Unlock() + return err +} + +func (r *simulcastReceiver) RegressTo(receiver sfu.TrackReceiver) { + r.lock.Lock() + r.regressTo = receiver + dts := r.GetDownTracks() + r.lock.Unlock() + + for _, dt := range dts { + dt.SetReceiver(receiver) + } +} + +func (r *simulcastReceiver) IsRegressed() bool { + r.lock.Lock() + defer r.lock.Unlock() + return r.regressTo != nil +} + +// ----------------------------------------------------- + type MediaTrackReceiverParams struct { - MediaTrack types.MediaTrack - IsRelayed bool - ParticipantID livekit.ParticipantID - ParticipantIdentity livekit.ParticipantIdentity - ParticipantVersion uint32 - ReceiverConfig ReceiverConfig - SubscriberConfig DirectionConfig - AudioConfig sfu.AudioConfig - Telemetry telemetry.TelemetryService - Logger logger.Logger + MediaTrack types.MediaTrack + IsRelayed bool + ParticipantID livekit.ParticipantID + ParticipantIdentity livekit.ParticipantIdentity + ParticipantVersion uint32 + ReceiverConfig ReceiverConfig + SubscriberConfig DirectionConfig + AudioConfig sfu.AudioConfig + Telemetry telemetry.TelemetryService + Logger logger.Logger + RegressionTargetCodec string } type MediaTrackReceiver struct { @@ -108,6 +143,7 @@ type MediaTrackReceiver struct { onSetupReceiver func(mime string) onMediaLossFeedback func(dt *sfu.DownTrack, report *rtcp.ReceiverReport) onClose []func(isExpectedToResume bool) + onCodecRegression func(old, new webrtc.RTPCodecParameters) *MediaTrackSubscriptions } @@ -160,23 +196,21 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority receivers := slices.Clone(t.receivers) // codec position maybe taken by DummyReceiver, check and upgrade to WebRTCReceiver - receiverToAdd := receiver - idx := -1 - for i, r := range receivers { + var existingReceiver bool + for _, r := range receivers { if strings.EqualFold(r.Codec().MimeType, receiver.Codec().MimeType) { - idx = i + existingReceiver = true + if d, ok := r.TrackReceiver.(*DummyReceiver); ok { + d.Upgrade(receiver) + } else { + t.params.Logger.Errorw("receiver already exists, setup failed", nil, "mime", receiver.Codec().MimeType) + } break } } - if idx != -1 { - if d, ok := receivers[idx].TrackReceiver.(*DummyReceiver); ok { - d.Upgrade(receiver) - receiverToAdd = d - } - // replace receiver - receivers = slices.Delete(receivers, idx, idx+1) + if !existingReceiver { + receivers = append(receivers, &simulcastReceiver{TrackReceiver: receiver, priority: priority}) } - receivers = append(receivers, &simulcastReceiver{TrackReceiver: receiverToAdd, priority: priority}) sort.Slice(receivers, func(i, j int) bool { return receivers[i].Priority() < receivers[j].Priority() @@ -220,6 +254,81 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority } } +func (t *MediaTrackReceiver) HandleReceiverCodecChange(r sfu.TrackReceiver, codec webrtc.RTPCodecParameters, state sfu.ReceiverCodecState) { + // TODO: we only support codec regress to backup codec now, so the receiver will not be available + // once fallback / regression happens. + // We will support codec upgrade in the future then the primary receiver will be available again if + // all subscribers of the track negotiate it. + if state == sfu.ReceiverCodecStateNormal { + return + } + + t.lock.Lock() + // codec regression, find backup codec and switch all downtracks to it + var ( + oldReceiver *simulcastReceiver + backupCodecReceiver sfu.TrackReceiver + ) + for _, receiver := range t.receivers { + if receiver.TrackReceiver == r { + oldReceiver = receiver + continue + } + if d, ok := receiver.TrackReceiver.(*DummyReceiver); ok && d.Receiver() == r { + oldReceiver = receiver + continue + } + + if strings.EqualFold(receiver.Codec().MimeType, t.params.RegressionTargetCodec) { + backupCodecReceiver = receiver.TrackReceiver + } + + if oldReceiver != nil && backupCodecReceiver != nil { + break + } + } + + if oldReceiver == nil { + // should not happen + t.params.Logger.Errorw("could not find primary receiver for codec", nil, "codec", codec.MimeType) + t.lock.Unlock() + return + } + + if oldReceiver.IsRegressed() { + t.params.Logger.Infow("codec already regressed", "codec", codec.MimeType) + t.lock.Unlock() + return + } + + if backupCodecReceiver == nil { + t.params.Logger.Infow("no backup codec found, can't regress codec") + t.lock.Unlock() + return + } + + t.params.Logger.Infow("regressing codec", "from", codec.MimeType, "to", backupCodecReceiver.Codec().MimeType) + + // remove old codec from potential codecs + for i, c := range t.potentialCodecs { + if strings.EqualFold(c.MimeType, codec.MimeType) { + slices.Delete(t.potentialCodecs, i, i+1) + break + } + } + onCodecRegression := t.onCodecRegression + t.lock.Unlock() + oldReceiver.RegressTo(backupCodecReceiver) + + if onCodecRegression != nil { + onCodecRegression(codec, backupCodecReceiver.Codec()) + } +} + +func (t *MediaTrackReceiver) OnCodecRegression(f func(old, new webrtc.RTPCodecParameters)) { + t.onCodecRegression = f +} + func (t *MediaTrackReceiver) SetPotentialCodecs(codecs []webrtc.RTPCodecParameters, headers []webrtc.RTPHeaderExtensionParameter) { // The potential codecs have not published yet, so we can't get the actual Extensions, the client/browser uses same extensions // for all video codecs so we assume they will have same extensions as the primary codec. @@ -447,6 +556,10 @@ func (t *MediaTrackReceiver) AddSubscriber(sub types.LocalParticipant) (types.Su } for _, receiver := range receivers { + if receiver.IsRegressed() { + continue + } + codec := receiver.Codec() var found bool for _, pc := range potentialCodecs { diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 1933f69e4..8538db467 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -143,6 +143,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * Logger: LoggerWithTrack(sub.GetLogger().WithComponent(sutils.ComponentSub), trackID, t.params.IsRelayed), RTCPWriter: sub.WriteSubscriberRTCP, DisableSenderReportPassThrough: sub.GetDisableSenderReportPassThrough(), + SupportsCodecChange: sub.SupportsCodecChange(), }) if err != nil { return nil, err diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index d1efcc04c..0e5615288 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -163,6 +163,7 @@ type ParticipantParams struct { DataChannelMaxBufferedAmount uint64 DatachannelSlowThreshold int FireOnTrackBySdp bool + DisableCodecRegression bool } type ParticipantImpl struct { @@ -1737,12 +1738,10 @@ func (p *ParticipantImpl) onMediaTrack(rtcTrack *webrtc.TrackRemote, rtpReceiver return } - codec := rtcTrack.Codec() + var codec webrtc.RTPCodecParameters var fromSdp bool - // track fired by sdp - if rtcTrack.Codec().PayloadType == 0 { - codecs := rtpReceiver.GetParameters().Codecs - if len(codecs) == 0 || (rtcTrack.Kind() == webrtc.RTPCodecTypeVideo && p.params.ClientInfo.FireTrackByRTPPacket()) { + if rtcTrack.Kind() == webrtc.RTPCodecTypeVideo && p.params.ClientInfo.FireTrackByRTPPacket() { + if rtcTrack.Codec().PayloadType == 0 { go func() { // wait for the first packet to determine the codec bytes := make([]byte, 1500) @@ -1757,12 +1756,20 @@ func (p *ParticipantImpl) onMediaTrack(rtcTrack *webrtc.TrackRemote, rtpReceiver }() return } + codec = rtcTrack.Codec() + } else { + // track fired by sdp + codecs := rtpReceiver.GetParameters().Codecs + if len(codecs) == 0 { + p.pubLogger.Errorw("no negotiated codecs for track, track will be ignored", nil, "trackID", rtcTrack.ID(), "StreamID", rtcTrack.StreamID()) + return + } codec = codecs[0] fromSdp = true } + p.params.Logger.Debugw("onMediaTrack", "codec", codec, "payloadType", codec.PayloadType, "fromSdp", fromSdp, "parameters", rtpReceiver.GetParameters()) var track sfu.TrackRemote = sfu.NewTrackRemoteFromSdp(rtcTrack, codec) - publishedTrack, isNewTrack := p.mediaTrackReceived(track, rtpReceiver) if publishedTrack == nil { p.pendingTracksLock.Lock() @@ -2195,19 +2202,25 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l return ti } + backupCodecPolicy := req.BackupCodecPolicy + if backupCodecPolicy == livekit.BackupCodecPolicy_REGRESSION && p.params.DisableCodecRegression { + backupCodecPolicy = livekit.BackupCodecPolicy_SIMULCAST + } + ti := &livekit.TrackInfo{ - Type: req.Type, - Name: req.Name, - Width: req.Width, - Height: req.Height, - Muted: req.Muted, - DisableDtx: req.DisableDtx, - Source: req.Source, - Layers: req.Layers, - DisableRed: req.DisableRed, - Stereo: req.Stereo, - Encryption: req.Encryption, - Stream: req.Stream, + Type: req.Type, + Name: req.Name, + Width: req.Width, + Height: req.Height, + Muted: req.Muted, + DisableDtx: req.DisableDtx, + Source: req.Source, + Layers: req.Layers, + DisableRed: req.DisableRed, + Stereo: req.Stereo, + Encryption: req.Encryption, + Stream: req.Stream, + BackupCodecPolicy: backupCodecPolicy, } if req.Stereo { ti.AudioFeatures = append(ti.AudioFeatures, livekit.AudioTrackFeature_TF_STEREO) @@ -3107,6 +3120,10 @@ func (p *ParticipantImpl) HandleMetrics(senderParticipantID livekit.ParticipantI return nil } +func (p *ParticipantImpl) SupportsCodecChange() bool { + return p.params.ClientInfo.SupportsCodecChange() +} + // ---------------------------------------------- func codecsFromMediaDescription(m *sdp.MediaDescription) (out []sdp.Codec, err error) { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 6741a0a45..eeaedc52b 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -390,6 +390,7 @@ type LocalParticipant interface { // has been reached. If the timeout expires, it will return an error. WaitUntilSubscribed(timeout time.Duration) error StopAndGetSubscribedTracksForwarderState() map[livekit.TrackID]*livekit.RTPForwarderState + SupportsCodecChange() bool // returns list of participant identities that the current participant is subscribed to GetSubscribedParticipants() []livekit.ParticipantID diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 092121663..f71d519c3 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -995,6 +995,16 @@ type FakeLocalParticipant struct { arg2 livekit.TrackID arg3 bool } + SupportsCodecChangeStub func() bool + supportsCodecChangeMutex sync.RWMutex + supportsCodecChangeArgsForCall []struct { + } + supportsCodecChangeReturns struct { + result1 bool + } + supportsCodecChangeReturnsOnCall map[int]struct { + result1 bool + } SupportsSyncStreamIDStub func() bool supportsSyncStreamIDMutex sync.RWMutex supportsSyncStreamIDArgsForCall []struct { @@ -6468,6 +6478,59 @@ func (fake *FakeLocalParticipant) SubscriptionPermissionUpdateArgsForCall(i int) return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 } +func (fake *FakeLocalParticipant) SupportsCodecChange() bool { + fake.supportsCodecChangeMutex.Lock() + ret, specificReturn := fake.supportsCodecChangeReturnsOnCall[len(fake.supportsCodecChangeArgsForCall)] + fake.supportsCodecChangeArgsForCall = append(fake.supportsCodecChangeArgsForCall, struct { + }{}) + stub := fake.SupportsCodecChangeStub + fakeReturns := fake.supportsCodecChangeReturns + fake.recordInvocation("SupportsCodecChange", []interface{}{}) + fake.supportsCodecChangeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SupportsCodecChangeCallCount() int { + fake.supportsCodecChangeMutex.RLock() + defer fake.supportsCodecChangeMutex.RUnlock() + return len(fake.supportsCodecChangeArgsForCall) +} + +func (fake *FakeLocalParticipant) SupportsCodecChangeCalls(stub func() bool) { + fake.supportsCodecChangeMutex.Lock() + defer fake.supportsCodecChangeMutex.Unlock() + fake.SupportsCodecChangeStub = stub +} + +func (fake *FakeLocalParticipant) SupportsCodecChangeReturns(result1 bool) { + fake.supportsCodecChangeMutex.Lock() + defer fake.supportsCodecChangeMutex.Unlock() + fake.SupportsCodecChangeStub = nil + fake.supportsCodecChangeReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SupportsCodecChangeReturnsOnCall(i int, result1 bool) { + fake.supportsCodecChangeMutex.Lock() + defer fake.supportsCodecChangeMutex.Unlock() + fake.SupportsCodecChangeStub = nil + if fake.supportsCodecChangeReturnsOnCall == nil { + fake.supportsCodecChangeReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.supportsCodecChangeReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeLocalParticipant) SupportsSyncStreamID() bool { fake.supportsSyncStreamIDMutex.Lock() ret, specificReturn := fake.supportsSyncStreamIDReturnsOnCall[len(fake.supportsSyncStreamIDArgsForCall)] @@ -7671,6 +7734,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.subscriptionPermissionMutex.RUnlock() fake.subscriptionPermissionUpdateMutex.RLock() defer fake.subscriptionPermissionUpdateMutex.RUnlock() + fake.supportsCodecChangeMutex.RLock() + defer fake.supportsCodecChangeMutex.RUnlock() fake.supportsSyncStreamIDMutex.RLock() defer fake.supportsSyncStreamIDMutex.RUnlock() fake.supportsTransceiverReuseMutex.RLock() diff --git a/pkg/rtc/wrappedreceiver.go b/pkg/rtc/wrappedreceiver.go index f035361e8..1f8fc81f2 100644 --- a/pkg/rtc/wrappedreceiver.go +++ b/pkg/rtc/wrappedreceiver.go @@ -21,6 +21,7 @@ import ( "github.com/pion/webrtc/v4" "go.uber.org/atomic" + "golang.org/x/exp/maps" "golang.org/x/exp/slices" "github.com/livekit/protocol/livekit" @@ -47,14 +48,13 @@ type WrappedReceiver struct { params WrappedReceiverParams receivers []sfu.TrackReceiver codecs []webrtc.RTPCodecParameters - determinedCodec webrtc.RTPCodecCapability onReadyCallbacks []func() } func NewWrappedReceiver(params WrappedReceiverParams) *WrappedReceiver { sfuReceivers := make([]sfu.TrackReceiver, 0, len(params.Receivers)) for _, r := range params.Receivers { - sfuReceivers = append(sfuReceivers, r.TrackReceiver) + sfuReceivers = append(sfuReceivers, r) } codecs := params.UpstreamCodecs @@ -94,7 +94,6 @@ func (r *WrappedReceiver) StreamID() string { // DetermineReceiver determines the receiver of negotiated codec and return ready state of the receiver func (r *WrappedReceiver) DetermineReceiver(codec webrtc.RTPCodecCapability) bool { r.lock.Lock() - r.determinedCodec = codec var trackReceiver sfu.TrackReceiver for _, receiver := range r.receivers { @@ -130,8 +129,10 @@ func (r *WrappedReceiver) DetermineReceiver(codec webrtc.RTPCodecCapability) boo trackReceiver.AddOnReady(f) } - if d, ok := trackReceiver.(*DummyReceiver); ok { - return d.IsReady() + if s, ok := trackReceiver.(*simulcastReceiver); ok { + if d, ok := s.TrackReceiver.(*DummyReceiver); ok { + return d.IsReady() + } } return true } @@ -174,9 +175,10 @@ type DummyReceiver struct { codec webrtc.RTPCodecParameters headerExtensions []webrtc.RTPHeaderExtensionParameter - downTrackLock sync.Mutex - downTracks map[livekit.ParticipantID]sfu.TrackSender - onReadyCallbacks []func() + downTrackLock sync.Mutex + downTracks map[livekit.ParticipantID]sfu.TrackSender + onReadyCallbacks []func() + onCodecStateChange []func(webrtc.RTPCodecParameters, sfu.ReceiverCodecState) settingsLock sync.Mutex maxExpectedLayerValid bool @@ -214,12 +216,18 @@ func (d *DummyReceiver) Upgrade(receiver sfu.TrackReceiver) { d.downTracks = make(map[livekit.ParticipantID]sfu.TrackSender) onReadyCallbacks := d.onReadyCallbacks d.onReadyCallbacks = nil + codecChange := d.onCodecStateChange + d.onCodecStateChange = nil d.downTrackLock.Unlock() for _, f := range onReadyCallbacks { receiver.AddOnReady(f) } + for _, f := range codecChange { + receiver.AddOnCodecStateChange(f) + } + d.settingsLock.Lock() if d.maxExpectedLayerValid { receiver.SetMaxExpectedSpatialLayer(d.maxExpectedLayer) @@ -336,6 +344,16 @@ func (d *DummyReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) { } } +func (d *DummyReceiver) GetDownTracks() []sfu.TrackSender { + d.downTrackLock.Lock() + defer d.downTrackLock.Unlock() + + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + return r.GetDownTracks() + } + return maps.Values(d.downTracks) +} + func (d *DummyReceiver) DebugInfo() map[string]interface{} { if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { return r.DebugInfo() @@ -420,6 +438,27 @@ func (d *DummyReceiver) IsReady() bool { return d.receiver.Load() != nil } +func (d *DummyReceiver) AddOnCodecStateChange(f func(codec webrtc.RTPCodecParameters, state sfu.ReceiverCodecState)) { + var receiver sfu.TrackReceiver + d.downTrackLock.Lock() + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + receiver = r + } else { + d.onCodecStateChange = append(d.onCodecStateChange, f) + } + d.downTrackLock.Unlock() + if receiver != nil { + receiver.AddOnCodecStateChange(f) + } +} + +func (d *DummyReceiver) CodecState() sfu.ReceiverCodecState { + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + return r.CodecState() + } + return sfu.ReceiverCodecStateNormal +} + // -------------------------------------------- type DummyRedReceiver struct { @@ -464,6 +503,16 @@ func (d *DummyRedReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) { } } +func (d *DummyRedReceiver) GetDownTracks() []sfu.TrackSender { + d.downTrackLock.Lock() + defer d.downTrackLock.Unlock() + + if r, ok := d.redReceiver.Load().(sfu.TrackReceiver); ok { + return r.GetDownTracks() + } + return maps.Values(d.downTracks) +} + func (d *DummyRedReceiver) ReadRTP(buf []byte, layer uint8, esn uint64) (int, error) { if r, ok := d.redReceiver.Load().(sfu.TrackReceiver); ok { return r.ReadRTP(buf, layer, esn) diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index bc106d13e..95233ddaa 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -17,6 +17,7 @@ package buffer import ( "encoding/binary" "errors" + "fmt" "io" "strings" "sync" @@ -79,7 +80,6 @@ type Buffer struct { maxVideoPkts int maxAudioPkts int codecType webrtc.RTPCodecType - payloadType uint8 extPackets deque.Deque[*ExtPacket] pPackets []pendingPacket closeOnce sync.Once @@ -90,7 +90,11 @@ type Buffer struct { audioLevelExtID uint8 bound bool closed atomic.Bool - mime string + + rtpParameters webrtc.RTPParameters + payloadType uint8 + rtxPayloadType uint8 + mime string snRangeMap *utils.RangeMap[uint64, uint64] @@ -119,6 +123,7 @@ type Buffer struct { onRtcpSenderReport func() onFpsChanged func() onFinalRtpStats func(*livekit.RTPStats) + onCodecChange func(webrtc.RTPCodecParameters) // logger logger logger.Logger @@ -214,6 +219,7 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili b.clockRate = codec.ClockRate b.lastReport = mono.UnixNano() b.mime = strings.ToLower(codec.MimeType) + b.rtpParameters = params for _, codecParameter := range params.Codecs { if strings.EqualFold(codecParameter.MimeType, codec.MimeType) { b.payloadType = uint8(codecParameter.PayloadType) @@ -226,23 +232,23 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili b.payloadType = uint8(params.Codecs[0].PayloadType) } + // find RTX payload type + for _, codec := range params.Codecs { + if strings.EqualFold(codec.MimeType, "video/rtx") && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", b.payloadType)) { + b.rtxPayloadType = uint8(codec.PayloadType) + break + } + } + for _, ext := range params.HeaderExtensions { switch ext.URI { case dd.ExtensionURI: - if IsSvcCodec(codec.MimeType) || strings.EqualFold(codec.MimeType, webrtc.MimeTypeVP8) { - if b.ddExtID != 0 { - b.logger.Warnw("multiple dependency descriptor extensions found", nil, "id", ext.ID, "previous", b.ddExtID) - continue - } - b.ddExtID = uint8(ext.ID) - frc := NewFrameRateCalculatorDD(b.clockRate, b.logger) - for i := range b.frameRateCalculator { - b.frameRateCalculator[i] = frc.GetFrameRateCalculatorForSpatial(int32(i)) - } - b.ddParser = NewDependencyDescriptorParser(b.ddExtID, b.logger, func(spatial, temporal int32) { - frc.SetMaxLayer(spatial, temporal) - }) + if b.ddExtID != 0 { + b.logger.Warnw("multiple dependency descriptor extensions found", nil, "id", ext.ID, "previous", b.ddExtID) + continue } + b.ddExtID = uint8(ext.ID) + b.createDDParserAndFrameRateCalculator(codec.MimeType) case sdp.AudioLevelURI: b.audioLevelExtID = uint8(ext.ID) @@ -262,16 +268,7 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili b.codecType = webrtc.RTPCodecTypeVideo b.bucket = bucket.NewBucket[uint64](InitPacketBufferSizeVideo) if b.frameRateCalculator[0] == nil { - if strings.EqualFold(codec.MimeType, webrtc.MimeTypeVP8) { - b.frameRateCalculator[0] = NewFrameRateCalculatorVP8(b.clockRate, b.logger) - } - - if strings.EqualFold(codec.MimeType, webrtc.MimeTypeVP9) { - frc := NewFrameRateCalculatorVP9(b.clockRate, b.logger) - for i := range b.frameRateCalculator { - b.frameRateCalculator[i] = frc.GetFrameRateCalculatorForSpatial(int32(i)) - } - } + b.createFrameRateCalculator(codec.MimeType) } if bitrates > 0 { pps := bitrates / 8 / 1200 @@ -310,6 +307,40 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili b.bound = true } +func (b *Buffer) OnCodecChange(fn func(webrtc.RTPCodecParameters)) { + b.Lock() + b.onCodecChange = fn + b.Unlock() +} + +func (b *Buffer) createDDParserAndFrameRateCalculator(mime string) { + if IsSvcCodec(mime) || strings.EqualFold(mime, webrtc.MimeTypeVP8) { + frc := NewFrameRateCalculatorDD(b.clockRate, b.logger) + for i := range b.frameRateCalculator { + b.frameRateCalculator[i] = frc.GetFrameRateCalculatorForSpatial(int32(i)) + } + b.ddParser = NewDependencyDescriptorParser(b.ddExtID, b.logger, func(spatial, temporal int32) { + frc.SetMaxLayer(spatial, temporal) + }) + } +} + +func (b *Buffer) createFrameRateCalculator(mime string) { + switch { + case strings.EqualFold(mime, webrtc.MimeTypeVP8): + b.frameRateCalculator[0] = NewFrameRateCalculatorVP8(b.clockRate, b.logger) + + case strings.EqualFold(mime, webrtc.MimeTypeVP9): + frc := NewFrameRateCalculatorVP9(b.clockRate, b.logger) + for i := range b.frameRateCalculator { + b.frameRateCalculator[i] = frc.GetFrameRateCalculatorForSpatial(int32(i)) + } + + case strings.EqualFold(mime, webrtc.MimeTypeH265): + b.frameRateCalculator[0] = NewFrameRateCalculatorH26x(b.clockRate, b.logger) + } +} + // Write adds an RTP Packet, ordering is not guaranteed, newer packets may arrive later func (b *Buffer) Write(pkt []byte) (n int, err error) { var rtpPacket rtp.Packet @@ -364,7 +395,6 @@ func (b *Buffer) Write(pkt []byte) (n int, err error) { return } - b.payloadType = rtpPacket.PayloadType b.calc(pkt, &rtpPacket, now, false) b.Unlock() b.readCond.Broadcast() @@ -397,6 +427,11 @@ func (b *Buffer) writeRTX(rtxPkt *rtp.Packet, arrivalTime int64) (n int, err err return } + if rtxPkt.PayloadType != b.rtxPayloadType { + b.logger.Debugw("unexpected rtx payload type", "expected", b.rtxPayloadType, "actual", rtxPkt.PayloadType) + return + } + if b.rtxPktBuf == nil { b.rtxPktBuf = make([]byte, bucket.MaxPktSize) } @@ -593,6 +628,10 @@ func (b *Buffer) calc(rawPkt []byte, rtpPacket *rtp.Packet, arrivalTime int64, i return } + if !flowState.IsOutOfOrder && rtpPacket.PayloadType != b.payloadType && b.codecType == webrtc.RTPCodecTypeVideo { + b.handleCodecChange(rtpPacket.PayloadType) + } + // add to RTX buffer using sequence number after accounting for dropped padding only packets snAdjustment, err := b.snRangeMap.GetValue(flowState.ExtSequenceNumber) if err != nil { @@ -712,6 +751,55 @@ func (b *Buffer) doFpsCalc(ep *ExtPacket) { } } +func (b *Buffer) handleCodecChange(newPT uint8) { + var ( + codecFound, rtxFound bool + rtxPt uint8 + newCodec webrtc.RTPCodecParameters + ) + for _, codec := range b.rtpParameters.Codecs { + if !codecFound && uint8(codec.PayloadType) == newPT { + newCodec = codec + codecFound = true + } + + if strings.EqualFold(codec.MimeType, "video/rtx") && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", newPT)) { + rtxFound = true + rtxPt = uint8(codec.PayloadType) + } + + if codecFound && rtxFound { + break + } + } + if !codecFound { + b.logger.Errorw("could not find codec for new payload type", nil, "pt", newPT, "rtpParameters", b.rtpParameters) + return + } + b.logger.Infow("codec changed", + "oldPayload", b.payloadType, "newPayload", newPT, + "oldRtxPayload", b.rtxPayloadType, "newRtxPayload", rtxPt, + "oldMime", b.mime, "newMime", newCodec.MimeType) + b.payloadType = newPT + b.rtxPayloadType = rtxPt + b.mime = strings.ToLower(newCodec.MimeType) + b.frameRateCalculated = false + + if b.ddExtID != 0 { + b.createDDParserAndFrameRateCalculator(b.mime) + } + + if b.frameRateCalculator[0] == nil { + b.createFrameRateCalculator(b.mime) + } + + b.bucket.ResyncOnNextPacket() + + if f := b.onCodecChange; f != nil { + go f(newCodec) + } +} + func (b *Buffer) updateStreamState(p *rtp.Packet, arrivalTime int64) rtpstats.RTPFlowState { flowState := b.rtpStats.Update( arrivalTime, @@ -825,6 +913,19 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64, flowStat case utils.MimeTypeAV1: ep.KeyFrame = IsAV1KeyFrame(rtpPacket.Payload) + + case utils.MimeTypeH265: + if ep.DependencyDescriptor == nil { + if len(rtpPacket.Payload) < 2 { + b.logger.Warnw("invalid H265 packet", nil) + return nil + } + ep.VideoLayer = VideoLayer{ + Temporal: int32(rtpPacket.Payload[1]&0x07) - 1, + } + ep.Spatial = InvalidLayerSpatial + } + ep.KeyFrame = IsH265KeyFrame(rtpPacket.Payload) } if ep.KeyFrame { diff --git a/pkg/sfu/buffer/buffer_test.go b/pkg/sfu/buffer/buffer_test.go index ad8a553b8..55aaebb7d 100644 --- a/pkg/sfu/buffer/buffer_test.go +++ b/pkg/sfu/buffer/buffer_test.go @@ -28,6 +28,17 @@ import ( "github.com/livekit/mediatransportutil/pkg/nack" ) +var h265Codec = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: "video/h265", + ClockRate: 90000, + RTCPFeedback: []webrtc.RTCPFeedback{{ + Type: "nack", + }}, + }, + PayloadType: 116, +} + var vp8Codec = webrtc.RTPCodecParameters{ RTPCodecCapability: webrtc.RTPCodecCapability{ MimeType: "video/vp8", @@ -314,6 +325,105 @@ func TestFractionLostReport(t *testing.T) { wg.Wait() } +func TestCodecChange(t *testing.T) { + // codec change before bind + buff := NewBuffer(123, 1, 1) + require.NotNil(t, buff) + changedCodec := make(chan webrtc.RTPCodecParameters, 1) + buff.OnCodecChange(func(rp webrtc.RTPCodecParameters) { + select { + case changedCodec <- rp: + default: + t.Fatalf("codec change not consumed") + } + }) + + h265Pkt := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 116, + SequenceNumber: 1, + Timestamp: 1, + SSRC: 123, + }, + Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, + } + buf, err := h265Pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(buf) + require.NoError(t, err) + + select { + case <-changedCodec: + t.Fatalf("unexpected codec change") + case <-time.After(100 * time.Millisecond): + } + + buff.Bind(webrtc.RTPParameters{ + HeaderExtensions: nil, + Codecs: []webrtc.RTPCodecParameters{vp8Codec, h265Codec}, + }, vp8Codec.RTPCodecCapability, 0) + + select { + case c := <-changedCodec: + require.Equal(t, h265Codec, c) + + case <-time.After(1 * time.Second): + t.Fatalf("expected codec change") + } + + // codec change after bind + vp8Pkt := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SequenceNumber: 3, + Timestamp: 3, + SSRC: 123, + }, + Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, + } + buf, err = vp8Pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(buf) + require.NoError(t, err) + + select { + case c := <-changedCodec: + require.Equal(t, vp8Codec, c) + + case <-time.After(1 * time.Second): + t.Fatalf("expected codec change") + } + + // out of order pkts can't cause codec change + h265Pkt.SequenceNumber = 2 + h265Pkt.Timestamp = 2 + buf, err = h265Pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(buf) + require.NoError(t, err) + select { + case <-changedCodec: + t.Fatalf("unexpected codec change") + case <-time.After(100 * time.Millisecond): + } + + // unknown codec should not cause change + h265Pkt.SequenceNumber = 4 + h265Pkt.Timestamp = 4 + h265Pkt.PayloadType = 117 + buf, err = h265Pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(buf) + require.NoError(t, err) + select { + case <-changedCodec: + t.Fatalf("unexpected codec change") + case <-time.After(100 * time.Millisecond): + } +} + func BenchmarkMemcpu(b *testing.B) { buf := make([]byte, 1500*1500*10) buf2 := make([]byte, 1500*1500*20) diff --git a/pkg/sfu/buffer/fps.go b/pkg/sfu/buffer/fps.go index 3135ad4c3..5622198f1 100644 --- a/pkg/sfu/buffer/fps.go +++ b/pkg/sfu/buffer/fps.go @@ -17,14 +17,16 @@ package buffer import ( "container/list" - "github.com/livekit/protocol/logger" "github.com/pion/rtp/codecs" + + "github.com/livekit/protocol/logger" ) var minFramesForCalculation = [...]int{8, 15, 40, 60} type frameInfo struct { - seq uint16 + startSeq uint16 + endSeq uint16 ts uint32 fn uint16 spatial int32 @@ -79,7 +81,7 @@ func (f *frameRateCalculatorVPx) RecvPacket(ep *ExtPacket, fn uint16) bool { } if f.baseFrame == nil { - f.baseFrame = &frameInfo{seq: ep.Packet.SequenceNumber, ts: ep.Packet.Timestamp, fn: fn} + f.baseFrame = &frameInfo{ts: ep.Packet.Timestamp, fn: fn} f.fnReceived[0] = f.baseFrame f.firstFrames[temporal] = f.baseFrame return false @@ -102,7 +104,6 @@ func (f *frameRateCalculatorVPx) RecvPacket(ep *ExtPacket, fn uint16) bool { } fi := &frameInfo{ - seq: ep.Packet.SequenceNumber, ts: ep.Packet.Timestamp, fn: fn, temporal: temporal, @@ -373,7 +374,7 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool { fn := ep.DependencyDescriptor.Descriptor.FrameNumber if f.baseFrame == nil { - f.baseFrame = &frameInfo{seq: ep.Packet.SequenceNumber, ts: ep.Packet.Timestamp, fn: fn} + f.baseFrame = &frameInfo{ts: ep.Packet.Timestamp, fn: fn} f.fnReceived[0] = f.baseFrame f.firstFrames[spatial][temporal] = f.baseFrame f.secondFrames[spatial][temporal] = f.baseFrame @@ -406,7 +407,6 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool { } fi := &frameInfo{ - seq: ep.Packet.SequenceNumber, ts: ep.Packet.Timestamp, fn: fn, temporal: temporal, @@ -438,12 +438,13 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool { switch { case val == dependFrame: break insertFrame - case val < dependFrame: + case sn16LT(val, dependFrame): chain.InsertAfter(dependFrame, e) break insertFrame default: if e == chain.Front() { chain.PushFront(dependFrame) + break insertFrame } } } @@ -570,3 +571,162 @@ func (f *FrameRateCalculatorForDDLayer) GetFrameRate() []float32 { } // ----------------------------------------------- + +type FrameRateCalculatorH26x struct { + frameRates [DefaultMaxLayerTemporal + 1]float32 + clockRate uint32 + logger logger.Logger + fnReceived *list.List + baseFrame *frameInfo + completed bool +} + +func NewFrameRateCalculatorH26x(clockRate uint32, logger logger.Logger) *FrameRateCalculatorH26x { + return &FrameRateCalculatorH26x{ + clockRate: clockRate, + logger: logger, + } +} + +func (f *FrameRateCalculatorH26x) Completed() bool { + return f.completed +} + +func (f *FrameRateCalculatorH26x) RecvPacket(ep *ExtPacket) bool { + if f.completed { + return true + } + + if ep.Temporal >= int32(len(f.frameRates)) { + f.logger.Warnw("invalid temporal layer", nil, "temporal", ep.Temporal) + return false + } + + temporal := ep.Temporal + if temporal < 0 { + temporal = 0 + } + + if f.baseFrame == nil { + f.baseFrame = &frameInfo{ + startSeq: ep.Packet.SequenceNumber, + endSeq: ep.Packet.SequenceNumber, + ts: ep.Packet.Timestamp, + temporal: temporal, + } + f.fnReceived = list.New() + f.fnReceived.PushBack(f.baseFrame) + return false + } + + if sn16LTOrEqual(ep.Packet.SequenceNumber, f.baseFrame.startSeq) { + return false + } + +insertFrame: + for e := f.fnReceived.Back(); e != nil; e = e.Prev() { + frame := e.Value.(*frameInfo) + switch { + case frame.ts == ep.Packet.Timestamp: + if sn16LT(frame.endSeq, ep.Packet.SequenceNumber) { + frame.endSeq = ep.Packet.SequenceNumber + } + if sn16LT(ep.Packet.SequenceNumber, frame.startSeq) { + frame.startSeq = ep.Packet.SequenceNumber + } + break insertFrame + case sn32LT(frame.ts, ep.Packet.Timestamp): + f.fnReceived.InsertAfter(&frameInfo{ + startSeq: ep.Packet.SequenceNumber, + endSeq: ep.Packet.SequenceNumber, + ts: ep.Packet.Timestamp, + temporal: temporal, + }, e) + break insertFrame + default: + if e == f.fnReceived.Front() { + f.fnReceived.PushFront(&frameInfo{ + startSeq: ep.Packet.SequenceNumber, + endSeq: ep.Packet.SequenceNumber, + ts: ep.Packet.Timestamp, + temporal: temporal, + }) + break insertFrame + } + } + } + + return f.calc() +} + +func (f *FrameRateCalculatorH26x) calc() bool { + frameCounts := make([]int, DefaultMaxLayerTemporal+1) + var totalFrameCount int + var tsDuration int + cur := f.fnReceived.Front() + for { + next := cur.Next() + if next == nil { + break + } + ff := cur.Value.(*frameInfo) + nf := next.Value.(*frameInfo) + if nf.startSeq-ff.endSeq == 1 { + totalFrameCount++ + tsDuration += int(nf.ts - ff.ts) + for i := int(nf.temporal); i < len(frameCounts); i++ { + frameCounts[i]++ + } + } else { + // reset to find continuous frames + totalFrameCount = 0 + for i := range frameCounts { + frameCounts[i] = 0 + } + tsDuration = 0 + } + + // received enough continuous frames, calculate fps + if totalFrameCount >= minFramesForCalculation[DefaultMaxLayerTemporal] { + for currentTemporal := int32(0); currentTemporal <= DefaultMaxLayerTemporal; currentTemporal++ { + count := frameCounts[currentTemporal] + if currentTemporal > 0 && count == frameCounts[currentTemporal-1] { + // no frames for this temporal layer + f.frameRates[currentTemporal] = 0 + } else { + f.frameRates[currentTemporal] = float32(f.clockRate) / float32(tsDuration) * float32(count) + } + } + f.logger.Debugw("fps changed", "fps", f.GetFrameRate()) + f.completed = true + f.reset() + return true + } + + cur = next + } + + return false +} + +func (f *FrameRateCalculatorH26x) reset() { + f.fnReceived.Init() + f.baseFrame = nil +} + +func (f *FrameRateCalculatorH26x) GetFrameRate() []float32 { + return f.frameRates[:] +} + +// ----------------------------------------------- +func sn16LT(a, b uint16) bool { + return a-b > 0x8000 +} + +func sn16LTOrEqual(a, b uint16) bool { + return a == b || a-b > 0x8000 +} + +func sn32LT(a, b uint32) bool { + return a-b > 0x80000000 +} diff --git a/pkg/sfu/buffer/fps_test.go b/pkg/sfu/buffer/fps_test.go index 39dd93df3..f560d8d12 100644 --- a/pkg/sfu/buffer/fps_test.go +++ b/pkg/sfu/buffer/fps_test.go @@ -55,10 +55,16 @@ func (f *testFrameInfo) toDD() *ExtPacket { }, VideoLayer: VideoLayer{Spatial: int32(f.spatial), Temporal: int32(f.temporal)}, } - } -func createFrames(startFrameNumber uint16, startTs uint32, totalFramesPerSpatial int, fps [][]float32, spatialDependency bool) [][]*testFrameInfo { +func (f *testFrameInfo) toH26x() *ExtPacket { + return &ExtPacket{ + Packet: &rtp.Packet{Header: f.header}, + VideoLayer: VideoLayer{Spatial: InvalidLayerSpatial, Temporal: int32(f.temporal)}, + } +} + +func createFrames(startFrameNumber uint16, startTs uint32, startSeq uint16, totalFramesPerSpatial int, fps [][]float32, spatialDependency bool) [][]*testFrameInfo { spatials := len(fps) temporals := len(fps[0]) frames := make([][]*testFrameInfo, spatials) @@ -85,7 +91,7 @@ func createFrames(startFrameNumber uint16, startTs uint32, totalFramesPerSpatial for i := 0; i < totalFramesPerSpatial; i++ { for s := 0; s < spatials; s++ { frame := &testFrameInfo{ - header: rtp.Header{Timestamp: currentTs[s]}, + header: rtp.Header{Timestamp: currentTs[s], SequenceNumber: startSeq}, framenumber: fn, spatial: s, } @@ -101,6 +107,7 @@ func createFrames(startFrameNumber uint16, startTs uint32, totalFramesPerSpatial currentTs[s] += tsStep[s][temporals-1] frames[s] = append(frames[s], frame) fn++ + startSeq++ for fidx := len(frames[s]) - 1; fidx >= 0; fidx-- { cf := frames[s][fidx] @@ -135,6 +142,7 @@ func verifyFps(t *testing.T, expect, got []float32) { type testcase struct { startTs uint32 + startSeq uint16 startFrameNumber uint16 fps [][]float32 spatialDependency bool @@ -167,7 +175,7 @@ func TestFpsVP8(t *testing.T) { vp8calcs := make([]*FrameRateCalculatorVP8, len(fps)) for i := range vp8calcs { vp8calcs[i] = NewFrameRateCalculatorVP8(90000, logger.GetLogger()) - frames = append(frames, createFrames(c.startFrameNumber, c.startTs, 200, [][]float32{fps[i]}, false)[0]) + frames = append(frames, createFrames(c.startFrameNumber, c.startTs, 10, 200, [][]float32{fps[i]}, false)[0]) } var frameratesGot bool @@ -198,7 +206,7 @@ func TestFpsVP8(t *testing.T) { vp8calcs := make([]*FrameRateCalculatorVP8, len(fps)) for i := range vp8calcs { vp8calcs[i] = NewFrameRateCalculatorVP8(90000, logger.GetLogger()) - frames = append(frames, createFrames(100, 12345678, 300, [][]float32{fps[i]}, false)[0]) + frames = append(frames, createFrames(100, 12345678, 10, 300, [][]float32{fps[i]}, false)[0]) for j := 5; j < 130; j++ { if j%2 == 0 { frames[i][j] = frames[i][j-1] @@ -255,7 +263,7 @@ func TestFpsDD(t *testing.T) { testCase := c t.Run(name, func(t *testing.T) { fps := testCase.fps - frames := createFrames(c.startFrameNumber, c.startTs, 500, fps, testCase.spatialDependency) + frames := createFrames(c.startFrameNumber, c.startTs, 10, 500, fps, testCase.spatialDependency) ddcalc := NewFrameRateCalculatorDD(90000, logger.GetLogger()) ddcalc.SetMaxLayer(int32(len(fps)-1), int32(len(fps[0])-1)) ddcalcs := make([]FrameRateCalculator, len(fps)) @@ -288,7 +296,7 @@ func TestFpsDD(t *testing.T) { t.Run("packet lost and duplicate", func(t *testing.T) { fps := [][]float32{{7.5, 15, 30}, {7.5, 15, 30}, {7.5, 15, 30}} - frames := createFrames(100, 12345678, 500, fps, true) + frames := createFrames(100, 12345678, 10, 500, fps, true) ddcalc := NewFrameRateCalculatorDD(90000, logger.GetLogger()) ddcalc.SetMaxLayer(int32(len(fps)-1), int32(len(fps[0])-1)) ddcalcs := make([]FrameRateCalculator, len(fps)) @@ -322,5 +330,108 @@ func TestFpsDD(t *testing.T) { verifyFps(t, fpsExpected, fpsGot[:len(fpsExpected)]) } }) - +} + +func TestFpsH26x(t *testing.T) { + cases := map[string]testcase{ + "normal": { + startTs: 12345678, + startSeq: 100, + startFrameNumber: 100, + fps: [][]float32{{5, 10, 15}, {5, 10, 15}, {7.5, 15, 30}}, + }, + "frame number and timestamp wrap": { + startTs: (uint32(1) << 31) - 10, + startSeq: (uint16(1) << 15) - 10, + startFrameNumber: (uint16(1) << 15) - 10, + fps: [][]float32{{5, 10, 15}, {5, 10, 15}, {7.5, 15, 30}}, + }, + "2 temporal layers": { + startTs: 12345678, + startFrameNumber: 100, + fps: [][]float32{{7.5, 15}, {7.5, 15}, {15, 30}}, + }, + } + + for name, c := range cases { + testCase := c + t.Run(name, func(t *testing.T) { + fps := testCase.fps + frames := make([][]*testFrameInfo, 0) + h26xcalcs := make([]*FrameRateCalculatorH26x, len(fps)) + for i := range h26xcalcs { + h26xcalcs[i] = NewFrameRateCalculatorH26x(90000, logger.GetLogger()) + frames = append(frames, createFrames(c.startFrameNumber, c.startTs, c.startSeq, 200, [][]float32{fps[i]}, false)[0]) + } + + var frameratesGot bool + for s, fs := range frames { + for _, f := range fs { + if h26xcalcs[s].RecvPacket(f.toH26x()) { + frameratesGot = true + for _, calc := range h26xcalcs { + if !calc.Completed() { + frameratesGot = false + break + } + } + } + } + } + require.True(t, frameratesGot) + for i, calc := range h26xcalcs { + fpsExpected := fps[i] + fpsGot := calc.GetFrameRate() + verifyFps(t, fpsExpected, fpsGot[:len(fpsExpected)]) + } + }) + } + + t.Run("packet lost and duplicate", func(t *testing.T) { + fps := [][]float32{{7.5, 15, 30}, {7.5, 15, 30}, {7.5, 15, 30}} + frames := make([][]*testFrameInfo, 0, len(fps)) + h26xcalcs := make([]FrameRateCalculator, len(fps)) + for i := range fps { + frames = append(frames, createFrames(100, 12345678, 10, 500, [][]float32{fps[i]}, false)[0]) + h26xcalcs[i] = NewFrameRateCalculatorH26x(90000, logger.GetLogger()) + for j := 5; j < 130; j++ { + if j%2 == 0 { + frames[i][j] = frames[i][j-1] + } + } + for j := 130; j < 230; j++ { + if j%3 == 0 { + frames[i][j] = nil + } + } + for j := 230; j < 330; j++ { + if j%2 == 0 { + frames[i][j], frames[i][j-1] = frames[i][j-1], frames[i][j] + } + } + } + var frameratesGot bool + for s, fs := range frames { + for _, f := range fs { + if f == nil { + continue + } + if h26xcalcs[s].RecvPacket(f.toH26x()) { + frameratesGot = true + for _, calc := range h26xcalcs { + if !calc.Completed() { + frameratesGot = false + break + } + } + } + } + } + require.True(t, frameratesGot) + for i, calc := range h26xcalcs { + fpsExpected := fps[i] + fpsGot := calc.GetFrameRate() + verifyFps(t, fpsExpected, fpsGot[:len(fpsExpected)]) + } + }) } diff --git a/pkg/sfu/buffer/helpers.go b/pkg/sfu/buffer/helpers.go index c6838ff4e..d8c862ddd 100644 --- a/pkg/sfu/buffer/helpers.go +++ b/pkg/sfu/buffer/helpers.go @@ -18,8 +18,9 @@ import ( "encoding/binary" "errors" - "github.com/livekit/protocol/logger" "github.com/pion/rtp/codecs" + + "github.com/livekit/protocol/logger" ) var ( @@ -415,4 +416,40 @@ func IsAV1KeyFrame(payload []byte) bool { } } +func IsH265KeyFrame(payload []byte) (kf bool) { + if len(payload) < 2 { + return false + } + naluType := (payload[0] & 0x7E) >> 1 + switch { + case naluType == 33 || naluType == 34: + return true + case naluType == 48: // AP + idx := 2 + for idx < len(payload)-2 { + // TODO: check the DONL field (controled by sprop-max-don-diff) + size := binary.BigEndian.Uint16(payload[idx:]) + idx += 2 + if idx >= len(payload) { + return false + } + naluType = (payload[idx] & 0x7E) >> 1 + if naluType == 33 || naluType == 34 { + return true + } + idx += int(size) + } + return false + + case naluType == 49: // FU + if len(payload) < 3 { + return false + } + naluType = (payload[2] & 0x7E) >> 1 + return naluType == 33 || naluType == 34 + default: + return false + } +} + // ------------------------------------- diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 44eb56088..4505c76f2 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -68,6 +68,7 @@ type TrackSender interface { publisherSRData *livekit.RTCPSenderReportState, ) error Resync() + SetReceiver(TrackReceiver) } // ------------------------------------------------------------------- @@ -237,6 +238,7 @@ type DowntrackParams struct { Trailer []byte RTCPWriter func([]rtcp.Packet) error DisableSenderReportPassThrough bool + SupportsCodecChange bool } // DownTrack implements TrackLocal, is the track used to write packets @@ -251,14 +253,16 @@ type DownTrack struct { params DowntrackParams id livekit.TrackID kind webrtc.RTPCodecType - mime string ssrc uint32 ssrcRTX uint32 - payloadType uint8 - payloadTypeRTX uint8 + payloadType atomic.Uint32 + payloadTypeRTX atomic.Uint32 sequencer *sequencer rtxSequenceNumber atomic.Uint64 + receiverLock sync.RWMutex + receiver TrackReceiver + forwarder *Forwarder upstreamCodecs []webrtc.RTPCodecParameters @@ -362,6 +366,7 @@ func NewDownTrack(params DowntrackParams) (*DownTrack, error) { maxLayerNotifierCh: make(chan string, 1), keyFrameRequesterCh: make(chan struct{}, 1), createdAt: time.Now().UnixNano(), + receiver: params.Receiver, } d.bindState.Store(bindStateUnbound) d.params.Logger = params.Logger.WithValues( @@ -542,8 +547,8 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.ssrc = uint32(t.SSRC()) d.ssrcRTX = uint32(t.SSRCRetransmission()) - d.payloadType = uint8(codec.PayloadType) - d.payloadTypeRTX = uint8(utils.FindRTXPayloadType(codec.PayloadType, d.negotiatedCodecParameters)) + d.payloadType.Store(uint32(codec.PayloadType)) + d.payloadTypeRTX.Store(uint32(utils.FindRTXPayloadType(codec.PayloadType, d.negotiatedCodecParameters))) logFields = append( logFields, "payloadType", d.payloadType, @@ -553,7 +558,6 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.params.Logger.Debugw("DownTrack.Bind", logFields...) d.writeStream = t.WriteStream() - d.mime = strings.ToLower(codec.MimeType) if rr := d.params.BufferFactory.GetOrNew(packetio.RTCPBufferPacket, d.ssrc).(*buffer.RTCPReader); rr != nil { rr.OnPacket(func(pkt []byte) { d.handleRTCP(pkt) @@ -578,7 +582,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.setBindStateLocked(bindStateBound) d.bindLock.Unlock() - d.forwarder.DetermineCodec(codec.RTPCodecCapability, d.params.Receiver.HeaderExtensions()) + d.forwarder.DetermineCodec(codec.RTPCodecCapability, d.Receiver().HeaderExtensions()) d.connectionStats.Start(codec.MimeType, isFECEnabled) d.params.Logger.Debugw("downtrack bound") } @@ -615,8 +619,12 @@ func (d *DownTrack) setBindStateLocked(state bindState) { } func (d *DownTrack) handleReceiverReady() { - d.params.Logger.Debugw("downtrack receiver ready") d.bindLock.Lock() + if d.isReceiverReady { + d.bindLock.Unlock() + return + } + d.params.Logger.Debugw("downtrack receiver ready") d.isReceiverReady = true doBind := d.bindOnReceiverReady d.bindOnReceiverReady = nil @@ -627,6 +635,63 @@ func (d *DownTrack) handleReceiverReady() { } } +func (d *DownTrack) handleUpstreamCodecChange(mime string) { + d.bindLock.Lock() + if strings.EqualFold(d.codec.MimeType, mime) { + d.bindLock.Unlock() + return + } + + if !d.params.SupportsCodecChange { + d.bindLock.Unlock() + d.params.Logger.Infow("client doesn't support codec change, renegotiate new codec") + go d.Close() + return + } + + oldPT, oldRtxPT, oldCodec := d.payloadType.Load(), d.payloadTypeRTX.Load(), d.codec + + var codec webrtc.RTPCodecParameters + for _, c := range d.upstreamCodecs { + if !strings.EqualFold(c.MimeType, mime) { + continue + } + + matchCodec, err := utils.CodecParametersFuzzySearch(c, d.negotiatedCodecParameters) + if err == nil { + codec = matchCodec + break + } + } + + if codec.MimeType == "" { + // codec not found, should not happen since the upstream codec should only fall back to higher compatibility (vp8) + d.params.Logger.Errorw( + "can't find matched codec for new upstream payload type", nil, + "upstreamCodecs", d.upstreamCodecs, + "remoteParameters", d.negotiatedCodecParameters, + "mime", mime, + ) + d.bindLock.Unlock() + return + } + + d.payloadType.Store(uint32(codec.PayloadType)) + d.payloadTypeRTX.Store(uint32(utils.FindRTXPayloadType(codec.PayloadType, d.negotiatedCodecParameters))) + d.codec = codec.RTPCodecCapability + d.bindLock.Unlock() + + d.params.Logger.Infow( + "upstream codec changed", + "oldPT", oldPT, "newPT", d.payloadType.Load(), + "oldRTXPT", oldRtxPT, "newRTXPT", d.payloadTypeRTX.Load(), + "oldCodec", oldCodec, "newCodec", codec.RTPCodecCapability, + ) + + d.forwarder.Restart() + d.forwarder.DetermineCodec(codec.RTPCodecCapability, d.Receiver().HeaderExtensions()) +} + // Unbind implements the teardown logic when the track is no longer needed. This happens // because a track has been stopped. func (d *DownTrack) Unbind(_ webrtc.TrackLocalContext) error { @@ -687,6 +752,38 @@ func (d *DownTrack) SubscriberID() livekit.ParticipantID { return livekit.ParticipantID(fmt.Sprintf("%s:%d", d.params.SubID, d.createdAt)) } +func (d *DownTrack) Receiver() TrackReceiver { + d.receiverLock.RLock() + defer d.receiverLock.RUnlock() + return d.receiver +} + +func (d *DownTrack) SetReceiver(r TrackReceiver) { + d.params.Logger.Debugw("downtrack set receiver", "codec", r.Codec()) + d.bindLock.Lock() + if d.IsClosed() { + d.bindLock.Unlock() + return + } + + d.receiverLock.Lock() + old := d.receiver + d.receiver = r + d.receiverLock.Unlock() + + old.DeleteDownTrack(d.SubscriberID()) + if err := r.AddDownTrack(d); err != nil { + d.params.Logger.Warnw("failed to add downtrack to receiver", err) + } + d.bindLock.Unlock() + + r.AddOnReady(d.handleReceiverReady) + d.handleUpstreamCodecChange(r.Codec().MimeType) + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnSubscribedLayerChanged(d, d.forwarder.MaxLayer()) + } +} + // Sets RTP header extensions for this track func (d *DownTrack) SetRTPHeaderExtensions(rtpHeaderExtensions []webrtc.RTPHeaderExtensionParameter) { isBWEEnabled := true @@ -799,7 +896,7 @@ func (d *DownTrack) keyFrameRequester() { locked, layer := d.forwarder.CheckSync() if !locked && layer != buffer.InvalidLayerSpatial && d.writable.Load() { d.params.Logger.Debugw("sending PLI for layer lock", "layer", layer) - d.params.Receiver.SendPLI(layer, false) + d.Receiver().SendPLI(layer, false) d.rtpStats.UpdateLayerLockPliAndTime(1) } } @@ -1051,7 +1148,7 @@ func (d *DownTrack) WritePaddingRTP(bytesToSend int, paddingOnMute bool, forceMa Version: 2, Padding: true, Marker: false, - PayloadType: d.payloadType, + PayloadType: uint8(d.payloadType.Load()), SequenceNumber: uint16(snts[i].extSequenceNumber), Timestamp: uint32(snts[i].extTimestamp), SSRC: d.ssrc, @@ -1203,7 +1300,7 @@ func (d *DownTrack) CloseWithFlush(flush bool) { d.params.Logger.Debugw("closing sender", "kind", d.kind) } d.setBindStateLocked(bindStateUnbound) - d.params.Receiver.DeleteDownTrack(d.SubscriberID()) + d.Receiver().DeleteDownTrack(d.SubscriberID()) if d.rtcpReader != nil && flush { d.params.Logger.Debugw("downtrack close rtcp reader") @@ -1215,6 +1312,7 @@ func (d *DownTrack) CloseWithFlush(flush bool) { d.rtcpReaderRTX.Close() d.rtcpReaderRTX.OnPacket(nil) } + mime := d.codec.MimeType d.bindLock.Unlock() d.connectionStats.Close() @@ -1223,7 +1321,7 @@ func (d *DownTrack) CloseWithFlush(flush bool) { d.rtpStatsRTX.Stop() d.params.Logger.Debugw("rtp stats", "direction", "downstream", - "mime", d.mime, + "mime", mime, "ssrc", d.ssrc, "stats", d.rtpStats, "statsRTX", d.rtpStatsRTX, @@ -1448,17 +1546,17 @@ func (d *DownTrack) IsDeficient() bool { } func (d *DownTrack) BandwidthRequested() int64 { - _, brs := d.params.Receiver.GetLayeredBitrate() + _, brs := d.Receiver().GetLayeredBitrate() return d.forwarder.BandwidthRequested(brs) } func (d *DownTrack) DistanceToDesired() float64 { - al, brs := d.params.Receiver.GetLayeredBitrate() + al, brs := d.Receiver().GetLayeredBitrate() return d.forwarder.DistanceToDesired(al, brs) } func (d *DownTrack) AllocateOptimal(allowOvershoot bool, hold bool) VideoAllocation { - al, brs := d.params.Receiver.GetLayeredBitrate() + al, brs := d.Receiver().GetLayeredBitrate() allocation := d.forwarder.AllocateOptimal(al, brs, allowOvershoot, hold) d.postKeyFrameRequestEvent() d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) @@ -1466,7 +1564,7 @@ func (d *DownTrack) AllocateOptimal(allowOvershoot bool, hold bool) VideoAllocat } func (d *DownTrack) ProvisionalAllocatePrepare() { - al, brs := d.params.Receiver.GetLayeredBitrate() + al, brs := d.Receiver().GetLayeredBitrate() d.forwarder.ProvisionalAllocatePrepare(al, brs) } @@ -1508,7 +1606,7 @@ func (d *DownTrack) ProvisionalAllocateCommit() VideoAllocation { } func (d *DownTrack) AllocateNextHigher(availableChannelCapacity int64, allowOvershoot bool) (VideoAllocation, bool) { - al, brs := d.params.Receiver.GetLayeredBitrate() + al, brs := d.Receiver().GetLayeredBitrate() allocation, available := d.forwarder.AllocateNextHigher(availableChannelCapacity, al, brs, allowOvershoot) d.postKeyFrameRequestEvent() d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) @@ -1516,7 +1614,7 @@ func (d *DownTrack) AllocateNextHigher(availableChannelCapacity int64, allowOver } func (d *DownTrack) GetNextHigherTransition(allowOvershoot bool) (VideoTransition, bool) { - availableLayers, brs := d.params.Receiver.GetLayeredBitrate() + availableLayers, brs := d.Receiver().GetLayeredBitrate() transition, available := d.forwarder.GetNextHigherTransition(brs, allowOvershoot) d.params.Logger.Debugw( "stream: get next higher layer", @@ -1529,7 +1627,7 @@ func (d *DownTrack) GetNextHigherTransition(allowOvershoot bool) (VideoTransitio } func (d *DownTrack) Pause() VideoAllocation { - al, brs := d.params.Receiver.GetLayeredBitrate() + al, brs := d.Receiver().GetLayeredBitrate() allocation := d.forwarder.Pause(al, brs) d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) return allocation @@ -1581,15 +1679,16 @@ func (d *DownTrack) writeBlankFrameRTP(duration float32, generation uint32) chan return } + mime := strings.ToLower(d.Codec().MimeType) var getBlankFrame func(bool) ([]byte, error) switch { - case strings.EqualFold(d.mime, webrtc.MimeTypeOpus): + case strings.EqualFold(mime, webrtc.MimeTypeOpus): getBlankFrame = d.getOpusBlankFrame - case strings.EqualFold(d.mime, MimeTypeAudioRed): + case strings.EqualFold(mime, MimeTypeAudioRed): getBlankFrame = d.getOpusRedBlankFrame - case strings.EqualFold(d.mime, webrtc.MimeTypeVP8): + case strings.EqualFold(mime, webrtc.MimeTypeVP8): getBlankFrame = d.getVP8BlankFrame - case strings.EqualFold(d.mime, webrtc.MimeTypeH264): + case strings.EqualFold(mime, webrtc.MimeTypeH264): getBlankFrame = d.getH264BlankFrame default: close(done) @@ -1597,7 +1696,7 @@ func (d *DownTrack) writeBlankFrameRTP(duration float32, generation uint32) chan } frameRate := uint32(30) - if d.mime == strings.ToLower(webrtc.MimeTypeOpus) || d.mime == strings.ToLower(MimeTypeAudioRed) { + if mime == strings.ToLower(webrtc.MimeTypeOpus) || mime == strings.ToLower(MimeTypeAudioRed) { frameRate = 50 } @@ -1628,7 +1727,7 @@ func (d *DownTrack) writeBlankFrameRTP(duration float32, generation uint32) chan Version: 2, Padding: false, Marker: true, - PayloadType: d.payloadType, + PayloadType: uint8(d.payloadType.Load()), SequenceNumber: uint16(snts[i].extSequenceNumber), Timestamp: uint32(snts[i].extTimestamp), SSRC: d.ssrc, @@ -1760,7 +1859,7 @@ func (d *DownTrack) handleRTCP(bytes []byte) { if pliOnce { if layer != buffer.InvalidLayerSpatial { d.params.Logger.Debugw("sending PLI RTCP", "layer", layer) - d.params.Receiver.SendPLI(layer, false) + d.Receiver().SendPLI(layer, false) d.isNACKThrottled.Store(true) d.rtpStats.UpdatePliTime() pliOnce = false @@ -1963,11 +2062,11 @@ func (d *DownTrack) retransmitPacket(epm *extPacketMeta, sourcePkt []byte, isPro } rtxOffset := 0 var rtxExtSequenceNumber uint64 - if d.payloadTypeRTX != 0 && d.ssrcRTX != 0 { + if rtxPT := d.payloadTypeRTX.Load(); rtxPT != 0 && d.ssrcRTX != 0 { rtxExtSequenceNumber = d.rtxSequenceNumber.Inc() rtxOffset = 2 - hdr.PayloadType = d.payloadTypeRTX + hdr.PayloadType = uint8(rtxPT) hdr.SequenceNumber = uint16(rtxExtSequenceNumber) hdr.SSRC = d.ssrcRTX } @@ -2083,7 +2182,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { nackAcks++ pktBuff := *src - n, err := d.params.Receiver.ReadRTP(pktBuff, uint8(epm.layer), epm.sourceSeqNo) + n, err := d.Receiver().ReadRTP(pktBuff, uint8(epm.layer), epm.sourceSeqNo) if err != nil { if err == io.EOF { break @@ -2105,7 +2204,8 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { } func (d *DownTrack) WriteProbePackets(bytesToSend int, usePadding bool) int { - if d.payloadTypeRTX == 0 || d.ssrcRTX == 0 { + rtxPT := uint8(d.payloadTypeRTX.Load()) + if rtxPT == 0 || d.ssrcRTX == 0 { return d.WritePaddingRTP(bytesToSend, false, false) } @@ -2132,7 +2232,7 @@ func (d *DownTrack) WriteProbePackets(bytesToSend int, usePadding bool) int { Version: 2, Padding: true, Marker: false, - PayloadType: d.payloadTypeRTX, + PayloadType: rtxPT, SequenceNumber: uint16(rtxExtSequenceNumber), Timestamp: 0, SSRC: d.ssrcRTX, @@ -2181,7 +2281,7 @@ func (d *DownTrack) WriteProbePackets(bytesToSend int, usePadding bool) int { } pktBuff := *src - n, err := d.params.Receiver.ReadRTP(pktBuff, uint8(epm.layer), epm.sourceSeqNo) + n, err := d.Receiver().ReadRTP(pktBuff, uint8(epm.layer), epm.sourceSeqNo) if err != nil { if err == io.EOF { break @@ -2216,7 +2316,7 @@ func (d *DownTrack) getTranslatedPayloadType(src uint8) uint8 { if d.isRED && src == d.upstreamPrimaryPT && d.primaryPT != 0 { return d.primaryPT } - return d.payloadType + return uint8(d.payloadType.Load()) } func (d *DownTrack) DebugInfo() map[string]interface{} { @@ -2314,7 +2414,7 @@ func (d *DownTrack) sendPaddingOnMute() { if d.kind == webrtc.RTPCodecTypeVideo { d.sendPaddingOnMuteForVideo() - } else if d.mime == strings.ToLower(webrtc.MimeTypeOpus) { + } else if strings.EqualFold(d.Codec().MimeType, webrtc.MimeTypeOpus) { d.sendSilentFrameOnMuteForOpus() } } @@ -2357,7 +2457,7 @@ func (d *DownTrack) sendSilentFrameOnMuteForOpus() { Version: 2, Padding: false, Marker: true, - PayloadType: d.payloadType, + PayloadType: uint8(d.payloadType.Load()), SequenceNumber: uint16(snts[i].extSequenceNumber), Timestamp: uint32(snts[i].extTimestamp), SSRC: d.ssrc, diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index 111b4d80f..e0e9f590d 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -220,15 +220,16 @@ type Forwarder struct { pubMuted bool resumeBehindThreshold float64 - started bool - preStartTime time.Time - extFirstTS uint64 - lastSSRC uint32 - lastSwitchExtIncomingTS uint64 - referenceLayerSpatial int32 - dummyStartTSOffset uint64 - refInfos [buffer.DefaultMaxLayerSpatial + 1]refInfo - refIsSVC bool + started bool + preStartTime time.Time + extFirstTS uint64 + lastSSRC uint32 + lastReferencePayloadType int8 + lastSwitchExtIncomingTS uint64 + referenceLayerSpatial int32 + dummyStartTSOffset uint64 + refInfos [buffer.DefaultMaxLayerSpatial + 1]refInfo + refIsSVC bool provisional *VideoAllocationProvisional @@ -248,15 +249,16 @@ func NewForwarder( rtpStats *rtpstats.RTPStatsSender, ) *Forwarder { f := &Forwarder{ - kind: kind, - logger: logger, - skipReferenceTS: skipReferenceTS, - rtpStats: rtpStats, - referenceLayerSpatial: buffer.InvalidLayerSpatial, - lastAllocation: VideoAllocationDefault, - rtpMunger: NewRTPMunger(logger), - vls: videolayerselector.NewNull(logger), - codecMunger: codecmunger.NewNull(logger), + kind: kind, + logger: logger, + skipReferenceTS: skipReferenceTS, + rtpStats: rtpStats, + referenceLayerSpatial: buffer.InvalidLayerSpatial, + lastAllocation: VideoAllocationDefault, + lastReferencePayloadType: -1, + rtpMunger: NewRTPMunger(logger), + vls: videolayerselector.NewNull(logger), + codecMunger: codecmunger.NewNull(logger), } if f.kind == webrtc.RTPCodecTypeVideo { @@ -297,8 +299,9 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ f.lock.Lock() defer f.lock.Unlock() - if f.codec.MimeType != "" { - return + codecChanged := f.codec.MimeType != "" && f.codec.MimeType != codec.MimeType + if codecChanged { + f.logger.Debugw("forwarder codec changed", "from", f.codec.MimeType, "to", codec.MimeType) } f.codec = codec @@ -315,15 +318,21 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ case "video/vp8": f.codecMunger = codecmunger.NewVP8FromNull(f.codecMunger, f.logger) if f.vls != nil { - f.vls = videolayerselector.NewSimulcastFromNull(f.vls) + if vls := videolayerselector.NewSimulcastFromOther(f.vls); vls != nil { + f.vls = vls + } else { + f.logger.Errorw("failed to create simulcast on codec change", nil) + } } else { f.vls = videolayerselector.NewSimulcast(f.logger) } f.vls.SetTemporalLayerSelector(temporallayerselector.NewVP8(f.logger)) case "video/h264": + fallthrough + case "video/h265": if f.vls != nil { - f.vls = videolayerselector.NewSimulcastFromNull(f.vls) + f.vls = videolayerselector.NewSimulcastFromOther(f.vls) } else { f.vls = videolayerselector.NewSimulcast(f.logger) } @@ -357,7 +366,7 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ } } else { if f.vls != nil { - f.vls = videolayerselector.NewSimulcastFromNull(f.vls) + f.vls = videolayerselector.NewSimulcastFromOther(f.vls) } else { f.vls = videolayerselector.NewSimulcast(f.logger) } @@ -1566,6 +1575,22 @@ func (f *Forwarder) CheckSync() (bool, int32) { return true, layer } +func (f *Forwarder) Restart() { + f.lock.Lock() + defer f.lock.Unlock() + + f.resyncLocked() + f.setTargetLayer(buffer.InvalidLayer, buffer.InvalidLayerSpatial) + f.referenceLayerSpatial = buffer.InvalidLayerSpatial + f.lastReferencePayloadType = -1 + + for layer := 0; layer < len(f.refInfos); layer++ { + f.refInfos[layer] = refInfo{} + } + f.lastSwitchExtIncomingTS = 0 + f.refIsSVC = false +} + func (f *Forwarder) FilterRTX(nacks []uint16) (filtered []uint16, disallowedLayers [buffer.DefaultMaxLayerSpatial + 1]bool) { f.lock.RLock() defer f.lock.RUnlock() @@ -1738,6 +1763,12 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e // potentially happening very quickly. Erroring out and waiting for a layer for which a sender report has been // received will calculate a better offset, but may result in initial adaptation to take a bit longer depending // on how often publisher/remote side sends RTCP sender report. + f.logger.Debugw( + "could not get ref layer timestamp", + "referenceLayerSpatial", f.referenceLayerSpatial, + "layer", layer, + "error", err, + ) return err } } @@ -1757,6 +1788,10 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e tsExt, err := f.rtpStats.GetExpectedRTPTimestamp(switchingAt) if err == nil { extExpectedTS = tsExt + if f.lastReferencePayloadType == -1 { + f.dummyStartTSOffset = extExpectedTS - uint64(refTS) + extRefTS = extExpectedTS + } } else { if !f.preStartTime.IsZero() { timeSinceFirst := time.Since(f.preStartTime) @@ -1834,6 +1869,7 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e // AVSYNC-TODO: Consider some forcing function to do the switch // (like "have waited for too long for layer switch, nothing available, switch to whatever is available" kind of condition). logTransition("layer switch, reference too far behind", extExpectedTS, extRefTS, extLastTS, diffSeconds) + return errors.New("switch point too far behind") } @@ -1918,9 +1954,12 @@ func (f *Forwarder) getTranslationParamsCommon(extPkt *buffer.ExtPacket, layer i f.vls.Rollback() return nil } - f.logger.Debugw("switching feed", - "from", f.lastSSRC, - "to", extPkt.Packet.SSRC, + f.logger.Debugw( + "switching feed", + "fromSSRC", f.lastSSRC, + "toSSRC", extPkt.Packet.SSRC, + "fromPayloadType", f.lastReferencePayloadType, + "toPayloadType", extPkt.Packet.PayloadType, "layer", layer, "refInfos", logger.ObjectSlice(f.refInfos[:]), "lastSwitchExtIncomingTS", f.lastSwitchExtIncomingTS, @@ -1929,6 +1968,7 @@ func (f *Forwarder) getTranslationParamsCommon(extPkt *buffer.ExtPacket, layer i "maxLayer", f.vls.GetMax(), ) f.lastSSRC = extPkt.Packet.SSRC + f.lastReferencePayloadType = int8(extPkt.Packet.PayloadType) f.lastSwitchExtIncomingTS = extPkt.ExtTimestamp } diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 216ea5ca8..70c8d434c 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -84,10 +84,21 @@ type AudioLevelHandle func(level uint8, duration uint32) type Bitrates [buffer.DefaultMaxLayerSpatial + 1][buffer.DefaultMaxLayerTemporal + 1]int64 +type ReceiverCodecState int + +const ( + ReceiverCodecStateNormal ReceiverCodecState = iota + ReceiverCodecStateSuspended + ReceiverCodecStateInvalid +) + // TrackReceiver defines an interface receive media from remote peer type TrackReceiver interface { TrackID() livekit.TrackID StreamID() string + + // returns the initial codec of the receiver, it is determined by the track's codec + // and will not change if the codec changes during the session (publisher changes codec) Codec() webrtc.RTPCodecParameters HeaderExtensions() []webrtc.RTPHeaderExtensionParameter IsClosed() bool @@ -104,6 +115,7 @@ type TrackReceiver interface { AddDownTrack(track TrackSender) error DeleteDownTrack(participantID livekit.ParticipantID) + GetDownTracks() []TrackSender DebugInfo() map[string]interface{} @@ -123,6 +135,9 @@ type TrackReceiver interface { // AddOnReady adds a function to be called when the receiver is ready, the callback // could be called immediately if the receiver is ready when the callback is added AddOnReady(func()) + + AddOnCodecStateChange(func(webrtc.RTPCodecParameters, ReceiverCodecState)) + CodecState() ReceiverCodecState } type redPktWriteFunc func(pkt *buffer.ExtPacket, spatialLayer int32) int @@ -134,18 +149,21 @@ type WebRTCReceiver struct { pliThrottleConfig PLIThrottleConfig audioConfig AudioConfig - trackID livekit.TrackID - streamID string - kind webrtc.RTPCodecType - receiver *webrtc.RTPReceiver - codec webrtc.RTPCodecParameters - isSVC bool - isRED bool - onCloseHandler func() - closeOnce sync.Once - closed atomic.Bool - useTrackers bool - trackInfo atomic.Pointer[livekit.TrackInfo] + trackID livekit.TrackID + streamID string + kind webrtc.RTPCodecType + receiver *webrtc.RTPReceiver + codec webrtc.RTPCodecParameters + codecState ReceiverCodecState + codecStateLock sync.Mutex + onCodecStateChange []func(webrtc.RTPCodecParameters, ReceiverCodecState) + isSVC bool + isRED bool + onCloseHandler func() + closeOnce sync.Once + closed atomic.Bool + useTrackers bool + trackInfo atomic.Pointer[livekit.TrackInfo] onRTCP func([]rtcp.Packet) @@ -228,15 +246,16 @@ func NewWebRTCReceiver( opts ...ReceiverOpts, ) *WebRTCReceiver { w := &WebRTCReceiver{ - logger: logger, - receiver: receiver, - trackID: livekit.TrackID(track.ID()), - streamID: track.StreamID(), - codec: track.Codec(), - kind: track.Kind(), - onRTCP: onRTCP, - isSVC: buffer.IsSvcCodec(track.Codec().MimeType), - isRED: buffer.IsRedCodec(track.Codec().MimeType), + logger: logger, + receiver: receiver, + trackID: livekit.TrackID(track.ID()), + streamID: track.StreamID(), + codec: track.Codec(), + codecState: ReceiverCodecStateNormal, + kind: track.Kind(), + onRTCP: onRTCP, + isSVC: buffer.IsSvcCodec(track.Codec().MimeType), + isRED: buffer.IsRedCodec(track.Codec().MimeType), } for _, opt := range opts { @@ -382,6 +401,10 @@ func (w *WebRTCReceiver) AddUpTrack(track TrackRemote, buff *buffer.Buffer) erro }) }) + if w.Kind() == webrtc.RTPCodecTypeVideo && layer == 0 { + buff.OnCodecChange(w.handleCodecChange) + } + var duration time.Duration switch layer { case 2: @@ -454,6 +477,10 @@ func (w *WebRTCReceiver) AddDownTrack(track TrackSender) error { return nil } +func (w *WebRTCReceiver) GetDownTracks() []TrackSender { + return w.downTrackSpreader.GetDownTracks() +} + func (w *WebRTCReceiver) notifyMaxExpectedLayer(layer int32) { ti := w.TrackInfo() if ti == nil { @@ -727,6 +754,11 @@ func (w *WebRTCReceiver) forwardRTP(layer int32, buff *buffer.Buffer) { return } + if pkt.Packet.PayloadType != uint8(w.codec.PayloadType) { + // drop packets as we don't support codec fallback directly + continue + } + spatialLayer := layer if pkt.Spatial >= 0 { // svc packet, take spatial layer info from packet @@ -860,6 +892,40 @@ func (w *WebRTCReceiver) AddOnReady(fn func()) { fn() } +func (w *WebRTCReceiver) handleCodecChange(newCodec webrtc.RTPCodecParameters) { + // we don't support the codec fallback directly, set the codec state to invalid once it happens + w.SetCodecState(ReceiverCodecStateInvalid) +} + +func (w *WebRTCReceiver) AddOnCodecStateChange(f func(webrtc.RTPCodecParameters, ReceiverCodecState)) { + w.codecStateLock.Lock() + w.onCodecStateChange = append(w.onCodecStateChange, f) + w.codecStateLock.Unlock() +} + +func (w *WebRTCReceiver) CodecState() ReceiverCodecState { + w.codecStateLock.Lock() + defer w.codecStateLock.Unlock() + + return w.codecState +} + +func (w *WebRTCReceiver) SetCodecState(state ReceiverCodecState) { + w.codecStateLock.Lock() + if w.codecState == state || w.codecState == ReceiverCodecStateInvalid { + w.codecStateLock.Unlock() + return + } + + w.codecState = state + fns := w.onCodecStateChange + w.codecStateLock.Unlock() + + for _, f := range fns { + f(w.codec, state) + } +} + // ----------------------------------------------------------- // closes all track senders in parallel, returns when all are closed diff --git a/pkg/sfu/redprimaryreceiver.go b/pkg/sfu/redprimaryreceiver.go index db47a7e94..6023a89ab 100644 --- a/pkg/sfu/redprimaryreceiver.go +++ b/pkg/sfu/redprimaryreceiver.go @@ -121,6 +121,10 @@ func (r *RedPrimaryReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) r.logger.Debugw("red primary receiver downtrack deleted", "subscriberID", subscriberID) } +func (r *RedPrimaryReceiver) GetDownTracks() []TrackSender { + return r.downTrackSpreader.GetDownTracks() +} + func (r *RedPrimaryReceiver) ResyncDownTracks() { r.downTrackSpreader.Broadcast(func(dt TrackSender) { dt.Resync() diff --git a/pkg/sfu/redreceiver.go b/pkg/sfu/redreceiver.go index e4afaa5ca..7fbd9a6eb 100644 --- a/pkg/sfu/redreceiver.go +++ b/pkg/sfu/redreceiver.go @@ -112,6 +112,10 @@ func (r *RedReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) { r.logger.Debugw("red receiver downtrack deleted", "subscriberID", subscriberID) } +func (r *RedReceiver) GetDownTracks() []TrackSender { + return r.downTrackSpreader.GetDownTracks() +} + func (r *RedReceiver) ResyncDownTracks() { r.downTrackSpreader.Broadcast(func(dt TrackSender) { dt.Resync() diff --git a/pkg/sfu/track_remote.go b/pkg/sfu/track_remote.go index ffdd959c4..bd6d27dc1 100644 --- a/pkg/sfu/track_remote.go +++ b/pkg/sfu/track_remote.go @@ -28,8 +28,5 @@ func NewTrackRemoteFromSdp(track *webrtc.TrackRemote, codec webrtc.RTPCodecParam } func (t *TrackRemoteFromSdp) Codec() webrtc.RTPCodecParameters { - if t.TrackRemote.PayloadType() == 0 { - return t.sdpCodec - } - return t.TrackRemote.Codec() + return t.sdpCodec } diff --git a/pkg/sfu/utils/mimetype.go b/pkg/sfu/utils/mimetype.go index 0908e1eec..cf7735b88 100644 --- a/pkg/sfu/utils/mimetype.go +++ b/pkg/sfu/utils/mimetype.go @@ -22,6 +22,7 @@ const ( MimeTypeVP9 MimeTypeH264 MimeTypeAV1 + MimeTypeH265 ) func MatchMimeType(mimeType string) MimeType { @@ -87,6 +88,8 @@ func MatchMimeType(mimeType string) MimeType { switch mimeType[9] { case '4': return MimeTypeH264 + case '5': + return MimeTypeH265 } } } diff --git a/pkg/sfu/videolayerselector/simulcast.go b/pkg/sfu/videolayerselector/simulcast.go index 6b20787b7..83f0884b4 100644 --- a/pkg/sfu/videolayerselector/simulcast.go +++ b/pkg/sfu/videolayerselector/simulcast.go @@ -29,9 +29,29 @@ func NewSimulcast(logger logger.Logger) *Simulcast { } } -func NewSimulcastFromNull(vls VideoLayerSelector) *Simulcast { - return &Simulcast{ - Base: vls.(*Null).Base, +func NewSimulcastFromOther(vls VideoLayerSelector) *Simulcast { + switch vls := vls.(type) { + case *Null: + return &Simulcast{ + Base: vls.Base, + } + case *Simulcast: + return &Simulcast{ + Base: vls.Base, + } + + case *DependencyDescriptor: + return &Simulcast{ + Base: vls.Base, + } + + case *VP9: + return &Simulcast{ + Base: vls.Base, + } + + default: + return nil } }