From c510ea2e1aa798fa2000c0b7fe00a114d042cb0f Mon Sep 17 00:00:00 2001 From: David Zhao Date: Fri, 4 Jun 2021 12:26:23 -0700 Subject: [PATCH] Fix race condition with Transport negotiations --- pkg/rtc/transport.go | 37 +++++++++++++++++++------------------ pkg/rtc/transport_test.go | 16 ++++++++++------ pkg/service/roommanager.go | 2 +- pkg/testutils/timeout.go | 30 ++++++++++++++++++++++++++++++ test/integration_helpers.go | 23 ++++------------------- test/multinode_test.go | 3 ++- test/scenarios.go | 13 +++++++------ test/singlenode_test.go | 9 +++++---- 8 files changed, 78 insertions(+), 55 deletions(-) create mode 100644 pkg/testutils/timeout.go diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 4a4a2a381..994851e26 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -2,7 +2,6 @@ package rtc import ( "sync" - "sync/atomic" "time" "github.com/bep/debounce" @@ -35,9 +34,8 @@ type PCTransport struct { pendingCandidates []webrtc.ICECandidateInit debouncedNegotiate func(func()) onOffer func(offer webrtc.SessionDescription) - restartAfterGathering atomic.Value - - negotiationState atomic.Value + restartAfterGathering bool + negotiationState int } type TransportParams struct { @@ -90,11 +88,13 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { pc: pc, me: me, debouncedNegotiate: debounce.New(negotiationFrequency), + negotiationState: negotiationStateNone, } - t.negotiationState.Store(negotiationStateNone) t.pc.OnICEGatheringStateChange(func(state webrtc.ICEGathererState) { if state == webrtc.ICEGathererStateComplete { - if restart, ok := t.restartAfterGathering.Load().(bool); ok && restart { + t.lock.Lock() + defer t.lock.Unlock() + if t.restartAfterGathering { if err := t.CreateAndSendOffer(&webrtc.OfferOptions{ICERestart: true}); err != nil { logger.Warnw("could not restart ICE", err) } @@ -125,22 +125,23 @@ func (t *PCTransport) Close() { } func (t *PCTransport) SetRemoteDescription(sd webrtc.SessionDescription) error { + t.lock.Lock() + defer t.lock.Unlock() + if err := t.pc.SetRemoteDescription(sd); err != nil { return err } - t.lock.Lock() for _, c := range t.pendingCandidates { if err := t.pc.AddICECandidate(c); err != nil { return err } } t.pendingCandidates = nil - t.lock.Unlock() // negotiated, reset flag - state := t.negotiationState.Load().(int) - t.negotiationState.Store(negotiationStateNone) + state := t.negotiationState + t.negotiationState = negotiationStateNone if state == negotiationRetry { // need to Negotiate again t.Negotiate() @@ -170,22 +171,22 @@ func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error { return nil } + t.lock.Lock() + defer t.lock.Unlock() iceRestart := options != nil && options.ICERestart // if restart is requested, and we are not ready, then continue afterwards if iceRestart { if t.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering { logger.Debugw("restart ICE after gathering") - t.restartAfterGathering.Store(true) + t.restartAfterGathering = true return nil } logger.Debugw("restarting ICE") } - state := t.negotiationState.Load().(int) - // when there's an ongoing negotiation, let it finish and not disrupt its state - if state == negotiationStateClient { + if t.negotiationState == negotiationStateClient { currentSD := t.pc.CurrentRemoteDescription() if iceRestart && currentSD != nil { logger.Debugw("recovering from client negotiation state") @@ -194,7 +195,7 @@ func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error { } } else { logger.Debugw("skipping negotiation, trying again later") - t.negotiationState.Store(negotiationRetry) + t.negotiationState = negotiationRetry return nil } } @@ -212,9 +213,9 @@ func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error { } // indicate waiting for client - t.negotiationState.Store(negotiationStateClient) - t.restartAfterGathering.Store(false) + t.negotiationState = negotiationStateClient + t.restartAfterGathering = false - t.onOffer(offer) + go t.onOffer(offer) return nil } diff --git a/pkg/rtc/transport_test.go b/pkg/rtc/transport_test.go index 5e5b8b1d6..747109ec8 100644 --- a/pkg/rtc/transport_test.go +++ b/pkg/rtc/transport_test.go @@ -2,8 +2,8 @@ package rtc import ( "testing" - "time" + "github.com/livekit/livekit-server/pkg/testutils" livekit "github.com/livekit/livekit-server/proto" "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" @@ -45,7 +45,10 @@ func TestMissingAnswerDuringICERestart(t *testing.T) { require.NoError(t, transportA.CreateAndSendOffer(nil)) // ensure we are connected the first time - time.Sleep(100 * time.Millisecond) + testutils.WithTimeout(t, "initial ICE connectivity", func() bool { + return transportA.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected && + transportB.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected + }) require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState()) require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState()) @@ -53,7 +56,7 @@ func TestMissingAnswerDuringICERestart(t *testing.T) { transportA.OnOffer(func(sd webrtc.SessionDescription) {}) require.NoError(t, transportA.CreateAndSendOffer(nil)) require.Equal(t, webrtc.SignalingStateHaveLocalOffer, transportA.pc.SignalingState()) - require.Equal(t, negotiationStateClient, transportA.negotiationState.Load().(int)) + require.Equal(t, negotiationStateClient, transportA.negotiationState) // now restart ICE t.Logf("creating offer with ICE restart") @@ -62,9 +65,10 @@ func TestMissingAnswerDuringICERestart(t *testing.T) { ICERestart: true, })) - time.Sleep(100 * time.Millisecond) - require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState()) - require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState()) + testutils.WithTimeout(t, "restarted ICE connectivity", func() bool { + return transportA.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected && + transportB.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected + }) } func handleOfferFunc(t *testing.T, current, other *PCTransport) func(sd webrtc.SessionDescription) { diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index c8404ecd7..4700b2a11 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -344,7 +344,7 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.Partici for { select { - case <-time.After(time.Millisecond * 100): + case <-time.After(time.Millisecond * 50): // periodic check to ensure participant didn't become disconnected if participant.State() == livekit.ParticipantInfo_DISCONNECTED { return diff --git a/pkg/testutils/timeout.go b/pkg/testutils/timeout.go new file mode 100644 index 000000000..d821c3b83 --- /dev/null +++ b/pkg/testutils/timeout.go @@ -0,0 +1,30 @@ +package testutils + +import ( + "context" + "testing" + "time" + + "github.com/livekit/livekit-server/pkg/logger" +) + +var ( + SyncDelay = 100 * time.Millisecond + ConnectTimeout = 5 * time.Second +) + +func WithTimeout(t *testing.T, description string, f func() bool) bool { + logger.Infow(description) + ctx, _ := context.WithTimeout(context.Background(), ConnectTimeout) + for { + select { + case <-ctx.Done(): + t.Fatal("timed out: " + description) + return false + case <-time.After(10 * time.Millisecond): + if f() { + return true + } + } + } +} diff --git a/test/integration_helpers.go b/test/integration_helpers.go index 9a815a705..1d885d5c8 100644 --- a/test/integration_helpers.go +++ b/test/integration_helpers.go @@ -9,6 +9,7 @@ import ( "time" "github.com/go-redis/redis/v8" + "github.com/livekit/livekit-server/pkg/testutils" testclient "github.com/livekit/livekit-server/test/client" "github.com/twitchtv/twirp" @@ -30,8 +31,7 @@ const ( nodeId1 = "node-1" nodeId2 = "node-2" - syncDelay = 100 * time.Millisecond - connectTimeout = 10 * time.Second + syncDelay = 100 * time.Millisecond // if there are deadlocks, it's helpful to set a short test timeout (i.e. go test -timeout=30s) // let connection timeout happen //connectTimeout = 5000 * time.Second @@ -95,7 +95,8 @@ func contextWithCreateRoomToken() context.Context { func waitForServerToStart(s *service.LivekitServer) { // wait till ready - ctx, _ := context.WithTimeout(context.Background(), connectTimeout) + ctx, cancel := context.WithTimeout(context.Background(), testutils.ConnectTimeout) + defer cancel() for { select { case <-ctx.Done(): @@ -108,22 +109,6 @@ func waitForServerToStart(s *service.LivekitServer) { } } -func withTimeout(t *testing.T, description string, f func() bool) bool { - logger.Infow(description) - ctx, _ := context.WithTimeout(context.Background(), connectTimeout) - for { - select { - case <-ctx.Done(): - t.Fatal("timed out: " + description) - return false - case <-time.After(10 * time.Millisecond): - if f() { - return true - } - } - } -} - func waitUntilConnected(t *testing.T, clients ...*testclient.RTCClient) { logger.Infow("waiting for clients to become connected") wg := sync.WaitGroup{} diff --git a/test/multinode_test.go b/test/multinode_test.go index caf99cc70..ad91c4d8b 100644 --- a/test/multinode_test.go +++ b/test/multinode_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/livekit/livekit-server/pkg/testutils" livekit "github.com/livekit/livekit-server/proto" "github.com/stretchr/testify/require" ) @@ -35,7 +36,7 @@ func TestMultiNodeRouting(t *testing.T) { defer t1.Stop() } - withTimeout(t, "c2 should receive one track", func() bool { + testutils.WithTimeout(t, "c2 should receive one track", func() bool { if len(c2.SubscribedTracks()) == 0 { return false } diff --git a/test/scenarios.go b/test/scenarios.go index c6ed59de0..d76c544e7 100644 --- a/test/scenarios.go +++ b/test/scenarios.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/livekit/livekit-server/pkg/testutils" testclient "github.com/livekit/livekit-server/test/client" "github.com/stretchr/testify/require" @@ -26,7 +27,7 @@ func scenarioPublishingUponJoining(t *testing.T, ports ...int) { defer stopWriters(writers...) logger.Infow("waiting to receive tracks from c1 and c2") - success := withTimeout(t, "c3 should receive tracks from both clients", func() bool { + success := testutils.WithTimeout(t, "c3 should receive tracks from both clients", func() bool { tracks := c3.SubscribedTracks() if len(tracks[c1.ID()]) != 2 { return false @@ -46,7 +47,7 @@ func scenarioPublishingUponJoining(t *testing.T, ports ...int) { c2.Stop() logger.Infow("waiting for c2 tracks to be gone") - success = withTimeout(t, "c2 tracks should be gone", func() bool { + success = testutils.WithTimeout(t, "c2 tracks should be gone", func() bool { tracks := c3.SubscribedTracks() if len(tracks[c1.ID()]) != 2 { return false @@ -71,7 +72,7 @@ func scenarioPublishingUponJoining(t *testing.T, ports ...int) { writers = publishTracksForClients(t, c2) defer stopWriters(writers...) - success = withTimeout(t, "new c2 tracks should be published again", func() bool { + success = testutils.WithTimeout(t, "new c2 tracks should be published again", func() bool { tracks := c3.SubscribedTracks() if len(tracks[c2.ID()]) != 2 { return false @@ -98,7 +99,7 @@ func scenarioReceiveBeforePublish(t *testing.T) { defer stopWriters(writers...) // c2 should see some bytes flowing through - success := withTimeout(t, "waiting to receive bytes on c2", func() bool { + success := testutils.WithTimeout(t, "waiting to receive bytes on c2", func() bool { return c2.BytesReceived() > 20 }) if !success { @@ -109,7 +110,7 @@ func scenarioReceiveBeforePublish(t *testing.T) { writers = publishTracksForClients(t, c2) defer stopWriters(writers...) - success = withTimeout(t, "waiting to receive c2 tracks on c1", func() bool { + success = testutils.WithTimeout(t, "waiting to receive c2 tracks on c1", func() bool { return len(c1.SubscribedTracks()[c2.ID()]) == 2 }) require.True(t, success) @@ -117,7 +118,7 @@ func scenarioReceiveBeforePublish(t *testing.T) { // now leave, and ensure that it's immediate c2.Stop() - time.Sleep(connectTimeout) + time.Sleep(testutils.ConnectTimeout) require.Empty(t, c1.RemoteParticipants()) } diff --git a/test/singlenode_test.go b/test/singlenode_test.go index 2402b3007..0e127a396 100644 --- a/test/singlenode_test.go +++ b/test/singlenode_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/livekit/livekit-server/pkg/testutils" testclient "github.com/livekit/livekit-server/test/client" "github.com/stretchr/testify/require" ) @@ -23,7 +24,7 @@ func TestClientCouldConnect(t *testing.T) { waitUntilConnected(t, c1, c2) // ensure they both see each other - withTimeout(t, "c1 and c2 could connect", func() bool { + testutils.WithTimeout(t, "c1 and c2 could connect", func() bool { if len(c1.RemoteParticipants()) == 0 || len(c2.RemoteParticipants()) == 0 { return false } @@ -56,7 +57,7 @@ func TestSinglePublisher(t *testing.T) { // a new client joins and should get the initial stream c3 := createRTCClient("c3", defaultServerPort, nil) - success := withTimeout(t, "c2 should receive two tracks", func() bool { + success := testutils.WithTimeout(t, "c2 should receive two tracks", func() bool { if len(c2.SubscribedTracks()) == 0 { return false } @@ -75,7 +76,7 @@ func TestSinglePublisher(t *testing.T) { // ensure that new client that has joined also received tracks waitUntilConnected(t, c3) - success = withTimeout(t, "c2 should receive two tracks", func() bool { + success = testutils.WithTimeout(t, "c2 should receive two tracks", func() bool { if len(c3.SubscribedTracks()) == 0 { return false } @@ -98,7 +99,7 @@ func TestSinglePublisher(t *testing.T) { // when c3 disconnects.. ensure subscriber is cleaned up correctly c3.Stop() - success = withTimeout(t, "c3 is cleaned up as a subscriber", func() bool { + success = testutils.WithTimeout(t, "c3 is cleaned up as a subscriber", func() bool { room := s.RoomManager().GetRoom(testRoom) require.NotNil(t, room)