From 6c20c7eb152018ab6b2cc0cb99bdf63fbb82177f Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Thu, 20 Jul 2023 21:21:40 -0700 Subject: [PATCH] add test for removing disconnected participants on signal close (#1896) * add test for removing disconnected participants on signal close * cleanup --- pkg/service/wire_gen.go | 2 +- test/client/client.go | 201 +++++++++++++++++++++++----------------- test/multinode_test.go | 44 +++++++++ 3 files changed, 163 insertions(+), 84 deletions(-) diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index ec817449d..b051fb954 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -136,7 +136,7 @@ func createKeyProvider(conf *config.Config) (auth.KeyProvider, error) { if st, err := os.Stat(conf.KeyFile); err != nil { return nil, err } else if st.Mode().Perm()&otherFilter != 0000 { - return nil, fmt.Errorf("key file others permission must be set to 0") + return nil, fmt.Errorf("key file others permissions must be set to 0") } f, err := os.Open(conf.KeyFile) if err != nil { diff --git a/test/client/client.go b/test/client/client.go index fa9b485a1..776d356a7 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -28,6 +28,11 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" ) +type SignalRequestHandler func(msg *livekit.SignalRequest) error +type SignalRequestInterceptor func(msg *livekit.SignalRequest, next SignalRequestHandler) error +type SignalResponseHandler func(msg *livekit.SignalResponse) error +type SignalResponseInterceptor func(msg *livekit.SignalResponse, next SignalResponseHandler) error + type RTCClient struct { id livekit.ParticipantID conn *websocket.Conn @@ -45,6 +50,9 @@ type RTCClient struct { localParticipant *livekit.ParticipantInfo remoteParticipants map[livekit.ParticipantID]*livekit.ParticipantInfo + signalRequestInterceptor SignalRequestInterceptor + signalResponseInterceptor SignalResponseInterceptor + subscriberAsPrimary atomic.Bool publisherFullyEstablished atomic.Bool subscriberFullyEstablished atomic.Bool @@ -83,10 +91,12 @@ var ( ) type Options struct { - AutoSubscribe bool - Publish string - ClientInfo *livekit.ClientInfo - DisabledCodecs []webrtc.RTPCodecCapability + AutoSubscribe bool + Publish string + ClientInfo *livekit.ClientInfo + DisabledCodecs []webrtc.RTPCodecCapability + SignalRequestInterceptor SignalRequestInterceptor + SignalResponseInterceptor SignalResponseInterceptor } func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error) { @@ -265,6 +275,11 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { }) }) + if opts != nil { + c.signalRequestInterceptor = opts.SignalRequestInterceptor + c.signalResponseInterceptor = opts.SignalResponseInterceptor + } + return c, nil } @@ -290,89 +305,101 @@ func (c *RTCClient) Run() error { logger.Errorw("error while reading", err) return err } - switch msg := res.Message.(type) { - case *livekit.SignalResponse_Join: - c.localParticipant = msg.Join.Participant - c.id = livekit.ParticipantID(msg.Join.Participant.Sid) - c.lock.Lock() - for _, p := range msg.Join.OtherParticipants { - c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p - } - c.lock.Unlock() - // if publish only, negotiate - if !msg.Join.SubscriberPrimary { - c.subscriberAsPrimary.Store(false) - c.publisher.Negotiate(false) - } else { - c.subscriberAsPrimary.Store(true) - } - - logger.Infow("join accepted, awaiting offer", "participant", msg.Join.Participant.Identity) - case *livekit.SignalResponse_Answer: - // logger.Debugw("received server answer", - // "participant", c.localParticipant.Identity, - // "answer", msg.Answer.Sdp) - c.handleAnswer(rtc.FromProtoSessionDescription(msg.Answer)) - case *livekit.SignalResponse_Offer: - logger.Infow("received server offer", - "participant", c.localParticipant.Identity, - ) - desc := rtc.FromProtoSessionDescription(msg.Offer) - c.handleOffer(desc) - case *livekit.SignalResponse_Trickle: - candidateInit, err := rtc.FromProtoTrickle(msg.Trickle) - if err != nil { - return err - } - if msg.Trickle.Target == livekit.SignalTarget_PUBLISHER { - c.publisher.AddICECandidate(candidateInit) - } else { - c.subscriber.AddICECandidate(candidateInit) - } - case *livekit.SignalResponse_Update: - c.lock.Lock() - for _, p := range msg.Update.Participants { - if livekit.ParticipantID(p.Sid) != c.id { - if p.State != livekit.ParticipantInfo_DISCONNECTED { - c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p - } else { - delete(c.remoteParticipants, livekit.ParticipantID(p.Sid)) - } - } - } - c.lock.Unlock() - - case *livekit.SignalResponse_TrackPublished: - logger.Debugw("track published", "trackID", msg.TrackPublished.Track.Name, "participant", c.localParticipant.Sid, - "cid", msg.TrackPublished.Cid, "trackSid", msg.TrackPublished.Track.Sid) - c.lock.Lock() - c.pendingPublishedTracks[msg.TrackPublished.Cid] = msg.TrackPublished.Track - c.lock.Unlock() - case *livekit.SignalResponse_RefreshToken: - c.lock.Lock() - c.refreshToken = msg.RefreshToken - c.lock.Unlock() - case *livekit.SignalResponse_TrackUnpublished: - sid := msg.TrackUnpublished.TrackSid - c.lock.Lock() - sender := c.trackSenders[sid] - if sender != nil { - if err := c.publisher.RemoveTrack(sender); err != nil { - logger.Errorw("Could not unpublish track", err) - } - c.publisher.Negotiate(false) - } - delete(c.trackSenders, sid) - delete(c.localTracks, sid) - c.lock.Unlock() - case *livekit.SignalResponse_Pong: - c.pongReceivedAt.Store(msg.Pong) - case *livekit.SignalResponse_SubscriptionResponse: - c.subscriptionResponse.Store(msg.SubscriptionResponse) + if c.signalResponseInterceptor != nil { + err = c.signalResponseInterceptor(res, c.handleSignalResponse) + } else { + err = c.handleSignalResponse(res) + } + if err != nil { + return err } } } +func (c *RTCClient) handleSignalResponse(res *livekit.SignalResponse) error { + switch msg := res.Message.(type) { + case *livekit.SignalResponse_Join: + c.localParticipant = msg.Join.Participant + c.id = livekit.ParticipantID(msg.Join.Participant.Sid) + c.lock.Lock() + for _, p := range msg.Join.OtherParticipants { + c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p + } + c.lock.Unlock() + // if publish only, negotiate + if !msg.Join.SubscriberPrimary { + c.subscriberAsPrimary.Store(false) + c.publisher.Negotiate(false) + } else { + c.subscriberAsPrimary.Store(true) + } + + logger.Infow("join accepted, awaiting offer", "participant", msg.Join.Participant.Identity) + case *livekit.SignalResponse_Answer: + // logger.Debugw("received server answer", + // "participant", c.localParticipant.Identity, + // "answer", msg.Answer.Sdp) + c.handleAnswer(rtc.FromProtoSessionDescription(msg.Answer)) + case *livekit.SignalResponse_Offer: + logger.Infow("received server offer", + "participant", c.localParticipant.Identity, + ) + desc := rtc.FromProtoSessionDescription(msg.Offer) + c.handleOffer(desc) + case *livekit.SignalResponse_Trickle: + candidateInit, err := rtc.FromProtoTrickle(msg.Trickle) + if err != nil { + return err + } + if msg.Trickle.Target == livekit.SignalTarget_PUBLISHER { + c.publisher.AddICECandidate(candidateInit) + } else { + c.subscriber.AddICECandidate(candidateInit) + } + case *livekit.SignalResponse_Update: + c.lock.Lock() + for _, p := range msg.Update.Participants { + if livekit.ParticipantID(p.Sid) != c.id { + if p.State != livekit.ParticipantInfo_DISCONNECTED { + c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p + } else { + delete(c.remoteParticipants, livekit.ParticipantID(p.Sid)) + } + } + } + c.lock.Unlock() + + case *livekit.SignalResponse_TrackPublished: + logger.Debugw("track published", "trackID", msg.TrackPublished.Track.Name, "participant", c.localParticipant.Sid, + "cid", msg.TrackPublished.Cid, "trackSid", msg.TrackPublished.Track.Sid) + c.lock.Lock() + c.pendingPublishedTracks[msg.TrackPublished.Cid] = msg.TrackPublished.Track + c.lock.Unlock() + case *livekit.SignalResponse_RefreshToken: + c.lock.Lock() + c.refreshToken = msg.RefreshToken + c.lock.Unlock() + case *livekit.SignalResponse_TrackUnpublished: + sid := msg.TrackUnpublished.TrackSid + c.lock.Lock() + sender := c.trackSenders[sid] + if sender != nil { + if err := c.publisher.RemoveTrack(sender); err != nil { + logger.Errorw("Could not unpublish track", err) + } + c.publisher.Negotiate(false) + } + delete(c.trackSenders, sid) + delete(c.localTracks, sid) + c.lock.Unlock() + case *livekit.SignalResponse_Pong: + c.pongReceivedAt.Store(msg.Pong) + case *livekit.SignalResponse_SubscriptionResponse: + c.subscriptionResponse.Store(msg.SubscriptionResponse) + } + return nil +} + func (c *RTCClient) WaitUntilConnected() error { ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() @@ -486,6 +513,14 @@ func (c *RTCClient) SendPing() error { } func (c *RTCClient) SendRequest(msg *livekit.SignalRequest) error { + if c.signalRequestInterceptor != nil { + return c.signalRequestInterceptor(msg, c.sendRequest) + } else { + return c.sendRequest(msg) + } +} + +func (c *RTCClient) sendRequest(msg *livekit.SignalRequest) error { payload, err := proto.Marshal(msg) if err != nil { return err diff --git a/test/multinode_test.go b/test/multinode_test.go index 57a2ed57e..fe973aefd 100644 --- a/test/multinode_test.go +++ b/test/multinode_test.go @@ -11,6 +11,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/livekit-server/pkg/testutils" + "github.com/livekit/livekit-server/test/client" ) func TestMultiNodeRouting(t *testing.T) { @@ -261,3 +262,46 @@ func TestMultiNodeRevokePublishPermission(t *testing.T) { return "" }) } + +func TestCloseDisconnectedParticipantOnSignalClose(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestCloseDisconnectedParticipantOnSignalClose") + defer finish() + + c1 := createRTCClient("c1", secondServerPort, nil) + waitUntilConnected(t, c1) + + c2 := createRTCClient("c2", defaultServerPort, &client.Options{ + SignalRequestInterceptor: func(msg *livekit.SignalRequest, next client.SignalRequestHandler) error { + switch msg.Message.(type) { + case *livekit.SignalRequest_Offer, *livekit.SignalRequest_Answer, *livekit.SignalRequest_Leave: + return nil + default: + return next(msg) + } + }, + SignalResponseInterceptor: func(msg *livekit.SignalResponse, next client.SignalResponseHandler) error { + switch msg.Message.(type) { + case *livekit.SignalResponse_Offer, *livekit.SignalResponse_Answer: + return nil + default: + return next(msg) + } + }, + }) + + testutils.WithTimeout(t, func() string { + if len(c1.RemoteParticipants()) != 1 { + return "c1 did not see c2 join" + } + return "" + }) + + c2.Stop() + + testutils.WithTimeout(t, func() string { + if len(c1.RemoteParticipants()) != 0 { + return "c1 did not see c2 removed" + } + return "" + }) +}