Fix race condition with Transport negotiations

This commit is contained in:
David Zhao
2021-06-04 12:26:23 -07:00
parent bf281b1994
commit c510ea2e1a
8 changed files with 78 additions and 55 deletions

View File

@@ -2,7 +2,6 @@ package rtc
import (
"sync"
"sync/atomic"
"time"
"github.com/bep/debounce"
@@ -35,9 +34,8 @@ type PCTransport struct {
pendingCandidates []webrtc.ICECandidateInit
debouncedNegotiate func(func())
onOffer func(offer webrtc.SessionDescription)
restartAfterGathering atomic.Value
negotiationState atomic.Value
restartAfterGathering bool
negotiationState int
}
type TransportParams struct {
@@ -90,11 +88,13 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) {
pc: pc,
me: me,
debouncedNegotiate: debounce.New(negotiationFrequency),
negotiationState: negotiationStateNone,
}
t.negotiationState.Store(negotiationStateNone)
t.pc.OnICEGatheringStateChange(func(state webrtc.ICEGathererState) {
if state == webrtc.ICEGathererStateComplete {
if restart, ok := t.restartAfterGathering.Load().(bool); ok && restart {
t.lock.Lock()
defer t.lock.Unlock()
if t.restartAfterGathering {
if err := t.CreateAndSendOffer(&webrtc.OfferOptions{ICERestart: true}); err != nil {
logger.Warnw("could not restart ICE", err)
}
@@ -125,22 +125,23 @@ func (t *PCTransport) Close() {
}
func (t *PCTransport) SetRemoteDescription(sd webrtc.SessionDescription) error {
t.lock.Lock()
defer t.lock.Unlock()
if err := t.pc.SetRemoteDescription(sd); err != nil {
return err
}
t.lock.Lock()
for _, c := range t.pendingCandidates {
if err := t.pc.AddICECandidate(c); err != nil {
return err
}
}
t.pendingCandidates = nil
t.lock.Unlock()
// negotiated, reset flag
state := t.negotiationState.Load().(int)
t.negotiationState.Store(negotiationStateNone)
state := t.negotiationState
t.negotiationState = negotiationStateNone
if state == negotiationRetry {
// need to Negotiate again
t.Negotiate()
@@ -170,22 +171,22 @@ func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error {
return nil
}
t.lock.Lock()
defer t.lock.Unlock()
iceRestart := options != nil && options.ICERestart
// if restart is requested, and we are not ready, then continue afterwards
if iceRestart {
if t.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering {
logger.Debugw("restart ICE after gathering")
t.restartAfterGathering.Store(true)
t.restartAfterGathering = true
return nil
}
logger.Debugw("restarting ICE")
}
state := t.negotiationState.Load().(int)
// when there's an ongoing negotiation, let it finish and not disrupt its state
if state == negotiationStateClient {
if t.negotiationState == negotiationStateClient {
currentSD := t.pc.CurrentRemoteDescription()
if iceRestart && currentSD != nil {
logger.Debugw("recovering from client negotiation state")
@@ -194,7 +195,7 @@ func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error {
}
} else {
logger.Debugw("skipping negotiation, trying again later")
t.negotiationState.Store(negotiationRetry)
t.negotiationState = negotiationRetry
return nil
}
}
@@ -212,9 +213,9 @@ func (t *PCTransport) CreateAndSendOffer(options *webrtc.OfferOptions) error {
}
// indicate waiting for client
t.negotiationState.Store(negotiationStateClient)
t.restartAfterGathering.Store(false)
t.negotiationState = negotiationStateClient
t.restartAfterGathering = false
t.onOffer(offer)
go t.onOffer(offer)
return nil
}

View File

@@ -2,8 +2,8 @@ package rtc
import (
"testing"
"time"
"github.com/livekit/livekit-server/pkg/testutils"
livekit "github.com/livekit/livekit-server/proto"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"
@@ -45,7 +45,10 @@ func TestMissingAnswerDuringICERestart(t *testing.T) {
require.NoError(t, transportA.CreateAndSendOffer(nil))
// ensure we are connected the first time
time.Sleep(100 * time.Millisecond)
testutils.WithTimeout(t, "initial ICE connectivity", func() bool {
return transportA.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected &&
transportB.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected
})
require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState())
require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState())
@@ -53,7 +56,7 @@ func TestMissingAnswerDuringICERestart(t *testing.T) {
transportA.OnOffer(func(sd webrtc.SessionDescription) {})
require.NoError(t, transportA.CreateAndSendOffer(nil))
require.Equal(t, webrtc.SignalingStateHaveLocalOffer, transportA.pc.SignalingState())
require.Equal(t, negotiationStateClient, transportA.negotiationState.Load().(int))
require.Equal(t, negotiationStateClient, transportA.negotiationState)
// now restart ICE
t.Logf("creating offer with ICE restart")
@@ -62,9 +65,10 @@ func TestMissingAnswerDuringICERestart(t *testing.T) {
ICERestart: true,
}))
time.Sleep(100 * time.Millisecond)
require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState())
require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState())
testutils.WithTimeout(t, "restarted ICE connectivity", func() bool {
return transportA.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected &&
transportB.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected
})
}
func handleOfferFunc(t *testing.T, current, other *PCTransport) func(sd webrtc.SessionDescription) {

View File

@@ -344,7 +344,7 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.Partici
for {
select {
case <-time.After(time.Millisecond * 100):
case <-time.After(time.Millisecond * 50):
// periodic check to ensure participant didn't become disconnected
if participant.State() == livekit.ParticipantInfo_DISCONNECTED {
return

30
pkg/testutils/timeout.go Normal file
View File

@@ -0,0 +1,30 @@
package testutils
import (
"context"
"testing"
"time"
"github.com/livekit/livekit-server/pkg/logger"
)
var (
SyncDelay = 100 * time.Millisecond
ConnectTimeout = 5 * time.Second
)
func WithTimeout(t *testing.T, description string, f func() bool) bool {
logger.Infow(description)
ctx, _ := context.WithTimeout(context.Background(), ConnectTimeout)
for {
select {
case <-ctx.Done():
t.Fatal("timed out: " + description)
return false
case <-time.After(10 * time.Millisecond):
if f() {
return true
}
}
}
}