From 9551c52c85d64ba8b1632b918c734ea46aee0eb1 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Mon, 10 Feb 2025 10:44:15 +0530 Subject: [PATCH] Try 2 to consolidate mime type (#3407) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Normalize mime type and add utilities. An attempt to normalize mime type and avoid string compares remembering to do case insensitive search. Not the best solution. Open to ideas. But, define our own mime types (just in case Pion changes things and Pion also does not have red mime type defined which should be easy to add though) and tried to use it everywhere. But, as we get a bunch of callbacks and info from Pion, needed conversion in more places than I anticipated. And also makes it necessary to carry that cognitive load of what comes from Pion and needing to process it properly. * more locations * test * Paul feedback * MimeType type * more consolidation * Remove unused * test * test * mime type as int * use string method * Pass error details and timeouts. (#3402) * go mod tidy (#3408) * Rename CHANGELOG to CHANGELOG.md (#3391) Enables markdown features in this otherwise already markdown'ish formatted document * Update config.go to properly process bool env vars (#3382) Fixes issue https://github.com/livekit/livekit/issues/3381 * fix(deps): update go deps (#3341) Generated by renovateBot Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> * Use a Twirp server hook to send API call details to telemetry. (#3401) * Use a Twirp server hook to send API call details to telemetry. * mage generate and clean up * Add project_id * deps * - Redact requests - Do not store responses - Extract top level fields room_name, room_id, participant_identity, participant_id, track_id as appropriate - Store status as int * deps * Update pkg/sfu/mime/mimetype.go * Fix prefer codec test * handle down track mime changes --------- Co-authored-by: Denys Smirnov Co-authored-by: Philzen Co-authored-by: Pablo Fuente Pérez Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Paul Wells Co-authored-by: cnderrauber --- pkg/clientconfiguration/conf.go | 5 +- pkg/config/config.go | 16 +- pkg/rtc/dynacast/dynacastmanager.go | 62 ++-- pkg/rtc/dynacast/dynacastmanager_test.go | 87 +++--- pkg/rtc/dynacast/dynacastquality.go | 7 +- pkg/rtc/mediaengine.go | 32 +-- pkg/rtc/mediaengine_test.go | 13 +- pkg/rtc/mediatrack.go | 42 +-- pkg/rtc/mediatrackreceiver.go | 51 ++-- pkg/rtc/mediatracksubscriptions.go | 16 +- pkg/rtc/participant.go | 52 ++-- pkg/rtc/participant_internal_test.go | 30 +- pkg/rtc/participant_sdp.go | 20 +- pkg/rtc/subscribedtrack.go | 2 +- pkg/rtc/subscriptionmanager.go | 2 +- pkg/rtc/transport.go | 7 +- pkg/rtc/transport_test.go | 11 +- pkg/rtc/types/interfaces.go | 5 +- .../typesfakes/fake_local_media_track.go | 13 +- pkg/rtc/types/typesfakes/fake_media_track.go | 13 +- pkg/rtc/utils.go | 7 +- pkg/rtc/wrappedreceiver.go | 22 +- pkg/sfu/buffer/buffer.go | 72 ++--- pkg/sfu/connectionquality/connectionstats.go | 27 +- .../connectionquality/connectionstats_test.go | 24 +- pkg/sfu/connectionquality/scorer.go | 7 + pkg/sfu/downtrack.go | 50 ++-- pkg/sfu/forwarder.go | 40 +-- pkg/sfu/mime/mimetype.go | 271 ++++++++++++++++++ pkg/sfu/receiver.go | 14 +- pkg/sfu/redprimaryreceiver.go | 4 - pkg/sfu/utils/helpers.go | 6 +- pkg/sfu/utils/mimetype.go | 105 ------- pkg/sfu/utils/mimetype_test.go | 65 ----- pkg/telemetry/events.go | 13 +- .../telemetryfakes/fake_telemetry_service.go | 37 +-- pkg/telemetry/telemetryservice.go | 7 +- test/client/client.go | 10 +- test/client/trackwriter.go | 20 +- test/singlenode_test.go | 19 +- 40 files changed, 727 insertions(+), 579 deletions(-) create mode 100644 pkg/sfu/mime/mimetype.go delete mode 100644 pkg/sfu/utils/mimetype.go delete mode 100644 pkg/sfu/utils/mimetype_test.go diff --git a/pkg/clientconfiguration/conf.go b/pkg/clientconfiguration/conf.go index cc069fc5c..992cca92d 100644 --- a/pkg/clientconfiguration/conf.go +++ b/pkg/clientconfiguration/conf.go @@ -15,6 +15,7 @@ package clientconfiguration import ( + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) @@ -28,7 +29,7 @@ var StaticConfigurations = []ConfigurationItem{ { Match: &ScriptMatch{Expr: `c.browser == "safari"`}, Configuration: &livekit.ClientConfiguration{DisabledCodecs: &livekit.DisabledCodecs{Codecs: []*livekit.Codec{ - {Mime: "video/av1"}, + {Mime: mime.MimeTypeAV1.String()}, }}}, Merge: false, }, @@ -37,7 +38,7 @@ var StaticConfigurations = []ConfigurationItem{ ((c.browser == "firefox" || c.browser == "firefox mobile") && (c.os == "linux" || c.os == "android"))`}, Configuration: &livekit.ClientConfiguration{ DisabledCodecs: &livekit.DisabledCodecs{ - Publish: []*livekit.Codec{{Mime: "video/h264"}}, + Publish: []*livekit.Codec{{Mime: mime.MimeTypeH264.String()}}, }, }, Merge: false, diff --git a/pkg/config/config.go b/pkg/config/config.go index 2a56fe730..5c04c5f91 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -22,7 +22,6 @@ import ( "time" "github.com/mitchellh/go-homedir" - "github.com/pion/webrtc/v4" "github.com/pkg/errors" "github.com/urfave/cli/v2" "gopkg.in/yaml.v3" @@ -31,6 +30,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/bwe/remotebwe" "github.com/livekit/livekit-server/pkg/sfu/bwe/sendsidebwe" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" "github.com/livekit/mediatransportutil/pkg/rtcconfig" "github.com/livekit/protocol/livekit" @@ -336,13 +336,13 @@ var DefaultConfig = Config{ Room: RoomConfig{ AutoCreate: true, EnabledCodecs: []CodecSpec{ - {Mime: webrtc.MimeTypeOpus}, - {Mime: sfu.MimeTypeAudioRed}, - {Mime: webrtc.MimeTypeVP8}, - {Mime: webrtc.MimeTypeH264}, - {Mime: webrtc.MimeTypeVP9}, - {Mime: webrtc.MimeTypeAV1}, - {Mime: webrtc.MimeTypeRTX}, + {Mime: mime.MimeTypeOpus.String()}, + {Mime: mime.MimeTypeRED.String()}, + {Mime: mime.MimeTypeVP8.String()}, + {Mime: mime.MimeTypeH264.String()}, + {Mime: mime.MimeTypeVP9.String()}, + {Mime: mime.MimeTypeAV1.String()}, + {Mime: mime.MimeTypeRTX.String()}, }, EmptyTimeout: 5 * 60, DepartureTimeout: 20, diff --git a/pkg/rtc/dynacast/dynacastmanager.go b/pkg/rtc/dynacast/dynacastmanager.go index 92ba192c5..b2ef9bc20 100644 --- a/pkg/rtc/dynacast/dynacastmanager.go +++ b/pkg/rtc/dynacast/dynacastmanager.go @@ -15,7 +15,6 @@ package dynacast import ( - "strings" "sync" "time" @@ -26,6 +25,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/utils" ) @@ -38,10 +38,10 @@ type DynacastManager struct { params DynacastManagerParams lock sync.RWMutex - regressedCodec map[string]struct{} - dynacastQuality map[string]*DynacastQuality // mime type => DynacastQuality - maxSubscribedQuality map[string]livekit.VideoQuality - committedMaxSubscribedQuality map[string]livekit.VideoQuality + regressedCodec map[mime.MimeType]struct{} + dynacastQuality map[mime.MimeType]*DynacastQuality + maxSubscribedQuality map[mime.MimeType]livekit.VideoQuality + committedMaxSubscribedQuality map[mime.MimeType]livekit.VideoQuality maxSubscribedQualityDebounce func(func()) maxSubscribedQualityDebouncePending bool @@ -59,10 +59,10 @@ func NewDynacastManager(params DynacastManagerParams) *DynacastManager { } d := &DynacastManager{ params: params, - regressedCodec: make(map[string]struct{}), - dynacastQuality: make(map[string]*DynacastQuality), - maxSubscribedQuality: make(map[string]livekit.VideoQuality), - committedMaxSubscribedQuality: make(map[string]livekit.VideoQuality), + regressedCodec: make(map[mime.MimeType]struct{}), + dynacastQuality: make(map[mime.MimeType]*DynacastQuality), + maxSubscribedQuality: make(map[mime.MimeType]livekit.VideoQuality), + committedMaxSubscribedQuality: make(map[mime.MimeType]livekit.VideoQuality), qualityNotifyOpQueue: utils.NewOpsQueue(utils.OpsQueueParams{ Name: "quality-notify", MinSize: 64, @@ -83,11 +83,11 @@ func (d *DynacastManager) OnSubscribedMaxQualityChange(f func(subscribedQualitie d.lock.Unlock() } -func (d *DynacastManager) AddCodec(mime string) { +func (d *DynacastManager) AddCodec(mime mime.MimeType) { d.getOrCreateDynacastQuality(mime) } -func (d *DynacastManager) HandleCodecRegression(fromMime, toMime string) { +func (d *DynacastManager) HandleCodecRegression(fromMime, toMime mime.MimeType) { fromDq := d.getOrCreateDynacastQuality(fromMime) d.lock.Lock() @@ -96,32 +96,31 @@ func (d *DynacastManager) HandleCodecRegression(fromMime, toMime string) { return } - normalizedFromMime, normalizedToMime := strings.ToLower(fromMime), strings.ToLower(toMime) if fromDq == nil { // should not happen as we have added the codec on setup receiver - d.params.Logger.Warnw("regression from codec not found", nil, "mime", normalizedFromMime) + d.params.Logger.Warnw("regression from codec not found", nil, "mime", fromMime) d.lock.Unlock() return } - d.regressedCodec[normalizedFromMime] = struct{}{} - d.maxSubscribedQuality[normalizedFromMime] = livekit.VideoQuality_OFF + d.regressedCodec[fromMime] = struct{}{} + d.maxSubscribedQuality[fromMime] = livekit.VideoQuality_OFF // if the new codec is not added, notify the publisher to start publishing - if _, ok := d.maxSubscribedQuality[normalizedToMime]; !ok { - d.maxSubscribedQuality[normalizedToMime] = livekit.VideoQuality_HIGH + if _, ok := d.maxSubscribedQuality[toMime]; !ok { + d.maxSubscribedQuality[toMime] = livekit.VideoQuality_HIGH } d.lock.Unlock() d.update(false) fromDq.Stop() - ToDq := d.getOrCreateDynacastQuality(normalizedToMime) + ToDq := d.getOrCreateDynacastQuality(toMime) fromDq.RegressTo(ToDq) } func (d *DynacastManager) Restart() { d.lock.Lock() - d.committedMaxSubscribedQuality = make(map[string]livekit.VideoQuality) + d.committedMaxSubscribedQuality = make(map[mime.MimeType]livekit.VideoQuality) dqs := d.getDynacastQualitiesLocked() d.lock.Unlock() @@ -136,7 +135,7 @@ func (d *DynacastManager) Close() { d.lock.Lock() dqs := d.getDynacastQualitiesLocked() - d.dynacastQuality = make(map[string]*DynacastQuality) + d.dynacastQuality = make(map[mime.MimeType]*DynacastQuality) d.isClosed = true d.lock.Unlock() @@ -169,7 +168,7 @@ func (d *DynacastManager) ForceQuality(quality livekit.VideoQuality) { d.enqueueSubscribedQualityChange() } -func (d *DynacastManager) NotifySubscriberMaxQuality(subscriberID livekit.ParticipantID, mime string, quality livekit.VideoQuality) { +func (d *DynacastManager) NotifySubscriberMaxQuality(subscriberID livekit.ParticipantID, mime mime.MimeType, quality livekit.VideoQuality) { dq := d.getOrCreateDynacastQuality(mime) if dq != nil { dq.NotifySubscriberMaxQuality(subscriberID, quality) @@ -185,7 +184,7 @@ func (d *DynacastManager) NotifySubscriberNodeMaxQuality(nodeID livekit.NodeID, } } -func (d *DynacastManager) getOrCreateDynacastQuality(mime string) *DynacastQuality { +func (d *DynacastManager) getOrCreateDynacastQuality(mimeType mime.MimeType) *DynacastQuality { d.lock.Lock() defer d.lock.Unlock() @@ -193,21 +192,18 @@ func (d *DynacastManager) getOrCreateDynacastQuality(mime string) *DynacastQuali return nil } - normalizedMime := strings.ToLower(mime) - if dq := d.dynacastQuality[normalizedMime]; dq != nil { + if dq := d.dynacastQuality[mimeType]; dq != nil { return dq } dq := NewDynacastQuality(DynacastQualityParams{ - MimeType: normalizedMime, + MimeType: mimeType, Logger: d.params.Logger, }) - dq.OnSubscribedMaxQualityChange(func(mimeType string, maxQuality livekit.VideoQuality) { - d.updateMaxQualityForMime(mimeType, maxQuality) - }) + dq.OnSubscribedMaxQualityChange(d.updateMaxQualityForMime) dq.Start() - d.dynacastQuality[normalizedMime] = dq + d.dynacastQuality[mimeType] = dq return dq } @@ -215,7 +211,7 @@ func (d *DynacastManager) getDynacastQualitiesLocked() []*DynacastQuality { return maps.Values(d.dynacastQuality) } -func (d *DynacastManager) updateMaxQualityForMime(mime string, maxQuality livekit.VideoQuality) { +func (d *DynacastManager) updateMaxQualityForMime(mime mime.MimeType, maxQuality livekit.VideoQuality) { d.lock.Lock() if _, ok := d.regressedCodec[mime]; !ok { d.maxSubscribedQuality[mime] = maxQuality @@ -297,7 +293,7 @@ func (d *DynacastManager) update(force bool) { ) // commit change - d.committedMaxSubscribedQuality = make(map[string]livekit.VideoQuality, len(d.maxSubscribedQuality)) + d.committedMaxSubscribedQuality = make(map[mime.MimeType]livekit.VideoQuality, len(d.maxSubscribedQuality)) for mime, quality := range d.maxSubscribedQuality { d.committedMaxSubscribedQuality[mime] = quality } @@ -321,7 +317,7 @@ func (d *DynacastManager) enqueueSubscribedQualityChange() { if quality == livekit.VideoQuality_OFF { subscribedCodecs = append(subscribedCodecs, &livekit.SubscribedCodec{ - Codec: mime, + Codec: mime.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -337,7 +333,7 @@ func (d *DynacastManager) enqueueSubscribedQualityChange() { }) } subscribedCodecs = append(subscribedCodecs, &livekit.SubscribedCodec{ - Codec: mime, + Codec: mime.String(), Qualities: subscribedQualities, }) } diff --git a/pkg/rtc/dynacast/dynacastmanager_test.go b/pkg/rtc/dynacast/dynacastmanager_test.go index f01b2edf9..8a30d7e97 100644 --- a/pkg/rtc/dynacast/dynacastmanager_test.go +++ b/pkg/rtc/dynacast/dynacastmanager_test.go @@ -16,15 +16,14 @@ package dynacast import ( "sort" - "strings" "sync" "testing" "time" - "github.com/pion/webrtc/v4" "github.com/stretchr/testify/require" "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) @@ -40,15 +39,15 @@ func TestSubscribedMaxQuality(t *testing.T) { lock.Unlock() }) - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_HIGH) - dm.NotifySubscriberMaxQuality("s2", webrtc.MimeTypeAV1, livekit.VideoQuality_HIGH) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_HIGH) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeAV1, livekit.VideoQuality_HIGH) // mute all subscribers of vp8 - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_OFF) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_OFF) expectedSubscribedQualities := []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -56,7 +55,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -87,17 +86,17 @@ func TestSubscribedMaxQuality(t *testing.T) { lock.Unlock() }) - dm.maxSubscribedQuality = map[string]livekit.VideoQuality{ - strings.ToLower(webrtc.MimeTypeVP8): livekit.VideoQuality_LOW, - strings.ToLower(webrtc.MimeTypeAV1): livekit.VideoQuality_LOW, + dm.maxSubscribedQuality = map[mime.MimeType]livekit.VideoQuality{ + mime.MimeTypeVP8: livekit.VideoQuality_LOW, + mime.MimeTypeAV1: livekit.VideoQuality_LOW, } - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_HIGH) - dm.NotifySubscriberMaxQuality("s2", webrtc.MimeTypeVP8, livekit.VideoQuality_MEDIUM) - dm.NotifySubscriberMaxQuality("s3", webrtc.MimeTypeAV1, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_HIGH) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeVP8, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberMaxQuality("s3", mime.MimeTypeAV1, livekit.VideoQuality_MEDIUM) expectedSubscribedQualities := []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -105,7 +104,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -121,11 +120,11 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // "s1" dropping to MEDIUM should disable HIGH layer - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_MEDIUM) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -133,7 +132,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -149,13 +148,13 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // "s1" , "s2" , "s3" dropping to LOW should disable HIGH & MEDIUM - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_LOW) - dm.NotifySubscriberMaxQuality("s2", webrtc.MimeTypeVP8, livekit.VideoQuality_LOW) - dm.NotifySubscriberMaxQuality("s3", webrtc.MimeTypeAV1, livekit.VideoQuality_LOW) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_LOW) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeVP8, livekit.VideoQuality_LOW) + dm.NotifySubscriberMaxQuality("s3", mime.MimeTypeAV1, livekit.VideoQuality_LOW) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -163,7 +162,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -179,7 +178,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // muting "s2" only should not disable all qualities of vp8, no change of expected qualities - dm.NotifySubscriberMaxQuality("s2", webrtc.MimeTypeVP8, livekit.VideoQuality_OFF) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeVP8, livekit.VideoQuality_OFF) time.Sleep(100 * time.Millisecond) require.Eventually(t, func() bool { @@ -190,12 +189,12 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // muting "s1" and s3 also should disable all qualities - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_OFF) - dm.NotifySubscriberMaxQuality("s3", webrtc.MimeTypeAV1, livekit.VideoQuality_OFF) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_OFF) + dm.NotifySubscriberMaxQuality("s3", mime.MimeTypeAV1, livekit.VideoQuality_OFF) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -203,7 +202,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -219,11 +218,11 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // unmuting "s1" should enable vp8 previously set max quality - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_LOW) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_LOW) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -231,7 +230,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -248,13 +247,13 @@ func TestSubscribedMaxQuality(t *testing.T) { // a higher quality from a different node should trigger that quality dm.NotifySubscriberNodeMaxQuality("n1", []types.SubscribedCodecQuality{ - {CodecMime: webrtc.MimeTypeVP8, Quality: livekit.VideoQuality_HIGH}, - {CodecMime: webrtc.MimeTypeAV1, Quality: livekit.VideoQuality_MEDIUM}, + {CodecMime: mime.MimeTypeVP8, Quality: livekit.VideoQuality_HIGH}, + {CodecMime: mime.MimeTypeAV1, Quality: livekit.VideoQuality_MEDIUM}, }) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -262,7 +261,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -289,11 +288,11 @@ func TestCodecRegression(t *testing.T) { lock.Unlock() }) - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeAV1, livekit.VideoQuality_HIGH) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeAV1, livekit.VideoQuality_HIGH) expectedSubscribedQualities := []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -308,11 +307,11 @@ func TestCodecRegression(t *testing.T) { return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) }, 10*time.Second, 100*time.Millisecond) - dm.HandleCodecRegression(webrtc.MimeTypeAV1, webrtc.MimeTypeVP8) + dm.HandleCodecRegression(mime.MimeTypeAV1, mime.MimeTypeVP8) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -320,7 +319,7 @@ func TestCodecRegression(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -337,13 +336,13 @@ func TestCodecRegression(t *testing.T) { // av1 quality change should be forwarded to vp8 // av1 quality change of node should be ignored - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeAV1, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeAV1, livekit.VideoQuality_MEDIUM) dm.NotifySubscriberNodeMaxQuality("n1", []types.SubscribedCodecQuality{ - {CodecMime: webrtc.MimeTypeAV1, Quality: livekit.VideoQuality_HIGH}, + {CodecMime: mime.MimeTypeAV1, Quality: livekit.VideoQuality_HIGH}, }) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -351,7 +350,7 @@ func TestCodecRegression(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, diff --git a/pkg/rtc/dynacast/dynacastquality.go b/pkg/rtc/dynacast/dynacastquality.go index 957874b36..77b6aaad2 100644 --- a/pkg/rtc/dynacast/dynacastquality.go +++ b/pkg/rtc/dynacast/dynacastquality.go @@ -18,6 +18,7 @@ import ( "sync" "time" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) @@ -27,7 +28,7 @@ const ( ) type DynacastQualityParams struct { - MimeType string + MimeType mime.MimeType Logger logger.Logger } @@ -44,7 +45,7 @@ type DynacastQuality struct { maxQualityTimer *time.Timer regressTo *DynacastQuality - onSubscribedMaxQualityChange func(mimeType string, maxSubscribedQuality livekit.VideoQuality) + onSubscribedMaxQualityChange func(mimeType mime.MimeType, maxSubscribedQuality livekit.VideoQuality) } func NewDynacastQuality(params DynacastQualityParams) *DynacastQuality { @@ -67,7 +68,7 @@ func (d *DynacastQuality) Stop() { d.stopMaxQualityTimer() } -func (d *DynacastQuality) OnSubscribedMaxQualityChange(f func(mimeType string, maxSubscribedQuality livekit.VideoQuality)) { +func (d *DynacastQuality) OnSubscribedMaxQualityChange(f func(mimeType mime.MimeType, maxSubscribedQuality livekit.VideoQuality)) { d.lock.Lock() defer d.lock.Unlock() d.onSubscribedMaxQualityChange = f diff --git a/pkg/rtc/mediaengine.go b/pkg/rtc/mediaengine.go index 81aab0176..a6034468e 100644 --- a/pkg/rtc/mediaengine.go +++ b/pkg/rtc/mediaengine.go @@ -20,24 +20,24 @@ import ( "github.com/pion/webrtc/v4" - "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) var OpusCodecCapability = webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeOpus, + MimeType: mime.MimeTypeOpus.String(), ClockRate: 48000, Channels: 2, SDPFmtpLine: "minptime=10;useinbandfec=1", } var RedCodecCapability = webrtc.RTPCodecCapability{ - MimeType: sfu.MimeTypeAudioRed, + MimeType: mime.MimeTypeRED.String(), ClockRate: 48000, Channels: 2, SDPFmtpLine: "111/111", } var videoRTX = webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeRTX, + MimeType: mime.MimeTypeRTX.String(), ClockRate: 90000, } @@ -70,7 +70,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac for _, codec := range []webrtc.RTPCodecParameters{ { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeVP8, + MimeType: mime.MimeTypeVP8.String(), ClockRate: 90000, RTCPFeedback: rtcpFeedback.Video, }, @@ -78,7 +78,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeVP9, + MimeType: mime.MimeTypeVP9.String(), ClockRate: 90000, SDPFmtpLine: "profile-id=0", RTCPFeedback: rtcpFeedback.Video, @@ -87,7 +87,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeVP9, + MimeType: mime.MimeTypeVP9.String(), ClockRate: 90000, SDPFmtpLine: "profile-id=1", RTCPFeedback: rtcpFeedback.Video, @@ -96,7 +96,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeH264, + MimeType: mime.MimeTypeH264.String(), ClockRate: 90000, SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", RTCPFeedback: rtcpFeedback.Video, @@ -105,7 +105,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeH264, + MimeType: mime.MimeTypeH264.String(), ClockRate: 90000, SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f", RTCPFeedback: rtcpFeedback.Video, @@ -114,7 +114,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeH264, + MimeType: mime.MimeTypeH264.String(), ClockRate: 90000, SDPFmtpLine: h264HighProfileFmtp, RTCPFeedback: rtcpFeedback.Video, @@ -123,7 +123,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeAV1, + MimeType: mime.MimeTypeAV1.String(), ClockRate: 90000, RTCPFeedback: rtcpFeedback.Video, }, @@ -141,7 +141,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac if filterOutH264HighProfile && codec.RTPCodecCapability.SDPFmtpLine == h264HighProfileFmtp { continue } - if strings.EqualFold(codec.MimeType, webrtc.MimeTypeRTX) { + if mime.IsMimeTypeStringRTX(codec.MimeType) { continue } if IsCodecEnabled(codecs, codec.RTPCodecCapability) { @@ -151,7 +151,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac if rtxEnabled { if err := me.RegisterCodec(webrtc.RTPCodecParameters{ RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeRTX, + MimeType: mime.MimeTypeRTX.String(), ClockRate: 90000, SDPFmtpLine: fmt.Sprintf("apt=%d", codec.PayloadType), }, @@ -196,7 +196,7 @@ func createMediaEngine(codecs []*livekit.Codec, config DirectionConfig, filterOu func IsCodecEnabled(codecs []*livekit.Codec, cap webrtc.RTPCodecCapability) bool { for _, codec := range codecs { - if !strings.EqualFold(codec.Mime, cap.MimeType) { + if !mime.IsMimeTypeStringEqual(codec.Mime, cap.MimeType) { continue } if codec.FmtpLine == "" || strings.EqualFold(codec.FmtpLine, cap.SDPFmtpLine) { @@ -208,10 +208,10 @@ func IsCodecEnabled(codecs []*livekit.Codec, cap webrtc.RTPCodecCapability) bool func selectAlternativeVideoCodec(enabledCodecs []*livekit.Codec) string { for _, c := range enabledCodecs { - if strings.HasPrefix(c.Mime, "video/") { + if mime.IsMimeTypeStringVideo(c.Mime) { return c.Mime } } // no viable codec in the list of enabled codecs, fall back to the most widely supported codec - return webrtc.MimeTypeVP8 + return mime.MimeTypeVP8.String() } diff --git a/pkg/rtc/mediaengine_test.go b/pkg/rtc/mediaengine_test.go index d91090165..122861a87 100644 --- a/pkg/rtc/mediaengine_test.go +++ b/pkg/rtc/mediaengine_test.go @@ -20,21 +20,22 @@ import ( "github.com/pion/webrtc/v4" "github.com/stretchr/testify/require" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) func TestIsCodecEnabled(t *testing.T) { t.Run("empty fmtp requirement should match all", func(t *testing.T) { enabledCodecs := []*livekit.Codec{{Mime: "video/h264"}} - require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264, SDPFmtpLine: "special"})) - require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264})) - require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeVP8})) + require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String(), SDPFmtpLine: "special"})) + require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String()})) + require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeVP8.String()})) }) t.Run("when fmtp is provided, require match", func(t *testing.T) { enabledCodecs := []*livekit.Codec{{Mime: "video/h264", FmtpLine: "special"}} - require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264, SDPFmtpLine: "special"})) - require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264})) - require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeVP8})) + require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String(), SDPFmtpLine: "special"})) + require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String()})) + require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeVP8.String()})) }) } diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index bcb4c1457..3a7235139 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -17,7 +17,6 @@ package rtc import ( "context" "math" - "strings" "sync" "time" @@ -34,6 +33,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/telemetry" util "github.com/livekit/mediatransportutil" ) @@ -56,7 +56,7 @@ type MediaTrack struct { rttFromXR atomic.Bool enableRegression bool - regressionTargetCodec string + regressionTargetCodec mime.MimeType } type MediaTrackParams struct { @@ -87,7 +87,7 @@ func NewMediaTrack(params MediaTrackParams, ti *livekit.TrackInfo) *MediaTrack { // TODO: disable codec regression until simulcast-codec clients knows that if ti.BackupCodecPolicy == livekit.BackupCodecPolicy_REGRESSION && len(ti.Codecs) > 1 { t.enableRegression = true - t.regressionTargetCodec = ti.Codecs[1].MimeType + t.regressionTargetCodec = mime.NormalizeMimeType(ti.Codecs[1].MimeType) t.params.Logger.Debugw("track enabled codec regression", "regressionCodec", t.regressionTargetCodec) } @@ -122,20 +122,20 @@ func NewMediaTrack(params MediaTrackParams, ti *livekit.TrackInfo) *MediaTrack { DynacastPauseDelay: params.VideoConfig.DynacastPauseDelay, Logger: params.Logger, }) - t.MediaTrackReceiver.OnSetupReceiver(func(mime string) { + t.MediaTrackReceiver.OnSetupReceiver(func(mime mime.MimeType) { t.dynacastManager.AddCodec(mime) }) t.MediaTrackReceiver.OnSubscriberMaxQualityChange( - func(subscriberID livekit.ParticipantID, codec webrtc.RTPCodecCapability, layer int32) { + func(subscriberID livekit.ParticipantID, mimeType mime.MimeType, layer int32) { t.dynacastManager.NotifySubscriberMaxQuality( subscriberID, - codec.MimeType, + mimeType, buffer.SpatialLayerToVideoQuality(layer, t.MediaTrackReceiver.TrackInfo()), ) }, ) t.MediaTrackReceiver.OnCodecRegression(func(old, new webrtc.RTPCodecParameters) { - t.dynacastManager.HandleCodecRegression(old.MimeType, new.MimeType) + t.dynacastManager.HandleCodecRegression(mime.NormalizeMimeType(old.MimeType), mime.NormalizeMimeType(new.MimeType)) }) } @@ -252,7 +252,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe ti := t.MediaTrackReceiver.TrackInfoClone() t.lock.Lock() - mime := strings.ToLower(track.Codec().MimeType) + mimeType := mime.NormalizeMimeType(track.Codec().MimeType) layer := buffer.RidToSpatialLayer(track.RID(), ti) t.params.Logger.Debugw( "AddReceiver", @@ -261,11 +261,11 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe "ssrc", track.SSRC(), "codec", track.Codec(), ) - wr := t.MediaTrackReceiver.Receiver(mime) + wr := t.MediaTrackReceiver.Receiver(mimeType) if wr == nil { priority := -1 for idx, c := range ti.Codecs { - if strings.EqualFold(mime, c.MimeType) { + if mime.IsMimeTypeStringEqual(track.Codec().MimeType, c.MimeType) { priority = idx break } @@ -283,7 +283,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe } } if priority < 0 { - t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mime, "track", logger.Proto(ti)) + t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mimeType, "track", logger.Proto(ti)) t.lock.Unlock() return false } @@ -292,7 +292,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe receiver, track, ti, - LoggerWithCodecMime(t.params.Logger, mime), + LoggerWithCodecMime(t.params.Logger, mimeType), t.params.OnRTCP, t.params.VideoConfig.StreamTrackerManager, sfu.WithPliThrottleConfig(t.params.PLIThrottleConfig), @@ -303,7 +303,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe ) newWR.OnCloseHandler(func() { t.MediaTrackReceiver.SetClosing() - t.MediaTrackReceiver.ClearReceiver(mime, false) + t.MediaTrackReceiver.ClearReceiver(mimeType, false) if t.MediaTrackReceiver.TryClose() { if t.dynacastManager != nil { t.dynacastManager.Close() @@ -325,7 +325,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe parameters := receiver.GetParameters() for _, c := range ti.Codecs { for _, nc := range parameters.Codecs { - if strings.EqualFold(nc.MimeType, c.MimeType) { + if mime.IsMimeTypeStringEqual(nc.MimeType, c.MimeType) { potentialCodecs = append(potentialCodecs, nc) break } @@ -344,7 +344,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe for ssrc, info := range t.params.SimTracks { if info.Mid == mid { - t.MediaTrackReceiver.SetLayerSsrc(mime, info.Rid, ssrc) + t.MediaTrackReceiver.SetLayerSsrc(mimeType, info.Rid, ssrc) } } wr = newWR @@ -379,17 +379,17 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe bitrates = int(ti.Layers[layer].GetBitrate()) } - t.MediaTrackReceiver.SetLayerSsrc(mime, track.RID(), uint32(track.SSRC())) + t.MediaTrackReceiver.SetLayerSsrc(mimeType, track.RID(), uint32(track.SSRC())) - if newCodec && t.enableRegression && strings.EqualFold(mime, t.regressionTargetCodec) { - t.params.Logger.Infow("regression target codec received", "codec", mime) + if newCodec && t.enableRegression && mimeType == t.regressionTargetCodec { + t.params.Logger.Infow("regression target codec received", "codec", mimeType) for _, c := range ti.Codecs { - if strings.EqualFold(c.MimeType, mime) { + if mime.NormalizeMimeType(c.MimeType) == mimeType { continue } t.params.Logger.Debugw("suspending codec for codec regression", "codec", c.MimeType) - if r := t.MediaTrackReceiver.Receiver(c.MimeType); r != nil { + if r := t.MediaTrackReceiver.Receiver(mime.NormalizeMimeType(c.MimeType)); r != nil { if rtcreceiver, ok := r.(*sfu.WebRTCReceiver); ok { rtcreceiver.SetCodecState(sfu.ReceiverCodecStateSuspended) } @@ -409,7 +409,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe context.Background(), t.params.ParticipantID, t.ID(), - mime, + mimeType, int(layer), stats, ) diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index 970dba5a4..6aff85278 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -35,6 +35,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" "github.com/livekit/livekit-server/pkg/telemetry" ) @@ -127,7 +128,7 @@ type MediaTrackReceiverParams struct { AudioConfig sfu.AudioConfig Telemetry telemetry.TelemetryService Logger logger.Logger - RegressionTargetCodec string + RegressionTargetCodec mime.MimeType } type MediaTrackReceiver struct { @@ -140,7 +141,7 @@ type MediaTrackReceiver struct { state mediaTrackReceiverState isExpectedToResume bool - onSetupReceiver func(mime string) + onSetupReceiver func(mime mime.MimeType) onMediaLossFeedback func(dt *sfu.DownTrack, report *rtcp.ReceiverReport) onClose []func(isExpectedToResume bool) onCodecRegression func(old, new webrtc.RTPCodecParameters) @@ -179,7 +180,7 @@ func (t *MediaTrackReceiver) Restart() { } } -func (t *MediaTrackReceiver) OnSetupReceiver(f func(mime string)) { +func (t *MediaTrackReceiver) OnSetupReceiver(f func(mime mime.MimeType)) { t.lock.Lock() t.onSetupReceiver = f t.lock.Unlock() @@ -198,12 +199,12 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority // codec position maybe taken by DummyReceiver, check and upgrade to WebRTCReceiver var existingReceiver bool for _, r := range receivers { - if strings.EqualFold(r.Codec().MimeType, receiver.Codec().MimeType) { + if r.Mime() == receiver.Mime() { existingReceiver = true if d, ok := r.TrackReceiver.(*DummyReceiver); ok { d.Upgrade(receiver) } else { - t.params.Logger.Errorw("receiver already exists, setup failed", nil, "mime", receiver.Codec().MimeType) + t.params.Logger.Errorw("receiver already exists, setup failed", nil, "mime", receiver.Mime()) } break } @@ -219,13 +220,13 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority if mid != "" { trackInfo := t.TrackInfoClone() if priority == 0 { - trackInfo.MimeType = receiver.Codec().MimeType + trackInfo.MimeType = receiver.Mime().String() trackInfo.Mid = mid } for i, ci := range trackInfo.Codecs { if i == priority { - ci.MimeType = receiver.Codec().MimeType + ci.MimeType = receiver.Mime().String() ci.Mid = mid break } @@ -237,20 +238,20 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority onSetupReceiver := t.onSetupReceiver t.lock.Unlock() - var receiverCodecs []string + var receiverCodecs []mime.MimeType for _, r := range receivers { - receiverCodecs = append(receiverCodecs, r.Codec().MimeType) + receiverCodecs = append(receiverCodecs, r.Mime()) } t.params.Logger.Debugw( "setup receiver", - "mime", receiver.Codec().MimeType, + "mime", receiver.Mime(), "priority", priority, "receivers", receiverCodecs, "mid", mid, ) if onSetupReceiver != nil { - onSetupReceiver(receiver.Codec().MimeType) + onSetupReceiver(receiver.Mime()) } } @@ -279,7 +280,7 @@ func (t *MediaTrackReceiver) HandleReceiverCodecChange(r sfu.TrackReceiver, code continue } - if strings.EqualFold(receiver.Codec().MimeType, t.params.RegressionTargetCodec) { + if receiver.Mime() == t.params.RegressionTargetCodec { backupCodecReceiver = receiver.TrackReceiver } @@ -338,7 +339,7 @@ func (t *MediaTrackReceiver) SetPotentialCodecs(codecs []webrtc.RTPCodecParamete for i, c := range codecs { var exist bool for _, r := range receivers { - if strings.EqualFold(c.MimeType, r.Codec().MimeType) { + if mime.NormalizeMimeType(c.MimeType) == r.Mime() { exist = true break } @@ -357,11 +358,11 @@ func (t *MediaTrackReceiver) SetPotentialCodecs(codecs []webrtc.RTPCodecParamete t.lock.Unlock() } -func (t *MediaTrackReceiver) ClearReceiver(mime string, isExpectedToResume bool) { +func (t *MediaTrackReceiver) ClearReceiver(mime mime.MimeType, isExpectedToResume bool) { t.lock.Lock() receivers := slices.Clone(t.receivers) for idx, receiver := range receivers { - if strings.EqualFold(receiver.Codec().MimeType, mime) { + if receiver.Mime() == mime { receivers[idx] = receivers[len(receivers)-1] receivers[len(receivers)-1] = nil receivers = receivers[:len(receivers)-1] @@ -384,7 +385,7 @@ func (t *MediaTrackReceiver) ClearAllReceivers(isExpectedToResume bool) { t.lock.Unlock() for _, r := range receivers { - t.removeAllSubscribersForMime(r.Codec().MimeType, isExpectedToResume) + t.removeAllSubscribersForMime(r.Mime(), isExpectedToResume) } } @@ -563,7 +564,7 @@ func (t *MediaTrackReceiver) AddSubscriber(sub types.LocalParticipant) (types.Su codec := receiver.Codec() var found bool for _, pc := range potentialCodecs { - if strings.EqualFold(codec.MimeType, pc.MimeType) { + if mime.IsMimeTypeStringEqual(codec.MimeType, pc.MimeType) { found = true break } @@ -615,7 +616,7 @@ func (t *MediaTrackReceiver) RemoveSubscriber(subscriberID livekit.ParticipantID _ = t.MediaTrackSubscriptions.RemoveSubscriber(subscriberID, isExpectedToResume) } -func (t *MediaTrackReceiver) removeAllSubscribersForMime(mime string, isExpectedToResume bool) { +func (t *MediaTrackReceiver) removeAllSubscribersForMime(mime mime.MimeType, isExpectedToResume bool) { t.params.Logger.Debugw("removing all subscribers for mime", "mime", mime) for _, subscriberID := range t.MediaTrackSubscriptions.GetAllSubscribersForMime(mime) { t.RemoveSubscriber(subscriberID, isExpectedToResume) @@ -655,7 +656,7 @@ func (t *MediaTrackReceiver) updateTrackInfoOfReceivers() { } } -func (t *MediaTrackReceiver) SetLayerSsrc(mime string, rid string, ssrc uint32) { +func (t *MediaTrackReceiver) SetLayerSsrc(mimeType mime.MimeType, rid string, ssrc uint32) { t.lock.Lock() trackInfo := t.TrackInfoClone() layer := buffer.RidToSpatialLayer(rid, trackInfo) @@ -666,7 +667,7 @@ func (t *MediaTrackReceiver) SetLayerSsrc(mime string, rid string, ssrc uint32) quality := buffer.SpatialLayerToVideoQuality(layer, trackInfo) // set video layer ssrc info for i, ci := range trackInfo.Codecs { - if !strings.EqualFold(ci.MimeType, mime) { + if mime.NormalizeMimeType(ci.MimeType) == mimeType { continue } @@ -703,7 +704,7 @@ func (t *MediaTrackReceiver) UpdateCodecCid(codecs []*livekit.SimulcastCodec) { trackInfo := t.TrackInfoClone() for _, c := range codecs { for _, origin := range trackInfo.Codecs { - if strings.Contains(origin.MimeType, c.Codec) { + if mime.GetMimeTypeCodec(origin.MimeType) == mime.NormalizeMimeTypeCodec(c.Codec) { origin.Cid = c.Cid break } @@ -724,7 +725,7 @@ func (t *MediaTrackReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { // patch Mid and SSRC of codecs/layers by keeping original if available for i, ci := range clonedInfo.Codecs { for _, originCi := range trackInfo.Codecs { - if !strings.EqualFold(ci.MimeType, originCi.MimeType) { + if !mime.IsMimeTypeStringEqual(ci.MimeType, originCi.MimeType) { continue } @@ -945,9 +946,9 @@ func (t *MediaTrackReceiver) PrimaryReceiver() sfu.TrackReceiver { return receivers[0].TrackReceiver } -func (t *MediaTrackReceiver) Receiver(mime string) sfu.TrackReceiver { +func (t *MediaTrackReceiver) Receiver(mime mime.MimeType) sfu.TrackReceiver { for _, r := range t.loadReceivers() { - if strings.EqualFold(r.Codec().MimeType, mime) { + if r.Mime() == mime { if dr, ok := r.TrackReceiver.(*DummyReceiver); ok { return dr.Receiver() } @@ -980,7 +981,7 @@ func (t *MediaTrackReceiver) SetRTT(rtt uint32) { } } -func (t *MediaTrackReceiver) GetTemporalLayerForSpatialFps(spatial int32, fps uint32, mime string) int32 { +func (t *MediaTrackReceiver) GetTemporalLayerForSpatialFps(spatial int32, fps uint32, mime mime.MimeType) int32 { receiver := t.Receiver(mime) if receiver == nil { return buffer.DefaultMaxLayerTemporal diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 8538db467..88a9df90f 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -16,7 +16,6 @@ package rtc import ( "errors" - "strings" "sync" "github.com/pion/rtcp" @@ -24,6 +23,7 @@ import ( "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" sutils "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -46,7 +46,7 @@ type MediaTrackSubscriptions struct { subscribedTracks map[livekit.ParticipantID]types.SubscribedTrack onDownTrackCreated func(downTrack *sfu.DownTrack) - onSubscriberMaxQualityChange func(subscriberID livekit.ParticipantID, codec webrtc.RTPCodecCapability, layer int32) + onSubscriberMaxQualityChange func(subscriberID livekit.ParticipantID, mime mime.MimeType, layer int32) } type MediaTrackSubscriptionsParams struct { @@ -72,7 +72,7 @@ func (t *MediaTrackSubscriptions) OnDownTrackCreated(f func(downTrack *sfu.DownT t.onDownTrackCreated = f } -func (t *MediaTrackSubscriptions) OnSubscriberMaxQualityChange(f func(subscriberID livekit.ParticipantID, codec webrtc.RTPCodecCapability, layer int32)) { +func (t *MediaTrackSubscriptions) OnSubscriberMaxQualityChange(f func(subscriberID livekit.ParticipantID, mime mime.MimeType, layer int32)) { t.onSubscriberMaxQualityChange = f } @@ -179,7 +179,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * if t.onSubscriberMaxQualityChange != nil { go func() { spatial := buffer.VideoQualityToSpatialLayer(livekit.VideoQuality_HIGH, t.params.MediaTrack.ToProto()) - t.onSubscriberMaxQualityChange(downTrack.SubscriberID(), codec, spatial) + t.onSubscriberMaxQualityChange(downTrack.SubscriberID(), mime.NormalizeMimeType(codec.MimeType), spatial) }() } } @@ -214,7 +214,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * downTrack.OnMaxLayerChanged(func(dt *sfu.DownTrack, layer int32) { if t.onSubscriberMaxQualityChange != nil { - t.onSubscriberMaxQualityChange(dt.SubscriberID(), dt.Codec(), layer) + t.onSubscriberMaxQualityChange(dt.SubscriberID(), dt.Mime(), layer) } }) @@ -280,7 +280,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * Stereo: info.Stereo, Red: !info.DisableRed, } - if addTrackParams.Red && (len(codecs) == 1 && strings.EqualFold(codecs[0].MimeType, webrtc.MimeTypeOpus)) { + if addTrackParams.Red && (len(codecs) == 1 && mime.IsMimeTypeStringOpus(codecs[0].MimeType)) { addTrackParams.Red = false } @@ -379,13 +379,13 @@ func (t *MediaTrackSubscriptions) GetAllSubscribers() []livekit.ParticipantID { return subs } -func (t *MediaTrackSubscriptions) GetAllSubscribersForMime(mime string) []livekit.ParticipantID { +func (t *MediaTrackSubscriptions) GetAllSubscribersForMime(mime mime.MimeType) []livekit.ParticipantID { t.subscribedTracksMu.RLock() defer t.subscribedTracksMu.RUnlock() subs := make([]livekit.ParticipantID, 0, len(t.subscribedTracks)) for id, subTrack := range t.subscribedTracks { - if !strings.EqualFold(subTrack.DownTrack().Codec().MimeType, mime) { + if subTrack.DownTrack().Mime() != mime { continue } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 201532736..266a2c741 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -49,6 +49,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/pacer" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" "github.com/livekit/livekit-server/pkg/telemetry" @@ -1775,12 +1776,13 @@ func (p *ParticipantImpl) onMediaTrack(rtcTrack *webrtc.TrackRemote, rtpReceiver p.pendingTracksLock.Lock() p.pendingRemoteTracks = append(p.pendingRemoteTracks, &pendingRemoteTrack{track: rtcTrack, receiver: rtpReceiver}) p.pendingTracksLock.Unlock() - p.pubLogger.Debugw("webrtc Track published but can't find MediaTrack, add to pendingTracks", + p.pubLogger.Debugw( + "webrtc Track published but can't find MediaTrack, add to pendingTracks", "kind", track.Kind().String(), "webrtcTrackID", track.ID(), "rid", track.RID(), "SSRC", track.SSRC(), - "mime", codec.MimeType, + "mime", mime.NormalizeMimeType(codec.MimeType), ) return } @@ -1802,7 +1804,7 @@ func (p *ParticipantImpl) onMediaTrack(rtcTrack *webrtc.TrackRemote, rtpReceiver "webrtcTrackID", track.ID(), "rid", track.RID(), "SSRC", track.SSRC(), - "mime", codec.MimeType, + "mime", mime.NormalizeMimeType(codec.MimeType), "trackInfo", logger.Proto(publishedTrack.ToProto()), "fromSdp", fromSdp, ) @@ -2167,7 +2169,7 @@ func (p *ParticipantImpl) onSubscribedMaxQualityChange( // normalize the codec name for _, subscribedQuality := range subscribedQualities { - subscribedQuality.Codec = strings.ToLower(strings.TrimPrefix(subscribedQuality.Codec, "video/")) + subscribedQuality.Codec = strings.ToLower(strings.TrimPrefix(subscribedQuality.Codec, mime.MimeTypePrefixVideo)) } subscribedQualityUpdate := &livekit.SubscribedQualityUpdate{ @@ -2244,36 +2246,36 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l } else { seenCodecs := make(map[string]struct{}) for _, codec := range req.SimulcastCodecs { - mime := codec.Codec + mimeType := codec.Codec if req.Type == livekit.TrackType_VIDEO { - if !strings.HasPrefix(mime, "video/") { - mime = "video/" + mime + if !mime.IsMimeTypeStringVideo(mimeType) { + mimeType = mime.MimeTypePrefixVideo + mimeType } - if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) { + if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mimeType}) { altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs) p.pubLogger.Infow("falling back to alternative codec", - "codec", mime, + "codec", mimeType, "altCodec", altCodec, "trackID", ti.Sid, ) // select an alternative MIME type that's generally supported - mime = altCodec + mimeType = altCodec } - } else if req.Type == livekit.TrackType_AUDIO && !strings.HasPrefix(mime, "audio/") { - mime = "audio/" + mime + } else if req.Type == livekit.TrackType_AUDIO && !mime.IsMimeTypeStringAudio(mimeType) { + mimeType = mime.MimeTypePrefixAudio + mimeType } - if _, ok := seenCodecs[mime]; ok || mime == "" { + if _, ok := seenCodecs[mimeType]; ok || mimeType == "" { continue } - seenCodecs[mime] = struct{}{} + seenCodecs[mimeType] = struct{}{} clonedLayers := make([]*livekit.VideoLayer, 0, len(req.Layers)) for _, l := range req.Layers { clonedLayers = append(clonedLayers, utils.CloneProto(l)) } ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ - MimeType: mime, + MimeType: mimeType, Cid: codec.Cid, Layers: clonedLayers, }) @@ -2391,7 +2393,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track sfu.TrackRemote, rtpReceiver "trackID", track.ID(), "rid", track.RID(), "SSRC", track.SSRC(), - "mime", track.Codec().MimeType, + "mime", mime.NormalizeMimeType(track.Codec().MimeType), "mid", mid, ) if mid == "" { @@ -2416,7 +2418,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track sfu.TrackRemote, rtpReceiver var codecFound int for _, c := range ti.Codecs { for _, nc := range parameters.Codecs { - if strings.EqualFold(nc.MimeType, c.MimeType) { + if mime.IsMimeTypeStringEqual(nc.MimeType, c.MimeType) { codecFound++ break } @@ -2499,7 +2501,7 @@ func (p *ParticipantImpl) addMigratedTrack(cid string, ti *livekit.TrackInfo) *M parameters := rtpReceiver.GetParameters() for _, c := range ti.Codecs { for _, nc := range parameters.Codecs { - if strings.EqualFold(nc.MimeType, c.MimeType) { + if mime.IsMimeTypeStringEqual(nc.MimeType, c.MimeType) { potentialCodecs = append(potentialCodecs, nc) break } @@ -2508,10 +2510,10 @@ func (p *ParticipantImpl) addMigratedTrack(cid string, ti *livekit.TrackInfo) *M // check for mime_type for tracks that do not have simulcast_codecs set if ti.MimeType != "" { for _, nc := range parameters.Codecs { - if strings.EqualFold(nc.MimeType, ti.MimeType) { + if mime.IsMimeTypeStringEqual(nc.MimeType, ti.MimeType) { alreadyAdded := false for _, pc := range potentialCodecs { - if strings.EqualFold(pc.MimeType, ti.MimeType) { + if mime.IsMimeTypeStringEqual(pc.MimeType, ti.MimeType) { alreadyAdded = true break } @@ -2528,7 +2530,7 @@ func (p *ParticipantImpl) addMigratedTrack(cid string, ti *livekit.TrackInfo) *M for _, codec := range ti.Codecs { for ssrc, info := range p.params.SimTracks { if info.Mid == codec.Mid { - mt.SetLayerSsrc(codec.MimeType, info.Rid, ssrc) + mt.SetLayerSsrc(mime.NormalizeMimeType(codec.MimeType), info.Rid, ssrc) } } } @@ -2993,7 +2995,7 @@ func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Cod shouldDisable := func(c *livekit.Codec, disabled []*livekit.Codec) bool { for _, disableCodec := range disabled { // disable codec's fmtp is empty means disable this codec entirely - if strings.EqualFold(c.Mime, disableCodec.Mime) { + if mime.IsMimeTypeStringEqual(c.Mime, disableCodec.Mime) { return true } } @@ -3007,13 +3009,13 @@ func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Cod } // sort by compatibility, since we will look for backups in these. - if strings.EqualFold(c.Mime, webrtc.MimeTypeVP8) { + if mime.IsMimeTypeStringVP8(c.Mime) { if len(p.enabledPublishCodecs) > 0 { p.enabledPublishCodecs = slices.Insert(p.enabledPublishCodecs, 0, c) } else { p.enabledPublishCodecs = append(p.enabledPublishCodecs, c) } - } else if strings.EqualFold(c.Mime, webrtc.MimeTypeH264) { + } else if mime.IsMimeTypeStringH264(c.Mime) { p.enabledPublishCodecs = append(p.enabledPublishCodecs, c) } else { publishCodecs = append(publishCodecs, c) @@ -3034,7 +3036,7 @@ func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Cod func (p *ParticipantImpl) GetEnabledPublishCodecs() []*livekit.Codec { codecs := make([]*livekit.Codec, 0, len(p.enabledPublishCodecs)) for _, c := range p.enabledPublishCodecs { - if strings.EqualFold(c.Mime, webrtc.MimeTypeRTX) { + if mime.IsMimeTypeStringRTX(c.Mime) { continue } codecs = append(codecs, c) diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index c46fc1ed8..db9ad1e04 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -26,6 +26,7 @@ import ( "google.golang.org/protobuf/proto" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/telemetry/telemetryfakes" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" @@ -356,7 +357,7 @@ func TestDisableCodecs(t *testing.T) { codecs := transceiver.Receiver().GetParameters().Codecs var found264 bool for _, c := range codecs { - if strings.EqualFold(c.MimeType, "video/h264") { + if mime.IsMimeTypeStringH264(c.MimeType) { found264 = true } } @@ -390,7 +391,7 @@ func TestDisableCodecs(t *testing.T) { codecs = transceiver.Receiver().GetParameters().Codecs found264 = false for _, c := range codecs { - if strings.EqualFold(c.MimeType, "video/h264") { + if mime.IsMimeTypeStringH264(c.MimeType) { found264 = true } } @@ -410,7 +411,7 @@ func TestDisablePublishCodec(t *testing.T) { }) for _, codec := range participant.enabledPublishCodecs { - require.NotEqual(t, strings.ToLower(codec.Mime), "video/h264") + require.False(t, mime.IsMimeTypeStringH264(codec.Mime)) } sink := &routingfakes.FakeMessageSink{} @@ -421,7 +422,7 @@ func TestDisablePublishCodec(t *testing.T) { if published := res.GetTrackPublished(); published != nil { publishReceived.Store(true) require.NotEmpty(t, published.Track.Codecs) - require.Equal(t, "video/vp8", strings.ToLower(published.Track.Codecs[0].MimeType)) + require.True(t, mime.IsMimeTypeStringVP8(published.Track.Codecs[0].MimeType)) } } return nil @@ -446,7 +447,7 @@ func TestDisablePublishCodec(t *testing.T) { if published := res.GetTrackPublished(); published != nil { publishReceived.Store(true) require.NotEmpty(t, published.Track.Codecs) - require.Equal(t, "video/vp8", strings.ToLower(published.Track.Codecs[0].MimeType)) + require.True(t, mime.IsMimeTypeStringVP8(published.Track.Codecs[0].MimeType)) } } return nil @@ -493,13 +494,20 @@ func TestPreferVideoCodecForPublisher(t *testing.T) { require.NoError(t, err) transceiver, err := pc.AddTransceiverFromTrack(track, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendrecv}) require.NoError(t, err) - sdp, err := pc.CreateOffer(nil) - require.NoError(t, err) - pc.SetLocalDescription(sdp) codecs := transceiver.Receiver().GetParameters().Codecs + if i > 0 { + // the negotiated codecs order could be updated by first negotiation, reorder to make h264 not preferred + for mime.IsMimeTypeStringH264(codecs[0].MimeType) { + codecs = append(codecs[1:], codecs[0]) + } + } // h264 should not be preferred - require.NotEqual(t, codecs[0].MimeType, "video/h264") + require.False(t, mime.IsMimeTypeStringH264(codecs[0].MimeType), "codecs", codecs) + + sdp, err := pc.CreateOffer(nil) + require.NoError(t, err) + require.NoError(t, pc.SetLocalDescription(sdp)) sink := &routingfakes.FakeMessageSink{} participant.SetResponseSink(sink) @@ -528,7 +536,7 @@ func TestPreferVideoCodecForPublisher(t *testing.T) { if videoSectionIndex == i { codecs, err := codecsFromMediaDescription(m) require.NoError(t, err) - if strings.EqualFold(codecs[0].Name, "h264") { + if mime.IsMimeTypeCodecStringH264(codecs[0].Name) { h264Preferred = true break } @@ -626,7 +634,7 @@ func TestPreferAudioCodecForRed(t *testing.T) { } require.True(t, nackEnabled, "nack should be enabled for opus") - if strings.EqualFold(codecs[0].Name, "red") { + if mime.IsMimeTypeCodecStringRED(codecs[0].Name) { redPreferred = true break } diff --git a/pkg/rtc/participant_sdp.go b/pkg/rtc/participant_sdp.go index ae897795c..59d75ada9 100644 --- a/pkg/rtc/participant_sdp.go +++ b/pkg/rtc/participant_sdp.go @@ -22,6 +22,7 @@ import ( "github.com/pion/sdp/v3" "github.com/pion/webrtc/v4" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" lksdp "github.com/livekit/protocol/sdp" ) @@ -58,7 +59,7 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.Se var opusPayload uint8 for _, codec := range codecs { - if strings.EqualFold(codec.Name, "opus") { + if mime.IsMimeTypeCodecStringOpus(codec.Name) { opusPayload = codec.PayloadType break } @@ -70,7 +71,7 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.Se var preferredCodecs, leftCodecs []string for _, codec := range codecs { // codec contain opus/red - if !disableRed && strings.EqualFold(codec.Name, "red") && strings.Contains(codec.Fmtp, strconv.FormatInt(int64(opusPayload), 10)) { + if !disableRed && mime.IsMimeTypeCodecStringRED(codec.Name) && strings.Contains(codec.Fmtp, strconv.FormatInt(int64(opusPayload), 10)) { preferredCodecs = append(preferredCodecs, strconv.FormatInt(int64(codec.PayloadType), 10)) } else { leftCodecs = append(leftCodecs, strconv.FormatInt(int64(codec.PayloadType), 10)) @@ -138,20 +139,19 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess p.pendingTracksLock.RUnlock() continue } - var mime string + var mimeType string for _, c := range info.Codecs { if c.Cid == streamID { - mime = c.MimeType + mimeType = c.MimeType break } } - if mime == "" && len(info.Codecs) > 0 { - mime = info.Codecs[0].MimeType + if mimeType == "" && len(info.Codecs) > 0 { + mimeType = info.Codecs[0].MimeType } p.pendingTracksLock.RUnlock() - mime = strings.ToUpper(mime) - if mime != "" { + if mimeType != "" { codecs, err := codecsFromMediaDescription(unmatchVideo) if err != nil { p.pubLogger.Errorw("extract codecs from media section failed", err, "media", unmatchVideo) @@ -160,7 +160,7 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess var preferredCodecs, leftCodecs []string for _, c := range codecs { - if strings.HasSuffix(mime, strings.ToUpper(c.Name)) { + if mime.GetMimeTypeCodec(mimeType) == mime.NormalizeMimeTypeCodec(c.Name) { preferredCodecs = append(preferredCodecs, strconv.FormatInt(int64(c.PayloadType), 10)) } else { leftCodecs = append(leftCodecs, strconv.FormatInt(int64(c.PayloadType), 10)) @@ -241,7 +241,7 @@ func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescript continue } - opusPT, err := parsed.GetPayloadTypeForCodec(sdp.Codec{Name: "opus"}) + opusPT, err := parsed.GetPayloadTypeForCodec(sdp.Codec{Name: mime.MimeTypeCodecOpus.String()}) if err != nil { p.pubLogger.Infow("failed to get opus payload type", "error", err, "trackID", ti.Sid) continue diff --git a/pkg/rtc/subscribedtrack.go b/pkg/rtc/subscribedtrack.go index 5c83832f3..a45f4e177 100644 --- a/pkg/rtc/subscribedtrack.go +++ b/pkg/rtc/subscribedtrack.go @@ -259,7 +259,7 @@ func (t *SubscribedTrack) applySettings() { spatial = buffer.VideoQualityToSpatialLayer(quality, mt.ToProto()) if t.settings.Fps > 0 { - temporal = mt.GetTemporalLayerForSpatialFps(spatial, t.settings.Fps, dt.Codec().MimeType) + temporal = mt.GetTemporalLayerForSpatialFps(spatial, t.settings.Fps, dt.Mime()) } } diff --git a/pkg/rtc/subscriptionmanager.go b/pkg/rtc/subscriptionmanager.go index be97b2453..2d41f37c3 100644 --- a/pkg/rtc/subscriptionmanager.go +++ b/pkg/rtc/subscriptionmanager.go @@ -819,7 +819,7 @@ func (m *SubscriptionManager) handleSubscribedTrackClose(s *trackSubscription, i context.Background(), m.params.Participant.ID(), s.trackID, - dt.Codec().MimeType, + dt.Mime(), stats, ) } diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index a274548ec..24e5801cb 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -44,6 +44,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/bwe/sendsidebwe" "github.com/livekit/livekit-server/pkg/sfu/datachannel" sfuinterceptor "github.com/livekit/livekit-server/pkg/sfu/interceptor" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/pacer" pd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/playoutdelay" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" @@ -416,11 +417,11 @@ func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimat } setTWCCForVideo := func(info *interceptor.StreamInfo) { - if !strings.HasPrefix(info.MimeType, "video") { + if !mime.IsMimeTypeStringVideo(info.MimeType) { return } // rtx stream don't have rtcp feedback, always set twcc for rtx stream - twccFb := strings.HasSuffix(info.MimeType, "rtx") + twccFb := mime.GetMimeTypeCodec(info.MimeType) == mime.MimeTypeCodecRTX if !twccFb { for _, fb := range info.RTCPFeedback { if fb.Type == webrtc.TypeRTCPFBTransportCC { @@ -2089,7 +2090,7 @@ func configureAudioTransceiver(tr *webrtc.RTPTransceiver, stereo bool, nack bool codecs := sender.GetParameters().Codecs configCodecs := make([]webrtc.RTPCodecParameters, 0, len(codecs)) for _, c := range codecs { - if strings.EqualFold(c.MimeType, webrtc.MimeTypeOpus) { + if mime.IsMimeTypeStringOpus(c.MimeType) { c.SDPFmtpLine = strings.ReplaceAll(c.SDPFmtpLine, ";sprop-stereo=1", "") if stereo { c.SDPFmtpLine += ";sprop-stereo=1" diff --git a/pkg/rtc/transport_test.go b/pkg/rtc/transport_test.go index 512b2ff91..ac0c1cc97 100644 --- a/pkg/rtc/transport_test.go +++ b/pkg/rtc/transport_test.go @@ -28,6 +28,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/transport/transportfakes" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/testutils" "github.com/livekit/protocol/livekit" ) @@ -389,9 +390,9 @@ func TestFilteringCandidates(t *testing.T) { ParticipantIdentity: "identity", Config: &WebRTCConfig{}, EnabledCodecs: []*livekit.Codec{ - {Mime: webrtc.MimeTypeOpus}, - {Mime: webrtc.MimeTypeVP8}, - {Mime: webrtc.MimeTypeH264}, + {Mime: mime.MimeTypeOpus.String()}, + {Mime: mime.MimeTypeVP8.String()}, + {Mime: mime.MimeTypeH264.String()}, }, Handler: &transportfakes.FakeHandler{}, } @@ -598,7 +599,7 @@ func TestConfigureAudioTransceiver(t *testing.T) { } { t.Run(fmt.Sprintf("nack=%v,stereo=%v", testcase.nack, testcase.stereo), func(t *testing.T) { var me webrtc.MediaEngine - registerCodecs(&me, []*livekit.Codec{{Mime: webrtc.MimeTypeOpus}}, RTCPFeedbackConfig{Audio: []webrtc.RTCPFeedback{{Type: webrtc.TypeRTCPFBNACK}}}, false) + registerCodecs(&me, []*livekit.Codec{{Mime: mime.MimeTypeOpus.String()}}, RTCPFeedbackConfig{Audio: []webrtc.RTCPFeedback{{Type: webrtc.TypeRTCPFBNACK}}}, false) pc, err := webrtc.NewAPI(webrtc.WithMediaEngine(&me)).NewPeerConnection(webrtc.Configuration{}) require.NoError(t, err) defer pc.Close() @@ -608,7 +609,7 @@ func TestConfigureAudioTransceiver(t *testing.T) { configureAudioTransceiver(tr, testcase.stereo, testcase.nack) codecs := tr.Sender().GetParameters().Codecs for _, codec := range codecs { - if strings.Contains(codec.MimeType, webrtc.MimeTypeOpus) { + if mime.IsMimeTypeStringOpus(codec.MimeType) { require.Equal(t, testcase.stereo, strings.Contains(codec.SDPFmtpLine, "sprop-stereo=1")) var nackEnabled bool for _, fb := range codec.RTCPFeedback { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index eeaedc52b..d44bb329c 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -29,6 +29,7 @@ import ( "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/pacer" ) @@ -74,7 +75,7 @@ func (m MigrateState) String() string { // --------------------------------------------- type SubscribedCodecQuality struct { - CodecMime string + CodecMime mime.MimeType Quality livekit.VideoQuality } @@ -527,7 +528,7 @@ type MediaTrack interface { GetQualityForDimension(width, height uint32) livekit.VideoQuality // returns temporal layer that's appropriate for fps - GetTemporalLayerForSpatialFps(spatial int32, fps uint32, mime string) int32 + GetTemporalLayerForSpatialFps(spatial int32, fps uint32, mime mime.MimeType) int32 Receivers() []sfu.TrackReceiver ClearAllReceivers(isExpectedToResume bool) diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index 382c609ae..877267f6c 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -6,6 +6,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) @@ -94,12 +95,12 @@ type FakeLocalMediaTrack struct { getQualityForDimensionReturnsOnCall map[int]struct { result1 livekit.VideoQuality } - GetTemporalLayerForSpatialFpsStub func(int32, uint32, string) int32 + GetTemporalLayerForSpatialFpsStub func(int32, uint32, mime.MimeType) int32 getTemporalLayerForSpatialFpsMutex sync.RWMutex getTemporalLayerForSpatialFpsArgsForCall []struct { arg1 int32 arg2 uint32 - arg3 string + arg3 mime.MimeType } getTemporalLayerForSpatialFpsReturns struct { result1 int32 @@ -795,13 +796,13 @@ func (fake *FakeLocalMediaTrack) GetQualityForDimensionReturnsOnCall(i int, resu }{result1} } -func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFps(arg1 int32, arg2 uint32, arg3 string) int32 { +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFps(arg1 int32, arg2 uint32, arg3 mime.MimeType) int32 { fake.getTemporalLayerForSpatialFpsMutex.Lock() ret, specificReturn := fake.getTemporalLayerForSpatialFpsReturnsOnCall[len(fake.getTemporalLayerForSpatialFpsArgsForCall)] fake.getTemporalLayerForSpatialFpsArgsForCall = append(fake.getTemporalLayerForSpatialFpsArgsForCall, struct { arg1 int32 arg2 uint32 - arg3 string + arg3 mime.MimeType }{arg1, arg2, arg3}) stub := fake.GetTemporalLayerForSpatialFpsStub fakeReturns := fake.getTemporalLayerForSpatialFpsReturns @@ -822,13 +823,13 @@ func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsCallCount() int { return len(fake.getTemporalLayerForSpatialFpsArgsForCall) } -func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsCalls(stub func(int32, uint32, string) int32) { +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsCalls(stub func(int32, uint32, mime.MimeType) int32) { fake.getTemporalLayerForSpatialFpsMutex.Lock() defer fake.getTemporalLayerForSpatialFpsMutex.Unlock() fake.GetTemporalLayerForSpatialFpsStub = stub } -func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsArgsForCall(i int) (int32, uint32, string) { +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsArgsForCall(i int) (int32, uint32, mime.MimeType) { fake.getTemporalLayerForSpatialFpsMutex.RLock() defer fake.getTemporalLayerForSpatialFpsMutex.RUnlock() argsForCall := fake.getTemporalLayerForSpatialFpsArgsForCall[i] diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index ce6e9355e..67304714d 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -6,6 +6,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) @@ -82,12 +83,12 @@ type FakeMediaTrack struct { getQualityForDimensionReturnsOnCall map[int]struct { result1 livekit.VideoQuality } - GetTemporalLayerForSpatialFpsStub func(int32, uint32, string) int32 + GetTemporalLayerForSpatialFpsStub func(int32, uint32, mime.MimeType) int32 getTemporalLayerForSpatialFpsMutex sync.RWMutex getTemporalLayerForSpatialFpsArgsForCall []struct { arg1 int32 arg2 uint32 - arg3 string + arg3 mime.MimeType } getTemporalLayerForSpatialFpsReturns struct { result1 int32 @@ -675,13 +676,13 @@ func (fake *FakeMediaTrack) GetQualityForDimensionReturnsOnCall(i int, result1 l }{result1} } -func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFps(arg1 int32, arg2 uint32, arg3 string) int32 { +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFps(arg1 int32, arg2 uint32, arg3 mime.MimeType) int32 { fake.getTemporalLayerForSpatialFpsMutex.Lock() ret, specificReturn := fake.getTemporalLayerForSpatialFpsReturnsOnCall[len(fake.getTemporalLayerForSpatialFpsArgsForCall)] fake.getTemporalLayerForSpatialFpsArgsForCall = append(fake.getTemporalLayerForSpatialFpsArgsForCall, struct { arg1 int32 arg2 uint32 - arg3 string + arg3 mime.MimeType }{arg1, arg2, arg3}) stub := fake.GetTemporalLayerForSpatialFpsStub fakeReturns := fake.getTemporalLayerForSpatialFpsReturns @@ -702,13 +703,13 @@ func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsCallCount() int { return len(fake.getTemporalLayerForSpatialFpsArgsForCall) } -func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsCalls(stub func(int32, uint32, string) int32) { +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsCalls(stub func(int32, uint32, mime.MimeType) int32) { fake.getTemporalLayerForSpatialFpsMutex.Lock() defer fake.getTemporalLayerForSpatialFpsMutex.Unlock() fake.GetTemporalLayerForSpatialFpsStub = stub } -func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsArgsForCall(i int) (int32, uint32, string) { +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsArgsForCall(i int) (int32, uint32, mime.MimeType) { fake.getTemporalLayerForSpatialFpsMutex.RLock() defer fake.getTemporalLayerForSpatialFpsMutex.RUnlock() argsForCall := fake.getTemporalLayerForSpatialFpsArgsForCall[i] diff --git a/pkg/rtc/utils.go b/pkg/rtc/utils.go index 1293cd42b..b6d0c29bf 100644 --- a/pkg/rtc/utils.go +++ b/pkg/rtc/utils.go @@ -23,6 +23,7 @@ import ( "github.com/pion/webrtc/v4" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) @@ -195,9 +196,9 @@ func LoggerWithPCTarget(l logger.Logger, target livekit.SignalTarget) logger.Log return l.WithValues("transport", target) } -func LoggerWithCodecMime(l logger.Logger, mime string) logger.Logger { - if mime != "" { - return l.WithValues("mime", mime) +func LoggerWithCodecMime(l logger.Logger, mimeType mime.MimeType) logger.Logger { + if mimeType != mime.MimeTypeUnknown { + return l.WithValues("mime", mimeType.String()) } return l } diff --git a/pkg/rtc/wrappedreceiver.go b/pkg/rtc/wrappedreceiver.go index 1f8fc81f2..a4d41a715 100644 --- a/pkg/rtc/wrappedreceiver.go +++ b/pkg/rtc/wrappedreceiver.go @@ -16,7 +16,6 @@ package rtc import ( "errors" - "strings" "sync" "github.com/pion/webrtc/v4" @@ -28,6 +27,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/mime" ) // wrapper around WebRTC receiver, overriding its ID @@ -59,13 +59,14 @@ func NewWrappedReceiver(params WrappedReceiverParams) *WrappedReceiver { codecs := params.UpstreamCodecs if len(codecs) == 1 { - if strings.EqualFold(codecs[0].MimeType, sfu.MimeTypeAudioRed) { + normalizedMimeType := mime.NormalizeMimeType(codecs[0].MimeType) + if normalizedMimeType == mime.MimeTypeRED { // if upstream is opus/red, then add opus to match clients that don't support red codecs = append(codecs, webrtc.RTPCodecParameters{ RTPCodecCapability: OpusCodecCapability, PayloadType: 111, }) - } else if !params.DisableRed && strings.EqualFold(codecs[0].MimeType, webrtc.MimeTypeOpus) { + } else if !params.DisableRed && normalizedMimeType == mime.MimeTypeOpus { // if upstream is opus only and red enabled, add red to match clients that support red codecs = append(codecs, webrtc.RTPCodecParameters{ RTPCodecCapability: RedCodecCapability, @@ -95,16 +96,18 @@ func (r *WrappedReceiver) StreamID() string { func (r *WrappedReceiver) DetermineReceiver(codec webrtc.RTPCodecCapability) bool { r.lock.Lock() + codecMimeType := mime.NormalizeMimeType(codec.MimeType) var trackReceiver sfu.TrackReceiver for _, receiver := range r.receivers { - if c := receiver.Codec(); strings.EqualFold(c.MimeType, codec.MimeType) { + receiverMimeType := receiver.Mime() + if receiverMimeType == codecMimeType { trackReceiver = receiver break - } else if strings.EqualFold(c.MimeType, sfu.MimeTypeAudioRed) && strings.EqualFold(codec.MimeType, webrtc.MimeTypeOpus) { + } else if receiverMimeType == mime.MimeTypeRED && codecMimeType == mime.MimeTypeOpus { // audio opus/red can match opus only trackReceiver = receiver.GetPrimaryReceiverForRed() break - } else if strings.EqualFold(c.MimeType, webrtc.MimeTypeOpus) && strings.EqualFold(codec.MimeType, sfu.MimeTypeAudioRed) { + } else if receiverMimeType == mime.MimeTypeOpus && codecMimeType == mime.MimeTypeRED { trackReceiver = receiver.GetRedReceiver() break } @@ -263,6 +266,13 @@ func (d *DummyReceiver) Codec() webrtc.RTPCodecParameters { return d.codec } +func (d *DummyReceiver) Mime() mime.MimeType { + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + return r.Mime() + } + return mime.NormalizeMimeType(d.codec.MimeType) +} + func (d *DummyReceiver) HeaderExtensions() []webrtc.RTPHeaderExtensionParameter { if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { return r.HeaderExtensions() diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 95233ddaa..427ca1c11 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -32,6 +32,7 @@ import ( "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/sfu/audio" + "github.com/livekit/livekit-server/pkg/sfu/mime" act "github.com/livekit/livekit-server/pkg/sfu/rtpextension/abscapturetime" dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" @@ -94,7 +95,7 @@ type Buffer struct { rtpParameters webrtc.RTPParameters payloadType uint8 rtxPayloadType uint8 - mime string + mime mime.MimeType snRangeMap *utils.RangeMap[uint64, uint64] @@ -218,10 +219,10 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili b.clockRate = codec.ClockRate b.lastReport = mono.UnixNano() - b.mime = strings.ToLower(codec.MimeType) + b.mime = mime.NormalizeMimeType(codec.MimeType) b.rtpParameters = params for _, codecParameter := range params.Codecs { - if strings.EqualFold(codecParameter.MimeType, codec.MimeType) { + if mime.IsMimeTypeStringEqual(codecParameter.MimeType, codec.MimeType) { b.payloadType = uint8(codecParameter.PayloadType) break } @@ -234,7 +235,7 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili // find RTX payload type for _, codec := range params.Codecs { - if strings.EqualFold(codec.MimeType, "video/rtx") && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", b.payloadType)) { + if mime.IsMimeTypeStringRTX(codec.MimeType) && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", b.payloadType)) { b.rtxPayloadType = uint8(codec.PayloadType) break } @@ -248,7 +249,7 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili continue } b.ddExtID = uint8(ext.ID) - b.createDDParserAndFrameRateCalculator(codec.MimeType) + b.createDDParserAndFrameRateCalculator() case sdp.AudioLevelURI: b.audioLevelExtID = uint8(ext.ID) @@ -260,15 +261,15 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili } switch { - case strings.HasPrefix(b.mime, "audio/"): + case mime.IsMimeTypeAudio(b.mime): b.codecType = webrtc.RTPCodecTypeAudio b.bucket = bucket.NewBucket[uint64](InitPacketBufferSizeAudio) - case strings.HasPrefix(b.mime, "video/"): + case mime.IsMimeTypeVideo(b.mime): b.codecType = webrtc.RTPCodecTypeVideo b.bucket = bucket.NewBucket[uint64](InitPacketBufferSizeVideo) if b.frameRateCalculator[0] == nil { - b.createFrameRateCalculator(codec.MimeType) + b.createFrameRateCalculator() } if bitrates > 0 { pps := bitrates / 8 / 1200 @@ -292,7 +293,7 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili // pion use a single mediaengine to manage negotiated codecs of peerconnection, that means we can't have different // codec settings at track level for same codec type, so enable nack for all audio receivers but don't create nack queue // for red codec. - if strings.EqualFold(b.mime, "audio/red") { + if b.mime == mime.MimeTypeRED { break } b.logger.Debugw("Setting feedback", "type", webrtc.TypeRTCPFBNACK) @@ -313,8 +314,8 @@ func (b *Buffer) OnCodecChange(fn func(webrtc.RTPCodecParameters)) { b.Unlock() } -func (b *Buffer) createDDParserAndFrameRateCalculator(mime string) { - if IsSvcCodec(mime) || strings.EqualFold(mime, webrtc.MimeTypeVP8) { +func (b *Buffer) createDDParserAndFrameRateCalculator() { + if mime.IsMimeTypeSVC(b.mime) || b.mime == mime.MimeTypeVP8 { frc := NewFrameRateCalculatorDD(b.clockRate, b.logger) for i := range b.frameRateCalculator { b.frameRateCalculator[i] = frc.GetFrameRateCalculatorForSpatial(int32(i)) @@ -325,18 +326,18 @@ func (b *Buffer) createDDParserAndFrameRateCalculator(mime string) { } } -func (b *Buffer) createFrameRateCalculator(mime string) { - switch { - case strings.EqualFold(mime, webrtc.MimeTypeVP8): +func (b *Buffer) createFrameRateCalculator() { + switch b.mime { + case mime.MimeTypeVP8: b.frameRateCalculator[0] = NewFrameRateCalculatorVP8(b.clockRate, b.logger) - case strings.EqualFold(mime, webrtc.MimeTypeVP9): + case mime.MimeTypeVP9: frc := NewFrameRateCalculatorVP9(b.clockRate, b.logger) for i := range b.frameRateCalculator { b.frameRateCalculator[i] = frc.GetFrameRateCalculatorForSpatial(int32(i)) } - case strings.EqualFold(mime, webrtc.MimeTypeH265): + case mime.MimeTypeH265: b.frameRateCalculator[0] = NewFrameRateCalculatorH26x(b.clockRate, b.logger) } } @@ -763,7 +764,7 @@ func (b *Buffer) handleCodecChange(newPT uint8) { codecFound = true } - if strings.EqualFold(codec.MimeType, "video/rtx") && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", newPT)) { + if mime.IsMimeTypeStringRTX(codec.MimeType) && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", newPT)) { rtxFound = true rtxPt = uint8(codec.PayloadType) } @@ -776,21 +777,22 @@ func (b *Buffer) handleCodecChange(newPT uint8) { b.logger.Errorw("could not find codec for new payload type", nil, "pt", newPT, "rtpParameters", b.rtpParameters) return } - b.logger.Infow("codec changed", + b.logger.Infow( + "codec changed", "oldPayload", b.payloadType, "newPayload", newPT, "oldRtxPayload", b.rtxPayloadType, "newRtxPayload", rtxPt, "oldMime", b.mime, "newMime", newCodec.MimeType) b.payloadType = newPT b.rtxPayloadType = rtxPt - b.mime = strings.ToLower(newCodec.MimeType) + b.mime = mime.NormalizeMimeType(newCodec.MimeType) b.frameRateCalculated = false if b.ddExtID != 0 { - b.createDDParserAndFrameRateCalculator(b.mime) + b.createDDParserAndFrameRateCalculator() } if b.frameRateCalculator[0] == nil { - b.createFrameRateCalculator(b.mime) + b.createFrameRateCalculator() } b.bucket.ResyncOnNextPacket() @@ -874,8 +876,8 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64, flowStat } } - switch utils.MatchMimeType(b.mime) { - case utils.MimeTypeVP8: + switch b.mime { + case mime.MimeTypeVP8: vp8Packet := VP8{} if err := vp8Packet.Unmarshal(rtpPacket.Payload); err != nil { b.logger.Warnw("could not unmarshal VP8 packet", err) @@ -891,7 +893,7 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64, flowStat ep.Payload = vp8Packet ep.Spatial = InvalidLayerSpatial // vp8 don't have spatial scalability, reset to invalid - case utils.MimeTypeVP9: + case mime.MimeTypeVP9: if ep.DependencyDescriptor == nil { var vp9Packet codecs.VP9Packet _, err := vp9Packet.Unmarshal(rtpPacket.Payload) @@ -907,14 +909,14 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64, flowStat } ep.KeyFrame = IsVP9KeyFrame(rtpPacket.Payload) - case utils.MimeTypeH264: + case mime.MimeTypeH264: ep.KeyFrame = IsH264KeyFrame(rtpPacket.Payload) ep.Spatial = InvalidLayerSpatial // h.264 don't have spatial scalability, reset to invalid - case utils.MimeTypeAV1: + case mime.MimeTypeAV1: ep.KeyFrame = IsAV1KeyFrame(rtpPacket.Payload) - case utils.MimeTypeH265: + case mime.MimeTypeH265: if ep.DependencyDescriptor == nil { if len(rtpPacket.Payload) < 2 { b.logger.Warnw("invalid H265 packet", nil) @@ -1218,19 +1220,3 @@ func (b *Buffer) GetTemporalLayerFpsForSpatial(layer int32) []float32 { } // --------------------------------------------------------------- - -// SVC-TODO: Have to use more conditions to differentiate between -// SVC-TODO: SVC and non-SVC (could be single layer or simulcast). -// SVC-TODO: May only need to differentiate between simulcast and non-simulcast -// SVC-TODO: i. e. may be possible to treat single layer as SVC to get proper/intended functionality. -func IsSvcCodec(mime string) bool { - switch utils.MatchMimeType(mime) { - case utils.MimeTypeAV1, utils.MimeTypeVP9: - return true - } - return false -} - -func IsRedCodec(mime string) bool { - return strings.HasSuffix(strings.ToLower(mime), "red") -} diff --git a/pkg/sfu/connectionquality/connectionstats.go b/pkg/sfu/connectionquality/connectionstats.go index eb254831a..f4f09b7c5 100644 --- a/pkg/sfu/connectionquality/connectionstats.go +++ b/pkg/sfu/connectionquality/connectionstats.go @@ -15,12 +15,10 @@ package connectionquality import ( - "strings" "sync" "time" "github.com/frostbyte73/core" - "github.com/pion/webrtc/v4" "go.uber.org/atomic" "google.golang.org/protobuf/types/known/timestamppb" @@ -28,6 +26,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" ) @@ -60,7 +59,7 @@ type ConnectionStatsParams struct { type ConnectionStats struct { params ConnectionStatsParams - codecMimeType atomic.String + codecMimeType atomic.Value // mime.MimeType isStarted atomic.Bool isVideo atomic.Bool @@ -88,19 +87,19 @@ func NewConnectionStats(params ConnectionStatsParams) *ConnectionStats { } } -func (cs *ConnectionStats) StartAt(codecMimeType string, isFECEnabled bool, at time.Time) { +func (cs *ConnectionStats) StartAt(codecMimeType mime.MimeType, isFECEnabled bool, at time.Time) { if cs.isStarted.Swap(true) { return } - cs.isVideo.Store(strings.HasPrefix(strings.ToLower(codecMimeType), "video/")) + cs.isVideo.Store(mime.IsMimeTypeVideo(codecMimeType)) cs.codecMimeType.Store(codecMimeType) cs.scorer.StartAt(getPacketLossWeight(codecMimeType, isFECEnabled), at) go cs.updateStatsWorker() } -func (cs *ConnectionStats) Start(codecMimeType string, isFECEnabled bool) { +func (cs *ConnectionStats) Start(codecMimeType mime.MimeType, isFECEnabled bool) { cs.StartAt(codecMimeType, isFECEnabled, time.Now()) } @@ -108,6 +107,12 @@ func (cs *ConnectionStats) Close() { cs.done.Break() } +func (cs *ConnectionStats) UpdateCodec(codecMimeType mime.MimeType, isFECEnabled bool) { + cs.isVideo.Store(mime.IsMimeTypeVideo(codecMimeType)) + cs.codecMimeType.Store(codecMimeType) + cs.scorer.UpdatePacketLossWeight(getPacketLossWeight(codecMimeType, isFECEnabled)) +} + func (cs *ConnectionStats) OnStatsUpdate(fn func(cs *ConnectionStats, stat *livekit.AnalyticsStat)) { cs.onStatsUpdate = fn } @@ -340,7 +345,7 @@ func (cs *ConnectionStats) getStat() { cs.onStatsUpdate(cs, &livekit.AnalyticsStat{ Score: score, Streams: analyticsStreams, - Mime: cs.codecMimeType.Load(), + Mime: cs.codecMimeType.Load().(mime.MimeType).String(), }) } } @@ -381,10 +386,10 @@ func (cs *ConnectionStats) updateStatsWorker() { // For video: // // o No in-built codec repair available, hence same for all codecs -func getPacketLossWeight(mimeType string, isFecEnabled bool) float64 { +func getPacketLossWeight(mimeType mime.MimeType, isFecEnabled bool) float64 { var plw float64 switch { - case strings.EqualFold(mimeType, webrtc.MimeTypeOpus): + case mimeType == mime.MimeTypeOpus: // 2.5%: fall to GOOD, 7.5%: fall to POOR plw = 8.0 if isFecEnabled { @@ -392,7 +397,7 @@ func getPacketLossWeight(mimeType string, isFecEnabled bool) float64 { plw /= 1.5 } - case strings.EqualFold(mimeType, "audio/red"): + case mimeType == mime.MimeTypeRED: // 5%: fall to GOOD, 15.0%: fall to POOR plw = 4.0 if isFecEnabled { @@ -400,7 +405,7 @@ func getPacketLossWeight(mimeType string, isFecEnabled bool) float64 { plw /= 1.5 } - case strings.HasPrefix(strings.ToLower(mimeType), "video/"): + case mime.IsMimeTypeVideo(mimeType): // 2%: fall to GOOD, 6%: fall to POOR plw = 10.0 } diff --git a/pkg/sfu/connectionquality/connectionstats_test.go b/pkg/sfu/connectionquality/connectionstats_test.go index ca380eb72..1e74eb41e 100644 --- a/pkg/sfu/connectionquality/connectionstats_test.go +++ b/pkg/sfu/connectionquality/connectionstats_test.go @@ -19,10 +19,10 @@ import ( "testing" "time" - "github.com/pion/webrtc/v4" "github.com/stretchr/testify/require" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -70,7 +70,7 @@ func TestConnectionQuality(t *testing.T) { duration := 5 * time.Second now := time.Now() - cs.StartAt(webrtc.MimeTypeOpus, false, now.Add(-duration)) + cs.StartAt(mime.MimeTypeOpus, false, now.Add(-duration)) cs.UpdateMuteAt(false, now.Add(-1*time.Second)) // no data and not enough unmute time should return default state which is EXCELLENT quality @@ -484,7 +484,7 @@ func TestConnectionQuality(t *testing.T) { duration := 5 * time.Second now := time.Now() - cs.StartAt(webrtc.MimeTypeOpus, false, now.Add(-duration)) + cs.StartAt(mime.MimeTypeOpus, false, now.Add(-duration)) cs.UpdateMuteAt(false, now.Add(-1*time.Second)) // RTT does not knock quality down because it is dependent and hence not taken into account @@ -517,7 +517,7 @@ func TestConnectionQuality(t *testing.T) { duration := 5 * time.Second now := time.Now() - cs.StartAt(webrtc.MimeTypeOpus, false, now.Add(-duration)) + cs.StartAt(mime.MimeTypeOpus, false, now.Add(-duration)) cs.UpdateMuteAt(false, now.Add(-1*time.Second)) // Jitter does not knock quality down because it is dependent and hence not taken into account @@ -548,7 +548,7 @@ func TestConnectionQuality(t *testing.T) { } testCases := []struct { name string - mimeType string + mimeType mime.MimeType isFECEnabled bool packetsExpected uint32 expectedQualities []expectedQuality @@ -557,7 +557,7 @@ func TestConnectionQuality(t *testing.T) { // "audio/opus" - no fec - 0 <= loss < 2.5%: EXCELLENT, 2.5% <= loss < 7.5%: GOOD, >= 7.5%: POOR { name: "audio/opus - no fec", - mimeType: "audio/opus", + mimeType: mime.MimeTypeOpus, isFECEnabled: false, packetsExpected: 200, expectedQualities: []expectedQuality{ @@ -581,7 +581,7 @@ func TestConnectionQuality(t *testing.T) { // "audio/opus" - fec - 0 <= loss < 3.75%: EXCELLENT, 3.75% <= loss < 11.25%: GOOD, >= 11.25%: POOR { name: "audio/opus - fec", - mimeType: "audio/opus", + mimeType: mime.MimeTypeOpus, isFECEnabled: true, packetsExpected: 200, expectedQualities: []expectedQuality{ @@ -605,7 +605,7 @@ func TestConnectionQuality(t *testing.T) { // "audio/red" - no fec - 0 <= loss < 5%: EXCELLENT, 5% <= loss < 15%: GOOD, >= 15%: POOR { name: "audio/red - no fec", - mimeType: "audio/red", + mimeType: mime.MimeTypeRED, isFECEnabled: false, packetsExpected: 200, expectedQualities: []expectedQuality{ @@ -629,7 +629,7 @@ func TestConnectionQuality(t *testing.T) { // "audio/red" - fec - 0 <= loss < 7.5%: EXCELLENT, 7.5% <= loss < 22.5%: GOOD, >= 22.5%: POOR { name: "audio/red - fec", - mimeType: "audio/red", + mimeType: mime.MimeTypeRED, isFECEnabled: true, packetsExpected: 200, expectedQualities: []expectedQuality{ @@ -653,7 +653,7 @@ func TestConnectionQuality(t *testing.T) { // "video/*" - 0 <= loss < 2%: EXCELLENT, 2% <= loss < 6%: GOOD, >= 6%: POOR { name: "video/*", - mimeType: "video/vp8", + mimeType: mime.MimeTypeVP8, isFECEnabled: false, packetsExpected: 200, expectedQualities: []expectedQuality{ @@ -786,7 +786,7 @@ func TestConnectionQuality(t *testing.T) { duration := 5 * time.Second now := time.Now() - cs.StartAt(webrtc.MimeTypeVP8, false, now) + cs.StartAt(mime.MimeTypeVP8, false, now) for _, tr := range tc.transitions { cs.AddBitrateTransitionAt(tr.bitrate, now.Add(tr.offset)) @@ -878,7 +878,7 @@ func TestConnectionQuality(t *testing.T) { duration := 5 * time.Second now := time.Now() - cs.StartAt(webrtc.MimeTypeVP8, false, now) + cs.StartAt(mime.MimeTypeVP8, false, now) for _, tr := range tc.transitions { cs.AddLayerTransitionAt(tr.distance, now.Add(tr.offset)) diff --git a/pkg/sfu/connectionquality/scorer.go b/pkg/sfu/connectionquality/scorer.go index ebafcd9bd..af9290027 100644 --- a/pkg/sfu/connectionquality/scorer.go +++ b/pkg/sfu/connectionquality/scorer.go @@ -258,6 +258,13 @@ func (q *qualityScorer) Start(packetLossWeight float64) { q.startAtLocked(packetLossWeight, time.Now()) } +func (q *qualityScorer) UpdatePacketLossWeight(packetLossWeight float64) { + q.lock.Lock() + defer q.lock.Unlock() + + q.packetLossWeight = packetLossWeight +} + func (q *qualityScorer) updateMuteAtLocked(isMuted bool, at time.Time) { if isMuted { q.mutedAt = at diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 4505c76f2..4faa09886 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -40,6 +40,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/ccutils" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/pacer" act "github.com/livekit/livekit-server/pkg/sfu/rtpextension/abscapturetime" dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" @@ -253,6 +254,7 @@ type DownTrack struct { params DowntrackParams id livekit.TrackID kind webrtc.RTPCodecType + mime mime.MimeType ssrc uint32 ssrcRTX uint32 payloadType atomic.Uint32 @@ -345,11 +347,12 @@ type DownTrack struct { // NewDownTrack returns a DownTrack. func NewDownTrack(params DowntrackParams) (*DownTrack, error) { codecs := params.Codecs + mimeType := mime.NormalizeMimeType(codecs[0].MimeType) var kind webrtc.RTPCodecType switch { - case strings.HasPrefix(codecs[0].MimeType, "audio/"): + case mime.IsMimeTypeAudio(mimeType): kind = webrtc.RTPCodecTypeAudio - case strings.HasPrefix(codecs[0].MimeType, "video/"): + case mime.IsMimeTypeVideo(mimeType): kind = webrtc.RTPCodecTypeVideo default: kind = webrtc.RTPCodecType(0) @@ -505,13 +508,13 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, } isFECEnabled := false - if strings.EqualFold(matchedUpstreamCodec.MimeType, MimeTypeAudioRed) { + if mime.IsMimeTypeStringRED(matchedUpstreamCodec.MimeType) { d.isRED = true for _, c := range d.upstreamCodecs { isFECEnabled = strings.Contains(strings.ToLower(c.SDPFmtpLine), "useinbandfec=1") // assume upstream primary codec is opus since we only support it for audio now - if strings.EqualFold(c.MimeType, webrtc.MimeTypeOpus) { + if mime.IsMimeTypeStringOpus(c.MimeType) { d.upstreamPrimaryPT = uint8(c.PayloadType) break } @@ -525,7 +528,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.params.Logger.Errorw("failed to parse primary and secondary payload type for RED", err, "matchedCodec", codec) } d.primaryPT = uint8(primaryPT) - } else if strings.HasPrefix(strings.ToLower(matchedUpstreamCodec.MimeType), "audio/") { + } else if mime.IsMimeTypeStringAudio(matchedUpstreamCodec.MimeType) { isFECEnabled = strings.Contains(strings.ToLower(matchedUpstreamCodec.SDPFmtpLine), "fec") } @@ -558,6 +561,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.params.Logger.Debugw("DownTrack.Bind", logFields...) d.writeStream = t.WriteStream() + d.mime = mime.NormalizeMimeType(codec.MimeType) if rr := d.params.BufferFactory.GetOrNew(packetio.RTCPBufferPacket, d.ssrc).(*buffer.RTCPReader); rr != nil { rr.OnPacket(func(pkt []byte) { d.handleRTCP(pkt) @@ -580,10 +584,11 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.onBinding(nil) } d.setBindStateLocked(bindStateBound) + mimeType := d.mime d.bindLock.Unlock() d.forwarder.DetermineCodec(codec.RTPCodecCapability, d.Receiver().HeaderExtensions()) - d.connectionStats.Start(codec.MimeType, isFECEnabled) + d.connectionStats.Start(mimeType, isFECEnabled) d.params.Logger.Debugw("downtrack bound") } @@ -635,9 +640,9 @@ func (d *DownTrack) handleReceiverReady() { } } -func (d *DownTrack) handleUpstreamCodecChange(mime string) { +func (d *DownTrack) handleUpstreamCodecChange(mimeType string) { d.bindLock.Lock() - if strings.EqualFold(d.codec.MimeType, mime) { + if mime.IsMimeTypeStringEqual(d.codec.MimeType, mimeType) { d.bindLock.Unlock() return } @@ -653,7 +658,7 @@ func (d *DownTrack) handleUpstreamCodecChange(mime string) { var codec webrtc.RTPCodecParameters for _, c := range d.upstreamCodecs { - if !strings.EqualFold(c.MimeType, mime) { + if !mime.IsMimeTypeStringEqual(d.codec.MimeType, mimeType) { continue } @@ -670,7 +675,7 @@ func (d *DownTrack) handleUpstreamCodecChange(mime string) { "can't find matched codec for new upstream payload type", nil, "upstreamCodecs", d.upstreamCodecs, "remoteParameters", d.negotiatedCodecParameters, - "mime", mime, + "mime", mimeType, ) d.bindLock.Unlock() return @@ -679,6 +684,9 @@ func (d *DownTrack) handleUpstreamCodecChange(mime string) { d.payloadType.Store(uint32(codec.PayloadType)) d.payloadTypeRTX.Store(uint32(utils.FindRTXPayloadType(codec.PayloadType, d.negotiatedCodecParameters))) d.codec = codec.RTPCodecCapability + d.mime = mime.NormalizeMimeType(codec.MimeType) + newMimeType := d.mime + isFECEnabled := strings.Contains(strings.ToLower(d.codec.SDPFmtpLine), "fec") d.bindLock.Unlock() d.params.Logger.Infow( @@ -690,6 +698,7 @@ func (d *DownTrack) handleUpstreamCodecChange(mime string) { d.forwarder.Restart() d.forwarder.DetermineCodec(codec.RTPCodecCapability, d.Receiver().HeaderExtensions()) + d.connectionStats.UpdateCodec(newMimeType, isFECEnabled) } // Unbind implements the teardown logic when the track is no longer needed. This happens @@ -744,6 +753,12 @@ func (d *DownTrack) Codec() webrtc.RTPCodecCapability { return d.codec } +func (d *DownTrack) Mime() mime.MimeType { + d.bindLock.Lock() + defer d.bindLock.Unlock() + return d.mime +} + // StreamID is the group this track belongs too. This must be unique func (d *DownTrack) StreamID() string { return d.params.StreamID } @@ -1679,16 +1694,15 @@ func (d *DownTrack) writeBlankFrameRTP(duration float32, generation uint32) chan return } - mime := strings.ToLower(d.Codec().MimeType) var getBlankFrame func(bool) ([]byte, error) - switch { - case strings.EqualFold(mime, webrtc.MimeTypeOpus): + switch d.mime { + case mime.MimeTypeOpus: getBlankFrame = d.getOpusBlankFrame - case strings.EqualFold(mime, MimeTypeAudioRed): + case mime.MimeTypeRED: getBlankFrame = d.getOpusRedBlankFrame - case strings.EqualFold(mime, webrtc.MimeTypeVP8): + case mime.MimeTypeVP8: getBlankFrame = d.getVP8BlankFrame - case strings.EqualFold(mime, webrtc.MimeTypeH264): + case mime.MimeTypeH264: getBlankFrame = d.getH264BlankFrame default: close(done) @@ -1696,7 +1710,7 @@ func (d *DownTrack) writeBlankFrameRTP(duration float32, generation uint32) chan } frameRate := uint32(30) - if mime == strings.ToLower(webrtc.MimeTypeOpus) || mime == strings.ToLower(MimeTypeAudioRed) { + if d.mime == mime.MimeTypeOpus || d.mime == mime.MimeTypeRED { frameRate = 50 } @@ -2414,7 +2428,7 @@ func (d *DownTrack) sendPaddingOnMute() { if d.kind == webrtc.RTPCodecTypeVideo { d.sendPaddingOnMuteForVideo() - } else if strings.EqualFold(d.Codec().MimeType, webrtc.MimeTypeOpus) { + } else if d.mime == mime.MimeTypeOpus { d.sendSilentFrameOnMuteForOpus() } } diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index e0e9f590d..df8ba0a08 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -19,7 +19,6 @@ import ( "fmt" "math" "math/rand" - "strings" "sync" "time" @@ -35,6 +34,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/codecmunger" + "github.com/livekit/livekit-server/pkg/sfu/mime" dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" "github.com/livekit/livekit-server/pkg/sfu/videolayerselector" @@ -210,7 +210,8 @@ func (r refInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { type Forwarder struct { lock sync.RWMutex - codec webrtc.RTPCodecCapability + mime mime.MimeType + clockRate uint32 kind webrtc.RTPCodecType logger logger.Logger skipReferenceTS bool @@ -249,6 +250,7 @@ func NewForwarder( rtpStats *rtpstats.RTPStatsSender, ) *Forwarder { f := &Forwarder{ + mime: mime.MimeTypeUnknown, kind: kind, logger: logger, skipReferenceTS: skipReferenceTS, @@ -299,11 +301,13 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ f.lock.Lock() defer f.lock.Unlock() - codecChanged := f.codec.MimeType != "" && f.codec.MimeType != codec.MimeType + toMimeType := mime.NormalizeMimeType(codec.MimeType) + codecChanged := f.mime != mime.MimeTypeUnknown && f.mime != toMimeType if codecChanged { - f.logger.Debugw("forwarder codec changed", "from", f.codec.MimeType, "to", codec.MimeType) + f.logger.Debugw("forwarder codec changed", "from", f.mime, "to", toMimeType) } - f.codec = codec + f.mime = toMimeType + f.clockRate = codec.ClockRate ddAvailable := func(exts []webrtc.RTPHeaderExtensionParameter) bool { for _, ext := range exts { @@ -314,8 +318,8 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ return false } - switch strings.ToLower(codec.MimeType) { - case "video/vp8": + switch f.mime { + case mime.MimeTypeVP8: f.codecMunger = codecmunger.NewVP8FromNull(f.codecMunger, f.logger) if f.vls != nil { if vls := videolayerselector.NewSimulcastFromOther(f.vls); vls != nil { @@ -328,16 +332,14 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ } f.vls.SetTemporalLayerSelector(temporallayerselector.NewVP8(f.logger)) - case "video/h264": - fallthrough - case "video/h265": + case mime.MimeTypeH264, mime.MimeTypeH265: if f.vls != nil { f.vls = videolayerselector.NewSimulcastFromOther(f.vls) } else { f.vls = videolayerselector.NewSimulcast(f.logger) } - case "video/vp9": + case mime.MimeTypeVP9: // DD-TODO : we only enable dd layer selector for av1/vp9 now, in the future we can enable it for vp8 too isDDAvailable := ddAvailable(extensions) if isDDAvailable { @@ -355,7 +357,7 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ } // SVC-TODO: Support for VP9 simulcast. When DD is not available, have to pick selector based on VP9 SVC or Simulcast - case "video/av1": + case mime.MimeTypeAV1: // DD-TODO : we only enable dd layer selector for av1/vp9 now, in the future we can enable it for vp8 too isDDAvailable := ddAvailable(extensions) if isDDAvailable { @@ -1513,7 +1515,7 @@ func (f *Forwarder) Pause(availableLayers []int32, brs Bitrates) VideoAllocation func (f *Forwarder) updateAllocation(alloc VideoAllocation, reason string) VideoAllocation { // restrict target temporal to 0 if codec does not support temporal layers - if alloc.TargetLayer.IsValid() && strings.ToLower(f.codec.MimeType) == "video/h264" { + if alloc.TargetLayer.IsValid() && f.mime == mime.MimeTypeH264 { alloc.TargetLayer.Temporal = 0 } @@ -1663,7 +1665,7 @@ func (f *Forwarder) getRefLayerRTPTimestamp(ts uint32, refLayer, targetLayer int } ntpDiff := mediatransportutil.NtpTime(srRef.NtpTimestamp).Time().Sub(mediatransportutil.NtpTime(srTarget.NtpTimestamp).Time()) - rtpDiff := ntpDiff.Nanoseconds() * int64(f.codec.ClockRate) / 1e9 + rtpDiff := ntpDiff.Nanoseconds() * int64(f.clockRate) / 1e9 // calculate other layer's time stamp at the same time as ref layer's NTP time normalizedOtherTS := srTarget.RtpTimestamp + uint32(rtpDiff) @@ -1795,7 +1797,7 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e } else { if !f.preStartTime.IsZero() { timeSinceFirst := time.Since(f.preStartTime) - rtpDiff := uint64(timeSinceFirst.Nanoseconds() * int64(f.codec.ClockRate) / 1e9) + rtpDiff := uint64(timeSinceFirst.Nanoseconds() * int64(f.clockRate) / 1e9) extExpectedTS = f.extFirstTS + rtpDiff if f.dummyStartTSOffset == 0 { f.dummyStartTSOffset = extExpectedTS - uint64(refTS) @@ -1839,7 +1841,7 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e // Ideally, extRefTS should not be ahead of extExpectedTS, but extExpectedTS uses the first packet's // wall clock time. So, if the first packet experienced abmormal latency, it is possible // for extRefTS > extExpectedTS - diffSeconds := float64(int64(extExpectedTS-extRefTS)) / float64(f.codec.ClockRate) + diffSeconds := float64(int64(extExpectedTS-extRefTS)) / float64(f.clockRate) if diffSeconds >= 0.0 { if f.resumeBehindThreshold > 0 && diffSeconds > f.resumeBehindThreshold { logTransitionInfo("resume, reference too far behind", extExpectedTS, extRefTS, extLastTS, diffSeconds) @@ -1862,7 +1864,7 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e f.resumeBehindThreshold = 0.0 } else { // switching between layers, check if extRefTS is too far behind the last sent - diffSeconds := float64(int64(extRefTS-extLastTS)) / float64(f.codec.ClockRate) + diffSeconds := float64(int64(extRefTS-extLastTS)) / float64(f.clockRate) if diffSeconds < 0.0 { if math.Abs(diffSeconds) > LayerSwitchBehindThresholdSeconds { // this could be due to pacer trickling out this layer. Error out and wait for a more opportune time. @@ -1877,7 +1879,7 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e logTransition("layer switch, reference is slightly behind", extExpectedTS, extRefTS, extLastTS, diffSeconds) extNextTS = extLastTS + 1 } else { - diffSeconds = float64(int64(extRefTS-extExpectedTS)) / float64(f.codec.ClockRate) + diffSeconds = float64(int64(extRefTS-extExpectedTS)) / float64(f.clockRate) if diffSeconds > SwitchAheadThresholdSeconds { logTransition("layer switch, reference too far ahead", extExpectedTS, extRefTS, extLastTS, diffSeconds) } @@ -2154,7 +2156,7 @@ func (f *Forwarder) GetSnTsForBlankFrames(frameRate uint32, numPackets int) ([]S if int64(extExpectedTS-extLastTS) <= 0 { extExpectedTS = extLastTS + 1 } - snts, err := f.rtpMunger.UpdateAndGetPaddingSnTs(numPackets, f.codec.ClockRate, frameRate, frameEndNeeded, extExpectedTS) + snts, err := f.rtpMunger.UpdateAndGetPaddingSnTs(numPackets, f.clockRate, frameRate, frameEndNeeded, extExpectedTS) return snts, frameEndNeeded, err } diff --git a/pkg/sfu/mime/mimetype.go b/pkg/sfu/mime/mimetype.go new file mode 100644 index 000000000..5e78c57d5 --- /dev/null +++ b/pkg/sfu/mime/mimetype.go @@ -0,0 +1,271 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mime + +import ( + "strings" + + "github.com/pion/webrtc/v4" +) + +const ( + MimeTypePrefixAudio = "audio/" + MimeTypePrefixVideo = "video/" +) + +type MimeTypeCodec int + +const ( + MimeTypeCodecUnknown MimeTypeCodec = iota + MimeTypeCodecH264 + MimeTypeCodecH265 + MimeTypeCodecOpus + MimeTypeCodecRED + MimeTypeCodecVP8 + MimeTypeCodecVP9 + MimeTypeCodecAV1 + MimeTypeCodecG722 + MimeTypeCodecPCMU + MimeTypeCodecPCMA + MimeTypeCodecRTX + MimeTypeCodecFlexFEC +) + +func (m MimeTypeCodec) String() string { + switch m { + case MimeTypeCodecUnknown: + return "MimeTypeCodecUnknown" + case MimeTypeCodecH264: + return "H264" + case MimeTypeCodecH265: + return "H265" + case MimeTypeCodecOpus: + return "opus" + case MimeTypeCodecRED: + return "red" + case MimeTypeCodecVP8: + return "VP8" + case MimeTypeCodecVP9: + return "VP9" + case MimeTypeCodecAV1: + return "AV1" + case MimeTypeCodecG722: + return "G722" + case MimeTypeCodecPCMU: + return "PCMU" + case MimeTypeCodecPCMA: + return "PCMA" + case MimeTypeCodecRTX: + return "rtx" + case MimeTypeCodecFlexFEC: + return "flexfec" + } + + return "MimeTypeCodecUnknown" +} + +func NormalizeMimeTypeCodec(codec string) MimeTypeCodec { + switch { + case strings.EqualFold(codec, "h264"): + return MimeTypeCodecH264 + case strings.EqualFold(codec, "h265"): + return MimeTypeCodecH265 + case strings.EqualFold(codec, "opus"): + return MimeTypeCodecOpus + case strings.EqualFold(codec, "red"): + return MimeTypeCodecRED + case strings.EqualFold(codec, "vp8"): + return MimeTypeCodecVP8 + case strings.EqualFold(codec, "vp9"): + return MimeTypeCodecVP9 + case strings.EqualFold(codec, "av1"): + return MimeTypeCodecAV1 + case strings.EqualFold(codec, "g722"): + return MimeTypeCodecG722 + case strings.EqualFold(codec, "pcmu"): + return MimeTypeCodecPCMU + case strings.EqualFold(codec, "pcma"): + return MimeTypeCodecPCMA + case strings.EqualFold(codec, "rtx"): + return MimeTypeCodecRTX + case strings.EqualFold(codec, "flexfec"): + return MimeTypeCodecFlexFEC + } + + return MimeTypeCodecUnknown +} + +func GetMimeTypeCodec(mime string) MimeTypeCodec { + i := strings.IndexByte(mime, '/') + if i == -1 { + return MimeTypeCodecUnknown + } + + return NormalizeMimeTypeCodec(mime[i+1:]) +} + +func IsMimeTypeCodecStringOpus(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecOpus +} + +func IsMimeTypeCodecStringRED(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecRED +} + +func IsMimeTypeCodecStringH264(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecH264 +} + +type MimeType int + +const ( + MimeTypeUnknown MimeType = iota + MimeTypeH264 + MimeTypeH265 + MimeTypeOpus + MimeTypeRED + MimeTypeVP8 + MimeTypeVP9 + MimeTypeAV1 + MimeTypeG722 + MimeTypePCMU + MimeTypePCMA + MimeTypeRTX + MimeTypeFlexFEC +) + +func (m MimeType) String() string { + switch m { + case MimeTypeUnknown: + return "MimeTypeUnknown" + case MimeTypeH264: + return webrtc.MimeTypeH264 + case MimeTypeH265: + return webrtc.MimeTypeH265 + case MimeTypeOpus: + return webrtc.MimeTypeOpus + case MimeTypeRED: + return "audio/red" + case MimeTypeVP8: + return webrtc.MimeTypeVP8 + case MimeTypeVP9: + return webrtc.MimeTypeVP9 + case MimeTypeAV1: + return webrtc.MimeTypeAV1 + case MimeTypeG722: + return webrtc.MimeTypeG722 + case MimeTypePCMU: + return webrtc.MimeTypePCMU + case MimeTypePCMA: + return webrtc.MimeTypePCMA + case MimeTypeRTX: + return webrtc.MimeTypeRTX + case MimeTypeFlexFEC: + return webrtc.MimeTypeFlexFEC + } + + return "MimeTypeUnknown" +} + +func NormalizeMimeType(mime string) MimeType { + switch { + case strings.EqualFold(mime, webrtc.MimeTypeH264): + return MimeTypeH264 + case strings.EqualFold(mime, webrtc.MimeTypeH265): + return MimeTypeH265 + case strings.EqualFold(mime, webrtc.MimeTypeOpus): + return MimeTypeOpus + case strings.EqualFold(mime, "audio/red"): + return MimeTypeRED + case strings.EqualFold(mime, webrtc.MimeTypeVP8): + return MimeTypeVP8 + case strings.EqualFold(mime, webrtc.MimeTypeVP9): + return MimeTypeVP9 + case strings.EqualFold(mime, webrtc.MimeTypeAV1): + return MimeTypeAV1 + case strings.EqualFold(mime, webrtc.MimeTypeG722): + return MimeTypeG722 + case strings.EqualFold(mime, webrtc.MimeTypePCMU): + return MimeTypePCMU + case strings.EqualFold(mime, webrtc.MimeTypePCMA): + return MimeTypePCMA + case strings.EqualFold(mime, webrtc.MimeTypeRTX): + return MimeTypeRTX + case strings.EqualFold(mime, webrtc.MimeTypeFlexFEC): + return MimeTypeFlexFEC + } + + return MimeTypeUnknown +} + +func IsMimeTypeStringEqual(mime1 string, mime2 string) bool { + return NormalizeMimeType(mime1) == NormalizeMimeType(mime2) +} + +func IsMimeTypeStringAudio(mime string) bool { + return strings.HasPrefix(mime, MimeTypePrefixAudio) +} + +func IsMimeTypeAudio(mimeType MimeType) bool { + return strings.HasPrefix(mimeType.String(), MimeTypePrefixAudio) +} + +func IsMimeTypeStringVideo(mime string) bool { + return strings.HasPrefix(mime, MimeTypePrefixVideo) +} + +func IsMimeTypeVideo(mimeType MimeType) bool { + return strings.HasPrefix(mimeType.String(), MimeTypePrefixVideo) +} + +// SVC-TODO: Have to use more conditions to differentiate between +// SVC-TODO: SVC and non-SVC (could be single layer or simulcast). +// SVC-TODO: May only need to differentiate between simulcast and non-simulcast +// SVC-TODO: i. e. may be possible to treat single layer as SVC to get proper/intended functionality. +func IsMimeTypeSVC(mimeType MimeType) bool { + switch mimeType { + case MimeTypeAV1, MimeTypeVP9: + return true + } + return false +} + +func IsMimeTypeStringSVC(mime string) bool { + return IsMimeTypeSVC(NormalizeMimeType(mime)) +} + +func IsMimeTypeStringRED(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeRED +} + +func IsMimeTypeStringOpus(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeOpus +} + +func IsMimeTypeStringRTX(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeRTX +} + +func IsMimeTypeStringVP8(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeVP8 +} + +func IsMimeTypeStringVP9(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeVP9 +} + +func IsMimeTypeStringH264(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeH264 +} diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 70c8d434c..8c16675bd 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -33,6 +33,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/audio" "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/mime" dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" "github.com/livekit/livekit-server/pkg/sfu/streamtracker" @@ -100,6 +101,7 @@ type TrackReceiver interface { // returns the initial codec of the receiver, it is determined by the track's codec // and will not change if the codec changes during the session (publisher changes codec) Codec() webrtc.RTPCodecParameters + Mime() mime.MimeType HeaderExtensions() []webrtc.RTPHeaderExtensionParameter IsClosed() bool @@ -254,8 +256,8 @@ func NewWebRTCReceiver( codecState: ReceiverCodecStateNormal, kind: track.Kind(), onRTCP: onRTCP, - isSVC: buffer.IsSvcCodec(track.Codec().MimeType), - isRED: buffer.IsRedCodec(track.Codec().MimeType), + isSVC: mime.IsMimeTypeStringSVC(track.Codec().MimeType), + isRED: mime.IsMimeTypeStringRED(track.Codec().MimeType), } for _, opt := range opts { @@ -278,9 +280,9 @@ func NewWebRTCReceiver( } }) w.connectionStats.Start( - w.codec.MimeType, + mime.NormalizeMimeType(w.codec.MimeType), // TODO: technically not correct to declare FEC on when RED. Need the primary codec's fmtp line to check. - strings.EqualFold(w.codec.MimeType, MimeTypeAudioRed) || strings.Contains(strings.ToLower(w.codec.SDPFmtpLine), "useinbandfec=1"), + mime.IsMimeTypeStringRED(w.codec.MimeType) || strings.Contains(strings.ToLower(w.codec.SDPFmtpLine), "useinbandfec=1"), ) w.streamTrackerManager = NewStreamTrackerManager(logger, trackInfo, w.isSVC, w.codec.ClockRate, streamTrackerManagerConfig) @@ -371,6 +373,10 @@ func (w *WebRTCReceiver) Codec() webrtc.RTPCodecParameters { return w.codec } +func (w *WebRTCReceiver) Mime() mime.MimeType { + return mime.NormalizeMimeType(w.codec.MimeType) +} + func (w *WebRTCReceiver) HeaderExtensions() []webrtc.RTPHeaderExtensionParameter { return w.receiver.GetParameters().HeaderExtensions } diff --git a/pkg/sfu/redprimaryreceiver.go b/pkg/sfu/redprimaryreceiver.go index 6023a89ab..f066205b3 100644 --- a/pkg/sfu/redprimaryreceiver.go +++ b/pkg/sfu/redprimaryreceiver.go @@ -27,10 +27,6 @@ import ( "github.com/livekit/protocol/logger" ) -const ( - MimeTypeAudioRed = "audio/red" -) - var ( ErrIncompleteRedHeader = errors.New("incomplete red block header") ErrIncompleteRedBlock = errors.New("incomplete red block payload") diff --git a/pkg/sfu/utils/helpers.go b/pkg/sfu/utils/helpers.go index 09990c183..d237d511a 100644 --- a/pkg/sfu/utils/helpers.go +++ b/pkg/sfu/utils/helpers.go @@ -17,8 +17,8 @@ package utils import ( "errors" "fmt" - "strings" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/pion/interceptor" "github.com/pion/rtp" "github.com/pion/webrtc/v4" @@ -29,7 +29,7 @@ import ( func CodecParametersFuzzySearch(needle webrtc.RTPCodecParameters, haystack []webrtc.RTPCodecParameters) (webrtc.RTPCodecParameters, error) { // First attempt to match on MimeType + SDPFmtpLine for _, c := range haystack { - if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) && + if mime.IsMimeTypeStringEqual(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) && c.RTPCodecCapability.SDPFmtpLine == needle.RTPCodecCapability.SDPFmtpLine { return c, nil } @@ -37,7 +37,7 @@ func CodecParametersFuzzySearch(needle webrtc.RTPCodecParameters, haystack []web // Fallback to just MimeType for _, c := range haystack { - if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) { + if mime.IsMimeTypeStringEqual(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) { return c, nil } } diff --git a/pkg/sfu/utils/mimetype.go b/pkg/sfu/utils/mimetype.go deleted file mode 100644 index cf7735b88..000000000 --- a/pkg/sfu/utils/mimetype.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2023 LiveKit, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -type MimeType int - -const ( - MimeTypeUnknown MimeType = iota - MimeTypeVP8 - MimeTypeVP9 - MimeTypeH264 - MimeTypeAV1 - MimeTypeH265 -) - -func MatchMimeType(mimeType string) MimeType { - switch len(mimeType) { - case 9: - switch mimeType[0] { - case 'v', 'V': - switch mimeType[1] { - case 'i', 'I': - switch mimeType[2] { - case 'd', 'D': - switch mimeType[3] { - case 'e', 'E': - switch mimeType[4] { - case 'o', 'O': - switch mimeType[5] { - case '/': - switch mimeType[6] { - case 'v', 'V': - switch mimeType[7] { - case 'p', 'P': - switch mimeType[8] { - case '8': - return MimeTypeVP8 - case '9': - return MimeTypeVP9 - } - } - case 'a', 'A': - switch mimeType[7] { - case 'v', 'V': - switch mimeType[8] { - case '1': - return MimeTypeAV1 - } - } - } - } - } - } - } - } - } - case 10: - switch mimeType[0] { - case 'v', 'V': - switch mimeType[1] { - case 'i', 'I': - switch mimeType[2] { - case 'd', 'D': - switch mimeType[3] { - case 'e', 'E': - switch mimeType[4] { - case 'o', 'O': - switch mimeType[5] { - case '/': - switch mimeType[6] { - case 'h', 'H': - switch mimeType[7] { - case '2': - switch mimeType[8] { - case '6': - switch mimeType[9] { - case '4': - return MimeTypeH264 - case '5': - return MimeTypeH265 - } - } - } - } - } - } - } - } - } - } - } - return MimeTypeUnknown -} diff --git a/pkg/sfu/utils/mimetype_test.go b/pkg/sfu/utils/mimetype_test.go deleted file mode 100644 index 6cfcc7111..000000000 --- a/pkg/sfu/utils/mimetype_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2023 LiveKit, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/require" -) - -func toLowerSwitch(mimeType string) MimeType { - switch strings.ToLower(mimeType) { - case "video/vp8": - return MimeTypeVP8 - case "video/vp9": - return MimeTypeVP9 - case "video/h264": - return MimeTypeH264 - case "video/av1": - return MimeTypeAV1 - default: - return MimeTypeUnknown - } -} - -func TestMimeTypeMatch(t *testing.T) { - require.Equal(t, MimeTypeVP8, MatchMimeType("VIDEO/VP8"), "VIDEO/VP8") - require.Equal(t, MimeTypeVP9, MatchMimeType("VIDEO/VP9"), "VIDEO/VP9") - require.Equal(t, MimeTypeH264, MatchMimeType("VIDEO/H264"), "VIDEO/H264") - require.Equal(t, MimeTypeAV1, MatchMimeType("VIDEO/AV1"), "VIDEO/AV1") -} - -func BenchmarkMimeTypeMatch(b *testing.B) { - mimeTypes := []string{ - "video/VP8", - "video/VP9", - "video/H264", - "video/AV1", - } - - b.Run("ToLower/switch", func(b *testing.B) { - for i := range b.N { - _ = toLowerSwitch(mimeTypes[i%len(mimeTypes)]) - } - }) - - b.Run("MatchMimeType", func(b *testing.B) { - for i := range b.N { - _ = MatchMimeType(mimeTypes[i%len(mimeTypes)]) - } - }) -} diff --git a/pkg/telemetry/events.go b/pkg/telemetry/events.go index 1e59e0b9f..b6e6c703c 100644 --- a/pkg/telemetry/events.go +++ b/pkg/telemetry/events.go @@ -20,6 +20,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -256,14 +257,14 @@ func (t *telemetryService) TrackMaxSubscribedVideoQuality( ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, - mime string, + mime mime.MimeType, maxQuality livekit.VideoQuality, ) { t.enqueue(func() { room := t.getRoomDetails(participantID) ev := newTrackEvent(livekit.AnalyticsEventType_TRACK_MAX_SUBSCRIBED_VIDEO_QUALITY, room, participantID, track) ev.MaxSubscribedVideoQuality = maxQuality - ev.Mime = mime + ev.Mime = mime.String() t.SendEvent(ctx, ev) }) } @@ -393,7 +394,7 @@ func (t *telemetryService) TrackPublishRTPStats( ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, - mimeType string, + mimeType mime.MimeType, layer int, stats *livekit.RTPStats, ) { @@ -402,7 +403,7 @@ func (t *telemetryService) TrackPublishRTPStats( ev := newRoomEvent(livekit.AnalyticsEventType_TRACK_PUBLISH_STATS, room) ev.ParticipantId = string(participantID) ev.TrackId = string(trackID) - ev.Mime = mimeType + ev.Mime = mimeType.String() ev.VideoLayer = int32(layer) ev.RtpStats = stats t.SendEvent(ctx, ev) @@ -413,7 +414,7 @@ func (t *telemetryService) TrackSubscribeRTPStats( ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, - mimeType string, + mimeType mime.MimeType, stats *livekit.RTPStats, ) { t.enqueue(func() { @@ -421,7 +422,7 @@ func (t *telemetryService) TrackSubscribeRTPStats( ev := newRoomEvent(livekit.AnalyticsEventType_TRACK_SUBSCRIBE_STATS, room) ev.ParticipantId = string(participantID) ev.TrackId = string(trackID) - ev.Mime = mimeType + ev.Mime = mimeType.String() ev.RtpStats = stats t.SendEvent(ctx, ev) }) diff --git a/pkg/telemetry/telemetryfakes/fake_telemetry_service.go b/pkg/telemetry/telemetryfakes/fake_telemetry_service.go index cdfecfdae..f3605cc67 100644 --- a/pkg/telemetry/telemetryfakes/fake_telemetry_service.go +++ b/pkg/telemetry/telemetryfakes/fake_telemetry_service.go @@ -5,6 +5,7 @@ import ( "context" "sync" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/protocol/livekit" ) @@ -152,13 +153,13 @@ type FakeTelemetryService struct { arg1 context.Context arg2 []*livekit.AnalyticsStat } - TrackMaxSubscribedVideoQualityStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, string, livekit.VideoQuality) + TrackMaxSubscribedVideoQualityStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, mime.MimeType, livekit.VideoQuality) trackMaxSubscribedVideoQualityMutex sync.RWMutex trackMaxSubscribedVideoQualityArgsForCall []struct { arg1 context.Context arg2 livekit.ParticipantID arg3 *livekit.TrackInfo - arg4 string + arg4 mime.MimeType arg5 livekit.VideoQuality } TrackMutedStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo) @@ -168,13 +169,13 @@ type FakeTelemetryService struct { arg2 livekit.ParticipantID arg3 *livekit.TrackInfo } - TrackPublishRTPStatsStub func(context.Context, livekit.ParticipantID, livekit.TrackID, string, int, *livekit.RTPStats) + TrackPublishRTPStatsStub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, int, *livekit.RTPStats) trackPublishRTPStatsMutex sync.RWMutex trackPublishRTPStatsArgsForCall []struct { arg1 context.Context arg2 livekit.ParticipantID arg3 livekit.TrackID - arg4 string + arg4 mime.MimeType arg5 int arg6 *livekit.RTPStats } @@ -216,13 +217,13 @@ type FakeTelemetryService struct { arg4 error arg5 bool } - TrackSubscribeRTPStatsStub func(context.Context, livekit.ParticipantID, livekit.TrackID, string, *livekit.RTPStats) + TrackSubscribeRTPStatsStub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, *livekit.RTPStats) trackSubscribeRTPStatsMutex sync.RWMutex trackSubscribeRTPStatsArgsForCall []struct { arg1 context.Context arg2 livekit.ParticipantID arg3 livekit.TrackID - arg4 string + arg4 mime.MimeType arg5 *livekit.RTPStats } TrackSubscribeRequestedStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo) @@ -1003,13 +1004,13 @@ func (fake *FakeTelemetryService) SendStatsArgsForCall(i int) (context.Context, return argsForCall.arg1, argsForCall.arg2 } -func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQuality(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo, arg4 string, arg5 livekit.VideoQuality) { +func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQuality(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo, arg4 mime.MimeType, arg5 livekit.VideoQuality) { fake.trackMaxSubscribedVideoQualityMutex.Lock() fake.trackMaxSubscribedVideoQualityArgsForCall = append(fake.trackMaxSubscribedVideoQualityArgsForCall, struct { arg1 context.Context arg2 livekit.ParticipantID arg3 *livekit.TrackInfo - arg4 string + arg4 mime.MimeType arg5 livekit.VideoQuality }{arg1, arg2, arg3, arg4, arg5}) stub := fake.TrackMaxSubscribedVideoQualityStub @@ -1026,13 +1027,13 @@ func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityCallCount() int return len(fake.trackMaxSubscribedVideoQualityArgsForCall) } -func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, string, livekit.VideoQuality)) { +func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, mime.MimeType, livekit.VideoQuality)) { fake.trackMaxSubscribedVideoQualityMutex.Lock() defer fake.trackMaxSubscribedVideoQualityMutex.Unlock() fake.TrackMaxSubscribedVideoQualityStub = stub } -func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo, string, livekit.VideoQuality) { +func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo, mime.MimeType, livekit.VideoQuality) { fake.trackMaxSubscribedVideoQualityMutex.RLock() defer fake.trackMaxSubscribedVideoQualityMutex.RUnlock() argsForCall := fake.trackMaxSubscribedVideoQualityArgsForCall[i] @@ -1073,13 +1074,13 @@ func (fake *FakeTelemetryService) TrackMutedArgsForCall(i int) (context.Context, return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 } -func (fake *FakeTelemetryService) TrackPublishRTPStats(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 string, arg5 int, arg6 *livekit.RTPStats) { +func (fake *FakeTelemetryService) TrackPublishRTPStats(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 mime.MimeType, arg5 int, arg6 *livekit.RTPStats) { fake.trackPublishRTPStatsMutex.Lock() fake.trackPublishRTPStatsArgsForCall = append(fake.trackPublishRTPStatsArgsForCall, struct { arg1 context.Context arg2 livekit.ParticipantID arg3 livekit.TrackID - arg4 string + arg4 mime.MimeType arg5 int arg6 *livekit.RTPStats }{arg1, arg2, arg3, arg4, arg5, arg6}) @@ -1097,13 +1098,13 @@ func (fake *FakeTelemetryService) TrackPublishRTPStatsCallCount() int { return len(fake.trackPublishRTPStatsArgsForCall) } -func (fake *FakeTelemetryService) TrackPublishRTPStatsCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, string, int, *livekit.RTPStats)) { +func (fake *FakeTelemetryService) TrackPublishRTPStatsCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, int, *livekit.RTPStats)) { fake.trackPublishRTPStatsMutex.Lock() defer fake.trackPublishRTPStatsMutex.Unlock() fake.TrackPublishRTPStatsStub = stub } -func (fake *FakeTelemetryService) TrackPublishRTPStatsArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, string, int, *livekit.RTPStats) { +func (fake *FakeTelemetryService) TrackPublishRTPStatsArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, int, *livekit.RTPStats) { fake.trackPublishRTPStatsMutex.RLock() defer fake.trackPublishRTPStatsMutex.RUnlock() argsForCall := fake.trackPublishRTPStatsArgsForCall[i] @@ -1283,13 +1284,13 @@ func (fake *FakeTelemetryService) TrackSubscribeFailedArgsForCall(i int) (contex return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 } -func (fake *FakeTelemetryService) TrackSubscribeRTPStats(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 string, arg5 *livekit.RTPStats) { +func (fake *FakeTelemetryService) TrackSubscribeRTPStats(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 mime.MimeType, arg5 *livekit.RTPStats) { fake.trackSubscribeRTPStatsMutex.Lock() fake.trackSubscribeRTPStatsArgsForCall = append(fake.trackSubscribeRTPStatsArgsForCall, struct { arg1 context.Context arg2 livekit.ParticipantID arg3 livekit.TrackID - arg4 string + arg4 mime.MimeType arg5 *livekit.RTPStats }{arg1, arg2, arg3, arg4, arg5}) stub := fake.TrackSubscribeRTPStatsStub @@ -1306,13 +1307,13 @@ func (fake *FakeTelemetryService) TrackSubscribeRTPStatsCallCount() int { return len(fake.trackSubscribeRTPStatsArgsForCall) } -func (fake *FakeTelemetryService) TrackSubscribeRTPStatsCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, string, *livekit.RTPStats)) { +func (fake *FakeTelemetryService) TrackSubscribeRTPStatsCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, *livekit.RTPStats)) { fake.trackSubscribeRTPStatsMutex.Lock() defer fake.trackSubscribeRTPStatsMutex.Unlock() fake.TrackSubscribeRTPStatsStub = stub } -func (fake *FakeTelemetryService) TrackSubscribeRTPStatsArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, string, *livekit.RTPStats) { +func (fake *FakeTelemetryService) TrackSubscribeRTPStatsArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, *livekit.RTPStats) { fake.trackSubscribeRTPStatsMutex.RLock() defer fake.trackSubscribeRTPStatsMutex.RUnlock() argsForCall := fake.trackSubscribeRTPStatsArgsForCall[i] diff --git a/pkg/telemetry/telemetryservice.go b/pkg/telemetry/telemetryservice.go index 226a52414..9d2cdf371 100644 --- a/pkg/telemetry/telemetryservice.go +++ b/pkg/telemetry/telemetryservice.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -64,9 +65,9 @@ type TelemetryService interface { // TrackPublishedUpdate - track metadata has been updated TrackPublishedUpdate(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) // TrackMaxSubscribedVideoQuality - publisher is notified of the max quality subscribers desire - TrackMaxSubscribedVideoQuality(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, mime string, maxQuality livekit.VideoQuality) - TrackPublishRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType string, layer int, stats *livekit.RTPStats) - TrackSubscribeRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType string, stats *livekit.RTPStats) + TrackMaxSubscribedVideoQuality(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, mime mime.MimeType, maxQuality livekit.VideoQuality) + TrackPublishRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType mime.MimeType, layer int, stats *livekit.RTPStats) + TrackSubscribeRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType mime.MimeType, stats *livekit.RTPStats) EgressStarted(ctx context.Context, info *livekit.EgressInfo) EgressUpdated(ctx context.Context, info *livekit.EgressInfo) EgressEnded(ctx context.Context, info *livekit.EgressInfo) diff --git a/test/client/client.go b/test/client/client.go index 7304532fa..b54fedb1e 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -22,7 +22,6 @@ import ( "net/http" "net/url" "path/filepath" - "strings" "sync" "time" @@ -43,6 +42,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/transport/transportfakes" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" ) type SignalRequestHandler func(msg *livekit.SignalRequest) error @@ -103,9 +103,9 @@ var ( }, } extMimeMapping = map[string]string{ - ".ivf": webrtc.MimeTypeVP8, - ".h264": webrtc.MimeTypeH264, - ".ogg": webrtc.MimeTypeOpus, + ".ivf": mime.MimeTypeVP8.String(), + ".h264": mime.MimeTypeH264.String(), + ".ogg": mime.MimeTypeOpus.String(), } ) @@ -195,7 +195,7 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { var disabled bool if opts != nil { for _, dc := range opts.DisabledCodecs { - if strings.EqualFold(dc.MimeType, codec.Mime) && (dc.SDPFmtpLine == "" || dc.SDPFmtpLine == codec.FmtpLine) { + if mime.IsMimeTypeStringEqual(dc.MimeType, codec.Mime) && (dc.SDPFmtpLine == "" || dc.SDPFmtpLine == codec.FmtpLine) { disabled = true break } diff --git a/test/client/trackwriter.go b/test/client/trackwriter.go index 2dc4a5a38..803bf3348 100644 --- a/test/client/trackwriter.go +++ b/test/client/trackwriter.go @@ -18,7 +18,6 @@ import ( "context" "io" "os" - "strings" "time" "github.com/pion/webrtc/v4" @@ -27,6 +26,7 @@ import ( "github.com/pion/webrtc/v4/pkg/media/ivfreader" "github.com/pion/webrtc/v4/pkg/media/oggreader" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/logger" ) @@ -37,7 +37,7 @@ type TrackWriter struct { cancel context.CancelFunc track *webrtc.TrackLocalStaticSample filePath string - mime string + mime mime.MimeType ogg *oggreader.OggReader ivfheader *ivfreader.IVFFileHeader @@ -52,7 +52,7 @@ func NewTrackWriter(ctx context.Context, track *webrtc.TrackLocalStaticSample, f cancel: cancel, track: track, filePath: filePath, - mime: track.Codec().MimeType, + mime: mime.NormalizeMimeType(track.Codec().MimeType), } } @@ -67,23 +67,25 @@ func (w *TrackWriter) Start() error { return err } - logger.Debugw("starting track writer", + logger.Debugw( + "starting track writer", "trackID", w.track.ID(), - "mime", w.mime) + "mime", w.mime, + ) switch w.mime { - case webrtc.MimeTypeOpus: + case mime.MimeTypeOpus: w.ogg, _, err = oggreader.NewWith(file) if err != nil { return err } go w.writeOgg() - case webrtc.MimeTypeVP8: + case mime.MimeTypeVP8: w.ivf, w.ivfheader, err = ivfreader.NewWith(file) if err != nil { return err } go w.writeVP8() - case webrtc.MimeTypeH264: + case mime.MimeTypeH264: w.h264, err = h264reader.NewReader(file) if err != nil { return err @@ -104,7 +106,7 @@ func (w *TrackWriter) writeNull() { for { select { case <-time.After(20 * time.Millisecond): - if strings.EqualFold(w.mime, webrtc.MimeTypeH264) { + if w.mime == mime.MimeTypeH264 { w.track.WriteSample(h264Sample) } else { w.track.WriteSample(sample) diff --git a/test/singlenode_test.go b/test/singlenode_test.go index 32946c963..b11e5da77 100644 --- a/test/singlenode_test.go +++ b/test/singlenode_test.go @@ -38,6 +38,7 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/livekit-server/pkg/sfu/datachannel" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/testutils" testclient "github.com/livekit/livekit-server/test/client" ) @@ -156,7 +157,7 @@ func TestSinglePublisher(t *testing.T) { waitUntilConnected(t, c1, c2) // publish a track and ensure clients receive it ok - t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcamaudio") + t1, err := c1.AddStaticTrack("audio/OPUS", "audio", "webcamaudio") require.NoError(t, err) defer t1.Stop() t2, err := c1.AddStaticTrack("video/vp8", "video", "webcamvideo") @@ -280,7 +281,7 @@ func Test_RenegotiationWithDifferentCodecs(t *testing.T) { tracks := c2.SubscribedTracks()[c1.ID()] for _, t := range tracks { - if strings.EqualFold(t.Codec().MimeType, "video/vp8") { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { return "" } @@ -308,9 +309,9 @@ func Test_RenegotiationWithDifferentCodecs(t *testing.T) { var vp8Found, h264Found bool tracks := c2.SubscribedTracks()[c1.ID()] for _, t := range tracks { - if strings.EqualFold(t.Codec().MimeType, "video/vp8") { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { vp8Found = true - } else if strings.EqualFold(t.Codec().MimeType, "video/h264") { + } else if mime.IsMimeTypeStringH264(t.Codec().MimeType) { h264Found = true } } @@ -594,8 +595,8 @@ func TestDeviceCodecOverride(t *testing.T) { hasSeenVP8 := false for _, a := range desc.Attributes { if a.Key == "rtpmap" { - require.NotContains(t, a.Value, "H264", "should not contain H264 codec") - if strings.Contains(a.Value, "VP8") { + require.NotContains(t, a.Value, mime.MimeTypeCodecH264.String(), "should not contain H264 codec") + if strings.Contains(a.Value, mime.MimeTypeCodecVP8.String()) { hasSeenVP8 = true } } @@ -641,7 +642,7 @@ func TestSubscribeToCodecUnsupported(t *testing.T) { tracks := c2.SubscribedTracks()[c1.ID()] for _, t := range tracks { - if strings.EqualFold(t.Codec().MimeType, "video/vp8") { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { return "" } } @@ -663,7 +664,7 @@ func TestSubscribeToCodecUnsupported(t *testing.T) { remoteC1 := c2.GetRemoteParticipant(c1.ID()) require.NotNil(t, remoteC1) for _, track := range remoteC1.Tracks { - if strings.EqualFold(track.MimeType, "video/h264") { + if mime.IsMimeTypeStringH264(track.MimeType) { h264TrackID = track.Sid return true } @@ -698,7 +699,7 @@ func TestSubscribeToCodecUnsupported(t *testing.T) { var vp8Count int tracks := c2.SubscribedTracks()[c1.ID()] for _, t := range tracks { - if strings.EqualFold(t.Codec().MimeType, "video/vp8") { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { vp8Count++ } }