From db4bc127e87b772c066d3325a1a6d3310eaeb0ec Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Fri, 1 Aug 2025 18:50:28 +0530 Subject: [PATCH] Get to the point of connecting publisher PC and using it for async signalling (#3822) * starting signalling DC work * WIP * plumbing data channel * add datachannel message sink file * mage generate * clean up --- pkg/rtc/participant.go | 32 ++++- pkg/rtc/signalling/datachannel_messagesink.go | 77 +++++++++++ pkg/rtc/signalling/interfaces.go | 3 +- pkg/rtc/signalling/signalhandler.go | 11 +- .../signalling/signalhandlerunimplemented.go | 6 +- pkg/rtc/signalling/signalhandlerv2.go | 13 +- pkg/rtc/transport.go | 50 ++++++- pkg/rtc/transport/handler.go | 11 +- .../transport/transportfakes/fake_handler.go | 123 ++++++++++++++++++ pkg/rtc/types/interfaces.go | 2 +- .../typesfakes/fake_local_participant.go | 78 +++++------ pkg/service/roommanager.go | 2 +- pkg/service/roommanager_service.go | 2 +- pkg/service/rtcv2service.go | 9 +- 14 files changed, 359 insertions(+), 60 deletions(-) create mode 100644 pkg/rtc/signalling/datachannel_messagesink.go diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index b1358f9a0..f56855fb1 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -54,6 +54,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/datachannel" "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/livekit-server/pkg/sfu/pacer" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" @@ -147,6 +148,8 @@ type reliableDataInfo struct { // --------------------------------------------------------------- +var _ types.LocalParticipant = (*ParticipantImpl)(nil) + type ParticipantParams struct { Identity livekit.ParticipantIdentity Name livekit.ParticipantName @@ -314,7 +317,7 @@ type ParticipantImpl struct { metricsReporter *metric.MetricsReporter signalling signalling.ParticipantSignalling - signalhandler signalling.ParticipantSignalHandler + signalHandler signalling.ParticipantSignalHandler signaller signalling.ParticipantSignaller // loggers for publisher and subscriber @@ -1733,6 +1736,8 @@ func (p *ParticipantImpl) UpdateMediaRTT(rtt uint32) { // ---------------------------------------------------------- +var _ transport.Handler = (*AnyTransportHandler)(nil) + type AnyTransportHandler struct { transport.UnimplementedHandler p *ParticipantImpl @@ -1776,6 +1781,23 @@ func (h PublisherTransportHandler) OnDataMessageUnlabeled(data []byte) { h.p.onReceivedDataMessageUnlabeled(data) } +func (h PublisherTransportHandler) OnDataChannelOpenSignalling(dc *datachannel.DataChannelWriter[*webrtc.DataChannel]) { + sink := signalling.NewDataChannelMessageSink(signalling.DataChannelMessageSinkParams{ + Logger: h.p.params.Logger, + DataChannel: dc, + }) + h.p.signaller.SetResponseSink(sink) +} + +func (h PublisherTransportHandler) OnDataChannelCloseSignalling(dc *datachannel.DataChannelWriter[*webrtc.DataChannel]) { + // SIGNALLING-V2-TODO: check that the closed data channel is actually the same as response sink + h.p.signaller.SetResponseSink(nil) +} + +func (h PublisherTransportHandler) OnDataMessageSignalling(data []byte) { + h.p.signalHandler.HandleEncodedMessage(data) +} + func (h PublisherTransportHandler) OnDataSendError(err error) { h.p.onDataSendError(err) } @@ -1824,7 +1846,7 @@ func (p *ParticipantImpl) setupSignalling() { p.signalling = signalling.NewSignalling(signalling.SignallingParams{ Logger: p.params.Logger, }) - p.signalhandler = signalling.NewSignalHandler(signalling.SignalHandlerParams{ + p.signalHandler = signalling.NewSignalHandler(signalling.SignalHandlerParams{ Logger: p.params.Logger, Participant: p, }) @@ -1836,7 +1858,7 @@ func (p *ParticipantImpl) setupSignalling() { p.signalling = signalling.NewSignallingv2(signalling.Signallingv2Params{ Logger: p.params.Logger, }) - p.signalhandler = signalling.NewSignalHandlerv2(signalling.SignalHandlerv2Params{ + p.signalHandler = signalling.NewSignalHandlerv2(signalling.SignalHandlerv2Params{ Logger: p.params.Logger, Participant: p, Signalling: p.signalling, @@ -3782,6 +3804,6 @@ func (p *ParticipantImpl) HandleLeaveRequest(reason types.ParticipantCloseReason } } -func (p *ParticipantImpl) HandleSignalRequest(msg proto.Message) error { - return p.signalhandler.HandleRequest(msg) +func (p *ParticipantImpl) HandleSignalMessage(msg proto.Message) error { + return p.signalHandler.HandleMessage(msg) } diff --git a/pkg/rtc/signalling/datachannel_messagesink.go b/pkg/rtc/signalling/datachannel_messagesink.go new file mode 100644 index 000000000..7848af611 --- /dev/null +++ b/pkg/rtc/signalling/datachannel_messagesink.go @@ -0,0 +1,77 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/pion/webrtc/v4" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/sfu/datachannel" + + "google.golang.org/protobuf/proto" +) + +var _ routing.MessageSink = (*dataChannelMessageSink)(nil) + +type DataChannelMessageSinkParams struct { + Logger logger.Logger + DataChannel *datachannel.DataChannelWriter[*webrtc.DataChannel] +} + +type dataChannelMessageSink struct { + params DataChannelMessageSinkParams +} + +func NewDataChannelMessageSink(params DataChannelMessageSinkParams) routing.MessageSink { + return &dataChannelMessageSink{ + params: params, + } +} + +func (d *dataChannelMessageSink) WriteMessage(msg proto.Message) error { + if msg == nil { + return nil + } + + protoMsg, err := proto.Marshal(msg) + if err != nil { + d.params.Logger.Errorw("could not marshal message", err) + return err + } + + if _, err := d.params.DataChannel.Write(protoMsg); err != nil { + // SIGNALLING-V2-TODO: filter out logging expected errors + d.params.Logger.Errorw("could not send message", err) + return err + } + + return nil +} + +func (d *dataChannelMessageSink) IsClosed() bool { + // SIGNALLING-V2-TODO + return false +} + +func (d *dataChannelMessageSink) Close() { + // SIGNALLING-V2-TODO +} + +func (d *dataChannelMessageSink) ConnectionID() livekit.ConnectionID { + // SIGNALLING-V2-TODO + return "" +} diff --git a/pkg/rtc/signalling/interfaces.go b/pkg/rtc/signalling/interfaces.go index 50f5a1257..4b31b92ec 100644 --- a/pkg/rtc/signalling/interfaces.go +++ b/pkg/rtc/signalling/interfaces.go @@ -24,7 +24,8 @@ import ( ) type ParticipantSignalHandler interface { - HandleRequest(msg proto.Message) error + HandleMessage(msg proto.Message) error + HandleEncodedMessage(data []byte) error PruneStaleReassemblies() } diff --git a/pkg/rtc/signalling/signalhandler.go b/pkg/rtc/signalling/signalhandler.go index 9f4b445e1..4e4e08c65 100644 --- a/pkg/rtc/signalling/signalhandler.go +++ b/pkg/rtc/signalling/signalhandler.go @@ -45,7 +45,7 @@ func NewSignalHandler(params SignalHandlerParams) ParticipantSignalHandler { } // SIGNALLING-V2-TODO: consolidate base message handling for messages common to different signalling versions -func (s *signalhandler) HandleRequest(msg proto.Message) error { +func (s *signalhandler) HandleMessage(msg proto.Message) error { req, ok := msg.(*livekit.SignalRequest) if !ok { s.params.Logger.Warnw( @@ -194,3 +194,12 @@ func (s *signalhandler) HandleRequest(msg proto.Message) error { return nil } + +func (s *signalhandler) HandleEncodedMessage(data []byte) error { + signalRequest := &livekit.SignalRequest{} + if err := proto.Unmarshal(data, signalRequest); err != nil { + return err + } + + return s.HandleMessage(signalRequest) +} diff --git a/pkg/rtc/signalling/signalhandlerunimplemented.go b/pkg/rtc/signalling/signalhandlerunimplemented.go index bfcaf74c3..91fb85502 100644 --- a/pkg/rtc/signalling/signalhandlerunimplemented.go +++ b/pkg/rtc/signalling/signalhandlerunimplemented.go @@ -22,7 +22,11 @@ var _ ParticipantSignalHandler = (*signalhandlerUnimplemented)(nil) type signalhandlerUnimplemented struct{} -func (u *signalhandlerUnimplemented) HandleRequest(msg proto.Message) error { +func (u *signalhandlerUnimplemented) HandleMessage(msg proto.Message) error { + return nil +} + +func (u *signalhandlerUnimplemented) HandleEncodedMessage(data []byte) error { return nil } diff --git a/pkg/rtc/signalling/signalhandlerv2.go b/pkg/rtc/signalling/signalhandlerv2.go index 3ca35951f..b6e8a4421 100644 --- a/pkg/rtc/signalling/signalhandlerv2.go +++ b/pkg/rtc/signalling/signalhandlerv2.go @@ -54,7 +54,7 @@ func NewSignalHandlerv2(params SignalHandlerv2Params) ParticipantSignalHandler { } } -func (s *signalhandlerv2) HandleRequest(msg proto.Message) error { +func (s *signalhandlerv2) HandleMessage(msg proto.Message) error { req, ok := msg.(*livekit.Signalv2WireMessage) if !ok { s.params.Logger.Warnw( @@ -129,13 +129,22 @@ func (s *signalhandlerv2) HandleRequest(msg proto.Message) error { return err } - s.HandleRequest(wireMessage) + s.HandleMessage(wireMessage) } } return nil } +func (s *signalhandlerv2) HandleEncodedMessage(data []byte) error { + wireMessage := &livekit.Signalv2WireMessage{} + if err := proto.Unmarshal(data, wireMessage); err != nil { + return err + } + + return s.HandleMessage(wireMessage) +} + func (s *signalhandlerv2) PruneStaleReassemblies() { s.signalReassembler.Prune() } diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index eba3b4bd0..b56f6c17e 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -62,8 +62,9 @@ import ( ) const ( - LossyDataChannel = "_lossy" - ReliableDataChannel = "_reliable" + LossyDataChannel = "_lossy" + ReliableDataChannel = "_reliable" + SignallingDataChannel = "_signalling" fastNegotiationFrequency = 10 * time.Millisecond negotiationFrequency = 150 * time.Millisecond @@ -204,6 +205,7 @@ type PCTransport struct { lossyDC *datachannel.DataChannelWriter[*webrtc.DataChannel] lossyDCOpened bool unlabeledDataChannels []*datachannel.DataChannelWriter[*webrtc.DataChannel] + signallingDataChannel *datachannel.DataChannelWriter[*webrtc.DataChannel] iceStartedAt time.Time iceConnectedAt time.Time @@ -553,7 +555,8 @@ func (t *PCTransport) createPeerConnection() (cc.BandwidthEstimator, error) { } t.pc = pc - if !t.params.UseOneShotSignallingMode && !t.params.SynchronousLocalCandidatesMode { + // SIGNALLING-V2-TODO: have to support both sync and async candidates, so has to be a check at function level + if !t.params.UseOneShotSignallingMode /* SIGNALLING-V2-TODO && !t.params.SynchronousLocalCandidatesMode */ { // one shot signalling mode gathers all candidates and sends in answer t.pc.OnICEGatheringStateChange(t.onICEGatheringStateChange) t.pc.OnICECandidate(t.onICECandidateTrickle) @@ -806,6 +809,7 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { t.params.Logger.Debugw(dc.Label() + " data channel open") var kind livekit.DataPacket_Kind var isUnlabeled bool + var isSignalling bool switch dc.Label() { case ReliableDataChannel: kind = livekit.DataPacket_RELIABLE @@ -813,6 +817,10 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { case LossyDataChannel: kind = livekit.DataPacket_LOSSY + case SignallingDataChannel: + t.params.Logger.Infow("signalling datachannel added", "label", dc.Label()) + isSignalling = true + default: t.params.Logger.Infow("unlabeled datachannel added", "label", dc.Label()) isUnlabeled = true @@ -833,6 +841,13 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { ) t.lock.Unlock() + case isSignalling: + t.lock.Lock() + signallingDataChannel := datachannel.NewDataChannelWriter(dc, rawDC, 0) + t.signallingDataChannel = signallingDataChannel + t.lock.Unlock() + t.params.Handler.OnDataChannelOpenSignalling(signallingDataChannel) + case kind == livekit.DataPacket_RELIABLE: t.lock.Lock() if t.reliableDC != nil { @@ -864,9 +879,14 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { return } - if isUnlabeled { + switch { + case isUnlabeled: t.params.Handler.OnDataMessageUnlabeled(buffer[:n]) - } else { + + case isSignalling: + t.params.Handler.OnDataMessageSignalling(buffer[:n]) + + default: t.params.Handler.OnDataMessage(kind, buffer[:n]) } } @@ -874,6 +894,18 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { t.maybeNotifyFullyEstablished() }) + + dc.OnClose(func() { + t.params.Logger.Debugw(dc.Label() + " data channel close") + switch dc.Label() { + case SignallingDataChannel: + t.lock.RLock() + signallingDataChannel := t.signallingDataChannel + t.lock.RUnlock() + + t.params.Handler.OnDataChannelCloseSignalling(signallingDataChannel) + } + }) } func (t *PCTransport) maybeNotifyFullyEstablished() { @@ -1298,8 +1330,10 @@ func (t *PCTransport) clearConnTimer() { } } +// SIGNALLING-V2-TODO: this needs both sync and async support when not in one shot mode, +// cannot use the state `SynchronousLocalCandidatesMode`, needs a flag at function level func (t *PCTransport) HandleRemoteDescription(sd webrtc.SessionDescription, remoteId uint32) error { - if t.params.UseOneShotSignallingMode || t.params.SynchronousLocalCandidatesMode { + if t.params.UseOneShotSignallingMode /* SIGNALLING-V2-TODO || t.params.SynchronousLocalCandidatesMode */ { if sd.Type == webrtc.SDPTypeOffer { remoteOfferId := t.remoteOfferId.Load() if remoteOfferId != 0 && remoteOfferId != t.localAnswerId.Load() { @@ -1366,6 +1400,8 @@ func (t *PCTransport) HandleRemoteDescription(sd webrtc.SessionDescription, remo return nil } +// SIGNALLING-V2-TODO: use a flag at function level for sync vs async rather +// then state `SynchronousLocalCandidatesMode` func (t *PCTransport) GetAnswer() (webrtc.SessionDescription, uint32, error) { if !t.params.UseOneShotSignallingMode && !t.params.SynchronousLocalCandidatesMode { return webrtc.SessionDescription{}, 0, ErrNotSynchronousLocalCandidatesMode @@ -1417,6 +1453,8 @@ func (t *PCTransport) GetAnswer() (webrtc.SessionDescription, uint32, error) { return *cld, answerId, nil } +// SIGNALLING-V2-TODO: use a flag at function level for sync vs async rather +// then state `SynchronousLocalCandidatesMode` func (t *PCTransport) GetOffer() (webrtc.SessionDescription, uint32, error) { if !t.params.SynchronousLocalCandidatesMode { return webrtc.SessionDescription{}, 0, ErrNotSynchronousLocalCandidatesMode diff --git a/pkg/rtc/transport/handler.go b/pkg/rtc/transport/handler.go index e43e3bd7b..82a47dfb3 100644 --- a/pkg/rtc/transport/handler.go +++ b/pkg/rtc/transport/handler.go @@ -20,6 +20,7 @@ import ( "github.com/pion/webrtc/v4" "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/datachannel" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" "github.com/livekit/protocol/livekit" ) @@ -41,6 +42,9 @@ type Handler interface { OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) OnDataMessage(kind livekit.DataPacket_Kind, data []byte) OnDataMessageUnlabeled(data []byte) + OnDataChannelOpenSignalling(dc *datachannel.DataChannelWriter[*webrtc.DataChannel]) + OnDataChannelCloseSignalling(dc *datachannel.DataChannelWriter[*webrtc.DataChannel]) + OnDataMessageSignalling(data []byte) OnDataSendError(err error) OnOffer(sd webrtc.SessionDescription, offerId uint32) error OnAnswer(sd webrtc.SessionDescription, answerId uint32) error @@ -60,7 +64,12 @@ func (h UnimplementedHandler) OnFailed(isShortLived bool) func (h UnimplementedHandler) OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) {} func (h UnimplementedHandler) OnDataMessage(kind livekit.DataPacket_Kind, data []byte) {} func (h UnimplementedHandler) OnDataMessageUnlabeled(data []byte) {} -func (h UnimplementedHandler) OnDataSendError(err error) {} +func (h UnimplementedHandler) OnDataChannelOpenSignalling(dc *datachannel.DataChannelWriter[*webrtc.DataChannel]) { +} +func (h UnimplementedHandler) OnDataChannelCloseSignalling(dc *datachannel.DataChannelWriter[*webrtc.DataChannel]) { +} +func (h UnimplementedHandler) OnDataMessageSignalling(data []byte) {} +func (h UnimplementedHandler) OnDataSendError(err error) {} func (h UnimplementedHandler) OnOffer(sd webrtc.SessionDescription, offerId uint32) error { return ErrNoOfferHandler } diff --git a/pkg/rtc/transport/transportfakes/fake_handler.go b/pkg/rtc/transport/transportfakes/fake_handler.go index fd46b60ba..64d650ded 100644 --- a/pkg/rtc/transport/transportfakes/fake_handler.go +++ b/pkg/rtc/transport/transportfakes/fake_handler.go @@ -6,6 +6,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/datachannel" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" "github.com/livekit/protocol/livekit" webrtc "github.com/pion/webrtc/v4" @@ -24,12 +25,27 @@ type FakeHandler struct { onAnswerReturnsOnCall map[int]struct { result1 error } + OnDataChannelCloseSignallingStub func(*datachannel.DataChannelWriter[*webrtc.DataChannel]) + onDataChannelCloseSignallingMutex sync.RWMutex + onDataChannelCloseSignallingArgsForCall []struct { + arg1 *datachannel.DataChannelWriter[*webrtc.DataChannel] + } + OnDataChannelOpenSignallingStub func(*datachannel.DataChannelWriter[*webrtc.DataChannel]) + onDataChannelOpenSignallingMutex sync.RWMutex + onDataChannelOpenSignallingArgsForCall []struct { + arg1 *datachannel.DataChannelWriter[*webrtc.DataChannel] + } OnDataMessageStub func(livekit.DataPacket_Kind, []byte) onDataMessageMutex sync.RWMutex onDataMessageArgsForCall []struct { arg1 livekit.DataPacket_Kind arg2 []byte } + OnDataMessageSignallingStub func([]byte) + onDataMessageSignallingMutex sync.RWMutex + onDataMessageSignallingArgsForCall []struct { + arg1 []byte + } OnDataMessageUnlabeledStub func([]byte) onDataMessageUnlabeledMutex sync.RWMutex onDataMessageUnlabeledArgsForCall []struct { @@ -170,6 +186,70 @@ func (fake *FakeHandler) OnAnswerReturnsOnCall(i int, result1 error) { }{result1} } +func (fake *FakeHandler) OnDataChannelCloseSignalling(arg1 *datachannel.DataChannelWriter[*webrtc.DataChannel]) { + fake.onDataChannelCloseSignallingMutex.Lock() + fake.onDataChannelCloseSignallingArgsForCall = append(fake.onDataChannelCloseSignallingArgsForCall, struct { + arg1 *datachannel.DataChannelWriter[*webrtc.DataChannel] + }{arg1}) + stub := fake.OnDataChannelCloseSignallingStub + fake.recordInvocation("OnDataChannelCloseSignalling", []interface{}{arg1}) + fake.onDataChannelCloseSignallingMutex.Unlock() + if stub != nil { + fake.OnDataChannelCloseSignallingStub(arg1) + } +} + +func (fake *FakeHandler) OnDataChannelCloseSignallingCallCount() int { + fake.onDataChannelCloseSignallingMutex.RLock() + defer fake.onDataChannelCloseSignallingMutex.RUnlock() + return len(fake.onDataChannelCloseSignallingArgsForCall) +} + +func (fake *FakeHandler) OnDataChannelCloseSignallingCalls(stub func(*datachannel.DataChannelWriter[*webrtc.DataChannel])) { + fake.onDataChannelCloseSignallingMutex.Lock() + defer fake.onDataChannelCloseSignallingMutex.Unlock() + fake.OnDataChannelCloseSignallingStub = stub +} + +func (fake *FakeHandler) OnDataChannelCloseSignallingArgsForCall(i int) *datachannel.DataChannelWriter[*webrtc.DataChannel] { + fake.onDataChannelCloseSignallingMutex.RLock() + defer fake.onDataChannelCloseSignallingMutex.RUnlock() + argsForCall := fake.onDataChannelCloseSignallingArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnDataChannelOpenSignalling(arg1 *datachannel.DataChannelWriter[*webrtc.DataChannel]) { + fake.onDataChannelOpenSignallingMutex.Lock() + fake.onDataChannelOpenSignallingArgsForCall = append(fake.onDataChannelOpenSignallingArgsForCall, struct { + arg1 *datachannel.DataChannelWriter[*webrtc.DataChannel] + }{arg1}) + stub := fake.OnDataChannelOpenSignallingStub + fake.recordInvocation("OnDataChannelOpenSignalling", []interface{}{arg1}) + fake.onDataChannelOpenSignallingMutex.Unlock() + if stub != nil { + fake.OnDataChannelOpenSignallingStub(arg1) + } +} + +func (fake *FakeHandler) OnDataChannelOpenSignallingCallCount() int { + fake.onDataChannelOpenSignallingMutex.RLock() + defer fake.onDataChannelOpenSignallingMutex.RUnlock() + return len(fake.onDataChannelOpenSignallingArgsForCall) +} + +func (fake *FakeHandler) OnDataChannelOpenSignallingCalls(stub func(*datachannel.DataChannelWriter[*webrtc.DataChannel])) { + fake.onDataChannelOpenSignallingMutex.Lock() + defer fake.onDataChannelOpenSignallingMutex.Unlock() + fake.OnDataChannelOpenSignallingStub = stub +} + +func (fake *FakeHandler) OnDataChannelOpenSignallingArgsForCall(i int) *datachannel.DataChannelWriter[*webrtc.DataChannel] { + fake.onDataChannelOpenSignallingMutex.RLock() + defer fake.onDataChannelOpenSignallingMutex.RUnlock() + argsForCall := fake.onDataChannelOpenSignallingArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeHandler) OnDataMessage(arg1 livekit.DataPacket_Kind, arg2 []byte) { var arg2Copy []byte if arg2 != nil { @@ -208,6 +288,43 @@ func (fake *FakeHandler) OnDataMessageArgsForCall(i int) (livekit.DataPacket_Kin return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeHandler) OnDataMessageSignalling(arg1 []byte) { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.onDataMessageSignallingMutex.Lock() + fake.onDataMessageSignallingArgsForCall = append(fake.onDataMessageSignallingArgsForCall, struct { + arg1 []byte + }{arg1Copy}) + stub := fake.OnDataMessageSignallingStub + fake.recordInvocation("OnDataMessageSignalling", []interface{}{arg1Copy}) + fake.onDataMessageSignallingMutex.Unlock() + if stub != nil { + fake.OnDataMessageSignallingStub(arg1) + } +} + +func (fake *FakeHandler) OnDataMessageSignallingCallCount() int { + fake.onDataMessageSignallingMutex.RLock() + defer fake.onDataMessageSignallingMutex.RUnlock() + return len(fake.onDataMessageSignallingArgsForCall) +} + +func (fake *FakeHandler) OnDataMessageSignallingCalls(stub func([]byte)) { + fake.onDataMessageSignallingMutex.Lock() + defer fake.onDataMessageSignallingMutex.Unlock() + fake.OnDataMessageSignallingStub = stub +} + +func (fake *FakeHandler) OnDataMessageSignallingArgsForCall(i int) []byte { + fake.onDataMessageSignallingMutex.RLock() + defer fake.onDataMessageSignallingMutex.RUnlock() + argsForCall := fake.onDataMessageSignallingArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeHandler) OnDataMessageUnlabeled(arg1 []byte) { var arg1Copy []byte if arg1 != nil { @@ -637,8 +754,14 @@ func (fake *FakeHandler) Invocations() map[string][][]interface{} { defer fake.invocationsMutex.RUnlock() fake.onAnswerMutex.RLock() defer fake.onAnswerMutex.RUnlock() + fake.onDataChannelCloseSignallingMutex.RLock() + defer fake.onDataChannelCloseSignallingMutex.RUnlock() + fake.onDataChannelOpenSignallingMutex.RLock() + defer fake.onDataChannelOpenSignallingMutex.RUnlock() fake.onDataMessageMutex.RLock() defer fake.onDataMessageMutex.RUnlock() + fake.onDataMessageSignallingMutex.RLock() + defer fake.onDataMessageSignallingMutex.RUnlock() fake.onDataMessageUnlabeledMutex.RLock() defer fake.onDataMessageUnlabeledMutex.RUnlock() fake.onDataSendErrorMutex.RLock() diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 132bb0f0a..61f3a4e69 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -534,7 +534,7 @@ type LocalParticipant interface { HandleSimulateScenario(*livekit.SimulateScenario) error HandleLeaveRequest(reason ParticipantCloseReason) - HandleSignalRequest(msg proto.Message) error + HandleSignalMessage(msg proto.Message) error } // Room is a container of participants, and can provide room-level actions diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 8d61eed90..ab1ba0dfb 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -588,15 +588,15 @@ type FakeLocalParticipant struct { handleReconnectAndSendResponseReturnsOnCall map[int]struct { result1 error } - HandleSignalRequestStub func(proto.Message) error - handleSignalRequestMutex sync.RWMutex - handleSignalRequestArgsForCall []struct { + HandleSignalMessageStub func(proto.Message) error + handleSignalMessageMutex sync.RWMutex + handleSignalMessageArgsForCall []struct { arg1 proto.Message } - handleSignalRequestReturns struct { + handleSignalMessageReturns struct { result1 error } - handleSignalRequestReturnsOnCall map[int]struct { + handleSignalMessageReturnsOnCall map[int]struct { result1 error } HandleSignalSourceCloseStub func() @@ -4397,16 +4397,16 @@ func (fake *FakeLocalParticipant) HandleReconnectAndSendResponseReturnsOnCall(i }{result1} } -func (fake *FakeLocalParticipant) HandleSignalRequest(arg1 proto.Message) error { - fake.handleSignalRequestMutex.Lock() - ret, specificReturn := fake.handleSignalRequestReturnsOnCall[len(fake.handleSignalRequestArgsForCall)] - fake.handleSignalRequestArgsForCall = append(fake.handleSignalRequestArgsForCall, struct { +func (fake *FakeLocalParticipant) HandleSignalMessage(arg1 proto.Message) error { + fake.handleSignalMessageMutex.Lock() + ret, specificReturn := fake.handleSignalMessageReturnsOnCall[len(fake.handleSignalMessageArgsForCall)] + fake.handleSignalMessageArgsForCall = append(fake.handleSignalMessageArgsForCall, struct { arg1 proto.Message }{arg1}) - stub := fake.HandleSignalRequestStub - fakeReturns := fake.handleSignalRequestReturns - fake.recordInvocation("HandleSignalRequest", []interface{}{arg1}) - fake.handleSignalRequestMutex.Unlock() + stub := fake.HandleSignalMessageStub + fakeReturns := fake.handleSignalMessageReturns + fake.recordInvocation("HandleSignalMessage", []interface{}{arg1}) + fake.handleSignalMessageMutex.Unlock() if stub != nil { return stub(arg1) } @@ -4416,44 +4416,44 @@ func (fake *FakeLocalParticipant) HandleSignalRequest(arg1 proto.Message) error return fakeReturns.result1 } -func (fake *FakeLocalParticipant) HandleSignalRequestCallCount() int { - fake.handleSignalRequestMutex.RLock() - defer fake.handleSignalRequestMutex.RUnlock() - return len(fake.handleSignalRequestArgsForCall) +func (fake *FakeLocalParticipant) HandleSignalMessageCallCount() int { + fake.handleSignalMessageMutex.RLock() + defer fake.handleSignalMessageMutex.RUnlock() + return len(fake.handleSignalMessageArgsForCall) } -func (fake *FakeLocalParticipant) HandleSignalRequestCalls(stub func(proto.Message) error) { - fake.handleSignalRequestMutex.Lock() - defer fake.handleSignalRequestMutex.Unlock() - fake.HandleSignalRequestStub = stub +func (fake *FakeLocalParticipant) HandleSignalMessageCalls(stub func(proto.Message) error) { + fake.handleSignalMessageMutex.Lock() + defer fake.handleSignalMessageMutex.Unlock() + fake.HandleSignalMessageStub = stub } -func (fake *FakeLocalParticipant) HandleSignalRequestArgsForCall(i int) proto.Message { - fake.handleSignalRequestMutex.RLock() - defer fake.handleSignalRequestMutex.RUnlock() - argsForCall := fake.handleSignalRequestArgsForCall[i] +func (fake *FakeLocalParticipant) HandleSignalMessageArgsForCall(i int) proto.Message { + fake.handleSignalMessageMutex.RLock() + defer fake.handleSignalMessageMutex.RUnlock() + argsForCall := fake.handleSignalMessageArgsForCall[i] return argsForCall.arg1 } -func (fake *FakeLocalParticipant) HandleSignalRequestReturns(result1 error) { - fake.handleSignalRequestMutex.Lock() - defer fake.handleSignalRequestMutex.Unlock() - fake.HandleSignalRequestStub = nil - fake.handleSignalRequestReturns = struct { +func (fake *FakeLocalParticipant) HandleSignalMessageReturns(result1 error) { + fake.handleSignalMessageMutex.Lock() + defer fake.handleSignalMessageMutex.Unlock() + fake.HandleSignalMessageStub = nil + fake.handleSignalMessageReturns = struct { result1 error }{result1} } -func (fake *FakeLocalParticipant) HandleSignalRequestReturnsOnCall(i int, result1 error) { - fake.handleSignalRequestMutex.Lock() - defer fake.handleSignalRequestMutex.Unlock() - fake.HandleSignalRequestStub = nil - if fake.handleSignalRequestReturnsOnCall == nil { - fake.handleSignalRequestReturnsOnCall = make(map[int]struct { +func (fake *FakeLocalParticipant) HandleSignalMessageReturnsOnCall(i int, result1 error) { + fake.handleSignalMessageMutex.Lock() + defer fake.handleSignalMessageMutex.Unlock() + fake.HandleSignalMessageStub = nil + if fake.handleSignalMessageReturnsOnCall == nil { + fake.handleSignalMessageReturnsOnCall = make(map[int]struct { result1 error }) } - fake.handleSignalRequestReturnsOnCall[i] = struct { + fake.handleSignalMessageReturnsOnCall[i] = struct { result1 error }{result1} } @@ -9383,8 +9383,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.handleReceiverReportMutex.RUnlock() fake.handleReconnectAndSendResponseMutex.RLock() defer fake.handleReconnectAndSendResponseMutex.RUnlock() - fake.handleSignalRequestMutex.RLock() - defer fake.handleSignalRequestMutex.RUnlock() + fake.handleSignalMessageMutex.RLock() + defer fake.handleSignalMessageMutex.RUnlock() fake.handleSignalSourceCloseMutex.RLock() defer fake.handleSignalSourceCloseMutex.RUnlock() fake.handleSimulateScenarioMutex.RLock() diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 753ce6208..74b87b792 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -945,7 +945,7 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.LocalPa return } - if err := participant.HandleSignalRequest(obj); err != nil { + if err := participant.HandleSignalMessage(obj); err != nil { // more specific errors are already logged // treat errors returned as fatal return diff --git a/pkg/service/roommanager_service.go b/pkg/service/roommanager_service.go index 422e875ff..9b178879b 100644 --- a/pkg/service/roommanager_service.go +++ b/pkg/service/roommanager_service.go @@ -207,7 +207,7 @@ func (s signalv2ParticipantService) RelaySignalv2Participant(ctx context.Context return nil, ErrParticipantNotFound } - err := lp.HandleSignalRequest(req.WireMessage) + err := lp.HandleSignalMessage(req.WireMessage) if err != nil { return nil, err } diff --git a/pkg/service/rtcv2service.go b/pkg/service/rtcv2service.go index 95a703543..92bd61e38 100644 --- a/pkg/service/rtcv2service.go +++ b/pkg/service/rtcv2service.go @@ -223,6 +223,13 @@ func (s *RTCv2Service) handleParticipantPatch(w http.ResponseWriter, r *http.Req HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not get wire message: %w", err)) return } + logger.Debugw( + "participant request", + "room", roomName, + "participant", participantIdentity, + "pID", pID, + "participantRequest", logger.Proto(wireMessage), + ) res, err := s.signalv2ParticipantClient.RelaySignalv2Participant( r.Context(), @@ -257,7 +264,7 @@ func (s *RTCv2Service) handleParticipantPatch(w http.ResponseWriter, r *http.Req "room", roomName, "participant", participantIdentity, "pID", pID, - "participantResponse", logger.Proto(res), + "participantResponse", logger.Proto(res.WireMessage), ) marshalled, err := proto.Marshal(res.WireMessage)