Merge remote-tracking branch 'origin/master' into raja_fr

This commit is contained in:
boks1971
2023-11-08 15:51:40 +05:30
44 changed files with 1674 additions and 338 deletions
+3 -3
View File
@@ -18,7 +18,7 @@ require (
github.com/jxskiss/base62 v1.1.0
github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1
github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e
github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1
github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e
github.com/livekit/psrpc v0.5.0
github.com/mackerelio/go-osstat v0.2.4
github.com/magefile/mage v1.15.0
@@ -27,7 +27,7 @@ require (
github.com/olekukonko/tablewriter v0.0.5
github.com/pion/dtls/v2 v2.2.7
github.com/pion/ice/v2 v2.3.11
github.com/pion/interceptor v0.1.24
github.com/pion/interceptor v0.1.25
github.com/pion/rtcp v1.2.10
github.com/pion/rtp v1.8.2
github.com/pion/sctp v1.8.9
@@ -47,7 +47,7 @@ require (
github.com/urfave/negroni/v3 v3.0.0
go.uber.org/atomic v1.11.0
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
golang.org/x/sync v0.4.0
golang.org/x/sync v0.5.0
google.golang.org/protobuf v1.31.0
gopkg.in/yaml.v3 v3.0.1
)
+6 -6
View File
@@ -125,8 +125,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD
github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ=
github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e h1:yNeIo7MSMUWgoLu7LkNKnBYnJBFPFH9Wq4S6h1kS44M=
github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e/go.mod h1:+WIOYwiBMive5T81V8B2wdAc2zQNRjNQiJIcPxMTILY=
github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1 h1:WPWxU9w5XHAsonxnSSIIXbWMty9b5uHnTnyKS9TpaXM=
github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ=
github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e h1:YShBpEjkEBY7yil2gjMWlkVkxs3OI58LIIYsBdb8aBU=
github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ=
github.com/livekit/psrpc v0.5.0 h1:g+yYNSs6Y1/vM7UlFkB2s/ARe2y3RKWZhX8ata5j+eo=
github.com/livekit/psrpc v0.5.0/go.mod h1:1XYH1LLoD/YbvBvt6xg2KQ/J3InLXSJK6PL/+DKmuAU=
github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs=
@@ -185,8 +185,8 @@ github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ
github.com/pion/ice/v2 v2.3.11 h1:rZjVmUwyT55cmN8ySMpL7rsS8KYsJERsrxJLLxpKhdw=
github.com/pion/ice/v2 v2.3.11/go.mod h1:hPcLC3kxMa+JGRzMHqQzjoSj3xtE9F+eoncmXLlCL4E=
github.com/pion/interceptor v0.1.18/go.mod h1:tpvvF4cPM6NGxFA1DUMbhabzQBxdWMATDGEUYOR9x6I=
github.com/pion/interceptor v0.1.24 h1:lN4ua3yUAJCgNKQKcZIM52wFjBgjN0r7shLj91PkJ0c=
github.com/pion/interceptor v0.1.24/go.mod h1:wkbPYAak5zKsfpVDYMtEfWEy8D4zL+rpxCxPImLOg3Y=
github.com/pion/interceptor v0.1.25 h1:pwY9r7P6ToQ3+IF0bajN0xmk/fNw/suTgaTdlwTDmhc=
github.com/pion/interceptor v0.1.25/go.mod h1:wkbPYAak5zKsfpVDYMtEfWEy8D4zL+rpxCxPImLOg3Y=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns v0.0.8 h1:HhicWIg7OX5PVilyBO6plhMetInbzkVJAhbdJiAeVaI=
@@ -335,8 +335,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -40,7 +40,7 @@ func (s *StaticClientConfigurationManager) GetConfiguration(clientInfo *livekit.
for _, c := range s.confs {
matched, err := c.Match.Match(clientInfo)
if err != nil {
logger.Errorw("matchrule failed", err, "clientInfo", clientInfo.String())
logger.Errorw("matchrule failed", err, "clientInfo", logger.Proto(clientInfo))
continue
}
if !matched {
+92
View File
@@ -0,0 +1,92 @@
// Copyright 2023 LiveKit, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rtc
import (
"context"
"time"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
"github.com/livekit/psrpc"
)
const (
RoomAgentTopic = "room"
PublisherAgentTopic = "publisher"
)
type AgentClient interface {
CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) *rpc.CheckEnabledResponse
JobRequest(ctx context.Context, job *livekit.Job)
}
type agentClient struct {
client rpc.AgentInternalClient
}
func NewAgentClient(bus psrpc.MessageBus) (AgentClient, error) {
client, err := rpc.NewAgentInternalClient(bus)
if err != nil {
return nil, err
}
return &agentClient{client: client}, nil
}
func (c *agentClient) CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) *rpc.CheckEnabledResponse {
res := &rpc.CheckEnabledResponse{}
resChan, err := c.client.CheckEnabled(ctx, req, psrpc.WithRequestTimeout(time.Second))
if err != nil {
return res
}
for r := range resChan {
if r.Err != nil {
continue
}
if r.Result.RoomEnabled {
res.RoomEnabled = true
if res.PublisherEnabled {
return res
}
}
if r.Result.PublisherEnabled {
res.PublisherEnabled = true
if res.RoomEnabled {
return res
}
}
}
return res
}
func (c *agentClient) JobRequest(ctx context.Context, job *livekit.Job) {
var topic string
var logError bool
switch job.Type {
case livekit.JobType_JT_ROOM:
topic = RoomAgentTopic
case livekit.JobType_JT_PUBLISHER:
topic = PublisherAgentTopic
logError = true
}
_, err := c.client.JobRequest(ctx, topic, job)
if err != nil && logError {
logger.Warnw("agent job request failed", err)
}
}
+7 -5
View File
@@ -134,7 +134,7 @@ func IsCodecEnabled(codecs []*livekit.Codec, cap webrtc.RTPCodecCapability) bool
return false
}
func selectAlternativeCodec(enabledCodecs []*livekit.Codec) string {
func selectAlternativeVideoCodec(enabledCodecs []*livekit.Codec) string {
// sort these by compatibility, since we are looking for backups
if slices.ContainsFunc(enabledCodecs, func(c *livekit.Codec) bool {
return strings.EqualFold(c.Mime, webrtc.MimeTypeVP8)
@@ -146,9 +146,11 @@ func selectAlternativeCodec(enabledCodecs []*livekit.Codec) string {
}) {
return webrtc.MimeTypeH264
}
if len(enabledCodecs) > 0 {
return enabledCodecs[0].Mime
for _, c := range enabledCodecs {
if strings.HasPrefix(c.Mime, "video/") {
return c.Mime
}
}
// uh oh. this should not happen
return ""
// no viable codec in the list of enabled codecs, fall back to the most widely supported codec
return webrtc.MimeTypeVP8
}
+11 -2
View File
@@ -239,13 +239,22 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra
t.params.Logger.Debugw("AddReceiver", "mime", track.Codec().MimeType)
wr := t.MediaTrackReceiver.Receiver(mime)
if wr == nil {
var priority int
priority := -1
for idx, c := range t.params.TrackInfo.Codecs {
if strings.HasSuffix(mime, c.MimeType) {
if strings.EqualFold(mime, c.MimeType) {
priority = idx
break
}
}
if len(t.params.TrackInfo.Codecs) == 0 {
priority = 0
}
if priority < 0 {
t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mime, "track", logger.Proto(t.params.TrackInfo))
t.lock.Unlock()
return false
}
newWR := sfu.NewWebRTCReceiver(
receiver,
track,
+66 -38
View File
@@ -101,7 +101,8 @@ type ParticipantParams struct {
PLIThrottleConfig config.PLIThrottleConfig
CongestionControlConfig config.CongestionControlConfig
// codecs that are enabled for this room
EnabledCodecs []*livekit.Codec
PublishEnabledCodecs []*livekit.Codec
SubscribeEnabledCodecs []*livekit.Codec
Logger logger.Logger
SimTracks map[uint32]SimulcastTrackInfo
Grants *auth.ClaimGrants
@@ -138,6 +139,7 @@ type ParticipantImpl struct {
resSinkMu sync.Mutex
resSink routing.MessageSink
grants *auth.ClaimGrants
hidden atomic.Bool
isPublisher atomic.Bool
// when first connected
@@ -248,8 +250,9 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) {
p.migrateState.Store(types.MigrateStateInit)
p.state.Store(livekit.ParticipantInfo_JOINING)
p.grants = params.Grants
p.hidden.Store(p.grants.Video.Hidden)
p.SetResponseSink(params.Sink)
p.setupEnabledCodecs(params.EnabledCodecs, params.ClientConf.GetDisabledCodecs())
p.setupEnabledCodecs(params.PublishEnabledCodecs, params.SubscribeEnabledCodecs, params.ClientConf.GetDisabledCodecs())
p.supervisor.OnPublicationError(p.onPublicationError)
@@ -425,6 +428,7 @@ func (p *ParticipantImpl) SetPermission(permission *livekit.ParticipantPermissio
p.params.Logger.Infow("updating participant permission", "permission", permission)
video.UpdateFromPermission(permission)
p.hidden.Store(permission.Hidden)
p.dirty.Store(true)
canPublish := video.GetCanPublish()
@@ -710,7 +714,7 @@ func (p *ParticipantImpl) SetMigrateInfo(
p.supervisor.SetPublicationMute(livekit.TrackID(ti.Sid), ti.Muted)
p.pendingTracks[t.GetCid()] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}, migrated: true}
p.pubLogger.Infow("pending track added (migration)", "trackID", ti.Sid, "track", ti.String())
p.pubLogger.Infow("pending track added (migration)", "trackID", ti.Sid, "track", logger.Proto(ti))
}
p.pendingTracksLock.Unlock()
@@ -734,7 +738,7 @@ func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseRea
"sendLeave", sendLeave,
"reason", reason.String(),
"isExpectedToResume", isExpectedToResume,
"clientInfo", p.params.ClientInfo.String(),
"clientInfo", logger.Proto(p.params.ClientInfo),
)
p.clearDisconnectTimer()
p.clearMigrationTimer()
@@ -1020,10 +1024,7 @@ func (p *ParticipantImpl) CanPublishData() bool {
}
func (p *ParticipantImpl) Hidden() bool {
p.lock.RLock()
defer p.lock.RUnlock()
return p.grants.Video.Hidden
return p.hidden.Load()
}
func (p *ParticipantImpl) IsRecorder() bool {
@@ -1033,6 +1034,13 @@ func (p *ParticipantImpl) IsRecorder() bool {
return p.grants.Video.Recorder
}
func (p *ParticipantImpl) IsAgent() bool {
p.lock.RLock()
defer p.lock.RUnlock()
return p.grants.Video.Agent
}
func (p *ParticipantImpl) VerifySubscribeParticipantInfo(pID livekit.ParticipantID, version uint32) {
if !p.IsReady() {
// we have not sent a JoinResponse yet. metadata would be covered in JoinResponse
@@ -1623,10 +1631,12 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l
seenCodecs := make(map[string]struct{})
for _, codec := range req.SimulcastCodecs {
mime := codec.Codec
if req.Type == livekit.TrackType_VIDEO && !strings.HasPrefix(mime, "video/") {
mime = "video/" + mime
if req.Type == livekit.TrackType_VIDEO {
if !strings.HasPrefix(mime, "video/") {
mime = "video/" + mime
}
if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) {
altCodec := selectAlternativeCodec(p.enabledPublishCodecs)
altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs)
p.pubLogger.Infow("falling back to alternative codec",
"codec", mime,
"altCodec", altCodec,
@@ -1639,7 +1649,7 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l
mime = "audio/" + mime
}
if _, ok := seenCodecs[mime]; ok {
if _, ok := seenCodecs[mime]; ok || mime == "" {
continue
}
seenCodecs[mime] = struct{}{}
@@ -1658,12 +1668,12 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l
} else {
p.pendingTracks[req.Cid].trackInfos = append(p.pendingTracks[req.Cid].trackInfos, ti)
}
p.pubLogger.Infow("pending track queued", "trackID", ti.Sid, "track", ti.String(), "request", req.String())
p.pubLogger.Infow("pending track queued", "trackID", ti.Sid, "track", logger.Proto(ti), "request", logger.Proto(req))
return nil
}
p.pendingTracks[req.Cid] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}}
p.pubLogger.Infow("pending track added", "trackID", ti.Sid, "track", ti.String(), "request", req.String())
p.pubLogger.Infow("pending track added", "trackID", ti.Sid, "track", logger.Proto(ti), "request", logger.Proto(req))
return ti
}
@@ -1681,7 +1691,7 @@ func (p *ParticipantImpl) GetPendingTrack(trackID livekit.TrackID) *livekit.Trac
}
func (p *ParticipantImpl) sendTrackPublished(cid string, ti *livekit.TrackInfo) {
p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", ti.String())
p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", logger.Proto(ti))
_ = p.writeMessage(&livekit.SignalResponse{
Message: &livekit.SignalResponse_TrackPublished{
TrackPublished: &livekit.TrackPublishedResponse{
@@ -1761,12 +1771,32 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei
// use existing media track to handle simulcast
mt, ok := p.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack)
if !ok {
signalCid, ti := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind()))
signalCid, ti, migrated := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind()))
if ti == nil {
p.pendingTracksLock.Unlock()
return nil, false
}
// check if the migrated track has correct codec
if migrated && len(ti.Codecs) > 0 {
parameters := rtpReceiver.GetParameters()
var codecFound int
for _, c := range ti.Codecs {
for _, nc := range parameters.Codecs {
if strings.EqualFold(nc.MimeType, c.MimeType) {
codecFound++
break
}
}
}
if codecFound != len(ti.Codecs) {
p.params.Logger.Warnw("migrated track codec mismatched", nil, "track", logger.Proto(ti), "webrtcCodec", parameters)
p.pendingTracksLock.Unlock()
p.IssueFullReconnect(types.ParticipantCloseReasonMigrateCodecMismatch)
return nil, false
}
}
ti.MimeType = track.Codec().MimeType
mt = p.addMediaTrack(signalCid, track.ID(), ti)
newTrack = true
@@ -1793,7 +1823,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei
}
func (p *ParticipantImpl) addMigrateMutedTrack(cid string, ti *livekit.TrackInfo) *MediaTrack {
p.pubLogger.Infow("add migrate muted track", "cid", cid, "trackID", ti.Sid, "track", ti.String())
p.pubLogger.Infow("add migrate muted track", "cid", cid, "trackID", ti.Sid, "track", logger.Proto(ti))
rtpReceiver := p.TransportManager.GetPublisherRTPReceiver(ti.Mid)
if rtpReceiver == nil {
p.pubLogger.Errorw("could not find receiver for migrated track", nil, "trackID", ti.Sid)
@@ -1975,7 +2005,7 @@ func (p *ParticipantImpl) onUpTrackManagerClose() {
p.postRtcp(nil)
}
func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) {
func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo, bool) {
signalCid := clientId
pendingInfo := p.pendingTracks[clientId]
if pendingInfo == nil {
@@ -2011,10 +2041,10 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp
// if still not found, we are done
if pendingInfo == nil {
p.pubLogger.Errorw("track info not published prior to track", nil, "clientId", clientId)
return signalCid, nil
return signalCid, nil, false
}
return signalCid, pendingInfo.trackInfos[0]
return signalCid, pendingInfo.trackInfos[0], pendingInfo.migrated
}
// setStableTrackID either generates a new TrackID or reuses a previously used one
@@ -2205,7 +2235,7 @@ func (p *ParticipantImpl) IssueFullReconnect(reason types.ParticipantCloseReason
scr := types.SignallingCloseReasonUnknown
switch reason {
case types.ParticipantCloseReasonPublicationError:
case types.ParticipantCloseReasonPublicationError, types.ParticipantCloseReasonMigrateCodecMismatch:
scr = types.SignallingCloseReasonFullReconnectPublicationError
case types.ParticipantCloseReasonSubscriptionError:
scr = types.SignallingCloseReasonFullReconnectSubscriptionError
@@ -2334,9 +2364,7 @@ func (p *ParticipantImpl) SendDataPacket(dp *livekit.DataPacket, data []byte) er
return err
}
func (p *ParticipantImpl) setupEnabledCodecs(codecs []*livekit.Codec, disabledCodecs *livekit.DisabledCodecs) {
subscribeCodecs := make([]*livekit.Codec, 0, len(codecs))
publishCodecs := make([]*livekit.Codec, 0, len(codecs))
func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Codec, subscribeEnabledCodecs []*livekit.Codec, disabledCodecs *livekit.DisabledCodecs) {
shouldDisable := func(c *livekit.Codec, disabled []*livekit.Codec) bool {
for _, disableCodec := range disabled {
// disable codec's fmtp is empty means disable this codec entirely
@@ -2346,22 +2374,22 @@ func (p *ParticipantImpl) setupEnabledCodecs(codecs []*livekit.Codec, disabledCo
}
return false
}
for _, c := range codecs {
var publishDisabled bool
var subscribeDisabled bool
publishCodecs := make([]*livekit.Codec, 0, len(publishEnabledCodecs))
for _, c := range publishEnabledCodecs {
if shouldDisable(c, disabledCodecs.GetCodecs()) || shouldDisable(c, disabledCodecs.GetPublish()) {
continue
}
publishCodecs = append(publishCodecs, c)
}
p.enabledPublishCodecs = publishCodecs
subscribeCodecs := make([]*livekit.Codec, 0, len(subscribeEnabledCodecs))
for _, c := range subscribeEnabledCodecs {
if shouldDisable(c, disabledCodecs.GetCodecs()) {
publishDisabled = true
subscribeDisabled = true
} else if shouldDisable(c, disabledCodecs.GetPublish()) {
publishDisabled = true
}
if !publishDisabled {
publishCodecs = append(publishCodecs, c)
}
if !subscribeDisabled {
subscribeCodecs = append(subscribeCodecs, c)
continue
}
subscribeCodecs = append(subscribeCodecs, c)
}
p.enabledSubscribeCodecs = subscribeCodecs
p.enabledPublishCodecs = publishCodecs
}
+15 -14
View File
@@ -284,7 +284,7 @@ func TestMuteSetting(t *testing.T) {
Muted: true,
})
_, ti := p.getPendingTrack("cid", livekit.TrackType_AUDIO)
_, ti, _ := p.getPendingTrack("cid", livekit.TrackType_AUDIO)
require.NotNil(t, ti)
require.True(t, ti.Muted)
})
@@ -748,19 +748,20 @@ func newParticipantForTestWithOpts(identity livekit.ParticipantIdentity, opts *p
}
sid := livekit.ParticipantID(utils.NewGuid(utils.ParticipantPrefix))
p, _ := NewParticipant(ParticipantParams{
SID: sid,
Identity: identity,
Config: rtcConf,
Sink: &routingfakes.FakeMessageSink{},
ProtocolVersion: opts.protocolVersion,
PLIThrottleConfig: conf.RTC.PLIThrottle,
Grants: grants,
EnabledCodecs: enabledCodecs,
ClientConf: opts.clientConf,
ClientInfo: ClientInfo{ClientInfo: opts.clientInfo},
Logger: LoggerWithParticipant(logger.GetLogger(), identity, sid, false),
Telemetry: &telemetryfakes.FakeTelemetryService{},
VersionGenerator: utils.NewDefaultTimedVersionGenerator(),
SID: sid,
Identity: identity,
Config: rtcConf,
Sink: &routingfakes.FakeMessageSink{},
ProtocolVersion: opts.protocolVersion,
PLIThrottleConfig: conf.RTC.PLIThrottle,
Grants: grants,
PublishEnabledCodecs: enabledCodecs,
SubscribeEnabledCodecs: enabledCodecs,
ClientConf: opts.clientConf,
ClientInfo: ClientInfo{ClientInfo: opts.clientInfo},
Logger: LoggerWithParticipant(logger.GetLogger(), identity, sid, false),
Telemetry: &telemetryfakes.FakeTelemetryService{},
VersionGenerator: utils.NewDefaultTimedVersionGenerator(),
})
p.isPublisher.Store(opts.publisher)
p.updateState(livekit.ParticipantInfo_ACTIVE)
+3 -3
View File
@@ -46,7 +46,7 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.Se
}
p.pendingTracksLock.RLock()
_, info := p.getPendingTrack(streamID, livekit.TrackType_AUDIO)
_, info, _ := p.getPendingTrack(streamID, livekit.TrackType_AUDIO)
// if RED is disabled for this track, don't prefer RED codec in offer
disableRed := info != nil && info.DisableRed
p.pendingTracksLock.RUnlock()
@@ -132,7 +132,7 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess
if mt != nil {
info = mt.ToProto()
} else {
_, info = p.getPendingTrack(streamID, livekit.TrackType_VIDEO)
_, info, _ = p.getPendingTrack(streamID, livekit.TrackType_VIDEO)
}
if info == nil {
@@ -239,7 +239,7 @@ func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescript
track, _ := p.getPublishedTrackBySdpCid(streamID).(*MediaTrack)
if track == nil {
p.pendingTracksLock.RLock()
_, ti = p.getPendingTrack(streamID, livekit.TrackType_AUDIO)
_, ti, _ = p.getPendingTrack(streamID, livekit.TrackType_AUDIO)
p.pendingTracksLock.RUnlock()
} else {
ti = track.TrackInfo(false)
+59 -9
View File
@@ -30,6 +30,7 @@ import (
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
"github.com/livekit/protocol/utils"
"github.com/livekit/livekit-server/pkg/config"
@@ -77,11 +78,15 @@ type Room struct {
egressLauncher EgressLauncher
trackManager *RoomTrackManager
// agents
agentClient AgentClient
publisherAgentsEnabled bool
// map of identity -> Participant
participants map[livekit.ParticipantIdentity]types.LocalParticipant
participantOpts map[livekit.ParticipantIdentity]*ParticipantOptions
participantRequestSources map[livekit.ParticipantIdentity]routing.MessageSource
hasPublished sync.Map // map of identity -> bool
hasPublished map[livekit.ParticipantIdentity]bool
bufferFactory *buffer.FactoryOfBufferFactory
// batch update participant info for non-publishers
@@ -113,6 +118,7 @@ func NewRoom(
audioConfig *config.AudioConfig,
serverInfo *livekit.ServerInfo,
telemetry telemetry.TelemetryService,
agentClient AgentClient,
egressLauncher EgressLauncher,
) *Room {
r := &Room{
@@ -127,16 +133,19 @@ func NewRoom(
audioConfig: audioConfig,
telemetry: telemetry,
egressLauncher: egressLauncher,
agentClient: agentClient,
trackManager: NewRoomTrackManager(),
serverInfo: serverInfo,
participants: make(map[livekit.ParticipantIdentity]types.LocalParticipant),
participantOpts: make(map[livekit.ParticipantIdentity]*ParticipantOptions),
participantRequestSources: make(map[livekit.ParticipantIdentity]routing.MessageSource),
hasPublished: make(map[livekit.ParticipantIdentity]bool),
bufferFactory: buffer.NewFactoryOfBufferFactory(config.Receiver.PacketBufferSize),
batchedUpdates: make(map[livekit.ParticipantIdentity]*livekit.ParticipantInfo),
closed: make(chan struct{}),
trailer: []byte(utils.RandomSecret()),
}
r.protoProxy = utils.NewProtoProxy[*livekit.Room](roomUpdateInterval, r.updateProto)
if r.protoRoom.EmptyTimeout == 0 {
r.protoRoom.EmptyTimeout = DefaultEmptyTimeout
@@ -145,6 +154,21 @@ func NewRoom(
r.protoRoom.CreationTime = time.Now().Unix()
}
if agentClient != nil {
go func() {
res := r.agentClient.CheckEnabled(context.Background(), &rpc.CheckEnabledRequest{})
if res.PublisherEnabled {
r.lock.Lock()
r.publisherAgentsEnabled = true
// if there are already published tracks, start the agents
for identity := range r.hasPublished {
r.launchPublisherAgent(r.participants[identity])
}
r.lock.Unlock()
}
}()
}
go r.audioUpdateWorker()
go r.connectionQualityWorker()
go r.changeUpdateWorker()
@@ -474,6 +498,7 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek
delete(r.participants, identity)
delete(r.participantOpts, identity)
delete(r.participantRequestSources, identity)
delete(r.hasPublished, identity)
if !p.Hidden() {
r.protoRoom.NumParticipants--
}
@@ -509,7 +534,6 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek
for _, t := range p.GetPublishedTracks() {
r.trackManager.RemoveTrack(t)
}
r.hasPublished.Delete(p.Identity())
p.OnTrackUpdated(nil)
p.OnTrackPublished(nil)
@@ -898,10 +922,19 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types.
r.trackManager.AddTrack(track, participant.Identity(), participant.ID())
// auto egress
if r.internal != nil {
if r.internal.ParticipantEgress != nil {
if _, hasPublished := r.hasPublished.Swap(participant.Identity(), true); !hasPublished {
// launch jobs
r.lock.Lock()
hasPublished := r.hasPublished[participant.Identity()]
r.hasPublished[participant.Identity()] = true
publisherAgentsEnabled := r.publisherAgentsEnabled
r.lock.Unlock()
if !hasPublished {
if publisherAgentsEnabled {
r.launchPublisherAgent(participant)
}
if r.internal != nil && r.internal.ParticipantEgress != nil {
go func() {
if err := StartParticipantEgress(
context.Background(),
r.egressLauncher,
@@ -913,9 +946,11 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types.
); err != nil {
r.Logger.Errorw("failed to launch participant egress", err)
}
}
}()
}
if r.internal.TrackEgress != nil {
}
if r.internal != nil && r.internal.TrackEgress != nil {
go func() {
if err := StartTrackEgress(
context.Background(),
r.egressLauncher,
@@ -927,7 +962,7 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types.
); err != nil {
r.Logger.Errorw("failed to launch track egress", err)
}
}
}()
}
}
@@ -1286,6 +1321,21 @@ func (r *Room) connectionQualityWorker() {
}
}
func (r *Room) launchPublisherAgent(p types.Participant) {
if p == nil || p.IsRecorder() || p.IsAgent() {
return
}
go func() {
r.agentClient.JobRequest(context.Background(), &livekit.Job{
Id: utils.NewGuid("JP_"),
Type: livekit.JobType_JT_PUBLISHER,
Room: r.ToProto(),
Participant: p.ToProto(),
})
}()
}
func (r *Room) DebugInfo() map[string]interface{} {
info := map[string]interface{}{
"Name": r.protoRoom.Name,
+1 -1
View File
@@ -739,7 +739,7 @@ func newRoomWithParticipants(t *testing.T, opts testRoomOpts) *Room {
Region: "testregion",
},
telemetry.NewTelemetryService(webhook.NewDefaultNotifier("", "", nil), &telemetryfakes.FakeAnalyticsService{}),
nil,
nil, nil,
)
for i := 0; i < opts.num+opts.numHidden; i++ {
identity := livekit.ParticipantIdentity(fmt.Sprintf("p%d", i))
+1 -1
View File
@@ -833,7 +833,7 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni
}
dcCloseHandler := func() {
t.params.Logger.Infow(dc.Label() + " data channel close")
t.params.Logger.Debugw(dc.Label() + " data channel close")
}
dcErrorHandler := func(err error) {
+5 -1
View File
@@ -104,6 +104,7 @@ const (
ParticipantCloseReasonPublicationError
ParticipantCloseReasonSubscriptionError
ParticipantCloseReasonDataChannelError
ParticipantCloseReasonMigrateCodecMismatch
)
func (p ParticipantCloseReason) String() string {
@@ -154,6 +155,8 @@ func (p ParticipantCloseReason) String() string {
return "SUBSCRIPTION_ERROR"
case ParticipantCloseReasonDataChannelError:
return "DATA_CHANNEL_ERROR"
case ParticipantCloseReasonMigrateCodecMismatch:
return "MIGRATE_CODEC_MISMATCH"
default:
return fmt.Sprintf("%d", int(p))
}
@@ -184,7 +187,7 @@ func (p ParticipantCloseReason) ToDisconnectReason() livekit.DisconnectReason {
return livekit.DisconnectReason_SERVER_SHUTDOWN
case ParticipantCloseReasonOvercommitted:
return livekit.DisconnectReason_SERVER_SHUTDOWN
case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError:
case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError, ParticipantCloseReasonMigrateCodecMismatch:
return livekit.DisconnectReason_STATE_MISMATCH
default:
// the other types will map to unknown reason
@@ -261,6 +264,7 @@ type Participant interface {
// permissions
Hidden() bool
IsRecorder() bool
IsAgent() bool
Start()
Close(sendLeave bool, reason ParticipantCloseReason, isExpectedToResume bool) error
@@ -408,6 +408,16 @@ type FakeLocalParticipant struct {
identityReturnsOnCall map[int]struct {
result1 livekit.ParticipantIdentity
}
IsAgentStub func() bool
isAgentMutex sync.RWMutex
isAgentArgsForCall []struct {
}
isAgentReturns struct {
result1 bool
}
isAgentReturnsOnCall map[int]struct {
result1 bool
}
IsClosedStub func() bool
isClosedMutex sync.RWMutex
isClosedArgsForCall []struct {
@@ -2986,6 +2996,59 @@ func (fake *FakeLocalParticipant) IdentityReturnsOnCall(i int, result1 livekit.P
}{result1}
}
func (fake *FakeLocalParticipant) IsAgent() bool {
fake.isAgentMutex.Lock()
ret, specificReturn := fake.isAgentReturnsOnCall[len(fake.isAgentArgsForCall)]
fake.isAgentArgsForCall = append(fake.isAgentArgsForCall, struct {
}{})
stub := fake.IsAgentStub
fakeReturns := fake.isAgentReturns
fake.recordInvocation("IsAgent", []interface{}{})
fake.isAgentMutex.Unlock()
if stub != nil {
return stub()
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeLocalParticipant) IsAgentCallCount() int {
fake.isAgentMutex.RLock()
defer fake.isAgentMutex.RUnlock()
return len(fake.isAgentArgsForCall)
}
func (fake *FakeLocalParticipant) IsAgentCalls(stub func() bool) {
fake.isAgentMutex.Lock()
defer fake.isAgentMutex.Unlock()
fake.IsAgentStub = stub
}
func (fake *FakeLocalParticipant) IsAgentReturns(result1 bool) {
fake.isAgentMutex.Lock()
defer fake.isAgentMutex.Unlock()
fake.IsAgentStub = nil
fake.isAgentReturns = struct {
result1 bool
}{result1}
}
func (fake *FakeLocalParticipant) IsAgentReturnsOnCall(i int, result1 bool) {
fake.isAgentMutex.Lock()
defer fake.isAgentMutex.Unlock()
fake.IsAgentStub = nil
if fake.isAgentReturnsOnCall == nil {
fake.isAgentReturnsOnCall = make(map[int]struct {
result1 bool
})
}
fake.isAgentReturnsOnCall[i] = struct {
result1 bool
}{result1}
}
func (fake *FakeLocalParticipant) IsClosed() bool {
fake.isClosedMutex.Lock()
ret, specificReturn := fake.isClosedReturnsOnCall[len(fake.isClosedArgsForCall)]
@@ -6031,6 +6094,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} {
defer fake.iDMutex.RUnlock()
fake.identityMutex.RLock()
defer fake.identityMutex.RUnlock()
fake.isAgentMutex.RLock()
defer fake.isAgentMutex.RUnlock()
fake.isClosedMutex.RLock()
defer fake.isClosedMutex.RUnlock()
fake.isDisconnectedMutex.RLock()
@@ -118,6 +118,16 @@ type FakeParticipant struct {
identityReturnsOnCall map[int]struct {
result1 livekit.ParticipantIdentity
}
IsAgentStub func() bool
isAgentMutex sync.RWMutex
isAgentArgsForCall []struct {
}
isAgentReturns struct {
result1 bool
}
isAgentReturnsOnCall map[int]struct {
result1 bool
}
IsPublisherStub func() bool
isPublisherMutex sync.RWMutex
isPublisherArgsForCall []struct {
@@ -780,6 +790,59 @@ func (fake *FakeParticipant) IdentityReturnsOnCall(i int, result1 livekit.Partic
}{result1}
}
func (fake *FakeParticipant) IsAgent() bool {
fake.isAgentMutex.Lock()
ret, specificReturn := fake.isAgentReturnsOnCall[len(fake.isAgentArgsForCall)]
fake.isAgentArgsForCall = append(fake.isAgentArgsForCall, struct {
}{})
stub := fake.IsAgentStub
fakeReturns := fake.isAgentReturns
fake.recordInvocation("IsAgent", []interface{}{})
fake.isAgentMutex.Unlock()
if stub != nil {
return stub()
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeParticipant) IsAgentCallCount() int {
fake.isAgentMutex.RLock()
defer fake.isAgentMutex.RUnlock()
return len(fake.isAgentArgsForCall)
}
func (fake *FakeParticipant) IsAgentCalls(stub func() bool) {
fake.isAgentMutex.Lock()
defer fake.isAgentMutex.Unlock()
fake.IsAgentStub = stub
}
func (fake *FakeParticipant) IsAgentReturns(result1 bool) {
fake.isAgentMutex.Lock()
defer fake.isAgentMutex.Unlock()
fake.IsAgentStub = nil
fake.isAgentReturns = struct {
result1 bool
}{result1}
}
func (fake *FakeParticipant) IsAgentReturnsOnCall(i int, result1 bool) {
fake.isAgentMutex.Lock()
defer fake.isAgentMutex.Unlock()
fake.IsAgentStub = nil
if fake.isAgentReturnsOnCall == nil {
fake.isAgentReturnsOnCall = make(map[int]struct {
result1 bool
})
}
fake.isAgentReturnsOnCall[i] = struct {
result1 bool
}{result1}
}
func (fake *FakeParticipant) IsPublisher() bool {
fake.isPublisherMutex.Lock()
ret, specificReturn := fake.isPublisherReturnsOnCall[len(fake.isPublisherArgsForCall)]
@@ -1318,6 +1381,8 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} {
defer fake.iDMutex.RUnlock()
fake.identityMutex.RLock()
defer fake.identityMutex.RUnlock()
fake.isAgentMutex.RLock()
defer fake.isAgentMutex.RUnlock()
fake.isPublisherMutex.RLock()
defer fake.isPublisherMutex.RUnlock()
fake.isRecorderMutex.RLock()
+6 -6
View File
@@ -160,8 +160,8 @@ func (u *UpTrackManager) UpdateSubscriptionPermission(
u.params.Logger.Debugw(
"skipping older subscription permission version",
"existingValue", perms,
"existingVersion", u.subscriptionPermissionVersion.ToProto().String(),
"requestingValue", subscriptionPermission.String(),
"existingVersion", u.subscriptionPermissionVersion.String(),
"requestingValue", logger.Proto(subscriptionPermission),
"requestingVersion", timedVersion.String(),
)
u.lock.Unlock()
@@ -178,7 +178,7 @@ func (u *UpTrackManager) UpdateSubscriptionPermission(
if subscriptionPermission == nil {
u.params.Logger.Debugw(
"updating subscription permission, setting to nil",
"version", u.subscriptionPermissionVersion.ToProto().String(),
"version", u.subscriptionPermissionVersion.String(),
)
// possible to get a nil when migrating
u.lock.Unlock()
@@ -187,8 +187,8 @@ func (u *UpTrackManager) UpdateSubscriptionPermission(
u.params.Logger.Debugw(
"updating subscription permission",
"permissions", u.subscriptionPermission.String(),
"version", u.subscriptionPermissionVersion.ToProto().String(),
"permissions", logger.Proto(u.subscriptionPermission),
"version", u.subscriptionPermissionVersion.String(),
)
if err := u.parseSubscriptionPermissionsLocked(subscriptionPermission, func(pID livekit.ParticipantID) types.LocalParticipant {
u.lock.Unlock()
@@ -247,7 +247,7 @@ func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) {
u.publishedTracks[track.ID()] = track
}
u.lock.Unlock()
u.params.Logger.Debugw("added published track", "trackID", track.ID(), "trackInfo", track.ToProto().String())
u.params.Logger.Debugw("added published track", "trackID", track.ID(), "trackInfo", logger.Proto(track.ToProto()))
track.AddOnClose(func() {
notifyClose := false
+468
View File
@@ -0,0 +1,468 @@
// Copyright 2023 LiveKit, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package service
import (
"context"
"errors"
"io"
"math/rand"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/livekit/livekit-server/pkg/rtc"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
"github.com/livekit/psrpc"
)
const AgentServiceVersion = "0.1.0"
type AgentService struct {
upgrader websocket.Upgrader
*AgentHandler
}
type AgentHandler struct {
agentServer rpc.AgentInternalServer
roomTopic string
publisherTopic string
mu sync.Mutex
availability map[string]chan *availability
unregistered map[*websocket.Conn]*worker
roomRegistered bool
roomWorkers map[string]*worker
publisherRegistered bool
publisherWorkers map[string]*worker
}
type worker struct {
mu sync.Mutex
conn *websocket.Conn
sigConn *WSSignalConnection
id string
jobType livekit.JobType
status livekit.WorkerStatus
activeJobs int
}
type availability struct {
workerID string
available bool
}
func NewAgentService(bus psrpc.MessageBus) (*AgentService, error) {
s := &AgentService{
upgrader: websocket.Upgrader{},
}
// allow connections from any origin, since script may be hosted anywhere
// security is enforced by access tokens
s.upgrader.CheckOrigin = func(r *http.Request) bool {
return true
}
agentServer, err := rpc.NewAgentInternalServer(s, bus)
if err != nil {
return nil, err
}
s.AgentHandler = NewAgentHandler(agentServer, rtc.RoomAgentTopic, rtc.PublisherAgentTopic)
return s, nil
}
func (s *AgentService) ServeHTTP(writer http.ResponseWriter, r *http.Request) {
// reject non websocket requests
if !websocket.IsWebSocketUpgrade(r) {
writer.WriteHeader(404)
return
}
// require a claim
claims := GetGrants(r.Context())
if claims == nil || claims.Video == nil || !claims.Video.Agent {
handleError(writer, http.StatusUnauthorized, rtc.ErrPermissionDenied)
return
}
// upgrade
conn, err := s.upgrader.Upgrade(writer, r, nil)
if err != nil {
handleError(writer, http.StatusInternalServerError, err)
return
}
s.HandleConnection(conn)
}
func NewAgentHandler(agentServer rpc.AgentInternalServer, roomTopic, publisherTopic string) *AgentHandler {
return &AgentHandler{
agentServer: agentServer,
roomTopic: roomTopic,
publisherTopic: publisherTopic,
availability: make(map[string]chan *availability),
unregistered: make(map[*websocket.Conn]*worker),
roomWorkers: make(map[string]*worker),
publisherWorkers: make(map[string]*worker),
}
}
func (s *AgentHandler) HandleConnection(conn *websocket.Conn) {
sigConn := NewWSSignalConnection(conn)
w := &worker{
conn: conn,
sigConn: sigConn,
}
s.mu.Lock()
s.unregistered[conn] = w
s.mu.Unlock()
defer func() {
s.mu.Lock()
if w.id == "" {
delete(s.unregistered, conn)
} else {
switch w.jobType {
case livekit.JobType_JT_ROOM:
delete(s.roomWorkers, w.id)
if s.roomRegistered && !s.roomAvailableLocked() {
s.roomRegistered = false
s.agentServer.DeregisterJobRequestTopic(s.roomTopic)
}
case livekit.JobType_JT_PUBLISHER:
delete(s.publisherWorkers, w.id)
if s.publisherRegistered && !s.publisherAvailableLocked() {
s.publisherRegistered = false
s.agentServer.DeregisterJobRequestTopic(s.publisherTopic)
}
}
}
s.mu.Unlock()
}()
// handle incoming requests from websocket
for {
req, _, err := sigConn.ReadWorkerMessage()
if err != nil {
// normal/expected closure
if err == io.EOF ||
strings.HasSuffix(err.Error(), "use of closed network connection") ||
strings.HasSuffix(err.Error(), "connection reset by peer") ||
websocket.IsCloseError(
err,
websocket.CloseAbnormalClosure,
websocket.CloseGoingAway,
websocket.CloseNormalClosure,
websocket.CloseNoStatusReceived,
) {
logger.Infow("exit ws read loop for closed connection", "wsError", err)
} else {
logger.Errorw("error reading from websocket", err)
}
return
}
switch m := req.Message.(type) {
case *livekit.WorkerMessage_Register:
go s.handleRegister(w, m.Register)
case *livekit.WorkerMessage_Availability:
go s.handleAvailability(w, m.Availability)
case *livekit.WorkerMessage_JobUpdate:
go s.handleJobUpdate(w, m.JobUpdate)
case *livekit.WorkerMessage_Status:
go s.handleStatus(w, m.Status)
}
}
}
func (s *AgentHandler) handleRegister(worker *worker, msg *livekit.RegisterWorkerRequest) {
s.mu.Lock()
defer s.mu.Unlock()
switch msg.Type {
case livekit.JobType_JT_ROOM:
worker.id = msg.WorkerId
delete(s.unregistered, worker.conn)
s.roomWorkers[worker.id] = worker
if !s.roomRegistered {
err := s.agentServer.RegisterJobRequestTopic(s.roomTopic)
if err != nil {
logger.Errorw("failed to register room agents", err)
} else {
s.roomRegistered = true
}
}
case livekit.JobType_JT_PUBLISHER:
worker.id = msg.WorkerId
delete(s.unregistered, worker.conn)
s.publisherWorkers[worker.id] = worker
if !s.publisherRegistered {
err := s.agentServer.RegisterJobRequestTopic(s.publisherTopic)
if err != nil {
logger.Errorw("failed to register publisher agents", err)
} else {
s.publisherRegistered = true
}
}
}
_, err := worker.sigConn.WriteServerMessage(&livekit.ServerMessage{
Message: &livekit.ServerMessage_Register{
Register: &livekit.RegisterWorkerResponse{
WorkerId: worker.id,
ServerVersion: AgentServiceVersion,
},
},
})
if err != nil {
logger.Errorw("failed to write server message", err)
}
}
func (s *AgentHandler) handleAvailability(w *worker, msg *livekit.AvailabilityResponse) {
s.mu.Lock()
availabilityChan, ok := s.availability[msg.JobId]
s.mu.Unlock()
if ok {
availabilityChan <- &availability{
workerID: w.id,
available: msg.Available,
}
}
}
func (s *AgentHandler) handleJobUpdate(w *worker, msg *livekit.JobStatusUpdate) {
switch msg.Status {
case livekit.JobStatus_JS_SUCCESS:
logger.Debugw("job complete", "jobID", msg.JobId)
case livekit.JobStatus_JS_FAILED:
logger.Warnw("job failed", errors.New(msg.Error), "jobID", msg.JobId)
}
w.mu.Lock()
w.activeJobs--
w.mu.Unlock()
}
func (s *AgentHandler) handleStatus(w *worker, msg *livekit.UpdateWorkerStatus) {
s.mu.Lock()
defer s.mu.Unlock()
w.mu.Lock()
w.status = msg.Status
w.mu.Unlock()
switch w.jobType {
case livekit.JobType_JT_ROOM:
if s.roomRegistered && !s.roomAvailableLocked() {
s.roomRegistered = false
s.agentServer.DeregisterJobRequestTopic(s.roomTopic)
} else if !s.roomRegistered && s.roomAvailableLocked() {
if err := s.agentServer.RegisterJobRequestTopic(s.roomTopic); err != nil {
logger.Errorw("failed to register room agents", err)
} else {
s.roomRegistered = true
}
}
case livekit.JobType_JT_PUBLISHER:
if s.publisherRegistered && !s.publisherAvailableLocked() {
s.publisherRegistered = false
s.agentServer.DeregisterJobRequestTopic(s.publisherTopic)
} else if !s.publisherRegistered && s.publisherAvailableLocked() {
if err := s.agentServer.RegisterJobRequestTopic(s.publisherTopic); err != nil {
logger.Errorw("failed to register publisher agents", err)
} else {
s.publisherRegistered = true
}
}
}
}
func (s *AgentHandler) CheckEnabled(_ context.Context, _ *rpc.CheckEnabledRequest) (*rpc.CheckEnabledResponse, error) {
s.mu.Lock()
res := &rpc.CheckEnabledResponse{
RoomEnabled: len(s.roomWorkers) > 0,
PublisherEnabled: len(s.publisherWorkers) > 0,
}
s.mu.Unlock()
return res, nil
}
func (s *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*emptypb.Empty, error) {
s.mu.Lock()
ac := make(chan *availability, 100)
s.availability[job.Id] = ac
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.availability, job.Id)
s.mu.Unlock()
}()
var pool map[string]*worker
switch job.Type {
case livekit.JobType_JT_ROOM:
pool = s.roomWorkers
case livekit.JobType_JT_PUBLISHER:
pool = s.publisherWorkers
}
attempted := make(map[string]bool)
for {
select {
case <-ctx.Done():
return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "request timed out")
default:
s.mu.Lock()
var selected *worker
for _, w := range pool {
if attempted[w.id] {
continue
}
if w.status == livekit.WorkerStatus_WS_AVAILABLE {
if w.activeJobs > 0 {
selected = w
break
} else if selected == nil {
selected = w
}
}
}
s.mu.Unlock()
if selected == nil {
return nil, psrpc.NewErrorf(psrpc.Unavailable, "no workers available")
}
attempted[selected.id] = true
_, err := selected.sigConn.WriteServerMessage(&livekit.ServerMessage{Message: &livekit.ServerMessage_Availability{
Availability: &livekit.AvailabilityRequest{Job: job},
}})
if err != nil {
logger.Errorw("failed to send availability request", err)
return nil, err
}
select {
case <-ctx.Done():
return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "request timed out")
case res := <-ac:
if res.available {
_, err = selected.sigConn.WriteServerMessage(&livekit.ServerMessage{Message: &livekit.ServerMessage_Assignment{
Assignment: &livekit.JobAssignment{Job: job},
}})
if err != nil {
logger.Errorw("failed to assign job", err)
} else {
selected.mu.Lock()
selected.activeJobs++
selected.mu.Unlock()
return &emptypb.Empty{}, nil
}
}
}
}
}
}
func (s *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) float32 {
s.mu.Lock()
defer s.mu.Unlock()
var pool map[string]*worker
switch job.Type {
case livekit.JobType_JT_ROOM:
pool = s.roomWorkers
case livekit.JobType_JT_PUBLISHER:
pool = s.publisherWorkers
}
var affinity float32
for _, w := range pool {
if w.status == livekit.WorkerStatus_WS_AVAILABLE {
if w.activeJobs > 0 {
return 1
} else {
affinity = 0.5
}
}
}
return affinity
}
func (s *AgentHandler) NumConnections() int {
s.mu.Lock()
defer s.mu.Unlock()
return len(s.unregistered) + len(s.roomWorkers) + len(s.publisherWorkers)
}
func (s *AgentHandler) DrainConnections(interval time.Duration) {
// jitter drain start
time.Sleep(time.Duration(rand.Int63n(int64(interval))))
t := time.NewTicker(interval)
defer t.Stop()
s.mu.Lock()
defer s.mu.Unlock()
for conn := range s.unregistered {
_ = conn.Close()
<-t.C
}
for _, w := range s.roomWorkers {
_ = w.conn.Close()
<-t.C
}
for _, w := range s.publisherWorkers {
_ = w.conn.Close()
<-t.C
}
}
func (s *AgentHandler) roomAvailableLocked() bool {
for _, w := range s.roomWorkers {
if w.status == livekit.WorkerStatus_WS_AVAILABLE {
return true
}
}
return false
}
func (s *AgentHandler) publisherAvailableLocked() bool {
for _, w := range s.publisherWorkers {
if w.status == livekit.WorkerStatus_WS_AVAILABLE {
return true
}
}
return false
}
+75
View File
@@ -0,0 +1,75 @@
// Copyright 2023 LiveKit, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package service
import (
"context"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/livekit/livekit-server/pkg/rtc"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
"github.com/livekit/protocol/utils"
)
type IOClient interface {
CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error)
GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error)
ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error)
}
type egressLauncher struct {
client rpc.EgressClient
io IOClient
}
func NewEgressLauncher(client rpc.EgressClient, io IOClient) rtc.EgressLauncher {
if client == nil {
return nil
}
return &egressLauncher{
client: client,
io: io,
}
}
func (s *egressLauncher) StartEgress(ctx context.Context, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) {
info, err := s.StartEgressWithClusterId(ctx, "", req)
if err != nil {
return nil, err
}
_, err = s.io.CreateEgress(ctx, info)
if err != nil {
logger.Errorw("failed to create egress", err)
}
return info, nil
}
func (s *egressLauncher) StartEgressWithClusterId(ctx context.Context, clusterId string, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) {
if s.client == nil {
return nil, ErrEgressNotConnected
}
// Ensure we have an Egress ID
if req.EgressId == "" {
req.EgressId = utils.NewGuid(utils.EgressPrefix)
}
return s.client.StartEgress(ctx, clusterId, req)
}
+2 -46
View File
@@ -25,40 +25,23 @@ import (
"github.com/livekit/livekit-server/pkg/rtc"
"github.com/livekit/protocol/egress"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
"github.com/livekit/protocol/utils"
)
type EgressService struct {
launcher rtc.EgressLauncher
client rpc.EgressClient
io IOClient
roomService livekit.RoomService
store ServiceStore
launcher rtc.EgressLauncher
}
type egressLauncher struct {
client rpc.EgressClient
io IOClient
}
func NewEgressLauncher(client rpc.EgressClient, io IOClient) rtc.EgressLauncher {
if client == nil {
return nil
}
return &egressLauncher{
client: client,
io: io,
}
}
func NewEgressService(
client rpc.EgressClient,
launcher rtc.EgressLauncher,
store ServiceStore,
io IOClient,
rs livekit.RoomService,
launcher rtc.EgressLauncher,
) *EgressService {
return &EgressService{
client: client,
@@ -189,33 +172,6 @@ func (s *EgressService) startEgress(ctx context.Context, roomName livekit.RoomNa
return s.launcher.StartEgress(ctx, req)
}
func (s *egressLauncher) StartEgress(ctx context.Context, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) {
info, err := s.StartEgressWithClusterId(ctx, "", req)
if err != nil {
return nil, err
}
_, err = s.io.CreateEgress(ctx, info)
if err != nil {
logger.Errorw("failed to create egress", err)
}
return info, nil
}
func (s *egressLauncher) StartEgressWithClusterId(ctx context.Context, clusterId string, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) {
if s.client == nil {
return nil, ErrEgressNotConnected
}
// Ensure we have an Egress ID
if req.EgressId == "" {
req.EgressId = utils.NewGuid(utils.EgressPrefix)
}
return s.client.StartEgress(ctx, clusterId, req)
}
type LayoutMetadata struct {
Layout string `json:"layout"`
}
-10
View File
@@ -18,10 +18,7 @@ import (
"context"
"time"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/rpc"
)
//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate
@@ -74,13 +71,6 @@ type IngressStore interface {
DeleteIngress(ctx context.Context, info *livekit.IngressInfo) error
}
//counterfeiter:generate . IOClient
type IOClient interface {
CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error)
GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error)
ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error)
}
//counterfeiter:generate . RoomAllocator
type RoomAllocator interface {
CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest) (*livekit.Room, bool, error)
+3 -1
View File
@@ -60,8 +60,10 @@ func (r *StandardRoomAllocator) CreateRoom(ctx context.Context, req *livekit.Cre
}()
// find existing room and update it
var created bool
rm, internal, err := r.roomStore.LoadRoom(ctx, livekit.RoomName(req.Name), true)
if err == ErrRoomNotFound {
created = true
rm = &livekit.Room{
Sid: utils.NewGuid(utils.RoomPrefix),
Name: req.Name,
@@ -114,7 +116,7 @@ func (r *StandardRoomAllocator) CreateRoom(ctx context.Context, req *livekit.Cre
return nil, false, routing.ErrNodeLimitReached
}
return rm, false, nil
return rm, created, nil
}
// select a new node
+6 -2
View File
@@ -69,6 +69,7 @@ type RoomManager struct {
roomStore ObjectStore
telemetry telemetry.TelemetryService
clientConfManager clientconfiguration.ClientConfigurationManager
agentClient rtc.AgentClient
egressLauncher rtc.EgressLauncher
versionGenerator utils.TimedVersionGenerator
turnAuthHandler *TURNAuthHandler
@@ -89,6 +90,7 @@ func NewLocalRoomManager(
router routing.Router,
telemetry telemetry.TelemetryService,
clientConfManager clientconfiguration.ClientConfigurationManager,
agentClient rtc.AgentClient,
egressLauncher rtc.EgressLauncher,
versionGenerator utils.TimedVersionGenerator,
turnAuthHandler *TURNAuthHandler,
@@ -108,6 +110,7 @@ func NewLocalRoomManager(
telemetry: telemetry,
clientConfManager: clientConfManager,
egressLauncher: egressLauncher,
agentClient: agentClient,
versionGenerator: versionGenerator,
turnAuthHandler: turnAuthHandler,
bus: bus,
@@ -393,7 +396,8 @@ func (r *RoomManager) StartSession(
Trailer: room.Trailer(),
PLIThrottleConfig: r.config.RTC.PLIThrottle,
CongestionControlConfig: r.config.RTC.CongestionControl,
EnabledCodecs: protoRoom.EnabledCodecs,
PublishEnabledCodecs: protoRoom.EnabledCodecs,
SubscribeEnabledCodecs: protoRoom.EnabledCodecs,
Grants: pi.Grants,
Logger: pLogger,
ClientConf: clientConf,
@@ -526,7 +530,7 @@ func (r *RoomManager) getOrCreateRoom(ctx context.Context, roomName livekit.Room
}
// construct ice servers
newRoom := rtc.NewRoom(ri, internal, *r.rtcConfig, &r.config.Audio, r.serverInfo, r.telemetry, r.egressLauncher)
newRoom := rtc.NewRoom(ri, internal, *r.rtcConfig, &r.config.Audio, r.serverInfo, r.telemetry, r.agentClient, r.egressLauncher)
roomTopic := rpc.FormatRoomTopic(roomName)
roomServer := utils.Must(rpc.NewTypedRoomServer(r, r.bus))
+35 -9
View File
@@ -31,6 +31,7 @@ import (
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
"github.com/livekit/protocol/utils"
)
// A rooms service that supports a single node
@@ -41,6 +42,7 @@ type RoomService struct {
router routing.MessageRouter
roomAllocator RoomAllocator
roomStore ServiceStore
agentClient rtc.AgentClient
egressLauncher rtc.EgressLauncher
topicFormatter rpc.TopicFormatter
roomClient rpc.TypedRoomClient
@@ -54,6 +56,7 @@ func NewRoomService(
router routing.MessageRouter,
roomAllocator RoomAllocator,
serviceStore ServiceStore,
agentClient rtc.AgentClient,
egressLauncher rtc.EgressLauncher,
topicFormatter rpc.TopicFormatter,
roomClient rpc.TypedRoomClient,
@@ -66,6 +69,7 @@ func NewRoomService(
router: router,
roomAllocator: roomAllocator,
roomStore: serviceStore,
agentClient: agentClient,
egressLauncher: egressLauncher,
topicFormatter: topicFormatter,
roomClient: roomClient,
@@ -112,17 +116,29 @@ func (s *RoomService) CreateRoom(ctx context.Context, req *livekit.CreateRoomReq
return nil, err
}
if created && req.Egress != nil && req.Egress.Room != nil {
egress := &rpc.StartEgressRequest{
Request: &rpc.StartEgressRequest_RoomComposite{
RoomComposite: req.Egress.Room,
},
RoomId: rm.Sid,
if created {
go func() {
s.agentClient.JobRequest(ctx, &livekit.Job{
Id: utils.NewGuid("JR_"),
Type: livekit.JobType_JT_ROOM,
Room: rm,
})
}()
if req.Egress != nil && req.Egress.Room != nil {
_, err = s.egressLauncher.StartEgress(ctx, &rpc.StartEgressRequest{
Request: &rpc.StartEgressRequest_RoomComposite{
RoomComposite: req.Egress.Room,
},
RoomId: rm.Sid,
})
if err != nil {
return nil, err
}
}
_, err = s.egressLauncher.StartEgress(ctx, egress)
}
return rm, err
return rm, nil
}
func (s *RoomService) ListRooms(ctx context.Context, req *livekit.ListRoomsRequest) (*livekit.ListRoomsResponse, error) {
@@ -427,7 +443,7 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat
// no one has joined the room, would not have been created on an RTC node.
// in this case, we'd want to run create again
_, _, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{
room, created, err := s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{
Name: req.Room,
Metadata: req.Metadata,
})
@@ -465,6 +481,16 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat
return nil, err
}
if created {
go func() {
s.agentClient.JobRequest(ctx, &livekit.Job{
Id: utils.NewGuid("JR_"),
Type: livekit.JobType_JT_ROOM,
Room: room,
})
}()
}
return room, nil
}
+1
View File
@@ -136,6 +136,7 @@ func newTestRoomService(conf config.RoomConfig) *TestRoomService {
allocator,
store,
nil,
nil,
rpc.NewTopicFormatter(),
&rpcfakes.FakeTypedRoomClient{},
&rpcfakes.FakeTypedParticipantClient{},
+28 -8
View File
@@ -31,17 +31,17 @@ import (
"github.com/ua-parser/uap-go/uaparser"
"golang.org/x/exp/maps"
"github.com/livekit/livekit-server/pkg/utils"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/psrpc"
"github.com/livekit/livekit-server/pkg/config"
"github.com/livekit/livekit-server/pkg/routing"
"github.com/livekit/livekit-server/pkg/routing/selector"
"github.com/livekit/livekit-server/pkg/rtc"
"github.com/livekit/livekit-server/pkg/telemetry"
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
"github.com/livekit/livekit-server/pkg/utils"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
putil "github.com/livekit/protocol/utils"
"github.com/livekit/psrpc"
)
type RTCService struct {
@@ -54,6 +54,7 @@ type RTCService struct {
isDev bool
limits config.LimitConfig
parser *uaparser.Parser
agentClient rtc.AgentClient
telemetry telemetry.TelemetryService
mu sync.Mutex
@@ -66,6 +67,7 @@ func NewRTCService(
store ServiceStore,
router routing.MessageRouter,
currentNode routing.LocalNode,
agentClient rtc.AgentClient,
telemetry telemetry.TelemetryService,
) *RTCService {
s := &RTCService{
@@ -78,6 +80,7 @@ func NewRTCService(
isDev: conf.Development,
limits: conf.Limit,
parser: uaparser.NewFromSaved(),
agentClient: agentClient,
telemetry: telemetry,
connections: map[*websocket.Conn]struct{}{},
}
@@ -229,6 +232,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
logger.Warnw("failed to start connection, retrying", err, fieldsWithAttempt...)
}
}
if err != nil {
prometheus.IncrementParticipantJoinFail(1)
handleError(w, http.StatusInternalServerError, err, loggerFields...)
@@ -514,10 +518,16 @@ type connectionResult struct {
ResponseSource routing.MessageSource
}
func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomName, pi routing.ParticipantInit, timeout time.Duration) (connectionResult, *livekit.SignalResponse, error) {
func (s *RTCService) startConnection(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
timeout time.Duration,
) (connectionResult, *livekit.SignalResponse, error) {
var cr connectionResult
var created bool
var err error
cr.Room, _, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: string(roomName)})
cr.Room, created, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: string(roomName)})
if err != nil {
return cr, nil, err
}
@@ -538,6 +548,17 @@ func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomN
cr.ResponseSource.Close()
return cr, nil, err
}
if created && s.agentClient != nil {
go func() {
s.agentClient.JobRequest(ctx, &livekit.Job{
Id: putil.NewGuid("JR_"),
Type: livekit.JobType_JT_ROOM,
Room: cr.Room,
})
}()
}
return cr, initialResponse, nil
}
@@ -559,5 +580,4 @@ func readInitialResponse(source routing.MessageSource, timeout time.Duration) (*
return res, nil
}
}
}
+4
View File
@@ -48,6 +48,7 @@ type LivekitServer struct {
config *config.Config
ioService *IOInfoService
rtcService *RTCService
agentService *AgentService
httpServer *http.Server
promServer *http.Server
router routing.Router
@@ -66,6 +67,7 @@ func NewLivekitServer(conf *config.Config,
ingressService *IngressService,
ioService *IOInfoService,
rtcService *RTCService,
agentService *AgentService,
keyProvider auth.KeyProvider,
router routing.Router,
roomManager *RoomManager,
@@ -77,6 +79,7 @@ func NewLivekitServer(conf *config.Config,
config: conf,
ioService: ioService,
rtcService: rtcService,
agentService: agentService,
router: router,
roomManager: roomManager,
signalServer: signalServer,
@@ -125,6 +128,7 @@ func NewLivekitServer(conf *config.Config,
mux.Handle(egressServer.PathPrefix(), egressServer)
mux.Handle(ingressServer.PathPrefix(), ingressServer)
mux.Handle("/rtc", rtcService)
mux.Handle("/agent", agentService)
mux.HandleFunc("/rtc/validate", rtcService.Validate)
mux.HandleFunc("/", s.defaultHandler)
+4 -1
View File
@@ -30,6 +30,7 @@ import (
"github.com/livekit/livekit-server/pkg/clientconfiguration"
"github.com/livekit/livekit-server/pkg/config"
"github.com/livekit/livekit-server/pkg/routing"
"github.com/livekit/livekit-server/pkg/rtc"
"github.com/livekit/livekit-server/pkg/telemetry"
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
"github.com/livekit/protocol/auth"
@@ -62,16 +63,18 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
NewIOInfoService,
wire.Bind(new(IOClient), new(*IOInfoService)),
rpc.NewEgressClient,
rpc.NewIngressClient,
getEgressStore,
NewEgressLauncher,
NewEgressService,
rpc.NewIngressClient,
getIngressStore,
getIngressConfig,
NewIngressService,
NewRoomAllocator,
NewRoomService,
NewRTCService,
NewAgentService,
rtc.NewAgentClient,
getSignalRelayConfig,
NewDefaultSignalServer,
routing.NewSignalClient,
+14 -5
View File
@@ -11,6 +11,7 @@ import (
"github.com/livekit/livekit-server/pkg/clientconfiguration"
"github.com/livekit/livekit-server/pkg/config"
"github.com/livekit/livekit-server/pkg/routing"
"github.com/livekit/livekit-server/pkg/rtc"
"github.com/livekit/livekit-server/pkg/telemetry"
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
"github.com/livekit/protocol/auth"
@@ -55,6 +56,10 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
if err != nil {
return nil, err
}
agentClient, err := rtc.NewAgentClient(messageBus)
if err != nil {
return nil, err
}
egressClient, err := rpc.NewEgressClient(messageBus)
if err != nil {
return nil, err
@@ -86,22 +91,26 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
if err != nil {
return nil, err
}
roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, roomClient, participantClient)
roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, agentClient, rtcEgressLauncher, topicFormatter, roomClient, participantClient)
if err != nil {
return nil, err
}
egressService := NewEgressService(egressClient, objectStore, ioInfoService, roomService, rtcEgressLauncher)
egressService := NewEgressService(egressClient, rtcEgressLauncher, objectStore, ioInfoService, roomService)
ingressConfig := getIngressConfig(conf)
ingressClient, err := rpc.NewIngressClient(messageBus)
if err != nil {
return nil, err
}
ingressService := NewIngressService(ingressConfig, nodeID, messageBus, ingressClient, ingressStore, roomService, telemetryService)
rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, telemetryService)
rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, agentClient, telemetryService)
agentService, err := NewAgentService(messageBus)
if err != nil {
return nil, err
}
clientConfigurationManager := createClientConfiguration()
timedVersionGenerator := utils.NewDefaultTimedVersionGenerator()
turnAuthHandler := NewTURNAuthHandler(keyProvider)
roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus)
roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, agentClient, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus)
if err != nil {
return nil, err
}
@@ -114,7 +123,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
if err != nil {
return nil, err
}
livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, keyProvider, router, roomManager, signalServer, server, currentNode)
livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, agentService, keyProvider, router, roomManager, signalServer, server, currentNode)
if err != nil {
return nil, err
}
+56
View File
@@ -83,6 +83,40 @@ func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, int, error)
}
}
func (c *WSSignalConnection) ReadWorkerMessage() (*livekit.WorkerMessage, int, error) {
for {
// handle special messages and pass on the rest
messageType, payload, err := c.conn.ReadMessage()
if err != nil {
return nil, 0, err
}
msg := &livekit.WorkerMessage{}
switch messageType {
case websocket.BinaryMessage:
if c.useJSON {
c.mu.Lock()
// switch to protobuf if client supports it
c.useJSON = false
c.mu.Unlock()
}
// protobuf encoded
err := proto.Unmarshal(payload, msg)
return msg, len(payload), err
case websocket.TextMessage:
c.mu.Lock()
// json encoded, also write back JSON
c.useJSON = true
c.mu.Unlock()
err := protojson.Unmarshal(payload, msg)
return msg, len(payload), err
default:
logger.Debugw("unsupported message", "message", messageType)
return nil, len(payload), nil
}
}
}
func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) (int, error) {
var msgType int
var payload []byte
@@ -105,6 +139,28 @@ func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) (int, er
return len(payload), c.conn.WriteMessage(msgType, payload)
}
func (c *WSSignalConnection) WriteServerMessage(msg *livekit.ServerMessage) (int, error) {
var msgType int
var payload []byte
var err error
c.mu.Lock()
defer c.mu.Unlock()
if c.useJSON {
msgType = websocket.TextMessage
payload, err = protojson.Marshal(msg)
} else {
msgType = websocket.BinaryMessage
payload, err = proto.Marshal(msg)
}
if err != nil {
return 0, err
}
return len(payload), c.conn.WriteMessage(msgType, payload)
}
func (c *WSSignalConnection) pingWorker() {
for {
<-time.After(pingFrequency)
+39 -17
View File
@@ -16,8 +16,8 @@ package audio
import (
"math"
"go.uber.org/atomic"
"sync"
"time"
)
const (
@@ -40,11 +40,13 @@ type AudioLevel struct {
smoothFactor float64
activeThreshold float64
smoothedLevel atomic.Float64
lock sync.Mutex
smoothedLevel float64
loudestObservedLevel uint8
activeDuration uint32 // ms
observedDuration uint32 // ms
lastObservedAt time.Time
}
func NewAudioLevel(params AudioLevelParams) *AudioLevel {
@@ -64,8 +66,13 @@ func NewAudioLevel(params AudioLevelParams) *AudioLevel {
return l
}
// Observes a new frame, must be called from the same thread
func (l *AudioLevel) Observe(level uint8, durationMs uint32) {
// Observes a new frame
func (l *AudioLevel) Observe(level uint8, durationMs uint32, arrivalTime time.Time) {
l.lock.Lock()
defer l.lock.Unlock()
l.lastObservedAt = arrivalTime
l.observedDuration += durationMs
if level <= l.params.ActiveLevel {
@@ -76,6 +83,7 @@ func (l *AudioLevel) Observe(level uint8, durationMs uint32) {
}
if l.observedDuration >= l.params.ObserveDuration {
smoothedLevel := float64(0.0)
// compute and reset
if l.activeDuration >= l.minActiveDuration {
// adjust loudest observed level by how much of the window was active.
@@ -87,25 +95,39 @@ func (l *AudioLevel) Observe(level uint8, durationMs uint32) {
linearLevel := ConvertAudioLevel(adjustedLevel)
// exponential smoothing to dampen transients
smoothedLevel := l.smoothedLevel.Load()
smoothedLevel += (linearLevel - smoothedLevel) * l.smoothFactor
l.smoothedLevel.Store(smoothedLevel)
} else {
l.smoothedLevel.Store(0)
smoothedLevel = l.smoothedLevel + (linearLevel-l.smoothedLevel)*l.smoothFactor
}
l.loudestObservedLevel = silentAudioLevel
l.activeDuration = 0
l.observedDuration = 0
l.resetLocked(smoothedLevel)
}
}
// returns current soothed audio level
func (l *AudioLevel) GetLevel() (float64, bool) {
smoothedLevel := l.smoothedLevel.Load()
active := smoothedLevel >= l.activeThreshold
return smoothedLevel, active
func (l *AudioLevel) GetLevel(now time.Time) (float64, bool) {
l.lock.Lock()
defer l.lock.Unlock()
l.resetIfStaleLocked(now)
return l.smoothedLevel, l.smoothedLevel >= l.activeThreshold
}
func (l *AudioLevel) resetIfStaleLocked(arrivalTime time.Time) {
if arrivalTime.Sub(l.lastObservedAt).Milliseconds() < int64(2*l.params.ObserveDuration) {
return
}
l.resetLocked(0.0)
}
func (l *AudioLevel) resetLocked(smoothedLevel float64) {
l.smoothedLevel = smoothedLevel
l.loudestObservedLevel = silentAudioLevel
l.activeDuration = 0
l.observedDuration = 0
}
// ---------------------------------------------------
// convert decibel back to linear
func ConvertAudioLevel(level float64) float64 {
return math.Pow(10, level*negInv20)
+49 -15
View File
@@ -16,6 +16,7 @@ package audio
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
@@ -30,46 +31,79 @@ const (
func TestAudioLevel(t *testing.T) {
t.Run("initially to return not noisy, within a few samples", func(t *testing.T) {
clock := time.Now()
a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration)
_, noisy := a.GetLevel()
_, noisy := a.GetLevel(clock)
require.False(t, noisy)
observeSamples(a, 28, 5)
_, noisy = a.GetLevel()
observeSamples(a, 28, 5, clock)
clock = clock.Add(5 * 20 * time.Millisecond)
_, noisy = a.GetLevel(clock)
require.False(t, noisy)
})
t.Run("not noisy when all samples are below threshold", func(t *testing.T) {
clock := time.Now()
a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration)
observeSamples(a, 35, 100)
_, noisy := a.GetLevel()
observeSamples(a, 35, 100, clock)
clock = clock.Add(100 * 20 * time.Millisecond)
_, noisy := a.GetLevel(clock)
require.False(t, noisy)
})
t.Run("not noisy when less than percentile samples are above threshold", func(t *testing.T) {
clock := time.Now()
a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration)
observeSamples(a, 35, samplesPerBatch-2)
observeSamples(a, 25, 1)
observeSamples(a, 35, 1)
observeSamples(a, 35, samplesPerBatch-2, clock)
clock = clock.Add((samplesPerBatch - 2) * 20 * time.Millisecond)
observeSamples(a, 25, 1, clock)
clock = clock.Add(20 * time.Millisecond)
observeSamples(a, 35, 1, clock)
clock = clock.Add(20 * time.Millisecond)
_, noisy := a.GetLevel()
_, noisy := a.GetLevel(clock)
require.False(t, noisy)
})
t.Run("noisy when higher than percentile samples are above threshold", func(t *testing.T) {
clock := time.Now()
a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration)
observeSamples(a, 35, samplesPerBatch-16)
observeSamples(a, 25, 8)
observeSamples(a, 29, 8)
observeSamples(a, 35, samplesPerBatch-16, clock)
clock = clock.Add((samplesPerBatch - 16) * 20 * time.Millisecond)
observeSamples(a, 25, 8, clock)
clock = clock.Add(8 * 20 * time.Millisecond)
observeSamples(a, 29, 8, clock)
clock = clock.Add(8 * 20 * time.Millisecond)
level, noisy := a.GetLevel()
level, noisy := a.GetLevel(clock)
require.True(t, noisy)
require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel)))
require.Less(t, level, ConvertAudioLevel(float64(25)))
})
t.Run("not noisy when samples are stale", func(t *testing.T) {
clock := time.Now()
a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration)
observeSamples(a, 25, 100, clock)
clock = clock.Add(100 * 20 * time.Millisecond)
level, noisy := a.GetLevel(clock)
require.True(t, noisy)
require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel)))
require.Less(t, level, ConvertAudioLevel(float64(20)))
// let enough time pass to make the samples stale
clock = clock.Add(1500 * time.Millisecond)
level, noisy = a.GetLevel(clock)
require.Equal(t, float64(0.0), level)
require.False(t, noisy)
})
}
func createAudioLevel(activeLevel uint8, minPercentile uint8, observeDuration uint32) *AudioLevel {
@@ -80,8 +114,8 @@ func createAudioLevel(activeLevel uint8, minPercentile uint8, observeDuration ui
})
}
func observeSamples(a *AudioLevel, level uint8, count int) {
func observeSamples(a *AudioLevel, level uint8, count int, baseTime time.Time) {
for i := 0; i < count; i++ {
a.Observe(level, 20)
a.Observe(level, 20, baseTime.Add(+time.Duration(i*20)*time.Millisecond))
}
}
+3 -7
View File
@@ -31,6 +31,7 @@ import (
"go.uber.org/atomic"
"github.com/livekit/livekit-server/pkg/sfu/audio"
dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor"
"github.com/livekit/livekit-server/pkg/sfu/utils"
sutils "github.com/livekit/livekit-server/pkg/utils"
"github.com/livekit/mediatransportutil"
@@ -39,8 +40,6 @@ import (
"github.com/livekit/mediatransportutil/pkg/twcc"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor"
)
const (
@@ -593,7 +592,7 @@ func (b *Buffer) processHeaderExtensions(p *rtp.Packet, arrivalTime time.Time) {
if (p.Timestamp - b.latestTSForAudioLevel) < (1 << 31) {
duration := (int64(p.Timestamp) - int64(b.latestTSForAudioLevel)) * 1e3 / int64(b.clockRate)
if duration > 0 {
b.audioLevel.Observe(ext.Level, uint32(duration))
b.audioLevel.Observe(ext.Level, uint32(duration), arrivalTime)
}
b.latestTSForAudioLevel = p.Timestamp
@@ -624,9 +623,6 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime time.Time, flow
if b.ddParser != nil {
ddVal, videoLayer, err := b.ddParser.Parse(ep.Packet)
if err != nil {
if err != ErrFrameEarlierThanKeyFrame {
b.logger.Warnw("could not parse dependency descriptor", err)
}
return nil
} else if ddVal != nil {
ep.DependencyDescriptor = ddVal
@@ -855,7 +851,7 @@ func (b *Buffer) GetAudioLevel() (float64, bool) {
return 0, false
}
return b.audioLevel.GetLevel()
return b.audioLevel.GetLevel(time.Now())
}
func (b *Buffer) OnFpsChanged(f func()) {
+23 -13
View File
@@ -27,7 +27,8 @@ import (
)
var (
ErrFrameEarlierThanKeyFrame = fmt.Errorf("frame is earlier than current keyframe")
ErrFrameEarlierThanKeyFrame = fmt.Errorf("frame is earlier than current keyframe")
ErrDDStructureAttachedToNonFirstPacket = fmt.Errorf("dependency descriptor structure is attached to non-first packet of a frame")
)
type DependencyDescriptorParser struct {
@@ -39,7 +40,6 @@ type DependencyDescriptorParser struct {
seqWrapAround *utils.WrapAround[uint16, uint64]
frameWrapAround *utils.WrapAround[uint16, uint64]
structureExtSeq uint64
structureExtFrameNum uint64
activeDecodeTargetsExtSeq uint64
activeDecodeTargetsMask uint32
@@ -66,12 +66,15 @@ type ExtDependencyDescriptor struct {
ActiveDecodeTargetsUpdated bool
Integrity bool
ExtFrameNum uint64
// the frame number of the keyframe which the current frame depends on
ExtKeyFrameNum uint64
}
func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescriptor, VideoLayer, error) {
var videoLayer VideoLayer
ddBuf := pkt.GetExtension(r.ddExtID)
if ddBuf == nil {
r.logger.Warnw("dependency descriptor extension is not present", nil, "seq", pkt.SequenceNumber)
return nil, videoLayer, nil
}
@@ -82,7 +85,9 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr
}
_, err := ext.Unmarshal(ddBuf)
if err != nil {
// r.logger.Debugw("failed to parse generic dependency descriptor", "err", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf))
if err != dd.ErrDDReaderNoStructure {
r.logger.Warnw("failed to parse generic dependency descriptor", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf))
}
return nil, videoLayer, err
}
@@ -108,17 +113,21 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr
}
if ddVal.AttachedStructure != nil {
r.logger.Debugw("parsed dependency descriptor", "extSeq", extSeq, "extFN", extFN, "structureID", ddVal.AttachedStructure.StructureId, "descriptor", ddVal.String())
if extSeq > r.structureExtSeq {
r.structure = ddVal.AttachedStructure
r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure)
r.structureExtSeq = extSeq
r.structureExtFrameNum = extFN
extDD.StructureUpdated = true
extDD.ActiveDecodeTargetsUpdated = true
// The dependency descriptor reader will always set ActiveDecodeTargetsBitmask for TemplateDependencyStructure is present,
// so don't need to notify max layer change here.
if !ddVal.FirstPacketInFrame {
r.logger.Warnw("attached structure is not the first packet in frame", nil, "extSeq", extSeq, "extFN", extFN)
return nil, videoLayer, ErrDDStructureAttachedToNonFirstPacket
}
if r.structure == nil || ddVal.AttachedStructure.StructureId != r.structure.StructureId {
r.logger.Infow("structure updated", "structureID", ddVal.AttachedStructure.StructureId, "extSeq", extSeq, "extFN", extFN, "descriptor", ddVal.String())
}
r.structure = ddVal.AttachedStructure
r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure)
r.structureExtFrameNum = extFN
extDD.StructureUpdated = true
extDD.ActiveDecodeTargetsUpdated = true
// The dependency descriptor reader will always set ActiveDecodeTargetsBitmask for TemplateDependencyStructure is present,
// so don't need to notify max layer change here.
}
if mask := ddVal.ActiveDecodeTargetsBitmask; mask != nil && extSeq > r.activeDecodeTargetsExtSeq {
@@ -143,6 +152,7 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr
}
extDD.DecodeTargets = r.decodeTargets
extDD.ExtKeyFrameNum = r.structureExtFrameNum
return extDD, videoLayer, nil
}
+23 -24
View File
@@ -129,31 +129,30 @@ func (r *RTPStatsReceiver) Update(
pktSize := uint64(hdrSize + payloadSize + paddingSize)
gapSN := int64(resSN.ExtendedVal - resSN.PreExtendedHighest)
if gapSN <= 0 { // duplicate OR out-of-order
if payloadSize == 0 {
// do not start on a padding only packet
if resTS.IsRestart {
r.logger.Infow(
"rolling back timestamp restart",
"tsBefore", resTS.PreExtendedStart,
"tsAfter", r.timestamp.GetExtendedStart(),
"snBefore", resSN.PreExtendedStart,
"snAfter", r.sequenceNumber.GetExtendedStart(),
)
r.timestamp.RollbackRestart(resTS.PreExtendedStart)
}
if resSN.IsRestart {
r.logger.Infow(
"rolling back sequence number restart",
"snBefore", resSN.PreExtendedStart,
"snAfter", r.sequenceNumber.GetExtendedStart(),
"tsBefore", resTS.PreExtendedStart,
"tsAfter", r.timestamp.GetExtendedStart(),
)
r.sequenceNumber.RollbackRestart(resSN.PreExtendedStart)
flowState.IsNotHandled = true
return
}
// before start, don't restart
if resTS.IsRestart {
r.logger.Infow(
"rolling back timestamp restart",
"tsBefore", resTS.PreExtendedStart,
"tsAfter", r.timestamp.GetExtendedStart(),
"snBefore", resSN.PreExtendedStart,
"snAfter", r.sequenceNumber.GetExtendedStart(),
)
r.timestamp.RollbackRestart(resTS.PreExtendedStart)
}
if resSN.IsRestart {
r.logger.Infow(
"rolling back sequence number restart",
"snBefore", resSN.PreExtendedStart,
"snAfter", r.sequenceNumber.GetExtendedStart(),
"tsBefore", resTS.PreExtendedStart,
"tsAfter", r.timestamp.GetExtendedStart(),
)
r.sequenceNumber.RollbackRestart(resSN.PreExtendedStart)
flowState.IsNotHandled = true
return
}
if -gapSN >= cNumSequenceNumbers/2 {
r.logger.Warnw(
"large sequence number gap negative", nil,
+17 -15
View File
@@ -130,7 +130,7 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
require.Equal(t, timestamp, r.timestamp.GetHighest())
require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest()))
// out-of-order
// out-of-order, would cause a restart which is disallowed
packet = getPacket(sequenceNumber-10, timestamp-30000, 1000)
flowState = r.Update(
time.Now(),
@@ -142,14 +142,15 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
0,
)
require.False(t, flowState.HasLoss)
require.True(t, flowState.IsNotHandled)
require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest())
require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest()))
require.Equal(t, timestamp, r.timestamp.GetHighest())
require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest()))
require.Equal(t, uint64(1), r.packetsOutOfOrder)
require.Equal(t, uint64(0), r.packetsOutOfOrder)
require.Equal(t, uint64(0), r.packetsDuplicate)
// duplicate
// duplicate of the above out-of-order packet, but would not be handled as it causes a restart
packet = getPacket(sequenceNumber-10, timestamp-30000, 1000)
flowState = r.Update(
time.Now(),
@@ -161,12 +162,13 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
0,
)
require.False(t, flowState.HasLoss)
require.True(t, flowState.IsNotHandled)
require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest())
require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest()))
require.Equal(t, timestamp, r.timestamp.GetHighest())
require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest()))
require.Equal(t, uint64(2), r.packetsOutOfOrder)
require.Equal(t, uint64(1), r.packetsDuplicate)
require.Equal(t, uint64(0), r.packetsOutOfOrder)
require.Equal(t, uint64(0), r.packetsDuplicate)
// loss
sequenceNumber += 10
@@ -184,10 +186,10 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
require.True(t, flowState.HasLoss)
require.Equal(t, uint64(sequenceNumber-9), flowState.LossStartInclusive)
require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive)
require.Equal(t, uint64(17), r.packetsLost)
require.Equal(t, uint64(9), r.packetsLost)
// out-of-order should decrement number of lost packets
packet = getPacket(sequenceNumber-15, timestamp-45000, 1000)
packet = getPacket(sequenceNumber-6, timestamp-45000, 1000)
flowState = r.Update(
time.Now(),
packet.Header.SequenceNumber,
@@ -202,9 +204,9 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest()))
require.Equal(t, timestamp, r.timestamp.GetHighest())
require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest()))
require.Equal(t, uint64(3), r.packetsOutOfOrder)
require.Equal(t, uint64(1), r.packetsDuplicate)
require.Equal(t, uint64(16), r.packetsLost)
require.Equal(t, uint64(1), r.packetsOutOfOrder)
require.Equal(t, uint64(0), r.packetsDuplicate)
require.Equal(t, uint64(8), r.packetsLost)
// test sequence number history
// with a gap
@@ -223,7 +225,7 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
require.True(t, flowState.HasLoss)
require.Equal(t, uint64(sequenceNumber-1), flowState.LossStartInclusive)
require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive)
require.Equal(t, uint64(17), r.packetsLost)
require.Equal(t, uint64(9), r.packetsLost)
require.False(t, r.history.IsSet(uint64(sequenceNumber)-1))
// out-of-order
@@ -240,8 +242,8 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
0,
)
require.False(t, flowState.HasLoss)
require.Equal(t, uint64(16), r.packetsLost)
require.Equal(t, uint64(4), r.packetsOutOfOrder)
require.Equal(t, uint64(8), r.packetsLost)
require.Equal(t, uint64(2), r.packetsOutOfOrder)
require.True(t, r.history.IsSet(uint64(sequenceNumber)))
// padding only
@@ -257,8 +259,8 @@ func Test_RTPStatsReceiver_Update(t *testing.T) {
25,
)
require.False(t, flowState.HasLoss)
require.Equal(t, uint64(16), r.packetsLost)
require.Equal(t, uint64(4), r.packetsOutOfOrder)
require.Equal(t, uint64(8), r.packetsLost)
require.Equal(t, uint64(2), r.packetsOutOfOrder)
require.True(t, r.history.IsSet(uint64(sequenceNumber)))
require.True(t, r.history.IsSet(uint64(sequenceNumber)-1))
require.True(t, r.history.IsSet(uint64(sequenceNumber)-2))
@@ -18,6 +18,18 @@ import (
"errors"
)
var (
ErrDDReaderNoStructure = errors.New("DependencyDescriptorReader: Structure is nil")
ErrDDReaderTemplateWithoutStructure = errors.New("DependencyDescriptorReader: has templateDependencyStructurePresentFlag but AttachedStructure is nil")
ErrDDReaderTooManyTemplates = errors.New("DependencyDescriptorReader: too many templates")
ErrDDReaderTooManyTemporalLayers = errors.New("DependencyDescriptorReader: too many temporal layers")
ErrDDReaderTooManySpatialLayers = errors.New("DependencyDescriptorReader: too many spatial layers")
ErrDDReaderInvalidTemplateIndex = errors.New("DependencyDescriptorReader: invalid template index")
ErrDDReaderInvalidSpatialLayer = errors.New("DependencyDescriptorReader: invalid spatial layer, should be less than the number of resolutions")
ErrDDReaderNumDTIMismatch = errors.New("DependencyDescriptorReader: decode target indications length mismatch with structure num decode targets")
ErrDDReaderNumChainDiffsMismatch = errors.New("DependencyDescriptorReader: chain diffs length mismatch with structure num chains")
)
type DependencyDescriptorReader struct {
// Output.
descriptor *DependencyDescriptor
@@ -59,7 +71,7 @@ func (r *DependencyDescriptorReader) Parse() (int, error) {
if r.structure == nil {
r.buffer.Invalidate()
return 0, errors.New("DependencyDescriptorReader: Structure is nil")
return 0, ErrDDReaderNoStructure
}
if r.activeDecodeTargetsPresentFlag {
@@ -140,7 +152,7 @@ func (r *DependencyDescriptorReader) readExtendedFields() error {
return err
}
if r.descriptor.AttachedStructure == nil {
return errors.New("DependencyDescriptorReader: has templateDependencyStructurePresentFlag but AttachedStructure is nil")
return ErrDDReaderTemplateWithoutStructure
}
bitmask := uint32((uint64(1) << r.descriptor.AttachedStructure.NumDecodeTargets) - 1)
r.descriptor.ActiveDecodeTargetsBitmask = &bitmask
@@ -203,7 +215,7 @@ func (r *DependencyDescriptorReader) readTemplateLayers() error {
)
for {
if len(templates) == MaxTemplates {
return errors.New("DependencyDescriptorReader: too many templates")
return ErrDDReaderTooManyTemplates
}
var lastTemplate FrameDependencyTemplate
@@ -220,13 +232,13 @@ func (r *DependencyDescriptorReader) readTemplateLayers() error {
if nextLayerIdc == nextTemporalLayer {
temporalId++
if temporalId >= MaxTemporalIds {
return errors.New("DependencyDescriptorReader: too many temporal layers")
return ErrDDReaderTooManyTemporalLayers
}
} else if nextLayerIdc == nextSpatialLayer {
spatialId++
temporalId = 0
if spatialId >= MaxSpatialIds {
return errors.New("DependencyDescriptorReader: too many spatial layers")
return ErrDDReaderTooManySpatialLayers
}
}
@@ -340,7 +352,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error {
if templateIndex >= len(r.structure.Templates) {
r.buffer.Invalidate()
return errors.New("DependencyDescriptorReader: invalid template index")
return ErrDDReaderInvalidTemplateIndex
}
// Copy all the fields from the matching template
@@ -374,7 +386,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error {
// then each spatial layer got one.
if r.descriptor.FrameDependencies.SpatialId >= len(r.structure.Resolutions) {
r.buffer.Invalidate()
return errors.New("DependencyDescriptorReader: invalid spatial layer, should be less than the number of resolutions")
return ErrDDReaderInvalidSpatialLayer
}
res := r.structure.Resolutions[r.descriptor.FrameDependencies.SpatialId]
r.descriptor.Resolution = &res
@@ -385,7 +397,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error {
func (r *DependencyDescriptorReader) readFrameDtis() error {
if len(r.descriptor.FrameDependencies.DecodeTargetIndications) != r.structure.NumDecodeTargets {
return errors.New("DependencyDescriptorReader: decode target indications length mismatch with structure num decode targets")
return ErrDDReaderNumDTIMismatch
}
for i := range r.descriptor.FrameDependencies.DecodeTargetIndications {
@@ -420,7 +432,7 @@ func (r *DependencyDescriptorReader) readFrameFdiffs() error {
func (r *DependencyDescriptorReader) readFrameChains() error {
if len(r.descriptor.FrameDependencies.ChainDiffs) != r.structure.NumChains {
return errors.New("DependencyDescriptorReader: chain diffs length mismatch with structure num chains")
return ErrDDReaderNumChainDiffsMismatch
}
for i := range r.descriptor.FrameDependencies.ChainDiffs {
@@ -148,7 +148,7 @@ func (w *DependencyDescriptorWriter) calculateMatch(idx int, template *FrameDepe
result.NeedCustomDtis = w.descriptor.FrameDependencies.DecodeTargetIndications != nil && !reflect.DeepEqual(w.descriptor.FrameDependencies.DecodeTargetIndications, template.DecodeTargetIndications)
for i := 0; i < w.structure.NumChains; i++ {
if w.activeChains&(1<<i) != 0 && w.descriptor.FrameDependencies.ChainDiffs[i] != template.ChainDiffs[i] {
if w.activeChains&(1<<i) != 0 && (len(w.descriptor.FrameDependencies.ChainDiffs) <= i || len(template.ChainDiffs) <= i || w.descriptor.FrameDependencies.ChainDiffs[i] != template.ChainDiffs[i]) {
result.NeedCustomChains = true
break
}
+1 -1
View File
@@ -193,7 +193,7 @@ func (s *StreamTrackerDependencyDescriptor) Observe(temporalLayer int32, pktSize
for _, dt := range ddVal.DecodeTargets {
if len(dtis) <= dt.Target {
s.params.Logger.Errorw("len(dtis) less than target", nil, "target", dt.Target, "dtls", dtis)
s.params.Logger.Errorw("len(dtis) less than target", nil, "target", dt.Target, "dtis", dtis)
continue
}
// we are not dropping discardable frames now, so only ingore not present frames
+6 -2
View File
@@ -214,9 +214,9 @@ func (s *StreamTrackerManager) AddTracker(layer int32) streamtracker.StreamTrack
})
}
s.logger.Debugw("StreamTrackerManager add track", "layer", layer)
s.logger.Debugw("stream tracker add track", "layer", layer)
tracker.OnStatusChanged(func(status streamtracker.StreamStatus) {
s.logger.Debugw("StreamTrackerManager OnStatusChanged", "layer", layer, "status", status)
s.logger.Debugw("stream tracker status changed", "layer", layer, "status", status)
if status == streamtracker.StreamStatusStopped {
s.removeAvailableLayer(layer)
} else {
@@ -289,6 +289,10 @@ func (s *StreamTrackerManager) GetTracker(layer int32) streamtracker.StreamTrack
s.lock.RLock()
defer s.lock.RUnlock()
if int(layer) >= len(s.trackers) {
s.logger.Errorw("unexpected layer", nil, "layer", layer)
return nil
}
return s.trackers[layer]
}
@@ -16,6 +16,7 @@ package videolayerselector
import (
"fmt"
"runtime/debug"
"sync"
"github.com/livekit/livekit-server/pkg/sfu/buffer"
@@ -31,6 +32,8 @@ type DependencyDescriptor struct {
previousActiveDecodeTargetsBitmask *uint32
activeDecodeTargetsBitmask *uint32
structure *dede.FrameDependencyStructure
extKeyFrameNum uint64
keyFrameValid bool
chains []*FrameChain
@@ -79,38 +82,76 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
Temporal: int32(fd.TemporalId),
}
if !d.keyFrameValid && dd.AttachedStructure == nil {
return
}
// early return if this frame is already forwarded or dropped
sd, err := d.decisions.GetDecision(extFrameNum)
if err != nil {
// do not mark as dropped as only error is an old frame
d.logger.Debugw(fmt.Sprintf("drop packet on decision error, incoming %v, fn: %d/%d, sn: %d",
incomingLayer,
dd.FrameNumber,
extFrameNum,
extPkt.Packet.SequenceNumber,
), "err", err)
// d.logger.Debugw(fmt.Sprintf("drop packet on decision error, incoming %v, fn: %d/%d, sn: %d",
// incomingLayer,
// dd.FrameNumber,
// extFrameNum,
// extPkt.Packet.SequenceNumber,
// ), "err", err)
return
}
switch sd {
case selectorDecisionDropped:
// a packet of an alreadty dropped frame, maintain decision
d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d",
incomingLayer,
dd.FrameNumber,
extFrameNum,
extPkt.Packet.SequenceNumber,
))
// d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d",
// incomingLayer,
// dd.FrameNumber,
// extFrameNum,
// extPkt.Packet.SequenceNumber,
// ))
return
}
if ddwdt.StructureUpdated {
d.updateDependencyStructure(dd.AttachedStructure, ddwdt.DecodeTargets)
// TODO-REMOVE: remove this log after stable
d.logger.Infow("update dependency structure",
"structureID", dd.AttachedStructure.StructureId,
"structure", dd.AttachedStructure,
"decodeTargets", ddwdt.DecodeTargets,
"efn", extFrameNum,
"sn", extPkt.Packet.SequenceNumber,
"isKeyFrame", extPkt.KeyFrame,
"currentKeyframe", d.extKeyFrameNum,
)
d.updateDependencyStructure(dd.AttachedStructure, ddwdt.DecodeTargets, extFrameNum)
}
if ddwdt.ExtKeyFrameNum != d.extKeyFrameNum {
// keyframe mismatch, drop and reset chains
// TODO-REMOVE: remove this log after stable
d.logger.Infow("drop packet for keyframe mismatch", "incoming", incomingLayer, "efn", extFrameNum, "sn", extPkt.Packet.SequenceNumber, "requiredKeyFrame", ddwdt.ExtKeyFrameNum, "structureKeyFrame", d.extKeyFrameNum)
d.decisions.AddDropped(extFrameNum)
d.invalidateKeyFrame()
return
}
if ddwdt.ActiveDecodeTargetsUpdated {
d.updateActiveDecodeTargets(*dd.ActiveDecodeTargetsBitmask)
}
// TODO-REMOVE: remove this log after stable
if len(fd.ChainDiffs) != len(d.chains) {
d.logger.Warnw("frame chain diff length mismatch", nil,
"incoming", incomingLayer,
"efn", extFrameNum,
"sn", extPkt.Packet.SequenceNumber,
"chainDiffs", fd.ChainDiffs,
"chains", len(d.chains),
"requiredKeyFrame", ddwdt.ExtKeyFrameNum,
"structureKeyFrame", d.extKeyFrameNum)
d.decisions.AddDropped(extFrameNum)
return
}
for _, chain := range d.chains {
chain.OnFrame(extFrameNum, fd)
}
@@ -133,7 +174,7 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
if err != nil {
d.decodeTargetsLock.RUnlock()
// dtis error, dependency descriptor might lost
d.logger.Debugw(fmt.Sprintf("drop packet for frame detection error, incoming: %v", incomingLayer), "err", err)
d.logger.Warnw(fmt.Sprintf("drop packet for frame detection error, incoming: %v", incomingLayer), err)
d.decisions.AddDropped(extFrameNum)
return
}
@@ -148,34 +189,34 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
if highestDecodeTarget.Target < 0 {
// no active decode target, do not select
d.logger.Debugw(
"drop packet for no target found",
"highestDecodeTarget", highestDecodeTarget,
"decodeTargets", d.decodeTargets,
"tagetLayer", d.targetLayer,
"incoming", incomingLayer,
"fn", dd.FrameNumber,
"efn", extFrameNum,
"sn", extPkt.Packet.SequenceNumber,
"isKeyFrame", extPkt.KeyFrame,
)
// d.logger.Debugw(
// "drop packet for no target found",
// "highestDecodeTarget", highestDecodeTarget,
// "decodeTargets", d.decodeTargets,
// "tagetLayer", d.targetLayer,
// "incoming", incomingLayer,
// "fn", dd.FrameNumber,
// "efn", extFrameNum,
// "sn", extPkt.Packet.SequenceNumber,
// "isKeyFrame", extPkt.KeyFrame,
// )
d.decisions.AddDropped(extFrameNum)
return
}
// DD-TODO : if bandwidth in congest, could drop the 'Discardable' frame
if dti == dede.DecodeTargetNotPresent {
d.logger.Debugw(
"drop packet for decode target not present",
"highestDecodeTarget", highestDecodeTarget,
"decodeTargets", d.decodeTargets,
"tagetLayer", d.targetLayer,
"incoming", incomingLayer,
"fn", dd.FrameNumber,
"efn", extFrameNum,
"sn", extPkt.Packet.SequenceNumber,
"isKeyFrame", extPkt.KeyFrame,
)
// d.logger.Debugw(
// "drop packet for decode target not present",
// "highestDecodeTarget", highestDecodeTarget,
// "decodeTargets", d.decodeTargets,
// "tagetLayer", d.targetLayer,
// "incoming", incomingLayer,
// "fn", dd.FrameNumber,
// "efn", extFrameNum,
// "sn", extPkt.Packet.SequenceNumber,
// "isKeyFrame", extPkt.KeyFrame,
// )
d.decisions.AddDropped(extFrameNum)
return
}
@@ -195,17 +236,17 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
}
}
if !isDecodable {
d.logger.Debugw(
"drop packet for not decodable",
"highestDecodeTarget", highestDecodeTarget,
"decodeTargets", d.decodeTargets,
"tagetLayer", d.targetLayer,
"incoming", incomingLayer,
"fn", dd.FrameNumber,
"efn", extFrameNum,
"sn", extPkt.Packet.SequenceNumber,
"isKeyFrame", extPkt.KeyFrame,
)
// d.logger.Debugw(
// "drop packet for not decodable",
// "highestDecodeTarget", highestDecodeTarget,
// "decodeTargets", d.decodeTargets,
// "tagetLayer", d.targetLayer,
// "incoming", incomingLayer,
// "fn", dd.FrameNumber,
// "efn", extFrameNum,
// "sn", extPkt.Packet.SequenceNumber,
// "isKeyFrame", extPkt.KeyFrame,
// )
d.decisions.AddDropped(extFrameNum)
return
}
@@ -263,11 +304,33 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
// d.logger.Debugw("set active decode targets bitmask", "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask)
}
}
bytes, err := ddExtension.Marshal()
if err != nil {
d.logger.Warnw("error marshalling dependency descriptor extension", err)
} else {
result.DependencyDescriptorExtension = bytes
var ddMarshaled bool
func() {
defer func() {
if r := recover(); r != nil {
d.logger.Errorw("panic marshalling dependency descriptor extension", nil,
"efn", extFrameNum,
"sn", extPkt.Packet.SequenceNumber,
"keyframeRequired", ddwdt.ExtKeyFrameNum,
"currentKeyframe", d.extKeyFrameNum,
"panic", r,
"stack", string(debug.Stack()))
}
}()
bytes, err := ddExtension.Marshal()
if err != nil {
d.logger.Warnw("error marshalling dependency descriptor extension", err)
} else {
result.DependencyDescriptorExtension = bytes
ddMarshaled = true
}
}()
if !ddMarshaled {
// drop packet if we can't marshal dependency descriptor
d.decisions.AddDropped(extFrameNum)
return
}
if ddwdt.Integrity {
@@ -284,8 +347,10 @@ func (d *DependencyDescriptor) Rollback() {
d.Base.Rollback()
}
func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure, decodeTargets []buffer.DependencyDescriptorDecodeTarget) {
func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure, decodeTargets []buffer.DependencyDescriptorDecodeTarget, extFrameNum uint64) {
d.structure = structure
d.extKeyFrameNum = extFrameNum
d.keyFrameValid = true
d.chains = d.chains[:0]
@@ -329,9 +394,17 @@ func (d *DependencyDescriptor) updateActiveDecodeTargets(activeDecodeTargetsBitm
}
}
func (d *DependencyDescriptor) invalidateKeyFrame() {
d.keyFrameValid = false
d.chains = d.chains[:0]
d.decodeTargetsLock.Lock()
d.decodeTargets = d.decodeTargets[:0]
d.decodeTargetsLock.Unlock()
}
func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) {
layer = d.GetRequestSpatial()
if !d.currentLayer.IsValid() {
if !d.currentLayer.IsValid() || !d.keyFrameValid {
// always declare not locked when trying to resume from nothing
return false, layer
}
@@ -339,7 +412,7 @@ func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) {
d.decodeTargetsLock.RLock()
defer d.decodeTargetsLock.RUnlock()
for _, dt := range d.decodeTargets {
if dt.Active() && dt.Layer.Spatial <= d.GetTarget().Spatial && dt.Valid() {
if dt.Active() && dt.Layer.Spatial == layer && dt.Valid() {
d.logger.Debugw(fmt.Sprintf("checking sync, matching decode target, layer: %d, dt: %s, dts: %+v", layer, dt, d.decodeTargets))
return true, layer
}
@@ -253,11 +253,23 @@ func TestDependencyDescriptor(t *testing.T) {
}
require.True(t, switchToLower)
// sync with requested layer
// not sync with requested layer
ddSelector.SetRequestSpatial(targetLayer.Spatial)
locked, layer := ddSelector.CheckSync()
require.True(t, locked)
require.False(t, locked)
require.Equal(t, targetLayer.Spatial, layer)
// request to current layer, sync
ddSelector.SetRequestSpatial(ddSelector.GetCurrent().Spatial)
locked, _ = ddSelector.CheckSync()
require.True(t, locked)
// should drop frame that relies on a keyframe is not present in current selection
framesPrevious := createDDFrames(buffer.VideoLayer{Spatial: 2, Temporal: 2}, 1000)
ret = ddSelector.Select(framesPrevious[1], 0)
require.False(t, ret.IsSelected)
// keyframe lost, out of sync
locked, _ = ddSelector.CheckSync()
require.False(t, locked)
}
func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buffer.ExtPacket {
@@ -279,7 +291,7 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff
return decodeTargets[i].Layer.GreaterThan(decodeTargets[j].Layer)
})
chainDiffs := make([]int, len(decodeTargets))
chainDiffs := make([]int, int(maxLayer.Spatial)+1)
dtis := make([]dd.DecodeTargetIndication, len(decodeTargets))
for _, dt := range decodeTargets {
dtis[dt.Target] = dd.DecodeTargetSwitch
@@ -319,6 +331,7 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff
ActiveDecodeTargetsUpdated: true,
Integrity: true,
ExtFrameNum: uint64(startFrameNumber),
ExtKeyFrameNum: uint64(startFrameNumber),
},
Packet: &rtp.Packet{
Header: rtp.Header{
@@ -356,7 +369,6 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff
}
frame := &buffer.ExtPacket{
KeyFrame: true,
DependencyDescriptor: &buffer.ExtDependencyDescriptor{
Descriptor: &dd.DependencyDescriptor{
FrameNumber: startFrameNumber,
@@ -367,9 +379,10 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff
DecodeTargetIndications: frameDtis,
},
},
DecodeTargets: decodeTargets,
Integrity: true,
ExtFrameNum: uint64(startFrameNumber),
DecodeTargets: decodeTargets,
Integrity: true,
ExtFrameNum: uint64(startFrameNumber),
ExtKeyFrameNum: keyFrame.DependencyDescriptor.ExtFrameNum,
},
Packet: &rtp.Packet{
Header: rtp.Header{
+158
View File
@@ -0,0 +1,158 @@
package test
import (
"fmt"
"net/http"
"net/url"
"sync"
"github.com/gorilla/websocket"
"go.uber.org/atomic"
"google.golang.org/protobuf/proto"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/utils"
)
type agentClient struct {
mu sync.Mutex
conn *websocket.Conn
registered atomic.Int32
roomAvailability atomic.Int32
roomJobs atomic.Int32
participantAvailability atomic.Int32
participantJobs atomic.Int32
done chan struct{}
}
func newAgentClient(token string) (*agentClient, error) {
host := fmt.Sprintf("ws://localhost:%d", defaultServerPort)
u, err := url.Parse(host + "/agent")
if err != nil {
return nil, err
}
requestHeader := make(http.Header)
requestHeader.Set("Authorization", "Bearer "+token)
connectUrl := u.String()
conn, _, err := websocket.DefaultDialer.Dial(connectUrl, requestHeader)
if err != nil {
return nil, err
}
return &agentClient{
conn: conn,
done: make(chan struct{}),
}, nil
}
func (c *agentClient) Run() error {
go c.read()
workerID := utils.NewGuid("W_")
if err := c.write(&livekit.WorkerMessage{
Message: &livekit.WorkerMessage_Register{
Register: &livekit.RegisterWorkerRequest{
Type: livekit.JobType_JT_ROOM,
WorkerId: workerID,
Version: "version",
Name: "name",
},
},
}); err != nil {
return err
}
if err := c.write(&livekit.WorkerMessage{
Message: &livekit.WorkerMessage_Register{
Register: &livekit.RegisterWorkerRequest{
Type: livekit.JobType_JT_PUBLISHER,
WorkerId: workerID,
Version: "version",
Name: "name",
},
},
}); err != nil {
return err
}
return nil
}
func (c *agentClient) read() {
for {
select {
case <-c.done:
return
default:
_, b, err := c.conn.ReadMessage()
if err != nil {
return
}
msg := &livekit.ServerMessage{}
if err = proto.Unmarshal(b, msg); err != nil {
return
}
switch m := msg.Message.(type) {
case *livekit.ServerMessage_Assignment:
go c.handleAssignment(m.Assignment)
case *livekit.ServerMessage_Availability:
go c.handleAvailability(m.Availability)
case *livekit.ServerMessage_Register:
go c.handleRegister(m.Register)
}
}
}
}
func (c *agentClient) handleAssignment(req *livekit.JobAssignment) {
if req.Job.Type == livekit.JobType_JT_ROOM {
c.roomJobs.Inc()
} else {
c.participantJobs.Inc()
}
}
func (c *agentClient) handleAvailability(req *livekit.AvailabilityRequest) {
if req.Job.Type == livekit.JobType_JT_ROOM {
c.roomAvailability.Inc()
} else {
c.participantAvailability.Inc()
}
c.write(&livekit.WorkerMessage{
Message: &livekit.WorkerMessage_Availability{
Availability: &livekit.AvailabilityResponse{
JobId: req.Job.Id,
Available: true,
},
},
})
}
func (c *agentClient) handleRegister(req *livekit.RegisterWorkerResponse) {
c.registered.Inc()
}
func (c *agentClient) write(msg *livekit.WorkerMessage) error {
c.mu.Lock()
defer c.mu.Unlock()
b, err := proto.Marshal(msg)
if err != nil {
return err
}
return c.conn.WriteMessage(websocket.BinaryMessage, b)
}
func (c *agentClient) close() {
close(c.done)
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
_ = c.conn.Close()
}
+83
View File
@@ -0,0 +1,83 @@
// Copyright 2023 LiveKit, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/livekit/protocol/auth"
)
func TestAgents(t *testing.T) {
_, finish := setupSingleNodeTest("TestAgents")
defer finish()
ac1, err := newAgentClient(agentToken())
require.NoError(t, err)
ac2, err := newAgentClient(agentToken())
require.NoError(t, err)
defer ac1.close()
defer ac2.close()
ac1.Run()
ac2.Run()
time.Sleep(time.Second * 3)
require.Equal(t, int32(2), ac1.registered.Load())
require.Equal(t, int32(2), ac2.registered.Load())
c1 := createRTCClient("c1", defaultServerPort, nil)
c2 := createRTCClient("c2", defaultServerPort, nil)
waitUntilConnected(t, c1, c2)
// publish 2 tracks
t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam")
require.NoError(t, err)
defer t1.Stop()
t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam")
require.NoError(t, err)
defer t2.Stop()
time.Sleep(time.Second * 3)
require.Equal(t, int32(1), ac1.roomJobs.Load()+ac2.roomJobs.Load())
require.Equal(t, int32(1), ac1.participantJobs.Load()+ac2.participantJobs.Load())
// publish 2 tracks
t3, err := c2.AddStaticTrack("audio/opus", "audio", "webcam")
require.NoError(t, err)
defer t3.Stop()
t4, err := c2.AddStaticTrack("video/vp8", "video", "webcam")
require.NoError(t, err)
defer t4.Stop()
time.Sleep(time.Second * 3)
require.Equal(t, int32(1), ac1.roomJobs.Load()+ac2.roomJobs.Load())
require.Equal(t, int32(2), ac1.participantJobs.Load()+ac2.participantJobs.Load())
}
func agentToken() string {
at := auth.NewAccessToken(testApiKey, testApiSecret).
AddGrant(&auth.VideoGrant{Agent: true})
t, err := at.ToJWT()
if err != nil {
panic(err)
}
return t
}