diff --git a/test/client/client.go b/test/client/client.go index be8317cd2..0cc49c4a4 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -25,6 +25,11 @@ import ( "github.com/livekit/livekit-server/pkg/rtc" ) +const ( + lossyDataChannel = "_lossy" + reliableDataChannel = "_reliable" +) + type RTCClient struct { id string conn *websocket.Conn @@ -43,11 +48,18 @@ type RTCClient struct { localParticipant *livekit.ParticipantInfo remoteParticipants map[string]*livekit.ParticipantInfo + reliableDC *webrtc.DataChannel + reliableDCSub *webrtc.DataChannel + lossyDC *webrtc.DataChannel + lossyDCSub *webrtc.DataChannel + publisherConnected utils.AtomicFlag + // tracks waiting to be acked, cid => trackInfo pendingPublishedTracks map[string]*livekit.TrackInfo pendingTrackWriters []*TrackWriter OnConnected func() + OnDataReceived func(data []byte, sid string) // map of track Id and last packet lastPackets map[string]*rtp.Packet @@ -138,6 +150,22 @@ func NewRTCClient(conn *websocket.Conn) (*RTCClient, error) { return nil, err } + ordered := true + c.reliableDC, err = c.publisher.PeerConnection().CreateDataChannel(reliableDataChannel, + &webrtc.DataChannelInit{Ordered: &ordered}, + ) + if err != nil { + return nil, err + } + + maxRetransmits := uint16(0) + c.lossyDC, err = c.publisher.PeerConnection().CreateDataChannel(lossyDataChannel, + &webrtc.DataChannelInit{Ordered: &ordered, MaxRetransmits: &maxRetransmits}, + ) + if err != nil { + return nil, err + } + c.publisher.PeerConnection().OnICECandidate(func(ic *webrtc.ICECandidate) { if ic == nil { return @@ -155,6 +183,14 @@ func NewRTCClient(conn *websocket.Conn) (*RTCClient, error) { go c.processTrack(track) }) c.subscriber.PeerConnection().OnDataChannel(func(channel *webrtc.DataChannel) { + if channel.Label() == reliableDataChannel { + c.reliableDCSub = channel + } else if channel.Label() == lossyDataChannel { + c.lossyDCSub = channel + } else { + return + } + channel.OnMessage(c.handleDataMessage) }) c.publisher.OnOffer(c.onOffer) @@ -182,6 +218,14 @@ func NewRTCClient(conn *websocket.Conn) (*RTCClient, error) { } }) + c.publisher.PeerConnection().OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + if state == webrtc.ICEConnectionStateConnected { + c.publisherConnected.TrySet(true) + } else { + c.publisherConnected.TrySet(false) + } + }) + return c, nil } @@ -466,6 +510,66 @@ func (c *RTCClient) SendAddTrack(cid string, name string, trackType livekit.Trac }) } +func (c *RTCClient) PublishData(data []byte, kind livekit.DataPacket_Kind) error { + if err := c.ensurePublisherConnected(); err != nil { + return err + } + + dp := &livekit.DataPacket{ + Kind: kind, + Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{Payload: data}, + }, + } + payload, err := proto.Marshal(dp) + if err != nil { + return err + } + if kind == livekit.DataPacket_RELIABLE { + return c.reliableDC.Send(payload) + } else { + return c.lossyDC.Send(payload) + } +} + +func (c *RTCClient) ensurePublisherConnected() error { + if c.publisherConnected.Get() { + return nil + } + + if c.publisher.PeerConnection().ConnectionState() == webrtc.PeerConnectionStateNew { + // start negotiating + c.publisher.Negotiate() + } + + // wait until connected + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + for { + select { + case <-ctx.Done(): + return fmt.Errorf("could not connect publisher after timeout") + case <-time.After(10 * time.Millisecond): + if c.publisherConnected.Get() { + return nil + } + } + } +} + +func (c *RTCClient) handleDataMessage(msg webrtc.DataChannelMessage) { + dp := &livekit.DataPacket{} + err := proto.Unmarshal(msg.Data, dp) + if err != nil { + return + } + if val, ok := dp.Value.(*livekit.DataPacket_User); ok { + if c.OnDataReceived != nil { + c.OnDataReceived(val.User.Payload, val.User.ParticipantSid) + } + } +} + // handles a server initiated offer, handle on subscriber PC func (c *RTCClient) handleOffer(desc webrtc.SessionDescription) error { if err := c.subscriber.SetRemoteDescription(desc); err != nil { diff --git a/test/multinode_test.go b/test/multinode_test.go index b6287c41f..97d95f41d 100644 --- a/test/multinode_test.go +++ b/test/multinode_test.go @@ -125,4 +125,14 @@ func TestMultinodeReconnectAfterNodeShutdown(t *testing.T) { waitUntilConnected(t, c3) } -// TODO: test room with protocol version 1 and 0 participants +func TestMultinodeDataPublishing(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, _, finish := setupMultiNodeTest("TestMultinodeDataPublishing") + defer finish() + + scenarioDataPublish(t) +} diff --git a/test/scenarios.go b/test/scenarios.go index dd79a9cd9..0ff28515d 100644 --- a/test/scenarios.go +++ b/test/scenarios.go @@ -5,6 +5,8 @@ import ( "time" "github.com/livekit/protocol/logger" + livekit "github.com/livekit/protocol/proto" + "github.com/livekit/protocol/utils" "github.com/stretchr/testify/require" "github.com/livekit/livekit-server/pkg/testutils" @@ -13,11 +15,9 @@ import ( // a scenario with lots of clients connecting, publishing, and leaving at random periods func scenarioPublishingUponJoining(t *testing.T, ports ...int) { - firstPort := ports[0] - lastPort := ports[len(ports)-1] - c1 := createRTCClient("puj_1", firstPort, nil) - c2 := createRTCClient("puj_2", lastPort, &testclient.Options{AutoSubscribe: true}) - c3 := createRTCClient("puj_3", firstPort, &testclient.Options{AutoSubscribe: true}) + c1 := createRTCClient("puj_1", defaultServerPort, nil) + c2 := createRTCClient("puj_2", secondServerPort, &testclient.Options{AutoSubscribe: true}) + c3 := createRTCClient("puj_3", defaultServerPort, &testclient.Options{AutoSubscribe: true}) defer stopClients(c1, c2, c3) waitUntilConnected(t, c1, c2, c3) @@ -66,7 +66,7 @@ func scenarioPublishingUponJoining(t *testing.T, ports ...int) { logger.Infow("c2 reconnecting") // connect to a diff port - c2 = createRTCClient("puj_2", firstPort, nil) + c2 = createRTCClient("puj_2", defaultServerPort, nil) defer c2.Stop() waitUntilConnected(t, c2) writers = publishTracksForClients(t, c2) @@ -119,6 +119,24 @@ func scenarioReceiveBeforePublish(t *testing.T) { require.Empty(t, c1.RemoteParticipants()) } +func scenarioDataPublish(t *testing.T) { + c1 := createRTCClient("dp1", defaultServerPort, nil) + c2 := createRTCClient("dp2", secondServerPort, nil) + waitUntilConnected(t, c1, c2) + defer stopClients(c1, c2) + + payload := "test bytes" + + received := utils.AtomicFlag{} + c2.OnDataReceived = func(data []byte, sid string) { + if string(data) == payload && sid == c2.ID() { + received.TrySet(true) + } + } + + require.NoError(t, c1.PublishData([]byte(payload), livekit.DataPacket_LOSSY)) +} + // websocket reconnects func scenarioWSReconnect(t *testing.T) { c1 := createRTCClient("wsr_1", defaultServerPort, nil)