diff --git a/go.mod b/go.mod index 406e93972..0429e06fa 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/google/wire v0.5.0 github.com/gorilla/websocket v1.4.2 github.com/hashicorp/golang-lru v0.5.4 - github.com/livekit/protocol v0.11.14-0.20220223230744-2d72f8bc52aa + github.com/livekit/protocol v0.11.14-0.20220225092016-4b44edff9ed7 github.com/magefile/mage v1.11.0 github.com/maxbrunsfeld/counterfeiter/v6 v6.3.0 github.com/mitchellh/go-homedir v1.1.0 diff --git a/go.sum b/go.sum index 2bba105bf..c72e6ccd3 100644 --- a/go.sum +++ b/go.sum @@ -134,6 +134,8 @@ github.com/lithammer/shortuuid/v3 v3.0.6 h1:pr15YQyvhiSX/qPxncFtqk+v4xLEpOZObbsY github.com/lithammer/shortuuid/v3 v3.0.6/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts= github.com/livekit/protocol v0.11.14-0.20220223230744-2d72f8bc52aa h1:3eJ92nyh9krLpbkz2TB4Dtc4oDLO86sznNxQR2sb614= github.com/livekit/protocol v0.11.14-0.20220223230744-2d72f8bc52aa/go.mod h1:3pHsWUtQmWaH8mG0cXrQWpbf3Vo+kj0U+In77CEXu90= +github.com/livekit/protocol v0.11.14-0.20220225092016-4b44edff9ed7 h1:SrUiL7cKfSMGGQBgzZlNGL+wcKXEO//C7j6CQ3zjxHA= +github.com/livekit/protocol v0.11.14-0.20220225092016-4b44edff9ed7/go.mod h1:3pHsWUtQmWaH8mG0cXrQWpbf3Vo+kj0U+In77CEXu90= github.com/magefile/mage v1.11.0 h1:C/55Ywp9BpgVVclD3lRnSYCwXTYxmSppIgLeDYlNuls= github.com/magefile/mage v1.11.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 72d9014a7..fa1eef8d4 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -116,10 +116,11 @@ type ParticipantImpl struct { onMetadataUpdate func(types.LocalParticipant) onDataPacket func(types.LocalParticipant, *livekit.DataPacket) - migrateState atomic.Value // types.MigrateState - pendingOffer *webrtc.SessionDescription - onClose func(types.LocalParticipant, map[livekit.TrackID]livekit.ParticipantID) - onClaimsChanged func(participant types.LocalParticipant) + migrateState atomic.Value // types.MigrateState + pendingOffer *webrtc.SessionDescription + pendingDataChannels []*livekit.DataChannelInfo + onClose func(types.LocalParticipant, map[livekit.TrackID]livekit.ParticipantID) + onClaimsChanged func(participant types.LocalParticipant) } func NewParticipant(params ParticipantParams, perms *livekit.ParticipantPermission) (*ParticipantImpl, error) { @@ -436,11 +437,14 @@ func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) { }) } -func (p *ParticipantImpl) AddMigratedTrack(cid string, ti *livekit.TrackInfo) { +func (p *ParticipantImpl) SetMigrateInfo(mediaTracks []*livekit.TrackPublishedResponse, dataChannels []*livekit.DataChannelInfo) { p.pendingTracksLock.Lock() defer p.pendingTracksLock.Unlock() - p.pendingTracks[cid] = &pendingTrackInfo{ti, true} + for _, t := range mediaTracks { + p.pendingTracks[t.GetCid()] = &pendingTrackInfo{t.GetTrack(), true} + } + p.pendingDataChannels = dataChannels } // HandleAnswer handles a client answer response, with subscriber PC, server initiates the @@ -560,6 +564,9 @@ func (p *ParticipantImpl) SetMigrateState(s types.MigrateState) { p.pendingOffer = nil } p.lock.Unlock() + if s == types.MigrateStateComplete { + p.handlePendingDataChannels() + } if pendingOffer != nil { p.HandleOffer(*pendingOffer) } @@ -1605,3 +1612,46 @@ func (p *ParticipantImpl) DebugInfo() map[string]interface{} { return info } + +func (p *ParticipantImpl) handlePendingDataChannels() { + p.lock.Lock() + defer p.lock.Unlock() + ordered := true + negotiated := true + for _, ci := range p.pendingDataChannels { + if ci.Label == lossyDataChannel && p.lossyDC == nil { + retransmits := uint16(0) + id := uint16(ci.GetId()) + dc, err := p.publisher.pc.CreateDataChannel(lossyDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + MaxRetransmits: &retransmits, + Negotiated: &negotiated, + ID: &id, + }) + if err != nil { + p.params.Logger.Errorw("create migrated data channel failed", err, "label", lossyDataChannel) + } else { + p.lossyDC = dc + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + p.handleDataMessage(livekit.DataPacket_LOSSY, msg.Data) + }) + } + } else if ci.Label == reliableDataChannel && p.reliableDC == nil { + id := uint16(ci.GetId()) + dc, err := p.publisher.pc.CreateDataChannel(reliableDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + Negotiated: &negotiated, + ID: &id, + }) + if err != nil { + p.params.Logger.Errorw("create migrated data channel failed", err, "label", reliableDataChannel) + } else { + p.reliableDC = dc + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + p.handleDataMessage(livekit.DataPacket_RELIABLE, msg.Data) + }) + } + } + } + p.pendingDataChannels = nil +} diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 2d75fe9d4..87964fd69 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -151,7 +151,7 @@ type LocalParticipant interface { // session migration SetMigrateState(s MigrateState) MigrateState() MigrateState - AddMigratedTrack(cid string, ti *livekit.TrackInfo) + SetMigrateInfo(mediaTracks []*livekit.TrackPublishedResponse, dataChannels []*livekit.DataChannelInfo) SetPreviousAnswer(previous *webrtc.SessionDescription) UpdateRTT(rtt uint32) diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 7fb0fe01b..1a71a2145 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -26,12 +26,6 @@ type FakeLocalParticipant struct { addICECandidateReturnsOnCall map[int]struct { result1 error } - AddMigratedTrackStub func(string, *livekit.TrackInfo) - addMigratedTrackMutex sync.RWMutex - addMigratedTrackArgsForCall []struct { - arg1 string - arg2 *livekit.TrackInfo - } AddSubscribedTrackStub func(types.SubscribedTrack) addSubscribedTrackMutex sync.RWMutex addSubscribedTrackArgsForCall []struct { @@ -439,6 +433,12 @@ type FakeLocalParticipant struct { setMetadataArgsForCall []struct { arg1 string } + SetMigrateInfoStub func([]*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo) + setMigrateInfoMutex sync.RWMutex + setMigrateInfoArgsForCall []struct { + arg1 []*livekit.TrackPublishedResponse + arg2 []*livekit.DataChannelInfo + } SetMigrateStateStub func(types.MigrateState) setMigrateStateMutex sync.RWMutex setMigrateStateArgsForCall []struct { @@ -669,39 +669,6 @@ func (fake *FakeLocalParticipant) AddICECandidateReturnsOnCall(i int, result1 er }{result1} } -func (fake *FakeLocalParticipant) AddMigratedTrack(arg1 string, arg2 *livekit.TrackInfo) { - fake.addMigratedTrackMutex.Lock() - fake.addMigratedTrackArgsForCall = append(fake.addMigratedTrackArgsForCall, struct { - arg1 string - arg2 *livekit.TrackInfo - }{arg1, arg2}) - stub := fake.AddMigratedTrackStub - fake.recordInvocation("AddMigratedTrack", []interface{}{arg1, arg2}) - fake.addMigratedTrackMutex.Unlock() - if stub != nil { - fake.AddMigratedTrackStub(arg1, arg2) - } -} - -func (fake *FakeLocalParticipant) AddMigratedTrackCallCount() int { - fake.addMigratedTrackMutex.RLock() - defer fake.addMigratedTrackMutex.RUnlock() - return len(fake.addMigratedTrackArgsForCall) -} - -func (fake *FakeLocalParticipant) AddMigratedTrackCalls(stub func(string, *livekit.TrackInfo)) { - fake.addMigratedTrackMutex.Lock() - defer fake.addMigratedTrackMutex.Unlock() - fake.AddMigratedTrackStub = stub -} - -func (fake *FakeLocalParticipant) AddMigratedTrackArgsForCall(i int) (string, *livekit.TrackInfo) { - fake.addMigratedTrackMutex.RLock() - defer fake.addMigratedTrackMutex.RUnlock() - argsForCall := fake.addMigratedTrackArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2 -} - func (fake *FakeLocalParticipant) AddSubscribedTrack(arg1 types.SubscribedTrack) { fake.addSubscribedTrackMutex.Lock() fake.addSubscribedTrackArgsForCall = append(fake.addSubscribedTrackArgsForCall, struct { @@ -2936,6 +2903,49 @@ func (fake *FakeLocalParticipant) SetMetadataArgsForCall(i int) string { return argsForCall.arg1 } +func (fake *FakeLocalParticipant) SetMigrateInfo(arg1 []*livekit.TrackPublishedResponse, arg2 []*livekit.DataChannelInfo) { + var arg1Copy []*livekit.TrackPublishedResponse + if arg1 != nil { + arg1Copy = make([]*livekit.TrackPublishedResponse, len(arg1)) + copy(arg1Copy, arg1) + } + var arg2Copy []*livekit.DataChannelInfo + if arg2 != nil { + arg2Copy = make([]*livekit.DataChannelInfo, len(arg2)) + copy(arg2Copy, arg2) + } + fake.setMigrateInfoMutex.Lock() + fake.setMigrateInfoArgsForCall = append(fake.setMigrateInfoArgsForCall, struct { + arg1 []*livekit.TrackPublishedResponse + arg2 []*livekit.DataChannelInfo + }{arg1Copy, arg2Copy}) + stub := fake.SetMigrateInfoStub + fake.recordInvocation("SetMigrateInfo", []interface{}{arg1Copy, arg2Copy}) + fake.setMigrateInfoMutex.Unlock() + if stub != nil { + fake.SetMigrateInfoStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) SetMigrateInfoCallCount() int { + fake.setMigrateInfoMutex.RLock() + defer fake.setMigrateInfoMutex.RUnlock() + return len(fake.setMigrateInfoArgsForCall) +} + +func (fake *FakeLocalParticipant) SetMigrateInfoCalls(stub func([]*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo)) { + fake.setMigrateInfoMutex.Lock() + defer fake.setMigrateInfoMutex.Unlock() + fake.SetMigrateInfoStub = stub +} + +func (fake *FakeLocalParticipant) SetMigrateInfoArgsForCall(i int) ([]*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo) { + fake.setMigrateInfoMutex.RLock() + defer fake.setMigrateInfoMutex.RUnlock() + argsForCall := fake.setMigrateInfoArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + func (fake *FakeLocalParticipant) SetMigrateState(arg1 types.MigrateState) { fake.setMigrateStateMutex.Lock() fake.setMigrateStateArgsForCall = append(fake.setMigrateStateArgsForCall, struct { @@ -3822,8 +3832,6 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.invocationsMutex.RUnlock() fake.addICECandidateMutex.RLock() defer fake.addICECandidateMutex.RUnlock() - fake.addMigratedTrackMutex.RLock() - defer fake.addMigratedTrackMutex.RUnlock() fake.addSubscribedTrackMutex.RLock() defer fake.addSubscribedTrackMutex.RUnlock() fake.addSubscriberMutex.RLock() @@ -3914,6 +3922,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.sendSpeakerUpdateMutex.RUnlock() fake.setMetadataMutex.RLock() defer fake.setMetadataMutex.RUnlock() + fake.setMigrateInfoMutex.RLock() + defer fake.setMigrateInfoMutex.RUnlock() fake.setMigrateStateMutex.RLock() defer fake.setMigrateStateMutex.RUnlock() fake.setPermissionMutex.RLock()