From a3228c2ae96b90c1e694e6c513f18103bdb1df67 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Fri, 4 Jun 2021 14:57:55 -0700 Subject: [PATCH] resolve data race conditions, code quality --- magefile.go | 3 +- pkg/routing/localrouter.go | 8 +- pkg/routing/node.go | 27 +++-- pkg/rtc/datatrack.go | 211 ------------------------------------- pkg/rtc/mediaengine.go | 8 +- pkg/rtc/mediatrack.go | 22 ++-- pkg/rtc/participant.go | 30 ++---- pkg/rtc/room.go | 28 ++--- pkg/rtc/room_test.go | 10 +- pkg/rtc/stats.go | 12 +++ pkg/service/roommanager.go | 7 +- pkg/service/server.go | 3 +- pkg/service/wire_gen.go | 2 +- pkg/service/wsprotocol.go | 9 +- pkg/testutils/timeout.go | 3 +- test/client/client.go | 4 +- tools/tools.go | 1 + 17 files changed, 95 insertions(+), 293 deletions(-) delete mode 100644 pkg/rtc/datatrack.go diff --git a/magefile.go b/magefile.go index 4658d4d58..9454670c5 100644 --- a/magefile.go +++ b/magefile.go @@ -213,7 +213,8 @@ func Test() error { // run all tests including integration func TestAll() error { mg.Deps(Proto) - cmd := exec.Command("go", "test", "./...", "-count=1", "-timeout=5m") + // "-v", "-race", + cmd := exec.Command("go", "test", "./...", "-count=1", "-timeout=3m") connectStd(cmd) return cmd.Run() } diff --git a/pkg/routing/localrouter.go b/pkg/routing/localrouter.go index fc37c064a..9a08d3c43 100644 --- a/pkg/routing/localrouter.go +++ b/pkg/routing/localrouter.go @@ -5,6 +5,7 @@ import ( "time" "github.com/livekit/protocol/utils" + "google.golang.org/protobuf/proto" "github.com/livekit/livekit-server/pkg/logger" livekit "github.com/livekit/livekit-server/proto" @@ -35,7 +36,10 @@ func NewLocalRouter(currentNode LocalNode) *LocalRouter { } func (r *LocalRouter) GetNodeForRoom(roomName string) (*livekit.Node, error) { - return r.currentNode, nil + r.lock.Lock() + defer r.lock.Unlock() + node := proto.Clone((*livekit.Node)(r.currentNode)).(*livekit.Node) + return node, nil } func (r *LocalRouter) SetNodeForRoom(roomName string, nodeId string) error { @@ -132,7 +136,9 @@ func (r *LocalRouter) statsWorker() { } // update every 10 seconds <-time.After(statsUpdateInterval) + r.lock.Lock() r.currentNode.Stats.UpdatedAt = time.Now().Unix() + r.lock.Unlock() } } diff --git a/pkg/routing/node.go b/pkg/routing/node.go index 0211a8b39..2e6e06d7d 100644 --- a/pkg/routing/node.go +++ b/pkg/routing/node.go @@ -63,7 +63,8 @@ func GetLocalIP(stunServers []string) (string, error) { } var stunErr error - var nodeIp string + // sufficiently large buffer to not block it + ipChan := make(chan string, 20) err = c.Start(message, func(res stun.Event) { if res.Error != nil { stunErr = res.Error @@ -77,7 +78,7 @@ func GetLocalIP(stunServers []string) (string, error) { } ip := xorAddr.IP.To4() if ip != nil { - nodeIp = ip.String() + ipChan <- ip.String() } }) if err != nil { @@ -86,21 +87,17 @@ func GetLocalIP(stunServers []string) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - for nodeIp == "" { - select { - case <-ctx.Done(): - msg := "could not determine public IP" - if stunErr != nil { - return "", errors.Wrap(stunErr, msg) - } else { - return "", fmt.Errorf(msg) - } - case <-time.After(100 * time.Millisecond): - continue + select { + case nodeIP := <-ipChan: + return nodeIP, nil + case <-ctx.Done(): + msg := "could not determine public IP" + if stunErr != nil { + return "", errors.Wrap(stunErr, msg) + } else { + return "", fmt.Errorf(msg) } } - - return nodeIp, nil } // Creates a hashed ID from a unique string diff --git a/pkg/rtc/datatrack.go b/pkg/rtc/datatrack.go deleted file mode 100644 index 030e73d46..000000000 --- a/pkg/rtc/datatrack.go +++ /dev/null @@ -1,211 +0,0 @@ -package rtc - -import ( - "sync" - - "github.com/pion/webrtc/v3" - - "github.com/livekit/livekit-server/pkg/logger" - "github.com/livekit/livekit-server/pkg/rtc/types" - livekit "github.com/livekit/livekit-server/proto" -) - -const ( - dataBufferSize = 50 -) - -// DataTrack wraps a WebRTC DataChannel to satisfy the PublishedTrack interface -// it shall forward publishedTracks to all of its subscribers -type DataTrack struct { - id string - name string - participantId string - dataChannel *webrtc.DataChannel - lock sync.RWMutex - once sync.Once - msgChan chan livekit.DataMessage - onClose func() - - // map of target participantId -> DownDataChannel - subscribers map[string]*DownDataChannel -} - -func NewDataTrack(trackId, participantId string, dc *webrtc.DataChannel) *DataTrack { - t := &DataTrack{ - // ctx: context.Background(), - id: trackId, - name: dc.Label(), - participantId: participantId, - dataChannel: dc, - msgChan: make(chan livekit.DataMessage, dataBufferSize), - lock: sync.RWMutex{}, - subscribers: make(map[string]*DownDataChannel), - } - - dc.OnMessage(func(msg webrtc.DataChannelMessage) { - dm := messageFromDataChannelMessage(msg) - t.msgChan <- dm - }) - - dc.OnClose(func() { - t.RemoveAllSubscribers() - if t.onClose != nil { - t.onClose() - } - }) - - return t -} - -func (t *DataTrack) Start() { - t.once.Do(func() { - go t.forwardWorker() - }) -} - -func (t *DataTrack) ID() string { - return t.id -} - -func (t *DataTrack) Kind() livekit.TrackType { - return livekit.TrackType_DATA -} - -func (t *DataTrack) Name() string { - return t.name -} - -// DataTrack cannot be muted -func (t *DataTrack) IsMuted() bool { - return false -} - -func (t *DataTrack) SetMuted(muted bool) { - -} - -func (t *DataTrack) OnClose(f func()) { - t.onClose = f -} - -func (t *DataTrack) IsSubscriber(subId string) bool { - t.lock.RLock() - defer t.lock.RUnlock() - return t.subscribers[subId] != nil -} - -func (t *DataTrack) AddSubscriber(participant types.Participant) error { - t.lock.Lock() - defer t.lock.Unlock() - - if t.subscribers[participant.ID()] != nil { - return nil - } - - label := PackDataTrackLabel(t.participantId, t.ID(), t.dataChannel.Label()) - downChannel, err := participant.SubscriberPC().CreateDataChannel(label, t.dataChannelOptions()) - if err != nil { - return err - } - - sub := &DownDataChannel{ - participantId: participant.ID(), - dataChannel: downChannel, - } - - t.subscribers[participant.ID()] = sub - - downChannel.OnClose(func() { - t.RemoveSubscriber(sub.participantId) - }) - return nil -} - -func (t *DataTrack) RemoveSubscriber(participantId string) { - t.lock.Lock() - sub := t.subscribers[participantId] - delete(t.subscribers, participantId) - t.lock.Unlock() - - if sub != nil { - go sub.dataChannel.Close() - } -} - -func (t *DataTrack) RemoveAllSubscribers() { - t.lock.Lock() - defer t.lock.Unlock() - for _, sub := range t.subscribers { - go sub.dataChannel.Close() - } - t.subscribers = make(map[string]*DownDataChannel) -} - -func (t *DataTrack) forwardWorker() { - defer Recover() - - for { - msg := <-t.msgChan - - if msg.Value == nil { - // track closed - return - } - - t.lock.RLock() - for _, sub := range t.subscribers { - err := sub.SendMessage(msg) - if err != nil { - logger.Errorw("could not send data message", err, - "source", t.participantId, - "dest", sub.participantId) - } - } - t.lock.RUnlock() - } -} - -func (t *DataTrack) dataChannelOptions() *webrtc.DataChannelInit { - ordered := t.dataChannel.Ordered() - protocol := t.dataChannel.Protocol() - negotiated := false - - return &webrtc.DataChannelInit{ - Ordered: &ordered, - MaxPacketLifeTime: t.dataChannel.MaxPacketLifeTime(), - MaxRetransmits: t.dataChannel.MaxRetransmits(), - Protocol: &protocol, - Negotiated: &negotiated, - } -} - -type DownDataChannel struct { - participantId string - dataChannel *webrtc.DataChannel -} - -func (d *DownDataChannel) SendMessage(msg livekit.DataMessage) error { - var err error - switch val := msg.Value.(type) { - case *livekit.DataMessage_Binary: - err = d.dataChannel.Send(val.Binary) - case *livekit.DataMessage_Text: - err = d.dataChannel.SendText(val.Text) - } - return err -} - -func messageFromDataChannelMessage(msg webrtc.DataChannelMessage) livekit.DataMessage { - dm := livekit.DataMessage{} - if msg.IsString { - dm.Value = &livekit.DataMessage_Text{ - Text: string(msg.Data), - } - } else { - dm.Value = &livekit.DataMessage_Binary{ - Binary: msg.Data, - } - - } - return dm -} diff --git a/pkg/rtc/mediaengine.go b/pkg/rtc/mediaengine.go index 9cef80354..5edc1fd4b 100644 --- a/pkg/rtc/mediaengine.go +++ b/pkg/rtc/mediaengine.go @@ -19,10 +19,10 @@ func createPubMediaEngine() (*webrtc.MediaEngine, error) { } videoRTCPFeedback := []webrtc.RTCPFeedback{ - {webrtc.TypeRTCPFBGoogREMB, ""}, - {webrtc.TypeRTCPFBCCM, "fir"}, - {webrtc.TypeRTCPFBNACK, ""}, - {webrtc.TypeRTCPFBNACK, "pli"}} + {Type: webrtc.TypeRTCPFBGoogREMB, Parameter: ""}, + {Type: webrtc.TypeRTCPFBCCM, Parameter: "fir"}, + {Type: webrtc.TypeRTCPFBNACK, Parameter: ""}, + {Type: webrtc.TypeRTCPFBNACK, Parameter: "pli"}} for _, codec := range []webrtc.RTPCodecParameters{ { RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeVP8, ClockRate: 90000, RTCPFeedback: videoRTCPFeedback}, diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 2f95187e4..9e1b86380 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -22,9 +22,9 @@ import ( var ( feedbackTypes = []webrtc.RTCPFeedback{ - {webrtc.TypeRTCPFBGoogREMB, ""}, - {webrtc.TypeRTCPFBNACK, ""}, - {webrtc.TypeRTCPFBNACK, "pli"}} + {Type: webrtc.TypeRTCPFBGoogREMB}, + {Type: webrtc.TypeRTCPFBNACK}, + {Type: webrtc.TypeRTCPFBNACK, Parameter: "pli"}} ) // MediaTrack represents a WebRTC track that needs to be forwarded @@ -67,7 +67,6 @@ func NewMediaTrack(track *webrtc.TrackRemote, params MediaTrackParams) *MediaTra streamID: track.StreamID(), kind: ToProtoTrackKind(track.Kind()), codec: track.Codec(), - lock: sync.RWMutex{}, subscribedTracks: make(map[string]*SubscribedTrack), } @@ -170,7 +169,7 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error { // when outtrack is bound, start loop to send reports downTrack.OnBind(func() { subTrack.SetPublisherMuted(t.IsMuted()) - go t.sendDownTrackBindingReports(sub.ID(), sub.RTCPChan()) + go t.sendDownTrackBindingReports(sub) }) downTrack.OnCloseHandler(func() { t.lock.Lock() @@ -302,8 +301,8 @@ func (t *MediaTrack) RemoveSubscriber(participantId string) { func (t *MediaTrack) RemoveAllSubscribers() { logger.Debugw("removing all subscribers", "track", t.params.TrackID) - t.lock.RLock() - defer t.lock.RUnlock() + t.lock.Lock() + defer t.lock.Unlock() for _, subTrack := range t.subscribedTracks { go subTrack.DownTrack().Close() } @@ -312,11 +311,11 @@ func (t *MediaTrack) RemoveAllSubscribers() { // TODO: send for all downtracks from the source participant // https://tools.ietf.org/html/rfc7941 -func (t *MediaTrack) sendDownTrackBindingReports(participantId string, rtcpCh chan []rtcp.Packet) { +func (t *MediaTrack) sendDownTrackBindingReports(sub types.Participant) { var sd []rtcp.SourceDescriptionChunk t.lock.RLock() - subTrack := t.subscribedTracks[participantId] + subTrack := t.subscribedTracks[sub.ID()] t.lock.RUnlock() if subTrack == nil { @@ -338,7 +337,10 @@ func (t *MediaTrack) sendDownTrackBindingReports(participantId string, rtcpCh ch batch := pkts i := 0 for { - rtcpCh <- batch + if err := sub.SubscriberPC().WriteRTCP(batch); err != nil { + logger.Errorw("could not write RTCP", err) + return + } if i > 5 { return } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index ad0783cda..c40c24364 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -27,7 +27,6 @@ import ( const ( lossyDataChannel = "_lossy" reliableDataChannel = "_reliable" - privateDataChannel = "_private" sdBatchSize = 20 ) @@ -90,7 +89,6 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { id: utils.NewGuid(utils.ParticipantPrefix), rtcpCh: make(chan []rtcp.Packet, 50), subscribedTracks: make(map[string][]types.SubscribedTrack), - lock: sync.RWMutex{}, publishedTracks: make(map[string]types.PublishedTrack, 0), pendingTracks: make(map[string]*livekit.TrackInfo), connectedAt: time.Now(), @@ -632,10 +630,13 @@ func (p *ParticipantImpl) updateState(state livekit.ParticipantInfo_State) { } p.state.Store(state) logger.Debugw("updating participant state", "state", state.String(), "participant", p.Identity()) - if p.onStateChange != nil { + p.lock.RLock() + onStateChange := p.onStateChange + p.lock.RUnlock() + if onStateChange != nil { go func() { defer Recover() - p.onStateChange(p, oldState) + onStateChange(p, oldState) }() } } @@ -743,27 +744,8 @@ func (p *ParticipantImpl) onDataChannel(dc *webrtc.DataChannel) { dc.OnMessage(func(msg webrtc.DataChannelMessage) { p.handleDataMessage(livekit.DataPacket_LOSSY, msg.Data) }) - case privateDataChannel: - // ignore default: - logger.Debugw("dataChannel added", "participant", p.Identity(), "label", dc.Label()) - - if !p.CanPublish() { - logger.Warnw("no permission to publish dataTrack", nil, - "participant", p.Identity()) - return - } - - // data channels have numeric ids, so we use its label to identify - ti := p.getPendingTrack(dc.Label(), livekit.TrackType_DATA, true) - if ti == nil { - return - } - - dt := NewDataTrack(ti.Sid, p.id, dc) - dt.name = ti.Name - - p.handleTrackPublished(dt) + logger.Warnw("unsupported datachannel added", nil, "participant", p.Identity(), "label", dc.Label()) } } diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index d5a6eef9b..44d8e25bd 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -7,6 +7,7 @@ import ( "time" "github.com/livekit/protocol/utils" + "google.golang.org/protobuf/proto" "github.com/livekit/livekit-server/pkg/logger" "github.com/livekit/livekit-server/pkg/rtc/types" @@ -19,7 +20,7 @@ const ( ) type Room struct { - livekit.Room + Room *livekit.Room config WebRTCConfig iceServers []*livekit.ICEServer lock sync.RWMutex @@ -49,20 +50,19 @@ type ParticipantOptions struct { func NewRoom(room *livekit.Room, config WebRTCConfig, iceServers []*livekit.ICEServer, audioUpdateInterval uint32) *Room { r := &Room{ - Room: *room, + Room: proto.Clone(room).(*livekit.Room), config: config, iceServers: iceServers, audioUpdateInterval: audioUpdateInterval, - lock: sync.RWMutex{}, statsReporter: NewRoomStatsReporter(room.Name), participants: make(map[string]types.Participant), participantOpts: make(map[string]*ParticipantOptions), } - if r.EmptyTimeout == 0 { - r.EmptyTimeout = DefaultEmptyTimeout + if r.Room.EmptyTimeout == 0 { + r.Room.EmptyTimeout = DefaultEmptyTimeout } - if r.CreationTime == 0 { - r.CreationTime = time.Now().Unix() + if r.Room.CreationTime == 0 { + r.Room.CreationTime = time.Now().Unix() } r.statsReporter.RoomStarted() go r.audioUpdateWorker() @@ -138,7 +138,7 @@ func (r *Room) Join(participant types.Participant, opts *ParticipantOptions) err return ErrAlreadyJoined } - if r.MaxParticipants > 0 && int(r.MaxParticipants) == len(r.participants) { + if r.Room.MaxParticipants > 0 && int(r.Room.MaxParticipants) == len(r.participants) { return ErrMaxParticipantsExceeded } @@ -176,7 +176,7 @@ func (r *Room) Join(participant types.Participant, opts *ParticipantOptions) err logger.Infow("new participant joined", "id", participant.ID(), "identity", participant.Identity(), - "roomId", r.Sid) + "roomId", r.Room.Sid) r.participants[participant.Identity()] = participant r.participantOpts[participant.Identity()] = opts @@ -193,7 +193,7 @@ func (r *Room) Join(participant types.Participant, opts *ParticipantOptions) err r.onParticipantChanged(participant) } - return participant.SendJoinResponse(&r.Room, otherParticipants, r.iceServers) + return participant.SendJoinResponse(r.Room, otherParticipants, r.iceServers) } func (r *Room) RemoveParticipant(identity string) { @@ -221,9 +221,11 @@ func (r *Room) RemoveParticipant(identity string) { // close participant as well _ = p.Close() + r.lock.RLock() if len(r.participants) == 0 { r.leftAt.Store(time.Now().Unix()) } + r.lock.RUnlock() if sendUpdates { if r.onParticipantChanged != nil { @@ -274,7 +276,7 @@ func (r *Room) CloseIfEmpty() { return } - timeout := r.EmptyTimeout + timeout := r.Room.EmptyTimeout var elapsed int64 if r.FirstJoinedAt() > 0 { // exit 20s after @@ -283,7 +285,7 @@ func (r *Room) CloseIfEmpty() { timeout = DefaultRoomDepartureGrace } } else { - elapsed = time.Now().Unix() - r.CreationTime + elapsed = time.Now().Unix() - r.Room.CreationTime } if elapsed >= int64(timeout) { @@ -295,7 +297,7 @@ func (r *Room) Close() { if !r.isClosed.TrySet(true) { return } - logger.Infow("closing room", "room", r.Sid, "name", r.Name) + logger.Infow("closing room", "room", r.Room.Sid, "name", r.Room.Name) r.statsReporter.RoomEnded() if r.onClose != nil { diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go index 735e6ad12..5d14ba5ba 100644 --- a/pkg/rtc/room_test.go +++ b/pkg/rtc/room_test.go @@ -63,7 +63,7 @@ func TestRoomJoin(t *testing.T) { // expect new participant to get a JoinReply info, participants, iceServers := pNew.SendJoinResponseArgsForCall(0) - require.Equal(t, info.Sid, rm.Sid) + require.Equal(t, info.Sid, rm.Room.Sid) require.Len(t, participants, numParticipants) require.Len(t, rm.GetParticipants(), numParticipants+1) require.NotEmpty(t, iceServers) @@ -125,7 +125,7 @@ func TestRoomJoin(t *testing.T) { t.Run("cannot exceed max participants", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 1}) - rm.MaxParticipants = 1 + rm.Room.MaxParticipants = 1 p := newMockParticipant("second", types.ProtocolVersion(0)) err := rm.Join(p, nil) @@ -198,7 +198,7 @@ func TestRoomClosure(t *testing.T) { }) p := rm.GetParticipants()[0] // allows immediate close after - rm.EmptyTimeout = 0 + rm.Room.EmptyTimeout = 0 rm.RemoveParticipant(p.Identity()) time.Sleep(defaultDelay) @@ -216,7 +216,7 @@ func TestRoomClosure(t *testing.T) { rm.OnClose(func() { isClosed = true }) - require.NotZero(t, rm.EmptyTimeout) + require.NotZero(t, rm.Room.EmptyTimeout) rm.CloseIfEmpty() require.False(t, isClosed) }) @@ -227,7 +227,7 @@ func TestRoomClosure(t *testing.T) { rm.OnClose(func() { isClosed = true }) - rm.EmptyTimeout = 1 + rm.Room.EmptyTimeout = 1 time.Sleep(1010 * time.Millisecond) rm.CloseIfEmpty() diff --git a/pkg/rtc/stats.go b/pkg/rtc/stats.go index bdea343a9..52bf7e940 100644 --- a/pkg/rtc/stats.go +++ b/pkg/rtc/stats.go @@ -192,6 +192,18 @@ func (s *PacketStats) HandleRTCP(pkts []rtcp.Packet) { } } +func (s PacketStats) Copy() *PacketStats { + return &PacketStats{ + roomName: s.roomName, + direction: s.direction, + PacketBytes: atomic.LoadUint64(&s.PacketBytes), + PacketTotal: atomic.LoadUint64(&s.PacketTotal), + NackTotal: atomic.LoadUint64(&s.NackTotal), + PLITotal: atomic.LoadUint64(&s.PLITotal), + FIRTotal: atomic.LoadUint64(&s.FIRTotal), + } +} + // StatsBufferWrapper wraps a buffer factory so we could get information on // incoming packets type StatsBufferWrapper struct { diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 9386a78c3..c5e2d2a6c 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -307,10 +307,11 @@ func (r *RoomManager) getOrCreateRoom(roomName string) (*rtc.Room, error) { if err := r.DeleteRoom(roomName); err != nil { logger.Errorw("could not delete room", err) } + // print stats logger.Infow("room closed", - "incomingStats", room.GetIncomingStats(), - "outgoingStats", room.GetOutgoingStats(), + "incomingStats", room.GetIncomingStats().Copy(), + "outgoingStats", room.GetOutgoingStats().Copy(), ) }) room.OnParticipantChanged(func(p types.Participant) { @@ -336,7 +337,7 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.Partici defer func() { logger.Debugw("RTC session finishing", "participant", participant.Identity(), - "room", room.Name, + "room", room.Room.Name, ) _ = participant.Close() }() diff --git a/pkg/service/server.go b/pkg/service/server.go index ecff72c28..5624c3534 100644 --- a/pkg/service/server.go +++ b/pkg/service/server.go @@ -177,7 +177,8 @@ func (s *LivekitServer) Start() error { } // wait for shutdown - ctx, _ := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() _ = s.httpServer.Shutdown(ctx) if s.turnServer != nil { diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 62d017ed7..c7bddcffb 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -1,6 +1,6 @@ // Code generated by Wire. DO NOT EDIT. -//go:generate go run github.com/google/wire/cmd/wire +//go:generate wire //+build !wireinject package service diff --git a/pkg/service/wsprotocol.go b/pkg/service/wsprotocol.go index 6ecc77b36..97abc7ab9 100644 --- a/pkg/service/wsprotocol.go +++ b/pkg/service/wsprotocol.go @@ -45,14 +45,18 @@ func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, error) { msg := &livekit.SignalRequest{} switch messageType { case websocket.BinaryMessage: + c.mu.Lock() // switch to protobuf if client supports it c.useJSON = false + c.mu.Unlock() // protobuf encoded err := proto.Unmarshal(payload, msg) return msg, err case websocket.TextMessage: + c.mu.Lock() // json encoded, also write back JSON c.useJSON = true + c.mu.Unlock() err := protojson.Unmarshal(payload, msg) return msg, err default: @@ -67,6 +71,9 @@ func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) error { var payload []byte var err error + c.mu.Lock() + defer c.mu.Unlock() + if c.useJSON { msgType = websocket.TextMessage payload, err = protojson.Marshal(msg) @@ -78,8 +85,6 @@ func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) error { return err } - c.mu.Lock() - defer c.mu.Unlock() return c.conn.WriteMessage(msgType, payload) } diff --git a/pkg/testutils/timeout.go b/pkg/testutils/timeout.go index d821c3b83..dba3d71e4 100644 --- a/pkg/testutils/timeout.go +++ b/pkg/testutils/timeout.go @@ -15,7 +15,8 @@ var ( func WithTimeout(t *testing.T, description string, f func() bool) bool { logger.Infow(description) - ctx, _ := context.WithTimeout(context.Background(), ConnectTimeout) + ctx, cancel := context.WithTimeout(context.Background(), ConnectTimeout) + defer cancel() for { select { case <-ctx.Done(): diff --git a/test/client/client.go b/test/client/client.go index 64ad8c8ca..bd36e2197 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -189,7 +189,7 @@ func (c *RTCClient) Run() error { }) // create a data channel, in order to work - _, err := c.publisher.PeerConnection().CreateDataChannel("_private", nil) + _, err := c.publisher.PeerConnection().CreateDataChannel("_lossy", nil) if err != nil { return err } @@ -353,6 +353,8 @@ func (c *RTCClient) SubscribedTracks() map[string][]*webrtc.TrackRemote { } func (c *RTCClient) RemoteParticipants() []*livekit.ParticipantInfo { + c.lock.Lock() + defer c.lock.Unlock() return funk.Values(c.remoteParticipants).([]*livekit.ParticipantInfo) } diff --git a/tools/tools.go b/tools/tools.go index 808316bc9..ea857b376 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -3,6 +3,7 @@ package tools import ( + _ "github.com/google/wire" _ "github.com/maxbrunsfeld/counterfeiter/v6" _ "github.com/twitchtv/twirp/protoc-gen-twirp" _ "google.golang.org/protobuf/cmd/protoc-gen-go"