From d4c4bc1100ad2067dc2eb8b96ff4bf90e5edd8aa Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Mon, 4 Dec 2023 19:11:55 -0800 Subject: [PATCH] fix signal response delivery after session start failure (#2294) * fix signal response delivery after session start failure * tidy --- pkg/service/signal.go | 5 +- pkg/service/signal_test.go | 137 +++++++++++++++++++++++++------------ 2 files changed, 95 insertions(+), 47 deletions(-) diff --git a/pkg/service/signal.go b/pkg/service/signal.go index 4c9e085be..a181987e9 100644 --- a/pkg/service/signal.go +++ b/pkg/service/signal.go @@ -148,6 +148,7 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe "connID", ss.ConnectionId, ) + stream.Hijack() sink := routing.NewSignalMessageSink(routing.SignalSinkParams[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest]{ Logger: l, Stream: stream, @@ -176,11 +177,9 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe err = r.sessionHandler(ctx, livekit.RoomName(ss.RoomName), *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink) if err != nil { + sink.Close() l.Errorw("could not handle new participant", err) - return } - - stream.Hijack() return } diff --git a/pkg/service/signal_test.go b/pkg/service/signal_test.go index 3e202636e..83bf391e9 100644 --- a/pkg/service/signal_test.go +++ b/pkg/service/signal_test.go @@ -16,6 +16,7 @@ package service import ( "context" + "errors" "testing" "time" @@ -35,7 +36,6 @@ func init() { } func TestSignal(t *testing.T) { - bus := psrpc.NewLocalMessageBus() cfg := config.SignalRelayConfig{ Enabled: false, RetryTimeout: 30 * time.Second, @@ -44,56 +44,105 @@ func TestSignal(t *testing.T) { StreamBufferSize: 1000, } - reqMessageIn := &livekit.SignalRequest{ - Message: &livekit.SignalRequest_Ping{Ping: 123}, - } - resMessageIn := &livekit.SignalResponse{ - Message: &livekit.SignalResponse_Pong{Pong: 321}, - } + t.Run("messages are delivered", func(t *testing.T) { + bus := psrpc.NewLocalMessageBus() - var reqMessageOut proto.Message - var resErr error - done := make(chan struct{}) + reqMessageIn := &livekit.SignalRequest{ + Message: &livekit.SignalRequest_Ping{Ping: 123}, + } + resMessageIn := &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Pong{Pong: 321}, + } - client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) - require.NoError(t, err) + var reqMessageOut proto.Message + var resErr error + done := make(chan struct{}) - server, err := NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func( - ctx context.Context, - roomName livekit.RoomName, - pi routing.ParticipantInit, - connectionID livekit.ConnectionID, - requestSource routing.MessageSource, - responseSink routing.MessageSink, - ) error { - go func() { - reqMessageOut = <-requestSource.ReadChan() - resErr = responseSink.WriteMessage(resMessageIn) - responseSink.Close() - close(done) - }() - return nil + client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) + require.NoError(t, err) + + server, err := NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error { + go func() { + reqMessageOut = <-requestSource.ReadChan() + resErr = responseSink.WriteMessage(resMessageIn) + responseSink.Close() + close(done) + }() + return nil + }) + require.NoError(t, err) + + err = server.Start() + require.NoError(t, err) + + _, reqSink, resSource, err := client.StartParticipantSignal( + context.Background(), + livekit.RoomName("room1"), + routing.ParticipantInit{}, + livekit.NodeID("node1"), + ) + require.NoError(t, err) + + err = reqSink.WriteMessage(reqMessageIn) + require.NoError(t, err) + + <-done + require.True(t, proto.Equal(reqMessageIn, reqMessageOut), "req message should match %s %s", protojson.Format(reqMessageIn), protojson.Format(reqMessageOut)) + require.NoError(t, resErr) + + resMessageOut := <-resSource.ReadChan() + require.True(t, proto.Equal(resMessageIn, resMessageOut), "res message should match %s %s", protojson.Format(resMessageIn), protojson.Format(resMessageOut)) }) - require.NoError(t, err) - err = server.Start() - require.NoError(t, err) + t.Run("messages are delivered when session handler fails", func(t *testing.T) { + bus := psrpc.NewLocalMessageBus() - _, reqSink, resSource, err := client.StartParticipantSignal( - context.Background(), - livekit.RoomName("room1"), - routing.ParticipantInit{}, - livekit.NodeID("node1"), - ) - require.NoError(t, err) + resMessageIn := &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Pong{Pong: 321}, + } - err = reqSink.WriteMessage(reqMessageIn) - require.NoError(t, err) + var resErr error + done := make(chan struct{}) - <-done - require.True(t, proto.Equal(reqMessageIn, reqMessageOut), "req message should match %s %s", protojson.Format(reqMessageIn), protojson.Format(reqMessageOut)) - require.NoError(t, resErr) + client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) + require.NoError(t, err) - resMessageOut := <-resSource.ReadChan() - require.True(t, proto.Equal(resMessageIn, resMessageOut), "res message should match %s %s", protojson.Format(resMessageIn), protojson.Format(resMessageOut)) + server, err := NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error { + defer close(done) + resErr = responseSink.WriteMessage(resMessageIn) + return errors.New("start session failed") + }) + require.NoError(t, err) + + err = server.Start() + require.NoError(t, err) + + _, _, resSource, err := client.StartParticipantSignal( + context.Background(), + livekit.RoomName("room1"), + routing.ParticipantInit{}, + livekit.NodeID("node1"), + ) + require.NoError(t, err) + + <-done + require.NoError(t, resErr) + + resMessageOut := <-resSource.ReadChan() + require.True(t, proto.Equal(resMessageIn, resMessageOut), "res message should match %s %s", protojson.Format(resMessageIn), protojson.Format(resMessageOut)) + }) }