From 2a79bdb678460bd032f6ca59211098719d17fab0 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Mon, 21 Dec 2020 23:00:48 -0800 Subject: [PATCH] switch participant to interface --- go.mod | 1 + pkg/rtc/datatrack.go | 4 +- pkg/rtc/interfaces.go | 44 ++++ pkg/rtc/mediatrack.go | 8 +- pkg/rtc/mock_helper_test.go | 15 ++ pkg/rtc/mock_test.go | 468 ++++++++++++++++++++++++++++++++++-- pkg/rtc/participant.go | 157 +++++++----- pkg/rtc/publishedtrack.go | 24 -- pkg/rtc/room.go | 36 +-- pkg/rtc/room_test.go | 44 ++++ pkg/rtc/utils.go | 10 +- pkg/service/rtc.go | 8 +- 12 files changed, 681 insertions(+), 138 deletions(-) create mode 100644 pkg/rtc/mock_helper_test.go create mode 100644 pkg/rtc/room_test.go diff --git a/go.mod b/go.mod index 4b5075488..69ddd74ef 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/pion/stun v0.3.5 github.com/pion/webrtc/v3 v3.0.0-beta.16 github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.6.1 github.com/twitchtv/twirp v7.1.0+incompatible github.com/urfave/cli/v2 v2.2.0 github.com/urfave/negroni v1.0.0 diff --git a/pkg/rtc/datatrack.go b/pkg/rtc/datatrack.go index dc4f0aca5..3f107cdc4 100644 --- a/pkg/rtc/datatrack.go +++ b/pkg/rtc/datatrack.go @@ -65,9 +65,9 @@ func (t *DataTrack) StreamID() string { return t.dataChannel.Label() } -func (t *DataTrack) AddSubscriber(participant *Participant) error { +func (t *DataTrack) AddSubscriber(participant Participant) error { label := PackDataTrackLabel(t.participantId, t.ID(), t.dataChannel.Label()) - downChannel, err := participant.peerConn.CreateDataChannel(label, t.dataChannelOptions()) + downChannel, err := participant.peerConnection().CreateDataChannel(label, t.dataChannelOptions()) if err != nil { return err } diff --git a/pkg/rtc/interfaces.go b/pkg/rtc/interfaces.go index 64ddc384c..15d876b14 100644 --- a/pkg/rtc/interfaces.go +++ b/pkg/rtc/interfaces.go @@ -6,6 +6,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/webrtc/v3" + "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/proto/livekit" ) @@ -41,3 +42,46 @@ type PeerConnection interface { ConnectionState() webrtc.PeerConnectionState RemoveTrack(sender *webrtc.RTPSender) error } + +type Participant interface { + ID() string + Name() string + State() livekit.ParticipantInfo_State + ToProto() *livekit.ParticipantInfo + Answer(sdp webrtc.SessionDescription) (answer webrtc.SessionDescription, err error) + HandleNegotiate(sd webrtc.SessionDescription) error + SetRemoteDescription(sdp webrtc.SessionDescription) error + AddICECandidate(candidate webrtc.ICECandidateInit) error + + AddSubscriber(op Participant) error + RemoveSubscriber(peerId string) + SendJoinResponse(otherParticipants []Participant) error + SendParticipantUpdate(participants []*livekit.ParticipantInfo) error + + Start() + Close() error + + // callbacks + OnOffer(func(webrtc.SessionDescription)) + OnICECandidate(func(c *webrtc.ICECandidateInit)) + OnStateChange(func(p Participant, oldState livekit.ParticipantInfo_State)) + OnTrackPublished(func(Participant, PublishedTrack)) + OnClose(func(Participant)) + + // package methods + addDownTrack(streamId string, dt *sfu.DownTrack) + removeDownTrack(streamId string, dt *sfu.DownTrack) + peerConnection() PeerConnection +} + +// PublishedTrack is the main interface representing a track published to the room +// it's responsible for managing subscribers and forwarding data from the input track to all subscribers +type PublishedTrack interface { + Start() + ID() string + Kind() livekit.TrackInfo_Type + StreamID() string + AddSubscriber(participant Participant) error + RemoveSubscriber(participantId string) + RemoveAllSubscribers() +} diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 70f651e4d..59d3ebd24 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -84,7 +84,7 @@ func (t *MediaTrack) StreamID() string { // subscribes participant to current remoteTrack // creates and add necessary forwarders and starts them -func (t *MediaTrack) AddSubscriber(participant *Participant) error { +func (t *MediaTrack) AddSubscriber(participant Participant) error { codec := t.remoteTrack.Codec() // pack ID to identify all tracks packedId := PackTrackId(t.participantId, t.id) @@ -101,7 +101,7 @@ func (t *MediaTrack) AddSubscriber(participant *Participant) error { return err } - transceiver, err := participant.peerConn.AddTransceiverFromTrack(outTrack, webrtc.RTPTransceiverInit{ + transceiver, err := participant.peerConnection().AddTransceiverFromTrack(outTrack, webrtc.RTPTransceiverInit{ Direction: webrtc.RTPTransceiverDirectionSendonly, }) if err != nil { @@ -121,12 +121,12 @@ func (t *MediaTrack) AddSubscriber(participant *Participant) error { delete(t.forwarders, participant.ID()) t.lock.Unlock() - if participant.peerConn.ConnectionState() == webrtc.PeerConnectionStateClosed { + if participant.peerConnection().ConnectionState() == webrtc.PeerConnectionStateClosed { return } sender := transceiver.Sender() if sender != nil { - if err := participant.peerConn.RemoveTrack(sender); err != nil { + if err := participant.peerConnection().RemoveTrack(sender); err != nil { if _, ok := err.(*rtcerr.InvalidStateError); !ok { logger.GetLogger().Warnw("could not remove remoteTrack from forwarder", "participant", participant.ID(), diff --git a/pkg/rtc/mock_helper_test.go b/pkg/rtc/mock_helper_test.go new file mode 100644 index 000000000..039386632 --- /dev/null +++ b/pkg/rtc/mock_helper_test.go @@ -0,0 +1,15 @@ +package rtc + +import ( + "github.com/golang/mock/gomock" +) + +func newMockPeerConnection(mockCtrl *gomock.Controller) *MockPeerConnection { + pc := NewMockPeerConnection(mockCtrl) + pc.EXPECT().OnDataChannel(gomock.Any()).AnyTimes() + pc.EXPECT().OnICECandidate(gomock.Any()).AnyTimes() + pc.EXPECT().OnICEConnectionStateChange(gomock.Any()).AnyTimes() + pc.EXPECT().OnNegotiationNeeded(gomock.Any()).AnyTimes() + pc.EXPECT().OnTrack(gomock.Any()).AnyTimes() + return pc +} diff --git a/pkg/rtc/mock_test.go b/pkg/rtc/mock_test.go index 6bacffa57..fd71521ee 100644 --- a/pkg/rtc/mock_test.go +++ b/pkg/rtc/mock_test.go @@ -6,9 +6,10 @@ package rtc import ( gomock "github.com/golang/mock/gomock" + sfu "github.com/livekit/livekit-server/pkg/sfu" livekit "github.com/livekit/livekit-server/proto/livekit" rtcp "github.com/pion/rtcp" - webrtc "github.com/pion/webrtc/v3" + v3 "github.com/pion/webrtc/v3" reflect "reflect" time "time" ) @@ -156,7 +157,7 @@ func (m *MockPeerConnection) EXPECT() *MockPeerConnectionMockRecorder { } // OnICECandidate mocks base method -func (m *MockPeerConnection) OnICECandidate(f func(*webrtc.ICECandidate)) { +func (m *MockPeerConnection) OnICECandidate(f func(*v3.ICECandidate)) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnICECandidate", f) } @@ -168,7 +169,7 @@ func (mr *MockPeerConnectionMockRecorder) OnICECandidate(f interface{}) *gomock. } // OnICEConnectionStateChange mocks base method -func (m *MockPeerConnection) OnICEConnectionStateChange(arg0 func(webrtc.ICEConnectionState)) { +func (m *MockPeerConnection) OnICEConnectionStateChange(arg0 func(v3.ICEConnectionState)) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnICEConnectionStateChange", arg0) } @@ -180,7 +181,7 @@ func (mr *MockPeerConnectionMockRecorder) OnICEConnectionStateChange(arg0 interf } // OnTrack mocks base method -func (m *MockPeerConnection) OnTrack(f func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { +func (m *MockPeerConnection) OnTrack(f func(*v3.TrackRemote, *v3.RTPReceiver)) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnTrack", f) } @@ -192,7 +193,7 @@ func (mr *MockPeerConnectionMockRecorder) OnTrack(f interface{}) *gomock.Call { } // OnDataChannel mocks base method -func (m *MockPeerConnection) OnDataChannel(arg0 func(*webrtc.DataChannel)) { +func (m *MockPeerConnection) OnDataChannel(arg0 func(*v3.DataChannel)) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnDataChannel", arg0) } @@ -230,7 +231,7 @@ func (mr *MockPeerConnectionMockRecorder) Close() *gomock.Call { } // SetRemoteDescription mocks base method -func (m *MockPeerConnection) SetRemoteDescription(desc webrtc.SessionDescription) error { +func (m *MockPeerConnection) SetRemoteDescription(desc v3.SessionDescription) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetRemoteDescription", desc) ret0, _ := ret[0].(error) @@ -244,7 +245,7 @@ func (mr *MockPeerConnectionMockRecorder) SetRemoteDescription(desc interface{}) } // SetLocalDescription mocks base method -func (m *MockPeerConnection) SetLocalDescription(desc webrtc.SessionDescription) error { +func (m *MockPeerConnection) SetLocalDescription(desc v3.SessionDescription) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetLocalDescription", desc) ret0, _ := ret[0].(error) @@ -258,10 +259,10 @@ func (mr *MockPeerConnectionMockRecorder) SetLocalDescription(desc interface{}) } // CreateOffer mocks base method -func (m *MockPeerConnection) CreateOffer(options *webrtc.OfferOptions) (webrtc.SessionDescription, error) { +func (m *MockPeerConnection) CreateOffer(options *v3.OfferOptions) (v3.SessionDescription, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateOffer", options) - ret0, _ := ret[0].(webrtc.SessionDescription) + ret0, _ := ret[0].(v3.SessionDescription) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -273,10 +274,10 @@ func (mr *MockPeerConnectionMockRecorder) CreateOffer(options interface{}) *gomo } // CreateAnswer mocks base method -func (m *MockPeerConnection) CreateAnswer(options *webrtc.AnswerOptions) (webrtc.SessionDescription, error) { +func (m *MockPeerConnection) CreateAnswer(options *v3.AnswerOptions) (v3.SessionDescription, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateAnswer", options) - ret0, _ := ret[0].(webrtc.SessionDescription) + ret0, _ := ret[0].(v3.SessionDescription) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -288,7 +289,7 @@ func (mr *MockPeerConnectionMockRecorder) CreateAnswer(options interface{}) *gom } // AddICECandidate mocks base method -func (m *MockPeerConnection) AddICECandidate(candidate webrtc.ICECandidateInit) error { +func (m *MockPeerConnection) AddICECandidate(candidate v3.ICECandidateInit) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddICECandidate", candidate) ret0, _ := ret[0].(error) @@ -316,10 +317,10 @@ func (mr *MockPeerConnectionMockRecorder) WriteRTCP(pkts interface{}) *gomock.Ca } // CreateDataChannel mocks base method -func (m *MockPeerConnection) CreateDataChannel(label string, options *webrtc.DataChannelInit) (*webrtc.DataChannel, error) { +func (m *MockPeerConnection) CreateDataChannel(label string, options *v3.DataChannelInit) (*v3.DataChannel, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateDataChannel", label, options) - ret0, _ := ret[0].(*webrtc.DataChannel) + ret0, _ := ret[0].(*v3.DataChannel) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -331,14 +332,14 @@ func (mr *MockPeerConnectionMockRecorder) CreateDataChannel(label, options inter } // AddTransceiverFromTrack mocks base method -func (m *MockPeerConnection) AddTransceiverFromTrack(track webrtc.TrackLocal, init ...webrtc.RtpTransceiverInit) (*webrtc.RTPTransceiver, error) { +func (m *MockPeerConnection) AddTransceiverFromTrack(track v3.TrackLocal, init ...v3.RtpTransceiverInit) (*v3.RTPTransceiver, error) { m.ctrl.T.Helper() varargs := []interface{}{track} for _, a := range init { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "AddTransceiverFromTrack", varargs...) - ret0, _ := ret[0].(*webrtc.RTPTransceiver) + ret0, _ := ret[0].(*v3.RTPTransceiver) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -351,10 +352,10 @@ func (mr *MockPeerConnectionMockRecorder) AddTransceiverFromTrack(track interfac } // ConnectionState mocks base method -func (m *MockPeerConnection) ConnectionState() webrtc.PeerConnectionState { +func (m *MockPeerConnection) ConnectionState() v3.PeerConnectionState { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(webrtc.PeerConnectionState) + ret0, _ := ret[0].(v3.PeerConnectionState) return ret0 } @@ -365,7 +366,7 @@ func (mr *MockPeerConnectionMockRecorder) ConnectionState() *gomock.Call { } // RemoveTrack mocks base method -func (m *MockPeerConnection) RemoveTrack(sender *webrtc.RTPSender) error { +func (m *MockPeerConnection) RemoveTrack(sender *v3.RTPSender) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RemoveTrack", sender) ret0, _ := ret[0].(error) @@ -377,3 +378,432 @@ func (mr *MockPeerConnectionMockRecorder) RemoveTrack(sender interface{}) *gomoc mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveTrack", reflect.TypeOf((*MockPeerConnection)(nil).RemoveTrack), sender) } + +// MockParticipant is a mock of Participant interface +type MockParticipant struct { + ctrl *gomock.Controller + recorder *MockParticipantMockRecorder +} + +// MockParticipantMockRecorder is the mock recorder for MockParticipant +type MockParticipantMockRecorder struct { + mock *MockParticipant +} + +// NewMockParticipant creates a new mock instance +func NewMockParticipant(ctrl *gomock.Controller) *MockParticipant { + mock := &MockParticipant{ctrl: ctrl} + mock.recorder = &MockParticipantMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockParticipant) EXPECT() *MockParticipantMockRecorder { + return m.recorder +} + +// ID mocks base method +func (m *MockParticipant) ID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ID") + ret0, _ := ret[0].(string) + return ret0 +} + +// ID indicates an expected call of ID +func (mr *MockParticipantMockRecorder) ID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockParticipant)(nil).ID)) +} + +// Name mocks base method +func (m *MockParticipant) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name +func (mr *MockParticipantMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockParticipant)(nil).Name)) +} + +// State mocks base method +func (m *MockParticipant) State() livekit.ParticipantInfo_State { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "State") + ret0, _ := ret[0].(livekit.ParticipantInfo_State) + return ret0 +} + +// State indicates an expected call of State +func (mr *MockParticipantMockRecorder) State() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockParticipant)(nil).State)) +} + +// ToProto mocks base method +func (m *MockParticipant) ToProto() *livekit.ParticipantInfo { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ToProto") + ret0, _ := ret[0].(*livekit.ParticipantInfo) + return ret0 +} + +// ToProto indicates an expected call of ToProto +func (mr *MockParticipantMockRecorder) ToProto() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ToProto", reflect.TypeOf((*MockParticipant)(nil).ToProto)) +} + +// Answer mocks base method +func (m *MockParticipant) Answer(sdp v3.SessionDescription) (v3.SessionDescription, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Answer", sdp) + ret0, _ := ret[0].(v3.SessionDescription) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Answer indicates an expected call of Answer +func (mr *MockParticipantMockRecorder) Answer(sdp interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Answer", reflect.TypeOf((*MockParticipant)(nil).Answer), sdp) +} + +// HandleNegotiate mocks base method +func (m *MockParticipant) HandleNegotiate(sd v3.SessionDescription) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleNegotiate", sd) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleNegotiate indicates an expected call of HandleNegotiate +func (mr *MockParticipantMockRecorder) HandleNegotiate(sd interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleNegotiate", reflect.TypeOf((*MockParticipant)(nil).HandleNegotiate), sd) +} + +// SetRemoteDescription mocks base method +func (m *MockParticipant) SetRemoteDescription(sdp v3.SessionDescription) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetRemoteDescription", sdp) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetRemoteDescription indicates an expected call of SetRemoteDescription +func (mr *MockParticipantMockRecorder) SetRemoteDescription(sdp interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRemoteDescription", reflect.TypeOf((*MockParticipant)(nil).SetRemoteDescription), sdp) +} + +// AddICECandidate mocks base method +func (m *MockParticipant) AddICECandidate(candidate v3.ICECandidateInit) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddICECandidate", candidate) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddICECandidate indicates an expected call of AddICECandidate +func (mr *MockParticipantMockRecorder) AddICECandidate(candidate interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddICECandidate", reflect.TypeOf((*MockParticipant)(nil).AddICECandidate), candidate) +} + +// AddSubscriber mocks base method +func (m *MockParticipant) AddSubscriber(op Participant) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddSubscriber", op) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddSubscriber indicates an expected call of AddSubscriber +func (mr *MockParticipantMockRecorder) AddSubscriber(op interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSubscriber", reflect.TypeOf((*MockParticipant)(nil).AddSubscriber), op) +} + +// RemoveSubscriber mocks base method +func (m *MockParticipant) RemoveSubscriber(peerId string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveSubscriber", peerId) +} + +// RemoveSubscriber indicates an expected call of RemoveSubscriber +func (mr *MockParticipantMockRecorder) RemoveSubscriber(peerId interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveSubscriber", reflect.TypeOf((*MockParticipant)(nil).RemoveSubscriber), peerId) +} + +// SendJoinResponse mocks base method +func (m *MockParticipant) SendJoinResponse(otherParticipants []Participant) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendJoinResponse", otherParticipants) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendJoinResponse indicates an expected call of SendJoinResponse +func (mr *MockParticipantMockRecorder) SendJoinResponse(otherParticipants interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendJoinResponse", reflect.TypeOf((*MockParticipant)(nil).SendJoinResponse), otherParticipants) +} + +// SendParticipantUpdate mocks base method +func (m *MockParticipant) SendParticipantUpdate(participants []*livekit.ParticipantInfo) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendParticipantUpdate", participants) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendParticipantUpdate indicates an expected call of SendParticipantUpdate +func (mr *MockParticipantMockRecorder) SendParticipantUpdate(participants interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendParticipantUpdate", reflect.TypeOf((*MockParticipant)(nil).SendParticipantUpdate), participants) +} + +// Start mocks base method +func (m *MockParticipant) Start() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Start") +} + +// Start indicates an expected call of Start +func (mr *MockParticipantMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockParticipant)(nil).Start)) +} + +// Close mocks base method +func (m *MockParticipant) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockParticipantMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockParticipant)(nil).Close)) +} + +// OnOffer mocks base method +func (m *MockParticipant) OnOffer(arg0 func(v3.SessionDescription)) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnOffer", arg0) +} + +// OnOffer indicates an expected call of OnOffer +func (mr *MockParticipantMockRecorder) OnOffer(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnOffer", reflect.TypeOf((*MockParticipant)(nil).OnOffer), arg0) +} + +// OnICECandidate mocks base method +func (m *MockParticipant) OnICECandidate(arg0 func(*v3.ICECandidateInit)) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnICECandidate", arg0) +} + +// OnICECandidate indicates an expected call of OnICECandidate +func (mr *MockParticipantMockRecorder) OnICECandidate(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnICECandidate", reflect.TypeOf((*MockParticipant)(nil).OnICECandidate), arg0) +} + +// OnStateChange mocks base method +func (m *MockParticipant) OnStateChange(arg0 func(Participant, livekit.ParticipantInfo_State)) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnStateChange", arg0) +} + +// OnStateChange indicates an expected call of OnStateChange +func (mr *MockParticipantMockRecorder) OnStateChange(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnStateChange", reflect.TypeOf((*MockParticipant)(nil).OnStateChange), arg0) +} + +// OnTrackPublished mocks base method +func (m *MockParticipant) OnTrackPublished(arg0 func(Participant, PublishedTrack)) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnTrackPublished", arg0) +} + +// OnTrackPublished indicates an expected call of OnTrackPublished +func (mr *MockParticipantMockRecorder) OnTrackPublished(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnTrackPublished", reflect.TypeOf((*MockParticipant)(nil).OnTrackPublished), arg0) +} + +// OnClose mocks base method +func (m *MockParticipant) OnClose(arg0 func(Participant)) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnClose", arg0) +} + +// OnClose indicates an expected call of OnClose +func (mr *MockParticipantMockRecorder) OnClose(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnClose", reflect.TypeOf((*MockParticipant)(nil).OnClose), arg0) +} + +// addDownTrack mocks base method +func (m *MockParticipant) addDownTrack(streamId string, dt *sfu.DownTrack) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "addDownTrack", streamId, dt) +} + +// addDownTrack indicates an expected call of addDownTrack +func (mr *MockParticipantMockRecorder) addDownTrack(streamId, dt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "addDownTrack", reflect.TypeOf((*MockParticipant)(nil).addDownTrack), streamId, dt) +} + +// removeDownTrack mocks base method +func (m *MockParticipant) removeDownTrack(streamId string, dt *sfu.DownTrack) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "removeDownTrack", streamId, dt) +} + +// removeDownTrack indicates an expected call of removeDownTrack +func (mr *MockParticipantMockRecorder) removeDownTrack(streamId, dt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "removeDownTrack", reflect.TypeOf((*MockParticipant)(nil).removeDownTrack), streamId, dt) +} + +// peerConnection mocks base method +func (m *MockParticipant) peerConnection() PeerConnection { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "peerConnection") + ret0, _ := ret[0].(PeerConnection) + return ret0 +} + +// peerConnection indicates an expected call of peerConnection +func (mr *MockParticipantMockRecorder) peerConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "peerConnection", reflect.TypeOf((*MockParticipant)(nil).peerConnection)) +} + +// MockPublishedTrack is a mock of PublishedTrack interface +type MockPublishedTrack struct { + ctrl *gomock.Controller + recorder *MockPublishedTrackMockRecorder +} + +// MockPublishedTrackMockRecorder is the mock recorder for MockPublishedTrack +type MockPublishedTrackMockRecorder struct { + mock *MockPublishedTrack +} + +// NewMockPublishedTrack creates a new mock instance +func NewMockPublishedTrack(ctrl *gomock.Controller) *MockPublishedTrack { + mock := &MockPublishedTrack{ctrl: ctrl} + mock.recorder = &MockPublishedTrackMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPublishedTrack) EXPECT() *MockPublishedTrackMockRecorder { + return m.recorder +} + +// Start mocks base method +func (m *MockPublishedTrack) Start() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Start") +} + +// Start indicates an expected call of Start +func (mr *MockPublishedTrackMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockPublishedTrack)(nil).Start)) +} + +// ID mocks base method +func (m *MockPublishedTrack) ID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ID") + ret0, _ := ret[0].(string) + return ret0 +} + +// ID indicates an expected call of ID +func (mr *MockPublishedTrackMockRecorder) ID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockPublishedTrack)(nil).ID)) +} + +// Kind mocks base method +func (m *MockPublishedTrack) Kind() livekit.TrackInfo_Type { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Kind") + ret0, _ := ret[0].(livekit.TrackInfo_Type) + return ret0 +} + +// Kind indicates an expected call of Kind +func (mr *MockPublishedTrackMockRecorder) Kind() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Kind", reflect.TypeOf((*MockPublishedTrack)(nil).Kind)) +} + +// StreamID mocks base method +func (m *MockPublishedTrack) StreamID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(string) + return ret0 +} + +// StreamID indicates an expected call of StreamID +func (mr *MockPublishedTrackMockRecorder) StreamID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockPublishedTrack)(nil).StreamID)) +} + +// AddSubscriber mocks base method +func (m *MockPublishedTrack) AddSubscriber(participant Participant) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddSubscriber", participant) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddSubscriber indicates an expected call of AddSubscriber +func (mr *MockPublishedTrackMockRecorder) AddSubscriber(participant interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSubscriber", reflect.TypeOf((*MockPublishedTrack)(nil).AddSubscriber), participant) +} + +// RemoveSubscriber mocks base method +func (m *MockPublishedTrack) RemoveSubscriber(participantId string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveSubscriber", participantId) +} + +// RemoveSubscriber indicates an expected call of RemoveSubscriber +func (mr *MockPublishedTrackMockRecorder) RemoveSubscriber(participantId interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveSubscriber", reflect.TypeOf((*MockPublishedTrack)(nil).RemoveSubscriber), participantId) +} + +// RemoveAllSubscribers mocks base method +func (m *MockPublishedTrack) RemoveAllSubscribers() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveAllSubscribers") +} + +// RemoveAllSubscribers indicates an expected call of RemoveAllSubscribers +func (mr *MockPublishedTrackMockRecorder) RemoveAllSubscribers() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveAllSubscribers", reflect.TypeOf((*MockPublishedTrack)(nil).RemoveAllSubscribers)) +} diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 757c4613c..8075af506 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -23,7 +23,7 @@ const ( sdBatchSize = 20 ) -type Participant struct { +type ParticipantImpl struct { id string peerConn PeerConnection sigConn SignalConnection @@ -41,14 +41,14 @@ type Participant struct { once sync.Once // callbacks & handlers - // OnTrackPublished - remote peer added a remoteTrack - OnTrackPublished func(*Participant, PublishedTrack) - // OnOffer - offer is ready for remote peer - OnOffer func(webrtc.SessionDescription) + // onTrackPublished - remote peer added a remoteTrack + onTrackPublished func(Participant, PublishedTrack) + // onOffer - offer is ready for remote peer + onOffer func(webrtc.SessionDescription) // OnIceCandidate - ice candidate discovered for local peer - OnICECandidate func(c *webrtc.ICECandidateInit) - OnStateChange func(p *Participant, oldState livekit.ParticipantInfo_State) - OnClose func(*Participant) + onICECandidate func(c *webrtc.ICECandidateInit) + onStateChange func(p Participant, oldState livekit.ParticipantInfo_State) + onClose func(Participant) } func NewPeerConnection(conf *WebRTCConfig) (*webrtc.PeerConnection, error) { @@ -58,7 +58,7 @@ func NewPeerConnection(conf *WebRTCConfig) (*webrtc.PeerConnection, error) { return api.NewPeerConnection(conf.Configuration) } -func NewParticipant(pc PeerConnection, sc SignalConnection, name string) (*Participant, error) { +func NewParticipant(pc PeerConnection, sc SignalConnection, name string) (*ParticipantImpl, error) { me := &webrtc.MediaEngine{} me.RegisterDefaultCodecs() @@ -67,7 +67,7 @@ func NewParticipant(pc PeerConnection, sc SignalConnection, name string) (*Parti ir.Add(bi) ctx, cancel := context.WithCancel(context.Background()) - participant := &Participant{ + participant := &ParticipantImpl{ id: utils.NewGuid(utils.ParticipantPrefix), name: name, peerConn: pc, @@ -104,8 +104,8 @@ func NewParticipant(pc PeerConnection, sc SignalConnection, name string) (*Parti log.Errorw("could not send trickle", "err", err) } - if participant.OnICECandidate != nil { - participant.OnICECandidate(&ci) + if participant.onICECandidate != nil { + participant.onICECandidate(&ci) } }) @@ -122,19 +122,19 @@ func NewParticipant(pc PeerConnection, sc SignalConnection, name string) (*Parti return participant, nil } -func (p *Participant) ID() string { +func (p *ParticipantImpl) ID() string { return p.id } -func (p *Participant) Name() string { +func (p *ParticipantImpl) Name() string { return p.name } -func (p *Participant) State() livekit.ParticipantInfo_State { +func (p *ParticipantImpl) State() livekit.ParticipantInfo_State { return p.state } -func (p *Participant) ToProto() *livekit.ParticipantInfo { +func (p *ParticipantImpl) ToProto() *livekit.ParticipantInfo { info := &livekit.ParticipantInfo{ Sid: p.id, Name: p.name, @@ -142,13 +142,34 @@ func (p *Participant) ToProto() *livekit.ParticipantInfo { } for _, t := range p.tracks { - info.Tracks = append(info.Tracks, TrackToProto(t)) + info.Tracks = append(info.Tracks, ToProtoTrack(t)) } return info } +// callbacks for clients +func (p *ParticipantImpl) OnTrackPublished(callback func(Participant, PublishedTrack)) { + p.onTrackPublished = callback +} + +func (p *ParticipantImpl) OnOffer(callback func(webrtc.SessionDescription)) { + p.onOffer = callback +} + +func (p *ParticipantImpl) OnICECandidate(callback func(c *webrtc.ICECandidateInit)) { + p.onICECandidate = callback +} + +func (p *ParticipantImpl) OnStateChange(callback func(p Participant, oldState livekit.ParticipantInfo_State)) { + p.onStateChange = callback +} + +func (p *ParticipantImpl) OnClose(callback func(Participant)) { + p.onClose = callback +} + // Answer an offer from remote participant -func (p *Participant) Answer(sdp webrtc.SessionDescription) (answer webrtc.SessionDescription, err error) { +func (p *ParticipantImpl) Answer(sdp webrtc.SessionDescription) (answer webrtc.SessionDescription, err error) { if err = p.peerConn.SetRemoteDescription(sdp); err != nil { return } @@ -190,8 +211,8 @@ func (p *Participant) Answer(sdp webrtc.SessionDescription) (answer webrtc.Sessi "err", err) } - if p.OnOffer != nil { - p.OnOffer(offer) + if p.onOffer != nil { + p.onOffer(offer) } }) @@ -208,7 +229,7 @@ func (p *Participant) Answer(sdp webrtc.SessionDescription) (answer webrtc.Sessi } // HandleNegotiate when receiving session description from client -func (p *Participant) HandleNegotiate(sd webrtc.SessionDescription) error { +func (p *ParticipantImpl) HandleNegotiate(sd webrtc.SessionDescription) error { if err := p.peerConn.SetRemoteDescription(sd); err != nil { return errors.Wrap(err, "could not set remote description") } @@ -234,7 +255,7 @@ func (p *Participant) HandleNegotiate(sd webrtc.SessionDescription) error { return nil } -func (p *Participant) SetRemoteDescription(sdp webrtc.SessionDescription) error { +func (p *ParticipantImpl) SetRemoteDescription(sdp webrtc.SessionDescription) error { logger.GetLogger().Debugw("setting remote description", "type", sdp.Type) if err := p.peerConn.SetRemoteDescription(sdp); err != nil { return errors.Wrap(err, "could not set remote description") @@ -243,57 +264,35 @@ func (p *Participant) SetRemoteDescription(sdp webrtc.SessionDescription) error } // AddICECandidate adds candidates for remote peer -func (p *Participant) AddICECandidate(candidate webrtc.ICECandidateInit) error { +func (p *ParticipantImpl) AddICECandidate(candidate webrtc.ICECandidateInit) error { if err := p.peerConn.AddICECandidate(candidate); err != nil { return err } return nil } -func (p *Participant) addDownTrack(streamId string, dt *sfu.DownTrack) { - p.lock.Lock() - p.downTracks[streamId] = append(p.downTracks[streamId], dt) - p.lock.Unlock() - dt.OnBind(func() { - go p.scheduleDownTrackBindingReports(streamId) - }) -} - -func (p *Participant) removeDownTrack(streamId string, dt *sfu.DownTrack) { - p.lock.Lock() - defer p.lock.Unlock() - tracks := p.downTracks[streamId] - newTracks := make([]*sfu.DownTrack, 0, len(tracks)) - for _, track := range tracks { - if track != dt { - newTracks = append(newTracks, track) - } - } - p.downTracks[streamId] = newTracks -} - -func (p *Participant) Start() { +func (p *ParticipantImpl) Start() { p.once.Do(func() { go p.rtcpSendWorker() go p.downTracksRTCPWorker() }) } -func (p *Participant) Close() error { +func (p *ParticipantImpl) Close() error { if p.ctx.Err() != nil { return p.ctx.Err() } close(p.rtcpCh) p.updateState(livekit.ParticipantInfo_DISCONNECTED) - if p.OnClose != nil { - p.OnClose(p) + if p.onClose != nil { + p.onClose(p) } p.cancel() return p.peerConn.Close() } // Subscribes otherPeer to all of the tracks -func (p *Participant) AddSubscriber(op *Participant) error { +func (p *ParticipantImpl) AddSubscriber(op Participant) error { p.lock.RLock() defer p.lock.RUnlock() @@ -309,17 +308,17 @@ func (p *Participant) AddSubscriber(op *Participant) error { return nil } -func (p *Participant) RemoveSubscriber(peerId string) { +func (p *ParticipantImpl) RemoveSubscriber(participantId string) { p.lock.RLock() defer p.lock.RUnlock() for _, track := range p.tracks { - track.RemoveSubscriber(peerId) + track.RemoveSubscriber(participantId) } } // signal connection methods -func (p *Participant) SendJoinResponse(otherParticipants []*Participant) error { +func (p *ParticipantImpl) SendJoinResponse(otherParticipants []Participant) error { // send Join response return p.sigConn.WriteResponse(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Join{ @@ -331,7 +330,7 @@ func (p *Participant) SendJoinResponse(otherParticipants []*Participant) error { }) } -func (p *Participant) SendParticipantUpdate(participants []*livekit.ParticipantInfo) error { +func (p *ParticipantImpl) SendParticipantUpdate(participants []*livekit.ParticipantInfo) error { return p.sigConn.WriteResponse(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Update{ Update: &livekit.ParticipantUpdate{ @@ -341,22 +340,48 @@ func (p *Participant) SendParticipantUpdate(participants []*livekit.ParticipantI }) } -func (p *Participant) updateState(state livekit.ParticipantInfo_State) { +func (p *ParticipantImpl) peerConnection() PeerConnection { + return p.peerConn +} + +func (p *ParticipantImpl) addDownTrack(streamId string, dt *sfu.DownTrack) { + p.lock.Lock() + p.downTracks[streamId] = append(p.downTracks[streamId], dt) + p.lock.Unlock() + dt.OnBind(func() { + go p.scheduleDownTrackBindingReports(streamId) + }) +} + +func (p *ParticipantImpl) removeDownTrack(streamId string, dt *sfu.DownTrack) { + p.lock.Lock() + defer p.lock.Unlock() + tracks := p.downTracks[streamId] + newTracks := make([]*sfu.DownTrack, 0, len(tracks)) + for _, track := range tracks { + if track != dt { + newTracks = append(newTracks, track) + } + } + p.downTracks[streamId] = newTracks +} + +func (p *ParticipantImpl) updateState(state livekit.ParticipantInfo_State) { if state == p.state { return } oldState := p.state p.state = state - if p.OnStateChange != nil { + if p.onStateChange != nil { go func() { - p.OnStateChange(p, oldState) + p.onStateChange(p, oldState) }() } } // when a new remoteTrack is created, creates a Track and adds it to room -func (p *Participant) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { +func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { logger.GetLogger().Debugw("remoteTrack added", "participantId", p.ID(), "remoteTrack", track.ID()) // create Receiver @@ -366,7 +391,7 @@ func (p *Participant) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *webrt p.handleTrackPublished(mt) } -func (p *Participant) onDataChannel(dc *webrtc.DataChannel) { +func (p *ParticipantImpl) onDataChannel(dc *webrtc.DataChannel) { if dc.Label() == placeholderDataChannel { return } @@ -382,7 +407,7 @@ func (p *Participant) onDataChannel(dc *webrtc.DataChannel) { p.handleTrackPublished(dt) } -func (p *Participant) handleTrackPublished(track PublishedTrack) { +func (p *ParticipantImpl) handleTrackPublished(track PublishedTrack) { p.lock.Lock() p.tracks[track.ID()] = track p.lock.Unlock() @@ -392,15 +417,15 @@ func (p *Participant) handleTrackPublished(track PublishedTrack) { // confirm publication p.sigConn.WriteResponse(&livekit.SignalResponse{ Message: &livekit.SignalResponse_TrackPublished{ - TrackPublished: TrackToProto(track), + TrackPublished: ToProtoTrack(track), }, }) - if p.OnTrackPublished != nil { - go p.OnTrackPublished(p, track) + if p.onTrackPublished != nil { + go p.onTrackPublished(p, track) } } -func (p *Participant) scheduleDownTrackBindingReports(streamId string) { +func (p *ParticipantImpl) scheduleDownTrackBindingReports(streamId string) { var sd []rtcp.SourceDescriptionChunk p.lock.RLock() @@ -440,7 +465,7 @@ func (p *Participant) scheduleDownTrackBindingReports(streamId string) { // downTracksRTCPWorker sends SenderReports periodically when the participant is subscribed to // other tracks in the room. -func (p *Participant) downTracksRTCPWorker() { +func (p *ParticipantImpl) downTracksRTCPWorker() { for { time.Sleep(5 * time.Second) @@ -485,7 +510,7 @@ func (p *Participant) downTracksRTCPWorker() { } } -func (p *Participant) rtcpSendWorker() { +func (p *ParticipantImpl) rtcpSendWorker() { // read from rtcpChan for pkts := range p.rtcpCh { for _, pkt := range pkts { diff --git a/pkg/rtc/publishedtrack.go b/pkg/rtc/publishedtrack.go index 143fe1372..f89c1a34a 100644 --- a/pkg/rtc/publishedtrack.go +++ b/pkg/rtc/publishedtrack.go @@ -1,25 +1 @@ package rtc - -import ( - "github.com/livekit/livekit-server/proto/livekit" -) - -// PublishedTrack is the main interface representing a track published to the room -// it's responsible for managing subscribers and forwarding data from the input track to all subscribers -type PublishedTrack interface { - Start() - ID() string - Kind() livekit.TrackInfo_Type - StreamID() string - AddSubscriber(participant *Participant) error - RemoveSubscriber(participantId string) - RemoveAllSubscribers() -} - -func TrackToProto(t PublishedTrack) *livekit.TrackInfo { - return &livekit.TrackInfo{ - Sid: t.ID(), - Type: t.Kind(), - Name: t.StreamID(), - } -} diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index be1920c12..716d9dc8b 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -16,7 +16,7 @@ type Room struct { config WebRTCConfig lock sync.RWMutex // map of participantId -> Participant - participants map[string]*Participant + participants map[string]Participant // Client ID => list of tracks they are publishing //tracks map[string][]Track } @@ -37,11 +37,11 @@ func NewRoomForRequest(req *livekit.CreateRoomRequest, config *WebRTCConfig) (*R }, config: *config, lock: sync.RWMutex{}, - participants: make(map[string]*Participant), + participants: make(map[string]Participant), }, nil } -func (r *Room) GetParticipant(id string) *Participant { +func (r *Room) GetParticipant(id string) Participant { r.lock.RLock() defer r.lock.RUnlock() return r.participants[id] @@ -56,22 +56,22 @@ func (r *Room) ToRoomInfo(node *livekit.Node) *livekit.RoomInfo { } } -func (r *Room) Join(participant *Participant) error { +func (r *Room) Join(participant Participant) error { r.lock.Lock() defer r.lock.Unlock() log := logger.GetLogger() // it's important to set this before connection, we don't want to miss out on any tracks - participant.OnTrackPublished = r.onTrackAdded - participant.OnStateChange = func(p *Participant, oldState livekit.ParticipantInfo_State) { - log.Debugw("participant state changed", "state", p.state, "participant", p.id) + participant.OnTrackPublished(r.onTrackAdded) + participant.OnStateChange(func(p Participant, oldState livekit.ParticipantInfo_State) { + log.Debugw("participant state changed", "state", p.State(), "participant", p.ID()) r.broadcastParticipantState(p) - if oldState == livekit.ParticipantInfo_JOINING && p.state == livekit.ParticipantInfo_JOINED { + if oldState == livekit.ParticipantInfo_JOINING && p.State() == livekit.ParticipantInfo_JOINED { // subscribe participant to existing tracks for _, op := range r.participants { - if p.id == op.id { + if p.ID() == op.ID() { // don't send to itself continue } @@ -85,7 +85,7 @@ func (r *Room) Join(participant *Participant) error { // start the workers once connectivity is established p.Start() } - } + }) log.Infow("new participant joined", "id", participant.ID(), @@ -95,9 +95,9 @@ func (r *Room) Join(participant *Participant) error { r.participants[participant.ID()] = participant // gather other participants and send join response - otherParticipants := make([]*Participant, 0, len(r.participants)) + otherParticipants := make([]Participant, 0, len(r.participants)) for _, p := range r.participants { - if p.id != participant.id { + if p.ID() != participant.ID() { otherParticipants = append(otherParticipants, p) } } @@ -122,8 +122,8 @@ func (r *Room) RemoveParticipant(id string) { delete(r.participants, id) } -// a Participant in the room added a new remoteTrack, subscribe other participants to it -func (r *Room) onTrackAdded(participant *Participant, track PublishedTrack) { +// a ParticipantImpl in the room added a new remoteTrack, subscribe other participants to it +func (r *Room) onTrackAdded(participant Participant, track PublishedTrack) { // publish participant update, since track state is changed r.broadcastParticipantState(participant) @@ -146,21 +146,21 @@ func (r *Room) onTrackAdded(participant *Participant, track PublishedTrack) { } } -func (r *Room) broadcastParticipantState(p *Participant) { +func (r *Room) broadcastParticipantState(p Participant) { r.lock.RLock() defer r.lock.RUnlock() - updates := ToProtoParticipants([]*Participant{p}) + updates := ToProtoParticipants([]Participant{p}) for _, op := range r.participants { // skip itself - if p.id == op.id { + if p.ID() == op.ID() { continue } err := op.SendParticipantUpdate(updates) if err != nil { logger.GetLogger().Errorw("could not send update to participant", - "participant", p.id, + "participant", p.ID(), "err", err) } } diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go new file mode 100644 index 000000000..2d54c274f --- /dev/null +++ b/pkg/rtc/room_test.go @@ -0,0 +1,44 @@ +package rtc + +import ( + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/livekit/livekit-server/proto/livekit" +) + +func TestRoomJoin(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + t.Run("joining returns existing participant data", func(t *testing.T) { + rm := newRoomWithParticipants(t, mockCtrl, 3) + assert.NotNil(t, rm) + pc := newMockPeerConnection(mockCtrl) + sc := NewMockSignalConnection(mockCtrl) + p, err := NewParticipant(pc, sc, "newparticipant") + assert.NoError(t, err) + assert.NotNil(t, p) + + // expect new participant to get a JoinReply + }) +} + +func newRoomWithParticipants(t *testing.T, mockCtrl *gomock.Controller, num int) *Room { + rm, err := NewRoomForRequest(&livekit.CreateRoomRequest{}, &WebRTCConfig{}) + if err != nil { + panic("could not create a room") + } + for i := 0; i < num; i++ { + pc := newMockPeerConnection(mockCtrl) + participant, err := NewParticipant(pc, + NewMockSignalConnection(mockCtrl), + fmt.Sprintf("p%d", i)) + assert.NoError(t, err, "could not create participant for room") + rm.participants[participant.ID()] = participant + } + return rm +} diff --git a/pkg/rtc/utils.go b/pkg/rtc/utils.go index f65b6239f..1e7463b9a 100644 --- a/pkg/rtc/utils.go +++ b/pkg/rtc/utils.go @@ -41,7 +41,7 @@ func UnpackDataTrackLabel(packed string) (peerId string, trackId string, label s return } -func ToProtoParticipants(participants []*Participant) []*livekit.ParticipantInfo { +func ToProtoParticipants(participants []Participant) []*livekit.ParticipantInfo { infos := make([]*livekit.ParticipantInfo, 0, len(participants)) for _, op := range participants { infos = append(infos, op.ToProto()) @@ -87,6 +87,14 @@ func FromProtoTrickle(trickle *livekit.Trickle) webrtc.ICECandidateInit { return ci } +func ToProtoTrack(t PublishedTrack) *livekit.TrackInfo { + return &livekit.TrackInfo{ + Sid: t.ID(), + Type: t.Kind(), + Name: t.StreamID(), + } +} + func IsEOF(err error) bool { return err == io.ErrClosedPipe || err == io.EOF } diff --git a/pkg/service/rtc.go b/pkg/service/rtc.go index e5dbaf2b3..d3f218a7c 100644 --- a/pkg/service/rtc.go +++ b/pkg/service/rtc.go @@ -151,7 +151,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *RTCService) handleOffer(participant *rtc.Participant, offer *livekit.SessionDescription) error { +func (s *RTCService) handleOffer(participant rtc.Participant, offer *livekit.SessionDescription) error { log := logger.GetLogger() _, err := participant.Answer(rtc.FromProtoSessionDescription(offer)) @@ -163,10 +163,10 @@ func (s *RTCService) handleOffer(participant *rtc.Participant, offer *livekit.Se return nil } -func (s *RTCService) handleTrickle(peer *rtc.Participant, trickle *livekit.Trickle) error { +func (s *RTCService) handleTrickle(participant rtc.Participant, trickle *livekit.Trickle) error { candidateInit := rtc.FromProtoTrickle(trickle) - logger.GetLogger().Debugw("adding peer candidate", "participantId", peer.ID()) - if err := peer.AddICECandidate(candidateInit); err != nil { + logger.GetLogger().Debugw("adding peer candidate", "participant", participant.ID()) + if err := participant.AddICECandidate(candidateInit); err != nil { return err }