diff --git a/config-sample.yaml b/config-sample.yaml index 4f9fb3a1a..4a70f2ae8 100644 --- a/config-sample.yaml +++ b/config-sample.yaml @@ -197,10 +197,9 @@ keys: # stream_buffer_size: 1000 # PSRPC -# since v1.5.1, a more reliable, psrpc based signal relay is available -# this gives us the ability to reliably proxy messages between a signal server and RTC node +# since v1.5.1, a more reliable, psrpc based internal rpc # psrpc: -# # enable the internal psrpc api client for roomservice api calls +# # enable the psrpc internal api client for roomservice calls # enabled: true # # maximum number of rpc attempts # max_attempts: 3 diff --git a/go.mod b/go.mod index 4c01e2787..5df1d6509 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e - github.com/livekit/protocol v1.8.2-0.20231026030639-f8b1277b3c7b + github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1 github.com/livekit/psrpc v0.5.0 github.com/mackerelio/go-osstat v0.2.4 github.com/magefile/mage v1.15.0 @@ -37,7 +37,7 @@ require ( github.com/pion/webrtc/v3 v3.2.21 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.17.0 - github.com/redis/go-redis/v9 v9.2.1 + github.com/redis/go-redis/v9 v9.3.0 github.com/rs/cors v1.10.1 github.com/stretchr/testify v1.8.4 github.com/thoas/go-funk v0.9.3 @@ -61,7 +61,7 @@ require ( github.com/eapache/channels v1.1.0 // indirect github.com/eapache/queue v1.1.0 // indirect github.com/go-jose/go-jose/v3 v3.0.0 // indirect - github.com/go-logr/logr v1.2.4 // indirect + github.com/go-logr/logr v1.3.0 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/subcommands v1.2.0 // indirect @@ -101,7 +101,7 @@ require ( golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.14.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20231030173426-d783a09b4405 // indirect google.golang.org/grpc v1.59.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index 630b60db0..4aedde2df 100644 --- a/go.sum +++ b/go.sum @@ -42,8 +42,8 @@ github.com/gammazero/workerpool v1.1.3 h1:WixN4xzukFoN0XSeXF6puqEqFTl2mECI9S6W44 github.com/gammazero/workerpool v1.1.3/go.mod h1:wPjyBLDbyKnUn2XwwyD3EEwo9dHutia9/fwNmSHWACc= github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo= github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= -github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= -github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= +github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= @@ -125,8 +125,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e h1:yNeIo7MSMUWgoLu7LkNKnBYnJBFPFH9Wq4S6h1kS44M= github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e/go.mod h1:+WIOYwiBMive5T81V8B2wdAc2zQNRjNQiJIcPxMTILY= -github.com/livekit/protocol v1.8.2-0.20231026030639-f8b1277b3c7b h1:ExuLaXyk6pGe2DVgXef7YQB0BNA7eDxidmthSkfGB2w= -github.com/livekit/protocol v1.8.2-0.20231026030639-f8b1277b3c7b/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ= +github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1 h1:WPWxU9w5XHAsonxnSSIIXbWMty9b5uHnTnyKS9TpaXM= +github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ= github.com/livekit/psrpc v0.5.0 h1:g+yYNSs6Y1/vM7UlFkB2s/ARe2y3RKWZhX8ata5j+eo= github.com/livekit/psrpc v0.5.0/go.mod h1:1XYH1LLoD/YbvBvt6xg2KQ/J3InLXSJK6PL/+DKmuAU= github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs= @@ -234,8 +234,8 @@ github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdO github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI= github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY= -github.com/redis/go-redis/v9 v9.2.1 h1:WlYJg71ODF0dVspZZCpYmoF1+U1Jjk9Rwd7pq6QmlCg= -github.com/redis/go-redis/v9 v9.2.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= +github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0= +github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo= @@ -412,8 +412,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b h1:ZlWIi1wSK56/8hn4QcBp/j9M7Gt3U/3hZw3mC7vDICo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b/go.mod h1:swOH3j0KzcDDgGUWr+SNpyTen5YrXjS3eyPzFYKc6lc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20231030173426-d783a09b4405 h1:AB/lmRny7e2pLhFEYIbl5qkDAUt2h0ZRO4wGPhZf+ik= +google.golang.org/genproto/googleapis/rpc v0.0.0-20231030173426-d783a09b4405/go.mod h1:67X1fPuzjcrkymZzZV1vvkFeTn2Rvc6lYF9MYFGCcwE= google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/pkg/clientconfiguration/conf.go b/pkg/clientconfiguration/conf.go index 8c37b7ba4..b8d32aa6a 100644 --- a/pkg/clientconfiguration/conf.go +++ b/pkg/clientconfiguration/conf.go @@ -33,7 +33,8 @@ var StaticConfigurations = []ConfigurationItem{ // Merge: false, // }, { - Match: &ScriptMatch{Expr: `c.device_model == "Xiaomi 2201117TI" && c.os == "android"`}, + Match: &ScriptMatch{Expr: `(c.device_model == "xiaomi 2201117ti" && c.os == "android") || + ((c.browser == "firefox" || c.browser == "firefox mobile") && (c.os == "linux" || c.os == "android"))`}, Configuration: &livekit.ClientConfiguration{ DisabledCodecs: &livekit.DisabledCodecs{ Publish: []*livekit.Codec{{Mime: "video/h264"}}, diff --git a/pkg/clientconfiguration/conf_test.go b/pkg/clientconfiguration/conf_test.go index 093a98f19..9c3554db0 100644 --- a/pkg/clientconfiguration/conf_test.go +++ b/pkg/clientconfiguration/conf_test.go @@ -55,7 +55,7 @@ func TestScriptMatchConfiguration(t *testing.T) { Merge: true, }, { - Match: &ScriptMatch{Expr: `c.sdk == "ANDROID"`}, + Match: &ScriptMatch{Expr: `c.sdk == "android"`}, Configuration: &livekit.ClientConfiguration{ Video: &livekit.VideoConfiguration{ HardwareEncoder: livekit.ClientConfigSetting_DISABLED, @@ -98,7 +98,8 @@ func TestScriptMatch(t *testing.T) { {name: "simple match", expr: `c.protocol > 5`, result: true}, {name: "invalid expr", expr: `cc.protocol > 5`, err: true}, {name: "unexist field", expr: `c.protocols > 5`, err: true}, - {name: "combined condition", expr: `c.protocol > 5 && (c.sdk=="ANDROID" || c.sdk=="IOS")`, result: true}, + {name: "combined condition", expr: `c.protocol > 5 && (c.sdk=="android" || c.sdk=="ios")`, result: true}, + {name: "combined condition2", expr: `(c.device_model == "xiaomi 2201117ti" && c.os == "android) || ((c.browser == "firefox" || c.browser == "firefox mobile") && (c.os == "linux" || c.os == "android"))`, result: false}, } for _, c := range cases { diff --git a/pkg/clientconfiguration/match.go b/pkg/clientconfiguration/match.go index 3c3514220..b915c2c6a 100644 --- a/pkg/clientconfiguration/match.go +++ b/pkg/clientconfiguration/match.go @@ -17,6 +17,7 @@ package clientconfiguration import ( "context" "errors" + "strings" "github.com/d5/tengo/v2" @@ -69,19 +70,19 @@ func (c *clientObject) IndexGet(index tengo.Object) (res tengo.Object, err error switch field.Value { case "sdk": - return &tengo.String{Value: c.info.Sdk.String()}, nil + return &tengo.String{Value: strings.ToLower(c.info.Sdk.String())}, nil case "version": return &tengo.String{Value: c.info.Version}, nil case "protocol": return &tengo.Int{Value: int64(c.info.Protocol)}, nil case "os": - return &tengo.String{Value: c.info.Os}, nil + return &tengo.String{Value: strings.ToLower(c.info.Os)}, nil case "os_version": return &tengo.String{Value: c.info.OsVersion}, nil case "device_model": - return &tengo.String{Value: c.info.DeviceModel}, nil + return &tengo.String{Value: strings.ToLower(c.info.DeviceModel)}, nil case "browser": - return &tengo.String{Value: c.info.Browser}, nil + return &tengo.String{Value: strings.ToLower(c.info.Browser)}, nil case "browser_version": return &tengo.String{Value: c.info.BrowserVersion}, nil case "address": diff --git a/pkg/clientconfiguration/staticconfiguration.go b/pkg/clientconfiguration/staticconfiguration.go index 2071d9112..2d0fcf984 100644 --- a/pkg/clientconfiguration/staticconfiguration.go +++ b/pkg/clientconfiguration/staticconfiguration.go @@ -15,8 +15,6 @@ package clientconfiguration import ( - "fmt" - "google.golang.org/protobuf/proto" "github.com/livekit/protocol/livekit" @@ -42,7 +40,7 @@ func (s *StaticClientConfigurationManager) GetConfiguration(clientInfo *livekit. for _, c := range s.confs { matched, err := c.Match.Match(clientInfo) if err != nil { - logger.Errorw(fmt.Sprintf("matchrule failed, clientInfo: %s", clientInfo.String()), err) + logger.Errorw("matchrule failed", err, "clientInfo", clientInfo.String()) continue } if !matched { diff --git a/pkg/config/config.go b/pkg/config/config.go index 13d780d81..a87caf55b 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -30,6 +30,7 @@ import ( "github.com/livekit/mediatransportutil/pkg/rtcconfig" "github.com/livekit/protocol/logger" redisLiveKit "github.com/livekit/protocol/redis" + "github.com/livekit/protocol/rpc" ) type CongestionControlProbeMode string @@ -71,7 +72,7 @@ type Config struct { Keys map[string]string `yaml:"keys,omitempty"` Region string `yaml:"region,omitempty"` SignalRelay SignalRelayConfig `yaml:"signal_relay,omitempty"` - PSRPC PSRPCConfig `yaml:"psrpc,omitempty"` + PSRPC rpc.PSRPCConfig `yaml:"psrpc,omitempty"` // LogLevel is deprecated LogLevel string `yaml:"log_level,omitempty"` Logging LoggingConfig `yaml:"logging,omitempty"` @@ -273,14 +274,6 @@ type SignalRelayConfig struct { StreamBufferSize int `yaml:"stream_buffer_size,omitempty"` } -type PSRPCConfig struct { - Enabled bool `yaml:"enabled,omitempty"` - MaxAttempts int `yaml:"max_attempts,omitempty"` - Timeout time.Duration `yaml:"timeout,omitempty"` - Backoff time.Duration `yaml:"backoff,omitempty"` - BufferSize int `yaml:"buffer_size,omitempty"` -} - // RegionConfig lists available regions and their latitude/longitude, so the selector would prefer // regions that are closer type RegionConfig struct { @@ -496,13 +489,8 @@ var DefaultConfig = Config{ MaxRetryInterval: 4 * time.Second, StreamBufferSize: 1000, }, - PSRPC: PSRPCConfig{ - MaxAttempts: 3, - Timeout: 500 * time.Millisecond, - Backoff: 500 * time.Millisecond, - BufferSize: 1000, - }, - Keys: map[string]string{}, + PSRPC: rpc.DefaultPSRPCConfig, + Keys: map[string]string{}, } func NewConfig(confString string, strictMode bool, c *cli.Context, baseFlags []cli.Flag) (*Config, error) { diff --git a/pkg/routing/roomclient.go b/pkg/routing/roomclient.go deleted file mode 100644 index f5b6d2bf4..000000000 --- a/pkg/routing/roomclient.go +++ /dev/null @@ -1,25 +0,0 @@ -package routing - -import ( - "github.com/livekit/livekit-server/pkg/config" - "github.com/livekit/livekit-server/pkg/telemetry/prometheus" - "github.com/livekit/protocol/logger" - protopsrpc "github.com/livekit/protocol/psrpc" - "github.com/livekit/protocol/rpc" - "github.com/livekit/psrpc" - "github.com/livekit/psrpc/pkg/middleware" -) - -func NewRoomClient(bus psrpc.MessageBus, config config.PSRPCConfig) (rpc.TypedRoomClient, error) { - return rpc.NewTypedRoomClient( - bus, - protopsrpc.WithClientLogger(logger.GetLogger()), - middleware.WithClientMetrics(prometheus.PSRPCMetricsObserver{}), - psrpc.WithClientChannelSize(config.BufferSize), - middleware.WithRPCRetries(middleware.RetryOptions{ - MaxAttempts: config.MaxAttempts, - Timeout: config.Timeout, - Backoff: config.Backoff, - }), - ) -} diff --git a/pkg/routing/topic.go b/pkg/routing/topic.go deleted file mode 100644 index 24ebb0bba..000000000 --- a/pkg/routing/topic.go +++ /dev/null @@ -1,22 +0,0 @@ -package routing - -import ( - "context" - - "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/rpc" -) - -type topicFormatter struct{} - -func NewTopicFormatter() rpc.TopicFormatter { - return topicFormatter{} -} - -func (f topicFormatter) ParticipantTopic(ctx context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity) rpc.ParticipantTopic { - return rpc.FormatParticipantTopic(roomName, identity) -} - -func (f topicFormatter) RoomTopic(ctx context.Context, roomName livekit.RoomName) rpc.RoomTopic { - return rpc.FormatRoomTopic(roomName) -} diff --git a/pkg/rtc/clientinfo.go b/pkg/rtc/clientinfo.go index 62117e0d1..ca85ec583 100644 --- a/pkg/rtc/clientinfo.go +++ b/pkg/rtc/clientinfo.go @@ -26,7 +26,7 @@ type ClientInfo struct { } func (c ClientInfo) isFirefox() bool { - return c.ClientInfo != nil && strings.EqualFold(c.ClientInfo.Browser, "firefox") + return c.ClientInfo != nil && (strings.EqualFold(c.ClientInfo.Browser, "firefox") || strings.EqualFold(c.ClientInfo.Browser, "firefox mobile")) } func (c ClientInfo) isSafari() bool { @@ -41,6 +41,10 @@ func (c ClientInfo) isLinux() bool { return c.ClientInfo != nil && strings.EqualFold(c.ClientInfo.Os, "linux") } +func (c ClientInfo) isAndroid() bool { + return c.ClientInfo != nil && strings.EqualFold(c.ClientInfo.Os, "android") +} + func (c ClientInfo) SupportsAudioRED() bool { return !c.isFirefox() && !c.isSafari() } @@ -85,7 +89,7 @@ func (c ClientInfo) SupportsChangeRTPSenderEncodingActive() bool { } func (c ClientInfo) ComplyWithCodecOrderInSDPAnswer() bool { - return !(c.isLinux() && c.isFirefox()) + return !((c.isLinux() || c.isAndroid()) && c.isFirefox()) } // compareVersion compares a semver against the current client SDK version diff --git a/pkg/rtc/mediaengine.go b/pkg/rtc/mediaengine.go index a739410db..0edd60565 100644 --- a/pkg/rtc/mediaengine.go +++ b/pkg/rtc/mediaengine.go @@ -18,6 +18,7 @@ import ( "strings" "github.com/pion/webrtc/v3" + "golang.org/x/exp/slices" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/protocol/livekit" @@ -132,3 +133,22 @@ func IsCodecEnabled(codecs []*livekit.Codec, cap webrtc.RTPCodecCapability) bool } return false } + +func selectAlternativeCodec(enabledCodecs []*livekit.Codec) string { + // sort these by compatibility, since we are looking for backups + if slices.ContainsFunc(enabledCodecs, func(c *livekit.Codec) bool { + return strings.EqualFold(c.Mime, webrtc.MimeTypeVP8) + }) { + return webrtc.MimeTypeVP8 + } + if slices.ContainsFunc(enabledCodecs, func(c *livekit.Codec) bool { + return strings.EqualFold(c.Mime, webrtc.MimeTypeH264) + }) { + return webrtc.MimeTypeH264 + } + if len(enabledCodecs) > 0 { + return enabledCodecs[0].Mime + } + // uh oh. this should not happen + return "" +} diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 33ca4921f..15db1a8a3 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -304,14 +304,18 @@ func (t *MediaTrackSubscriptions) closeSubscribedTrack(subTrack types.Subscribed return } - dt.CloseWithFlush(!willBeResumed) - if willBeResumed { + dt.CloseWithFlush(false) + + // cache transceiver for potential re-use on resume tr := dt.GetTransceiver() if tr != nil { sub := subTrack.Subscriber() sub.CacheDownTrack(subTrack.ID(), tr, dt.GetState()) } + } else { + // flushing blocks, avoid blocking when publisher removes all its subscribers + go dt.CloseWithFlush(true) } } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 701482ff8..88d7f2aaf 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -88,18 +88,19 @@ func (p participantUpdateInfo) String() string { // --------------------------------------------------------------- type ParticipantParams struct { - Identity livekit.ParticipantIdentity - Name livekit.ParticipantName - SID livekit.ParticipantID - Config *WebRTCConfig - Sink routing.MessageSink - AudioConfig config.AudioConfig - VideoConfig config.VideoConfig - ProtocolVersion types.ProtocolVersion - Telemetry telemetry.TelemetryService - Trailer []byte - PLIThrottleConfig config.PLIThrottleConfig - CongestionControlConfig config.CongestionControlConfig + Identity livekit.ParticipantIdentity + Name livekit.ParticipantName + SID livekit.ParticipantID + Config *WebRTCConfig + Sink routing.MessageSink + AudioConfig config.AudioConfig + VideoConfig config.VideoConfig + ProtocolVersion types.ProtocolVersion + Telemetry telemetry.TelemetryService + Trailer []byte + PLIThrottleConfig config.PLIThrottleConfig + CongestionControlConfig config.CongestionControlConfig + // codecs that are enabled for this room EnabledCodecs []*livekit.Codec Logger logger.Logger SimTracks map[uint32]SimulcastTrackInfo @@ -157,6 +158,10 @@ type ParticipantImpl struct { // migrated in muted tracks are not fired need close at participant close mutedTrackNotFired []*MediaTrack + // supported codecs + enabledPublishCodecs []*livekit.Codec + enabledSubscribeCodecs []*livekit.Codec + *TransportManager *UpTrackManager *SubscriptionManager @@ -244,6 +249,7 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { p.state.Store(livekit.ParticipantInfo_JOINING) p.grants = params.Grants p.SetResponseSink(params.Sink) + p.setupEnabledCodecs(params.EnabledCodecs, params.ClientConf.GetDisabledCodecs()) p.supervisor.OnPublicationError(p.onPublicationError) @@ -1109,9 +1115,9 @@ func (p *ParticipantImpl) setupTransportManager() error { ProtocolVersion: p.params.ProtocolVersion, Telemetry: p.params.Telemetry, CongestionControlConfig: p.params.CongestionControlConfig, - EnabledCodecs: p.params.EnabledCodecs, + EnabledPublishCodecs: p.enabledPublishCodecs, + EnabledSubscribeCodecs: p.enabledSubscribeCodecs, SimTracks: p.params.SimTracks, - ClientConf: p.params.ClientConf, ClientInfo: p.params.ClientInfo, Migration: p.params.Migration, AllowTCPFallback: p.params.AllowTCPFallback, @@ -1614,19 +1620,33 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l ti.Stream = StreamFromTrackSource(ti.Source) } p.setStableTrackID(req.Cid, ti) + seenCodecs := make(map[string]struct{}) for _, codec := range req.SimulcastCodecs { mime := codec.Codec if req.Type == livekit.TrackType_VIDEO && !strings.HasPrefix(mime, "video/") { mime = "video/" + mime + if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) { + altCodec := selectAlternativeCodec(p.enabledPublishCodecs) + p.pubLogger.Infow("falling back to alternative codec", + "codec", mime, + "altCodec", altCodec, + "trackID", ti.Sid, + ) + // select an alternative MIME type that's generally supported + mime = altCodec + } } else if req.Type == livekit.TrackType_AUDIO && !strings.HasPrefix(mime, "audio/") { mime = "audio/" + mime } - if IsCodecEnabled(p.params.EnabledCodecs, webrtc.RTPCodecCapability{MimeType: mime}) { - ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ - MimeType: mime, - Cid: codec.Cid, - }) + + if _, ok := seenCodecs[mime]; ok { + continue } + seenCodecs[mime] = struct{}{} + ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ + MimeType: mime, + Cid: codec.Cid, + }) } p.params.Telemetry.TrackPublishRequested(context.Background(), p.ID(), p.Identity(), ti) @@ -2313,3 +2333,35 @@ func (p *ParticipantImpl) SendDataPacket(dp *livekit.DataPacket, data []byte) er } return err } + +func (p *ParticipantImpl) setupEnabledCodecs(codecs []*livekit.Codec, disabledCodecs *livekit.DisabledCodecs) { + subscribeCodecs := make([]*livekit.Codec, 0, len(codecs)) + publishCodecs := make([]*livekit.Codec, 0, len(codecs)) + shouldDisable := func(c *livekit.Codec, disabled []*livekit.Codec) bool { + for _, disableCodec := range disabled { + // disable codec's fmtp is empty means disable this codec entirely + if strings.EqualFold(c.Mime, disableCodec.Mime) { + return true + } + } + return false + } + for _, c := range codecs { + var publishDisabled bool + var subscribeDisabled bool + if shouldDisable(c, disabledCodecs.GetCodecs()) { + publishDisabled = true + subscribeDisabled = true + } else if shouldDisable(c, disabledCodecs.GetPublish()) { + publishDisabled = true + } + if !publishDisabled { + publishCodecs = append(publishCodecs, c) + } + if !subscribeDisabled { + subscribeCodecs = append(subscribeCodecs, c) + } + } + p.enabledSubscribeCodecs = subscribeCodecs + p.enabledPublishCodecs = publishCodecs +} diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 1454a27d6..2b2ce2505 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -466,6 +466,71 @@ func TestDisableCodecs(t *testing.T) { require.False(t, found264) } +func TestDisablePublishCodec(t *testing.T) { + participant := newParticipantForTestWithOpts("123", &participantOpts{ + publisher: true, + clientConf: &livekit.ClientConfiguration{ + DisabledCodecs: &livekit.DisabledCodecs{ + Publish: []*livekit.Codec{ + {Mime: "video/h264"}, + }, + }, + }, + }) + + for _, codec := range participant.enabledPublishCodecs { + require.NotEqual(t, strings.ToLower(codec.Mime), "video/h264") + } + + sink := &routingfakes.FakeMessageSink{} + participant.SetResponseSink(sink) + var publishReceived atomic.Bool + sink.WriteMessageCalls(func(msg proto.Message) error { + if res, ok := msg.(*livekit.SignalResponse); ok { + if published := res.GetTrackPublished(); published != nil { + publishReceived.Store(true) + require.NotEmpty(t, published.Track.Codecs) + require.Equal(t, "video/vp8", strings.ToLower(published.Track.Codecs[0].MimeType)) + } + } + return nil + }) + + // simulcast codec response should pick an alternative + participant.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid1", + Type: livekit.TrackType_VIDEO, + SimulcastCodecs: []*livekit.SimulcastCodec{{ + Codec: "h264", + Cid: "cid1", + }}, + }) + + require.Eventually(t, func() bool { return publishReceived.Load() }, 5*time.Second, 10*time.Millisecond) + + // publishing a supported codec should not change + publishReceived.Store(false) + sink.WriteMessageCalls(func(msg proto.Message) error { + if res, ok := msg.(*livekit.SignalResponse); ok { + if published := res.GetTrackPublished(); published != nil { + publishReceived.Store(true) + require.NotEmpty(t, published.Track.Codecs) + require.Equal(t, "video/vp8", strings.ToLower(published.Track.Codecs[0].MimeType)) + } + } + return nil + }) + participant.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid2", + Type: livekit.TrackType_VIDEO, + SimulcastCodecs: []*livekit.SimulcastCodec{{ + Codec: "vp8", + Cid: "cid2", + }}, + }) + require.Eventually(t, func() bool { return publishReceived.Load() }, 5*time.Second, 10*time.Millisecond) +} + func TestPreferVideoCodecForPublisher(t *testing.T) { participant := newParticipantForTestWithOpts("123", &participantOpts{ publisher: true, @@ -641,7 +706,6 @@ func TestPreferAudioCodecForRed(t *testing.T) { require.Equalf(t, !disableRed, redPreferred, "offer : \n%s\nanswer sdp: \n%s", sdp, answer.SDP) }) } - } type participantOpts struct { diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index 0864bf490..6805bd430 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -56,9 +56,9 @@ type TransportManagerParams struct { ProtocolVersion types.ProtocolVersion Telemetry telemetry.TelemetryService CongestionControlConfig config.CongestionControlConfig - EnabledCodecs []*livekit.Codec + EnabledSubscribeCodecs []*livekit.Codec + EnabledPublishCodecs []*livekit.Codec SimTracks map[uint32]SimulcastTrackInfo - ClientConf *livekit.ClientConfiguration ClientInfo ClientInfo Migration bool AllowTCPFallback bool @@ -112,35 +112,6 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro } t.mediaLossProxy.OnMediaLossUpdate(t.onMediaLossUpdate) - subscribeCodecs := make([]*livekit.Codec, 0, len(params.EnabledCodecs)) - publishCodecs := make([]*livekit.Codec, 0, len(params.EnabledCodecs)) - shouldDisable := func(c *livekit.Codec, disabledCodecs []*livekit.Codec) bool { - for _, disableCodec := range disabledCodecs { - // disable codec's fmtp is empty means disable this codec entirely - if strings.EqualFold(c.Mime, disableCodec.Mime) && (disableCodec.FmtpLine == "" || disableCodec.FmtpLine == c.FmtpLine) { - return true - } - } - return false - } - for _, c := range params.EnabledCodecs { - var publishDisabled bool - var subscribeDisabled bool - if shouldDisable(c, params.ClientConf.GetDisabledCodecs().GetCodecs()) { - publishDisabled = true - subscribeDisabled = true - } - if shouldDisable(c, params.ClientConf.GetDisabledCodecs().GetPublish()) { - publishDisabled = true - } - if !publishDisabled { - publishCodecs = append(publishCodecs, c) - } - if !subscribeDisabled { - subscribeCodecs = append(subscribeCodecs, c) - } - } - publisher, err := NewPCTransport(TransportParams{ ParticipantID: params.SID, ParticipantIdentity: params.Identity, @@ -149,7 +120,7 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro DirectionConfig: params.Config.Publisher, CongestionControlConfig: params.CongestionControlConfig, Telemetry: params.Telemetry, - EnabledCodecs: publishCodecs, + EnabledCodecs: params.EnabledPublishCodecs, Logger: LoggerWithPCTarget(params.Logger, livekit.SignalTarget_PUBLISHER), SimTracks: params.SimTracks, ClientInfo: params.ClientInfo, @@ -181,7 +152,7 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro DirectionConfig: params.Config.Subscriber, CongestionControlConfig: params.CongestionControlConfig, Telemetry: params.Telemetry, - EnabledCodecs: subscribeCodecs, + EnabledCodecs: params.EnabledSubscribeCodecs, Logger: LoggerWithPCTarget(params.Logger, livekit.SignalTarget_SUBSCRIBER), ClientInfo: params.ClientInfo, IsOfferer: true, diff --git a/pkg/service/ioservice.go b/pkg/service/ioservice.go index 4a653fbe5..b957e8b15 100644 --- a/pkg/service/ioservice.go +++ b/pkg/service/ioservice.go @@ -38,7 +38,6 @@ type IOInfoService struct { } func NewIOInfoService( - nodeID livekit.NodeID, bus psrpc.MessageBus, es EgressStore, is IngressStore, @@ -126,7 +125,6 @@ func (s *IOInfoService) GetEgress(ctx context.Context, req *rpc.GetEgressRequest } func (s *IOInfoService) ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error) { - var items []*livekit.EgressInfo if req.EgressId != "" { info, err := s.es.LoadEgress(ctx, req.EgressId) if err != nil { @@ -134,16 +132,13 @@ func (s *IOInfoService) ListEgress(ctx context.Context, req *livekit.ListEgressR return nil, err } - if !req.Active || int32(info.Status) < int32(livekit.EgressStatus_EGRESS_COMPLETE) { - items = []*livekit.EgressInfo{info} - } - } else { - var err error - items, err = s.es.ListEgress(ctx, livekit.RoomName(req.RoomName), req.Active) - if err != nil { - logger.Errorw("failed to list egress", err) - return nil, err - } + return &livekit.ListEgressResponse{Items: []*livekit.EgressInfo{info}}, nil + } + + items, err := s.es.ListEgress(ctx, livekit.RoomName(req.RoomName), req.Active) + if err != nil { + logger.Errorw("failed to list egress", err) + return nil, err } return &livekit.ListEgressResponse{Items: items}, nil @@ -223,33 +218,3 @@ func (s *IOInfoService) Stop() { s.ioServer.Shutdown() } } - -// deprecated -func (s *IOInfoService) UpdateEgressInfo(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error) { - err := s.es.UpdateEgress(ctx, info) - - switch info.Status { - case livekit.EgressStatus_EGRESS_ACTIVE: - s.telemetry.EgressUpdated(ctx, info) - - case livekit.EgressStatus_EGRESS_COMPLETE, - livekit.EgressStatus_EGRESS_FAILED, - livekit.EgressStatus_EGRESS_ABORTED, - livekit.EgressStatus_EGRESS_LIMIT_REACHED: - - // log results - if info.Error != "" { - logger.Errorw("egress failed", errors.New(info.Error), "egressID", info.EgressId) - } else { - logger.Infow("egress ended", "egressID", info.EgressId) - } - - s.telemetry.EgressEnded(ctx, info) - } - if err != nil { - logger.Errorw("could not update egress", err) - return nil, err - } - - return &emptypb.Empty{}, nil -} diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index f37bfeee8..918e847bf 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -49,6 +49,8 @@ const ( iceConfigTTL = 5 * time.Minute ) +var affinityEpoch = time.Date(2000, 0, 0, 0, 0, 0, 0, time.UTC) + type iceConfigCacheEntry struct { iceConfig *livekit.ICEConfig modifiedAt time.Time @@ -70,10 +72,13 @@ type RoomManager struct { egressLauncher rtc.EgressLauncher versionGenerator utils.TimedVersionGenerator turnAuthHandler *TURNAuthHandler - roomServer rpc.TypedRoomServer + bus psrpc.MessageBus rooms map[livekit.RoomName]*rtc.Room + roomServers utils.MultitonService[rpc.RoomTopic] + participantServers utils.MultitonService[rpc.ParticipantTopic] + iceConfigCache map[livekit.ParticipantIdentity]*iceConfigCacheEntry } @@ -105,6 +110,7 @@ func NewLocalRoomManager( egressLauncher: egressLauncher, versionGenerator: versionGenerator, turnAuthHandler: turnAuthHandler, + bus: bus, rooms: make(map[livekit.RoomName]*rtc.Room), @@ -119,11 +125,6 @@ func NewLocalRoomManager( }, } - r.roomServer, err = rpc.NewTypedRoomServer(r, bus) - if err != nil { - return nil, err - } - // hook up to router router.OnNewParticipantRTC(r.StartSession) router.OnRTCMessage(r.handleRTCMessage) @@ -220,7 +221,8 @@ func (r *RoomManager) Stop() { room.Close() } - r.roomServer.Kill() + r.roomServers.Kill() + r.participantServers.Kill() if r.rtcConfig != nil { if r.rtcConfig.UDPMux != nil { @@ -433,13 +435,17 @@ func (r *RoomManager) StartSession( _ = participant.Close(true, types.ParticipantCloseReasonJoinFailed, false) return err } - if r.config.PSRPC.Enabled { - if err := r.roomServer.RegisterAllParticipantTopics(rpc.FormatParticipantTopic(roomName, participant.Identity())); err != nil { - pLogger.Errorw("could not join register participant topic", err) - _ = participant.Close(true, types.ParticipantCloseReasonMessageBusFailed, false) - return err - } + + participantTopic := rpc.FormatParticipantTopic(roomName, participant.Identity()) + participantServer := utils.Must(rpc.NewTypedParticipantServer(r, r.bus)) + killParticipantServer := r.participantServers.Replace(participantTopic, participantServer) + if err := participantServer.RegisterAllParticipantTopics(participantTopic); err != nil { + killParticipantServer() + pLogger.Errorw("could not join register participant topic", err) + _ = participant.Close(true, types.ParticipantCloseReasonMessageBusFailed, false) + return err } + if err = r.roomStore.StoreParticipant(ctx, roomName, participant.ToProto()); err != nil { pLogger.Errorw("could not store participant", err) } @@ -459,14 +465,12 @@ func (r *RoomManager) StartSession( clientMeta := &livekit.AnalyticsClientMeta{Region: r.currentNode.Region, Node: r.currentNode.Id} r.telemetry.ParticipantJoined(ctx, protoRoom, participant.ToProto(), pi.Client, clientMeta, true) participant.OnClose(func(p types.LocalParticipant) { + killParticipantServer() + if err := r.roomStore.DeleteParticipant(ctx, roomName, p.Identity()); err != nil { pLogger.Errorw("could not delete participant", err) } - if r.config.PSRPC.Enabled { - r.roomServer.DeregisterAllParticipantTopics(rpc.FormatParticipantTopic(roomName, participant.Identity())) - } - // update room store with new numParticipants proto := room.ToProto() persistRoomForParticipantCount(proto) @@ -507,12 +511,6 @@ func (r *RoomManager) getOrCreateRoom(ctx context.Context, roomName livekit.Room return nil, err } - if r.config.PSRPC.Enabled { - if err := r.roomServer.RegisterAllRoomTopics(rpc.FormatRoomTopic(roomName)); err != nil { - return nil, err - } - } - r.lock.Lock() currentRoom := r.rooms[roomName] @@ -530,10 +528,17 @@ func (r *RoomManager) getOrCreateRoom(ctx context.Context, roomName livekit.Room // construct ice servers newRoom := rtc.NewRoom(ri, internal, *r.rtcConfig, &r.config.Audio, r.serverInfo, r.telemetry, r.egressLauncher) + roomTopic := rpc.FormatRoomTopic(roomName) + roomServer := utils.Must(rpc.NewTypedRoomServer(r, r.bus)) + killRoomServer := r.roomServers.Replace(roomTopic, roomServer) + if err := roomServer.RegisterAllRoomTopics(roomTopic); err != nil { + killRoomServer() + r.lock.Unlock() + return nil, err + } + newRoom.OnClose(func() { - if r.config.PSRPC.Enabled { - r.roomServer.DeregisterAllRoomTopics(rpc.FormatRoomTopic(roomName)) - } + killRoomServer() roomInfo := newRoom.ToProto() r.telemetry.RoomEnded(ctx, roomInfo) @@ -648,19 +653,29 @@ func (r *RoomManager) handleRTCMessage(ctx context.Context, roomName livekit.Roo } } -func (r *RoomManager) roomLogger(room *rtc.Room) logger.Logger { - return rtc.LoggerWithParticipant(rtc.LoggerWithRoom(logger.GetLogger(), room.Name(), room.ID()), "", "", false) +type participantReq interface { + GetRoom() string + GetIdentity() string +} + +func (r *RoomManager) roomAndParticipantForReq(ctx context.Context, req participantReq) (*rtc.Room, types.LocalParticipant, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.GetRoom())) + if room == nil { + return nil, nil, ErrRoomNotFound + } + + participant := room.GetParticipant(livekit.ParticipantIdentity(req.GetIdentity())) + if participant == nil { + return nil, nil, ErrParticipantNotFound + } + + return room, participant, nil } func (r *RoomManager) RemoveParticipant(ctx context.Context, req *livekit.RoomParticipantIdentity) (*livekit.RemoveParticipantResponse, error) { - room := r.GetRoom(ctx, livekit.RoomName(req.Room)) - if room == nil { - return nil, ErrRoomNotFound - } - - participant := room.GetParticipant(livekit.ParticipantIdentity(req.Identity)) - if participant == nil { - return nil, ErrParticipantNotFound + room, participant, err := r.roomAndParticipantForReq(ctx, req) + if err != nil { + return nil, err } participant.GetLogger().Infow("removing participant") @@ -669,14 +684,9 @@ func (r *RoomManager) RemoveParticipant(ctx context.Context, req *livekit.RoomPa } func (r *RoomManager) MutePublishedTrack(ctx context.Context, req *livekit.MuteRoomTrackRequest) (*livekit.MuteRoomTrackResponse, error) { - room := r.GetRoom(ctx, livekit.RoomName(req.Room)) - if room == nil { - return nil, ErrRoomNotFound - } - - participant := room.GetParticipant(livekit.ParticipantIdentity(req.Identity)) - if participant == nil { - return nil, ErrParticipantNotFound + _, participant, err := r.roomAndParticipantForReq(ctx, req) + if err != nil { + return nil, err } participant.GetLogger().Debugw("setting track muted", @@ -690,14 +700,9 @@ func (r *RoomManager) MutePublishedTrack(ctx context.Context, req *livekit.MuteR } func (r *RoomManager) UpdateParticipant(ctx context.Context, req *livekit.UpdateParticipantRequest) (*livekit.ParticipantInfo, error) { - room := r.GetRoom(ctx, livekit.RoomName(req.Room)) - if room == nil { - return nil, ErrRoomNotFound - } - - participant := room.GetParticipant(livekit.ParticipantIdentity(req.Identity)) - if participant == nil { - return nil, ErrParticipantNotFound + room, participant, err := r.roomAndParticipantForReq(ctx, req) + if err != nil { + return nil, err } participant.GetLogger().Debugw("updating participant", @@ -730,14 +735,9 @@ func (r *RoomManager) DeleteRoom(ctx context.Context, req *livekit.DeleteRoomReq } func (r *RoomManager) UpdateSubscriptions(ctx context.Context, req *livekit.UpdateSubscriptionsRequest) (*livekit.UpdateSubscriptionsResponse, error) { - room := r.GetRoom(ctx, livekit.RoomName(req.Room)) - if room == nil { - return nil, ErrRoomNotFound - } - - participant := room.GetParticipant(livekit.ParticipantIdentity(req.Identity)) - if participant == nil { - return nil, ErrParticipantNotFound + room, participant, err := r.roomAndParticipantForReq(ctx, req) + if err != nil { + return nil, err } participant.GetLogger().Debugw("updating participant subscriptions") @@ -756,7 +756,7 @@ func (r *RoomManager) SendData(ctx context.Context, req *livekit.SendDataRequest return nil, ErrRoomNotFound } - r.roomLogger(room).Debugw("api send data", "size", len(req.Data)) + room.Logger.Debugw("api send data", "size", len(req.Data)) up := &livekit.UserPacket{ Payload: req.Data, DestinationSids: req.DestinationSids, @@ -773,7 +773,7 @@ func (r *RoomManager) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat return nil, ErrRoomNotFound } - r.roomLogger(room).Debugw("updating room") + room.Logger.Debugw("updating room") room.SetMetadata(req.Metadata) return room.ToProto(), nil } diff --git a/pkg/service/roomservice.go b/pkg/service/roomservice.go index dd9cefebc..aa5e323d2 100644 --- a/pkg/service/roomservice.go +++ b/pkg/service/roomservice.go @@ -35,38 +35,41 @@ import ( // A rooms service that supports a single node type RoomService struct { - roomConf config.RoomConfig - apiConf config.APIConfig - psrpcConf config.PSRPCConfig - router routing.MessageRouter - roomAllocator RoomAllocator - roomStore ServiceStore - egressLauncher rtc.EgressLauncher - topicFormatter rpc.TopicFormatter - roomClient rpc.TypedRoomClient + roomConf config.RoomConfig + apiConf config.APIConfig + psrpcConf rpc.PSRPCConfig + router routing.MessageRouter + roomAllocator RoomAllocator + roomStore ServiceStore + egressLauncher rtc.EgressLauncher + topicFormatter rpc.TopicFormatter + roomClient rpc.TypedRoomClient + participantClient rpc.TypedParticipantClient } func NewRoomService( roomConf config.RoomConfig, apiConf config.APIConfig, - psrpcConf config.PSRPCConfig, + psrpcConf rpc.PSRPCConfig, router routing.MessageRouter, roomAllocator RoomAllocator, serviceStore ServiceStore, egressLauncher rtc.EgressLauncher, topicFormatter rpc.TopicFormatter, roomClient rpc.TypedRoomClient, + participantClient rpc.TypedParticipantClient, ) (svc *RoomService, err error) { svc = &RoomService{ - roomConf: roomConf, - apiConf: apiConf, - psrpcConf: psrpcConf, - router: router, - roomAllocator: roomAllocator, - roomStore: serviceStore, - egressLauncher: egressLauncher, - topicFormatter: topicFormatter, - roomClient: roomClient, + roomConf: roomConf, + apiConf: apiConf, + psrpcConf: psrpcConf, + router: router, + roomAllocator: roomAllocator, + roomStore: serviceStore, + egressLauncher: egressLauncher, + topicFormatter: topicFormatter, + roomClient: roomClient, + participantClient: participantClient, } return } @@ -229,7 +232,7 @@ func (s *RoomService) RemoveParticipant(ctx context.Context, req *livekit.RoomPa } if s.psrpcConf.Enabled { - return s.roomClient.RemoveParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) + return s.participantClient.RemoveParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) } err := s.writeParticipantMessage(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity), &livekit.RTCNodeMessage{ @@ -265,7 +268,7 @@ func (s *RoomService) MutePublishedTrack(ctx context.Context, req *livekit.MuteR } if s.psrpcConf.Enabled { - return s.roomClient.MutePublishedTrack(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) + return s.participantClient.MutePublishedTrack(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) } err := s.writeParticipantMessage(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity), &livekit.RTCNodeMessage{ @@ -319,7 +322,7 @@ func (s *RoomService) UpdateParticipant(ctx context.Context, req *livekit.Update } if s.psrpcConf.Enabled { - return s.roomClient.UpdateParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) + return s.participantClient.UpdateParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) } err := s.writeParticipantMessage(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity), &livekit.RTCNodeMessage{ @@ -368,7 +371,7 @@ func (s *RoomService) UpdateSubscriptions(ctx context.Context, req *livekit.Upda } if s.psrpcConf.Enabled { - return s.roomClient.UpdateSubscriptions(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) + return s.participantClient.UpdateSubscriptions(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) } err := s.writeParticipantMessage(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity), &livekit.RTCNodeMessage{ diff --git a/pkg/service/roomservice_test.go b/pkg/service/roomservice_test.go index bf25870e7..f15a72cd7 100644 --- a/pkg/service/roomservice_test.go +++ b/pkg/service/roomservice_test.go @@ -23,10 +23,10 @@ import ( "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/rpc/rpcfakes" "github.com/livekit/livekit-server/pkg/config" - "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/routing/routingfakes" "github.com/livekit/livekit-server/pkg/service" "github.com/livekit/livekit-server/pkg/service/servicefakes" @@ -131,13 +131,14 @@ func newTestRoomService(conf config.RoomConfig) *TestRoomService { svc, err := service.NewRoomService( conf, config.APIConfig{ExecutionTimeout: 2}, - config.PSRPCConfig{}, + rpc.PSRPCConfig{}, router, allocator, store, nil, - routing.NewTopicFormatter(), + rpc.NewTopicFormatter(), &rpcfakes.FakeTypedRoomClient{}, + &rpcfakes.FakeTypedParticipantClient{}, ) if err != nil { panic(err) diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 0aa93088b..2e2b39592 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -31,8 +31,10 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" redisLiveKit "github.com/livekit/protocol/redis" "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/utils" @@ -74,8 +76,10 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live NewDefaultSignalServer, routing.NewSignalClient, getPSRPCConfig, - routing.NewTopicFormatter, - routing.NewRoomClient, + getPSRPCClientParams, + rpc.NewTopicFormatter, + rpc.NewTypedRoomClient, + rpc.NewTypedParticipantClient, NewLocalRoomManager, NewTURNAuthHandler, getTURNAuthHandlerFunc, @@ -200,10 +204,14 @@ func getSignalRelayConfig(config *config.Config) config.SignalRelayConfig { return config.SignalRelay } -func getPSRPCConfig(config *config.Config) config.PSRPCConfig { +func getPSRPCConfig(config *config.Config) rpc.PSRPCConfig { return config.PSRPC } +func getPSRPCClientParams(config rpc.PSRPCConfig, bus psrpc.MessageBus) rpc.ClientParams { + return rpc.NewClientParams(config, bus, logger.GetLogger(), prometheus.PSRPCMetricsObserver{}) +} + func newInProcessTurnServer(conf *config.Config, authHandler turn.AuthHandler) (*turn.Server, error) { return NewTurnServer(conf, authHandler, false) } diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index cbc3d041a..322d125e6 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -12,8 +12,10 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" redis2 "github.com/livekit/protocol/redis" "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/utils" @@ -69,17 +71,22 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live } analyticsService := telemetry.NewAnalyticsService(conf, currentNode) telemetryService := telemetry.NewTelemetryService(queuedNotifier, analyticsService) - ioInfoService, err := NewIOInfoService(nodeID, messageBus, egressStore, ingressStore, telemetryService) + ioInfoService, err := NewIOInfoService(messageBus, egressStore, ingressStore, telemetryService) if err != nil { return nil, err } rtcEgressLauncher := NewEgressLauncher(egressClient, ioInfoService) - topicFormatter := routing.NewTopicFormatter() - roomClient, err := routing.NewRoomClient(messageBus, psrpcConfig) + topicFormatter := rpc.NewTopicFormatter() + clientParams := getPSRPCClientParams(psrpcConfig, messageBus) + roomClient, err := rpc.NewTypedRoomClient(clientParams) if err != nil { return nil, err } - roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, roomClient) + participantClient, err := rpc.NewTypedParticipantClient(clientParams) + if err != nil { + return nil, err + } + roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, roomClient, participantClient) if err != nil { return nil, err } @@ -233,10 +240,14 @@ func getSignalRelayConfig(config2 *config.Config) config.SignalRelayConfig { return config2.SignalRelay } -func getPSRPCConfig(config2 *config.Config) config.PSRPCConfig { +func getPSRPCConfig(config2 *config.Config) rpc.PSRPCConfig { return config2.PSRPC } +func getPSRPCClientParams(config2 rpc.PSRPCConfig, bus psrpc.MessageBus) rpc.ClientParams { + return rpc.NewClientParams(config2, bus, logger.GetLogger(), prometheus.PSRPCMetricsObserver{}) +} + func newInProcessTurnServer(conf *config.Config, authHandler turn.AuthHandler) (*turn.Server, error) { return NewTurnServer(conf, authHandler, false) } diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 77c6c58ab..3bb0a6715 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -614,7 +614,12 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime time.Time, flow ep.Temporal = 0 if b.ddParser != nil { ddVal, videoLayer, err := b.ddParser.Parse(ep.Packet) - if err == nil && ddVal != nil { + if err != nil { + if err != ErrFrameEarlierThanKeyFrame { + b.logger.Warnw("could not parse dependency descriptor", err) + } + return nil + } else if ddVal != nil { ep.DependencyDescriptor = ddVal ep.VideoLayer = videoLayer // DD-TODO : notify active decode target change if changed. diff --git a/pkg/sfu/buffer/dependencydescriptorparser.go b/pkg/sfu/buffer/dependencydescriptorparser.go index ca69bd763..05b675ad0 100644 --- a/pkg/sfu/buffer/dependencydescriptorparser.go +++ b/pkg/sfu/buffer/dependencydescriptorparser.go @@ -26,6 +26,10 @@ import ( "github.com/livekit/protocol/logger" ) +var ( + ErrFrameEarlierThanKeyFrame = fmt.Errorf("frame is earlier than current keyframe") +) + type DependencyDescriptorParser struct { structure *dd.FrameDependencyStructure ddExtID uint8 @@ -36,6 +40,7 @@ type DependencyDescriptorParser struct { seqWrapAround *utils.WrapAround[uint16, uint64] frameWrapAround *utils.WrapAround[uint16, uint64] structureExtSeq uint64 + structureExtFrameNum uint64 activeDecodeTargetsExtSeq uint64 activeDecodeTargetsMask uint32 frameChecker *FrameIntegrityChecker @@ -88,6 +93,12 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr } extFN := r.frameWrapAround.Update(ddVal.FrameNumber).ExtendedVal + + if extFN < r.structureExtFrameNum { + r.logger.Debugw("drop frame which is earlier than current structure", "frameNum", extFN, "structureFrameNum", r.structureExtFrameNum) + return nil, videoLayer, ErrFrameEarlierThanKeyFrame + } + r.frameChecker.AddPacket(extSeq, extFN, &ddVal) extDD := &ExtDependencyDescriptor{ @@ -97,11 +108,12 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr } if ddVal.AttachedStructure != nil { - r.logger.Debugw(fmt.Sprintf("parsed dependency descriptor\n%s", ddVal.String())) + r.logger.Debugw("parsed dependency descriptor", "extSeq", extSeq, "extFN", extFN, "structureID", ddVal.AttachedStructure.StructureId, "descriptor", ddVal.String()) if extSeq > r.structureExtSeq { r.structure = ddVal.AttachedStructure r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure) r.structureExtSeq = extSeq + r.structureExtFrameNum = extFN extDD.StructureUpdated = true extDD.ActiveDecodeTargetsUpdated = true // The dependency descriptor reader will always set ActiveDecodeTargetsBitmask for TemplateDependencyStructure is present, diff --git a/pkg/sfu/buffer/rtpstats_sender.go b/pkg/sfu/buffer/rtpstats_sender.go index 567a5efad..fd2e9dd58 100644 --- a/pkg/sfu/buffer/rtpstats_sender.go +++ b/pkg/sfu/buffer/rtpstats_sender.go @@ -394,22 +394,7 @@ func (r *RTPStatsSender) Update( "tsBefore", r.extStartTS, "tsAfter", extTimestamp, ) - if extTimestamp == 0 { // TODO-REMOVE-AFTER-DEBUG - r.logger.Errorw( - "invalid start timestamp", nil, - "snBefore", r.extStartSN, - "snAfter", extSequenceNumber, - "snHighest", r.extHighestSN, - "tsBefore", r.extStartTS, - "tsAfter", extTimestamp, - "tsHighest", r.extHighestTS, - "firstTime", r.firstTime.String(), - "startTime", r.startTime.String(), - ) - } - if extTimestamp != 0 { - r.extStartTS = extTimestamp - } + r.extStartTS = extTimestamp } if extTimestamp > r.extHighestTS { @@ -554,7 +539,11 @@ func (r *RTPStatsSender) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt "receivedRR", rr, "extStartSN", r.extStartSN, "extHighestSN", r.extHighestSN, + "extStartTS", r.extStartTS, + "extHighestTS", r.extHighestTS, "extLastRRSN", s.extLastRRSN, + "firstTime", r.firstTime.String(), + "startTime", r.startTime.String(), "extReceivedRRSN", extReceivedRRSN, "packetsInInterval", extReceivedRRSN-s.extLastRRSN, "intervalStats", is.ToString(), diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 920a473a5..f995a90dd 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -246,8 +246,6 @@ type DownTrack struct { totalRepeatedNACKs atomic.Uint32 - keyFrameRequestGeneration atomic.Uint32 - blankFramesGeneration atomic.Uint32 connectionStats *connectionquality.ConnectionStats @@ -273,6 +271,10 @@ type DownTrack struct { maxLayerNotifierCh chan struct{} maxLayerNotifierChClosed bool + keyFrameRequesterChMu sync.RWMutex + keyFrameRequesterCh chan struct{} + keyFrameRequesterChClosed bool + cbMu sync.RWMutex onStatsUpdate func(dt *DownTrack, stat *livekit.AnalyticsStat) onMaxSubscribedLayerChanged func(dt *DownTrack, layer int32) @@ -294,13 +296,14 @@ func NewDownTrack(params DowntrackParams) (*DownTrack, error) { } d := &DownTrack{ - params: params, - id: params.Receiver.TrackID(), - upstreamCodecs: codecs, - kind: kind, - codec: codecs[0].RTPCodecCapability, - pacer: params.Pacer, - maxLayerNotifierCh: make(chan struct{}, 1), + params: params, + id: params.Receiver.TrackID(), + upstreamCodecs: codecs, + kind: kind, + codec: codecs[0].RTPCodecCapability, + pacer: params.Pacer, + maxLayerNotifierCh: make(chan struct{}, 1), + keyFrameRequesterCh: make(chan struct{}, 1), } d.forwarder = NewForwarder( d.kind, @@ -346,6 +349,7 @@ func NewDownTrack(params DowntrackParams) (*DownTrack, error) { } if d.kind == webrtc.RTPCodecTypeVideo { go d.maxLayerNotifierWorker() + go d.keyFrameRequester() } return d, nil @@ -584,57 +588,58 @@ func (d *DownTrack) GetTransceiver() *webrtc.RTPTransceiver { return d.transceiver.Load() } -func (d *DownTrack) maybeStartKeyFrameRequester() { - // - // Always move to next generation to abandon any running key frame requester - // This ensures that it is stopped if forwarding is disabled due to mute - // or paused due to bandwidth constraints. A new key frame requester is - // started if a layer lock is required. - // - d.stopKeyFrameRequester() - - locked, layer := d.forwarder.CheckSync() - if !locked { - go d.keyFrameRequester(d.keyFrameRequestGeneration.Load(), layer) - } -} - -func (d *DownTrack) stopKeyFrameRequester() { - d.keyFrameRequestGeneration.Inc() -} - -func (d *DownTrack) keyFrameRequester(generation uint32, layer int32) { - if d.IsClosed() || layer == buffer.InvalidLayerSpatial { +func (d *DownTrack) postKeyFrameRequestEvent() { + if d.kind != webrtc.RTPCodecTypeVideo { return } - interval := 2 * d.rtpStats.GetRtt() - if interval < keyFrameIntervalMin { - interval = keyFrameIntervalMin + d.keyFrameRequesterChMu.RLock() + if !d.keyFrameRequesterChClosed { + select { + case d.keyFrameRequesterCh <- struct{}{}: + default: + } } - if interval > keyFrameIntervalMax { - interval = keyFrameIntervalMax + d.keyFrameRequesterChMu.RUnlock() +} + +func (d *DownTrack) keyFrameRequester() { + getInterval := func() time.Duration { + interval := 2 * d.rtpStats.GetRtt() + if interval < keyFrameIntervalMin { + interval = keyFrameIntervalMin + } + if interval > keyFrameIntervalMax { + interval = keyFrameIntervalMax + } + return time.Duration(interval) * time.Millisecond } - ticker := time.NewTicker(time.Duration(interval) * time.Millisecond) + + interval := getInterval() + ticker := time.NewTicker(interval) defer ticker.Stop() for { - locked, _ := d.forwarder.CheckSync() - if locked { + if d.IsClosed() { return } - if d.writable.Load() { - d.params.Logger.Debugw("sending PLI for layer lock", "generation", generation, "layer", layer) + select { + case _, more := <-d.keyFrameRequesterCh: + if !more { + return + } + case <-ticker.C: + } + + locked, layer := d.forwarder.CheckSync() + if !locked && layer != buffer.InvalidLayerSpatial && d.writable.Load() { + d.params.Logger.Debugw("sending PLI for layer lock", "layer", layer) d.params.Receiver.SendPLI(layer, false) d.rtpStats.UpdateLayerLockPliAndTime(1) } - <-ticker.C - - if generation != d.keyFrameRequestGeneration.Load() || !d.writable.Load() { - return - } + ticker.Reset(getInterval()) } } @@ -1035,11 +1040,15 @@ func (d *DownTrack) CloseWithFlush(flush bool) { close(d.maxLayerNotifierCh) d.maxLayerNotifierChMu.Unlock() + d.keyFrameRequesterChMu.Lock() + d.keyFrameRequesterChClosed = true + close(d.keyFrameRequesterCh) + d.keyFrameRequesterChMu.Unlock() + if onCloseHandler := d.getOnCloseHandler(); onCloseHandler != nil { onCloseHandler(!flush) } - d.stopKeyFrameRequester() d.ClearStreamAllocatorReportInterval() } @@ -1223,7 +1232,7 @@ func (d *DownTrack) DistanceToDesired() float64 { func (d *DownTrack) AllocateOptimal(allowOvershoot bool) VideoAllocation { al, brs := d.params.Receiver.GetLayeredBitrate() allocation := d.forwarder.AllocateOptimal(al, brs, allowOvershoot) - d.maybeStartKeyFrameRequester() + d.postKeyFrameRequestEvent() d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) return allocation } @@ -1265,7 +1274,7 @@ func (d *DownTrack) ProvisionalAllocateGetBestWeightedTransition() VideoTransiti func (d *DownTrack) ProvisionalAllocateCommit() VideoAllocation { allocation := d.forwarder.ProvisionalAllocateCommit() - d.maybeStartKeyFrameRequester() + d.postKeyFrameRequestEvent() d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) return allocation } @@ -1273,7 +1282,7 @@ func (d *DownTrack) ProvisionalAllocateCommit() VideoAllocation { func (d *DownTrack) AllocateNextHigher(availableChannelCapacity int64, allowOvershoot bool) (VideoAllocation, bool) { al, brs := d.params.Receiver.GetLayeredBitrate() allocation, available := d.forwarder.AllocateNextHigher(availableChannelCapacity, al, brs, allowOvershoot) - d.maybeStartKeyFrameRequester() + d.postKeyFrameRequestEvent() d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) return allocation, available } @@ -1294,7 +1303,6 @@ func (d *DownTrack) GetNextHigherTransition(allowOvershoot bool) (VideoTransitio func (d *DownTrack) Pause() VideoAllocation { al, brs := d.params.Receiver.GetLayeredBitrate() allocation := d.forwarder.Pause(al, brs) - d.maybeStartKeyFrameRequester() d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) return allocation } @@ -1513,10 +1521,9 @@ func (d *DownTrack) handleRTCP(bytes []byte) { pliOnce := true sendPliOnce := func() { _, layer := d.forwarder.CheckSync() - isAnyMuted := d.forwarder.IsAnyMuted() - d.params.Logger.Debugw("received PLI/FIR RTCP", "layer", layer, "isAnyMuted", isAnyMuted) + d.params.Logger.Debugw("received PLI/FIR RTCP", "layer", layer) if pliOnce { - if layer != buffer.InvalidLayerSpatial && !isAnyMuted { + if layer != buffer.InvalidLayerSpatial { d.params.Logger.Debugw("sending PLI RTCP", "layer", layer) d.params.Receiver.SendPLI(layer, false) d.isNACKThrottled.Store(true) @@ -1845,10 +1852,6 @@ func (d *DownTrack) GetAndResetBytesSent() (uint32, uint32) { func (d *DownTrack) onBindAndConnectedChange() { d.writable.Store(d.connected.Load() && d.bound.Load()) if d.connected.Load() && d.bound.Load() && !d.bindAndConnectedOnce.Swap(true) { - if d.kind == webrtc.RTPCodecTypeVideo { - d.maybeStartKeyFrameRequester() - } - if d.activePaddingOnMuteUpTrack.Load() { go d.sendPaddingOnMute() } diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index ecd0b981c..5373c4848 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -1375,7 +1375,11 @@ func (f *Forwarder) updateAllocation(alloc VideoAllocation, reason string) Video func (f *Forwarder) setTargetLayer(targetLayer buffer.VideoLayer, requestLayerSpatial int32) { f.vls.SetTarget(targetLayer) - f.vls.SetRequestSpatial(requestLayerSpatial) + if targetLayer.IsValid() { + f.vls.SetRequestSpatial(requestLayerSpatial) + } else { + f.vls.SetRequestSpatial(buffer.InvalidLayerSpatial) + } } func (f *Forwarder) Resync() { @@ -1393,7 +1397,7 @@ func (f *Forwarder) resyncLocked() { } } -func (f *Forwarder) CheckSync() (locked bool, layer int32) { +func (f *Forwarder) CheckSync() (bool, int32) { f.lock.RLock() defer f.lock.RUnlock() diff --git a/pkg/sfu/utils/wraparound.go b/pkg/sfu/utils/wraparound.go index 1b5502c62..b04b331c2 100644 --- a/pkg/sfu/utils/wraparound.go +++ b/pkg/sfu/utils/wraparound.go @@ -70,16 +70,15 @@ func (w *WrapAround[T, ET]) Update(val T) (result WrapAroundUpdateResult[ET]) { return } - result.PreExtendedHighest = w.extendedHighest - gap := val - w.highest if gap > T(w.fullRange>>1) { // out-of-order - result.IsRestart, result.PreExtendedStart, result.ExtendedVal = w.maybeAdjustStart(val) - return + return w.maybeAdjustStart(val) } // in-order + result.PreExtendedHighest = w.extendedHighest + if val < w.highest { w.cycles += w.fullRange } @@ -124,7 +123,7 @@ func (w *WrapAround[T, ET]) updateExtendedHighest() { w.extendedHighest = getExtendedHighest(w.cycles, w.highest) } -func (w *WrapAround[T, ET]) maybeAdjustStart(val T) (isRestart bool, preExtendedStart ET, extendedVal ET) { +func (w *WrapAround[T, ET]) maybeAdjustStart(val T) (result WrapAroundUpdateResult[ET]) { // re-adjust start if necessary. The conditions are // 1. Not seen more than half the range yet // 1. wrap back compared to start and not completed a half cycle, sequences like (10, 65530) in uint16 space @@ -135,14 +134,19 @@ func (w *WrapAround[T, ET]) maybeAdjustStart(val T) (isRestart bool, preExtended if w.isWrapBack(val, w.highest) { cycles -= w.fullRange } - extendedVal = getExtendedHighest(cycles, val) + result.PreExtendedHighest = w.extendedHighest + result.ExtendedVal = getExtendedHighest(cycles, val) return } if val-w.start > T(w.fullRange>>1) { // out-of-order with existing start => a new start - isRestart = true - preExtendedStart = w.GetExtendedStart() + result.IsRestart = true + if val > w.start { + result.PreExtendedStart = w.fullRange + ET(w.start) + } else { + result.PreExtendedStart = ET(w.start) + } if w.isWrapBack(val, w.highest) { w.cycles = w.fullRange @@ -155,7 +159,8 @@ func (w *WrapAround[T, ET]) maybeAdjustStart(val T) (isRestart bool, preExtended cycles -= w.fullRange } } - extendedVal = getExtendedHighest(cycles, val) + result.PreExtendedHighest = w.extendedHighest + result.ExtendedVal = getExtendedHighest(cycles, val) return } diff --git a/pkg/sfu/utils/wraparound_test.go b/pkg/sfu/utils/wraparound_test.go index 9b69f105f..55d700650 100644 --- a/pkg/sfu/utils/wraparound_test.go +++ b/pkg/sfu/utils/wraparound_test.go @@ -67,8 +67,8 @@ func TestWrapAroundUint16(t *testing.T) { input: (1 << 16) - 6, updated: WrapAroundUpdateResult[uint32]{ IsRestart: true, - PreExtendedStart: 8, - PreExtendedHighest: 10, + PreExtendedStart: (1 << 16) + 8, + PreExtendedHighest: (1 << 16) + 10, ExtendedVal: (1 << 16) - 6, }, start: (1 << 16) - 6, @@ -236,8 +236,8 @@ func TestWrapAroundUint16RollbackRestartAndResetHighest(t *testing.T) { res = w.Update(65533) expectedResult = WrapAroundUpdateResult[uint64]{ IsRestart: true, - PreExtendedStart: 23, - PreExtendedHighest: 25, + PreExtendedStart: (1 << 16) + 23, + PreExtendedHighest: (1 << 16) + 25, ExtendedVal: 65533, } require.Equal(t, expectedResult, res) @@ -267,6 +267,52 @@ func TestWrapAroundUint16RollbackRestartAndResetHighest(t *testing.T) { require.Equal(t, uint64(0x7f1234), w.GetExtendedHighest()) } +func TestWrapAroundUint16WrapAroundRestartDuplicate(t *testing.T) { + w := NewWrapAround[uint16, uint64]() + + // initialize + w.Update(65534) + require.Equal(t, uint16(65534), w.GetStart()) + require.Equal(t, uint64(65534), w.GetExtendedStart()) + require.Equal(t, uint16(65534), w.GetHighest()) + require.Equal(t, uint64(65534), w.GetExtendedHighest()) + + // an in-order update with a roll over + w.Update(32) + require.Equal(t, uint16(65534), w.GetStart()) + require.Equal(t, uint64(65534), w.GetExtendedStart()) + require.Equal(t, uint16(32), w.GetHighest()) + require.Equal(t, uint64(65568), w.GetExtendedHighest()) + + // duplicate of start + res := w.Update(65534) + expectedResult := WrapAroundUpdateResult[uint64]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: 65568, + ExtendedVal: 65534, + } + require.Equal(t, expectedResult, res) + require.Equal(t, uint16(65534), w.GetStart()) + require.Equal(t, uint64(65534), w.GetExtendedStart()) + require.Equal(t, uint16(32), w.GetHighest()) + require.Equal(t, uint64(65568), w.GetExtendedHighest()) + + // duplicate of start - again + res = w.Update(65534) + expectedResult = WrapAroundUpdateResult[uint64]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: 65568, + ExtendedVal: 65534, + } + require.Equal(t, expectedResult, res) + require.Equal(t, uint16(65534), w.GetStart()) + require.Equal(t, uint64(65534), w.GetExtendedStart()) + require.Equal(t, uint16(32), w.GetHighest()) + require.Equal(t, uint64(65568), w.GetExtendedHighest()) +} + func TestWrapAroundUint32(t *testing.T) { w := NewWrapAround[uint32, uint64]() testCases := []struct { @@ -314,8 +360,8 @@ func TestWrapAroundUint32(t *testing.T) { input: (1 << 32) - 6, updated: WrapAroundUpdateResult[uint64]{ IsRestart: true, - PreExtendedStart: 8, - PreExtendedHighest: 10, + PreExtendedStart: (1 << 32) + 8, + PreExtendedHighest: (1 << 32) + 10, ExtendedVal: (1 << 32) - 6, }, start: (1 << 32) - 6, diff --git a/pkg/sfu/videolayerselector/dependencydescriptor.go b/pkg/sfu/videolayerselector/dependencydescriptor.go index 1f8a9390b..f03f78c6c 100644 --- a/pkg/sfu/videolayerselector/dependencydescriptor.go +++ b/pkg/sfu/videolayerselector/dependencydescriptor.go @@ -58,7 +58,9 @@ func (d *DependencyDescriptor) IsOvershootOkay() bool { func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (result VideoLayerSelectorResult) { // a packet is always relevant for the svc codec - result.IsRelevant = true + if d.currentLayer.IsValid() { + result.IsRelevant = true + } ddwdt := extPkt.DependencyDescriptor if ddwdt == nil { @@ -92,7 +94,7 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r switch sd { case selectorDecisionDropped: // a packet of an alreadty dropped frame, maintain decision - d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sm: %d", + d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d", incomingLayer, dd.FrameNumber, extFrameNum, @@ -243,6 +245,8 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r "sn", extPkt.Packet.SequenceNumber, "isKeyFrame", extPkt.KeyFrame, ) + + result.IsRelevant = true } ddExtension := &dede.DependencyDescriptorExtension{ @@ -335,7 +339,7 @@ func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) { d.decodeTargetsLock.RLock() defer d.decodeTargetsLock.RUnlock() for _, dt := range d.decodeTargets { - if dt.Active() && dt.Layer.Spatial == layer && dt.Valid() { + if dt.Active() && dt.Layer.Spatial <= d.GetTarget().Spatial && dt.Valid() { d.logger.Debugw(fmt.Sprintf("checking sync, matching decode target, layer: %d, dt: %s, dts: %+v", layer, dt, d.decodeTargets)) return true, layer } diff --git a/pkg/sfu/videolayerselector/dependencydescriptor_test.go b/pkg/sfu/videolayerselector/dependencydescriptor_test.go index a33f3b15f..c013e46af 100644 --- a/pkg/sfu/videolayerselector/dependencydescriptor_test.go +++ b/pkg/sfu/videolayerselector/dependencydescriptor_test.go @@ -139,7 +139,7 @@ func TestDependencyDescriptor(t *testing.T) { // no dd ext, dropped ret := ddSelector.Select(&buffer.ExtPacket{Packet: &rtp.Packet{}}, 0) require.False(t, ret.IsSelected) - require.True(t, ret.IsRelevant) + require.False(t, ret.IsRelevant) // non key frame, dropped ret = ddSelector.Select(&buffer.ExtPacket{ @@ -156,7 +156,7 @@ func TestDependencyDescriptor(t *testing.T) { Packet: &rtp.Packet{}, }, 0) require.False(t, ret.IsSelected) - require.True(t, ret.IsRelevant) + require.False(t, ret.IsRelevant) frames := createDDFrames(buffer.VideoLayer{Spatial: 2, Temporal: 2}, 3) // key frame, update structure and decode targets @@ -253,16 +253,11 @@ func TestDependencyDescriptor(t *testing.T) { } require.True(t, switchToLower) - // not sync with requested layer + // sync with requested layer ddSelector.SetRequestSpatial(targetLayer.Spatial) locked, layer := ddSelector.CheckSync() - require.False(t, locked) - require.Equal(t, targetLayer.Spatial, layer) - - // request to current layer, sync - ddSelector.SetRequestSpatial(ddSelector.GetCurrent().Spatial) - locked, _ = ddSelector.CheckSync() require.True(t, locked) + require.Equal(t, targetLayer.Spatial, layer) } func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buffer.ExtPacket { diff --git a/pkg/sfu/videolayerselector/framechain.go b/pkg/sfu/videolayerselector/framechain.go index 613f2d264..0777fb0fa 100644 --- a/pkg/sfu/videolayerselector/framechain.go +++ b/pkg/sfu/videolayerselector/framechain.go @@ -45,6 +45,11 @@ func (fc *FrameChain) OnFrame(extFrameNum uint64, fd *dd.FrameDependencyTemplate return false } + if len(fd.ChainDiffs) <= fc.chainIdx { + fc.logger.Warnw("invalid frame chain diff", nil, "chanIdx", fc.chainIdx, "frame", extFrameNum, "fd", fd) + return fc.broken + } + // A decodable frame with frame_chain_fdiff equal to 0 indicates that the Chain is intact. if fd.ChainDiffs[fc.chainIdx] == 0 { if fc.broken {