mirror of
https://github.com/livekit/livekit.git
synced 2026-03-31 02:25:39 +00:00
Fix race condition with Transport negotiations
This commit is contained in:
@@ -2,7 +2,6 @@ package rtc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bep/debounce"
|
||||
@@ -35,9 +34,8 @@ type PCTransport struct {
|
||||
pendingCandidates []webrtc.ICECandidateInit
|
||||
debouncedNegotiate func(func())
|
||||
onOffer func(offer webrtc.SessionDescription)
|
||||
restartAfterGathering atomic.Value
|
||||
|
||||
negotiationState atomic.Value
|
||||
restartAfterGathering bool
|
||||
negotiationState int
|
||||
}
|
||||
|
||||
type TransportParams struct {
|
||||
@@ -90,11 +88,13 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) {
|
||||
pc: pc,
|
||||
me: me,
|
||||
debouncedNegotiate: debounce.New(negotiationFrequency),
|
||||
negotiationState: negotiationStateNone,
|
||||
}
|
||||
t.negotiationState.Store(negotiationStateNone)
|
||||
t.pc.OnICEGatheringStateChange(func(state webrtc.ICEGathererState) {
|
||||
if state == webrtc.ICEGathererStateComplete {
|
||||
if restart, ok := t.restartAfterGathering.Load().(bool); ok && restart {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
if t.restartAfterGathering {
|
||||
if err := t.CreateAndSendOffer(&webrtc.OfferOptions{ICERestart: true}); err != nil {
|
||||
logger.Warnw("could not restart ICE", err)
|
||||
}
|
||||
@@ -125,22 +125,23 @@ func (t *PCTransport) Close() {
|
||||
}
|
||||
|
||||
func (t *PCTransport) SetRemoteDescription(sd webrtc.SessionDescription) error {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
|
||||
if err := t.pc.SetRemoteDescription(sd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.lock.Lock()
|
||||
for _, c := range t.pendingCandidates {
|
||||
if err := t.pc.AddICECandidate(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
t.pendingCandidates = nil
|
||||
t.lock.Unlock()
|
||||
|
||||
// negotiated, reset flag
|
||||
state := t.negotiationState.Load().(int)
|
||||
t.negotiationState.Store(negotiationStateNone)
|
||||
state := t.negotiationState
|
||||
t.negotiationState = negotiationStateNone
|
||||
if state == negotiationRetry {
|
||||
// need to Negotiate again
|
||||
t.Negotiate()
|
||||
@@ -170,22 +171,22 @@ func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
iceRestart := options != nil && options.ICERestart
|
||||
|
||||
// if restart is requested, and we are not ready, then continue afterwards
|
||||
if iceRestart {
|
||||
if t.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering {
|
||||
logger.Debugw("restart ICE after gathering")
|
||||
t.restartAfterGathering.Store(true)
|
||||
t.restartAfterGathering = true
|
||||
return nil
|
||||
}
|
||||
logger.Debugw("restarting ICE")
|
||||
}
|
||||
|
||||
state := t.negotiationState.Load().(int)
|
||||
|
||||
// when there's an ongoing negotiation, let it finish and not disrupt its state
|
||||
if state == negotiationStateClient {
|
||||
if t.negotiationState == negotiationStateClient {
|
||||
currentSD := t.pc.CurrentRemoteDescription()
|
||||
if iceRestart && currentSD != nil {
|
||||
logger.Debugw("recovering from client negotiation state")
|
||||
@@ -194,7 +195,7 @@ func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error {
|
||||
}
|
||||
} else {
|
||||
logger.Debugw("skipping negotiation, trying again later")
|
||||
t.negotiationState.Store(negotiationRetry)
|
||||
t.negotiationState = negotiationRetry
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -212,9 +213,9 @@ func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error {
|
||||
}
|
||||
|
||||
// indicate waiting for client
|
||||
t.negotiationState.Store(negotiationStateClient)
|
||||
t.restartAfterGathering.Store(false)
|
||||
t.negotiationState = negotiationStateClient
|
||||
t.restartAfterGathering = false
|
||||
|
||||
t.onOffer(offer)
|
||||
go t.onOffer(offer)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@ package rtc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/testutils"
|
||||
livekit "github.com/livekit/livekit-server/proto"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -45,7 +45,10 @@ func TestMissingAnswerDuringICERestart(t *testing.T) {
|
||||
require.NoError(t, transportA.CreateAndSendOffer(nil))
|
||||
|
||||
// ensure we are connected the first time
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
testutils.WithTimeout(t, "initial ICE connectivity", func() bool {
|
||||
return transportA.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected &&
|
||||
transportB.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected
|
||||
})
|
||||
require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState())
|
||||
require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState())
|
||||
|
||||
@@ -53,7 +56,7 @@ func TestMissingAnswerDuringICERestart(t *testing.T) {
|
||||
transportA.OnOffer(func(sd webrtc.SessionDescription) {})
|
||||
require.NoError(t, transportA.CreateAndSendOffer(nil))
|
||||
require.Equal(t, webrtc.SignalingStateHaveLocalOffer, transportA.pc.SignalingState())
|
||||
require.Equal(t, negotiationStateClient, transportA.negotiationState.Load().(int))
|
||||
require.Equal(t, negotiationStateClient, transportA.negotiationState)
|
||||
|
||||
// now restart ICE
|
||||
t.Logf("creating offer with ICE restart")
|
||||
@@ -62,9 +65,10 @@ func TestMissingAnswerDuringICERestart(t *testing.T) {
|
||||
ICERestart: true,
|
||||
}))
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState())
|
||||
require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState())
|
||||
testutils.WithTimeout(t, "restarted ICE connectivity", func() bool {
|
||||
return transportA.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected &&
|
||||
transportB.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected
|
||||
})
|
||||
}
|
||||
|
||||
func handleOfferFunc(t *testing.T, current, other *PCTransport) func(sd webrtc.SessionDescription) {
|
||||
|
||||
@@ -344,7 +344,7 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.Partici
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Millisecond * 100):
|
||||
case <-time.After(time.Millisecond * 50):
|
||||
// periodic check to ensure participant didn't become disconnected
|
||||
if participant.State() == livekit.ParticipantInfo_DISCONNECTED {
|
||||
return
|
||||
|
||||
30
pkg/testutils/timeout.go
Normal file
30
pkg/testutils/timeout.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
SyncDelay = 100 * time.Millisecond
|
||||
ConnectTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func WithTimeout(t *testing.T, description string, f func() bool) bool {
|
||||
logger.Infow(description)
|
||||
ctx, _ := context.WithTimeout(context.Background(), ConnectTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out: " + description)
|
||||
return false
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
if f() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/livekit/livekit-server/pkg/testutils"
|
||||
testclient "github.com/livekit/livekit-server/test/client"
|
||||
"github.com/twitchtv/twirp"
|
||||
|
||||
@@ -30,8 +31,7 @@ const (
|
||||
nodeId1 = "node-1"
|
||||
nodeId2 = "node-2"
|
||||
|
||||
syncDelay = 100 * time.Millisecond
|
||||
connectTimeout = 10 * time.Second
|
||||
syncDelay = 100 * time.Millisecond
|
||||
// if there are deadlocks, it's helpful to set a short test timeout (i.e. go test -timeout=30s)
|
||||
// let connection timeout happen
|
||||
//connectTimeout = 5000 * time.Second
|
||||
@@ -95,7 +95,8 @@ func contextWithCreateRoomToken() context.Context {
|
||||
|
||||
func waitForServerToStart(s *service.LivekitServer) {
|
||||
// wait till ready
|
||||
ctx, _ := context.WithTimeout(context.Background(), connectTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutils.ConnectTimeout)
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -108,22 +109,6 @@ func waitForServerToStart(s *service.LivekitServer) {
|
||||
}
|
||||
}
|
||||
|
||||
func withTimeout(t *testing.T, description string, f func() bool) bool {
|
||||
logger.Infow(description)
|
||||
ctx, _ := context.WithTimeout(context.Background(), connectTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out: " + description)
|
||||
return false
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
if f() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitUntilConnected(t *testing.T, clients ...*testclient.RTCClient) {
|
||||
logger.Infow("waiting for clients to become connected")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/testutils"
|
||||
livekit "github.com/livekit/livekit-server/proto"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -35,7 +36,7 @@ func TestMultiNodeRouting(t *testing.T) {
|
||||
defer t1.Stop()
|
||||
}
|
||||
|
||||
withTimeout(t, "c2 should receive one track", func() bool {
|
||||
testutils.WithTimeout(t, "c2 should receive one track", func() bool {
|
||||
if len(c2.SubscribedTracks()) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/testutils"
|
||||
testclient "github.com/livekit/livekit-server/test/client"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -26,7 +27,7 @@ func scenarioPublishingUponJoining(t *testing.T, ports ...int) {
|
||||
defer stopWriters(writers...)
|
||||
|
||||
logger.Infow("waiting to receive tracks from c1 and c2")
|
||||
success := withTimeout(t, "c3 should receive tracks from both clients", func() bool {
|
||||
success := testutils.WithTimeout(t, "c3 should receive tracks from both clients", func() bool {
|
||||
tracks := c3.SubscribedTracks()
|
||||
if len(tracks[c1.ID()]) != 2 {
|
||||
return false
|
||||
@@ -46,7 +47,7 @@ func scenarioPublishingUponJoining(t *testing.T, ports ...int) {
|
||||
c2.Stop()
|
||||
|
||||
logger.Infow("waiting for c2 tracks to be gone")
|
||||
success = withTimeout(t, "c2 tracks should be gone", func() bool {
|
||||
success = testutils.WithTimeout(t, "c2 tracks should be gone", func() bool {
|
||||
tracks := c3.SubscribedTracks()
|
||||
if len(tracks[c1.ID()]) != 2 {
|
||||
return false
|
||||
@@ -71,7 +72,7 @@ func scenarioPublishingUponJoining(t *testing.T, ports ...int) {
|
||||
writers = publishTracksForClients(t, c2)
|
||||
defer stopWriters(writers...)
|
||||
|
||||
success = withTimeout(t, "new c2 tracks should be published again", func() bool {
|
||||
success = testutils.WithTimeout(t, "new c2 tracks should be published again", func() bool {
|
||||
tracks := c3.SubscribedTracks()
|
||||
if len(tracks[c2.ID()]) != 2 {
|
||||
return false
|
||||
@@ -98,7 +99,7 @@ func scenarioReceiveBeforePublish(t *testing.T) {
|
||||
defer stopWriters(writers...)
|
||||
|
||||
// c2 should see some bytes flowing through
|
||||
success := withTimeout(t, "waiting to receive bytes on c2", func() bool {
|
||||
success := testutils.WithTimeout(t, "waiting to receive bytes on c2", func() bool {
|
||||
return c2.BytesReceived() > 20
|
||||
})
|
||||
if !success {
|
||||
@@ -109,7 +110,7 @@ func scenarioReceiveBeforePublish(t *testing.T) {
|
||||
writers = publishTracksForClients(t, c2)
|
||||
defer stopWriters(writers...)
|
||||
|
||||
success = withTimeout(t, "waiting to receive c2 tracks on c1", func() bool {
|
||||
success = testutils.WithTimeout(t, "waiting to receive c2 tracks on c1", func() bool {
|
||||
return len(c1.SubscribedTracks()[c2.ID()]) == 2
|
||||
})
|
||||
require.True(t, success)
|
||||
@@ -117,7 +118,7 @@ func scenarioReceiveBeforePublish(t *testing.T) {
|
||||
// now leave, and ensure that it's immediate
|
||||
c2.Stop()
|
||||
|
||||
time.Sleep(connectTimeout)
|
||||
time.Sleep(testutils.ConnectTimeout)
|
||||
require.Empty(t, c1.RemoteParticipants())
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/testutils"
|
||||
testclient "github.com/livekit/livekit-server/test/client"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -23,7 +24,7 @@ func TestClientCouldConnect(t *testing.T) {
|
||||
waitUntilConnected(t, c1, c2)
|
||||
|
||||
// ensure they both see each other
|
||||
withTimeout(t, "c1 and c2 could connect", func() bool {
|
||||
testutils.WithTimeout(t, "c1 and c2 could connect", func() bool {
|
||||
if len(c1.RemoteParticipants()) == 0 || len(c2.RemoteParticipants()) == 0 {
|
||||
return false
|
||||
}
|
||||
@@ -56,7 +57,7 @@ func TestSinglePublisher(t *testing.T) {
|
||||
// a new client joins and should get the initial stream
|
||||
c3 := createRTCClient("c3", defaultServerPort, nil)
|
||||
|
||||
success := withTimeout(t, "c2 should receive two tracks", func() bool {
|
||||
success := testutils.WithTimeout(t, "c2 should receive two tracks", func() bool {
|
||||
if len(c2.SubscribedTracks()) == 0 {
|
||||
return false
|
||||
}
|
||||
@@ -75,7 +76,7 @@ func TestSinglePublisher(t *testing.T) {
|
||||
|
||||
// ensure that new client that has joined also received tracks
|
||||
waitUntilConnected(t, c3)
|
||||
success = withTimeout(t, "c2 should receive two tracks", func() bool {
|
||||
success = testutils.WithTimeout(t, "c2 should receive two tracks", func() bool {
|
||||
if len(c3.SubscribedTracks()) == 0 {
|
||||
return false
|
||||
}
|
||||
@@ -98,7 +99,7 @@ func TestSinglePublisher(t *testing.T) {
|
||||
// when c3 disconnects.. ensure subscriber is cleaned up correctly
|
||||
c3.Stop()
|
||||
|
||||
success = withTimeout(t, "c3 is cleaned up as a subscriber", func() bool {
|
||||
success = testutils.WithTimeout(t, "c3 is cleaned up as a subscriber", func() bool {
|
||||
room := s.RoomManager().GetRoom(testRoom)
|
||||
require.NotNil(t, room)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user