diff --git a/pkg/clientconfiguration/conf.go b/pkg/clientconfiguration/conf.go index cc069fc5c..992cca92d 100644 --- a/pkg/clientconfiguration/conf.go +++ b/pkg/clientconfiguration/conf.go @@ -15,6 +15,7 @@ package clientconfiguration import ( + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) @@ -28,7 +29,7 @@ var StaticConfigurations = []ConfigurationItem{ { Match: &ScriptMatch{Expr: `c.browser == "safari"`}, Configuration: &livekit.ClientConfiguration{DisabledCodecs: &livekit.DisabledCodecs{Codecs: []*livekit.Codec{ - {Mime: "video/av1"}, + {Mime: mime.MimeTypeAV1.String()}, }}}, Merge: false, }, @@ -37,7 +38,7 @@ var StaticConfigurations = []ConfigurationItem{ ((c.browser == "firefox" || c.browser == "firefox mobile") && (c.os == "linux" || c.os == "android"))`}, Configuration: &livekit.ClientConfiguration{ DisabledCodecs: &livekit.DisabledCodecs{ - Publish: []*livekit.Codec{{Mime: "video/h264"}}, + Publish: []*livekit.Codec{{Mime: mime.MimeTypeH264.String()}}, }, }, Merge: false, diff --git a/pkg/config/config.go b/pkg/config/config.go index 2a56fe730..5c04c5f91 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -22,7 +22,6 @@ import ( "time" "github.com/mitchellh/go-homedir" - "github.com/pion/webrtc/v4" "github.com/pkg/errors" "github.com/urfave/cli/v2" "gopkg.in/yaml.v3" @@ -31,6 +30,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/bwe/remotebwe" "github.com/livekit/livekit-server/pkg/sfu/bwe/sendsidebwe" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" "github.com/livekit/mediatransportutil/pkg/rtcconfig" "github.com/livekit/protocol/livekit" @@ -336,13 +336,13 @@ var DefaultConfig = Config{ Room: RoomConfig{ AutoCreate: true, EnabledCodecs: []CodecSpec{ - {Mime: webrtc.MimeTypeOpus}, - {Mime: sfu.MimeTypeAudioRed}, - {Mime: webrtc.MimeTypeVP8}, - {Mime: webrtc.MimeTypeH264}, - {Mime: webrtc.MimeTypeVP9}, - {Mime: webrtc.MimeTypeAV1}, - {Mime: webrtc.MimeTypeRTX}, + {Mime: mime.MimeTypeOpus.String()}, + {Mime: mime.MimeTypeRED.String()}, + {Mime: mime.MimeTypeVP8.String()}, + {Mime: mime.MimeTypeH264.String()}, + {Mime: mime.MimeTypeVP9.String()}, + {Mime: mime.MimeTypeAV1.String()}, + {Mime: mime.MimeTypeRTX.String()}, }, EmptyTimeout: 5 * 60, DepartureTimeout: 20, diff --git a/pkg/rtc/dynacast/dynacastmanager.go b/pkg/rtc/dynacast/dynacastmanager.go index 92ba192c5..b2ef9bc20 100644 --- a/pkg/rtc/dynacast/dynacastmanager.go +++ b/pkg/rtc/dynacast/dynacastmanager.go @@ -15,7 +15,6 @@ package dynacast import ( - "strings" "sync" "time" @@ -26,6 +25,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/utils" ) @@ -38,10 +38,10 @@ type DynacastManager struct { params DynacastManagerParams lock sync.RWMutex - regressedCodec map[string]struct{} - dynacastQuality map[string]*DynacastQuality // mime type => DynacastQuality - maxSubscribedQuality map[string]livekit.VideoQuality - committedMaxSubscribedQuality map[string]livekit.VideoQuality + regressedCodec map[mime.MimeType]struct{} + dynacastQuality map[mime.MimeType]*DynacastQuality + maxSubscribedQuality map[mime.MimeType]livekit.VideoQuality + committedMaxSubscribedQuality map[mime.MimeType]livekit.VideoQuality maxSubscribedQualityDebounce func(func()) maxSubscribedQualityDebouncePending bool @@ -59,10 +59,10 @@ func NewDynacastManager(params DynacastManagerParams) *DynacastManager { } d := &DynacastManager{ params: params, - regressedCodec: make(map[string]struct{}), - dynacastQuality: make(map[string]*DynacastQuality), - maxSubscribedQuality: make(map[string]livekit.VideoQuality), - committedMaxSubscribedQuality: make(map[string]livekit.VideoQuality), + regressedCodec: make(map[mime.MimeType]struct{}), + dynacastQuality: make(map[mime.MimeType]*DynacastQuality), + maxSubscribedQuality: make(map[mime.MimeType]livekit.VideoQuality), + committedMaxSubscribedQuality: make(map[mime.MimeType]livekit.VideoQuality), qualityNotifyOpQueue: utils.NewOpsQueue(utils.OpsQueueParams{ Name: "quality-notify", MinSize: 64, @@ -83,11 +83,11 @@ func (d *DynacastManager) OnSubscribedMaxQualityChange(f func(subscribedQualitie d.lock.Unlock() } -func (d *DynacastManager) AddCodec(mime string) { +func (d *DynacastManager) AddCodec(mime mime.MimeType) { d.getOrCreateDynacastQuality(mime) } -func (d *DynacastManager) HandleCodecRegression(fromMime, toMime string) { +func (d *DynacastManager) HandleCodecRegression(fromMime, toMime mime.MimeType) { fromDq := d.getOrCreateDynacastQuality(fromMime) d.lock.Lock() @@ -96,32 +96,31 @@ func (d *DynacastManager) HandleCodecRegression(fromMime, toMime string) { return } - normalizedFromMime, normalizedToMime := strings.ToLower(fromMime), strings.ToLower(toMime) if fromDq == nil { // should not happen as we have added the codec on setup receiver - d.params.Logger.Warnw("regression from codec not found", nil, "mime", normalizedFromMime) + d.params.Logger.Warnw("regression from codec not found", nil, "mime", fromMime) d.lock.Unlock() return } - d.regressedCodec[normalizedFromMime] = struct{}{} - d.maxSubscribedQuality[normalizedFromMime] = livekit.VideoQuality_OFF + d.regressedCodec[fromMime] = struct{}{} + d.maxSubscribedQuality[fromMime] = livekit.VideoQuality_OFF // if the new codec is not added, notify the publisher to start publishing - if _, ok := d.maxSubscribedQuality[normalizedToMime]; !ok { - d.maxSubscribedQuality[normalizedToMime] = livekit.VideoQuality_HIGH + if _, ok := d.maxSubscribedQuality[toMime]; !ok { + d.maxSubscribedQuality[toMime] = livekit.VideoQuality_HIGH } d.lock.Unlock() d.update(false) fromDq.Stop() - ToDq := d.getOrCreateDynacastQuality(normalizedToMime) + ToDq := d.getOrCreateDynacastQuality(toMime) fromDq.RegressTo(ToDq) } func (d *DynacastManager) Restart() { d.lock.Lock() - d.committedMaxSubscribedQuality = make(map[string]livekit.VideoQuality) + d.committedMaxSubscribedQuality = make(map[mime.MimeType]livekit.VideoQuality) dqs := d.getDynacastQualitiesLocked() d.lock.Unlock() @@ -136,7 +135,7 @@ func (d *DynacastManager) Close() { d.lock.Lock() dqs := d.getDynacastQualitiesLocked() - d.dynacastQuality = make(map[string]*DynacastQuality) + d.dynacastQuality = make(map[mime.MimeType]*DynacastQuality) d.isClosed = true d.lock.Unlock() @@ -169,7 +168,7 @@ func (d *DynacastManager) ForceQuality(quality livekit.VideoQuality) { d.enqueueSubscribedQualityChange() } -func (d *DynacastManager) NotifySubscriberMaxQuality(subscriberID livekit.ParticipantID, mime string, quality livekit.VideoQuality) { +func (d *DynacastManager) NotifySubscriberMaxQuality(subscriberID livekit.ParticipantID, mime mime.MimeType, quality livekit.VideoQuality) { dq := d.getOrCreateDynacastQuality(mime) if dq != nil { dq.NotifySubscriberMaxQuality(subscriberID, quality) @@ -185,7 +184,7 @@ func (d *DynacastManager) NotifySubscriberNodeMaxQuality(nodeID livekit.NodeID, } } -func (d *DynacastManager) getOrCreateDynacastQuality(mime string) *DynacastQuality { +func (d *DynacastManager) getOrCreateDynacastQuality(mimeType mime.MimeType) *DynacastQuality { d.lock.Lock() defer d.lock.Unlock() @@ -193,21 +192,18 @@ func (d *DynacastManager) getOrCreateDynacastQuality(mime string) *DynacastQuali return nil } - normalizedMime := strings.ToLower(mime) - if dq := d.dynacastQuality[normalizedMime]; dq != nil { + if dq := d.dynacastQuality[mimeType]; dq != nil { return dq } dq := NewDynacastQuality(DynacastQualityParams{ - MimeType: normalizedMime, + MimeType: mimeType, Logger: d.params.Logger, }) - dq.OnSubscribedMaxQualityChange(func(mimeType string, maxQuality livekit.VideoQuality) { - d.updateMaxQualityForMime(mimeType, maxQuality) - }) + dq.OnSubscribedMaxQualityChange(d.updateMaxQualityForMime) dq.Start() - d.dynacastQuality[normalizedMime] = dq + d.dynacastQuality[mimeType] = dq return dq } @@ -215,7 +211,7 @@ func (d *DynacastManager) getDynacastQualitiesLocked() []*DynacastQuality { return maps.Values(d.dynacastQuality) } -func (d *DynacastManager) updateMaxQualityForMime(mime string, maxQuality livekit.VideoQuality) { +func (d *DynacastManager) updateMaxQualityForMime(mime mime.MimeType, maxQuality livekit.VideoQuality) { d.lock.Lock() if _, ok := d.regressedCodec[mime]; !ok { d.maxSubscribedQuality[mime] = maxQuality @@ -297,7 +293,7 @@ func (d *DynacastManager) update(force bool) { ) // commit change - d.committedMaxSubscribedQuality = make(map[string]livekit.VideoQuality, len(d.maxSubscribedQuality)) + d.committedMaxSubscribedQuality = make(map[mime.MimeType]livekit.VideoQuality, len(d.maxSubscribedQuality)) for mime, quality := range d.maxSubscribedQuality { d.committedMaxSubscribedQuality[mime] = quality } @@ -321,7 +317,7 @@ func (d *DynacastManager) enqueueSubscribedQualityChange() { if quality == livekit.VideoQuality_OFF { subscribedCodecs = append(subscribedCodecs, &livekit.SubscribedCodec{ - Codec: mime, + Codec: mime.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -337,7 +333,7 @@ func (d *DynacastManager) enqueueSubscribedQualityChange() { }) } subscribedCodecs = append(subscribedCodecs, &livekit.SubscribedCodec{ - Codec: mime, + Codec: mime.String(), Qualities: subscribedQualities, }) } diff --git a/pkg/rtc/dynacast/dynacastmanager_test.go b/pkg/rtc/dynacast/dynacastmanager_test.go index f01b2edf9..8a30d7e97 100644 --- a/pkg/rtc/dynacast/dynacastmanager_test.go +++ b/pkg/rtc/dynacast/dynacastmanager_test.go @@ -16,15 +16,14 @@ package dynacast import ( "sort" - "strings" "sync" "testing" "time" - "github.com/pion/webrtc/v4" "github.com/stretchr/testify/require" "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) @@ -40,15 +39,15 @@ func TestSubscribedMaxQuality(t *testing.T) { lock.Unlock() }) - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_HIGH) - dm.NotifySubscriberMaxQuality("s2", webrtc.MimeTypeAV1, livekit.VideoQuality_HIGH) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_HIGH) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeAV1, livekit.VideoQuality_HIGH) // mute all subscribers of vp8 - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_OFF) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_OFF) expectedSubscribedQualities := []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -56,7 +55,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -87,17 +86,17 @@ func TestSubscribedMaxQuality(t *testing.T) { lock.Unlock() }) - dm.maxSubscribedQuality = map[string]livekit.VideoQuality{ - strings.ToLower(webrtc.MimeTypeVP8): livekit.VideoQuality_LOW, - strings.ToLower(webrtc.MimeTypeAV1): livekit.VideoQuality_LOW, + dm.maxSubscribedQuality = map[mime.MimeType]livekit.VideoQuality{ + mime.MimeTypeVP8: livekit.VideoQuality_LOW, + mime.MimeTypeAV1: livekit.VideoQuality_LOW, } - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_HIGH) - dm.NotifySubscriberMaxQuality("s2", webrtc.MimeTypeVP8, livekit.VideoQuality_MEDIUM) - dm.NotifySubscriberMaxQuality("s3", webrtc.MimeTypeAV1, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_HIGH) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeVP8, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberMaxQuality("s3", mime.MimeTypeAV1, livekit.VideoQuality_MEDIUM) expectedSubscribedQualities := []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -105,7 +104,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -121,11 +120,11 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // "s1" dropping to MEDIUM should disable HIGH layer - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_MEDIUM) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -133,7 +132,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -149,13 +148,13 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // "s1" , "s2" , "s3" dropping to LOW should disable HIGH & MEDIUM - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_LOW) - dm.NotifySubscriberMaxQuality("s2", webrtc.MimeTypeVP8, livekit.VideoQuality_LOW) - dm.NotifySubscriberMaxQuality("s3", webrtc.MimeTypeAV1, livekit.VideoQuality_LOW) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_LOW) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeVP8, livekit.VideoQuality_LOW) + dm.NotifySubscriberMaxQuality("s3", mime.MimeTypeAV1, livekit.VideoQuality_LOW) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -163,7 +162,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -179,7 +178,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // muting "s2" only should not disable all qualities of vp8, no change of expected qualities - dm.NotifySubscriberMaxQuality("s2", webrtc.MimeTypeVP8, livekit.VideoQuality_OFF) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeVP8, livekit.VideoQuality_OFF) time.Sleep(100 * time.Millisecond) require.Eventually(t, func() bool { @@ -190,12 +189,12 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // muting "s1" and s3 also should disable all qualities - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_OFF) - dm.NotifySubscriberMaxQuality("s3", webrtc.MimeTypeAV1, livekit.VideoQuality_OFF) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_OFF) + dm.NotifySubscriberMaxQuality("s3", mime.MimeTypeAV1, livekit.VideoQuality_OFF) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -203,7 +202,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -219,11 +218,11 @@ func TestSubscribedMaxQuality(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) // unmuting "s1" should enable vp8 previously set max quality - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeVP8, livekit.VideoQuality_LOW) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_LOW) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -231,7 +230,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -248,13 +247,13 @@ func TestSubscribedMaxQuality(t *testing.T) { // a higher quality from a different node should trigger that quality dm.NotifySubscriberNodeMaxQuality("n1", []types.SubscribedCodecQuality{ - {CodecMime: webrtc.MimeTypeVP8, Quality: livekit.VideoQuality_HIGH}, - {CodecMime: webrtc.MimeTypeAV1, Quality: livekit.VideoQuality_MEDIUM}, + {CodecMime: mime.MimeTypeVP8, Quality: livekit.VideoQuality_HIGH}, + {CodecMime: mime.MimeTypeAV1, Quality: livekit.VideoQuality_MEDIUM}, }) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -262,7 +261,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -289,11 +288,11 @@ func TestCodecRegression(t *testing.T) { lock.Unlock() }) - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeAV1, livekit.VideoQuality_HIGH) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeAV1, livekit.VideoQuality_HIGH) expectedSubscribedQualities := []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -308,11 +307,11 @@ func TestCodecRegression(t *testing.T) { return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) }, 10*time.Second, 100*time.Millisecond) - dm.HandleCodecRegression(webrtc.MimeTypeAV1, webrtc.MimeTypeVP8) + dm.HandleCodecRegression(mime.MimeTypeAV1, mime.MimeTypeVP8) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -320,7 +319,7 @@ func TestCodecRegression(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, @@ -337,13 +336,13 @@ func TestCodecRegression(t *testing.T) { // av1 quality change should be forwarded to vp8 // av1 quality change of node should be ignored - dm.NotifySubscriberMaxQuality("s1", webrtc.MimeTypeAV1, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeAV1, livekit.VideoQuality_MEDIUM) dm.NotifySubscriberNodeMaxQuality("n1", []types.SubscribedCodecQuality{ - {CodecMime: webrtc.MimeTypeAV1, Quality: livekit.VideoQuality_HIGH}, + {CodecMime: mime.MimeTypeAV1, Quality: livekit.VideoQuality_HIGH}, }) expectedSubscribedQualities = []*livekit.SubscribedCodec{ { - Codec: strings.ToLower(webrtc.MimeTypeAV1), + Codec: mime.MimeTypeAV1.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: false}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, @@ -351,7 +350,7 @@ func TestCodecRegression(t *testing.T) { }, }, { - Codec: strings.ToLower(webrtc.MimeTypeVP8), + Codec: mime.MimeTypeVP8.String(), Qualities: []*livekit.SubscribedQuality{ {Quality: livekit.VideoQuality_LOW, Enabled: true}, {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, diff --git a/pkg/rtc/dynacast/dynacastquality.go b/pkg/rtc/dynacast/dynacastquality.go index 957874b36..77b6aaad2 100644 --- a/pkg/rtc/dynacast/dynacastquality.go +++ b/pkg/rtc/dynacast/dynacastquality.go @@ -18,6 +18,7 @@ import ( "sync" "time" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) @@ -27,7 +28,7 @@ const ( ) type DynacastQualityParams struct { - MimeType string + MimeType mime.MimeType Logger logger.Logger } @@ -44,7 +45,7 @@ type DynacastQuality struct { maxQualityTimer *time.Timer regressTo *DynacastQuality - onSubscribedMaxQualityChange func(mimeType string, maxSubscribedQuality livekit.VideoQuality) + onSubscribedMaxQualityChange func(mimeType mime.MimeType, maxSubscribedQuality livekit.VideoQuality) } func NewDynacastQuality(params DynacastQualityParams) *DynacastQuality { @@ -67,7 +68,7 @@ func (d *DynacastQuality) Stop() { d.stopMaxQualityTimer() } -func (d *DynacastQuality) OnSubscribedMaxQualityChange(f func(mimeType string, maxSubscribedQuality livekit.VideoQuality)) { +func (d *DynacastQuality) OnSubscribedMaxQualityChange(f func(mimeType mime.MimeType, maxSubscribedQuality livekit.VideoQuality)) { d.lock.Lock() defer d.lock.Unlock() d.onSubscribedMaxQualityChange = f diff --git a/pkg/rtc/mediaengine.go b/pkg/rtc/mediaengine.go index 81aab0176..a6034468e 100644 --- a/pkg/rtc/mediaengine.go +++ b/pkg/rtc/mediaengine.go @@ -20,24 +20,24 @@ import ( "github.com/pion/webrtc/v4" - "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) var OpusCodecCapability = webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeOpus, + MimeType: mime.MimeTypeOpus.String(), ClockRate: 48000, Channels: 2, SDPFmtpLine: "minptime=10;useinbandfec=1", } var RedCodecCapability = webrtc.RTPCodecCapability{ - MimeType: sfu.MimeTypeAudioRed, + MimeType: mime.MimeTypeRED.String(), ClockRate: 48000, Channels: 2, SDPFmtpLine: "111/111", } var videoRTX = webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeRTX, + MimeType: mime.MimeTypeRTX.String(), ClockRate: 90000, } @@ -70,7 +70,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac for _, codec := range []webrtc.RTPCodecParameters{ { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeVP8, + MimeType: mime.MimeTypeVP8.String(), ClockRate: 90000, RTCPFeedback: rtcpFeedback.Video, }, @@ -78,7 +78,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeVP9, + MimeType: mime.MimeTypeVP9.String(), ClockRate: 90000, SDPFmtpLine: "profile-id=0", RTCPFeedback: rtcpFeedback.Video, @@ -87,7 +87,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeVP9, + MimeType: mime.MimeTypeVP9.String(), ClockRate: 90000, SDPFmtpLine: "profile-id=1", RTCPFeedback: rtcpFeedback.Video, @@ -96,7 +96,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeH264, + MimeType: mime.MimeTypeH264.String(), ClockRate: 90000, SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", RTCPFeedback: rtcpFeedback.Video, @@ -105,7 +105,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeH264, + MimeType: mime.MimeTypeH264.String(), ClockRate: 90000, SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f", RTCPFeedback: rtcpFeedback.Video, @@ -114,7 +114,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeH264, + MimeType: mime.MimeTypeH264.String(), ClockRate: 90000, SDPFmtpLine: h264HighProfileFmtp, RTCPFeedback: rtcpFeedback.Video, @@ -123,7 +123,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac }, { RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeAV1, + MimeType: mime.MimeTypeAV1.String(), ClockRate: 90000, RTCPFeedback: rtcpFeedback.Video, }, @@ -141,7 +141,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac if filterOutH264HighProfile && codec.RTPCodecCapability.SDPFmtpLine == h264HighProfileFmtp { continue } - if strings.EqualFold(codec.MimeType, webrtc.MimeTypeRTX) { + if mime.IsMimeTypeStringRTX(codec.MimeType) { continue } if IsCodecEnabled(codecs, codec.RTPCodecCapability) { @@ -151,7 +151,7 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac if rtxEnabled { if err := me.RegisterCodec(webrtc.RTPCodecParameters{ RTPCodecCapability: webrtc.RTPCodecCapability{ - MimeType: webrtc.MimeTypeRTX, + MimeType: mime.MimeTypeRTX.String(), ClockRate: 90000, SDPFmtpLine: fmt.Sprintf("apt=%d", codec.PayloadType), }, @@ -196,7 +196,7 @@ func createMediaEngine(codecs []*livekit.Codec, config DirectionConfig, filterOu func IsCodecEnabled(codecs []*livekit.Codec, cap webrtc.RTPCodecCapability) bool { for _, codec := range codecs { - if !strings.EqualFold(codec.Mime, cap.MimeType) { + if !mime.IsMimeTypeStringEqual(codec.Mime, cap.MimeType) { continue } if codec.FmtpLine == "" || strings.EqualFold(codec.FmtpLine, cap.SDPFmtpLine) { @@ -208,10 +208,10 @@ func IsCodecEnabled(codecs []*livekit.Codec, cap webrtc.RTPCodecCapability) bool func selectAlternativeVideoCodec(enabledCodecs []*livekit.Codec) string { for _, c := range enabledCodecs { - if strings.HasPrefix(c.Mime, "video/") { + if mime.IsMimeTypeStringVideo(c.Mime) { return c.Mime } } // no viable codec in the list of enabled codecs, fall back to the most widely supported codec - return webrtc.MimeTypeVP8 + return mime.MimeTypeVP8.String() } diff --git a/pkg/rtc/mediaengine_test.go b/pkg/rtc/mediaengine_test.go index d91090165..122861a87 100644 --- a/pkg/rtc/mediaengine_test.go +++ b/pkg/rtc/mediaengine_test.go @@ -20,21 +20,22 @@ import ( "github.com/pion/webrtc/v4" "github.com/stretchr/testify/require" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) func TestIsCodecEnabled(t *testing.T) { t.Run("empty fmtp requirement should match all", func(t *testing.T) { enabledCodecs := []*livekit.Codec{{Mime: "video/h264"}} - require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264, SDPFmtpLine: "special"})) - require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264})) - require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeVP8})) + require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String(), SDPFmtpLine: "special"})) + require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String()})) + require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeVP8.String()})) }) t.Run("when fmtp is provided, require match", func(t *testing.T) { enabledCodecs := []*livekit.Codec{{Mime: "video/h264", FmtpLine: "special"}} - require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264, SDPFmtpLine: "special"})) - require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264})) - require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeVP8})) + require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String(), SDPFmtpLine: "special"})) + require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String()})) + require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeVP8.String()})) }) } diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index bcb4c1457..3a7235139 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -17,7 +17,6 @@ package rtc import ( "context" "math" - "strings" "sync" "time" @@ -34,6 +33,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/telemetry" util "github.com/livekit/mediatransportutil" ) @@ -56,7 +56,7 @@ type MediaTrack struct { rttFromXR atomic.Bool enableRegression bool - regressionTargetCodec string + regressionTargetCodec mime.MimeType } type MediaTrackParams struct { @@ -87,7 +87,7 @@ func NewMediaTrack(params MediaTrackParams, ti *livekit.TrackInfo) *MediaTrack { // TODO: disable codec regression until simulcast-codec clients knows that if ti.BackupCodecPolicy == livekit.BackupCodecPolicy_REGRESSION && len(ti.Codecs) > 1 { t.enableRegression = true - t.regressionTargetCodec = ti.Codecs[1].MimeType + t.regressionTargetCodec = mime.NormalizeMimeType(ti.Codecs[1].MimeType) t.params.Logger.Debugw("track enabled codec regression", "regressionCodec", t.regressionTargetCodec) } @@ -122,20 +122,20 @@ func NewMediaTrack(params MediaTrackParams, ti *livekit.TrackInfo) *MediaTrack { DynacastPauseDelay: params.VideoConfig.DynacastPauseDelay, Logger: params.Logger, }) - t.MediaTrackReceiver.OnSetupReceiver(func(mime string) { + t.MediaTrackReceiver.OnSetupReceiver(func(mime mime.MimeType) { t.dynacastManager.AddCodec(mime) }) t.MediaTrackReceiver.OnSubscriberMaxQualityChange( - func(subscriberID livekit.ParticipantID, codec webrtc.RTPCodecCapability, layer int32) { + func(subscriberID livekit.ParticipantID, mimeType mime.MimeType, layer int32) { t.dynacastManager.NotifySubscriberMaxQuality( subscriberID, - codec.MimeType, + mimeType, buffer.SpatialLayerToVideoQuality(layer, t.MediaTrackReceiver.TrackInfo()), ) }, ) t.MediaTrackReceiver.OnCodecRegression(func(old, new webrtc.RTPCodecParameters) { - t.dynacastManager.HandleCodecRegression(old.MimeType, new.MimeType) + t.dynacastManager.HandleCodecRegression(mime.NormalizeMimeType(old.MimeType), mime.NormalizeMimeType(new.MimeType)) }) } @@ -252,7 +252,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe ti := t.MediaTrackReceiver.TrackInfoClone() t.lock.Lock() - mime := strings.ToLower(track.Codec().MimeType) + mimeType := mime.NormalizeMimeType(track.Codec().MimeType) layer := buffer.RidToSpatialLayer(track.RID(), ti) t.params.Logger.Debugw( "AddReceiver", @@ -261,11 +261,11 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe "ssrc", track.SSRC(), "codec", track.Codec(), ) - wr := t.MediaTrackReceiver.Receiver(mime) + wr := t.MediaTrackReceiver.Receiver(mimeType) if wr == nil { priority := -1 for idx, c := range ti.Codecs { - if strings.EqualFold(mime, c.MimeType) { + if mime.IsMimeTypeStringEqual(track.Codec().MimeType, c.MimeType) { priority = idx break } @@ -283,7 +283,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe } } if priority < 0 { - t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mime, "track", logger.Proto(ti)) + t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mimeType, "track", logger.Proto(ti)) t.lock.Unlock() return false } @@ -292,7 +292,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe receiver, track, ti, - LoggerWithCodecMime(t.params.Logger, mime), + LoggerWithCodecMime(t.params.Logger, mimeType), t.params.OnRTCP, t.params.VideoConfig.StreamTrackerManager, sfu.WithPliThrottleConfig(t.params.PLIThrottleConfig), @@ -303,7 +303,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe ) newWR.OnCloseHandler(func() { t.MediaTrackReceiver.SetClosing() - t.MediaTrackReceiver.ClearReceiver(mime, false) + t.MediaTrackReceiver.ClearReceiver(mimeType, false) if t.MediaTrackReceiver.TryClose() { if t.dynacastManager != nil { t.dynacastManager.Close() @@ -325,7 +325,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe parameters := receiver.GetParameters() for _, c := range ti.Codecs { for _, nc := range parameters.Codecs { - if strings.EqualFold(nc.MimeType, c.MimeType) { + if mime.IsMimeTypeStringEqual(nc.MimeType, c.MimeType) { potentialCodecs = append(potentialCodecs, nc) break } @@ -344,7 +344,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe for ssrc, info := range t.params.SimTracks { if info.Mid == mid { - t.MediaTrackReceiver.SetLayerSsrc(mime, info.Rid, ssrc) + t.MediaTrackReceiver.SetLayerSsrc(mimeType, info.Rid, ssrc) } } wr = newWR @@ -379,17 +379,17 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe bitrates = int(ti.Layers[layer].GetBitrate()) } - t.MediaTrackReceiver.SetLayerSsrc(mime, track.RID(), uint32(track.SSRC())) + t.MediaTrackReceiver.SetLayerSsrc(mimeType, track.RID(), uint32(track.SSRC())) - if newCodec && t.enableRegression && strings.EqualFold(mime, t.regressionTargetCodec) { - t.params.Logger.Infow("regression target codec received", "codec", mime) + if newCodec && t.enableRegression && mimeType == t.regressionTargetCodec { + t.params.Logger.Infow("regression target codec received", "codec", mimeType) for _, c := range ti.Codecs { - if strings.EqualFold(c.MimeType, mime) { + if mime.NormalizeMimeType(c.MimeType) == mimeType { continue } t.params.Logger.Debugw("suspending codec for codec regression", "codec", c.MimeType) - if r := t.MediaTrackReceiver.Receiver(c.MimeType); r != nil { + if r := t.MediaTrackReceiver.Receiver(mime.NormalizeMimeType(c.MimeType)); r != nil { if rtcreceiver, ok := r.(*sfu.WebRTCReceiver); ok { rtcreceiver.SetCodecState(sfu.ReceiverCodecStateSuspended) } @@ -409,7 +409,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRe context.Background(), t.params.ParticipantID, t.ID(), - mime, + mimeType, int(layer), stats, ) diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index 970dba5a4..6aff85278 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -35,6 +35,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" "github.com/livekit/livekit-server/pkg/telemetry" ) @@ -127,7 +128,7 @@ type MediaTrackReceiverParams struct { AudioConfig sfu.AudioConfig Telemetry telemetry.TelemetryService Logger logger.Logger - RegressionTargetCodec string + RegressionTargetCodec mime.MimeType } type MediaTrackReceiver struct { @@ -140,7 +141,7 @@ type MediaTrackReceiver struct { state mediaTrackReceiverState isExpectedToResume bool - onSetupReceiver func(mime string) + onSetupReceiver func(mime mime.MimeType) onMediaLossFeedback func(dt *sfu.DownTrack, report *rtcp.ReceiverReport) onClose []func(isExpectedToResume bool) onCodecRegression func(old, new webrtc.RTPCodecParameters) @@ -179,7 +180,7 @@ func (t *MediaTrackReceiver) Restart() { } } -func (t *MediaTrackReceiver) OnSetupReceiver(f func(mime string)) { +func (t *MediaTrackReceiver) OnSetupReceiver(f func(mime mime.MimeType)) { t.lock.Lock() t.onSetupReceiver = f t.lock.Unlock() @@ -198,12 +199,12 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority // codec position maybe taken by DummyReceiver, check and upgrade to WebRTCReceiver var existingReceiver bool for _, r := range receivers { - if strings.EqualFold(r.Codec().MimeType, receiver.Codec().MimeType) { + if r.Mime() == receiver.Mime() { existingReceiver = true if d, ok := r.TrackReceiver.(*DummyReceiver); ok { d.Upgrade(receiver) } else { - t.params.Logger.Errorw("receiver already exists, setup failed", nil, "mime", receiver.Codec().MimeType) + t.params.Logger.Errorw("receiver already exists, setup failed", nil, "mime", receiver.Mime()) } break } @@ -219,13 +220,13 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority if mid != "" { trackInfo := t.TrackInfoClone() if priority == 0 { - trackInfo.MimeType = receiver.Codec().MimeType + trackInfo.MimeType = receiver.Mime().String() trackInfo.Mid = mid } for i, ci := range trackInfo.Codecs { if i == priority { - ci.MimeType = receiver.Codec().MimeType + ci.MimeType = receiver.Mime().String() ci.Mid = mid break } @@ -237,20 +238,20 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority onSetupReceiver := t.onSetupReceiver t.lock.Unlock() - var receiverCodecs []string + var receiverCodecs []mime.MimeType for _, r := range receivers { - receiverCodecs = append(receiverCodecs, r.Codec().MimeType) + receiverCodecs = append(receiverCodecs, r.Mime()) } t.params.Logger.Debugw( "setup receiver", - "mime", receiver.Codec().MimeType, + "mime", receiver.Mime(), "priority", priority, "receivers", receiverCodecs, "mid", mid, ) if onSetupReceiver != nil { - onSetupReceiver(receiver.Codec().MimeType) + onSetupReceiver(receiver.Mime()) } } @@ -279,7 +280,7 @@ func (t *MediaTrackReceiver) HandleReceiverCodecChange(r sfu.TrackReceiver, code continue } - if strings.EqualFold(receiver.Codec().MimeType, t.params.RegressionTargetCodec) { + if receiver.Mime() == t.params.RegressionTargetCodec { backupCodecReceiver = receiver.TrackReceiver } @@ -338,7 +339,7 @@ func (t *MediaTrackReceiver) SetPotentialCodecs(codecs []webrtc.RTPCodecParamete for i, c := range codecs { var exist bool for _, r := range receivers { - if strings.EqualFold(c.MimeType, r.Codec().MimeType) { + if mime.NormalizeMimeType(c.MimeType) == r.Mime() { exist = true break } @@ -357,11 +358,11 @@ func (t *MediaTrackReceiver) SetPotentialCodecs(codecs []webrtc.RTPCodecParamete t.lock.Unlock() } -func (t *MediaTrackReceiver) ClearReceiver(mime string, isExpectedToResume bool) { +func (t *MediaTrackReceiver) ClearReceiver(mime mime.MimeType, isExpectedToResume bool) { t.lock.Lock() receivers := slices.Clone(t.receivers) for idx, receiver := range receivers { - if strings.EqualFold(receiver.Codec().MimeType, mime) { + if receiver.Mime() == mime { receivers[idx] = receivers[len(receivers)-1] receivers[len(receivers)-1] = nil receivers = receivers[:len(receivers)-1] @@ -384,7 +385,7 @@ func (t *MediaTrackReceiver) ClearAllReceivers(isExpectedToResume bool) { t.lock.Unlock() for _, r := range receivers { - t.removeAllSubscribersForMime(r.Codec().MimeType, isExpectedToResume) + t.removeAllSubscribersForMime(r.Mime(), isExpectedToResume) } } @@ -563,7 +564,7 @@ func (t *MediaTrackReceiver) AddSubscriber(sub types.LocalParticipant) (types.Su codec := receiver.Codec() var found bool for _, pc := range potentialCodecs { - if strings.EqualFold(codec.MimeType, pc.MimeType) { + if mime.IsMimeTypeStringEqual(codec.MimeType, pc.MimeType) { found = true break } @@ -615,7 +616,7 @@ func (t *MediaTrackReceiver) RemoveSubscriber(subscriberID livekit.ParticipantID _ = t.MediaTrackSubscriptions.RemoveSubscriber(subscriberID, isExpectedToResume) } -func (t *MediaTrackReceiver) removeAllSubscribersForMime(mime string, isExpectedToResume bool) { +func (t *MediaTrackReceiver) removeAllSubscribersForMime(mime mime.MimeType, isExpectedToResume bool) { t.params.Logger.Debugw("removing all subscribers for mime", "mime", mime) for _, subscriberID := range t.MediaTrackSubscriptions.GetAllSubscribersForMime(mime) { t.RemoveSubscriber(subscriberID, isExpectedToResume) @@ -655,7 +656,7 @@ func (t *MediaTrackReceiver) updateTrackInfoOfReceivers() { } } -func (t *MediaTrackReceiver) SetLayerSsrc(mime string, rid string, ssrc uint32) { +func (t *MediaTrackReceiver) SetLayerSsrc(mimeType mime.MimeType, rid string, ssrc uint32) { t.lock.Lock() trackInfo := t.TrackInfoClone() layer := buffer.RidToSpatialLayer(rid, trackInfo) @@ -666,7 +667,7 @@ func (t *MediaTrackReceiver) SetLayerSsrc(mime string, rid string, ssrc uint32) quality := buffer.SpatialLayerToVideoQuality(layer, trackInfo) // set video layer ssrc info for i, ci := range trackInfo.Codecs { - if !strings.EqualFold(ci.MimeType, mime) { + if mime.NormalizeMimeType(ci.MimeType) == mimeType { continue } @@ -703,7 +704,7 @@ func (t *MediaTrackReceiver) UpdateCodecCid(codecs []*livekit.SimulcastCodec) { trackInfo := t.TrackInfoClone() for _, c := range codecs { for _, origin := range trackInfo.Codecs { - if strings.Contains(origin.MimeType, c.Codec) { + if mime.GetMimeTypeCodec(origin.MimeType) == mime.NormalizeMimeTypeCodec(c.Codec) { origin.Cid = c.Cid break } @@ -724,7 +725,7 @@ func (t *MediaTrackReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { // patch Mid and SSRC of codecs/layers by keeping original if available for i, ci := range clonedInfo.Codecs { for _, originCi := range trackInfo.Codecs { - if !strings.EqualFold(ci.MimeType, originCi.MimeType) { + if !mime.IsMimeTypeStringEqual(ci.MimeType, originCi.MimeType) { continue } @@ -945,9 +946,9 @@ func (t *MediaTrackReceiver) PrimaryReceiver() sfu.TrackReceiver { return receivers[0].TrackReceiver } -func (t *MediaTrackReceiver) Receiver(mime string) sfu.TrackReceiver { +func (t *MediaTrackReceiver) Receiver(mime mime.MimeType) sfu.TrackReceiver { for _, r := range t.loadReceivers() { - if strings.EqualFold(r.Codec().MimeType, mime) { + if r.Mime() == mime { if dr, ok := r.TrackReceiver.(*DummyReceiver); ok { return dr.Receiver() } @@ -980,7 +981,7 @@ func (t *MediaTrackReceiver) SetRTT(rtt uint32) { } } -func (t *MediaTrackReceiver) GetTemporalLayerForSpatialFps(spatial int32, fps uint32, mime string) int32 { +func (t *MediaTrackReceiver) GetTemporalLayerForSpatialFps(spatial int32, fps uint32, mime mime.MimeType) int32 { receiver := t.Receiver(mime) if receiver == nil { return buffer.DefaultMaxLayerTemporal diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 8538db467..88a9df90f 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -16,7 +16,6 @@ package rtc import ( "errors" - "strings" "sync" "github.com/pion/rtcp" @@ -24,6 +23,7 @@ import ( "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" sutils "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -46,7 +46,7 @@ type MediaTrackSubscriptions struct { subscribedTracks map[livekit.ParticipantID]types.SubscribedTrack onDownTrackCreated func(downTrack *sfu.DownTrack) - onSubscriberMaxQualityChange func(subscriberID livekit.ParticipantID, codec webrtc.RTPCodecCapability, layer int32) + onSubscriberMaxQualityChange func(subscriberID livekit.ParticipantID, mime mime.MimeType, layer int32) } type MediaTrackSubscriptionsParams struct { @@ -72,7 +72,7 @@ func (t *MediaTrackSubscriptions) OnDownTrackCreated(f func(downTrack *sfu.DownT t.onDownTrackCreated = f } -func (t *MediaTrackSubscriptions) OnSubscriberMaxQualityChange(f func(subscriberID livekit.ParticipantID, codec webrtc.RTPCodecCapability, layer int32)) { +func (t *MediaTrackSubscriptions) OnSubscriberMaxQualityChange(f func(subscriberID livekit.ParticipantID, mime mime.MimeType, layer int32)) { t.onSubscriberMaxQualityChange = f } @@ -179,7 +179,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * if t.onSubscriberMaxQualityChange != nil { go func() { spatial := buffer.VideoQualityToSpatialLayer(livekit.VideoQuality_HIGH, t.params.MediaTrack.ToProto()) - t.onSubscriberMaxQualityChange(downTrack.SubscriberID(), codec, spatial) + t.onSubscriberMaxQualityChange(downTrack.SubscriberID(), mime.NormalizeMimeType(codec.MimeType), spatial) }() } } @@ -214,7 +214,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * downTrack.OnMaxLayerChanged(func(dt *sfu.DownTrack, layer int32) { if t.onSubscriberMaxQualityChange != nil { - t.onSubscriberMaxQualityChange(dt.SubscriberID(), dt.Codec(), layer) + t.onSubscriberMaxQualityChange(dt.SubscriberID(), dt.Mime(), layer) } }) @@ -280,7 +280,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * Stereo: info.Stereo, Red: !info.DisableRed, } - if addTrackParams.Red && (len(codecs) == 1 && strings.EqualFold(codecs[0].MimeType, webrtc.MimeTypeOpus)) { + if addTrackParams.Red && (len(codecs) == 1 && mime.IsMimeTypeStringOpus(codecs[0].MimeType)) { addTrackParams.Red = false } @@ -379,13 +379,13 @@ func (t *MediaTrackSubscriptions) GetAllSubscribers() []livekit.ParticipantID { return subs } -func (t *MediaTrackSubscriptions) GetAllSubscribersForMime(mime string) []livekit.ParticipantID { +func (t *MediaTrackSubscriptions) GetAllSubscribersForMime(mime mime.MimeType) []livekit.ParticipantID { t.subscribedTracksMu.RLock() defer t.subscribedTracksMu.RUnlock() subs := make([]livekit.ParticipantID, 0, len(t.subscribedTracks)) for id, subTrack := range t.subscribedTracks { - if !strings.EqualFold(subTrack.DownTrack().Codec().MimeType, mime) { + if subTrack.DownTrack().Mime() != mime { continue } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 201532736..266a2c741 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -49,6 +49,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/pacer" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" "github.com/livekit/livekit-server/pkg/telemetry" @@ -1775,12 +1776,13 @@ func (p *ParticipantImpl) onMediaTrack(rtcTrack *webrtc.TrackRemote, rtpReceiver p.pendingTracksLock.Lock() p.pendingRemoteTracks = append(p.pendingRemoteTracks, &pendingRemoteTrack{track: rtcTrack, receiver: rtpReceiver}) p.pendingTracksLock.Unlock() - p.pubLogger.Debugw("webrtc Track published but can't find MediaTrack, add to pendingTracks", + p.pubLogger.Debugw( + "webrtc Track published but can't find MediaTrack, add to pendingTracks", "kind", track.Kind().String(), "webrtcTrackID", track.ID(), "rid", track.RID(), "SSRC", track.SSRC(), - "mime", codec.MimeType, + "mime", mime.NormalizeMimeType(codec.MimeType), ) return } @@ -1802,7 +1804,7 @@ func (p *ParticipantImpl) onMediaTrack(rtcTrack *webrtc.TrackRemote, rtpReceiver "webrtcTrackID", track.ID(), "rid", track.RID(), "SSRC", track.SSRC(), - "mime", codec.MimeType, + "mime", mime.NormalizeMimeType(codec.MimeType), "trackInfo", logger.Proto(publishedTrack.ToProto()), "fromSdp", fromSdp, ) @@ -2167,7 +2169,7 @@ func (p *ParticipantImpl) onSubscribedMaxQualityChange( // normalize the codec name for _, subscribedQuality := range subscribedQualities { - subscribedQuality.Codec = strings.ToLower(strings.TrimPrefix(subscribedQuality.Codec, "video/")) + subscribedQuality.Codec = strings.ToLower(strings.TrimPrefix(subscribedQuality.Codec, mime.MimeTypePrefixVideo)) } subscribedQualityUpdate := &livekit.SubscribedQualityUpdate{ @@ -2244,36 +2246,36 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l } else { seenCodecs := make(map[string]struct{}) for _, codec := range req.SimulcastCodecs { - mime := codec.Codec + mimeType := codec.Codec if req.Type == livekit.TrackType_VIDEO { - if !strings.HasPrefix(mime, "video/") { - mime = "video/" + mime + if !mime.IsMimeTypeStringVideo(mimeType) { + mimeType = mime.MimeTypePrefixVideo + mimeType } - if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) { + if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mimeType}) { altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs) p.pubLogger.Infow("falling back to alternative codec", - "codec", mime, + "codec", mimeType, "altCodec", altCodec, "trackID", ti.Sid, ) // select an alternative MIME type that's generally supported - mime = altCodec + mimeType = altCodec } - } else if req.Type == livekit.TrackType_AUDIO && !strings.HasPrefix(mime, "audio/") { - mime = "audio/" + mime + } else if req.Type == livekit.TrackType_AUDIO && !mime.IsMimeTypeStringAudio(mimeType) { + mimeType = mime.MimeTypePrefixAudio + mimeType } - if _, ok := seenCodecs[mime]; ok || mime == "" { + if _, ok := seenCodecs[mimeType]; ok || mimeType == "" { continue } - seenCodecs[mime] = struct{}{} + seenCodecs[mimeType] = struct{}{} clonedLayers := make([]*livekit.VideoLayer, 0, len(req.Layers)) for _, l := range req.Layers { clonedLayers = append(clonedLayers, utils.CloneProto(l)) } ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ - MimeType: mime, + MimeType: mimeType, Cid: codec.Cid, Layers: clonedLayers, }) @@ -2391,7 +2393,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track sfu.TrackRemote, rtpReceiver "trackID", track.ID(), "rid", track.RID(), "SSRC", track.SSRC(), - "mime", track.Codec().MimeType, + "mime", mime.NormalizeMimeType(track.Codec().MimeType), "mid", mid, ) if mid == "" { @@ -2416,7 +2418,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track sfu.TrackRemote, rtpReceiver var codecFound int for _, c := range ti.Codecs { for _, nc := range parameters.Codecs { - if strings.EqualFold(nc.MimeType, c.MimeType) { + if mime.IsMimeTypeStringEqual(nc.MimeType, c.MimeType) { codecFound++ break } @@ -2499,7 +2501,7 @@ func (p *ParticipantImpl) addMigratedTrack(cid string, ti *livekit.TrackInfo) *M parameters := rtpReceiver.GetParameters() for _, c := range ti.Codecs { for _, nc := range parameters.Codecs { - if strings.EqualFold(nc.MimeType, c.MimeType) { + if mime.IsMimeTypeStringEqual(nc.MimeType, c.MimeType) { potentialCodecs = append(potentialCodecs, nc) break } @@ -2508,10 +2510,10 @@ func (p *ParticipantImpl) addMigratedTrack(cid string, ti *livekit.TrackInfo) *M // check for mime_type for tracks that do not have simulcast_codecs set if ti.MimeType != "" { for _, nc := range parameters.Codecs { - if strings.EqualFold(nc.MimeType, ti.MimeType) { + if mime.IsMimeTypeStringEqual(nc.MimeType, ti.MimeType) { alreadyAdded := false for _, pc := range potentialCodecs { - if strings.EqualFold(pc.MimeType, ti.MimeType) { + if mime.IsMimeTypeStringEqual(pc.MimeType, ti.MimeType) { alreadyAdded = true break } @@ -2528,7 +2530,7 @@ func (p *ParticipantImpl) addMigratedTrack(cid string, ti *livekit.TrackInfo) *M for _, codec := range ti.Codecs { for ssrc, info := range p.params.SimTracks { if info.Mid == codec.Mid { - mt.SetLayerSsrc(codec.MimeType, info.Rid, ssrc) + mt.SetLayerSsrc(mime.NormalizeMimeType(codec.MimeType), info.Rid, ssrc) } } } @@ -2993,7 +2995,7 @@ func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Cod shouldDisable := func(c *livekit.Codec, disabled []*livekit.Codec) bool { for _, disableCodec := range disabled { // disable codec's fmtp is empty means disable this codec entirely - if strings.EqualFold(c.Mime, disableCodec.Mime) { + if mime.IsMimeTypeStringEqual(c.Mime, disableCodec.Mime) { return true } } @@ -3007,13 +3009,13 @@ func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Cod } // sort by compatibility, since we will look for backups in these. - if strings.EqualFold(c.Mime, webrtc.MimeTypeVP8) { + if mime.IsMimeTypeStringVP8(c.Mime) { if len(p.enabledPublishCodecs) > 0 { p.enabledPublishCodecs = slices.Insert(p.enabledPublishCodecs, 0, c) } else { p.enabledPublishCodecs = append(p.enabledPublishCodecs, c) } - } else if strings.EqualFold(c.Mime, webrtc.MimeTypeH264) { + } else if mime.IsMimeTypeStringH264(c.Mime) { p.enabledPublishCodecs = append(p.enabledPublishCodecs, c) } else { publishCodecs = append(publishCodecs, c) @@ -3034,7 +3036,7 @@ func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Cod func (p *ParticipantImpl) GetEnabledPublishCodecs() []*livekit.Codec { codecs := make([]*livekit.Codec, 0, len(p.enabledPublishCodecs)) for _, c := range p.enabledPublishCodecs { - if strings.EqualFold(c.Mime, webrtc.MimeTypeRTX) { + if mime.IsMimeTypeStringRTX(c.Mime) { continue } codecs = append(codecs, c) diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index c46fc1ed8..db9ad1e04 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -26,6 +26,7 @@ import ( "google.golang.org/protobuf/proto" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/telemetry/telemetryfakes" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" @@ -356,7 +357,7 @@ func TestDisableCodecs(t *testing.T) { codecs := transceiver.Receiver().GetParameters().Codecs var found264 bool for _, c := range codecs { - if strings.EqualFold(c.MimeType, "video/h264") { + if mime.IsMimeTypeStringH264(c.MimeType) { found264 = true } } @@ -390,7 +391,7 @@ func TestDisableCodecs(t *testing.T) { codecs = transceiver.Receiver().GetParameters().Codecs found264 = false for _, c := range codecs { - if strings.EqualFold(c.MimeType, "video/h264") { + if mime.IsMimeTypeStringH264(c.MimeType) { found264 = true } } @@ -410,7 +411,7 @@ func TestDisablePublishCodec(t *testing.T) { }) for _, codec := range participant.enabledPublishCodecs { - require.NotEqual(t, strings.ToLower(codec.Mime), "video/h264") + require.False(t, mime.IsMimeTypeStringH264(codec.Mime)) } sink := &routingfakes.FakeMessageSink{} @@ -421,7 +422,7 @@ func TestDisablePublishCodec(t *testing.T) { if published := res.GetTrackPublished(); published != nil { publishReceived.Store(true) require.NotEmpty(t, published.Track.Codecs) - require.Equal(t, "video/vp8", strings.ToLower(published.Track.Codecs[0].MimeType)) + require.True(t, mime.IsMimeTypeStringVP8(published.Track.Codecs[0].MimeType)) } } return nil @@ -446,7 +447,7 @@ func TestDisablePublishCodec(t *testing.T) { if published := res.GetTrackPublished(); published != nil { publishReceived.Store(true) require.NotEmpty(t, published.Track.Codecs) - require.Equal(t, "video/vp8", strings.ToLower(published.Track.Codecs[0].MimeType)) + require.True(t, mime.IsMimeTypeStringVP8(published.Track.Codecs[0].MimeType)) } } return nil @@ -493,13 +494,20 @@ func TestPreferVideoCodecForPublisher(t *testing.T) { require.NoError(t, err) transceiver, err := pc.AddTransceiverFromTrack(track, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendrecv}) require.NoError(t, err) - sdp, err := pc.CreateOffer(nil) - require.NoError(t, err) - pc.SetLocalDescription(sdp) codecs := transceiver.Receiver().GetParameters().Codecs + if i > 0 { + // the negotiated codecs order could be updated by first negotiation, reorder to make h264 not preferred + for mime.IsMimeTypeStringH264(codecs[0].MimeType) { + codecs = append(codecs[1:], codecs[0]) + } + } // h264 should not be preferred - require.NotEqual(t, codecs[0].MimeType, "video/h264") + require.False(t, mime.IsMimeTypeStringH264(codecs[0].MimeType), "codecs", codecs) + + sdp, err := pc.CreateOffer(nil) + require.NoError(t, err) + require.NoError(t, pc.SetLocalDescription(sdp)) sink := &routingfakes.FakeMessageSink{} participant.SetResponseSink(sink) @@ -528,7 +536,7 @@ func TestPreferVideoCodecForPublisher(t *testing.T) { if videoSectionIndex == i { codecs, err := codecsFromMediaDescription(m) require.NoError(t, err) - if strings.EqualFold(codecs[0].Name, "h264") { + if mime.IsMimeTypeCodecStringH264(codecs[0].Name) { h264Preferred = true break } @@ -626,7 +634,7 @@ func TestPreferAudioCodecForRed(t *testing.T) { } require.True(t, nackEnabled, "nack should be enabled for opus") - if strings.EqualFold(codecs[0].Name, "red") { + if mime.IsMimeTypeCodecStringRED(codecs[0].Name) { redPreferred = true break } diff --git a/pkg/rtc/participant_sdp.go b/pkg/rtc/participant_sdp.go index ae897795c..59d75ada9 100644 --- a/pkg/rtc/participant_sdp.go +++ b/pkg/rtc/participant_sdp.go @@ -22,6 +22,7 @@ import ( "github.com/pion/sdp/v3" "github.com/pion/webrtc/v4" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" lksdp "github.com/livekit/protocol/sdp" ) @@ -58,7 +59,7 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.Se var opusPayload uint8 for _, codec := range codecs { - if strings.EqualFold(codec.Name, "opus") { + if mime.IsMimeTypeCodecStringOpus(codec.Name) { opusPayload = codec.PayloadType break } @@ -70,7 +71,7 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.Se var preferredCodecs, leftCodecs []string for _, codec := range codecs { // codec contain opus/red - if !disableRed && strings.EqualFold(codec.Name, "red") && strings.Contains(codec.Fmtp, strconv.FormatInt(int64(opusPayload), 10)) { + if !disableRed && mime.IsMimeTypeCodecStringRED(codec.Name) && strings.Contains(codec.Fmtp, strconv.FormatInt(int64(opusPayload), 10)) { preferredCodecs = append(preferredCodecs, strconv.FormatInt(int64(codec.PayloadType), 10)) } else { leftCodecs = append(leftCodecs, strconv.FormatInt(int64(codec.PayloadType), 10)) @@ -138,20 +139,19 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess p.pendingTracksLock.RUnlock() continue } - var mime string + var mimeType string for _, c := range info.Codecs { if c.Cid == streamID { - mime = c.MimeType + mimeType = c.MimeType break } } - if mime == "" && len(info.Codecs) > 0 { - mime = info.Codecs[0].MimeType + if mimeType == "" && len(info.Codecs) > 0 { + mimeType = info.Codecs[0].MimeType } p.pendingTracksLock.RUnlock() - mime = strings.ToUpper(mime) - if mime != "" { + if mimeType != "" { codecs, err := codecsFromMediaDescription(unmatchVideo) if err != nil { p.pubLogger.Errorw("extract codecs from media section failed", err, "media", unmatchVideo) @@ -160,7 +160,7 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess var preferredCodecs, leftCodecs []string for _, c := range codecs { - if strings.HasSuffix(mime, strings.ToUpper(c.Name)) { + if mime.GetMimeTypeCodec(mimeType) == mime.NormalizeMimeTypeCodec(c.Name) { preferredCodecs = append(preferredCodecs, strconv.FormatInt(int64(c.PayloadType), 10)) } else { leftCodecs = append(leftCodecs, strconv.FormatInt(int64(c.PayloadType), 10)) @@ -241,7 +241,7 @@ func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescript continue } - opusPT, err := parsed.GetPayloadTypeForCodec(sdp.Codec{Name: "opus"}) + opusPT, err := parsed.GetPayloadTypeForCodec(sdp.Codec{Name: mime.MimeTypeCodecOpus.String()}) if err != nil { p.pubLogger.Infow("failed to get opus payload type", "error", err, "trackID", ti.Sid) continue diff --git a/pkg/rtc/subscribedtrack.go b/pkg/rtc/subscribedtrack.go index 5c83832f3..a45f4e177 100644 --- a/pkg/rtc/subscribedtrack.go +++ b/pkg/rtc/subscribedtrack.go @@ -259,7 +259,7 @@ func (t *SubscribedTrack) applySettings() { spatial = buffer.VideoQualityToSpatialLayer(quality, mt.ToProto()) if t.settings.Fps > 0 { - temporal = mt.GetTemporalLayerForSpatialFps(spatial, t.settings.Fps, dt.Codec().MimeType) + temporal = mt.GetTemporalLayerForSpatialFps(spatial, t.settings.Fps, dt.Mime()) } } diff --git a/pkg/rtc/subscriptionmanager.go b/pkg/rtc/subscriptionmanager.go index be97b2453..2d41f37c3 100644 --- a/pkg/rtc/subscriptionmanager.go +++ b/pkg/rtc/subscriptionmanager.go @@ -819,7 +819,7 @@ func (m *SubscriptionManager) handleSubscribedTrackClose(s *trackSubscription, i context.Background(), m.params.Participant.ID(), s.trackID, - dt.Codec().MimeType, + dt.Mime(), stats, ) } diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index a274548ec..24e5801cb 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -44,6 +44,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/bwe/sendsidebwe" "github.com/livekit/livekit-server/pkg/sfu/datachannel" sfuinterceptor "github.com/livekit/livekit-server/pkg/sfu/interceptor" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/pacer" pd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/playoutdelay" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" @@ -416,11 +417,11 @@ func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimat } setTWCCForVideo := func(info *interceptor.StreamInfo) { - if !strings.HasPrefix(info.MimeType, "video") { + if !mime.IsMimeTypeStringVideo(info.MimeType) { return } // rtx stream don't have rtcp feedback, always set twcc for rtx stream - twccFb := strings.HasSuffix(info.MimeType, "rtx") + twccFb := mime.GetMimeTypeCodec(info.MimeType) == mime.MimeTypeCodecRTX if !twccFb { for _, fb := range info.RTCPFeedback { if fb.Type == webrtc.TypeRTCPFBTransportCC { @@ -2089,7 +2090,7 @@ func configureAudioTransceiver(tr *webrtc.RTPTransceiver, stereo bool, nack bool codecs := sender.GetParameters().Codecs configCodecs := make([]webrtc.RTPCodecParameters, 0, len(codecs)) for _, c := range codecs { - if strings.EqualFold(c.MimeType, webrtc.MimeTypeOpus) { + if mime.IsMimeTypeStringOpus(c.MimeType) { c.SDPFmtpLine = strings.ReplaceAll(c.SDPFmtpLine, ";sprop-stereo=1", "") if stereo { c.SDPFmtpLine += ";sprop-stereo=1" diff --git a/pkg/rtc/transport_test.go b/pkg/rtc/transport_test.go index 512b2ff91..ac0c1cc97 100644 --- a/pkg/rtc/transport_test.go +++ b/pkg/rtc/transport_test.go @@ -28,6 +28,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/transport/transportfakes" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/testutils" "github.com/livekit/protocol/livekit" ) @@ -389,9 +390,9 @@ func TestFilteringCandidates(t *testing.T) { ParticipantIdentity: "identity", Config: &WebRTCConfig{}, EnabledCodecs: []*livekit.Codec{ - {Mime: webrtc.MimeTypeOpus}, - {Mime: webrtc.MimeTypeVP8}, - {Mime: webrtc.MimeTypeH264}, + {Mime: mime.MimeTypeOpus.String()}, + {Mime: mime.MimeTypeVP8.String()}, + {Mime: mime.MimeTypeH264.String()}, }, Handler: &transportfakes.FakeHandler{}, } @@ -598,7 +599,7 @@ func TestConfigureAudioTransceiver(t *testing.T) { } { t.Run(fmt.Sprintf("nack=%v,stereo=%v", testcase.nack, testcase.stereo), func(t *testing.T) { var me webrtc.MediaEngine - registerCodecs(&me, []*livekit.Codec{{Mime: webrtc.MimeTypeOpus}}, RTCPFeedbackConfig{Audio: []webrtc.RTCPFeedback{{Type: webrtc.TypeRTCPFBNACK}}}, false) + registerCodecs(&me, []*livekit.Codec{{Mime: mime.MimeTypeOpus.String()}}, RTCPFeedbackConfig{Audio: []webrtc.RTCPFeedback{{Type: webrtc.TypeRTCPFBNACK}}}, false) pc, err := webrtc.NewAPI(webrtc.WithMediaEngine(&me)).NewPeerConnection(webrtc.Configuration{}) require.NoError(t, err) defer pc.Close() @@ -608,7 +609,7 @@ func TestConfigureAudioTransceiver(t *testing.T) { configureAudioTransceiver(tr, testcase.stereo, testcase.nack) codecs := tr.Sender().GetParameters().Codecs for _, codec := range codecs { - if strings.Contains(codec.MimeType, webrtc.MimeTypeOpus) { + if mime.IsMimeTypeStringOpus(codec.MimeType) { require.Equal(t, testcase.stereo, strings.Contains(codec.SDPFmtpLine, "sprop-stereo=1")) var nackEnabled bool for _, fb := range codec.RTCPFeedback { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index eeaedc52b..d44bb329c 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -29,6 +29,7 @@ import ( "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/pacer" ) @@ -74,7 +75,7 @@ func (m MigrateState) String() string { // --------------------------------------------- type SubscribedCodecQuality struct { - CodecMime string + CodecMime mime.MimeType Quality livekit.VideoQuality } @@ -527,7 +528,7 @@ type MediaTrack interface { GetQualityForDimension(width, height uint32) livekit.VideoQuality // returns temporal layer that's appropriate for fps - GetTemporalLayerForSpatialFps(spatial int32, fps uint32, mime string) int32 + GetTemporalLayerForSpatialFps(spatial int32, fps uint32, mime mime.MimeType) int32 Receivers() []sfu.TrackReceiver ClearAllReceivers(isExpectedToResume bool) diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index 382c609ae..877267f6c 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -6,6 +6,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) @@ -94,12 +95,12 @@ type FakeLocalMediaTrack struct { getQualityForDimensionReturnsOnCall map[int]struct { result1 livekit.VideoQuality } - GetTemporalLayerForSpatialFpsStub func(int32, uint32, string) int32 + GetTemporalLayerForSpatialFpsStub func(int32, uint32, mime.MimeType) int32 getTemporalLayerForSpatialFpsMutex sync.RWMutex getTemporalLayerForSpatialFpsArgsForCall []struct { arg1 int32 arg2 uint32 - arg3 string + arg3 mime.MimeType } getTemporalLayerForSpatialFpsReturns struct { result1 int32 @@ -795,13 +796,13 @@ func (fake *FakeLocalMediaTrack) GetQualityForDimensionReturnsOnCall(i int, resu }{result1} } -func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFps(arg1 int32, arg2 uint32, arg3 string) int32 { +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFps(arg1 int32, arg2 uint32, arg3 mime.MimeType) int32 { fake.getTemporalLayerForSpatialFpsMutex.Lock() ret, specificReturn := fake.getTemporalLayerForSpatialFpsReturnsOnCall[len(fake.getTemporalLayerForSpatialFpsArgsForCall)] fake.getTemporalLayerForSpatialFpsArgsForCall = append(fake.getTemporalLayerForSpatialFpsArgsForCall, struct { arg1 int32 arg2 uint32 - arg3 string + arg3 mime.MimeType }{arg1, arg2, arg3}) stub := fake.GetTemporalLayerForSpatialFpsStub fakeReturns := fake.getTemporalLayerForSpatialFpsReturns @@ -822,13 +823,13 @@ func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsCallCount() int { return len(fake.getTemporalLayerForSpatialFpsArgsForCall) } -func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsCalls(stub func(int32, uint32, string) int32) { +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsCalls(stub func(int32, uint32, mime.MimeType) int32) { fake.getTemporalLayerForSpatialFpsMutex.Lock() defer fake.getTemporalLayerForSpatialFpsMutex.Unlock() fake.GetTemporalLayerForSpatialFpsStub = stub } -func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsArgsForCall(i int) (int32, uint32, string) { +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsArgsForCall(i int) (int32, uint32, mime.MimeType) { fake.getTemporalLayerForSpatialFpsMutex.RLock() defer fake.getTemporalLayerForSpatialFpsMutex.RUnlock() argsForCall := fake.getTemporalLayerForSpatialFpsArgsForCall[i] diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index ce6e9355e..67304714d 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -6,6 +6,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" ) @@ -82,12 +83,12 @@ type FakeMediaTrack struct { getQualityForDimensionReturnsOnCall map[int]struct { result1 livekit.VideoQuality } - GetTemporalLayerForSpatialFpsStub func(int32, uint32, string) int32 + GetTemporalLayerForSpatialFpsStub func(int32, uint32, mime.MimeType) int32 getTemporalLayerForSpatialFpsMutex sync.RWMutex getTemporalLayerForSpatialFpsArgsForCall []struct { arg1 int32 arg2 uint32 - arg3 string + arg3 mime.MimeType } getTemporalLayerForSpatialFpsReturns struct { result1 int32 @@ -675,13 +676,13 @@ func (fake *FakeMediaTrack) GetQualityForDimensionReturnsOnCall(i int, result1 l }{result1} } -func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFps(arg1 int32, arg2 uint32, arg3 string) int32 { +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFps(arg1 int32, arg2 uint32, arg3 mime.MimeType) int32 { fake.getTemporalLayerForSpatialFpsMutex.Lock() ret, specificReturn := fake.getTemporalLayerForSpatialFpsReturnsOnCall[len(fake.getTemporalLayerForSpatialFpsArgsForCall)] fake.getTemporalLayerForSpatialFpsArgsForCall = append(fake.getTemporalLayerForSpatialFpsArgsForCall, struct { arg1 int32 arg2 uint32 - arg3 string + arg3 mime.MimeType }{arg1, arg2, arg3}) stub := fake.GetTemporalLayerForSpatialFpsStub fakeReturns := fake.getTemporalLayerForSpatialFpsReturns @@ -702,13 +703,13 @@ func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsCallCount() int { return len(fake.getTemporalLayerForSpatialFpsArgsForCall) } -func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsCalls(stub func(int32, uint32, string) int32) { +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsCalls(stub func(int32, uint32, mime.MimeType) int32) { fake.getTemporalLayerForSpatialFpsMutex.Lock() defer fake.getTemporalLayerForSpatialFpsMutex.Unlock() fake.GetTemporalLayerForSpatialFpsStub = stub } -func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsArgsForCall(i int) (int32, uint32, string) { +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsArgsForCall(i int) (int32, uint32, mime.MimeType) { fake.getTemporalLayerForSpatialFpsMutex.RLock() defer fake.getTemporalLayerForSpatialFpsMutex.RUnlock() argsForCall := fake.getTemporalLayerForSpatialFpsArgsForCall[i] diff --git a/pkg/rtc/utils.go b/pkg/rtc/utils.go index 1293cd42b..b6d0c29bf 100644 --- a/pkg/rtc/utils.go +++ b/pkg/rtc/utils.go @@ -23,6 +23,7 @@ import ( "github.com/pion/webrtc/v4" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) @@ -195,9 +196,9 @@ func LoggerWithPCTarget(l logger.Logger, target livekit.SignalTarget) logger.Log return l.WithValues("transport", target) } -func LoggerWithCodecMime(l logger.Logger, mime string) logger.Logger { - if mime != "" { - return l.WithValues("mime", mime) +func LoggerWithCodecMime(l logger.Logger, mimeType mime.MimeType) logger.Logger { + if mimeType != mime.MimeTypeUnknown { + return l.WithValues("mime", mimeType.String()) } return l } diff --git a/pkg/rtc/wrappedreceiver.go b/pkg/rtc/wrappedreceiver.go index 1f8fc81f2..a4d41a715 100644 --- a/pkg/rtc/wrappedreceiver.go +++ b/pkg/rtc/wrappedreceiver.go @@ -16,7 +16,6 @@ package rtc import ( "errors" - "strings" "sync" "github.com/pion/webrtc/v4" @@ -28,6 +27,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/mime" ) // wrapper around WebRTC receiver, overriding its ID @@ -59,13 +59,14 @@ func NewWrappedReceiver(params WrappedReceiverParams) *WrappedReceiver { codecs := params.UpstreamCodecs if len(codecs) == 1 { - if strings.EqualFold(codecs[0].MimeType, sfu.MimeTypeAudioRed) { + normalizedMimeType := mime.NormalizeMimeType(codecs[0].MimeType) + if normalizedMimeType == mime.MimeTypeRED { // if upstream is opus/red, then add opus to match clients that don't support red codecs = append(codecs, webrtc.RTPCodecParameters{ RTPCodecCapability: OpusCodecCapability, PayloadType: 111, }) - } else if !params.DisableRed && strings.EqualFold(codecs[0].MimeType, webrtc.MimeTypeOpus) { + } else if !params.DisableRed && normalizedMimeType == mime.MimeTypeOpus { // if upstream is opus only and red enabled, add red to match clients that support red codecs = append(codecs, webrtc.RTPCodecParameters{ RTPCodecCapability: RedCodecCapability, @@ -95,16 +96,18 @@ func (r *WrappedReceiver) StreamID() string { func (r *WrappedReceiver) DetermineReceiver(codec webrtc.RTPCodecCapability) bool { r.lock.Lock() + codecMimeType := mime.NormalizeMimeType(codec.MimeType) var trackReceiver sfu.TrackReceiver for _, receiver := range r.receivers { - if c := receiver.Codec(); strings.EqualFold(c.MimeType, codec.MimeType) { + receiverMimeType := receiver.Mime() + if receiverMimeType == codecMimeType { trackReceiver = receiver break - } else if strings.EqualFold(c.MimeType, sfu.MimeTypeAudioRed) && strings.EqualFold(codec.MimeType, webrtc.MimeTypeOpus) { + } else if receiverMimeType == mime.MimeTypeRED && codecMimeType == mime.MimeTypeOpus { // audio opus/red can match opus only trackReceiver = receiver.GetPrimaryReceiverForRed() break - } else if strings.EqualFold(c.MimeType, webrtc.MimeTypeOpus) && strings.EqualFold(codec.MimeType, sfu.MimeTypeAudioRed) { + } else if receiverMimeType == mime.MimeTypeOpus && codecMimeType == mime.MimeTypeRED { trackReceiver = receiver.GetRedReceiver() break } @@ -263,6 +266,13 @@ func (d *DummyReceiver) Codec() webrtc.RTPCodecParameters { return d.codec } +func (d *DummyReceiver) Mime() mime.MimeType { + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + return r.Mime() + } + return mime.NormalizeMimeType(d.codec.MimeType) +} + func (d *DummyReceiver) HeaderExtensions() []webrtc.RTPHeaderExtensionParameter { if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { return r.HeaderExtensions() diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 95233ddaa..427ca1c11 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -32,6 +32,7 @@ import ( "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/sfu/audio" + "github.com/livekit/livekit-server/pkg/sfu/mime" act "github.com/livekit/livekit-server/pkg/sfu/rtpextension/abscapturetime" dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" @@ -94,7 +95,7 @@ type Buffer struct { rtpParameters webrtc.RTPParameters payloadType uint8 rtxPayloadType uint8 - mime string + mime mime.MimeType snRangeMap *utils.RangeMap[uint64, uint64] @@ -218,10 +219,10 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili b.clockRate = codec.ClockRate b.lastReport = mono.UnixNano() - b.mime = strings.ToLower(codec.MimeType) + b.mime = mime.NormalizeMimeType(codec.MimeType) b.rtpParameters = params for _, codecParameter := range params.Codecs { - if strings.EqualFold(codecParameter.MimeType, codec.MimeType) { + if mime.IsMimeTypeStringEqual(codecParameter.MimeType, codec.MimeType) { b.payloadType = uint8(codecParameter.PayloadType) break } @@ -234,7 +235,7 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili // find RTX payload type for _, codec := range params.Codecs { - if strings.EqualFold(codec.MimeType, "video/rtx") && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", b.payloadType)) { + if mime.IsMimeTypeStringRTX(codec.MimeType) && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", b.payloadType)) { b.rtxPayloadType = uint8(codec.PayloadType) break } @@ -248,7 +249,7 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili continue } b.ddExtID = uint8(ext.ID) - b.createDDParserAndFrameRateCalculator(codec.MimeType) + b.createDDParserAndFrameRateCalculator() case sdp.AudioLevelURI: b.audioLevelExtID = uint8(ext.ID) @@ -260,15 +261,15 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili } switch { - case strings.HasPrefix(b.mime, "audio/"): + case mime.IsMimeTypeAudio(b.mime): b.codecType = webrtc.RTPCodecTypeAudio b.bucket = bucket.NewBucket[uint64](InitPacketBufferSizeAudio) - case strings.HasPrefix(b.mime, "video/"): + case mime.IsMimeTypeVideo(b.mime): b.codecType = webrtc.RTPCodecTypeVideo b.bucket = bucket.NewBucket[uint64](InitPacketBufferSizeVideo) if b.frameRateCalculator[0] == nil { - b.createFrameRateCalculator(codec.MimeType) + b.createFrameRateCalculator() } if bitrates > 0 { pps := bitrates / 8 / 1200 @@ -292,7 +293,7 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili // pion use a single mediaengine to manage negotiated codecs of peerconnection, that means we can't have different // codec settings at track level for same codec type, so enable nack for all audio receivers but don't create nack queue // for red codec. - if strings.EqualFold(b.mime, "audio/red") { + if b.mime == mime.MimeTypeRED { break } b.logger.Debugw("Setting feedback", "type", webrtc.TypeRTCPFBNACK) @@ -313,8 +314,8 @@ func (b *Buffer) OnCodecChange(fn func(webrtc.RTPCodecParameters)) { b.Unlock() } -func (b *Buffer) createDDParserAndFrameRateCalculator(mime string) { - if IsSvcCodec(mime) || strings.EqualFold(mime, webrtc.MimeTypeVP8) { +func (b *Buffer) createDDParserAndFrameRateCalculator() { + if mime.IsMimeTypeSVC(b.mime) || b.mime == mime.MimeTypeVP8 { frc := NewFrameRateCalculatorDD(b.clockRate, b.logger) for i := range b.frameRateCalculator { b.frameRateCalculator[i] = frc.GetFrameRateCalculatorForSpatial(int32(i)) @@ -325,18 +326,18 @@ func (b *Buffer) createDDParserAndFrameRateCalculator(mime string) { } } -func (b *Buffer) createFrameRateCalculator(mime string) { - switch { - case strings.EqualFold(mime, webrtc.MimeTypeVP8): +func (b *Buffer) createFrameRateCalculator() { + switch b.mime { + case mime.MimeTypeVP8: b.frameRateCalculator[0] = NewFrameRateCalculatorVP8(b.clockRate, b.logger) - case strings.EqualFold(mime, webrtc.MimeTypeVP9): + case mime.MimeTypeVP9: frc := NewFrameRateCalculatorVP9(b.clockRate, b.logger) for i := range b.frameRateCalculator { b.frameRateCalculator[i] = frc.GetFrameRateCalculatorForSpatial(int32(i)) } - case strings.EqualFold(mime, webrtc.MimeTypeH265): + case mime.MimeTypeH265: b.frameRateCalculator[0] = NewFrameRateCalculatorH26x(b.clockRate, b.logger) } } @@ -763,7 +764,7 @@ func (b *Buffer) handleCodecChange(newPT uint8) { codecFound = true } - if strings.EqualFold(codec.MimeType, "video/rtx") && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", newPT)) { + if mime.IsMimeTypeStringRTX(codec.MimeType) && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", newPT)) { rtxFound = true rtxPt = uint8(codec.PayloadType) } @@ -776,21 +777,22 @@ func (b *Buffer) handleCodecChange(newPT uint8) { b.logger.Errorw("could not find codec for new payload type", nil, "pt", newPT, "rtpParameters", b.rtpParameters) return } - b.logger.Infow("codec changed", + b.logger.Infow( + "codec changed", "oldPayload", b.payloadType, "newPayload", newPT, "oldRtxPayload", b.rtxPayloadType, "newRtxPayload", rtxPt, "oldMime", b.mime, "newMime", newCodec.MimeType) b.payloadType = newPT b.rtxPayloadType = rtxPt - b.mime = strings.ToLower(newCodec.MimeType) + b.mime = mime.NormalizeMimeType(newCodec.MimeType) b.frameRateCalculated = false if b.ddExtID != 0 { - b.createDDParserAndFrameRateCalculator(b.mime) + b.createDDParserAndFrameRateCalculator() } if b.frameRateCalculator[0] == nil { - b.createFrameRateCalculator(b.mime) + b.createFrameRateCalculator() } b.bucket.ResyncOnNextPacket() @@ -874,8 +876,8 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64, flowStat } } - switch utils.MatchMimeType(b.mime) { - case utils.MimeTypeVP8: + switch b.mime { + case mime.MimeTypeVP8: vp8Packet := VP8{} if err := vp8Packet.Unmarshal(rtpPacket.Payload); err != nil { b.logger.Warnw("could not unmarshal VP8 packet", err) @@ -891,7 +893,7 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64, flowStat ep.Payload = vp8Packet ep.Spatial = InvalidLayerSpatial // vp8 don't have spatial scalability, reset to invalid - case utils.MimeTypeVP9: + case mime.MimeTypeVP9: if ep.DependencyDescriptor == nil { var vp9Packet codecs.VP9Packet _, err := vp9Packet.Unmarshal(rtpPacket.Payload) @@ -907,14 +909,14 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64, flowStat } ep.KeyFrame = IsVP9KeyFrame(rtpPacket.Payload) - case utils.MimeTypeH264: + case mime.MimeTypeH264: ep.KeyFrame = IsH264KeyFrame(rtpPacket.Payload) ep.Spatial = InvalidLayerSpatial // h.264 don't have spatial scalability, reset to invalid - case utils.MimeTypeAV1: + case mime.MimeTypeAV1: ep.KeyFrame = IsAV1KeyFrame(rtpPacket.Payload) - case utils.MimeTypeH265: + case mime.MimeTypeH265: if ep.DependencyDescriptor == nil { if len(rtpPacket.Payload) < 2 { b.logger.Warnw("invalid H265 packet", nil) @@ -1218,19 +1220,3 @@ func (b *Buffer) GetTemporalLayerFpsForSpatial(layer int32) []float32 { } // --------------------------------------------------------------- - -// SVC-TODO: Have to use more conditions to differentiate between -// SVC-TODO: SVC and non-SVC (could be single layer or simulcast). -// SVC-TODO: May only need to differentiate between simulcast and non-simulcast -// SVC-TODO: i. e. may be possible to treat single layer as SVC to get proper/intended functionality. -func IsSvcCodec(mime string) bool { - switch utils.MatchMimeType(mime) { - case utils.MimeTypeAV1, utils.MimeTypeVP9: - return true - } - return false -} - -func IsRedCodec(mime string) bool { - return strings.HasSuffix(strings.ToLower(mime), "red") -} diff --git a/pkg/sfu/connectionquality/connectionstats.go b/pkg/sfu/connectionquality/connectionstats.go index eb254831a..f4f09b7c5 100644 --- a/pkg/sfu/connectionquality/connectionstats.go +++ b/pkg/sfu/connectionquality/connectionstats.go @@ -15,12 +15,10 @@ package connectionquality import ( - "strings" "sync" "time" "github.com/frostbyte73/core" - "github.com/pion/webrtc/v4" "go.uber.org/atomic" "google.golang.org/protobuf/types/known/timestamppb" @@ -28,6 +26,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" ) @@ -60,7 +59,7 @@ type ConnectionStatsParams struct { type ConnectionStats struct { params ConnectionStatsParams - codecMimeType atomic.String + codecMimeType atomic.Value // mime.MimeType isStarted atomic.Bool isVideo atomic.Bool @@ -88,19 +87,19 @@ func NewConnectionStats(params ConnectionStatsParams) *ConnectionStats { } } -func (cs *ConnectionStats) StartAt(codecMimeType string, isFECEnabled bool, at time.Time) { +func (cs *ConnectionStats) StartAt(codecMimeType mime.MimeType, isFECEnabled bool, at time.Time) { if cs.isStarted.Swap(true) { return } - cs.isVideo.Store(strings.HasPrefix(strings.ToLower(codecMimeType), "video/")) + cs.isVideo.Store(mime.IsMimeTypeVideo(codecMimeType)) cs.codecMimeType.Store(codecMimeType) cs.scorer.StartAt(getPacketLossWeight(codecMimeType, isFECEnabled), at) go cs.updateStatsWorker() } -func (cs *ConnectionStats) Start(codecMimeType string, isFECEnabled bool) { +func (cs *ConnectionStats) Start(codecMimeType mime.MimeType, isFECEnabled bool) { cs.StartAt(codecMimeType, isFECEnabled, time.Now()) } @@ -108,6 +107,12 @@ func (cs *ConnectionStats) Close() { cs.done.Break() } +func (cs *ConnectionStats) UpdateCodec(codecMimeType mime.MimeType, isFECEnabled bool) { + cs.isVideo.Store(mime.IsMimeTypeVideo(codecMimeType)) + cs.codecMimeType.Store(codecMimeType) + cs.scorer.UpdatePacketLossWeight(getPacketLossWeight(codecMimeType, isFECEnabled)) +} + func (cs *ConnectionStats) OnStatsUpdate(fn func(cs *ConnectionStats, stat *livekit.AnalyticsStat)) { cs.onStatsUpdate = fn } @@ -340,7 +345,7 @@ func (cs *ConnectionStats) getStat() { cs.onStatsUpdate(cs, &livekit.AnalyticsStat{ Score: score, Streams: analyticsStreams, - Mime: cs.codecMimeType.Load(), + Mime: cs.codecMimeType.Load().(mime.MimeType).String(), }) } } @@ -381,10 +386,10 @@ func (cs *ConnectionStats) updateStatsWorker() { // For video: // // o No in-built codec repair available, hence same for all codecs -func getPacketLossWeight(mimeType string, isFecEnabled bool) float64 { +func getPacketLossWeight(mimeType mime.MimeType, isFecEnabled bool) float64 { var plw float64 switch { - case strings.EqualFold(mimeType, webrtc.MimeTypeOpus): + case mimeType == mime.MimeTypeOpus: // 2.5%: fall to GOOD, 7.5%: fall to POOR plw = 8.0 if isFecEnabled { @@ -392,7 +397,7 @@ func getPacketLossWeight(mimeType string, isFecEnabled bool) float64 { plw /= 1.5 } - case strings.EqualFold(mimeType, "audio/red"): + case mimeType == mime.MimeTypeRED: // 5%: fall to GOOD, 15.0%: fall to POOR plw = 4.0 if isFecEnabled { @@ -400,7 +405,7 @@ func getPacketLossWeight(mimeType string, isFecEnabled bool) float64 { plw /= 1.5 } - case strings.HasPrefix(strings.ToLower(mimeType), "video/"): + case mime.IsMimeTypeVideo(mimeType): // 2%: fall to GOOD, 6%: fall to POOR plw = 10.0 } diff --git a/pkg/sfu/connectionquality/connectionstats_test.go b/pkg/sfu/connectionquality/connectionstats_test.go index ca380eb72..1e74eb41e 100644 --- a/pkg/sfu/connectionquality/connectionstats_test.go +++ b/pkg/sfu/connectionquality/connectionstats_test.go @@ -19,10 +19,10 @@ import ( "testing" "time" - "github.com/pion/webrtc/v4" "github.com/stretchr/testify/require" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -70,7 +70,7 @@ func TestConnectionQuality(t *testing.T) { duration := 5 * time.Second now := time.Now() - cs.StartAt(webrtc.MimeTypeOpus, false, now.Add(-duration)) + cs.StartAt(mime.MimeTypeOpus, false, now.Add(-duration)) cs.UpdateMuteAt(false, now.Add(-1*time.Second)) // no data and not enough unmute time should return default state which is EXCELLENT quality @@ -484,7 +484,7 @@ func TestConnectionQuality(t *testing.T) { duration := 5 * time.Second now := time.Now() - cs.StartAt(webrtc.MimeTypeOpus, false, now.Add(-duration)) + cs.StartAt(mime.MimeTypeOpus, false, now.Add(-duration)) cs.UpdateMuteAt(false, now.Add(-1*time.Second)) // RTT does not knock quality down because it is dependent and hence not taken into account @@ -517,7 +517,7 @@ func TestConnectionQuality(t *testing.T) { duration := 5 * time.Second now := time.Now() - cs.StartAt(webrtc.MimeTypeOpus, false, now.Add(-duration)) + cs.StartAt(mime.MimeTypeOpus, false, now.Add(-duration)) cs.UpdateMuteAt(false, now.Add(-1*time.Second)) // Jitter does not knock quality down because it is dependent and hence not taken into account @@ -548,7 +548,7 @@ func TestConnectionQuality(t *testing.T) { } testCases := []struct { name string - mimeType string + mimeType mime.MimeType isFECEnabled bool packetsExpected uint32 expectedQualities []expectedQuality @@ -557,7 +557,7 @@ func TestConnectionQuality(t *testing.T) { // "audio/opus" - no fec - 0 <= loss < 2.5%: EXCELLENT, 2.5% <= loss < 7.5%: GOOD, >= 7.5%: POOR { name: "audio/opus - no fec", - mimeType: "audio/opus", + mimeType: mime.MimeTypeOpus, isFECEnabled: false, packetsExpected: 200, expectedQualities: []expectedQuality{ @@ -581,7 +581,7 @@ func TestConnectionQuality(t *testing.T) { // "audio/opus" - fec - 0 <= loss < 3.75%: EXCELLENT, 3.75% <= loss < 11.25%: GOOD, >= 11.25%: POOR { name: "audio/opus - fec", - mimeType: "audio/opus", + mimeType: mime.MimeTypeOpus, isFECEnabled: true, packetsExpected: 200, expectedQualities: []expectedQuality{ @@ -605,7 +605,7 @@ func TestConnectionQuality(t *testing.T) { // "audio/red" - no fec - 0 <= loss < 5%: EXCELLENT, 5% <= loss < 15%: GOOD, >= 15%: POOR { name: "audio/red - no fec", - mimeType: "audio/red", + mimeType: mime.MimeTypeRED, isFECEnabled: false, packetsExpected: 200, expectedQualities: []expectedQuality{ @@ -629,7 +629,7 @@ func TestConnectionQuality(t *testing.T) { // "audio/red" - fec - 0 <= loss < 7.5%: EXCELLENT, 7.5% <= loss < 22.5%: GOOD, >= 22.5%: POOR { name: "audio/red - fec", - mimeType: "audio/red", + mimeType: mime.MimeTypeRED, isFECEnabled: true, packetsExpected: 200, expectedQualities: []expectedQuality{ @@ -653,7 +653,7 @@ func TestConnectionQuality(t *testing.T) { // "video/*" - 0 <= loss < 2%: EXCELLENT, 2% <= loss < 6%: GOOD, >= 6%: POOR { name: "video/*", - mimeType: "video/vp8", + mimeType: mime.MimeTypeVP8, isFECEnabled: false, packetsExpected: 200, expectedQualities: []expectedQuality{ @@ -786,7 +786,7 @@ func TestConnectionQuality(t *testing.T) { duration := 5 * time.Second now := time.Now() - cs.StartAt(webrtc.MimeTypeVP8, false, now) + cs.StartAt(mime.MimeTypeVP8, false, now) for _, tr := range tc.transitions { cs.AddBitrateTransitionAt(tr.bitrate, now.Add(tr.offset)) @@ -878,7 +878,7 @@ func TestConnectionQuality(t *testing.T) { duration := 5 * time.Second now := time.Now() - cs.StartAt(webrtc.MimeTypeVP8, false, now) + cs.StartAt(mime.MimeTypeVP8, false, now) for _, tr := range tc.transitions { cs.AddLayerTransitionAt(tr.distance, now.Add(tr.offset)) diff --git a/pkg/sfu/connectionquality/scorer.go b/pkg/sfu/connectionquality/scorer.go index ebafcd9bd..af9290027 100644 --- a/pkg/sfu/connectionquality/scorer.go +++ b/pkg/sfu/connectionquality/scorer.go @@ -258,6 +258,13 @@ func (q *qualityScorer) Start(packetLossWeight float64) { q.startAtLocked(packetLossWeight, time.Now()) } +func (q *qualityScorer) UpdatePacketLossWeight(packetLossWeight float64) { + q.lock.Lock() + defer q.lock.Unlock() + + q.packetLossWeight = packetLossWeight +} + func (q *qualityScorer) updateMuteAtLocked(isMuted bool, at time.Time) { if isMuted { q.mutedAt = at diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 4505c76f2..4faa09886 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -40,6 +40,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/ccutils" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/pacer" act "github.com/livekit/livekit-server/pkg/sfu/rtpextension/abscapturetime" dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" @@ -253,6 +254,7 @@ type DownTrack struct { params DowntrackParams id livekit.TrackID kind webrtc.RTPCodecType + mime mime.MimeType ssrc uint32 ssrcRTX uint32 payloadType atomic.Uint32 @@ -345,11 +347,12 @@ type DownTrack struct { // NewDownTrack returns a DownTrack. func NewDownTrack(params DowntrackParams) (*DownTrack, error) { codecs := params.Codecs + mimeType := mime.NormalizeMimeType(codecs[0].MimeType) var kind webrtc.RTPCodecType switch { - case strings.HasPrefix(codecs[0].MimeType, "audio/"): + case mime.IsMimeTypeAudio(mimeType): kind = webrtc.RTPCodecTypeAudio - case strings.HasPrefix(codecs[0].MimeType, "video/"): + case mime.IsMimeTypeVideo(mimeType): kind = webrtc.RTPCodecTypeVideo default: kind = webrtc.RTPCodecType(0) @@ -505,13 +508,13 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, } isFECEnabled := false - if strings.EqualFold(matchedUpstreamCodec.MimeType, MimeTypeAudioRed) { + if mime.IsMimeTypeStringRED(matchedUpstreamCodec.MimeType) { d.isRED = true for _, c := range d.upstreamCodecs { isFECEnabled = strings.Contains(strings.ToLower(c.SDPFmtpLine), "useinbandfec=1") // assume upstream primary codec is opus since we only support it for audio now - if strings.EqualFold(c.MimeType, webrtc.MimeTypeOpus) { + if mime.IsMimeTypeStringOpus(c.MimeType) { d.upstreamPrimaryPT = uint8(c.PayloadType) break } @@ -525,7 +528,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.params.Logger.Errorw("failed to parse primary and secondary payload type for RED", err, "matchedCodec", codec) } d.primaryPT = uint8(primaryPT) - } else if strings.HasPrefix(strings.ToLower(matchedUpstreamCodec.MimeType), "audio/") { + } else if mime.IsMimeTypeStringAudio(matchedUpstreamCodec.MimeType) { isFECEnabled = strings.Contains(strings.ToLower(matchedUpstreamCodec.SDPFmtpLine), "fec") } @@ -558,6 +561,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.params.Logger.Debugw("DownTrack.Bind", logFields...) d.writeStream = t.WriteStream() + d.mime = mime.NormalizeMimeType(codec.MimeType) if rr := d.params.BufferFactory.GetOrNew(packetio.RTCPBufferPacket, d.ssrc).(*buffer.RTCPReader); rr != nil { rr.OnPacket(func(pkt []byte) { d.handleRTCP(pkt) @@ -580,10 +584,11 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.onBinding(nil) } d.setBindStateLocked(bindStateBound) + mimeType := d.mime d.bindLock.Unlock() d.forwarder.DetermineCodec(codec.RTPCodecCapability, d.Receiver().HeaderExtensions()) - d.connectionStats.Start(codec.MimeType, isFECEnabled) + d.connectionStats.Start(mimeType, isFECEnabled) d.params.Logger.Debugw("downtrack bound") } @@ -635,9 +640,9 @@ func (d *DownTrack) handleReceiverReady() { } } -func (d *DownTrack) handleUpstreamCodecChange(mime string) { +func (d *DownTrack) handleUpstreamCodecChange(mimeType string) { d.bindLock.Lock() - if strings.EqualFold(d.codec.MimeType, mime) { + if mime.IsMimeTypeStringEqual(d.codec.MimeType, mimeType) { d.bindLock.Unlock() return } @@ -653,7 +658,7 @@ func (d *DownTrack) handleUpstreamCodecChange(mime string) { var codec webrtc.RTPCodecParameters for _, c := range d.upstreamCodecs { - if !strings.EqualFold(c.MimeType, mime) { + if !mime.IsMimeTypeStringEqual(d.codec.MimeType, mimeType) { continue } @@ -670,7 +675,7 @@ func (d *DownTrack) handleUpstreamCodecChange(mime string) { "can't find matched codec for new upstream payload type", nil, "upstreamCodecs", d.upstreamCodecs, "remoteParameters", d.negotiatedCodecParameters, - "mime", mime, + "mime", mimeType, ) d.bindLock.Unlock() return @@ -679,6 +684,9 @@ func (d *DownTrack) handleUpstreamCodecChange(mime string) { d.payloadType.Store(uint32(codec.PayloadType)) d.payloadTypeRTX.Store(uint32(utils.FindRTXPayloadType(codec.PayloadType, d.negotiatedCodecParameters))) d.codec = codec.RTPCodecCapability + d.mime = mime.NormalizeMimeType(codec.MimeType) + newMimeType := d.mime + isFECEnabled := strings.Contains(strings.ToLower(d.codec.SDPFmtpLine), "fec") d.bindLock.Unlock() d.params.Logger.Infow( @@ -690,6 +698,7 @@ func (d *DownTrack) handleUpstreamCodecChange(mime string) { d.forwarder.Restart() d.forwarder.DetermineCodec(codec.RTPCodecCapability, d.Receiver().HeaderExtensions()) + d.connectionStats.UpdateCodec(newMimeType, isFECEnabled) } // Unbind implements the teardown logic when the track is no longer needed. This happens @@ -744,6 +753,12 @@ func (d *DownTrack) Codec() webrtc.RTPCodecCapability { return d.codec } +func (d *DownTrack) Mime() mime.MimeType { + d.bindLock.Lock() + defer d.bindLock.Unlock() + return d.mime +} + // StreamID is the group this track belongs too. This must be unique func (d *DownTrack) StreamID() string { return d.params.StreamID } @@ -1679,16 +1694,15 @@ func (d *DownTrack) writeBlankFrameRTP(duration float32, generation uint32) chan return } - mime := strings.ToLower(d.Codec().MimeType) var getBlankFrame func(bool) ([]byte, error) - switch { - case strings.EqualFold(mime, webrtc.MimeTypeOpus): + switch d.mime { + case mime.MimeTypeOpus: getBlankFrame = d.getOpusBlankFrame - case strings.EqualFold(mime, MimeTypeAudioRed): + case mime.MimeTypeRED: getBlankFrame = d.getOpusRedBlankFrame - case strings.EqualFold(mime, webrtc.MimeTypeVP8): + case mime.MimeTypeVP8: getBlankFrame = d.getVP8BlankFrame - case strings.EqualFold(mime, webrtc.MimeTypeH264): + case mime.MimeTypeH264: getBlankFrame = d.getH264BlankFrame default: close(done) @@ -1696,7 +1710,7 @@ func (d *DownTrack) writeBlankFrameRTP(duration float32, generation uint32) chan } frameRate := uint32(30) - if mime == strings.ToLower(webrtc.MimeTypeOpus) || mime == strings.ToLower(MimeTypeAudioRed) { + if d.mime == mime.MimeTypeOpus || d.mime == mime.MimeTypeRED { frameRate = 50 } @@ -2414,7 +2428,7 @@ func (d *DownTrack) sendPaddingOnMute() { if d.kind == webrtc.RTPCodecTypeVideo { d.sendPaddingOnMuteForVideo() - } else if strings.EqualFold(d.Codec().MimeType, webrtc.MimeTypeOpus) { + } else if d.mime == mime.MimeTypeOpus { d.sendSilentFrameOnMuteForOpus() } } diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index e0e9f590d..df8ba0a08 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -19,7 +19,6 @@ import ( "fmt" "math" "math/rand" - "strings" "sync" "time" @@ -35,6 +34,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/codecmunger" + "github.com/livekit/livekit-server/pkg/sfu/mime" dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" "github.com/livekit/livekit-server/pkg/sfu/videolayerselector" @@ -210,7 +210,8 @@ func (r refInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { type Forwarder struct { lock sync.RWMutex - codec webrtc.RTPCodecCapability + mime mime.MimeType + clockRate uint32 kind webrtc.RTPCodecType logger logger.Logger skipReferenceTS bool @@ -249,6 +250,7 @@ func NewForwarder( rtpStats *rtpstats.RTPStatsSender, ) *Forwarder { f := &Forwarder{ + mime: mime.MimeTypeUnknown, kind: kind, logger: logger, skipReferenceTS: skipReferenceTS, @@ -299,11 +301,13 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ f.lock.Lock() defer f.lock.Unlock() - codecChanged := f.codec.MimeType != "" && f.codec.MimeType != codec.MimeType + toMimeType := mime.NormalizeMimeType(codec.MimeType) + codecChanged := f.mime != mime.MimeTypeUnknown && f.mime != toMimeType if codecChanged { - f.logger.Debugw("forwarder codec changed", "from", f.codec.MimeType, "to", codec.MimeType) + f.logger.Debugw("forwarder codec changed", "from", f.mime, "to", toMimeType) } - f.codec = codec + f.mime = toMimeType + f.clockRate = codec.ClockRate ddAvailable := func(exts []webrtc.RTPHeaderExtensionParameter) bool { for _, ext := range exts { @@ -314,8 +318,8 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ return false } - switch strings.ToLower(codec.MimeType) { - case "video/vp8": + switch f.mime { + case mime.MimeTypeVP8: f.codecMunger = codecmunger.NewVP8FromNull(f.codecMunger, f.logger) if f.vls != nil { if vls := videolayerselector.NewSimulcastFromOther(f.vls); vls != nil { @@ -328,16 +332,14 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ } f.vls.SetTemporalLayerSelector(temporallayerselector.NewVP8(f.logger)) - case "video/h264": - fallthrough - case "video/h265": + case mime.MimeTypeH264, mime.MimeTypeH265: if f.vls != nil { f.vls = videolayerselector.NewSimulcastFromOther(f.vls) } else { f.vls = videolayerselector.NewSimulcast(f.logger) } - case "video/vp9": + case mime.MimeTypeVP9: // DD-TODO : we only enable dd layer selector for av1/vp9 now, in the future we can enable it for vp8 too isDDAvailable := ddAvailable(extensions) if isDDAvailable { @@ -355,7 +357,7 @@ func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions [ } // SVC-TODO: Support for VP9 simulcast. When DD is not available, have to pick selector based on VP9 SVC or Simulcast - case "video/av1": + case mime.MimeTypeAV1: // DD-TODO : we only enable dd layer selector for av1/vp9 now, in the future we can enable it for vp8 too isDDAvailable := ddAvailable(extensions) if isDDAvailable { @@ -1513,7 +1515,7 @@ func (f *Forwarder) Pause(availableLayers []int32, brs Bitrates) VideoAllocation func (f *Forwarder) updateAllocation(alloc VideoAllocation, reason string) VideoAllocation { // restrict target temporal to 0 if codec does not support temporal layers - if alloc.TargetLayer.IsValid() && strings.ToLower(f.codec.MimeType) == "video/h264" { + if alloc.TargetLayer.IsValid() && f.mime == mime.MimeTypeH264 { alloc.TargetLayer.Temporal = 0 } @@ -1663,7 +1665,7 @@ func (f *Forwarder) getRefLayerRTPTimestamp(ts uint32, refLayer, targetLayer int } ntpDiff := mediatransportutil.NtpTime(srRef.NtpTimestamp).Time().Sub(mediatransportutil.NtpTime(srTarget.NtpTimestamp).Time()) - rtpDiff := ntpDiff.Nanoseconds() * int64(f.codec.ClockRate) / 1e9 + rtpDiff := ntpDiff.Nanoseconds() * int64(f.clockRate) / 1e9 // calculate other layer's time stamp at the same time as ref layer's NTP time normalizedOtherTS := srTarget.RtpTimestamp + uint32(rtpDiff) @@ -1795,7 +1797,7 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e } else { if !f.preStartTime.IsZero() { timeSinceFirst := time.Since(f.preStartTime) - rtpDiff := uint64(timeSinceFirst.Nanoseconds() * int64(f.codec.ClockRate) / 1e9) + rtpDiff := uint64(timeSinceFirst.Nanoseconds() * int64(f.clockRate) / 1e9) extExpectedTS = f.extFirstTS + rtpDiff if f.dummyStartTSOffset == 0 { f.dummyStartTSOffset = extExpectedTS - uint64(refTS) @@ -1839,7 +1841,7 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e // Ideally, extRefTS should not be ahead of extExpectedTS, but extExpectedTS uses the first packet's // wall clock time. So, if the first packet experienced abmormal latency, it is possible // for extRefTS > extExpectedTS - diffSeconds := float64(int64(extExpectedTS-extRefTS)) / float64(f.codec.ClockRate) + diffSeconds := float64(int64(extExpectedTS-extRefTS)) / float64(f.clockRate) if diffSeconds >= 0.0 { if f.resumeBehindThreshold > 0 && diffSeconds > f.resumeBehindThreshold { logTransitionInfo("resume, reference too far behind", extExpectedTS, extRefTS, extLastTS, diffSeconds) @@ -1862,7 +1864,7 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e f.resumeBehindThreshold = 0.0 } else { // switching between layers, check if extRefTS is too far behind the last sent - diffSeconds := float64(int64(extRefTS-extLastTS)) / float64(f.codec.ClockRate) + diffSeconds := float64(int64(extRefTS-extLastTS)) / float64(f.clockRate) if diffSeconds < 0.0 { if math.Abs(diffSeconds) > LayerSwitchBehindThresholdSeconds { // this could be due to pacer trickling out this layer. Error out and wait for a more opportune time. @@ -1877,7 +1879,7 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e logTransition("layer switch, reference is slightly behind", extExpectedTS, extRefTS, extLastTS, diffSeconds) extNextTS = extLastTS + 1 } else { - diffSeconds = float64(int64(extRefTS-extExpectedTS)) / float64(f.codec.ClockRate) + diffSeconds = float64(int64(extRefTS-extExpectedTS)) / float64(f.clockRate) if diffSeconds > SwitchAheadThresholdSeconds { logTransition("layer switch, reference too far ahead", extExpectedTS, extRefTS, extLastTS, diffSeconds) } @@ -2154,7 +2156,7 @@ func (f *Forwarder) GetSnTsForBlankFrames(frameRate uint32, numPackets int) ([]S if int64(extExpectedTS-extLastTS) <= 0 { extExpectedTS = extLastTS + 1 } - snts, err := f.rtpMunger.UpdateAndGetPaddingSnTs(numPackets, f.codec.ClockRate, frameRate, frameEndNeeded, extExpectedTS) + snts, err := f.rtpMunger.UpdateAndGetPaddingSnTs(numPackets, f.clockRate, frameRate, frameEndNeeded, extExpectedTS) return snts, frameEndNeeded, err } diff --git a/pkg/sfu/mime/mimetype.go b/pkg/sfu/mime/mimetype.go new file mode 100644 index 000000000..5e78c57d5 --- /dev/null +++ b/pkg/sfu/mime/mimetype.go @@ -0,0 +1,271 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mime + +import ( + "strings" + + "github.com/pion/webrtc/v4" +) + +const ( + MimeTypePrefixAudio = "audio/" + MimeTypePrefixVideo = "video/" +) + +type MimeTypeCodec int + +const ( + MimeTypeCodecUnknown MimeTypeCodec = iota + MimeTypeCodecH264 + MimeTypeCodecH265 + MimeTypeCodecOpus + MimeTypeCodecRED + MimeTypeCodecVP8 + MimeTypeCodecVP9 + MimeTypeCodecAV1 + MimeTypeCodecG722 + MimeTypeCodecPCMU + MimeTypeCodecPCMA + MimeTypeCodecRTX + MimeTypeCodecFlexFEC +) + +func (m MimeTypeCodec) String() string { + switch m { + case MimeTypeCodecUnknown: + return "MimeTypeCodecUnknown" + case MimeTypeCodecH264: + return "H264" + case MimeTypeCodecH265: + return "H265" + case MimeTypeCodecOpus: + return "opus" + case MimeTypeCodecRED: + return "red" + case MimeTypeCodecVP8: + return "VP8" + case MimeTypeCodecVP9: + return "VP9" + case MimeTypeCodecAV1: + return "AV1" + case MimeTypeCodecG722: + return "G722" + case MimeTypeCodecPCMU: + return "PCMU" + case MimeTypeCodecPCMA: + return "PCMA" + case MimeTypeCodecRTX: + return "rtx" + case MimeTypeCodecFlexFEC: + return "flexfec" + } + + return "MimeTypeCodecUnknown" +} + +func NormalizeMimeTypeCodec(codec string) MimeTypeCodec { + switch { + case strings.EqualFold(codec, "h264"): + return MimeTypeCodecH264 + case strings.EqualFold(codec, "h265"): + return MimeTypeCodecH265 + case strings.EqualFold(codec, "opus"): + return MimeTypeCodecOpus + case strings.EqualFold(codec, "red"): + return MimeTypeCodecRED + case strings.EqualFold(codec, "vp8"): + return MimeTypeCodecVP8 + case strings.EqualFold(codec, "vp9"): + return MimeTypeCodecVP9 + case strings.EqualFold(codec, "av1"): + return MimeTypeCodecAV1 + case strings.EqualFold(codec, "g722"): + return MimeTypeCodecG722 + case strings.EqualFold(codec, "pcmu"): + return MimeTypeCodecPCMU + case strings.EqualFold(codec, "pcma"): + return MimeTypeCodecPCMA + case strings.EqualFold(codec, "rtx"): + return MimeTypeCodecRTX + case strings.EqualFold(codec, "flexfec"): + return MimeTypeCodecFlexFEC + } + + return MimeTypeCodecUnknown +} + +func GetMimeTypeCodec(mime string) MimeTypeCodec { + i := strings.IndexByte(mime, '/') + if i == -1 { + return MimeTypeCodecUnknown + } + + return NormalizeMimeTypeCodec(mime[i+1:]) +} + +func IsMimeTypeCodecStringOpus(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecOpus +} + +func IsMimeTypeCodecStringRED(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecRED +} + +func IsMimeTypeCodecStringH264(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecH264 +} + +type MimeType int + +const ( + MimeTypeUnknown MimeType = iota + MimeTypeH264 + MimeTypeH265 + MimeTypeOpus + MimeTypeRED + MimeTypeVP8 + MimeTypeVP9 + MimeTypeAV1 + MimeTypeG722 + MimeTypePCMU + MimeTypePCMA + MimeTypeRTX + MimeTypeFlexFEC +) + +func (m MimeType) String() string { + switch m { + case MimeTypeUnknown: + return "MimeTypeUnknown" + case MimeTypeH264: + return webrtc.MimeTypeH264 + case MimeTypeH265: + return webrtc.MimeTypeH265 + case MimeTypeOpus: + return webrtc.MimeTypeOpus + case MimeTypeRED: + return "audio/red" + case MimeTypeVP8: + return webrtc.MimeTypeVP8 + case MimeTypeVP9: + return webrtc.MimeTypeVP9 + case MimeTypeAV1: + return webrtc.MimeTypeAV1 + case MimeTypeG722: + return webrtc.MimeTypeG722 + case MimeTypePCMU: + return webrtc.MimeTypePCMU + case MimeTypePCMA: + return webrtc.MimeTypePCMA + case MimeTypeRTX: + return webrtc.MimeTypeRTX + case MimeTypeFlexFEC: + return webrtc.MimeTypeFlexFEC + } + + return "MimeTypeUnknown" +} + +func NormalizeMimeType(mime string) MimeType { + switch { + case strings.EqualFold(mime, webrtc.MimeTypeH264): + return MimeTypeH264 + case strings.EqualFold(mime, webrtc.MimeTypeH265): + return MimeTypeH265 + case strings.EqualFold(mime, webrtc.MimeTypeOpus): + return MimeTypeOpus + case strings.EqualFold(mime, "audio/red"): + return MimeTypeRED + case strings.EqualFold(mime, webrtc.MimeTypeVP8): + return MimeTypeVP8 + case strings.EqualFold(mime, webrtc.MimeTypeVP9): + return MimeTypeVP9 + case strings.EqualFold(mime, webrtc.MimeTypeAV1): + return MimeTypeAV1 + case strings.EqualFold(mime, webrtc.MimeTypeG722): + return MimeTypeG722 + case strings.EqualFold(mime, webrtc.MimeTypePCMU): + return MimeTypePCMU + case strings.EqualFold(mime, webrtc.MimeTypePCMA): + return MimeTypePCMA + case strings.EqualFold(mime, webrtc.MimeTypeRTX): + return MimeTypeRTX + case strings.EqualFold(mime, webrtc.MimeTypeFlexFEC): + return MimeTypeFlexFEC + } + + return MimeTypeUnknown +} + +func IsMimeTypeStringEqual(mime1 string, mime2 string) bool { + return NormalizeMimeType(mime1) == NormalizeMimeType(mime2) +} + +func IsMimeTypeStringAudio(mime string) bool { + return strings.HasPrefix(mime, MimeTypePrefixAudio) +} + +func IsMimeTypeAudio(mimeType MimeType) bool { + return strings.HasPrefix(mimeType.String(), MimeTypePrefixAudio) +} + +func IsMimeTypeStringVideo(mime string) bool { + return strings.HasPrefix(mime, MimeTypePrefixVideo) +} + +func IsMimeTypeVideo(mimeType MimeType) bool { + return strings.HasPrefix(mimeType.String(), MimeTypePrefixVideo) +} + +// SVC-TODO: Have to use more conditions to differentiate between +// SVC-TODO: SVC and non-SVC (could be single layer or simulcast). +// SVC-TODO: May only need to differentiate between simulcast and non-simulcast +// SVC-TODO: i. e. may be possible to treat single layer as SVC to get proper/intended functionality. +func IsMimeTypeSVC(mimeType MimeType) bool { + switch mimeType { + case MimeTypeAV1, MimeTypeVP9: + return true + } + return false +} + +func IsMimeTypeStringSVC(mime string) bool { + return IsMimeTypeSVC(NormalizeMimeType(mime)) +} + +func IsMimeTypeStringRED(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeRED +} + +func IsMimeTypeStringOpus(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeOpus +} + +func IsMimeTypeStringRTX(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeRTX +} + +func IsMimeTypeStringVP8(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeVP8 +} + +func IsMimeTypeStringVP9(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeVP9 +} + +func IsMimeTypeStringH264(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeH264 +} diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 70c8d434c..8c16675bd 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -33,6 +33,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/audio" "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/mime" dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" "github.com/livekit/livekit-server/pkg/sfu/rtpstats" "github.com/livekit/livekit-server/pkg/sfu/streamtracker" @@ -100,6 +101,7 @@ type TrackReceiver interface { // returns the initial codec of the receiver, it is determined by the track's codec // and will not change if the codec changes during the session (publisher changes codec) Codec() webrtc.RTPCodecParameters + Mime() mime.MimeType HeaderExtensions() []webrtc.RTPHeaderExtensionParameter IsClosed() bool @@ -254,8 +256,8 @@ func NewWebRTCReceiver( codecState: ReceiverCodecStateNormal, kind: track.Kind(), onRTCP: onRTCP, - isSVC: buffer.IsSvcCodec(track.Codec().MimeType), - isRED: buffer.IsRedCodec(track.Codec().MimeType), + isSVC: mime.IsMimeTypeStringSVC(track.Codec().MimeType), + isRED: mime.IsMimeTypeStringRED(track.Codec().MimeType), } for _, opt := range opts { @@ -278,9 +280,9 @@ func NewWebRTCReceiver( } }) w.connectionStats.Start( - w.codec.MimeType, + mime.NormalizeMimeType(w.codec.MimeType), // TODO: technically not correct to declare FEC on when RED. Need the primary codec's fmtp line to check. - strings.EqualFold(w.codec.MimeType, MimeTypeAudioRed) || strings.Contains(strings.ToLower(w.codec.SDPFmtpLine), "useinbandfec=1"), + mime.IsMimeTypeStringRED(w.codec.MimeType) || strings.Contains(strings.ToLower(w.codec.SDPFmtpLine), "useinbandfec=1"), ) w.streamTrackerManager = NewStreamTrackerManager(logger, trackInfo, w.isSVC, w.codec.ClockRate, streamTrackerManagerConfig) @@ -371,6 +373,10 @@ func (w *WebRTCReceiver) Codec() webrtc.RTPCodecParameters { return w.codec } +func (w *WebRTCReceiver) Mime() mime.MimeType { + return mime.NormalizeMimeType(w.codec.MimeType) +} + func (w *WebRTCReceiver) HeaderExtensions() []webrtc.RTPHeaderExtensionParameter { return w.receiver.GetParameters().HeaderExtensions } diff --git a/pkg/sfu/redprimaryreceiver.go b/pkg/sfu/redprimaryreceiver.go index 6023a89ab..f066205b3 100644 --- a/pkg/sfu/redprimaryreceiver.go +++ b/pkg/sfu/redprimaryreceiver.go @@ -27,10 +27,6 @@ import ( "github.com/livekit/protocol/logger" ) -const ( - MimeTypeAudioRed = "audio/red" -) - var ( ErrIncompleteRedHeader = errors.New("incomplete red block header") ErrIncompleteRedBlock = errors.New("incomplete red block payload") diff --git a/pkg/sfu/utils/helpers.go b/pkg/sfu/utils/helpers.go index 09990c183..d237d511a 100644 --- a/pkg/sfu/utils/helpers.go +++ b/pkg/sfu/utils/helpers.go @@ -17,8 +17,8 @@ package utils import ( "errors" "fmt" - "strings" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/pion/interceptor" "github.com/pion/rtp" "github.com/pion/webrtc/v4" @@ -29,7 +29,7 @@ import ( func CodecParametersFuzzySearch(needle webrtc.RTPCodecParameters, haystack []webrtc.RTPCodecParameters) (webrtc.RTPCodecParameters, error) { // First attempt to match on MimeType + SDPFmtpLine for _, c := range haystack { - if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) && + if mime.IsMimeTypeStringEqual(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) && c.RTPCodecCapability.SDPFmtpLine == needle.RTPCodecCapability.SDPFmtpLine { return c, nil } @@ -37,7 +37,7 @@ func CodecParametersFuzzySearch(needle webrtc.RTPCodecParameters, haystack []web // Fallback to just MimeType for _, c := range haystack { - if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) { + if mime.IsMimeTypeStringEqual(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) { return c, nil } } diff --git a/pkg/sfu/utils/mimetype.go b/pkg/sfu/utils/mimetype.go deleted file mode 100644 index cf7735b88..000000000 --- a/pkg/sfu/utils/mimetype.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2023 LiveKit, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -type MimeType int - -const ( - MimeTypeUnknown MimeType = iota - MimeTypeVP8 - MimeTypeVP9 - MimeTypeH264 - MimeTypeAV1 - MimeTypeH265 -) - -func MatchMimeType(mimeType string) MimeType { - switch len(mimeType) { - case 9: - switch mimeType[0] { - case 'v', 'V': - switch mimeType[1] { - case 'i', 'I': - switch mimeType[2] { - case 'd', 'D': - switch mimeType[3] { - case 'e', 'E': - switch mimeType[4] { - case 'o', 'O': - switch mimeType[5] { - case '/': - switch mimeType[6] { - case 'v', 'V': - switch mimeType[7] { - case 'p', 'P': - switch mimeType[8] { - case '8': - return MimeTypeVP8 - case '9': - return MimeTypeVP9 - } - } - case 'a', 'A': - switch mimeType[7] { - case 'v', 'V': - switch mimeType[8] { - case '1': - return MimeTypeAV1 - } - } - } - } - } - } - } - } - } - case 10: - switch mimeType[0] { - case 'v', 'V': - switch mimeType[1] { - case 'i', 'I': - switch mimeType[2] { - case 'd', 'D': - switch mimeType[3] { - case 'e', 'E': - switch mimeType[4] { - case 'o', 'O': - switch mimeType[5] { - case '/': - switch mimeType[6] { - case 'h', 'H': - switch mimeType[7] { - case '2': - switch mimeType[8] { - case '6': - switch mimeType[9] { - case '4': - return MimeTypeH264 - case '5': - return MimeTypeH265 - } - } - } - } - } - } - } - } - } - } - } - return MimeTypeUnknown -} diff --git a/pkg/sfu/utils/mimetype_test.go b/pkg/sfu/utils/mimetype_test.go deleted file mode 100644 index 6cfcc7111..000000000 --- a/pkg/sfu/utils/mimetype_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2023 LiveKit, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/require" -) - -func toLowerSwitch(mimeType string) MimeType { - switch strings.ToLower(mimeType) { - case "video/vp8": - return MimeTypeVP8 - case "video/vp9": - return MimeTypeVP9 - case "video/h264": - return MimeTypeH264 - case "video/av1": - return MimeTypeAV1 - default: - return MimeTypeUnknown - } -} - -func TestMimeTypeMatch(t *testing.T) { - require.Equal(t, MimeTypeVP8, MatchMimeType("VIDEO/VP8"), "VIDEO/VP8") - require.Equal(t, MimeTypeVP9, MatchMimeType("VIDEO/VP9"), "VIDEO/VP9") - require.Equal(t, MimeTypeH264, MatchMimeType("VIDEO/H264"), "VIDEO/H264") - require.Equal(t, MimeTypeAV1, MatchMimeType("VIDEO/AV1"), "VIDEO/AV1") -} - -func BenchmarkMimeTypeMatch(b *testing.B) { - mimeTypes := []string{ - "video/VP8", - "video/VP9", - "video/H264", - "video/AV1", - } - - b.Run("ToLower/switch", func(b *testing.B) { - for i := range b.N { - _ = toLowerSwitch(mimeTypes[i%len(mimeTypes)]) - } - }) - - b.Run("MatchMimeType", func(b *testing.B) { - for i := range b.N { - _ = MatchMimeType(mimeTypes[i%len(mimeTypes)]) - } - }) -} diff --git a/pkg/telemetry/events.go b/pkg/telemetry/events.go index 1e59e0b9f..b6e6c703c 100644 --- a/pkg/telemetry/events.go +++ b/pkg/telemetry/events.go @@ -20,6 +20,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -256,14 +257,14 @@ func (t *telemetryService) TrackMaxSubscribedVideoQuality( ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, - mime string, + mime mime.MimeType, maxQuality livekit.VideoQuality, ) { t.enqueue(func() { room := t.getRoomDetails(participantID) ev := newTrackEvent(livekit.AnalyticsEventType_TRACK_MAX_SUBSCRIBED_VIDEO_QUALITY, room, participantID, track) ev.MaxSubscribedVideoQuality = maxQuality - ev.Mime = mime + ev.Mime = mime.String() t.SendEvent(ctx, ev) }) } @@ -393,7 +394,7 @@ func (t *telemetryService) TrackPublishRTPStats( ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, - mimeType string, + mimeType mime.MimeType, layer int, stats *livekit.RTPStats, ) { @@ -402,7 +403,7 @@ func (t *telemetryService) TrackPublishRTPStats( ev := newRoomEvent(livekit.AnalyticsEventType_TRACK_PUBLISH_STATS, room) ev.ParticipantId = string(participantID) ev.TrackId = string(trackID) - ev.Mime = mimeType + ev.Mime = mimeType.String() ev.VideoLayer = int32(layer) ev.RtpStats = stats t.SendEvent(ctx, ev) @@ -413,7 +414,7 @@ func (t *telemetryService) TrackSubscribeRTPStats( ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, - mimeType string, + mimeType mime.MimeType, stats *livekit.RTPStats, ) { t.enqueue(func() { @@ -421,7 +422,7 @@ func (t *telemetryService) TrackSubscribeRTPStats( ev := newRoomEvent(livekit.AnalyticsEventType_TRACK_SUBSCRIBE_STATS, room) ev.ParticipantId = string(participantID) ev.TrackId = string(trackID) - ev.Mime = mimeType + ev.Mime = mimeType.String() ev.RtpStats = stats t.SendEvent(ctx, ev) }) diff --git a/pkg/telemetry/telemetryfakes/fake_telemetry_service.go b/pkg/telemetry/telemetryfakes/fake_telemetry_service.go index cdfecfdae..f3605cc67 100644 --- a/pkg/telemetry/telemetryfakes/fake_telemetry_service.go +++ b/pkg/telemetry/telemetryfakes/fake_telemetry_service.go @@ -5,6 +5,7 @@ import ( "context" "sync" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/protocol/livekit" ) @@ -152,13 +153,13 @@ type FakeTelemetryService struct { arg1 context.Context arg2 []*livekit.AnalyticsStat } - TrackMaxSubscribedVideoQualityStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, string, livekit.VideoQuality) + TrackMaxSubscribedVideoQualityStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, mime.MimeType, livekit.VideoQuality) trackMaxSubscribedVideoQualityMutex sync.RWMutex trackMaxSubscribedVideoQualityArgsForCall []struct { arg1 context.Context arg2 livekit.ParticipantID arg3 *livekit.TrackInfo - arg4 string + arg4 mime.MimeType arg5 livekit.VideoQuality } TrackMutedStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo) @@ -168,13 +169,13 @@ type FakeTelemetryService struct { arg2 livekit.ParticipantID arg3 *livekit.TrackInfo } - TrackPublishRTPStatsStub func(context.Context, livekit.ParticipantID, livekit.TrackID, string, int, *livekit.RTPStats) + TrackPublishRTPStatsStub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, int, *livekit.RTPStats) trackPublishRTPStatsMutex sync.RWMutex trackPublishRTPStatsArgsForCall []struct { arg1 context.Context arg2 livekit.ParticipantID arg3 livekit.TrackID - arg4 string + arg4 mime.MimeType arg5 int arg6 *livekit.RTPStats } @@ -216,13 +217,13 @@ type FakeTelemetryService struct { arg4 error arg5 bool } - TrackSubscribeRTPStatsStub func(context.Context, livekit.ParticipantID, livekit.TrackID, string, *livekit.RTPStats) + TrackSubscribeRTPStatsStub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, *livekit.RTPStats) trackSubscribeRTPStatsMutex sync.RWMutex trackSubscribeRTPStatsArgsForCall []struct { arg1 context.Context arg2 livekit.ParticipantID arg3 livekit.TrackID - arg4 string + arg4 mime.MimeType arg5 *livekit.RTPStats } TrackSubscribeRequestedStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo) @@ -1003,13 +1004,13 @@ func (fake *FakeTelemetryService) SendStatsArgsForCall(i int) (context.Context, return argsForCall.arg1, argsForCall.arg2 } -func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQuality(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo, arg4 string, arg5 livekit.VideoQuality) { +func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQuality(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo, arg4 mime.MimeType, arg5 livekit.VideoQuality) { fake.trackMaxSubscribedVideoQualityMutex.Lock() fake.trackMaxSubscribedVideoQualityArgsForCall = append(fake.trackMaxSubscribedVideoQualityArgsForCall, struct { arg1 context.Context arg2 livekit.ParticipantID arg3 *livekit.TrackInfo - arg4 string + arg4 mime.MimeType arg5 livekit.VideoQuality }{arg1, arg2, arg3, arg4, arg5}) stub := fake.TrackMaxSubscribedVideoQualityStub @@ -1026,13 +1027,13 @@ func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityCallCount() int return len(fake.trackMaxSubscribedVideoQualityArgsForCall) } -func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, string, livekit.VideoQuality)) { +func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, mime.MimeType, livekit.VideoQuality)) { fake.trackMaxSubscribedVideoQualityMutex.Lock() defer fake.trackMaxSubscribedVideoQualityMutex.Unlock() fake.TrackMaxSubscribedVideoQualityStub = stub } -func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo, string, livekit.VideoQuality) { +func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo, mime.MimeType, livekit.VideoQuality) { fake.trackMaxSubscribedVideoQualityMutex.RLock() defer fake.trackMaxSubscribedVideoQualityMutex.RUnlock() argsForCall := fake.trackMaxSubscribedVideoQualityArgsForCall[i] @@ -1073,13 +1074,13 @@ func (fake *FakeTelemetryService) TrackMutedArgsForCall(i int) (context.Context, return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 } -func (fake *FakeTelemetryService) TrackPublishRTPStats(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 string, arg5 int, arg6 *livekit.RTPStats) { +func (fake *FakeTelemetryService) TrackPublishRTPStats(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 mime.MimeType, arg5 int, arg6 *livekit.RTPStats) { fake.trackPublishRTPStatsMutex.Lock() fake.trackPublishRTPStatsArgsForCall = append(fake.trackPublishRTPStatsArgsForCall, struct { arg1 context.Context arg2 livekit.ParticipantID arg3 livekit.TrackID - arg4 string + arg4 mime.MimeType arg5 int arg6 *livekit.RTPStats }{arg1, arg2, arg3, arg4, arg5, arg6}) @@ -1097,13 +1098,13 @@ func (fake *FakeTelemetryService) TrackPublishRTPStatsCallCount() int { return len(fake.trackPublishRTPStatsArgsForCall) } -func (fake *FakeTelemetryService) TrackPublishRTPStatsCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, string, int, *livekit.RTPStats)) { +func (fake *FakeTelemetryService) TrackPublishRTPStatsCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, int, *livekit.RTPStats)) { fake.trackPublishRTPStatsMutex.Lock() defer fake.trackPublishRTPStatsMutex.Unlock() fake.TrackPublishRTPStatsStub = stub } -func (fake *FakeTelemetryService) TrackPublishRTPStatsArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, string, int, *livekit.RTPStats) { +func (fake *FakeTelemetryService) TrackPublishRTPStatsArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, int, *livekit.RTPStats) { fake.trackPublishRTPStatsMutex.RLock() defer fake.trackPublishRTPStatsMutex.RUnlock() argsForCall := fake.trackPublishRTPStatsArgsForCall[i] @@ -1283,13 +1284,13 @@ func (fake *FakeTelemetryService) TrackSubscribeFailedArgsForCall(i int) (contex return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 } -func (fake *FakeTelemetryService) TrackSubscribeRTPStats(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 string, arg5 *livekit.RTPStats) { +func (fake *FakeTelemetryService) TrackSubscribeRTPStats(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 mime.MimeType, arg5 *livekit.RTPStats) { fake.trackSubscribeRTPStatsMutex.Lock() fake.trackSubscribeRTPStatsArgsForCall = append(fake.trackSubscribeRTPStatsArgsForCall, struct { arg1 context.Context arg2 livekit.ParticipantID arg3 livekit.TrackID - arg4 string + arg4 mime.MimeType arg5 *livekit.RTPStats }{arg1, arg2, arg3, arg4, arg5}) stub := fake.TrackSubscribeRTPStatsStub @@ -1306,13 +1307,13 @@ func (fake *FakeTelemetryService) TrackSubscribeRTPStatsCallCount() int { return len(fake.trackSubscribeRTPStatsArgsForCall) } -func (fake *FakeTelemetryService) TrackSubscribeRTPStatsCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, string, *livekit.RTPStats)) { +func (fake *FakeTelemetryService) TrackSubscribeRTPStatsCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, *livekit.RTPStats)) { fake.trackSubscribeRTPStatsMutex.Lock() defer fake.trackSubscribeRTPStatsMutex.Unlock() fake.TrackSubscribeRTPStatsStub = stub } -func (fake *FakeTelemetryService) TrackSubscribeRTPStatsArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, string, *livekit.RTPStats) { +func (fake *FakeTelemetryService) TrackSubscribeRTPStatsArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, *livekit.RTPStats) { fake.trackSubscribeRTPStatsMutex.RLock() defer fake.trackSubscribeRTPStatsMutex.RUnlock() argsForCall := fake.trackSubscribeRTPStatsArgsForCall[i] diff --git a/pkg/telemetry/telemetryservice.go b/pkg/telemetry/telemetryservice.go index 226a52414..9d2cdf371 100644 --- a/pkg/telemetry/telemetryservice.go +++ b/pkg/telemetry/telemetryservice.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -64,9 +65,9 @@ type TelemetryService interface { // TrackPublishedUpdate - track metadata has been updated TrackPublishedUpdate(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) // TrackMaxSubscribedVideoQuality - publisher is notified of the max quality subscribers desire - TrackMaxSubscribedVideoQuality(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, mime string, maxQuality livekit.VideoQuality) - TrackPublishRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType string, layer int, stats *livekit.RTPStats) - TrackSubscribeRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType string, stats *livekit.RTPStats) + TrackMaxSubscribedVideoQuality(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, mime mime.MimeType, maxQuality livekit.VideoQuality) + TrackPublishRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType mime.MimeType, layer int, stats *livekit.RTPStats) + TrackSubscribeRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType mime.MimeType, stats *livekit.RTPStats) EgressStarted(ctx context.Context, info *livekit.EgressInfo) EgressUpdated(ctx context.Context, info *livekit.EgressInfo) EgressEnded(ctx context.Context, info *livekit.EgressInfo) diff --git a/test/client/client.go b/test/client/client.go index 7304532fa..b54fedb1e 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -22,7 +22,6 @@ import ( "net/http" "net/url" "path/filepath" - "strings" "sync" "time" @@ -43,6 +42,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/transport/transportfakes" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" ) type SignalRequestHandler func(msg *livekit.SignalRequest) error @@ -103,9 +103,9 @@ var ( }, } extMimeMapping = map[string]string{ - ".ivf": webrtc.MimeTypeVP8, - ".h264": webrtc.MimeTypeH264, - ".ogg": webrtc.MimeTypeOpus, + ".ivf": mime.MimeTypeVP8.String(), + ".h264": mime.MimeTypeH264.String(), + ".ogg": mime.MimeTypeOpus.String(), } ) @@ -195,7 +195,7 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { var disabled bool if opts != nil { for _, dc := range opts.DisabledCodecs { - if strings.EqualFold(dc.MimeType, codec.Mime) && (dc.SDPFmtpLine == "" || dc.SDPFmtpLine == codec.FmtpLine) { + if mime.IsMimeTypeStringEqual(dc.MimeType, codec.Mime) && (dc.SDPFmtpLine == "" || dc.SDPFmtpLine == codec.FmtpLine) { disabled = true break } diff --git a/test/client/trackwriter.go b/test/client/trackwriter.go index 2dc4a5a38..803bf3348 100644 --- a/test/client/trackwriter.go +++ b/test/client/trackwriter.go @@ -18,7 +18,6 @@ import ( "context" "io" "os" - "strings" "time" "github.com/pion/webrtc/v4" @@ -27,6 +26,7 @@ import ( "github.com/pion/webrtc/v4/pkg/media/ivfreader" "github.com/pion/webrtc/v4/pkg/media/oggreader" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/logger" ) @@ -37,7 +37,7 @@ type TrackWriter struct { cancel context.CancelFunc track *webrtc.TrackLocalStaticSample filePath string - mime string + mime mime.MimeType ogg *oggreader.OggReader ivfheader *ivfreader.IVFFileHeader @@ -52,7 +52,7 @@ func NewTrackWriter(ctx context.Context, track *webrtc.TrackLocalStaticSample, f cancel: cancel, track: track, filePath: filePath, - mime: track.Codec().MimeType, + mime: mime.NormalizeMimeType(track.Codec().MimeType), } } @@ -67,23 +67,25 @@ func (w *TrackWriter) Start() error { return err } - logger.Debugw("starting track writer", + logger.Debugw( + "starting track writer", "trackID", w.track.ID(), - "mime", w.mime) + "mime", w.mime, + ) switch w.mime { - case webrtc.MimeTypeOpus: + case mime.MimeTypeOpus: w.ogg, _, err = oggreader.NewWith(file) if err != nil { return err } go w.writeOgg() - case webrtc.MimeTypeVP8: + case mime.MimeTypeVP8: w.ivf, w.ivfheader, err = ivfreader.NewWith(file) if err != nil { return err } go w.writeVP8() - case webrtc.MimeTypeH264: + case mime.MimeTypeH264: w.h264, err = h264reader.NewReader(file) if err != nil { return err @@ -104,7 +106,7 @@ func (w *TrackWriter) writeNull() { for { select { case <-time.After(20 * time.Millisecond): - if strings.EqualFold(w.mime, webrtc.MimeTypeH264) { + if w.mime == mime.MimeTypeH264 { w.track.WriteSample(h264Sample) } else { w.track.WriteSample(sample) diff --git a/test/singlenode_test.go b/test/singlenode_test.go index 32946c963..b11e5da77 100644 --- a/test/singlenode_test.go +++ b/test/singlenode_test.go @@ -38,6 +38,7 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/livekit-server/pkg/sfu/datachannel" + "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/testutils" testclient "github.com/livekit/livekit-server/test/client" ) @@ -156,7 +157,7 @@ func TestSinglePublisher(t *testing.T) { waitUntilConnected(t, c1, c2) // publish a track and ensure clients receive it ok - t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcamaudio") + t1, err := c1.AddStaticTrack("audio/OPUS", "audio", "webcamaudio") require.NoError(t, err) defer t1.Stop() t2, err := c1.AddStaticTrack("video/vp8", "video", "webcamvideo") @@ -280,7 +281,7 @@ func Test_RenegotiationWithDifferentCodecs(t *testing.T) { tracks := c2.SubscribedTracks()[c1.ID()] for _, t := range tracks { - if strings.EqualFold(t.Codec().MimeType, "video/vp8") { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { return "" } @@ -308,9 +309,9 @@ func Test_RenegotiationWithDifferentCodecs(t *testing.T) { var vp8Found, h264Found bool tracks := c2.SubscribedTracks()[c1.ID()] for _, t := range tracks { - if strings.EqualFold(t.Codec().MimeType, "video/vp8") { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { vp8Found = true - } else if strings.EqualFold(t.Codec().MimeType, "video/h264") { + } else if mime.IsMimeTypeStringH264(t.Codec().MimeType) { h264Found = true } } @@ -594,8 +595,8 @@ func TestDeviceCodecOverride(t *testing.T) { hasSeenVP8 := false for _, a := range desc.Attributes { if a.Key == "rtpmap" { - require.NotContains(t, a.Value, "H264", "should not contain H264 codec") - if strings.Contains(a.Value, "VP8") { + require.NotContains(t, a.Value, mime.MimeTypeCodecH264.String(), "should not contain H264 codec") + if strings.Contains(a.Value, mime.MimeTypeCodecVP8.String()) { hasSeenVP8 = true } } @@ -641,7 +642,7 @@ func TestSubscribeToCodecUnsupported(t *testing.T) { tracks := c2.SubscribedTracks()[c1.ID()] for _, t := range tracks { - if strings.EqualFold(t.Codec().MimeType, "video/vp8") { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { return "" } } @@ -663,7 +664,7 @@ func TestSubscribeToCodecUnsupported(t *testing.T) { remoteC1 := c2.GetRemoteParticipant(c1.ID()) require.NotNil(t, remoteC1) for _, track := range remoteC1.Tracks { - if strings.EqualFold(track.MimeType, "video/h264") { + if mime.IsMimeTypeStringH264(track.MimeType) { h264TrackID = track.Sid return true } @@ -698,7 +699,7 @@ func TestSubscribeToCodecUnsupported(t *testing.T) { var vp8Count int tracks := c2.SubscribedTracks()[c1.ID()] for _, t := range tracks { - if strings.EqualFold(t.Codec().MimeType, "video/vp8") { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { vp8Count++ } }