mirror of
https://github.com/livekit/livekit.git
synced 2026-05-10 23:37:13 +00:00
Merge remote-tracking branch 'origin/master' into raja_fr
This commit is contained in:
@@ -18,7 +18,7 @@ require (
|
||||
github.com/jxskiss/base62 v1.1.0
|
||||
github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1
|
||||
github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e
|
||||
github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1
|
||||
github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e
|
||||
github.com/livekit/psrpc v0.5.0
|
||||
github.com/mackerelio/go-osstat v0.2.4
|
||||
github.com/magefile/mage v1.15.0
|
||||
@@ -27,7 +27,7 @@ require (
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/pion/dtls/v2 v2.2.7
|
||||
github.com/pion/ice/v2 v2.3.11
|
||||
github.com/pion/interceptor v0.1.24
|
||||
github.com/pion/interceptor v0.1.25
|
||||
github.com/pion/rtcp v1.2.10
|
||||
github.com/pion/rtp v1.8.2
|
||||
github.com/pion/sctp v1.8.9
|
||||
@@ -47,7 +47,7 @@ require (
|
||||
github.com/urfave/negroni/v3 v3.0.0
|
||||
go.uber.org/atomic v1.11.0
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
|
||||
golang.org/x/sync v0.4.0
|
||||
golang.org/x/sync v0.5.0
|
||||
google.golang.org/protobuf v1.31.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -125,8 +125,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD
|
||||
github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ=
|
||||
github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e h1:yNeIo7MSMUWgoLu7LkNKnBYnJBFPFH9Wq4S6h1kS44M=
|
||||
github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e/go.mod h1:+WIOYwiBMive5T81V8B2wdAc2zQNRjNQiJIcPxMTILY=
|
||||
github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1 h1:WPWxU9w5XHAsonxnSSIIXbWMty9b5uHnTnyKS9TpaXM=
|
||||
github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ=
|
||||
github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e h1:YShBpEjkEBY7yil2gjMWlkVkxs3OI58LIIYsBdb8aBU=
|
||||
github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ=
|
||||
github.com/livekit/psrpc v0.5.0 h1:g+yYNSs6Y1/vM7UlFkB2s/ARe2y3RKWZhX8ata5j+eo=
|
||||
github.com/livekit/psrpc v0.5.0/go.mod h1:1XYH1LLoD/YbvBvt6xg2KQ/J3InLXSJK6PL/+DKmuAU=
|
||||
github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs=
|
||||
@@ -185,8 +185,8 @@ github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ
|
||||
github.com/pion/ice/v2 v2.3.11 h1:rZjVmUwyT55cmN8ySMpL7rsS8KYsJERsrxJLLxpKhdw=
|
||||
github.com/pion/ice/v2 v2.3.11/go.mod h1:hPcLC3kxMa+JGRzMHqQzjoSj3xtE9F+eoncmXLlCL4E=
|
||||
github.com/pion/interceptor v0.1.18/go.mod h1:tpvvF4cPM6NGxFA1DUMbhabzQBxdWMATDGEUYOR9x6I=
|
||||
github.com/pion/interceptor v0.1.24 h1:lN4ua3yUAJCgNKQKcZIM52wFjBgjN0r7shLj91PkJ0c=
|
||||
github.com/pion/interceptor v0.1.24/go.mod h1:wkbPYAak5zKsfpVDYMtEfWEy8D4zL+rpxCxPImLOg3Y=
|
||||
github.com/pion/interceptor v0.1.25 h1:pwY9r7P6ToQ3+IF0bajN0xmk/fNw/suTgaTdlwTDmhc=
|
||||
github.com/pion/interceptor v0.1.25/go.mod h1:wkbPYAak5zKsfpVDYMtEfWEy8D4zL+rpxCxPImLOg3Y=
|
||||
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
|
||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||
github.com/pion/mdns v0.0.8 h1:HhicWIg7OX5PVilyBO6plhMetInbzkVJAhbdJiAeVaI=
|
||||
@@ -335,8 +335,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
|
||||
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
|
||||
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
||||
@@ -40,7 +40,7 @@ func (s *StaticClientConfigurationManager) GetConfiguration(clientInfo *livekit.
|
||||
for _, c := range s.confs {
|
||||
matched, err := c.Match.Match(clientInfo)
|
||||
if err != nil {
|
||||
logger.Errorw("matchrule failed", err, "clientInfo", clientInfo.String())
|
||||
logger.Errorw("matchrule failed", err, "clientInfo", logger.Proto(clientInfo))
|
||||
continue
|
||||
}
|
||||
if !matched {
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
// Copyright 2023 LiveKit, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package rtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
"github.com/livekit/protocol/rpc"
|
||||
"github.com/livekit/psrpc"
|
||||
)
|
||||
|
||||
const (
|
||||
RoomAgentTopic = "room"
|
||||
PublisherAgentTopic = "publisher"
|
||||
)
|
||||
|
||||
type AgentClient interface {
|
||||
CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) *rpc.CheckEnabledResponse
|
||||
JobRequest(ctx context.Context, job *livekit.Job)
|
||||
}
|
||||
|
||||
type agentClient struct {
|
||||
client rpc.AgentInternalClient
|
||||
}
|
||||
|
||||
func NewAgentClient(bus psrpc.MessageBus) (AgentClient, error) {
|
||||
client, err := rpc.NewAgentInternalClient(bus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &agentClient{client: client}, nil
|
||||
}
|
||||
|
||||
func (c *agentClient) CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) *rpc.CheckEnabledResponse {
|
||||
res := &rpc.CheckEnabledResponse{}
|
||||
resChan, err := c.client.CheckEnabled(ctx, req, psrpc.WithRequestTimeout(time.Second))
|
||||
if err != nil {
|
||||
return res
|
||||
}
|
||||
|
||||
for r := range resChan {
|
||||
if r.Err != nil {
|
||||
continue
|
||||
}
|
||||
if r.Result.RoomEnabled {
|
||||
res.RoomEnabled = true
|
||||
if res.PublisherEnabled {
|
||||
return res
|
||||
}
|
||||
}
|
||||
if r.Result.PublisherEnabled {
|
||||
res.PublisherEnabled = true
|
||||
if res.RoomEnabled {
|
||||
return res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (c *agentClient) JobRequest(ctx context.Context, job *livekit.Job) {
|
||||
var topic string
|
||||
var logError bool
|
||||
switch job.Type {
|
||||
case livekit.JobType_JT_ROOM:
|
||||
topic = RoomAgentTopic
|
||||
case livekit.JobType_JT_PUBLISHER:
|
||||
topic = PublisherAgentTopic
|
||||
logError = true
|
||||
}
|
||||
|
||||
_, err := c.client.JobRequest(ctx, topic, job)
|
||||
if err != nil && logError {
|
||||
logger.Warnw("agent job request failed", err)
|
||||
}
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func IsCodecEnabled(codecs []*livekit.Codec, cap webrtc.RTPCodecCapability) bool
|
||||
return false
|
||||
}
|
||||
|
||||
func selectAlternativeCodec(enabledCodecs []*livekit.Codec) string {
|
||||
func selectAlternativeVideoCodec(enabledCodecs []*livekit.Codec) string {
|
||||
// sort these by compatibility, since we are looking for backups
|
||||
if slices.ContainsFunc(enabledCodecs, func(c *livekit.Codec) bool {
|
||||
return strings.EqualFold(c.Mime, webrtc.MimeTypeVP8)
|
||||
@@ -146,9 +146,11 @@ func selectAlternativeCodec(enabledCodecs []*livekit.Codec) string {
|
||||
}) {
|
||||
return webrtc.MimeTypeH264
|
||||
}
|
||||
if len(enabledCodecs) > 0 {
|
||||
return enabledCodecs[0].Mime
|
||||
for _, c := range enabledCodecs {
|
||||
if strings.HasPrefix(c.Mime, "video/") {
|
||||
return c.Mime
|
||||
}
|
||||
}
|
||||
// uh oh. this should not happen
|
||||
return ""
|
||||
// no viable codec in the list of enabled codecs, fall back to the most widely supported codec
|
||||
return webrtc.MimeTypeVP8
|
||||
}
|
||||
|
||||
+11
-2
@@ -239,13 +239,22 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra
|
||||
t.params.Logger.Debugw("AddReceiver", "mime", track.Codec().MimeType)
|
||||
wr := t.MediaTrackReceiver.Receiver(mime)
|
||||
if wr == nil {
|
||||
var priority int
|
||||
priority := -1
|
||||
for idx, c := range t.params.TrackInfo.Codecs {
|
||||
if strings.HasSuffix(mime, c.MimeType) {
|
||||
if strings.EqualFold(mime, c.MimeType) {
|
||||
priority = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(t.params.TrackInfo.Codecs) == 0 {
|
||||
priority = 0
|
||||
}
|
||||
if priority < 0 {
|
||||
t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mime, "track", logger.Proto(t.params.TrackInfo))
|
||||
t.lock.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
newWR := sfu.NewWebRTCReceiver(
|
||||
receiver,
|
||||
track,
|
||||
|
||||
+66
-38
@@ -101,7 +101,8 @@ type ParticipantParams struct {
|
||||
PLIThrottleConfig config.PLIThrottleConfig
|
||||
CongestionControlConfig config.CongestionControlConfig
|
||||
// codecs that are enabled for this room
|
||||
EnabledCodecs []*livekit.Codec
|
||||
PublishEnabledCodecs []*livekit.Codec
|
||||
SubscribeEnabledCodecs []*livekit.Codec
|
||||
Logger logger.Logger
|
||||
SimTracks map[uint32]SimulcastTrackInfo
|
||||
Grants *auth.ClaimGrants
|
||||
@@ -138,6 +139,7 @@ type ParticipantImpl struct {
|
||||
resSinkMu sync.Mutex
|
||||
resSink routing.MessageSink
|
||||
grants *auth.ClaimGrants
|
||||
hidden atomic.Bool
|
||||
isPublisher atomic.Bool
|
||||
|
||||
// when first connected
|
||||
@@ -248,8 +250,9 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) {
|
||||
p.migrateState.Store(types.MigrateStateInit)
|
||||
p.state.Store(livekit.ParticipantInfo_JOINING)
|
||||
p.grants = params.Grants
|
||||
p.hidden.Store(p.grants.Video.Hidden)
|
||||
p.SetResponseSink(params.Sink)
|
||||
p.setupEnabledCodecs(params.EnabledCodecs, params.ClientConf.GetDisabledCodecs())
|
||||
p.setupEnabledCodecs(params.PublishEnabledCodecs, params.SubscribeEnabledCodecs, params.ClientConf.GetDisabledCodecs())
|
||||
|
||||
p.supervisor.OnPublicationError(p.onPublicationError)
|
||||
|
||||
@@ -425,6 +428,7 @@ func (p *ParticipantImpl) SetPermission(permission *livekit.ParticipantPermissio
|
||||
p.params.Logger.Infow("updating participant permission", "permission", permission)
|
||||
|
||||
video.UpdateFromPermission(permission)
|
||||
p.hidden.Store(permission.Hidden)
|
||||
p.dirty.Store(true)
|
||||
|
||||
canPublish := video.GetCanPublish()
|
||||
@@ -710,7 +714,7 @@ func (p *ParticipantImpl) SetMigrateInfo(
|
||||
p.supervisor.SetPublicationMute(livekit.TrackID(ti.Sid), ti.Muted)
|
||||
|
||||
p.pendingTracks[t.GetCid()] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}, migrated: true}
|
||||
p.pubLogger.Infow("pending track added (migration)", "trackID", ti.Sid, "track", ti.String())
|
||||
p.pubLogger.Infow("pending track added (migration)", "trackID", ti.Sid, "track", logger.Proto(ti))
|
||||
}
|
||||
p.pendingTracksLock.Unlock()
|
||||
|
||||
@@ -734,7 +738,7 @@ func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseRea
|
||||
"sendLeave", sendLeave,
|
||||
"reason", reason.String(),
|
||||
"isExpectedToResume", isExpectedToResume,
|
||||
"clientInfo", p.params.ClientInfo.String(),
|
||||
"clientInfo", logger.Proto(p.params.ClientInfo),
|
||||
)
|
||||
p.clearDisconnectTimer()
|
||||
p.clearMigrationTimer()
|
||||
@@ -1020,10 +1024,7 @@ func (p *ParticipantImpl) CanPublishData() bool {
|
||||
}
|
||||
|
||||
func (p *ParticipantImpl) Hidden() bool {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
|
||||
return p.grants.Video.Hidden
|
||||
return p.hidden.Load()
|
||||
}
|
||||
|
||||
func (p *ParticipantImpl) IsRecorder() bool {
|
||||
@@ -1033,6 +1034,13 @@ func (p *ParticipantImpl) IsRecorder() bool {
|
||||
return p.grants.Video.Recorder
|
||||
}
|
||||
|
||||
func (p *ParticipantImpl) IsAgent() bool {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
|
||||
return p.grants.Video.Agent
|
||||
}
|
||||
|
||||
func (p *ParticipantImpl) VerifySubscribeParticipantInfo(pID livekit.ParticipantID, version uint32) {
|
||||
if !p.IsReady() {
|
||||
// we have not sent a JoinResponse yet. metadata would be covered in JoinResponse
|
||||
@@ -1623,10 +1631,12 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l
|
||||
seenCodecs := make(map[string]struct{})
|
||||
for _, codec := range req.SimulcastCodecs {
|
||||
mime := codec.Codec
|
||||
if req.Type == livekit.TrackType_VIDEO && !strings.HasPrefix(mime, "video/") {
|
||||
mime = "video/" + mime
|
||||
if req.Type == livekit.TrackType_VIDEO {
|
||||
if !strings.HasPrefix(mime, "video/") {
|
||||
mime = "video/" + mime
|
||||
}
|
||||
if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) {
|
||||
altCodec := selectAlternativeCodec(p.enabledPublishCodecs)
|
||||
altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs)
|
||||
p.pubLogger.Infow("falling back to alternative codec",
|
||||
"codec", mime,
|
||||
"altCodec", altCodec,
|
||||
@@ -1639,7 +1649,7 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l
|
||||
mime = "audio/" + mime
|
||||
}
|
||||
|
||||
if _, ok := seenCodecs[mime]; ok {
|
||||
if _, ok := seenCodecs[mime]; ok || mime == "" {
|
||||
continue
|
||||
}
|
||||
seenCodecs[mime] = struct{}{}
|
||||
@@ -1658,12 +1668,12 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l
|
||||
} else {
|
||||
p.pendingTracks[req.Cid].trackInfos = append(p.pendingTracks[req.Cid].trackInfos, ti)
|
||||
}
|
||||
p.pubLogger.Infow("pending track queued", "trackID", ti.Sid, "track", ti.String(), "request", req.String())
|
||||
p.pubLogger.Infow("pending track queued", "trackID", ti.Sid, "track", logger.Proto(ti), "request", logger.Proto(req))
|
||||
return nil
|
||||
}
|
||||
|
||||
p.pendingTracks[req.Cid] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}}
|
||||
p.pubLogger.Infow("pending track added", "trackID", ti.Sid, "track", ti.String(), "request", req.String())
|
||||
p.pubLogger.Infow("pending track added", "trackID", ti.Sid, "track", logger.Proto(ti), "request", logger.Proto(req))
|
||||
return ti
|
||||
}
|
||||
|
||||
@@ -1681,7 +1691,7 @@ func (p *ParticipantImpl) GetPendingTrack(trackID livekit.TrackID) *livekit.Trac
|
||||
}
|
||||
|
||||
func (p *ParticipantImpl) sendTrackPublished(cid string, ti *livekit.TrackInfo) {
|
||||
p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", ti.String())
|
||||
p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", logger.Proto(ti))
|
||||
_ = p.writeMessage(&livekit.SignalResponse{
|
||||
Message: &livekit.SignalResponse_TrackPublished{
|
||||
TrackPublished: &livekit.TrackPublishedResponse{
|
||||
@@ -1761,12 +1771,32 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei
|
||||
// use existing media track to handle simulcast
|
||||
mt, ok := p.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack)
|
||||
if !ok {
|
||||
signalCid, ti := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind()))
|
||||
signalCid, ti, migrated := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind()))
|
||||
if ti == nil {
|
||||
p.pendingTracksLock.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// check if the migrated track has correct codec
|
||||
if migrated && len(ti.Codecs) > 0 {
|
||||
parameters := rtpReceiver.GetParameters()
|
||||
var codecFound int
|
||||
for _, c := range ti.Codecs {
|
||||
for _, nc := range parameters.Codecs {
|
||||
if strings.EqualFold(nc.MimeType, c.MimeType) {
|
||||
codecFound++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if codecFound != len(ti.Codecs) {
|
||||
p.params.Logger.Warnw("migrated track codec mismatched", nil, "track", logger.Proto(ti), "webrtcCodec", parameters)
|
||||
p.pendingTracksLock.Unlock()
|
||||
p.IssueFullReconnect(types.ParticipantCloseReasonMigrateCodecMismatch)
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
ti.MimeType = track.Codec().MimeType
|
||||
mt = p.addMediaTrack(signalCid, track.ID(), ti)
|
||||
newTrack = true
|
||||
@@ -1793,7 +1823,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei
|
||||
}
|
||||
|
||||
func (p *ParticipantImpl) addMigrateMutedTrack(cid string, ti *livekit.TrackInfo) *MediaTrack {
|
||||
p.pubLogger.Infow("add migrate muted track", "cid", cid, "trackID", ti.Sid, "track", ti.String())
|
||||
p.pubLogger.Infow("add migrate muted track", "cid", cid, "trackID", ti.Sid, "track", logger.Proto(ti))
|
||||
rtpReceiver := p.TransportManager.GetPublisherRTPReceiver(ti.Mid)
|
||||
if rtpReceiver == nil {
|
||||
p.pubLogger.Errorw("could not find receiver for migrated track", nil, "trackID", ti.Sid)
|
||||
@@ -1975,7 +2005,7 @@ func (p *ParticipantImpl) onUpTrackManagerClose() {
|
||||
p.postRtcp(nil)
|
||||
}
|
||||
|
||||
func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) {
|
||||
func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo, bool) {
|
||||
signalCid := clientId
|
||||
pendingInfo := p.pendingTracks[clientId]
|
||||
if pendingInfo == nil {
|
||||
@@ -2011,10 +2041,10 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp
|
||||
// if still not found, we are done
|
||||
if pendingInfo == nil {
|
||||
p.pubLogger.Errorw("track info not published prior to track", nil, "clientId", clientId)
|
||||
return signalCid, nil
|
||||
return signalCid, nil, false
|
||||
}
|
||||
|
||||
return signalCid, pendingInfo.trackInfos[0]
|
||||
return signalCid, pendingInfo.trackInfos[0], pendingInfo.migrated
|
||||
}
|
||||
|
||||
// setStableTrackID either generates a new TrackID or reuses a previously used one
|
||||
@@ -2205,7 +2235,7 @@ func (p *ParticipantImpl) IssueFullReconnect(reason types.ParticipantCloseReason
|
||||
|
||||
scr := types.SignallingCloseReasonUnknown
|
||||
switch reason {
|
||||
case types.ParticipantCloseReasonPublicationError:
|
||||
case types.ParticipantCloseReasonPublicationError, types.ParticipantCloseReasonMigrateCodecMismatch:
|
||||
scr = types.SignallingCloseReasonFullReconnectPublicationError
|
||||
case types.ParticipantCloseReasonSubscriptionError:
|
||||
scr = types.SignallingCloseReasonFullReconnectSubscriptionError
|
||||
@@ -2334,9 +2364,7 @@ func (p *ParticipantImpl) SendDataPacket(dp *livekit.DataPacket, data []byte) er
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ParticipantImpl) setupEnabledCodecs(codecs []*livekit.Codec, disabledCodecs *livekit.DisabledCodecs) {
|
||||
subscribeCodecs := make([]*livekit.Codec, 0, len(codecs))
|
||||
publishCodecs := make([]*livekit.Codec, 0, len(codecs))
|
||||
func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Codec, subscribeEnabledCodecs []*livekit.Codec, disabledCodecs *livekit.DisabledCodecs) {
|
||||
shouldDisable := func(c *livekit.Codec, disabled []*livekit.Codec) bool {
|
||||
for _, disableCodec := range disabled {
|
||||
// disable codec's fmtp is empty means disable this codec entirely
|
||||
@@ -2346,22 +2374,22 @@ func (p *ParticipantImpl) setupEnabledCodecs(codecs []*livekit.Codec, disabledCo
|
||||
}
|
||||
return false
|
||||
}
|
||||
for _, c := range codecs {
|
||||
var publishDisabled bool
|
||||
var subscribeDisabled bool
|
||||
|
||||
publishCodecs := make([]*livekit.Codec, 0, len(publishEnabledCodecs))
|
||||
for _, c := range publishEnabledCodecs {
|
||||
if shouldDisable(c, disabledCodecs.GetCodecs()) || shouldDisable(c, disabledCodecs.GetPublish()) {
|
||||
continue
|
||||
}
|
||||
publishCodecs = append(publishCodecs, c)
|
||||
}
|
||||
p.enabledPublishCodecs = publishCodecs
|
||||
|
||||
subscribeCodecs := make([]*livekit.Codec, 0, len(subscribeEnabledCodecs))
|
||||
for _, c := range subscribeEnabledCodecs {
|
||||
if shouldDisable(c, disabledCodecs.GetCodecs()) {
|
||||
publishDisabled = true
|
||||
subscribeDisabled = true
|
||||
} else if shouldDisable(c, disabledCodecs.GetPublish()) {
|
||||
publishDisabled = true
|
||||
}
|
||||
if !publishDisabled {
|
||||
publishCodecs = append(publishCodecs, c)
|
||||
}
|
||||
if !subscribeDisabled {
|
||||
subscribeCodecs = append(subscribeCodecs, c)
|
||||
continue
|
||||
}
|
||||
subscribeCodecs = append(subscribeCodecs, c)
|
||||
}
|
||||
p.enabledSubscribeCodecs = subscribeCodecs
|
||||
p.enabledPublishCodecs = publishCodecs
|
||||
}
|
||||
|
||||
@@ -284,7 +284,7 @@ func TestMuteSetting(t *testing.T) {
|
||||
Muted: true,
|
||||
})
|
||||
|
||||
_, ti := p.getPendingTrack("cid", livekit.TrackType_AUDIO)
|
||||
_, ti, _ := p.getPendingTrack("cid", livekit.TrackType_AUDIO)
|
||||
require.NotNil(t, ti)
|
||||
require.True(t, ti.Muted)
|
||||
})
|
||||
@@ -748,19 +748,20 @@ func newParticipantForTestWithOpts(identity livekit.ParticipantIdentity, opts *p
|
||||
}
|
||||
sid := livekit.ParticipantID(utils.NewGuid(utils.ParticipantPrefix))
|
||||
p, _ := NewParticipant(ParticipantParams{
|
||||
SID: sid,
|
||||
Identity: identity,
|
||||
Config: rtcConf,
|
||||
Sink: &routingfakes.FakeMessageSink{},
|
||||
ProtocolVersion: opts.protocolVersion,
|
||||
PLIThrottleConfig: conf.RTC.PLIThrottle,
|
||||
Grants: grants,
|
||||
EnabledCodecs: enabledCodecs,
|
||||
ClientConf: opts.clientConf,
|
||||
ClientInfo: ClientInfo{ClientInfo: opts.clientInfo},
|
||||
Logger: LoggerWithParticipant(logger.GetLogger(), identity, sid, false),
|
||||
Telemetry: &telemetryfakes.FakeTelemetryService{},
|
||||
VersionGenerator: utils.NewDefaultTimedVersionGenerator(),
|
||||
SID: sid,
|
||||
Identity: identity,
|
||||
Config: rtcConf,
|
||||
Sink: &routingfakes.FakeMessageSink{},
|
||||
ProtocolVersion: opts.protocolVersion,
|
||||
PLIThrottleConfig: conf.RTC.PLIThrottle,
|
||||
Grants: grants,
|
||||
PublishEnabledCodecs: enabledCodecs,
|
||||
SubscribeEnabledCodecs: enabledCodecs,
|
||||
ClientConf: opts.clientConf,
|
||||
ClientInfo: ClientInfo{ClientInfo: opts.clientInfo},
|
||||
Logger: LoggerWithParticipant(logger.GetLogger(), identity, sid, false),
|
||||
Telemetry: &telemetryfakes.FakeTelemetryService{},
|
||||
VersionGenerator: utils.NewDefaultTimedVersionGenerator(),
|
||||
})
|
||||
p.isPublisher.Store(opts.publisher)
|
||||
p.updateState(livekit.ParticipantInfo_ACTIVE)
|
||||
|
||||
@@ -46,7 +46,7 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.Se
|
||||
}
|
||||
|
||||
p.pendingTracksLock.RLock()
|
||||
_, info := p.getPendingTrack(streamID, livekit.TrackType_AUDIO)
|
||||
_, info, _ := p.getPendingTrack(streamID, livekit.TrackType_AUDIO)
|
||||
// if RED is disabled for this track, don't prefer RED codec in offer
|
||||
disableRed := info != nil && info.DisableRed
|
||||
p.pendingTracksLock.RUnlock()
|
||||
@@ -132,7 +132,7 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess
|
||||
if mt != nil {
|
||||
info = mt.ToProto()
|
||||
} else {
|
||||
_, info = p.getPendingTrack(streamID, livekit.TrackType_VIDEO)
|
||||
_, info, _ = p.getPendingTrack(streamID, livekit.TrackType_VIDEO)
|
||||
}
|
||||
|
||||
if info == nil {
|
||||
@@ -239,7 +239,7 @@ func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescript
|
||||
track, _ := p.getPublishedTrackBySdpCid(streamID).(*MediaTrack)
|
||||
if track == nil {
|
||||
p.pendingTracksLock.RLock()
|
||||
_, ti = p.getPendingTrack(streamID, livekit.TrackType_AUDIO)
|
||||
_, ti, _ = p.getPendingTrack(streamID, livekit.TrackType_AUDIO)
|
||||
p.pendingTracksLock.RUnlock()
|
||||
} else {
|
||||
ti = track.TrackInfo(false)
|
||||
|
||||
+59
-9
@@ -30,6 +30,7 @@ import (
|
||||
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
"github.com/livekit/protocol/rpc"
|
||||
"github.com/livekit/protocol/utils"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/config"
|
||||
@@ -77,11 +78,15 @@ type Room struct {
|
||||
egressLauncher EgressLauncher
|
||||
trackManager *RoomTrackManager
|
||||
|
||||
// agents
|
||||
agentClient AgentClient
|
||||
publisherAgentsEnabled bool
|
||||
|
||||
// map of identity -> Participant
|
||||
participants map[livekit.ParticipantIdentity]types.LocalParticipant
|
||||
participantOpts map[livekit.ParticipantIdentity]*ParticipantOptions
|
||||
participantRequestSources map[livekit.ParticipantIdentity]routing.MessageSource
|
||||
hasPublished sync.Map // map of identity -> bool
|
||||
hasPublished map[livekit.ParticipantIdentity]bool
|
||||
bufferFactory *buffer.FactoryOfBufferFactory
|
||||
|
||||
// batch update participant info for non-publishers
|
||||
@@ -113,6 +118,7 @@ func NewRoom(
|
||||
audioConfig *config.AudioConfig,
|
||||
serverInfo *livekit.ServerInfo,
|
||||
telemetry telemetry.TelemetryService,
|
||||
agentClient AgentClient,
|
||||
egressLauncher EgressLauncher,
|
||||
) *Room {
|
||||
r := &Room{
|
||||
@@ -127,16 +133,19 @@ func NewRoom(
|
||||
audioConfig: audioConfig,
|
||||
telemetry: telemetry,
|
||||
egressLauncher: egressLauncher,
|
||||
agentClient: agentClient,
|
||||
trackManager: NewRoomTrackManager(),
|
||||
serverInfo: serverInfo,
|
||||
participants: make(map[livekit.ParticipantIdentity]types.LocalParticipant),
|
||||
participantOpts: make(map[livekit.ParticipantIdentity]*ParticipantOptions),
|
||||
participantRequestSources: make(map[livekit.ParticipantIdentity]routing.MessageSource),
|
||||
hasPublished: make(map[livekit.ParticipantIdentity]bool),
|
||||
bufferFactory: buffer.NewFactoryOfBufferFactory(config.Receiver.PacketBufferSize),
|
||||
batchedUpdates: make(map[livekit.ParticipantIdentity]*livekit.ParticipantInfo),
|
||||
closed: make(chan struct{}),
|
||||
trailer: []byte(utils.RandomSecret()),
|
||||
}
|
||||
|
||||
r.protoProxy = utils.NewProtoProxy[*livekit.Room](roomUpdateInterval, r.updateProto)
|
||||
if r.protoRoom.EmptyTimeout == 0 {
|
||||
r.protoRoom.EmptyTimeout = DefaultEmptyTimeout
|
||||
@@ -145,6 +154,21 @@ func NewRoom(
|
||||
r.protoRoom.CreationTime = time.Now().Unix()
|
||||
}
|
||||
|
||||
if agentClient != nil {
|
||||
go func() {
|
||||
res := r.agentClient.CheckEnabled(context.Background(), &rpc.CheckEnabledRequest{})
|
||||
if res.PublisherEnabled {
|
||||
r.lock.Lock()
|
||||
r.publisherAgentsEnabled = true
|
||||
// if there are already published tracks, start the agents
|
||||
for identity := range r.hasPublished {
|
||||
r.launchPublisherAgent(r.participants[identity])
|
||||
}
|
||||
r.lock.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go r.audioUpdateWorker()
|
||||
go r.connectionQualityWorker()
|
||||
go r.changeUpdateWorker()
|
||||
@@ -474,6 +498,7 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek
|
||||
delete(r.participants, identity)
|
||||
delete(r.participantOpts, identity)
|
||||
delete(r.participantRequestSources, identity)
|
||||
delete(r.hasPublished, identity)
|
||||
if !p.Hidden() {
|
||||
r.protoRoom.NumParticipants--
|
||||
}
|
||||
@@ -509,7 +534,6 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek
|
||||
for _, t := range p.GetPublishedTracks() {
|
||||
r.trackManager.RemoveTrack(t)
|
||||
}
|
||||
r.hasPublished.Delete(p.Identity())
|
||||
|
||||
p.OnTrackUpdated(nil)
|
||||
p.OnTrackPublished(nil)
|
||||
@@ -898,10 +922,19 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types.
|
||||
|
||||
r.trackManager.AddTrack(track, participant.Identity(), participant.ID())
|
||||
|
||||
// auto egress
|
||||
if r.internal != nil {
|
||||
if r.internal.ParticipantEgress != nil {
|
||||
if _, hasPublished := r.hasPublished.Swap(participant.Identity(), true); !hasPublished {
|
||||
// launch jobs
|
||||
r.lock.Lock()
|
||||
hasPublished := r.hasPublished[participant.Identity()]
|
||||
r.hasPublished[participant.Identity()] = true
|
||||
publisherAgentsEnabled := r.publisherAgentsEnabled
|
||||
r.lock.Unlock()
|
||||
|
||||
if !hasPublished {
|
||||
if publisherAgentsEnabled {
|
||||
r.launchPublisherAgent(participant)
|
||||
}
|
||||
if r.internal != nil && r.internal.ParticipantEgress != nil {
|
||||
go func() {
|
||||
if err := StartParticipantEgress(
|
||||
context.Background(),
|
||||
r.egressLauncher,
|
||||
@@ -913,9 +946,11 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types.
|
||||
); err != nil {
|
||||
r.Logger.Errorw("failed to launch participant egress", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
if r.internal.TrackEgress != nil {
|
||||
}
|
||||
if r.internal != nil && r.internal.TrackEgress != nil {
|
||||
go func() {
|
||||
if err := StartTrackEgress(
|
||||
context.Background(),
|
||||
r.egressLauncher,
|
||||
@@ -927,7 +962,7 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types.
|
||||
); err != nil {
|
||||
r.Logger.Errorw("failed to launch track egress", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1286,6 +1321,21 @@ func (r *Room) connectionQualityWorker() {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) launchPublisherAgent(p types.Participant) {
|
||||
if p == nil || p.IsRecorder() || p.IsAgent() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
r.agentClient.JobRequest(context.Background(), &livekit.Job{
|
||||
Id: utils.NewGuid("JP_"),
|
||||
Type: livekit.JobType_JT_PUBLISHER,
|
||||
Room: r.ToProto(),
|
||||
Participant: p.ToProto(),
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *Room) DebugInfo() map[string]interface{} {
|
||||
info := map[string]interface{}{
|
||||
"Name": r.protoRoom.Name,
|
||||
|
||||
@@ -739,7 +739,7 @@ func newRoomWithParticipants(t *testing.T, opts testRoomOpts) *Room {
|
||||
Region: "testregion",
|
||||
},
|
||||
telemetry.NewTelemetryService(webhook.NewDefaultNotifier("", "", nil), &telemetryfakes.FakeAnalyticsService{}),
|
||||
nil,
|
||||
nil, nil,
|
||||
)
|
||||
for i := 0; i < opts.num+opts.numHidden; i++ {
|
||||
identity := livekit.ParticipantIdentity(fmt.Sprintf("p%d", i))
|
||||
|
||||
@@ -833,7 +833,7 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni
|
||||
}
|
||||
|
||||
dcCloseHandler := func() {
|
||||
t.params.Logger.Infow(dc.Label() + " data channel close")
|
||||
t.params.Logger.Debugw(dc.Label() + " data channel close")
|
||||
}
|
||||
|
||||
dcErrorHandler := func(err error) {
|
||||
|
||||
@@ -104,6 +104,7 @@ const (
|
||||
ParticipantCloseReasonPublicationError
|
||||
ParticipantCloseReasonSubscriptionError
|
||||
ParticipantCloseReasonDataChannelError
|
||||
ParticipantCloseReasonMigrateCodecMismatch
|
||||
)
|
||||
|
||||
func (p ParticipantCloseReason) String() string {
|
||||
@@ -154,6 +155,8 @@ func (p ParticipantCloseReason) String() string {
|
||||
return "SUBSCRIPTION_ERROR"
|
||||
case ParticipantCloseReasonDataChannelError:
|
||||
return "DATA_CHANNEL_ERROR"
|
||||
case ParticipantCloseReasonMigrateCodecMismatch:
|
||||
return "MIGRATE_CODEC_MISMATCH"
|
||||
default:
|
||||
return fmt.Sprintf("%d", int(p))
|
||||
}
|
||||
@@ -184,7 +187,7 @@ func (p ParticipantCloseReason) ToDisconnectReason() livekit.DisconnectReason {
|
||||
return livekit.DisconnectReason_SERVER_SHUTDOWN
|
||||
case ParticipantCloseReasonOvercommitted:
|
||||
return livekit.DisconnectReason_SERVER_SHUTDOWN
|
||||
case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError:
|
||||
case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError, ParticipantCloseReasonMigrateCodecMismatch:
|
||||
return livekit.DisconnectReason_STATE_MISMATCH
|
||||
default:
|
||||
// the other types will map to unknown reason
|
||||
@@ -261,6 +264,7 @@ type Participant interface {
|
||||
// permissions
|
||||
Hidden() bool
|
||||
IsRecorder() bool
|
||||
IsAgent() bool
|
||||
|
||||
Start()
|
||||
Close(sendLeave bool, reason ParticipantCloseReason, isExpectedToResume bool) error
|
||||
|
||||
@@ -408,6 +408,16 @@ type FakeLocalParticipant struct {
|
||||
identityReturnsOnCall map[int]struct {
|
||||
result1 livekit.ParticipantIdentity
|
||||
}
|
||||
IsAgentStub func() bool
|
||||
isAgentMutex sync.RWMutex
|
||||
isAgentArgsForCall []struct {
|
||||
}
|
||||
isAgentReturns struct {
|
||||
result1 bool
|
||||
}
|
||||
isAgentReturnsOnCall map[int]struct {
|
||||
result1 bool
|
||||
}
|
||||
IsClosedStub func() bool
|
||||
isClosedMutex sync.RWMutex
|
||||
isClosedArgsForCall []struct {
|
||||
@@ -2986,6 +2996,59 @@ func (fake *FakeLocalParticipant) IdentityReturnsOnCall(i int, result1 livekit.P
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeLocalParticipant) IsAgent() bool {
|
||||
fake.isAgentMutex.Lock()
|
||||
ret, specificReturn := fake.isAgentReturnsOnCall[len(fake.isAgentArgsForCall)]
|
||||
fake.isAgentArgsForCall = append(fake.isAgentArgsForCall, struct {
|
||||
}{})
|
||||
stub := fake.IsAgentStub
|
||||
fakeReturns := fake.isAgentReturns
|
||||
fake.recordInvocation("IsAgent", []interface{}{})
|
||||
fake.isAgentMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub()
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeLocalParticipant) IsAgentCallCount() int {
|
||||
fake.isAgentMutex.RLock()
|
||||
defer fake.isAgentMutex.RUnlock()
|
||||
return len(fake.isAgentArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeLocalParticipant) IsAgentCalls(stub func() bool) {
|
||||
fake.isAgentMutex.Lock()
|
||||
defer fake.isAgentMutex.Unlock()
|
||||
fake.IsAgentStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeLocalParticipant) IsAgentReturns(result1 bool) {
|
||||
fake.isAgentMutex.Lock()
|
||||
defer fake.isAgentMutex.Unlock()
|
||||
fake.IsAgentStub = nil
|
||||
fake.isAgentReturns = struct {
|
||||
result1 bool
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeLocalParticipant) IsAgentReturnsOnCall(i int, result1 bool) {
|
||||
fake.isAgentMutex.Lock()
|
||||
defer fake.isAgentMutex.Unlock()
|
||||
fake.IsAgentStub = nil
|
||||
if fake.isAgentReturnsOnCall == nil {
|
||||
fake.isAgentReturnsOnCall = make(map[int]struct {
|
||||
result1 bool
|
||||
})
|
||||
}
|
||||
fake.isAgentReturnsOnCall[i] = struct {
|
||||
result1 bool
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeLocalParticipant) IsClosed() bool {
|
||||
fake.isClosedMutex.Lock()
|
||||
ret, specificReturn := fake.isClosedReturnsOnCall[len(fake.isClosedArgsForCall)]
|
||||
@@ -6031,6 +6094,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} {
|
||||
defer fake.iDMutex.RUnlock()
|
||||
fake.identityMutex.RLock()
|
||||
defer fake.identityMutex.RUnlock()
|
||||
fake.isAgentMutex.RLock()
|
||||
defer fake.isAgentMutex.RUnlock()
|
||||
fake.isClosedMutex.RLock()
|
||||
defer fake.isClosedMutex.RUnlock()
|
||||
fake.isDisconnectedMutex.RLock()
|
||||
|
||||
@@ -118,6 +118,16 @@ type FakeParticipant struct {
|
||||
identityReturnsOnCall map[int]struct {
|
||||
result1 livekit.ParticipantIdentity
|
||||
}
|
||||
IsAgentStub func() bool
|
||||
isAgentMutex sync.RWMutex
|
||||
isAgentArgsForCall []struct {
|
||||
}
|
||||
isAgentReturns struct {
|
||||
result1 bool
|
||||
}
|
||||
isAgentReturnsOnCall map[int]struct {
|
||||
result1 bool
|
||||
}
|
||||
IsPublisherStub func() bool
|
||||
isPublisherMutex sync.RWMutex
|
||||
isPublisherArgsForCall []struct {
|
||||
@@ -780,6 +790,59 @@ func (fake *FakeParticipant) IdentityReturnsOnCall(i int, result1 livekit.Partic
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeParticipant) IsAgent() bool {
|
||||
fake.isAgentMutex.Lock()
|
||||
ret, specificReturn := fake.isAgentReturnsOnCall[len(fake.isAgentArgsForCall)]
|
||||
fake.isAgentArgsForCall = append(fake.isAgentArgsForCall, struct {
|
||||
}{})
|
||||
stub := fake.IsAgentStub
|
||||
fakeReturns := fake.isAgentReturns
|
||||
fake.recordInvocation("IsAgent", []interface{}{})
|
||||
fake.isAgentMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub()
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeParticipant) IsAgentCallCount() int {
|
||||
fake.isAgentMutex.RLock()
|
||||
defer fake.isAgentMutex.RUnlock()
|
||||
return len(fake.isAgentArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeParticipant) IsAgentCalls(stub func() bool) {
|
||||
fake.isAgentMutex.Lock()
|
||||
defer fake.isAgentMutex.Unlock()
|
||||
fake.IsAgentStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeParticipant) IsAgentReturns(result1 bool) {
|
||||
fake.isAgentMutex.Lock()
|
||||
defer fake.isAgentMutex.Unlock()
|
||||
fake.IsAgentStub = nil
|
||||
fake.isAgentReturns = struct {
|
||||
result1 bool
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeParticipant) IsAgentReturnsOnCall(i int, result1 bool) {
|
||||
fake.isAgentMutex.Lock()
|
||||
defer fake.isAgentMutex.Unlock()
|
||||
fake.IsAgentStub = nil
|
||||
if fake.isAgentReturnsOnCall == nil {
|
||||
fake.isAgentReturnsOnCall = make(map[int]struct {
|
||||
result1 bool
|
||||
})
|
||||
}
|
||||
fake.isAgentReturnsOnCall[i] = struct {
|
||||
result1 bool
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeParticipant) IsPublisher() bool {
|
||||
fake.isPublisherMutex.Lock()
|
||||
ret, specificReturn := fake.isPublisherReturnsOnCall[len(fake.isPublisherArgsForCall)]
|
||||
@@ -1318,6 +1381,8 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} {
|
||||
defer fake.iDMutex.RUnlock()
|
||||
fake.identityMutex.RLock()
|
||||
defer fake.identityMutex.RUnlock()
|
||||
fake.isAgentMutex.RLock()
|
||||
defer fake.isAgentMutex.RUnlock()
|
||||
fake.isPublisherMutex.RLock()
|
||||
defer fake.isPublisherMutex.RUnlock()
|
||||
fake.isRecorderMutex.RLock()
|
||||
|
||||
@@ -160,8 +160,8 @@ func (u *UpTrackManager) UpdateSubscriptionPermission(
|
||||
u.params.Logger.Debugw(
|
||||
"skipping older subscription permission version",
|
||||
"existingValue", perms,
|
||||
"existingVersion", u.subscriptionPermissionVersion.ToProto().String(),
|
||||
"requestingValue", subscriptionPermission.String(),
|
||||
"existingVersion", u.subscriptionPermissionVersion.String(),
|
||||
"requestingValue", logger.Proto(subscriptionPermission),
|
||||
"requestingVersion", timedVersion.String(),
|
||||
)
|
||||
u.lock.Unlock()
|
||||
@@ -178,7 +178,7 @@ func (u *UpTrackManager) UpdateSubscriptionPermission(
|
||||
if subscriptionPermission == nil {
|
||||
u.params.Logger.Debugw(
|
||||
"updating subscription permission, setting to nil",
|
||||
"version", u.subscriptionPermissionVersion.ToProto().String(),
|
||||
"version", u.subscriptionPermissionVersion.String(),
|
||||
)
|
||||
// possible to get a nil when migrating
|
||||
u.lock.Unlock()
|
||||
@@ -187,8 +187,8 @@ func (u *UpTrackManager) UpdateSubscriptionPermission(
|
||||
|
||||
u.params.Logger.Debugw(
|
||||
"updating subscription permission",
|
||||
"permissions", u.subscriptionPermission.String(),
|
||||
"version", u.subscriptionPermissionVersion.ToProto().String(),
|
||||
"permissions", logger.Proto(u.subscriptionPermission),
|
||||
"version", u.subscriptionPermissionVersion.String(),
|
||||
)
|
||||
if err := u.parseSubscriptionPermissionsLocked(subscriptionPermission, func(pID livekit.ParticipantID) types.LocalParticipant {
|
||||
u.lock.Unlock()
|
||||
@@ -247,7 +247,7 @@ func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) {
|
||||
u.publishedTracks[track.ID()] = track
|
||||
}
|
||||
u.lock.Unlock()
|
||||
u.params.Logger.Debugw("added published track", "trackID", track.ID(), "trackInfo", track.ToProto().String())
|
||||
u.params.Logger.Debugw("added published track", "trackID", track.ID(), "trackInfo", logger.Proto(track.ToProto()))
|
||||
|
||||
track.AddOnClose(func() {
|
||||
notifyClose := false
|
||||
|
||||
@@ -0,0 +1,468 @@
|
||||
// Copyright 2023 LiveKit, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/rtc"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
"github.com/livekit/protocol/rpc"
|
||||
"github.com/livekit/psrpc"
|
||||
)
|
||||
|
||||
const AgentServiceVersion = "0.1.0"
|
||||
|
||||
type AgentService struct {
|
||||
upgrader websocket.Upgrader
|
||||
|
||||
*AgentHandler
|
||||
}
|
||||
|
||||
type AgentHandler struct {
|
||||
agentServer rpc.AgentInternalServer
|
||||
roomTopic string
|
||||
publisherTopic string
|
||||
|
||||
mu sync.Mutex
|
||||
availability map[string]chan *availability
|
||||
unregistered map[*websocket.Conn]*worker
|
||||
roomRegistered bool
|
||||
roomWorkers map[string]*worker
|
||||
publisherRegistered bool
|
||||
publisherWorkers map[string]*worker
|
||||
}
|
||||
|
||||
type worker struct {
|
||||
mu sync.Mutex
|
||||
conn *websocket.Conn
|
||||
sigConn *WSSignalConnection
|
||||
|
||||
id string
|
||||
jobType livekit.JobType
|
||||
status livekit.WorkerStatus
|
||||
activeJobs int
|
||||
}
|
||||
|
||||
type availability struct {
|
||||
workerID string
|
||||
available bool
|
||||
}
|
||||
|
||||
func NewAgentService(bus psrpc.MessageBus) (*AgentService, error) {
|
||||
s := &AgentService{
|
||||
upgrader: websocket.Upgrader{},
|
||||
}
|
||||
|
||||
// allow connections from any origin, since script may be hosted anywhere
|
||||
// security is enforced by access tokens
|
||||
s.upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
agentServer, err := rpc.NewAgentInternalServer(s, bus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.AgentHandler = NewAgentHandler(agentServer, rtc.RoomAgentTopic, rtc.PublisherAgentTopic)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *AgentService) ServeHTTP(writer http.ResponseWriter, r *http.Request) {
|
||||
// reject non websocket requests
|
||||
if !websocket.IsWebSocketUpgrade(r) {
|
||||
writer.WriteHeader(404)
|
||||
return
|
||||
}
|
||||
|
||||
// require a claim
|
||||
claims := GetGrants(r.Context())
|
||||
if claims == nil || claims.Video == nil || !claims.Video.Agent {
|
||||
handleError(writer, http.StatusUnauthorized, rtc.ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
|
||||
// upgrade
|
||||
conn, err := s.upgrader.Upgrade(writer, r, nil)
|
||||
if err != nil {
|
||||
handleError(writer, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.HandleConnection(conn)
|
||||
}
|
||||
|
||||
func NewAgentHandler(agentServer rpc.AgentInternalServer, roomTopic, publisherTopic string) *AgentHandler {
|
||||
return &AgentHandler{
|
||||
agentServer: agentServer,
|
||||
roomTopic: roomTopic,
|
||||
publisherTopic: publisherTopic,
|
||||
availability: make(map[string]chan *availability),
|
||||
unregistered: make(map[*websocket.Conn]*worker),
|
||||
roomWorkers: make(map[string]*worker),
|
||||
publisherWorkers: make(map[string]*worker),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AgentHandler) HandleConnection(conn *websocket.Conn) {
|
||||
sigConn := NewWSSignalConnection(conn)
|
||||
w := &worker{
|
||||
conn: conn,
|
||||
sigConn: sigConn,
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.unregistered[conn] = w
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
if w.id == "" {
|
||||
delete(s.unregistered, conn)
|
||||
} else {
|
||||
switch w.jobType {
|
||||
case livekit.JobType_JT_ROOM:
|
||||
delete(s.roomWorkers, w.id)
|
||||
if s.roomRegistered && !s.roomAvailableLocked() {
|
||||
s.roomRegistered = false
|
||||
s.agentServer.DeregisterJobRequestTopic(s.roomTopic)
|
||||
}
|
||||
case livekit.JobType_JT_PUBLISHER:
|
||||
delete(s.publisherWorkers, w.id)
|
||||
if s.publisherRegistered && !s.publisherAvailableLocked() {
|
||||
s.publisherRegistered = false
|
||||
s.agentServer.DeregisterJobRequestTopic(s.publisherTopic)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
// handle incoming requests from websocket
|
||||
for {
|
||||
req, _, err := sigConn.ReadWorkerMessage()
|
||||
if err != nil {
|
||||
// normal/expected closure
|
||||
if err == io.EOF ||
|
||||
strings.HasSuffix(err.Error(), "use of closed network connection") ||
|
||||
strings.HasSuffix(err.Error(), "connection reset by peer") ||
|
||||
websocket.IsCloseError(
|
||||
err,
|
||||
websocket.CloseAbnormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseNoStatusReceived,
|
||||
) {
|
||||
logger.Infow("exit ws read loop for closed connection", "wsError", err)
|
||||
} else {
|
||||
logger.Errorw("error reading from websocket", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
switch m := req.Message.(type) {
|
||||
case *livekit.WorkerMessage_Register:
|
||||
go s.handleRegister(w, m.Register)
|
||||
case *livekit.WorkerMessage_Availability:
|
||||
go s.handleAvailability(w, m.Availability)
|
||||
case *livekit.WorkerMessage_JobUpdate:
|
||||
go s.handleJobUpdate(w, m.JobUpdate)
|
||||
case *livekit.WorkerMessage_Status:
|
||||
go s.handleStatus(w, m.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AgentHandler) handleRegister(worker *worker, msg *livekit.RegisterWorkerRequest) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
switch msg.Type {
|
||||
case livekit.JobType_JT_ROOM:
|
||||
worker.id = msg.WorkerId
|
||||
delete(s.unregistered, worker.conn)
|
||||
s.roomWorkers[worker.id] = worker
|
||||
|
||||
if !s.roomRegistered {
|
||||
err := s.agentServer.RegisterJobRequestTopic(s.roomTopic)
|
||||
if err != nil {
|
||||
logger.Errorw("failed to register room agents", err)
|
||||
} else {
|
||||
s.roomRegistered = true
|
||||
}
|
||||
}
|
||||
|
||||
case livekit.JobType_JT_PUBLISHER:
|
||||
worker.id = msg.WorkerId
|
||||
delete(s.unregistered, worker.conn)
|
||||
s.publisherWorkers[worker.id] = worker
|
||||
|
||||
if !s.publisherRegistered {
|
||||
err := s.agentServer.RegisterJobRequestTopic(s.publisherTopic)
|
||||
if err != nil {
|
||||
logger.Errorw("failed to register publisher agents", err)
|
||||
} else {
|
||||
s.publisherRegistered = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err := worker.sigConn.WriteServerMessage(&livekit.ServerMessage{
|
||||
Message: &livekit.ServerMessage_Register{
|
||||
Register: &livekit.RegisterWorkerResponse{
|
||||
WorkerId: worker.id,
|
||||
ServerVersion: AgentServiceVersion,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorw("failed to write server message", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AgentHandler) handleAvailability(w *worker, msg *livekit.AvailabilityResponse) {
|
||||
s.mu.Lock()
|
||||
availabilityChan, ok := s.availability[msg.JobId]
|
||||
s.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
availabilityChan <- &availability{
|
||||
workerID: w.id,
|
||||
available: msg.Available,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AgentHandler) handleJobUpdate(w *worker, msg *livekit.JobStatusUpdate) {
|
||||
switch msg.Status {
|
||||
case livekit.JobStatus_JS_SUCCESS:
|
||||
logger.Debugw("job complete", "jobID", msg.JobId)
|
||||
case livekit.JobStatus_JS_FAILED:
|
||||
logger.Warnw("job failed", errors.New(msg.Error), "jobID", msg.JobId)
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
w.activeJobs--
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *AgentHandler) handleStatus(w *worker, msg *livekit.UpdateWorkerStatus) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
w.mu.Lock()
|
||||
w.status = msg.Status
|
||||
w.mu.Unlock()
|
||||
|
||||
switch w.jobType {
|
||||
case livekit.JobType_JT_ROOM:
|
||||
if s.roomRegistered && !s.roomAvailableLocked() {
|
||||
s.roomRegistered = false
|
||||
s.agentServer.DeregisterJobRequestTopic(s.roomTopic)
|
||||
} else if !s.roomRegistered && s.roomAvailableLocked() {
|
||||
if err := s.agentServer.RegisterJobRequestTopic(s.roomTopic); err != nil {
|
||||
logger.Errorw("failed to register room agents", err)
|
||||
} else {
|
||||
s.roomRegistered = true
|
||||
}
|
||||
}
|
||||
case livekit.JobType_JT_PUBLISHER:
|
||||
if s.publisherRegistered && !s.publisherAvailableLocked() {
|
||||
s.publisherRegistered = false
|
||||
s.agentServer.DeregisterJobRequestTopic(s.publisherTopic)
|
||||
} else if !s.publisherRegistered && s.publisherAvailableLocked() {
|
||||
if err := s.agentServer.RegisterJobRequestTopic(s.publisherTopic); err != nil {
|
||||
logger.Errorw("failed to register publisher agents", err)
|
||||
} else {
|
||||
s.publisherRegistered = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AgentHandler) CheckEnabled(_ context.Context, _ *rpc.CheckEnabledRequest) (*rpc.CheckEnabledResponse, error) {
|
||||
s.mu.Lock()
|
||||
res := &rpc.CheckEnabledResponse{
|
||||
RoomEnabled: len(s.roomWorkers) > 0,
|
||||
PublisherEnabled: len(s.publisherWorkers) > 0,
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*emptypb.Empty, error) {
|
||||
s.mu.Lock()
|
||||
ac := make(chan *availability, 100)
|
||||
s.availability[job.Id] = ac
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(s.availability, job.Id)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
var pool map[string]*worker
|
||||
switch job.Type {
|
||||
case livekit.JobType_JT_ROOM:
|
||||
pool = s.roomWorkers
|
||||
case livekit.JobType_JT_PUBLISHER:
|
||||
pool = s.publisherWorkers
|
||||
}
|
||||
|
||||
attempted := make(map[string]bool)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "request timed out")
|
||||
default:
|
||||
s.mu.Lock()
|
||||
var selected *worker
|
||||
for _, w := range pool {
|
||||
if attempted[w.id] {
|
||||
continue
|
||||
}
|
||||
if w.status == livekit.WorkerStatus_WS_AVAILABLE {
|
||||
if w.activeJobs > 0 {
|
||||
selected = w
|
||||
break
|
||||
} else if selected == nil {
|
||||
selected = w
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if selected == nil {
|
||||
return nil, psrpc.NewErrorf(psrpc.Unavailable, "no workers available")
|
||||
}
|
||||
|
||||
attempted[selected.id] = true
|
||||
_, err := selected.sigConn.WriteServerMessage(&livekit.ServerMessage{Message: &livekit.ServerMessage_Availability{
|
||||
Availability: &livekit.AvailabilityRequest{Job: job},
|
||||
}})
|
||||
if err != nil {
|
||||
logger.Errorw("failed to send availability request", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "request timed out")
|
||||
case res := <-ac:
|
||||
if res.available {
|
||||
_, err = selected.sigConn.WriteServerMessage(&livekit.ServerMessage{Message: &livekit.ServerMessage_Assignment{
|
||||
Assignment: &livekit.JobAssignment{Job: job},
|
||||
}})
|
||||
if err != nil {
|
||||
logger.Errorw("failed to assign job", err)
|
||||
} else {
|
||||
selected.mu.Lock()
|
||||
selected.activeJobs++
|
||||
selected.mu.Unlock()
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) float32 {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var pool map[string]*worker
|
||||
switch job.Type {
|
||||
case livekit.JobType_JT_ROOM:
|
||||
pool = s.roomWorkers
|
||||
case livekit.JobType_JT_PUBLISHER:
|
||||
pool = s.publisherWorkers
|
||||
}
|
||||
|
||||
var affinity float32
|
||||
for _, w := range pool {
|
||||
if w.status == livekit.WorkerStatus_WS_AVAILABLE {
|
||||
if w.activeJobs > 0 {
|
||||
return 1
|
||||
} else {
|
||||
affinity = 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return affinity
|
||||
}
|
||||
|
||||
func (s *AgentHandler) NumConnections() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return len(s.unregistered) + len(s.roomWorkers) + len(s.publisherWorkers)
|
||||
}
|
||||
|
||||
func (s *AgentHandler) DrainConnections(interval time.Duration) {
|
||||
// jitter drain start
|
||||
time.Sleep(time.Duration(rand.Int63n(int64(interval))))
|
||||
|
||||
t := time.NewTicker(interval)
|
||||
defer t.Stop()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for conn := range s.unregistered {
|
||||
_ = conn.Close()
|
||||
<-t.C
|
||||
}
|
||||
for _, w := range s.roomWorkers {
|
||||
_ = w.conn.Close()
|
||||
<-t.C
|
||||
}
|
||||
for _, w := range s.publisherWorkers {
|
||||
_ = w.conn.Close()
|
||||
<-t.C
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AgentHandler) roomAvailableLocked() bool {
|
||||
for _, w := range s.roomWorkers {
|
||||
if w.status == livekit.WorkerStatus_WS_AVAILABLE {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *AgentHandler) publisherAvailableLocked() bool {
|
||||
for _, w := range s.publisherWorkers {
|
||||
if w.status == livekit.WorkerStatus_WS_AVAILABLE {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
// Copyright 2023 LiveKit, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/rtc"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
"github.com/livekit/protocol/rpc"
|
||||
"github.com/livekit/protocol/utils"
|
||||
)
|
||||
|
||||
type IOClient interface {
|
||||
CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error)
|
||||
GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error)
|
||||
ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error)
|
||||
}
|
||||
|
||||
type egressLauncher struct {
|
||||
client rpc.EgressClient
|
||||
io IOClient
|
||||
}
|
||||
|
||||
func NewEgressLauncher(client rpc.EgressClient, io IOClient) rtc.EgressLauncher {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
return &egressLauncher{
|
||||
client: client,
|
||||
io: io,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *egressLauncher) StartEgress(ctx context.Context, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) {
|
||||
info, err := s.StartEgressWithClusterId(ctx, "", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = s.io.CreateEgress(ctx, info)
|
||||
if err != nil {
|
||||
logger.Errorw("failed to create egress", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (s *egressLauncher) StartEgressWithClusterId(ctx context.Context, clusterId string, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) {
|
||||
if s.client == nil {
|
||||
return nil, ErrEgressNotConnected
|
||||
}
|
||||
|
||||
// Ensure we have an Egress ID
|
||||
if req.EgressId == "" {
|
||||
req.EgressId = utils.NewGuid(utils.EgressPrefix)
|
||||
}
|
||||
|
||||
return s.client.StartEgress(ctx, clusterId, req)
|
||||
}
|
||||
+2
-46
@@ -25,40 +25,23 @@ import (
|
||||
"github.com/livekit/livekit-server/pkg/rtc"
|
||||
"github.com/livekit/protocol/egress"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
"github.com/livekit/protocol/rpc"
|
||||
"github.com/livekit/protocol/utils"
|
||||
)
|
||||
|
||||
type EgressService struct {
|
||||
launcher rtc.EgressLauncher
|
||||
client rpc.EgressClient
|
||||
io IOClient
|
||||
roomService livekit.RoomService
|
||||
store ServiceStore
|
||||
launcher rtc.EgressLauncher
|
||||
}
|
||||
|
||||
type egressLauncher struct {
|
||||
client rpc.EgressClient
|
||||
io IOClient
|
||||
}
|
||||
|
||||
func NewEgressLauncher(client rpc.EgressClient, io IOClient) rtc.EgressLauncher {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
return &egressLauncher{
|
||||
client: client,
|
||||
io: io,
|
||||
}
|
||||
}
|
||||
|
||||
func NewEgressService(
|
||||
client rpc.EgressClient,
|
||||
launcher rtc.EgressLauncher,
|
||||
store ServiceStore,
|
||||
io IOClient,
|
||||
rs livekit.RoomService,
|
||||
launcher rtc.EgressLauncher,
|
||||
) *EgressService {
|
||||
return &EgressService{
|
||||
client: client,
|
||||
@@ -189,33 +172,6 @@ func (s *EgressService) startEgress(ctx context.Context, roomName livekit.RoomNa
|
||||
return s.launcher.StartEgress(ctx, req)
|
||||
}
|
||||
|
||||
func (s *egressLauncher) StartEgress(ctx context.Context, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) {
|
||||
info, err := s.StartEgressWithClusterId(ctx, "", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = s.io.CreateEgress(ctx, info)
|
||||
if err != nil {
|
||||
logger.Errorw("failed to create egress", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (s *egressLauncher) StartEgressWithClusterId(ctx context.Context, clusterId string, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) {
|
||||
if s.client == nil {
|
||||
return nil, ErrEgressNotConnected
|
||||
}
|
||||
|
||||
// Ensure we have an Egress ID
|
||||
if req.EgressId == "" {
|
||||
req.EgressId = utils.NewGuid(utils.EgressPrefix)
|
||||
}
|
||||
|
||||
return s.client.StartEgress(ctx, clusterId, req)
|
||||
}
|
||||
|
||||
type LayoutMetadata struct {
|
||||
Layout string `json:"layout"`
|
||||
}
|
||||
|
||||
@@ -18,10 +18,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/rpc"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate
|
||||
@@ -74,13 +71,6 @@ type IngressStore interface {
|
||||
DeleteIngress(ctx context.Context, info *livekit.IngressInfo) error
|
||||
}
|
||||
|
||||
//counterfeiter:generate . IOClient
|
||||
type IOClient interface {
|
||||
CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error)
|
||||
GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error)
|
||||
ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error)
|
||||
}
|
||||
|
||||
//counterfeiter:generate . RoomAllocator
|
||||
type RoomAllocator interface {
|
||||
CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest) (*livekit.Room, bool, error)
|
||||
|
||||
@@ -60,8 +60,10 @@ func (r *StandardRoomAllocator) CreateRoom(ctx context.Context, req *livekit.Cre
|
||||
}()
|
||||
|
||||
// find existing room and update it
|
||||
var created bool
|
||||
rm, internal, err := r.roomStore.LoadRoom(ctx, livekit.RoomName(req.Name), true)
|
||||
if err == ErrRoomNotFound {
|
||||
created = true
|
||||
rm = &livekit.Room{
|
||||
Sid: utils.NewGuid(utils.RoomPrefix),
|
||||
Name: req.Name,
|
||||
@@ -114,7 +116,7 @@ func (r *StandardRoomAllocator) CreateRoom(ctx context.Context, req *livekit.Cre
|
||||
return nil, false, routing.ErrNodeLimitReached
|
||||
}
|
||||
|
||||
return rm, false, nil
|
||||
return rm, created, nil
|
||||
}
|
||||
|
||||
// select a new node
|
||||
|
||||
@@ -69,6 +69,7 @@ type RoomManager struct {
|
||||
roomStore ObjectStore
|
||||
telemetry telemetry.TelemetryService
|
||||
clientConfManager clientconfiguration.ClientConfigurationManager
|
||||
agentClient rtc.AgentClient
|
||||
egressLauncher rtc.EgressLauncher
|
||||
versionGenerator utils.TimedVersionGenerator
|
||||
turnAuthHandler *TURNAuthHandler
|
||||
@@ -89,6 +90,7 @@ func NewLocalRoomManager(
|
||||
router routing.Router,
|
||||
telemetry telemetry.TelemetryService,
|
||||
clientConfManager clientconfiguration.ClientConfigurationManager,
|
||||
agentClient rtc.AgentClient,
|
||||
egressLauncher rtc.EgressLauncher,
|
||||
versionGenerator utils.TimedVersionGenerator,
|
||||
turnAuthHandler *TURNAuthHandler,
|
||||
@@ -108,6 +110,7 @@ func NewLocalRoomManager(
|
||||
telemetry: telemetry,
|
||||
clientConfManager: clientConfManager,
|
||||
egressLauncher: egressLauncher,
|
||||
agentClient: agentClient,
|
||||
versionGenerator: versionGenerator,
|
||||
turnAuthHandler: turnAuthHandler,
|
||||
bus: bus,
|
||||
@@ -393,7 +396,8 @@ func (r *RoomManager) StartSession(
|
||||
Trailer: room.Trailer(),
|
||||
PLIThrottleConfig: r.config.RTC.PLIThrottle,
|
||||
CongestionControlConfig: r.config.RTC.CongestionControl,
|
||||
EnabledCodecs: protoRoom.EnabledCodecs,
|
||||
PublishEnabledCodecs: protoRoom.EnabledCodecs,
|
||||
SubscribeEnabledCodecs: protoRoom.EnabledCodecs,
|
||||
Grants: pi.Grants,
|
||||
Logger: pLogger,
|
||||
ClientConf: clientConf,
|
||||
@@ -526,7 +530,7 @@ func (r *RoomManager) getOrCreateRoom(ctx context.Context, roomName livekit.Room
|
||||
}
|
||||
|
||||
// construct ice servers
|
||||
newRoom := rtc.NewRoom(ri, internal, *r.rtcConfig, &r.config.Audio, r.serverInfo, r.telemetry, r.egressLauncher)
|
||||
newRoom := rtc.NewRoom(ri, internal, *r.rtcConfig, &r.config.Audio, r.serverInfo, r.telemetry, r.agentClient, r.egressLauncher)
|
||||
|
||||
roomTopic := rpc.FormatRoomTopic(roomName)
|
||||
roomServer := utils.Must(rpc.NewTypedRoomServer(r, r.bus))
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
"github.com/livekit/protocol/rpc"
|
||||
"github.com/livekit/protocol/utils"
|
||||
)
|
||||
|
||||
// A rooms service that supports a single node
|
||||
@@ -41,6 +42,7 @@ type RoomService struct {
|
||||
router routing.MessageRouter
|
||||
roomAllocator RoomAllocator
|
||||
roomStore ServiceStore
|
||||
agentClient rtc.AgentClient
|
||||
egressLauncher rtc.EgressLauncher
|
||||
topicFormatter rpc.TopicFormatter
|
||||
roomClient rpc.TypedRoomClient
|
||||
@@ -54,6 +56,7 @@ func NewRoomService(
|
||||
router routing.MessageRouter,
|
||||
roomAllocator RoomAllocator,
|
||||
serviceStore ServiceStore,
|
||||
agentClient rtc.AgentClient,
|
||||
egressLauncher rtc.EgressLauncher,
|
||||
topicFormatter rpc.TopicFormatter,
|
||||
roomClient rpc.TypedRoomClient,
|
||||
@@ -66,6 +69,7 @@ func NewRoomService(
|
||||
router: router,
|
||||
roomAllocator: roomAllocator,
|
||||
roomStore: serviceStore,
|
||||
agentClient: agentClient,
|
||||
egressLauncher: egressLauncher,
|
||||
topicFormatter: topicFormatter,
|
||||
roomClient: roomClient,
|
||||
@@ -112,17 +116,29 @@ func (s *RoomService) CreateRoom(ctx context.Context, req *livekit.CreateRoomReq
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if created && req.Egress != nil && req.Egress.Room != nil {
|
||||
egress := &rpc.StartEgressRequest{
|
||||
Request: &rpc.StartEgressRequest_RoomComposite{
|
||||
RoomComposite: req.Egress.Room,
|
||||
},
|
||||
RoomId: rm.Sid,
|
||||
if created {
|
||||
go func() {
|
||||
s.agentClient.JobRequest(ctx, &livekit.Job{
|
||||
Id: utils.NewGuid("JR_"),
|
||||
Type: livekit.JobType_JT_ROOM,
|
||||
Room: rm,
|
||||
})
|
||||
}()
|
||||
|
||||
if req.Egress != nil && req.Egress.Room != nil {
|
||||
_, err = s.egressLauncher.StartEgress(ctx, &rpc.StartEgressRequest{
|
||||
Request: &rpc.StartEgressRequest_RoomComposite{
|
||||
RoomComposite: req.Egress.Room,
|
||||
},
|
||||
RoomId: rm.Sid,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
_, err = s.egressLauncher.StartEgress(ctx, egress)
|
||||
}
|
||||
|
||||
return rm, err
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
func (s *RoomService) ListRooms(ctx context.Context, req *livekit.ListRoomsRequest) (*livekit.ListRoomsResponse, error) {
|
||||
@@ -427,7 +443,7 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat
|
||||
|
||||
// no one has joined the room, would not have been created on an RTC node.
|
||||
// in this case, we'd want to run create again
|
||||
_, _, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{
|
||||
room, created, err := s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{
|
||||
Name: req.Room,
|
||||
Metadata: req.Metadata,
|
||||
})
|
||||
@@ -465,6 +481,16 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if created {
|
||||
go func() {
|
||||
s.agentClient.JobRequest(ctx, &livekit.Job{
|
||||
Id: utils.NewGuid("JR_"),
|
||||
Type: livekit.JobType_JT_ROOM,
|
||||
Room: room,
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
return room, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -136,6 +136,7 @@ func newTestRoomService(conf config.RoomConfig) *TestRoomService {
|
||||
allocator,
|
||||
store,
|
||||
nil,
|
||||
nil,
|
||||
rpc.NewTopicFormatter(),
|
||||
&rpcfakes.FakeTypedRoomClient{},
|
||||
&rpcfakes.FakeTypedParticipantClient{},
|
||||
|
||||
@@ -31,17 +31,17 @@ import (
|
||||
"github.com/ua-parser/uap-go/uaparser"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/utils"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
"github.com/livekit/psrpc"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/config"
|
||||
"github.com/livekit/livekit-server/pkg/routing"
|
||||
"github.com/livekit/livekit-server/pkg/routing/selector"
|
||||
"github.com/livekit/livekit-server/pkg/rtc"
|
||||
"github.com/livekit/livekit-server/pkg/telemetry"
|
||||
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
|
||||
"github.com/livekit/livekit-server/pkg/utils"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
putil "github.com/livekit/protocol/utils"
|
||||
"github.com/livekit/psrpc"
|
||||
)
|
||||
|
||||
type RTCService struct {
|
||||
@@ -54,6 +54,7 @@ type RTCService struct {
|
||||
isDev bool
|
||||
limits config.LimitConfig
|
||||
parser *uaparser.Parser
|
||||
agentClient rtc.AgentClient
|
||||
telemetry telemetry.TelemetryService
|
||||
|
||||
mu sync.Mutex
|
||||
@@ -66,6 +67,7 @@ func NewRTCService(
|
||||
store ServiceStore,
|
||||
router routing.MessageRouter,
|
||||
currentNode routing.LocalNode,
|
||||
agentClient rtc.AgentClient,
|
||||
telemetry telemetry.TelemetryService,
|
||||
) *RTCService {
|
||||
s := &RTCService{
|
||||
@@ -78,6 +80,7 @@ func NewRTCService(
|
||||
isDev: conf.Development,
|
||||
limits: conf.Limit,
|
||||
parser: uaparser.NewFromSaved(),
|
||||
agentClient: agentClient,
|
||||
telemetry: telemetry,
|
||||
connections: map[*websocket.Conn]struct{}{},
|
||||
}
|
||||
@@ -229,6 +232,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Warnw("failed to start connection, retrying", err, fieldsWithAttempt...)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
prometheus.IncrementParticipantJoinFail(1)
|
||||
handleError(w, http.StatusInternalServerError, err, loggerFields...)
|
||||
@@ -514,10 +518,16 @@ type connectionResult struct {
|
||||
ResponseSource routing.MessageSource
|
||||
}
|
||||
|
||||
func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomName, pi routing.ParticipantInit, timeout time.Duration) (connectionResult, *livekit.SignalResponse, error) {
|
||||
func (s *RTCService) startConnection(
|
||||
ctx context.Context,
|
||||
roomName livekit.RoomName,
|
||||
pi routing.ParticipantInit,
|
||||
timeout time.Duration,
|
||||
) (connectionResult, *livekit.SignalResponse, error) {
|
||||
var cr connectionResult
|
||||
var created bool
|
||||
var err error
|
||||
cr.Room, _, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: string(roomName)})
|
||||
cr.Room, created, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: string(roomName)})
|
||||
if err != nil {
|
||||
return cr, nil, err
|
||||
}
|
||||
@@ -538,6 +548,17 @@ func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomN
|
||||
cr.ResponseSource.Close()
|
||||
return cr, nil, err
|
||||
}
|
||||
|
||||
if created && s.agentClient != nil {
|
||||
go func() {
|
||||
s.agentClient.JobRequest(ctx, &livekit.Job{
|
||||
Id: putil.NewGuid("JR_"),
|
||||
Type: livekit.JobType_JT_ROOM,
|
||||
Room: cr.Room,
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
return cr, initialResponse, nil
|
||||
}
|
||||
|
||||
@@ -559,5 +580,4 @@ func readInitialResponse(source routing.MessageSource, timeout time.Duration) (*
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ type LivekitServer struct {
|
||||
config *config.Config
|
||||
ioService *IOInfoService
|
||||
rtcService *RTCService
|
||||
agentService *AgentService
|
||||
httpServer *http.Server
|
||||
promServer *http.Server
|
||||
router routing.Router
|
||||
@@ -66,6 +67,7 @@ func NewLivekitServer(conf *config.Config,
|
||||
ingressService *IngressService,
|
||||
ioService *IOInfoService,
|
||||
rtcService *RTCService,
|
||||
agentService *AgentService,
|
||||
keyProvider auth.KeyProvider,
|
||||
router routing.Router,
|
||||
roomManager *RoomManager,
|
||||
@@ -77,6 +79,7 @@ func NewLivekitServer(conf *config.Config,
|
||||
config: conf,
|
||||
ioService: ioService,
|
||||
rtcService: rtcService,
|
||||
agentService: agentService,
|
||||
router: router,
|
||||
roomManager: roomManager,
|
||||
signalServer: signalServer,
|
||||
@@ -125,6 +128,7 @@ func NewLivekitServer(conf *config.Config,
|
||||
mux.Handle(egressServer.PathPrefix(), egressServer)
|
||||
mux.Handle(ingressServer.PathPrefix(), ingressServer)
|
||||
mux.Handle("/rtc", rtcService)
|
||||
mux.Handle("/agent", agentService)
|
||||
mux.HandleFunc("/rtc/validate", rtcService.Validate)
|
||||
mux.HandleFunc("/", s.defaultHandler)
|
||||
|
||||
|
||||
+4
-1
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/livekit/livekit-server/pkg/clientconfiguration"
|
||||
"github.com/livekit/livekit-server/pkg/config"
|
||||
"github.com/livekit/livekit-server/pkg/routing"
|
||||
"github.com/livekit/livekit-server/pkg/rtc"
|
||||
"github.com/livekit/livekit-server/pkg/telemetry"
|
||||
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
|
||||
"github.com/livekit/protocol/auth"
|
||||
@@ -62,16 +63,18 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
|
||||
NewIOInfoService,
|
||||
wire.Bind(new(IOClient), new(*IOInfoService)),
|
||||
rpc.NewEgressClient,
|
||||
rpc.NewIngressClient,
|
||||
getEgressStore,
|
||||
NewEgressLauncher,
|
||||
NewEgressService,
|
||||
rpc.NewIngressClient,
|
||||
getIngressStore,
|
||||
getIngressConfig,
|
||||
NewIngressService,
|
||||
NewRoomAllocator,
|
||||
NewRoomService,
|
||||
NewRTCService,
|
||||
NewAgentService,
|
||||
rtc.NewAgentClient,
|
||||
getSignalRelayConfig,
|
||||
NewDefaultSignalServer,
|
||||
routing.NewSignalClient,
|
||||
|
||||
+14
-5
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/livekit/livekit-server/pkg/clientconfiguration"
|
||||
"github.com/livekit/livekit-server/pkg/config"
|
||||
"github.com/livekit/livekit-server/pkg/routing"
|
||||
"github.com/livekit/livekit-server/pkg/rtc"
|
||||
"github.com/livekit/livekit-server/pkg/telemetry"
|
||||
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
|
||||
"github.com/livekit/protocol/auth"
|
||||
@@ -55,6 +56,10 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
agentClient, err := rtc.NewAgentClient(messageBus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
egressClient, err := rpc.NewEgressClient(messageBus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -86,22 +91,26 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, roomClient, participantClient)
|
||||
roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, agentClient, rtcEgressLauncher, topicFormatter, roomClient, participantClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
egressService := NewEgressService(egressClient, objectStore, ioInfoService, roomService, rtcEgressLauncher)
|
||||
egressService := NewEgressService(egressClient, rtcEgressLauncher, objectStore, ioInfoService, roomService)
|
||||
ingressConfig := getIngressConfig(conf)
|
||||
ingressClient, err := rpc.NewIngressClient(messageBus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ingressService := NewIngressService(ingressConfig, nodeID, messageBus, ingressClient, ingressStore, roomService, telemetryService)
|
||||
rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, telemetryService)
|
||||
rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, agentClient, telemetryService)
|
||||
agentService, err := NewAgentService(messageBus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientConfigurationManager := createClientConfiguration()
|
||||
timedVersionGenerator := utils.NewDefaultTimedVersionGenerator()
|
||||
turnAuthHandler := NewTURNAuthHandler(keyProvider)
|
||||
roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus)
|
||||
roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, agentClient, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -114,7 +123,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, keyProvider, router, roomManager, signalServer, server, currentNode)
|
||||
livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, agentService, keyProvider, router, roomManager, signalServer, server, currentNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -83,6 +83,40 @@ func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, int, error)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSSignalConnection) ReadWorkerMessage() (*livekit.WorkerMessage, int, error) {
|
||||
for {
|
||||
// handle special messages and pass on the rest
|
||||
messageType, payload, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
msg := &livekit.WorkerMessage{}
|
||||
switch messageType {
|
||||
case websocket.BinaryMessage:
|
||||
if c.useJSON {
|
||||
c.mu.Lock()
|
||||
// switch to protobuf if client supports it
|
||||
c.useJSON = false
|
||||
c.mu.Unlock()
|
||||
}
|
||||
// protobuf encoded
|
||||
err := proto.Unmarshal(payload, msg)
|
||||
return msg, len(payload), err
|
||||
case websocket.TextMessage:
|
||||
c.mu.Lock()
|
||||
// json encoded, also write back JSON
|
||||
c.useJSON = true
|
||||
c.mu.Unlock()
|
||||
err := protojson.Unmarshal(payload, msg)
|
||||
return msg, len(payload), err
|
||||
default:
|
||||
logger.Debugw("unsupported message", "message", messageType)
|
||||
return nil, len(payload), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) (int, error) {
|
||||
var msgType int
|
||||
var payload []byte
|
||||
@@ -105,6 +139,28 @@ func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) (int, er
|
||||
return len(payload), c.conn.WriteMessage(msgType, payload)
|
||||
}
|
||||
|
||||
func (c *WSSignalConnection) WriteServerMessage(msg *livekit.ServerMessage) (int, error) {
|
||||
var msgType int
|
||||
var payload []byte
|
||||
var err error
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.useJSON {
|
||||
msgType = websocket.TextMessage
|
||||
payload, err = protojson.Marshal(msg)
|
||||
} else {
|
||||
msgType = websocket.BinaryMessage
|
||||
payload, err = proto.Marshal(msg)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(payload), c.conn.WriteMessage(msgType, payload)
|
||||
}
|
||||
|
||||
func (c *WSSignalConnection) pingWorker() {
|
||||
for {
|
||||
<-time.After(pingFrequency)
|
||||
|
||||
+39
-17
@@ -16,8 +16,8 @@ package audio
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -40,11 +40,13 @@ type AudioLevel struct {
|
||||
smoothFactor float64
|
||||
activeThreshold float64
|
||||
|
||||
smoothedLevel atomic.Float64
|
||||
lock sync.Mutex
|
||||
smoothedLevel float64
|
||||
|
||||
loudestObservedLevel uint8
|
||||
activeDuration uint32 // ms
|
||||
observedDuration uint32 // ms
|
||||
lastObservedAt time.Time
|
||||
}
|
||||
|
||||
func NewAudioLevel(params AudioLevelParams) *AudioLevel {
|
||||
@@ -64,8 +66,13 @@ func NewAudioLevel(params AudioLevelParams) *AudioLevel {
|
||||
return l
|
||||
}
|
||||
|
||||
// Observes a new frame, must be called from the same thread
|
||||
func (l *AudioLevel) Observe(level uint8, durationMs uint32) {
|
||||
// Observes a new frame
|
||||
func (l *AudioLevel) Observe(level uint8, durationMs uint32, arrivalTime time.Time) {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
l.lastObservedAt = arrivalTime
|
||||
|
||||
l.observedDuration += durationMs
|
||||
|
||||
if level <= l.params.ActiveLevel {
|
||||
@@ -76,6 +83,7 @@ func (l *AudioLevel) Observe(level uint8, durationMs uint32) {
|
||||
}
|
||||
|
||||
if l.observedDuration >= l.params.ObserveDuration {
|
||||
smoothedLevel := float64(0.0)
|
||||
// compute and reset
|
||||
if l.activeDuration >= l.minActiveDuration {
|
||||
// adjust loudest observed level by how much of the window was active.
|
||||
@@ -87,25 +95,39 @@ func (l *AudioLevel) Observe(level uint8, durationMs uint32) {
|
||||
linearLevel := ConvertAudioLevel(adjustedLevel)
|
||||
|
||||
// exponential smoothing to dampen transients
|
||||
smoothedLevel := l.smoothedLevel.Load()
|
||||
smoothedLevel += (linearLevel - smoothedLevel) * l.smoothFactor
|
||||
l.smoothedLevel.Store(smoothedLevel)
|
||||
} else {
|
||||
l.smoothedLevel.Store(0)
|
||||
smoothedLevel = l.smoothedLevel + (linearLevel-l.smoothedLevel)*l.smoothFactor
|
||||
}
|
||||
l.loudestObservedLevel = silentAudioLevel
|
||||
l.activeDuration = 0
|
||||
l.observedDuration = 0
|
||||
l.resetLocked(smoothedLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// returns current soothed audio level
|
||||
func (l *AudioLevel) GetLevel() (float64, bool) {
|
||||
smoothedLevel := l.smoothedLevel.Load()
|
||||
active := smoothedLevel >= l.activeThreshold
|
||||
return smoothedLevel, active
|
||||
func (l *AudioLevel) GetLevel(now time.Time) (float64, bool) {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
l.resetIfStaleLocked(now)
|
||||
|
||||
return l.smoothedLevel, l.smoothedLevel >= l.activeThreshold
|
||||
}
|
||||
|
||||
func (l *AudioLevel) resetIfStaleLocked(arrivalTime time.Time) {
|
||||
if arrivalTime.Sub(l.lastObservedAt).Milliseconds() < int64(2*l.params.ObserveDuration) {
|
||||
return
|
||||
}
|
||||
|
||||
l.resetLocked(0.0)
|
||||
}
|
||||
|
||||
func (l *AudioLevel) resetLocked(smoothedLevel float64) {
|
||||
l.smoothedLevel = smoothedLevel
|
||||
l.loudestObservedLevel = silentAudioLevel
|
||||
l.activeDuration = 0
|
||||
l.observedDuration = 0
|
||||
}
|
||||
|
||||
// ---------------------------------------------------
|
||||
|
||||
// convert decibel back to linear
|
||||
func ConvertAudioLevel(level float64) float64 {
|
||||
return math.Pow(10, level*negInv20)
|
||||
|
||||
@@ -16,6 +16,7 @@ package audio
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -30,46 +31,79 @@ const (
|
||||
|
||||
func TestAudioLevel(t *testing.T) {
|
||||
t.Run("initially to return not noisy, within a few samples", func(t *testing.T) {
|
||||
clock := time.Now()
|
||||
a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration)
|
||||
_, noisy := a.GetLevel()
|
||||
|
||||
_, noisy := a.GetLevel(clock)
|
||||
require.False(t, noisy)
|
||||
|
||||
observeSamples(a, 28, 5)
|
||||
_, noisy = a.GetLevel()
|
||||
observeSamples(a, 28, 5, clock)
|
||||
clock = clock.Add(5 * 20 * time.Millisecond)
|
||||
|
||||
_, noisy = a.GetLevel(clock)
|
||||
require.False(t, noisy)
|
||||
})
|
||||
|
||||
t.Run("not noisy when all samples are below threshold", func(t *testing.T) {
|
||||
clock := time.Now()
|
||||
a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration)
|
||||
|
||||
observeSamples(a, 35, 100)
|
||||
_, noisy := a.GetLevel()
|
||||
observeSamples(a, 35, 100, clock)
|
||||
clock = clock.Add(100 * 20 * time.Millisecond)
|
||||
|
||||
_, noisy := a.GetLevel(clock)
|
||||
require.False(t, noisy)
|
||||
})
|
||||
|
||||
t.Run("not noisy when less than percentile samples are above threshold", func(t *testing.T) {
|
||||
clock := time.Now()
|
||||
a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration)
|
||||
|
||||
observeSamples(a, 35, samplesPerBatch-2)
|
||||
observeSamples(a, 25, 1)
|
||||
observeSamples(a, 35, 1)
|
||||
observeSamples(a, 35, samplesPerBatch-2, clock)
|
||||
clock = clock.Add((samplesPerBatch - 2) * 20 * time.Millisecond)
|
||||
observeSamples(a, 25, 1, clock)
|
||||
clock = clock.Add(20 * time.Millisecond)
|
||||
observeSamples(a, 35, 1, clock)
|
||||
clock = clock.Add(20 * time.Millisecond)
|
||||
|
||||
_, noisy := a.GetLevel()
|
||||
_, noisy := a.GetLevel(clock)
|
||||
require.False(t, noisy)
|
||||
})
|
||||
|
||||
t.Run("noisy when higher than percentile samples are above threshold", func(t *testing.T) {
|
||||
clock := time.Now()
|
||||
a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration)
|
||||
|
||||
observeSamples(a, 35, samplesPerBatch-16)
|
||||
observeSamples(a, 25, 8)
|
||||
observeSamples(a, 29, 8)
|
||||
observeSamples(a, 35, samplesPerBatch-16, clock)
|
||||
clock = clock.Add((samplesPerBatch - 16) * 20 * time.Millisecond)
|
||||
observeSamples(a, 25, 8, clock)
|
||||
clock = clock.Add(8 * 20 * time.Millisecond)
|
||||
observeSamples(a, 29, 8, clock)
|
||||
clock = clock.Add(8 * 20 * time.Millisecond)
|
||||
|
||||
level, noisy := a.GetLevel()
|
||||
level, noisy := a.GetLevel(clock)
|
||||
require.True(t, noisy)
|
||||
require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel)))
|
||||
require.Less(t, level, ConvertAudioLevel(float64(25)))
|
||||
})
|
||||
|
||||
t.Run("not noisy when samples are stale", func(t *testing.T) {
|
||||
clock := time.Now()
|
||||
a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration)
|
||||
|
||||
observeSamples(a, 25, 100, clock)
|
||||
clock = clock.Add(100 * 20 * time.Millisecond)
|
||||
level, noisy := a.GetLevel(clock)
|
||||
require.True(t, noisy)
|
||||
require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel)))
|
||||
require.Less(t, level, ConvertAudioLevel(float64(20)))
|
||||
|
||||
// let enough time pass to make the samples stale
|
||||
clock = clock.Add(1500 * time.Millisecond)
|
||||
level, noisy = a.GetLevel(clock)
|
||||
require.Equal(t, float64(0.0), level)
|
||||
require.False(t, noisy)
|
||||
})
|
||||
}
|
||||
|
||||
func createAudioLevel(activeLevel uint8, minPercentile uint8, observeDuration uint32) *AudioLevel {
|
||||
@@ -80,8 +114,8 @@ func createAudioLevel(activeLevel uint8, minPercentile uint8, observeDuration ui
|
||||
})
|
||||
}
|
||||
|
||||
func observeSamples(a *AudioLevel, level uint8, count int) {
|
||||
func observeSamples(a *AudioLevel, level uint8, count int, baseTime time.Time) {
|
||||
for i := 0; i < count; i++ {
|
||||
a.Observe(level, 20)
|
||||
a.Observe(level, 20, baseTime.Add(+time.Duration(i*20)*time.Millisecond))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/sfu/audio"
|
||||
dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor"
|
||||
"github.com/livekit/livekit-server/pkg/sfu/utils"
|
||||
sutils "github.com/livekit/livekit-server/pkg/utils"
|
||||
"github.com/livekit/mediatransportutil"
|
||||
@@ -39,8 +40,6 @@ import (
|
||||
"github.com/livekit/mediatransportutil/pkg/twcc"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
|
||||
dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -593,7 +592,7 @@ func (b *Buffer) processHeaderExtensions(p *rtp.Packet, arrivalTime time.Time) {
|
||||
if (p.Timestamp - b.latestTSForAudioLevel) < (1 << 31) {
|
||||
duration := (int64(p.Timestamp) - int64(b.latestTSForAudioLevel)) * 1e3 / int64(b.clockRate)
|
||||
if duration > 0 {
|
||||
b.audioLevel.Observe(ext.Level, uint32(duration))
|
||||
b.audioLevel.Observe(ext.Level, uint32(duration), arrivalTime)
|
||||
}
|
||||
|
||||
b.latestTSForAudioLevel = p.Timestamp
|
||||
@@ -624,9 +623,6 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime time.Time, flow
|
||||
if b.ddParser != nil {
|
||||
ddVal, videoLayer, err := b.ddParser.Parse(ep.Packet)
|
||||
if err != nil {
|
||||
if err != ErrFrameEarlierThanKeyFrame {
|
||||
b.logger.Warnw("could not parse dependency descriptor", err)
|
||||
}
|
||||
return nil
|
||||
} else if ddVal != nil {
|
||||
ep.DependencyDescriptor = ddVal
|
||||
@@ -855,7 +851,7 @@ func (b *Buffer) GetAudioLevel() (float64, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return b.audioLevel.GetLevel()
|
||||
return b.audioLevel.GetLevel(time.Now())
|
||||
}
|
||||
|
||||
func (b *Buffer) OnFpsChanged(f func()) {
|
||||
|
||||
@@ -27,7 +27,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFrameEarlierThanKeyFrame = fmt.Errorf("frame is earlier than current keyframe")
|
||||
ErrFrameEarlierThanKeyFrame = fmt.Errorf("frame is earlier than current keyframe")
|
||||
ErrDDStructureAttachedToNonFirstPacket = fmt.Errorf("dependency descriptor structure is attached to non-first packet of a frame")
|
||||
)
|
||||
|
||||
type DependencyDescriptorParser struct {
|
||||
@@ -39,7 +40,6 @@ type DependencyDescriptorParser struct {
|
||||
|
||||
seqWrapAround *utils.WrapAround[uint16, uint64]
|
||||
frameWrapAround *utils.WrapAround[uint16, uint64]
|
||||
structureExtSeq uint64
|
||||
structureExtFrameNum uint64
|
||||
activeDecodeTargetsExtSeq uint64
|
||||
activeDecodeTargetsMask uint32
|
||||
@@ -66,12 +66,15 @@ type ExtDependencyDescriptor struct {
|
||||
ActiveDecodeTargetsUpdated bool
|
||||
Integrity bool
|
||||
ExtFrameNum uint64
|
||||
// the frame number of the keyframe which the current frame depends on
|
||||
ExtKeyFrameNum uint64
|
||||
}
|
||||
|
||||
func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescriptor, VideoLayer, error) {
|
||||
var videoLayer VideoLayer
|
||||
ddBuf := pkt.GetExtension(r.ddExtID)
|
||||
if ddBuf == nil {
|
||||
r.logger.Warnw("dependency descriptor extension is not present", nil, "seq", pkt.SequenceNumber)
|
||||
return nil, videoLayer, nil
|
||||
}
|
||||
|
||||
@@ -82,7 +85,9 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr
|
||||
}
|
||||
_, err := ext.Unmarshal(ddBuf)
|
||||
if err != nil {
|
||||
// r.logger.Debugw("failed to parse generic dependency descriptor", "err", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf))
|
||||
if err != dd.ErrDDReaderNoStructure {
|
||||
r.logger.Warnw("failed to parse generic dependency descriptor", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf))
|
||||
}
|
||||
return nil, videoLayer, err
|
||||
}
|
||||
|
||||
@@ -108,17 +113,21 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr
|
||||
}
|
||||
|
||||
if ddVal.AttachedStructure != nil {
|
||||
r.logger.Debugw("parsed dependency descriptor", "extSeq", extSeq, "extFN", extFN, "structureID", ddVal.AttachedStructure.StructureId, "descriptor", ddVal.String())
|
||||
if extSeq > r.structureExtSeq {
|
||||
r.structure = ddVal.AttachedStructure
|
||||
r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure)
|
||||
r.structureExtSeq = extSeq
|
||||
r.structureExtFrameNum = extFN
|
||||
extDD.StructureUpdated = true
|
||||
extDD.ActiveDecodeTargetsUpdated = true
|
||||
// The dependency descriptor reader will always set ActiveDecodeTargetsBitmask for TemplateDependencyStructure is present,
|
||||
// so don't need to notify max layer change here.
|
||||
if !ddVal.FirstPacketInFrame {
|
||||
r.logger.Warnw("attached structure is not the first packet in frame", nil, "extSeq", extSeq, "extFN", extFN)
|
||||
return nil, videoLayer, ErrDDStructureAttachedToNonFirstPacket
|
||||
}
|
||||
|
||||
if r.structure == nil || ddVal.AttachedStructure.StructureId != r.structure.StructureId {
|
||||
r.logger.Infow("structure updated", "structureID", ddVal.AttachedStructure.StructureId, "extSeq", extSeq, "extFN", extFN, "descriptor", ddVal.String())
|
||||
}
|
||||
r.structure = ddVal.AttachedStructure
|
||||
r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure)
|
||||
r.structureExtFrameNum = extFN
|
||||
extDD.StructureUpdated = true
|
||||
extDD.ActiveDecodeTargetsUpdated = true
|
||||
// The dependency descriptor reader will always set ActiveDecodeTargetsBitmask for TemplateDependencyStructure is present,
|
||||
// so don't need to notify max layer change here.
|
||||
}
|
||||
|
||||
if mask := ddVal.ActiveDecodeTargetsBitmask; mask != nil && extSeq > r.activeDecodeTargetsExtSeq {
|
||||
@@ -143,6 +152,7 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr
|
||||
}
|
||||
|
||||
extDD.DecodeTargets = r.decodeTargets
|
||||
extDD.ExtKeyFrameNum = r.structureExtFrameNum
|
||||
|
||||
return extDD, videoLayer, nil
|
||||
}
|
||||
|
||||
@@ -129,31 +129,30 @@ func (r *RTPStatsReceiver) Update(
|
||||
pktSize := uint64(hdrSize + payloadSize + paddingSize)
|
||||
gapSN := int64(resSN.ExtendedVal - resSN.PreExtendedHighest)
|
||||
if gapSN <= 0 { // duplicate OR out-of-order
|
||||
if payloadSize == 0 {
|
||||
// do not start on a padding only packet
|
||||
if resTS.IsRestart {
|
||||
r.logger.Infow(
|
||||
"rolling back timestamp restart",
|
||||
"tsBefore", resTS.PreExtendedStart,
|
||||
"tsAfter", r.timestamp.GetExtendedStart(),
|
||||
"snBefore", resSN.PreExtendedStart,
|
||||
"snAfter", r.sequenceNumber.GetExtendedStart(),
|
||||
)
|
||||
r.timestamp.RollbackRestart(resTS.PreExtendedStart)
|
||||
}
|
||||
if resSN.IsRestart {
|
||||
r.logger.Infow(
|
||||
"rolling back sequence number restart",
|
||||
"snBefore", resSN.PreExtendedStart,
|
||||
"snAfter", r.sequenceNumber.GetExtendedStart(),
|
||||
"tsBefore", resTS.PreExtendedStart,
|
||||
"tsAfter", r.timestamp.GetExtendedStart(),
|
||||
)
|
||||
r.sequenceNumber.RollbackRestart(resSN.PreExtendedStart)
|
||||
flowState.IsNotHandled = true
|
||||
return
|
||||
}
|
||||
// before start, don't restart
|
||||
if resTS.IsRestart {
|
||||
r.logger.Infow(
|
||||
"rolling back timestamp restart",
|
||||
"tsBefore", resTS.PreExtendedStart,
|
||||
"tsAfter", r.timestamp.GetExtendedStart(),
|
||||
"snBefore", resSN.PreExtendedStart,
|
||||
"snAfter", r.sequenceNumber.GetExtendedStart(),
|
||||
)
|
||||
r.timestamp.RollbackRestart(resTS.PreExtendedStart)
|
||||
}
|
||||
if resSN.IsRestart {
|
||||
r.logger.Infow(
|
||||
"rolling back sequence number restart",
|
||||
"snBefore", resSN.PreExtendedStart,
|
||||
"snAfter", r.sequenceNumber.GetExtendedStart(),
|
||||
"tsBefore", resTS.PreExtendedStart,
|
||||
"tsAfter", r.timestamp.GetExtendedStart(),
|
||||
)
|
||||
r.sequenceNumber.RollbackRestart(resSN.PreExtendedStart)
|
||||
flowState.IsNotHandled = true
|
||||
return
|
||||
}
|
||||
|
||||
if -gapSN >= cNumSequenceNumbers/2 {
|
||||
r.logger.Warnw(
|
||||
"large sequence number gap negative", nil,
|
||||
|
||||
@@ -130,7 +130,7 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
|
||||
require.Equal(t, timestamp, r.timestamp.GetHighest())
|
||||
require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest()))
|
||||
|
||||
// out-of-order
|
||||
// out-of-order, would cause a restart which is disallowed
|
||||
packet = getPacket(sequenceNumber-10, timestamp-30000, 1000)
|
||||
flowState = r.Update(
|
||||
time.Now(),
|
||||
@@ -142,14 +142,15 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
|
||||
0,
|
||||
)
|
||||
require.False(t, flowState.HasLoss)
|
||||
require.True(t, flowState.IsNotHandled)
|
||||
require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest())
|
||||
require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest()))
|
||||
require.Equal(t, timestamp, r.timestamp.GetHighest())
|
||||
require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest()))
|
||||
require.Equal(t, uint64(1), r.packetsOutOfOrder)
|
||||
require.Equal(t, uint64(0), r.packetsOutOfOrder)
|
||||
require.Equal(t, uint64(0), r.packetsDuplicate)
|
||||
|
||||
// duplicate
|
||||
// duplicate of the above out-of-order packet, but would not be handled as it causes a restart
|
||||
packet = getPacket(sequenceNumber-10, timestamp-30000, 1000)
|
||||
flowState = r.Update(
|
||||
time.Now(),
|
||||
@@ -161,12 +162,13 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
|
||||
0,
|
||||
)
|
||||
require.False(t, flowState.HasLoss)
|
||||
require.True(t, flowState.IsNotHandled)
|
||||
require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest())
|
||||
require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest()))
|
||||
require.Equal(t, timestamp, r.timestamp.GetHighest())
|
||||
require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest()))
|
||||
require.Equal(t, uint64(2), r.packetsOutOfOrder)
|
||||
require.Equal(t, uint64(1), r.packetsDuplicate)
|
||||
require.Equal(t, uint64(0), r.packetsOutOfOrder)
|
||||
require.Equal(t, uint64(0), r.packetsDuplicate)
|
||||
|
||||
// loss
|
||||
sequenceNumber += 10
|
||||
@@ -184,10 +186,10 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
|
||||
require.True(t, flowState.HasLoss)
|
||||
require.Equal(t, uint64(sequenceNumber-9), flowState.LossStartInclusive)
|
||||
require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive)
|
||||
require.Equal(t, uint64(17), r.packetsLost)
|
||||
require.Equal(t, uint64(9), r.packetsLost)
|
||||
|
||||
// out-of-order should decrement number of lost packets
|
||||
packet = getPacket(sequenceNumber-15, timestamp-45000, 1000)
|
||||
packet = getPacket(sequenceNumber-6, timestamp-45000, 1000)
|
||||
flowState = r.Update(
|
||||
time.Now(),
|
||||
packet.Header.SequenceNumber,
|
||||
@@ -202,9 +204,9 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
|
||||
require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest()))
|
||||
require.Equal(t, timestamp, r.timestamp.GetHighest())
|
||||
require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest()))
|
||||
require.Equal(t, uint64(3), r.packetsOutOfOrder)
|
||||
require.Equal(t, uint64(1), r.packetsDuplicate)
|
||||
require.Equal(t, uint64(16), r.packetsLost)
|
||||
require.Equal(t, uint64(1), r.packetsOutOfOrder)
|
||||
require.Equal(t, uint64(0), r.packetsDuplicate)
|
||||
require.Equal(t, uint64(8), r.packetsLost)
|
||||
|
||||
// test sequence number history
|
||||
// with a gap
|
||||
@@ -223,7 +225,7 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
|
||||
require.True(t, flowState.HasLoss)
|
||||
require.Equal(t, uint64(sequenceNumber-1), flowState.LossStartInclusive)
|
||||
require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive)
|
||||
require.Equal(t, uint64(17), r.packetsLost)
|
||||
require.Equal(t, uint64(9), r.packetsLost)
|
||||
require.False(t, r.history.IsSet(uint64(sequenceNumber)-1))
|
||||
|
||||
// out-of-order
|
||||
@@ -240,8 +242,8 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
|
||||
0,
|
||||
)
|
||||
require.False(t, flowState.HasLoss)
|
||||
require.Equal(t, uint64(16), r.packetsLost)
|
||||
require.Equal(t, uint64(4), r.packetsOutOfOrder)
|
||||
require.Equal(t, uint64(8), r.packetsLost)
|
||||
require.Equal(t, uint64(2), r.packetsOutOfOrder)
|
||||
require.True(t, r.history.IsSet(uint64(sequenceNumber)))
|
||||
|
||||
// padding only
|
||||
@@ -257,8 +259,8 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
|
||||
25,
|
||||
)
|
||||
require.False(t, flowState.HasLoss)
|
||||
require.Equal(t, uint64(16), r.packetsLost)
|
||||
require.Equal(t, uint64(4), r.packetsOutOfOrder)
|
||||
require.Equal(t, uint64(8), r.packetsLost)
|
||||
require.Equal(t, uint64(2), r.packetsOutOfOrder)
|
||||
require.True(t, r.history.IsSet(uint64(sequenceNumber)))
|
||||
require.True(t, r.history.IsSet(uint64(sequenceNumber)-1))
|
||||
require.True(t, r.history.IsSet(uint64(sequenceNumber)-2))
|
||||
|
||||
@@ -18,6 +18,18 @@ import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDDReaderNoStructure = errors.New("DependencyDescriptorReader: Structure is nil")
|
||||
ErrDDReaderTemplateWithoutStructure = errors.New("DependencyDescriptorReader: has templateDependencyStructurePresentFlag but AttachedStructure is nil")
|
||||
ErrDDReaderTooManyTemplates = errors.New("DependencyDescriptorReader: too many templates")
|
||||
ErrDDReaderTooManyTemporalLayers = errors.New("DependencyDescriptorReader: too many temporal layers")
|
||||
ErrDDReaderTooManySpatialLayers = errors.New("DependencyDescriptorReader: too many spatial layers")
|
||||
ErrDDReaderInvalidTemplateIndex = errors.New("DependencyDescriptorReader: invalid template index")
|
||||
ErrDDReaderInvalidSpatialLayer = errors.New("DependencyDescriptorReader: invalid spatial layer, should be less than the number of resolutions")
|
||||
ErrDDReaderNumDTIMismatch = errors.New("DependencyDescriptorReader: decode target indications length mismatch with structure num decode targets")
|
||||
ErrDDReaderNumChainDiffsMismatch = errors.New("DependencyDescriptorReader: chain diffs length mismatch with structure num chains")
|
||||
)
|
||||
|
||||
type DependencyDescriptorReader struct {
|
||||
// Output.
|
||||
descriptor *DependencyDescriptor
|
||||
@@ -59,7 +71,7 @@ func (r *DependencyDescriptorReader) Parse() (int, error) {
|
||||
|
||||
if r.structure == nil {
|
||||
r.buffer.Invalidate()
|
||||
return 0, errors.New("DependencyDescriptorReader: Structure is nil")
|
||||
return 0, ErrDDReaderNoStructure
|
||||
}
|
||||
|
||||
if r.activeDecodeTargetsPresentFlag {
|
||||
@@ -140,7 +152,7 @@ func (r *DependencyDescriptorReader) readExtendedFields() error {
|
||||
return err
|
||||
}
|
||||
if r.descriptor.AttachedStructure == nil {
|
||||
return errors.New("DependencyDescriptorReader: has templateDependencyStructurePresentFlag but AttachedStructure is nil")
|
||||
return ErrDDReaderTemplateWithoutStructure
|
||||
}
|
||||
bitmask := uint32((uint64(1) << r.descriptor.AttachedStructure.NumDecodeTargets) - 1)
|
||||
r.descriptor.ActiveDecodeTargetsBitmask = &bitmask
|
||||
@@ -203,7 +215,7 @@ func (r *DependencyDescriptorReader) readTemplateLayers() error {
|
||||
)
|
||||
for {
|
||||
if len(templates) == MaxTemplates {
|
||||
return errors.New("DependencyDescriptorReader: too many templates")
|
||||
return ErrDDReaderTooManyTemplates
|
||||
}
|
||||
|
||||
var lastTemplate FrameDependencyTemplate
|
||||
@@ -220,13 +232,13 @@ func (r *DependencyDescriptorReader) readTemplateLayers() error {
|
||||
if nextLayerIdc == nextTemporalLayer {
|
||||
temporalId++
|
||||
if temporalId >= MaxTemporalIds {
|
||||
return errors.New("DependencyDescriptorReader: too many temporal layers")
|
||||
return ErrDDReaderTooManyTemporalLayers
|
||||
}
|
||||
} else if nextLayerIdc == nextSpatialLayer {
|
||||
spatialId++
|
||||
temporalId = 0
|
||||
if spatialId >= MaxSpatialIds {
|
||||
return errors.New("DependencyDescriptorReader: too many spatial layers")
|
||||
return ErrDDReaderTooManySpatialLayers
|
||||
}
|
||||
}
|
||||
|
||||
@@ -340,7 +352,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error {
|
||||
|
||||
if templateIndex >= len(r.structure.Templates) {
|
||||
r.buffer.Invalidate()
|
||||
return errors.New("DependencyDescriptorReader: invalid template index")
|
||||
return ErrDDReaderInvalidTemplateIndex
|
||||
}
|
||||
|
||||
// Copy all the fields from the matching template
|
||||
@@ -374,7 +386,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error {
|
||||
// then each spatial layer got one.
|
||||
if r.descriptor.FrameDependencies.SpatialId >= len(r.structure.Resolutions) {
|
||||
r.buffer.Invalidate()
|
||||
return errors.New("DependencyDescriptorReader: invalid spatial layer, should be less than the number of resolutions")
|
||||
return ErrDDReaderInvalidSpatialLayer
|
||||
}
|
||||
res := r.structure.Resolutions[r.descriptor.FrameDependencies.SpatialId]
|
||||
r.descriptor.Resolution = &res
|
||||
@@ -385,7 +397,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error {
|
||||
|
||||
func (r *DependencyDescriptorReader) readFrameDtis() error {
|
||||
if len(r.descriptor.FrameDependencies.DecodeTargetIndications) != r.structure.NumDecodeTargets {
|
||||
return errors.New("DependencyDescriptorReader: decode target indications length mismatch with structure num decode targets")
|
||||
return ErrDDReaderNumDTIMismatch
|
||||
}
|
||||
|
||||
for i := range r.descriptor.FrameDependencies.DecodeTargetIndications {
|
||||
@@ -420,7 +432,7 @@ func (r *DependencyDescriptorReader) readFrameFdiffs() error {
|
||||
|
||||
func (r *DependencyDescriptorReader) readFrameChains() error {
|
||||
if len(r.descriptor.FrameDependencies.ChainDiffs) != r.structure.NumChains {
|
||||
return errors.New("DependencyDescriptorReader: chain diffs length mismatch with structure num chains")
|
||||
return ErrDDReaderNumChainDiffsMismatch
|
||||
}
|
||||
|
||||
for i := range r.descriptor.FrameDependencies.ChainDiffs {
|
||||
|
||||
@@ -148,7 +148,7 @@ func (w *DependencyDescriptorWriter) calculateMatch(idx int, template *FrameDepe
|
||||
result.NeedCustomDtis = w.descriptor.FrameDependencies.DecodeTargetIndications != nil && !reflect.DeepEqual(w.descriptor.FrameDependencies.DecodeTargetIndications, template.DecodeTargetIndications)
|
||||
|
||||
for i := 0; i < w.structure.NumChains; i++ {
|
||||
if w.activeChains&(1<<i) != 0 && w.descriptor.FrameDependencies.ChainDiffs[i] != template.ChainDiffs[i] {
|
||||
if w.activeChains&(1<<i) != 0 && (len(w.descriptor.FrameDependencies.ChainDiffs) <= i || len(template.ChainDiffs) <= i || w.descriptor.FrameDependencies.ChainDiffs[i] != template.ChainDiffs[i]) {
|
||||
result.NeedCustomChains = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -193,7 +193,7 @@ func (s *StreamTrackerDependencyDescriptor) Observe(temporalLayer int32, pktSize
|
||||
|
||||
for _, dt := range ddVal.DecodeTargets {
|
||||
if len(dtis) <= dt.Target {
|
||||
s.params.Logger.Errorw("len(dtis) less than target", nil, "target", dt.Target, "dtls", dtis)
|
||||
s.params.Logger.Errorw("len(dtis) less than target", nil, "target", dt.Target, "dtis", dtis)
|
||||
continue
|
||||
}
|
||||
// we are not dropping discardable frames now, so only ingore not present frames
|
||||
|
||||
@@ -214,9 +214,9 @@ func (s *StreamTrackerManager) AddTracker(layer int32) streamtracker.StreamTrack
|
||||
})
|
||||
}
|
||||
|
||||
s.logger.Debugw("StreamTrackerManager add track", "layer", layer)
|
||||
s.logger.Debugw("stream tracker add track", "layer", layer)
|
||||
tracker.OnStatusChanged(func(status streamtracker.StreamStatus) {
|
||||
s.logger.Debugw("StreamTrackerManager OnStatusChanged", "layer", layer, "status", status)
|
||||
s.logger.Debugw("stream tracker status changed", "layer", layer, "status", status)
|
||||
if status == streamtracker.StreamStatusStopped {
|
||||
s.removeAvailableLayer(layer)
|
||||
} else {
|
||||
@@ -289,6 +289,10 @@ func (s *StreamTrackerManager) GetTracker(layer int32) streamtracker.StreamTrack
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
if int(layer) >= len(s.trackers) {
|
||||
s.logger.Errorw("unexpected layer", nil, "layer", layer)
|
||||
return nil
|
||||
}
|
||||
return s.trackers[layer]
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ package videolayerselector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/sfu/buffer"
|
||||
@@ -31,6 +32,8 @@ type DependencyDescriptor struct {
|
||||
previousActiveDecodeTargetsBitmask *uint32
|
||||
activeDecodeTargetsBitmask *uint32
|
||||
structure *dede.FrameDependencyStructure
|
||||
extKeyFrameNum uint64
|
||||
keyFrameValid bool
|
||||
|
||||
chains []*FrameChain
|
||||
|
||||
@@ -79,38 +82,76 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
|
||||
Temporal: int32(fd.TemporalId),
|
||||
}
|
||||
|
||||
if !d.keyFrameValid && dd.AttachedStructure == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// early return if this frame is already forwarded or dropped
|
||||
sd, err := d.decisions.GetDecision(extFrameNum)
|
||||
if err != nil {
|
||||
// do not mark as dropped as only error is an old frame
|
||||
d.logger.Debugw(fmt.Sprintf("drop packet on decision error, incoming %v, fn: %d/%d, sn: %d",
|
||||
incomingLayer,
|
||||
dd.FrameNumber,
|
||||
extFrameNum,
|
||||
extPkt.Packet.SequenceNumber,
|
||||
), "err", err)
|
||||
// d.logger.Debugw(fmt.Sprintf("drop packet on decision error, incoming %v, fn: %d/%d, sn: %d",
|
||||
// incomingLayer,
|
||||
// dd.FrameNumber,
|
||||
// extFrameNum,
|
||||
// extPkt.Packet.SequenceNumber,
|
||||
// ), "err", err)
|
||||
return
|
||||
}
|
||||
switch sd {
|
||||
case selectorDecisionDropped:
|
||||
// a packet of an alreadty dropped frame, maintain decision
|
||||
d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d",
|
||||
incomingLayer,
|
||||
dd.FrameNumber,
|
||||
extFrameNum,
|
||||
extPkt.Packet.SequenceNumber,
|
||||
))
|
||||
// d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d",
|
||||
// incomingLayer,
|
||||
// dd.FrameNumber,
|
||||
// extFrameNum,
|
||||
// extPkt.Packet.SequenceNumber,
|
||||
// ))
|
||||
return
|
||||
}
|
||||
|
||||
if ddwdt.StructureUpdated {
|
||||
d.updateDependencyStructure(dd.AttachedStructure, ddwdt.DecodeTargets)
|
||||
// TODO-REMOVE: remove this log after stable
|
||||
d.logger.Infow("update dependency structure",
|
||||
"structureID", dd.AttachedStructure.StructureId,
|
||||
"structure", dd.AttachedStructure,
|
||||
"decodeTargets", ddwdt.DecodeTargets,
|
||||
"efn", extFrameNum,
|
||||
"sn", extPkt.Packet.SequenceNumber,
|
||||
"isKeyFrame", extPkt.KeyFrame,
|
||||
"currentKeyframe", d.extKeyFrameNum,
|
||||
)
|
||||
|
||||
d.updateDependencyStructure(dd.AttachedStructure, ddwdt.DecodeTargets, extFrameNum)
|
||||
}
|
||||
|
||||
if ddwdt.ExtKeyFrameNum != d.extKeyFrameNum {
|
||||
// keyframe mismatch, drop and reset chains
|
||||
// TODO-REMOVE: remove this log after stable
|
||||
d.logger.Infow("drop packet for keyframe mismatch", "incoming", incomingLayer, "efn", extFrameNum, "sn", extPkt.Packet.SequenceNumber, "requiredKeyFrame", ddwdt.ExtKeyFrameNum, "structureKeyFrame", d.extKeyFrameNum)
|
||||
d.decisions.AddDropped(extFrameNum)
|
||||
d.invalidateKeyFrame()
|
||||
return
|
||||
}
|
||||
|
||||
if ddwdt.ActiveDecodeTargetsUpdated {
|
||||
d.updateActiveDecodeTargets(*dd.ActiveDecodeTargetsBitmask)
|
||||
}
|
||||
|
||||
// TODO-REMOVE: remove this log after stable
|
||||
if len(fd.ChainDiffs) != len(d.chains) {
|
||||
d.logger.Warnw("frame chain diff length mismatch", nil,
|
||||
"incoming", incomingLayer,
|
||||
"efn", extFrameNum,
|
||||
"sn", extPkt.Packet.SequenceNumber,
|
||||
"chainDiffs", fd.ChainDiffs,
|
||||
"chains", len(d.chains),
|
||||
"requiredKeyFrame", ddwdt.ExtKeyFrameNum,
|
||||
"structureKeyFrame", d.extKeyFrameNum)
|
||||
d.decisions.AddDropped(extFrameNum)
|
||||
return
|
||||
}
|
||||
|
||||
for _, chain := range d.chains {
|
||||
chain.OnFrame(extFrameNum, fd)
|
||||
}
|
||||
@@ -133,7 +174,7 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
|
||||
if err != nil {
|
||||
d.decodeTargetsLock.RUnlock()
|
||||
// dtis error, dependency descriptor might lost
|
||||
d.logger.Debugw(fmt.Sprintf("drop packet for frame detection error, incoming: %v", incomingLayer), "err", err)
|
||||
d.logger.Warnw(fmt.Sprintf("drop packet for frame detection error, incoming: %v", incomingLayer), err)
|
||||
d.decisions.AddDropped(extFrameNum)
|
||||
return
|
||||
}
|
||||
@@ -148,34 +189,34 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
|
||||
|
||||
if highestDecodeTarget.Target < 0 {
|
||||
// no active decode target, do not select
|
||||
d.logger.Debugw(
|
||||
"drop packet for no target found",
|
||||
"highestDecodeTarget", highestDecodeTarget,
|
||||
"decodeTargets", d.decodeTargets,
|
||||
"tagetLayer", d.targetLayer,
|
||||
"incoming", incomingLayer,
|
||||
"fn", dd.FrameNumber,
|
||||
"efn", extFrameNum,
|
||||
"sn", extPkt.Packet.SequenceNumber,
|
||||
"isKeyFrame", extPkt.KeyFrame,
|
||||
)
|
||||
// d.logger.Debugw(
|
||||
// "drop packet for no target found",
|
||||
// "highestDecodeTarget", highestDecodeTarget,
|
||||
// "decodeTargets", d.decodeTargets,
|
||||
// "tagetLayer", d.targetLayer,
|
||||
// "incoming", incomingLayer,
|
||||
// "fn", dd.FrameNumber,
|
||||
// "efn", extFrameNum,
|
||||
// "sn", extPkt.Packet.SequenceNumber,
|
||||
// "isKeyFrame", extPkt.KeyFrame,
|
||||
// )
|
||||
d.decisions.AddDropped(extFrameNum)
|
||||
return
|
||||
}
|
||||
|
||||
// DD-TODO : if bandwidth in congest, could drop the 'Discardable' frame
|
||||
if dti == dede.DecodeTargetNotPresent {
|
||||
d.logger.Debugw(
|
||||
"drop packet for decode target not present",
|
||||
"highestDecodeTarget", highestDecodeTarget,
|
||||
"decodeTargets", d.decodeTargets,
|
||||
"tagetLayer", d.targetLayer,
|
||||
"incoming", incomingLayer,
|
||||
"fn", dd.FrameNumber,
|
||||
"efn", extFrameNum,
|
||||
"sn", extPkt.Packet.SequenceNumber,
|
||||
"isKeyFrame", extPkt.KeyFrame,
|
||||
)
|
||||
// d.logger.Debugw(
|
||||
// "drop packet for decode target not present",
|
||||
// "highestDecodeTarget", highestDecodeTarget,
|
||||
// "decodeTargets", d.decodeTargets,
|
||||
// "tagetLayer", d.targetLayer,
|
||||
// "incoming", incomingLayer,
|
||||
// "fn", dd.FrameNumber,
|
||||
// "efn", extFrameNum,
|
||||
// "sn", extPkt.Packet.SequenceNumber,
|
||||
// "isKeyFrame", extPkt.KeyFrame,
|
||||
// )
|
||||
d.decisions.AddDropped(extFrameNum)
|
||||
return
|
||||
}
|
||||
@@ -195,17 +236,17 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
|
||||
}
|
||||
}
|
||||
if !isDecodable {
|
||||
d.logger.Debugw(
|
||||
"drop packet for not decodable",
|
||||
"highestDecodeTarget", highestDecodeTarget,
|
||||
"decodeTargets", d.decodeTargets,
|
||||
"tagetLayer", d.targetLayer,
|
||||
"incoming", incomingLayer,
|
||||
"fn", dd.FrameNumber,
|
||||
"efn", extFrameNum,
|
||||
"sn", extPkt.Packet.SequenceNumber,
|
||||
"isKeyFrame", extPkt.KeyFrame,
|
||||
)
|
||||
// d.logger.Debugw(
|
||||
// "drop packet for not decodable",
|
||||
// "highestDecodeTarget", highestDecodeTarget,
|
||||
// "decodeTargets", d.decodeTargets,
|
||||
// "tagetLayer", d.targetLayer,
|
||||
// "incoming", incomingLayer,
|
||||
// "fn", dd.FrameNumber,
|
||||
// "efn", extFrameNum,
|
||||
// "sn", extPkt.Packet.SequenceNumber,
|
||||
// "isKeyFrame", extPkt.KeyFrame,
|
||||
// )
|
||||
d.decisions.AddDropped(extFrameNum)
|
||||
return
|
||||
}
|
||||
@@ -263,11 +304,33 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
|
||||
// d.logger.Debugw("set active decode targets bitmask", "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask)
|
||||
}
|
||||
}
|
||||
bytes, err := ddExtension.Marshal()
|
||||
if err != nil {
|
||||
d.logger.Warnw("error marshalling dependency descriptor extension", err)
|
||||
} else {
|
||||
result.DependencyDescriptorExtension = bytes
|
||||
|
||||
var ddMarshaled bool
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
d.logger.Errorw("panic marshalling dependency descriptor extension", nil,
|
||||
"efn", extFrameNum,
|
||||
"sn", extPkt.Packet.SequenceNumber,
|
||||
"keyframeRequired", ddwdt.ExtKeyFrameNum,
|
||||
"currentKeyframe", d.extKeyFrameNum,
|
||||
"panic", r,
|
||||
"stack", string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
bytes, err := ddExtension.Marshal()
|
||||
if err != nil {
|
||||
d.logger.Warnw("error marshalling dependency descriptor extension", err)
|
||||
} else {
|
||||
result.DependencyDescriptorExtension = bytes
|
||||
ddMarshaled = true
|
||||
}
|
||||
}()
|
||||
|
||||
if !ddMarshaled {
|
||||
// drop packet if we can't marshal dependency descriptor
|
||||
d.decisions.AddDropped(extFrameNum)
|
||||
return
|
||||
}
|
||||
|
||||
if ddwdt.Integrity {
|
||||
@@ -284,8 +347,10 @@ func (d *DependencyDescriptor) Rollback() {
|
||||
d.Base.Rollback()
|
||||
}
|
||||
|
||||
func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure, decodeTargets []buffer.DependencyDescriptorDecodeTarget) {
|
||||
func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure, decodeTargets []buffer.DependencyDescriptorDecodeTarget, extFrameNum uint64) {
|
||||
d.structure = structure
|
||||
d.extKeyFrameNum = extFrameNum
|
||||
d.keyFrameValid = true
|
||||
|
||||
d.chains = d.chains[:0]
|
||||
|
||||
@@ -329,9 +394,17 @@ func (d *DependencyDescriptor) updateActiveDecodeTargets(activeDecodeTargetsBitm
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DependencyDescriptor) invalidateKeyFrame() {
|
||||
d.keyFrameValid = false
|
||||
d.chains = d.chains[:0]
|
||||
d.decodeTargetsLock.Lock()
|
||||
d.decodeTargets = d.decodeTargets[:0]
|
||||
d.decodeTargetsLock.Unlock()
|
||||
}
|
||||
|
||||
func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) {
|
||||
layer = d.GetRequestSpatial()
|
||||
if !d.currentLayer.IsValid() {
|
||||
if !d.currentLayer.IsValid() || !d.keyFrameValid {
|
||||
// always declare not locked when trying to resume from nothing
|
||||
return false, layer
|
||||
}
|
||||
@@ -339,7 +412,7 @@ func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) {
|
||||
d.decodeTargetsLock.RLock()
|
||||
defer d.decodeTargetsLock.RUnlock()
|
||||
for _, dt := range d.decodeTargets {
|
||||
if dt.Active() && dt.Layer.Spatial <= d.GetTarget().Spatial && dt.Valid() {
|
||||
if dt.Active() && dt.Layer.Spatial == layer && dt.Valid() {
|
||||
d.logger.Debugw(fmt.Sprintf("checking sync, matching decode target, layer: %d, dt: %s, dts: %+v", layer, dt, d.decodeTargets))
|
||||
return true, layer
|
||||
}
|
||||
|
||||
@@ -253,11 +253,23 @@ func TestDependencyDescriptor(t *testing.T) {
|
||||
}
|
||||
require.True(t, switchToLower)
|
||||
|
||||
// sync with requested layer
|
||||
// not sync with requested layer
|
||||
ddSelector.SetRequestSpatial(targetLayer.Spatial)
|
||||
locked, layer := ddSelector.CheckSync()
|
||||
require.True(t, locked)
|
||||
require.False(t, locked)
|
||||
require.Equal(t, targetLayer.Spatial, layer)
|
||||
// request to current layer, sync
|
||||
ddSelector.SetRequestSpatial(ddSelector.GetCurrent().Spatial)
|
||||
locked, _ = ddSelector.CheckSync()
|
||||
require.True(t, locked)
|
||||
|
||||
// should drop frame that relies on a keyframe is not present in current selection
|
||||
framesPrevious := createDDFrames(buffer.VideoLayer{Spatial: 2, Temporal: 2}, 1000)
|
||||
ret = ddSelector.Select(framesPrevious[1], 0)
|
||||
require.False(t, ret.IsSelected)
|
||||
// keyframe lost, out of sync
|
||||
locked, _ = ddSelector.CheckSync()
|
||||
require.False(t, locked)
|
||||
}
|
||||
|
||||
func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buffer.ExtPacket {
|
||||
@@ -279,7 +291,7 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff
|
||||
return decodeTargets[i].Layer.GreaterThan(decodeTargets[j].Layer)
|
||||
})
|
||||
|
||||
chainDiffs := make([]int, len(decodeTargets))
|
||||
chainDiffs := make([]int, int(maxLayer.Spatial)+1)
|
||||
dtis := make([]dd.DecodeTargetIndication, len(decodeTargets))
|
||||
for _, dt := range decodeTargets {
|
||||
dtis[dt.Target] = dd.DecodeTargetSwitch
|
||||
@@ -319,6 +331,7 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff
|
||||
ActiveDecodeTargetsUpdated: true,
|
||||
Integrity: true,
|
||||
ExtFrameNum: uint64(startFrameNumber),
|
||||
ExtKeyFrameNum: uint64(startFrameNumber),
|
||||
},
|
||||
Packet: &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
@@ -356,7 +369,6 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff
|
||||
}
|
||||
|
||||
frame := &buffer.ExtPacket{
|
||||
KeyFrame: true,
|
||||
DependencyDescriptor: &buffer.ExtDependencyDescriptor{
|
||||
Descriptor: &dd.DependencyDescriptor{
|
||||
FrameNumber: startFrameNumber,
|
||||
@@ -367,9 +379,10 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff
|
||||
DecodeTargetIndications: frameDtis,
|
||||
},
|
||||
},
|
||||
DecodeTargets: decodeTargets,
|
||||
Integrity: true,
|
||||
ExtFrameNum: uint64(startFrameNumber),
|
||||
DecodeTargets: decodeTargets,
|
||||
Integrity: true,
|
||||
ExtFrameNum: uint64(startFrameNumber),
|
||||
ExtKeyFrameNum: keyFrame.DependencyDescriptor.ExtFrameNum,
|
||||
},
|
||||
Packet: &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
|
||||
+158
@@ -0,0 +1,158 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/atomic"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/utils"
|
||||
)
|
||||
|
||||
type agentClient struct {
|
||||
mu sync.Mutex
|
||||
conn *websocket.Conn
|
||||
|
||||
registered atomic.Int32
|
||||
roomAvailability atomic.Int32
|
||||
roomJobs atomic.Int32
|
||||
participantAvailability atomic.Int32
|
||||
participantJobs atomic.Int32
|
||||
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newAgentClient(token string) (*agentClient, error) {
|
||||
host := fmt.Sprintf("ws://localhost:%d", defaultServerPort)
|
||||
u, err := url.Parse(host + "/agent")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestHeader := make(http.Header)
|
||||
requestHeader.Set("Authorization", "Bearer "+token)
|
||||
|
||||
connectUrl := u.String()
|
||||
conn, _, err := websocket.DefaultDialer.Dial(connectUrl, requestHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &agentClient{
|
||||
conn: conn,
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *agentClient) Run() error {
|
||||
go c.read()
|
||||
|
||||
workerID := utils.NewGuid("W_")
|
||||
|
||||
if err := c.write(&livekit.WorkerMessage{
|
||||
Message: &livekit.WorkerMessage_Register{
|
||||
Register: &livekit.RegisterWorkerRequest{
|
||||
Type: livekit.JobType_JT_ROOM,
|
||||
WorkerId: workerID,
|
||||
Version: "version",
|
||||
Name: "name",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.write(&livekit.WorkerMessage{
|
||||
Message: &livekit.WorkerMessage_Register{
|
||||
Register: &livekit.RegisterWorkerRequest{
|
||||
Type: livekit.JobType_JT_PUBLISHER,
|
||||
WorkerId: workerID,
|
||||
Version: "version",
|
||||
Name: "name",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *agentClient) read() {
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
_, b, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg := &livekit.ServerMessage{}
|
||||
if err = proto.Unmarshal(b, msg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch m := msg.Message.(type) {
|
||||
case *livekit.ServerMessage_Assignment:
|
||||
go c.handleAssignment(m.Assignment)
|
||||
case *livekit.ServerMessage_Availability:
|
||||
go c.handleAvailability(m.Availability)
|
||||
case *livekit.ServerMessage_Register:
|
||||
go c.handleRegister(m.Register)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *agentClient) handleAssignment(req *livekit.JobAssignment) {
|
||||
if req.Job.Type == livekit.JobType_JT_ROOM {
|
||||
c.roomJobs.Inc()
|
||||
} else {
|
||||
c.participantJobs.Inc()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *agentClient) handleAvailability(req *livekit.AvailabilityRequest) {
|
||||
if req.Job.Type == livekit.JobType_JT_ROOM {
|
||||
c.roomAvailability.Inc()
|
||||
} else {
|
||||
c.participantAvailability.Inc()
|
||||
}
|
||||
|
||||
c.write(&livekit.WorkerMessage{
|
||||
Message: &livekit.WorkerMessage_Availability{
|
||||
Availability: &livekit.AvailabilityResponse{
|
||||
JobId: req.Job.Id,
|
||||
Available: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *agentClient) handleRegister(req *livekit.RegisterWorkerResponse) {
|
||||
c.registered.Inc()
|
||||
}
|
||||
|
||||
func (c *agentClient) write(msg *livekit.WorkerMessage) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
b, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.conn.WriteMessage(websocket.BinaryMessage, b)
|
||||
}
|
||||
|
||||
func (c *agentClient) close() {
|
||||
close(c.done)
|
||||
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
_ = c.conn.Close()
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
// Copyright 2023 LiveKit, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/livekit/protocol/auth"
|
||||
)
|
||||
|
||||
func TestAgents(t *testing.T) {
|
||||
_, finish := setupSingleNodeTest("TestAgents")
|
||||
defer finish()
|
||||
|
||||
ac1, err := newAgentClient(agentToken())
|
||||
require.NoError(t, err)
|
||||
ac2, err := newAgentClient(agentToken())
|
||||
require.NoError(t, err)
|
||||
defer ac1.close()
|
||||
defer ac2.close()
|
||||
ac1.Run()
|
||||
ac2.Run()
|
||||
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
require.Equal(t, int32(2), ac1.registered.Load())
|
||||
require.Equal(t, int32(2), ac2.registered.Load())
|
||||
|
||||
c1 := createRTCClient("c1", defaultServerPort, nil)
|
||||
c2 := createRTCClient("c2", defaultServerPort, nil)
|
||||
waitUntilConnected(t, c1, c2)
|
||||
|
||||
// publish 2 tracks
|
||||
t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam")
|
||||
require.NoError(t, err)
|
||||
defer t1.Stop()
|
||||
t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam")
|
||||
require.NoError(t, err)
|
||||
defer t2.Stop()
|
||||
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
require.Equal(t, int32(1), ac1.roomJobs.Load()+ac2.roomJobs.Load())
|
||||
require.Equal(t, int32(1), ac1.participantJobs.Load()+ac2.participantJobs.Load())
|
||||
|
||||
// publish 2 tracks
|
||||
t3, err := c2.AddStaticTrack("audio/opus", "audio", "webcam")
|
||||
require.NoError(t, err)
|
||||
defer t3.Stop()
|
||||
t4, err := c2.AddStaticTrack("video/vp8", "video", "webcam")
|
||||
require.NoError(t, err)
|
||||
defer t4.Stop()
|
||||
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
require.Equal(t, int32(1), ac1.roomJobs.Load()+ac2.roomJobs.Load())
|
||||
require.Equal(t, int32(2), ac1.participantJobs.Load()+ac2.participantJobs.Load())
|
||||
}
|
||||
|
||||
func agentToken() string {
|
||||
at := auth.NewAccessToken(testApiKey, testApiSecret).
|
||||
AddGrant(&auth.VideoGrant{Agent: true})
|
||||
t, err := at.ToJWT()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
Reference in New Issue
Block a user