add test for removing disconnected participants on signal close (#1896)

* add test for removing disconnected participants on signal close

* cleanup
This commit is contained in:
Paul Wells
2023-07-20 21:21:40 -07:00
committed by GitHub
parent 3980d049c9
commit 6c20c7eb15
3 changed files with 163 additions and 84 deletions

View File

@@ -136,7 +136,7 @@ func createKeyProvider(conf *config.Config) (auth.KeyProvider, error) {
if st, err := os.Stat(conf.KeyFile); err != nil {
return nil, err
} else if st.Mode().Perm()&otherFilter != 0000 {
return nil, fmt.Errorf("key file others permission must be set to 0")
return nil, fmt.Errorf("key file others permissions must be set to 0")
}
f, err := os.Open(conf.KeyFile)
if err != nil {

View File

@@ -28,6 +28,11 @@ import (
"github.com/livekit/livekit-server/pkg/rtc/types"
)
type SignalRequestHandler func(msg *livekit.SignalRequest) error
type SignalRequestInterceptor func(msg *livekit.SignalRequest, next SignalRequestHandler) error
type SignalResponseHandler func(msg *livekit.SignalResponse) error
type SignalResponseInterceptor func(msg *livekit.SignalResponse, next SignalResponseHandler) error
type RTCClient struct {
id livekit.ParticipantID
conn *websocket.Conn
@@ -45,6 +50,9 @@ type RTCClient struct {
localParticipant *livekit.ParticipantInfo
remoteParticipants map[livekit.ParticipantID]*livekit.ParticipantInfo
signalRequestInterceptor SignalRequestInterceptor
signalResponseInterceptor SignalResponseInterceptor
subscriberAsPrimary atomic.Bool
publisherFullyEstablished atomic.Bool
subscriberFullyEstablished atomic.Bool
@@ -83,10 +91,12 @@ var (
)
type Options struct {
AutoSubscribe bool
Publish string
ClientInfo *livekit.ClientInfo
DisabledCodecs []webrtc.RTPCodecCapability
AutoSubscribe bool
Publish string
ClientInfo *livekit.ClientInfo
DisabledCodecs []webrtc.RTPCodecCapability
SignalRequestInterceptor SignalRequestInterceptor
SignalResponseInterceptor SignalResponseInterceptor
}
func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error) {
@@ -265,6 +275,11 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) {
})
})
if opts != nil {
c.signalRequestInterceptor = opts.SignalRequestInterceptor
c.signalResponseInterceptor = opts.SignalResponseInterceptor
}
return c, nil
}
@@ -290,89 +305,101 @@ func (c *RTCClient) Run() error {
logger.Errorw("error while reading", err)
return err
}
switch msg := res.Message.(type) {
case *livekit.SignalResponse_Join:
c.localParticipant = msg.Join.Participant
c.id = livekit.ParticipantID(msg.Join.Participant.Sid)
c.lock.Lock()
for _, p := range msg.Join.OtherParticipants {
c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p
}
c.lock.Unlock()
// if publish only, negotiate
if !msg.Join.SubscriberPrimary {
c.subscriberAsPrimary.Store(false)
c.publisher.Negotiate(false)
} else {
c.subscriberAsPrimary.Store(true)
}
logger.Infow("join accepted, awaiting offer", "participant", msg.Join.Participant.Identity)
case *livekit.SignalResponse_Answer:
// logger.Debugw("received server answer",
// "participant", c.localParticipant.Identity,
// "answer", msg.Answer.Sdp)
c.handleAnswer(rtc.FromProtoSessionDescription(msg.Answer))
case *livekit.SignalResponse_Offer:
logger.Infow("received server offer",
"participant", c.localParticipant.Identity,
)
desc := rtc.FromProtoSessionDescription(msg.Offer)
c.handleOffer(desc)
case *livekit.SignalResponse_Trickle:
candidateInit, err := rtc.FromProtoTrickle(msg.Trickle)
if err != nil {
return err
}
if msg.Trickle.Target == livekit.SignalTarget_PUBLISHER {
c.publisher.AddICECandidate(candidateInit)
} else {
c.subscriber.AddICECandidate(candidateInit)
}
case *livekit.SignalResponse_Update:
c.lock.Lock()
for _, p := range msg.Update.Participants {
if livekit.ParticipantID(p.Sid) != c.id {
if p.State != livekit.ParticipantInfo_DISCONNECTED {
c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p
} else {
delete(c.remoteParticipants, livekit.ParticipantID(p.Sid))
}
}
}
c.lock.Unlock()
case *livekit.SignalResponse_TrackPublished:
logger.Debugw("track published", "trackID", msg.TrackPublished.Track.Name, "participant", c.localParticipant.Sid,
"cid", msg.TrackPublished.Cid, "trackSid", msg.TrackPublished.Track.Sid)
c.lock.Lock()
c.pendingPublishedTracks[msg.TrackPublished.Cid] = msg.TrackPublished.Track
c.lock.Unlock()
case *livekit.SignalResponse_RefreshToken:
c.lock.Lock()
c.refreshToken = msg.RefreshToken
c.lock.Unlock()
case *livekit.SignalResponse_TrackUnpublished:
sid := msg.TrackUnpublished.TrackSid
c.lock.Lock()
sender := c.trackSenders[sid]
if sender != nil {
if err := c.publisher.RemoveTrack(sender); err != nil {
logger.Errorw("Could not unpublish track", err)
}
c.publisher.Negotiate(false)
}
delete(c.trackSenders, sid)
delete(c.localTracks, sid)
c.lock.Unlock()
case *livekit.SignalResponse_Pong:
c.pongReceivedAt.Store(msg.Pong)
case *livekit.SignalResponse_SubscriptionResponse:
c.subscriptionResponse.Store(msg.SubscriptionResponse)
if c.signalResponseInterceptor != nil {
err = c.signalResponseInterceptor(res, c.handleSignalResponse)
} else {
err = c.handleSignalResponse(res)
}
if err != nil {
return err
}
}
}
func (c *RTCClient) handleSignalResponse(res *livekit.SignalResponse) error {
switch msg := res.Message.(type) {
case *livekit.SignalResponse_Join:
c.localParticipant = msg.Join.Participant
c.id = livekit.ParticipantID(msg.Join.Participant.Sid)
c.lock.Lock()
for _, p := range msg.Join.OtherParticipants {
c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p
}
c.lock.Unlock()
// if publish only, negotiate
if !msg.Join.SubscriberPrimary {
c.subscriberAsPrimary.Store(false)
c.publisher.Negotiate(false)
} else {
c.subscriberAsPrimary.Store(true)
}
logger.Infow("join accepted, awaiting offer", "participant", msg.Join.Participant.Identity)
case *livekit.SignalResponse_Answer:
// logger.Debugw("received server answer",
// "participant", c.localParticipant.Identity,
// "answer", msg.Answer.Sdp)
c.handleAnswer(rtc.FromProtoSessionDescription(msg.Answer))
case *livekit.SignalResponse_Offer:
logger.Infow("received server offer",
"participant", c.localParticipant.Identity,
)
desc := rtc.FromProtoSessionDescription(msg.Offer)
c.handleOffer(desc)
case *livekit.SignalResponse_Trickle:
candidateInit, err := rtc.FromProtoTrickle(msg.Trickle)
if err != nil {
return err
}
if msg.Trickle.Target == livekit.SignalTarget_PUBLISHER {
c.publisher.AddICECandidate(candidateInit)
} else {
c.subscriber.AddICECandidate(candidateInit)
}
case *livekit.SignalResponse_Update:
c.lock.Lock()
for _, p := range msg.Update.Participants {
if livekit.ParticipantID(p.Sid) != c.id {
if p.State != livekit.ParticipantInfo_DISCONNECTED {
c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p
} else {
delete(c.remoteParticipants, livekit.ParticipantID(p.Sid))
}
}
}
c.lock.Unlock()
case *livekit.SignalResponse_TrackPublished:
logger.Debugw("track published", "trackID", msg.TrackPublished.Track.Name, "participant", c.localParticipant.Sid,
"cid", msg.TrackPublished.Cid, "trackSid", msg.TrackPublished.Track.Sid)
c.lock.Lock()
c.pendingPublishedTracks[msg.TrackPublished.Cid] = msg.TrackPublished.Track
c.lock.Unlock()
case *livekit.SignalResponse_RefreshToken:
c.lock.Lock()
c.refreshToken = msg.RefreshToken
c.lock.Unlock()
case *livekit.SignalResponse_TrackUnpublished:
sid := msg.TrackUnpublished.TrackSid
c.lock.Lock()
sender := c.trackSenders[sid]
if sender != nil {
if err := c.publisher.RemoveTrack(sender); err != nil {
logger.Errorw("Could not unpublish track", err)
}
c.publisher.Negotiate(false)
}
delete(c.trackSenders, sid)
delete(c.localTracks, sid)
c.lock.Unlock()
case *livekit.SignalResponse_Pong:
c.pongReceivedAt.Store(msg.Pong)
case *livekit.SignalResponse_SubscriptionResponse:
c.subscriptionResponse.Store(msg.SubscriptionResponse)
}
return nil
}
func (c *RTCClient) WaitUntilConnected() error {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
@@ -486,6 +513,14 @@ func (c *RTCClient) SendPing() error {
}
func (c *RTCClient) SendRequest(msg *livekit.SignalRequest) error {
if c.signalRequestInterceptor != nil {
return c.signalRequestInterceptor(msg, c.sendRequest)
} else {
return c.sendRequest(msg)
}
}
func (c *RTCClient) sendRequest(msg *livekit.SignalRequest) error {
payload, err := proto.Marshal(msg)
if err != nil {
return err

View File

@@ -11,6 +11,7 @@ import (
"github.com/livekit/livekit-server/pkg/rtc"
"github.com/livekit/livekit-server/pkg/testutils"
"github.com/livekit/livekit-server/test/client"
)
func TestMultiNodeRouting(t *testing.T) {
@@ -261,3 +262,46 @@ func TestMultiNodeRevokePublishPermission(t *testing.T) {
return ""
})
}
func TestCloseDisconnectedParticipantOnSignalClose(t *testing.T) {
_, _, finish := setupMultiNodeTest("TestCloseDisconnectedParticipantOnSignalClose")
defer finish()
c1 := createRTCClient("c1", secondServerPort, nil)
waitUntilConnected(t, c1)
c2 := createRTCClient("c2", defaultServerPort, &client.Options{
SignalRequestInterceptor: func(msg *livekit.SignalRequest, next client.SignalRequestHandler) error {
switch msg.Message.(type) {
case *livekit.SignalRequest_Offer, *livekit.SignalRequest_Answer, *livekit.SignalRequest_Leave:
return nil
default:
return next(msg)
}
},
SignalResponseInterceptor: func(msg *livekit.SignalResponse, next client.SignalResponseHandler) error {
switch msg.Message.(type) {
case *livekit.SignalResponse_Offer, *livekit.SignalResponse_Answer:
return nil
default:
return next(msg)
}
},
})
testutils.WithTimeout(t, func() string {
if len(c1.RemoteParticipants()) != 1 {
return "c1 did not see c2 join"
}
return ""
})
c2.Stop()
testutils.WithTimeout(t, func() string {
if len(c1.RemoteParticipants()) != 0 {
return "c1 did not see c2 removed"
}
return ""
})
}