mirror of
https://github.com/livekit/livekit.git
synced 2026-04-26 13:07:39 +00:00
add test for removing disconnected participants on signal close (#1896)
* add test for removing disconnected participants on signal close * cleanup
This commit is contained in:
@@ -136,7 +136,7 @@ func createKeyProvider(conf *config.Config) (auth.KeyProvider, error) {
|
||||
if st, err := os.Stat(conf.KeyFile); err != nil {
|
||||
return nil, err
|
||||
} else if st.Mode().Perm()&otherFilter != 0000 {
|
||||
return nil, fmt.Errorf("key file others permission must be set to 0")
|
||||
return nil, fmt.Errorf("key file others permissions must be set to 0")
|
||||
}
|
||||
f, err := os.Open(conf.KeyFile)
|
||||
if err != nil {
|
||||
|
||||
@@ -28,6 +28,11 @@ import (
|
||||
"github.com/livekit/livekit-server/pkg/rtc/types"
|
||||
)
|
||||
|
||||
type SignalRequestHandler func(msg *livekit.SignalRequest) error
|
||||
type SignalRequestInterceptor func(msg *livekit.SignalRequest, next SignalRequestHandler) error
|
||||
type SignalResponseHandler func(msg *livekit.SignalResponse) error
|
||||
type SignalResponseInterceptor func(msg *livekit.SignalResponse, next SignalResponseHandler) error
|
||||
|
||||
type RTCClient struct {
|
||||
id livekit.ParticipantID
|
||||
conn *websocket.Conn
|
||||
@@ -45,6 +50,9 @@ type RTCClient struct {
|
||||
localParticipant *livekit.ParticipantInfo
|
||||
remoteParticipants map[livekit.ParticipantID]*livekit.ParticipantInfo
|
||||
|
||||
signalRequestInterceptor SignalRequestInterceptor
|
||||
signalResponseInterceptor SignalResponseInterceptor
|
||||
|
||||
subscriberAsPrimary atomic.Bool
|
||||
publisherFullyEstablished atomic.Bool
|
||||
subscriberFullyEstablished atomic.Bool
|
||||
@@ -83,10 +91,12 @@ var (
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
AutoSubscribe bool
|
||||
Publish string
|
||||
ClientInfo *livekit.ClientInfo
|
||||
DisabledCodecs []webrtc.RTPCodecCapability
|
||||
AutoSubscribe bool
|
||||
Publish string
|
||||
ClientInfo *livekit.ClientInfo
|
||||
DisabledCodecs []webrtc.RTPCodecCapability
|
||||
SignalRequestInterceptor SignalRequestInterceptor
|
||||
SignalResponseInterceptor SignalResponseInterceptor
|
||||
}
|
||||
|
||||
func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error) {
|
||||
@@ -265,6 +275,11 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) {
|
||||
})
|
||||
})
|
||||
|
||||
if opts != nil {
|
||||
c.signalRequestInterceptor = opts.SignalRequestInterceptor
|
||||
c.signalResponseInterceptor = opts.SignalResponseInterceptor
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -290,89 +305,101 @@ func (c *RTCClient) Run() error {
|
||||
logger.Errorw("error while reading", err)
|
||||
return err
|
||||
}
|
||||
switch msg := res.Message.(type) {
|
||||
case *livekit.SignalResponse_Join:
|
||||
c.localParticipant = msg.Join.Participant
|
||||
c.id = livekit.ParticipantID(msg.Join.Participant.Sid)
|
||||
c.lock.Lock()
|
||||
for _, p := range msg.Join.OtherParticipants {
|
||||
c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p
|
||||
}
|
||||
c.lock.Unlock()
|
||||
// if publish only, negotiate
|
||||
if !msg.Join.SubscriberPrimary {
|
||||
c.subscriberAsPrimary.Store(false)
|
||||
c.publisher.Negotiate(false)
|
||||
} else {
|
||||
c.subscriberAsPrimary.Store(true)
|
||||
}
|
||||
|
||||
logger.Infow("join accepted, awaiting offer", "participant", msg.Join.Participant.Identity)
|
||||
case *livekit.SignalResponse_Answer:
|
||||
// logger.Debugw("received server answer",
|
||||
// "participant", c.localParticipant.Identity,
|
||||
// "answer", msg.Answer.Sdp)
|
||||
c.handleAnswer(rtc.FromProtoSessionDescription(msg.Answer))
|
||||
case *livekit.SignalResponse_Offer:
|
||||
logger.Infow("received server offer",
|
||||
"participant", c.localParticipant.Identity,
|
||||
)
|
||||
desc := rtc.FromProtoSessionDescription(msg.Offer)
|
||||
c.handleOffer(desc)
|
||||
case *livekit.SignalResponse_Trickle:
|
||||
candidateInit, err := rtc.FromProtoTrickle(msg.Trickle)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Trickle.Target == livekit.SignalTarget_PUBLISHER {
|
||||
c.publisher.AddICECandidate(candidateInit)
|
||||
} else {
|
||||
c.subscriber.AddICECandidate(candidateInit)
|
||||
}
|
||||
case *livekit.SignalResponse_Update:
|
||||
c.lock.Lock()
|
||||
for _, p := range msg.Update.Participants {
|
||||
if livekit.ParticipantID(p.Sid) != c.id {
|
||||
if p.State != livekit.ParticipantInfo_DISCONNECTED {
|
||||
c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p
|
||||
} else {
|
||||
delete(c.remoteParticipants, livekit.ParticipantID(p.Sid))
|
||||
}
|
||||
}
|
||||
}
|
||||
c.lock.Unlock()
|
||||
|
||||
case *livekit.SignalResponse_TrackPublished:
|
||||
logger.Debugw("track published", "trackID", msg.TrackPublished.Track.Name, "participant", c.localParticipant.Sid,
|
||||
"cid", msg.TrackPublished.Cid, "trackSid", msg.TrackPublished.Track.Sid)
|
||||
c.lock.Lock()
|
||||
c.pendingPublishedTracks[msg.TrackPublished.Cid] = msg.TrackPublished.Track
|
||||
c.lock.Unlock()
|
||||
case *livekit.SignalResponse_RefreshToken:
|
||||
c.lock.Lock()
|
||||
c.refreshToken = msg.RefreshToken
|
||||
c.lock.Unlock()
|
||||
case *livekit.SignalResponse_TrackUnpublished:
|
||||
sid := msg.TrackUnpublished.TrackSid
|
||||
c.lock.Lock()
|
||||
sender := c.trackSenders[sid]
|
||||
if sender != nil {
|
||||
if err := c.publisher.RemoveTrack(sender); err != nil {
|
||||
logger.Errorw("Could not unpublish track", err)
|
||||
}
|
||||
c.publisher.Negotiate(false)
|
||||
}
|
||||
delete(c.trackSenders, sid)
|
||||
delete(c.localTracks, sid)
|
||||
c.lock.Unlock()
|
||||
case *livekit.SignalResponse_Pong:
|
||||
c.pongReceivedAt.Store(msg.Pong)
|
||||
case *livekit.SignalResponse_SubscriptionResponse:
|
||||
c.subscriptionResponse.Store(msg.SubscriptionResponse)
|
||||
if c.signalResponseInterceptor != nil {
|
||||
err = c.signalResponseInterceptor(res, c.handleSignalResponse)
|
||||
} else {
|
||||
err = c.handleSignalResponse(res)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RTCClient) handleSignalResponse(res *livekit.SignalResponse) error {
|
||||
switch msg := res.Message.(type) {
|
||||
case *livekit.SignalResponse_Join:
|
||||
c.localParticipant = msg.Join.Participant
|
||||
c.id = livekit.ParticipantID(msg.Join.Participant.Sid)
|
||||
c.lock.Lock()
|
||||
for _, p := range msg.Join.OtherParticipants {
|
||||
c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p
|
||||
}
|
||||
c.lock.Unlock()
|
||||
// if publish only, negotiate
|
||||
if !msg.Join.SubscriberPrimary {
|
||||
c.subscriberAsPrimary.Store(false)
|
||||
c.publisher.Negotiate(false)
|
||||
} else {
|
||||
c.subscriberAsPrimary.Store(true)
|
||||
}
|
||||
|
||||
logger.Infow("join accepted, awaiting offer", "participant", msg.Join.Participant.Identity)
|
||||
case *livekit.SignalResponse_Answer:
|
||||
// logger.Debugw("received server answer",
|
||||
// "participant", c.localParticipant.Identity,
|
||||
// "answer", msg.Answer.Sdp)
|
||||
c.handleAnswer(rtc.FromProtoSessionDescription(msg.Answer))
|
||||
case *livekit.SignalResponse_Offer:
|
||||
logger.Infow("received server offer",
|
||||
"participant", c.localParticipant.Identity,
|
||||
)
|
||||
desc := rtc.FromProtoSessionDescription(msg.Offer)
|
||||
c.handleOffer(desc)
|
||||
case *livekit.SignalResponse_Trickle:
|
||||
candidateInit, err := rtc.FromProtoTrickle(msg.Trickle)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Trickle.Target == livekit.SignalTarget_PUBLISHER {
|
||||
c.publisher.AddICECandidate(candidateInit)
|
||||
} else {
|
||||
c.subscriber.AddICECandidate(candidateInit)
|
||||
}
|
||||
case *livekit.SignalResponse_Update:
|
||||
c.lock.Lock()
|
||||
for _, p := range msg.Update.Participants {
|
||||
if livekit.ParticipantID(p.Sid) != c.id {
|
||||
if p.State != livekit.ParticipantInfo_DISCONNECTED {
|
||||
c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p
|
||||
} else {
|
||||
delete(c.remoteParticipants, livekit.ParticipantID(p.Sid))
|
||||
}
|
||||
}
|
||||
}
|
||||
c.lock.Unlock()
|
||||
|
||||
case *livekit.SignalResponse_TrackPublished:
|
||||
logger.Debugw("track published", "trackID", msg.TrackPublished.Track.Name, "participant", c.localParticipant.Sid,
|
||||
"cid", msg.TrackPublished.Cid, "trackSid", msg.TrackPublished.Track.Sid)
|
||||
c.lock.Lock()
|
||||
c.pendingPublishedTracks[msg.TrackPublished.Cid] = msg.TrackPublished.Track
|
||||
c.lock.Unlock()
|
||||
case *livekit.SignalResponse_RefreshToken:
|
||||
c.lock.Lock()
|
||||
c.refreshToken = msg.RefreshToken
|
||||
c.lock.Unlock()
|
||||
case *livekit.SignalResponse_TrackUnpublished:
|
||||
sid := msg.TrackUnpublished.TrackSid
|
||||
c.lock.Lock()
|
||||
sender := c.trackSenders[sid]
|
||||
if sender != nil {
|
||||
if err := c.publisher.RemoveTrack(sender); err != nil {
|
||||
logger.Errorw("Could not unpublish track", err)
|
||||
}
|
||||
c.publisher.Negotiate(false)
|
||||
}
|
||||
delete(c.trackSenders, sid)
|
||||
delete(c.localTracks, sid)
|
||||
c.lock.Unlock()
|
||||
case *livekit.SignalResponse_Pong:
|
||||
c.pongReceivedAt.Store(msg.Pong)
|
||||
case *livekit.SignalResponse_SubscriptionResponse:
|
||||
c.subscriptionResponse.Store(msg.SubscriptionResponse)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RTCClient) WaitUntilConnected() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
@@ -486,6 +513,14 @@ func (c *RTCClient) SendPing() error {
|
||||
}
|
||||
|
||||
func (c *RTCClient) SendRequest(msg *livekit.SignalRequest) error {
|
||||
if c.signalRequestInterceptor != nil {
|
||||
return c.signalRequestInterceptor(msg, c.sendRequest)
|
||||
} else {
|
||||
return c.sendRequest(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RTCClient) sendRequest(msg *livekit.SignalRequest) error {
|
||||
payload, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/rtc"
|
||||
"github.com/livekit/livekit-server/pkg/testutils"
|
||||
"github.com/livekit/livekit-server/test/client"
|
||||
)
|
||||
|
||||
func TestMultiNodeRouting(t *testing.T) {
|
||||
@@ -261,3 +262,46 @@ func TestMultiNodeRevokePublishPermission(t *testing.T) {
|
||||
return ""
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseDisconnectedParticipantOnSignalClose(t *testing.T) {
|
||||
_, _, finish := setupMultiNodeTest("TestCloseDisconnectedParticipantOnSignalClose")
|
||||
defer finish()
|
||||
|
||||
c1 := createRTCClient("c1", secondServerPort, nil)
|
||||
waitUntilConnected(t, c1)
|
||||
|
||||
c2 := createRTCClient("c2", defaultServerPort, &client.Options{
|
||||
SignalRequestInterceptor: func(msg *livekit.SignalRequest, next client.SignalRequestHandler) error {
|
||||
switch msg.Message.(type) {
|
||||
case *livekit.SignalRequest_Offer, *livekit.SignalRequest_Answer, *livekit.SignalRequest_Leave:
|
||||
return nil
|
||||
default:
|
||||
return next(msg)
|
||||
}
|
||||
},
|
||||
SignalResponseInterceptor: func(msg *livekit.SignalResponse, next client.SignalResponseHandler) error {
|
||||
switch msg.Message.(type) {
|
||||
case *livekit.SignalResponse_Offer, *livekit.SignalResponse_Answer:
|
||||
return nil
|
||||
default:
|
||||
return next(msg)
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
testutils.WithTimeout(t, func() string {
|
||||
if len(c1.RemoteParticipants()) != 1 {
|
||||
return "c1 did not see c2 join"
|
||||
}
|
||||
return ""
|
||||
})
|
||||
|
||||
c2.Stop()
|
||||
|
||||
testutils.WithTimeout(t, func() string {
|
||||
if len(c1.RemoteParticipants()) != 0 {
|
||||
return "c1 did not see c2 removed"
|
||||
}
|
||||
return ""
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user