resolve data race conditions, code quality

This commit is contained in:
David Zhao
2021-06-04 14:57:55 -07:00
parent 5baf97e99b
commit a3228c2ae9
17 changed files with 95 additions and 293 deletions
+2 -1
View File
@@ -213,7 +213,8 @@ func Test() error {
// run all tests including integration
func TestAll() error {
mg.Deps(Proto)
cmd := exec.Command("go", "test", "./...", "-count=1", "-timeout=5m")
// "-v", "-race",
cmd := exec.Command("go", "test", "./...", "-count=1", "-timeout=3m")
connectStd(cmd)
return cmd.Run()
}
+7 -1
View File
@@ -5,6 +5,7 @@ import (
"time"
"github.com/livekit/protocol/utils"
"google.golang.org/protobuf/proto"
"github.com/livekit/livekit-server/pkg/logger"
livekit "github.com/livekit/livekit-server/proto"
@@ -35,7 +36,10 @@ func NewLocalRouter(currentNode LocalNode) *LocalRouter {
}
func (r *LocalRouter) GetNodeForRoom(roomName string) (*livekit.Node, error) {
return r.currentNode, nil
r.lock.Lock()
defer r.lock.Unlock()
node := proto.Clone((*livekit.Node)(r.currentNode)).(*livekit.Node)
return node, nil
}
func (r *LocalRouter) SetNodeForRoom(roomName string, nodeId string) error {
@@ -132,7 +136,9 @@ func (r *LocalRouter) statsWorker() {
}
// update every 10 seconds
<-time.After(statsUpdateInterval)
r.lock.Lock()
r.currentNode.Stats.UpdatedAt = time.Now().Unix()
r.lock.Unlock()
}
}
+12 -15
View File
@@ -63,7 +63,8 @@ func GetLocalIP(stunServers []string) (string, error) {
}
var stunErr error
var nodeIp string
// sufficiently large buffer to not block it
ipChan := make(chan string, 20)
err = c.Start(message, func(res stun.Event) {
if res.Error != nil {
stunErr = res.Error
@@ -77,7 +78,7 @@ func GetLocalIP(stunServers []string) (string, error) {
}
ip := xorAddr.IP.To4()
if ip != nil {
nodeIp = ip.String()
ipChan <- ip.String()
}
})
if err != nil {
@@ -86,21 +87,17 @@ func GetLocalIP(stunServers []string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
for nodeIp == "" {
select {
case <-ctx.Done():
msg := "could not determine public IP"
if stunErr != nil {
return "", errors.Wrap(stunErr, msg)
} else {
return "", fmt.Errorf(msg)
}
case <-time.After(100 * time.Millisecond):
continue
select {
case nodeIP := <-ipChan:
return nodeIP, nil
case <-ctx.Done():
msg := "could not determine public IP"
if stunErr != nil {
return "", errors.Wrap(stunErr, msg)
} else {
return "", fmt.Errorf(msg)
}
}
return nodeIp, nil
}
// Creates a hashed ID from a unique string
-211
View File
@@ -1,211 +0,0 @@
package rtc
import (
"sync"
"github.com/pion/webrtc/v3"
"github.com/livekit/livekit-server/pkg/logger"
"github.com/livekit/livekit-server/pkg/rtc/types"
livekit "github.com/livekit/livekit-server/proto"
)
const (
dataBufferSize = 50
)
// DataTrack wraps a WebRTC DataChannel to satisfy the PublishedTrack interface
// it shall forward publishedTracks to all of its subscribers
type DataTrack struct {
id string
name string
participantId string
dataChannel *webrtc.DataChannel
lock sync.RWMutex
once sync.Once
msgChan chan livekit.DataMessage
onClose func()
// map of target participantId -> DownDataChannel
subscribers map[string]*DownDataChannel
}
func NewDataTrack(trackId, participantId string, dc *webrtc.DataChannel) *DataTrack {
t := &DataTrack{
// ctx: context.Background(),
id: trackId,
name: dc.Label(),
participantId: participantId,
dataChannel: dc,
msgChan: make(chan livekit.DataMessage, dataBufferSize),
lock: sync.RWMutex{},
subscribers: make(map[string]*DownDataChannel),
}
dc.OnMessage(func(msg webrtc.DataChannelMessage) {
dm := messageFromDataChannelMessage(msg)
t.msgChan <- dm
})
dc.OnClose(func() {
t.RemoveAllSubscribers()
if t.onClose != nil {
t.onClose()
}
})
return t
}
func (t *DataTrack) Start() {
t.once.Do(func() {
go t.forwardWorker()
})
}
func (t *DataTrack) ID() string {
return t.id
}
func (t *DataTrack) Kind() livekit.TrackType {
return livekit.TrackType_DATA
}
func (t *DataTrack) Name() string {
return t.name
}
// DataTrack cannot be muted
func (t *DataTrack) IsMuted() bool {
return false
}
func (t *DataTrack) SetMuted(muted bool) {
}
func (t *DataTrack) OnClose(f func()) {
t.onClose = f
}
func (t *DataTrack) IsSubscriber(subId string) bool {
t.lock.RLock()
defer t.lock.RUnlock()
return t.subscribers[subId] != nil
}
func (t *DataTrack) AddSubscriber(participant types.Participant) error {
t.lock.Lock()
defer t.lock.Unlock()
if t.subscribers[participant.ID()] != nil {
return nil
}
label := PackDataTrackLabel(t.participantId, t.ID(), t.dataChannel.Label())
downChannel, err := participant.SubscriberPC().CreateDataChannel(label, t.dataChannelOptions())
if err != nil {
return err
}
sub := &DownDataChannel{
participantId: participant.ID(),
dataChannel: downChannel,
}
t.subscribers[participant.ID()] = sub
downChannel.OnClose(func() {
t.RemoveSubscriber(sub.participantId)
})
return nil
}
func (t *DataTrack) RemoveSubscriber(participantId string) {
t.lock.Lock()
sub := t.subscribers[participantId]
delete(t.subscribers, participantId)
t.lock.Unlock()
if sub != nil {
go sub.dataChannel.Close()
}
}
func (t *DataTrack) RemoveAllSubscribers() {
t.lock.Lock()
defer t.lock.Unlock()
for _, sub := range t.subscribers {
go sub.dataChannel.Close()
}
t.subscribers = make(map[string]*DownDataChannel)
}
func (t *DataTrack) forwardWorker() {
defer Recover()
for {
msg := <-t.msgChan
if msg.Value == nil {
// track closed
return
}
t.lock.RLock()
for _, sub := range t.subscribers {
err := sub.SendMessage(msg)
if err != nil {
logger.Errorw("could not send data message", err,
"source", t.participantId,
"dest", sub.participantId)
}
}
t.lock.RUnlock()
}
}
func (t *DataTrack) dataChannelOptions() *webrtc.DataChannelInit {
ordered := t.dataChannel.Ordered()
protocol := t.dataChannel.Protocol()
negotiated := false
return &webrtc.DataChannelInit{
Ordered: &ordered,
MaxPacketLifeTime: t.dataChannel.MaxPacketLifeTime(),
MaxRetransmits: t.dataChannel.MaxRetransmits(),
Protocol: &protocol,
Negotiated: &negotiated,
}
}
type DownDataChannel struct {
participantId string
dataChannel *webrtc.DataChannel
}
func (d *DownDataChannel) SendMessage(msg livekit.DataMessage) error {
var err error
switch val := msg.Value.(type) {
case *livekit.DataMessage_Binary:
err = d.dataChannel.Send(val.Binary)
case *livekit.DataMessage_Text:
err = d.dataChannel.SendText(val.Text)
}
return err
}
func messageFromDataChannelMessage(msg webrtc.DataChannelMessage) livekit.DataMessage {
dm := livekit.DataMessage{}
if msg.IsString {
dm.Value = &livekit.DataMessage_Text{
Text: string(msg.Data),
}
} else {
dm.Value = &livekit.DataMessage_Binary{
Binary: msg.Data,
}
}
return dm
}
+4 -4
View File
@@ -19,10 +19,10 @@ func createPubMediaEngine() (*webrtc.MediaEngine, error) {
}
videoRTCPFeedback := []webrtc.RTCPFeedback{
{webrtc.TypeRTCPFBGoogREMB, ""},
{webrtc.TypeRTCPFBCCM, "fir"},
{webrtc.TypeRTCPFBNACK, ""},
{webrtc.TypeRTCPFBNACK, "pli"}}
{Type: webrtc.TypeRTCPFBGoogREMB, Parameter: ""},
{Type: webrtc.TypeRTCPFBCCM, Parameter: "fir"},
{Type: webrtc.TypeRTCPFBNACK, Parameter: ""},
{Type: webrtc.TypeRTCPFBNACK, Parameter: "pli"}}
for _, codec := range []webrtc.RTPCodecParameters{
{
RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeVP8, ClockRate: 90000, RTCPFeedback: videoRTCPFeedback},
+12 -10
View File
@@ -22,9 +22,9 @@ import (
var (
feedbackTypes = []webrtc.RTCPFeedback{
{webrtc.TypeRTCPFBGoogREMB, ""},
{webrtc.TypeRTCPFBNACK, ""},
{webrtc.TypeRTCPFBNACK, "pli"}}
{Type: webrtc.TypeRTCPFBGoogREMB},
{Type: webrtc.TypeRTCPFBNACK},
{Type: webrtc.TypeRTCPFBNACK, Parameter: "pli"}}
)
// MediaTrack represents a WebRTC track that needs to be forwarded
@@ -67,7 +67,6 @@ func NewMediaTrack(track *webrtc.TrackRemote, params MediaTrackParams) *MediaTra
streamID: track.StreamID(),
kind: ToProtoTrackKind(track.Kind()),
codec: track.Codec(),
lock: sync.RWMutex{},
subscribedTracks: make(map[string]*SubscribedTrack),
}
@@ -170,7 +169,7 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error {
// when outtrack is bound, start loop to send reports
downTrack.OnBind(func() {
subTrack.SetPublisherMuted(t.IsMuted())
go t.sendDownTrackBindingReports(sub.ID(), sub.RTCPChan())
go t.sendDownTrackBindingReports(sub)
})
downTrack.OnCloseHandler(func() {
t.lock.Lock()
@@ -302,8 +301,8 @@ func (t *MediaTrack) RemoveSubscriber(participantId string) {
func (t *MediaTrack) RemoveAllSubscribers() {
logger.Debugw("removing all subscribers", "track", t.params.TrackID)
t.lock.RLock()
defer t.lock.RUnlock()
t.lock.Lock()
defer t.lock.Unlock()
for _, subTrack := range t.subscribedTracks {
go subTrack.DownTrack().Close()
}
@@ -312,11 +311,11 @@ func (t *MediaTrack) RemoveAllSubscribers() {
// TODO: send for all downtracks from the source participant
// https://tools.ietf.org/html/rfc7941
func (t *MediaTrack) sendDownTrackBindingReports(participantId string, rtcpCh chan []rtcp.Packet) {
func (t *MediaTrack) sendDownTrackBindingReports(sub types.Participant) {
var sd []rtcp.SourceDescriptionChunk
t.lock.RLock()
subTrack := t.subscribedTracks[participantId]
subTrack := t.subscribedTracks[sub.ID()]
t.lock.RUnlock()
if subTrack == nil {
@@ -338,7 +337,10 @@ func (t *MediaTrack) sendDownTrackBindingReports(participantId string, rtcpCh ch
batch := pkts
i := 0
for {
rtcpCh <- batch
if err := sub.SubscriberPC().WriteRTCP(batch); err != nil {
logger.Errorw("could not write RTCP", err)
return
}
if i > 5 {
return
}
+6 -24
View File
@@ -27,7 +27,6 @@ import (
const (
lossyDataChannel = "_lossy"
reliableDataChannel = "_reliable"
privateDataChannel = "_private"
sdBatchSize = 20
)
@@ -90,7 +89,6 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) {
id: utils.NewGuid(utils.ParticipantPrefix),
rtcpCh: make(chan []rtcp.Packet, 50),
subscribedTracks: make(map[string][]types.SubscribedTrack),
lock: sync.RWMutex{},
publishedTracks: make(map[string]types.PublishedTrack, 0),
pendingTracks: make(map[string]*livekit.TrackInfo),
connectedAt: time.Now(),
@@ -632,10 +630,13 @@ func (p *ParticipantImpl) updateState(state livekit.ParticipantInfo_State) {
}
p.state.Store(state)
logger.Debugw("updating participant state", "state", state.String(), "participant", p.Identity())
if p.onStateChange != nil {
p.lock.RLock()
onStateChange := p.onStateChange
p.lock.RUnlock()
if onStateChange != nil {
go func() {
defer Recover()
p.onStateChange(p, oldState)
onStateChange(p, oldState)
}()
}
}
@@ -743,27 +744,8 @@ func (p *ParticipantImpl) onDataChannel(dc *webrtc.DataChannel) {
dc.OnMessage(func(msg webrtc.DataChannelMessage) {
p.handleDataMessage(livekit.DataPacket_LOSSY, msg.Data)
})
case privateDataChannel:
// ignore
default:
logger.Debugw("dataChannel added", "participant", p.Identity(), "label", dc.Label())
if !p.CanPublish() {
logger.Warnw("no permission to publish dataTrack", nil,
"participant", p.Identity())
return
}
// data channels have numeric ids, so we use its label to identify
ti := p.getPendingTrack(dc.Label(), livekit.TrackType_DATA, true)
if ti == nil {
return
}
dt := NewDataTrack(ti.Sid, p.id, dc)
dt.name = ti.Name
p.handleTrackPublished(dt)
logger.Warnw("unsupported datachannel added", nil, "participant", p.Identity(), "label", dc.Label())
}
}
+15 -13
View File
@@ -7,6 +7,7 @@ import (
"time"
"github.com/livekit/protocol/utils"
"google.golang.org/protobuf/proto"
"github.com/livekit/livekit-server/pkg/logger"
"github.com/livekit/livekit-server/pkg/rtc/types"
@@ -19,7 +20,7 @@ const (
)
type Room struct {
livekit.Room
Room *livekit.Room
config WebRTCConfig
iceServers []*livekit.ICEServer
lock sync.RWMutex
@@ -49,20 +50,19 @@ type ParticipantOptions struct {
func NewRoom(room *livekit.Room, config WebRTCConfig, iceServers []*livekit.ICEServer, audioUpdateInterval uint32) *Room {
r := &Room{
Room: *room,
Room: proto.Clone(room).(*livekit.Room),
config: config,
iceServers: iceServers,
audioUpdateInterval: audioUpdateInterval,
lock: sync.RWMutex{},
statsReporter: NewRoomStatsReporter(room.Name),
participants: make(map[string]types.Participant),
participantOpts: make(map[string]*ParticipantOptions),
}
if r.EmptyTimeout == 0 {
r.EmptyTimeout = DefaultEmptyTimeout
if r.Room.EmptyTimeout == 0 {
r.Room.EmptyTimeout = DefaultEmptyTimeout
}
if r.CreationTime == 0 {
r.CreationTime = time.Now().Unix()
if r.Room.CreationTime == 0 {
r.Room.CreationTime = time.Now().Unix()
}
r.statsReporter.RoomStarted()
go r.audioUpdateWorker()
@@ -138,7 +138,7 @@ func (r *Room) Join(participant types.Participant, opts *ParticipantOptions) err
return ErrAlreadyJoined
}
if r.MaxParticipants > 0 && int(r.MaxParticipants) == len(r.participants) {
if r.Room.MaxParticipants > 0 && int(r.Room.MaxParticipants) == len(r.participants) {
return ErrMaxParticipantsExceeded
}
@@ -176,7 +176,7 @@ func (r *Room) Join(participant types.Participant, opts *ParticipantOptions) err
logger.Infow("new participant joined",
"id", participant.ID(),
"identity", participant.Identity(),
"roomId", r.Sid)
"roomId", r.Room.Sid)
r.participants[participant.Identity()] = participant
r.participantOpts[participant.Identity()] = opts
@@ -193,7 +193,7 @@ func (r *Room) Join(participant types.Participant, opts *ParticipantOptions) err
r.onParticipantChanged(participant)
}
return participant.SendJoinResponse(&r.Room, otherParticipants, r.iceServers)
return participant.SendJoinResponse(r.Room, otherParticipants, r.iceServers)
}
func (r *Room) RemoveParticipant(identity string) {
@@ -221,9 +221,11 @@ func (r *Room) RemoveParticipant(identity string) {
// close participant as well
_ = p.Close()
r.lock.RLock()
if len(r.participants) == 0 {
r.leftAt.Store(time.Now().Unix())
}
r.lock.RUnlock()
if sendUpdates {
if r.onParticipantChanged != nil {
@@ -274,7 +276,7 @@ func (r *Room) CloseIfEmpty() {
return
}
timeout := r.EmptyTimeout
timeout := r.Room.EmptyTimeout
var elapsed int64
if r.FirstJoinedAt() > 0 {
// exit 20s after
@@ -283,7 +285,7 @@ func (r *Room) CloseIfEmpty() {
timeout = DefaultRoomDepartureGrace
}
} else {
elapsed = time.Now().Unix() - r.CreationTime
elapsed = time.Now().Unix() - r.Room.CreationTime
}
if elapsed >= int64(timeout) {
@@ -295,7 +297,7 @@ func (r *Room) Close() {
if !r.isClosed.TrySet(true) {
return
}
logger.Infow("closing room", "room", r.Sid, "name", r.Name)
logger.Infow("closing room", "room", r.Room.Sid, "name", r.Room.Name)
r.statsReporter.RoomEnded()
if r.onClose != nil {
+5 -5
View File
@@ -63,7 +63,7 @@ func TestRoomJoin(t *testing.T) {
// expect new participant to get a JoinReply
info, participants, iceServers := pNew.SendJoinResponseArgsForCall(0)
require.Equal(t, info.Sid, rm.Sid)
require.Equal(t, info.Sid, rm.Room.Sid)
require.Len(t, participants, numParticipants)
require.Len(t, rm.GetParticipants(), numParticipants+1)
require.NotEmpty(t, iceServers)
@@ -125,7 +125,7 @@ func TestRoomJoin(t *testing.T) {
t.Run("cannot exceed max participants", func(t *testing.T) {
rm := newRoomWithParticipants(t, testRoomOpts{num: 1})
rm.MaxParticipants = 1
rm.Room.MaxParticipants = 1
p := newMockParticipant("second", types.ProtocolVersion(0))
err := rm.Join(p, nil)
@@ -198,7 +198,7 @@ func TestRoomClosure(t *testing.T) {
})
p := rm.GetParticipants()[0]
// allows immediate close after
rm.EmptyTimeout = 0
rm.Room.EmptyTimeout = 0
rm.RemoveParticipant(p.Identity())
time.Sleep(defaultDelay)
@@ -216,7 +216,7 @@ func TestRoomClosure(t *testing.T) {
rm.OnClose(func() {
isClosed = true
})
require.NotZero(t, rm.EmptyTimeout)
require.NotZero(t, rm.Room.EmptyTimeout)
rm.CloseIfEmpty()
require.False(t, isClosed)
})
@@ -227,7 +227,7 @@ func TestRoomClosure(t *testing.T) {
rm.OnClose(func() {
isClosed = true
})
rm.EmptyTimeout = 1
rm.Room.EmptyTimeout = 1
time.Sleep(1010 * time.Millisecond)
rm.CloseIfEmpty()
+12
View File
@@ -192,6 +192,18 @@ func (s *PacketStats) HandleRTCP(pkts []rtcp.Packet) {
}
}
func (s PacketStats) Copy() *PacketStats {
return &PacketStats{
roomName: s.roomName,
direction: s.direction,
PacketBytes: atomic.LoadUint64(&s.PacketBytes),
PacketTotal: atomic.LoadUint64(&s.PacketTotal),
NackTotal: atomic.LoadUint64(&s.NackTotal),
PLITotal: atomic.LoadUint64(&s.PLITotal),
FIRTotal: atomic.LoadUint64(&s.FIRTotal),
}
}
// StatsBufferWrapper wraps a buffer factory so we could get information on
// incoming packets
type StatsBufferWrapper struct {
+4 -3
View File
@@ -307,10 +307,11 @@ func (r *RoomManager) getOrCreateRoom(roomName string) (*rtc.Room, error) {
if err := r.DeleteRoom(roomName); err != nil {
logger.Errorw("could not delete room", err)
}
// print stats
logger.Infow("room closed",
"incomingStats", room.GetIncomingStats(),
"outgoingStats", room.GetOutgoingStats(),
"incomingStats", room.GetIncomingStats().Copy(),
"outgoingStats", room.GetOutgoingStats().Copy(),
)
})
room.OnParticipantChanged(func(p types.Participant) {
@@ -336,7 +337,7 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.Partici
defer func() {
logger.Debugw("RTC session finishing",
"participant", participant.Identity(),
"room", room.Name,
"room", room.Room.Name,
)
_ = participant.Close()
}()
+2 -1
View File
@@ -177,7 +177,8 @@ func (s *LivekitServer) Start() error {
}
// wait for shutdown
ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_ = s.httpServer.Shutdown(ctx)
if s.turnServer != nil {
+1 -1
View File
@@ -1,6 +1,6 @@
// Code generated by Wire. DO NOT EDIT.
//go:generate go run github.com/google/wire/cmd/wire
//go:generate wire
//+build !wireinject
package service
+7 -2
View File
@@ -45,14 +45,18 @@ func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, error) {
msg := &livekit.SignalRequest{}
switch messageType {
case websocket.BinaryMessage:
c.mu.Lock()
// switch to protobuf if client supports it
c.useJSON = false
c.mu.Unlock()
// protobuf encoded
err := proto.Unmarshal(payload, msg)
return msg, err
case websocket.TextMessage:
c.mu.Lock()
// json encoded, also write back JSON
c.useJSON = true
c.mu.Unlock()
err := protojson.Unmarshal(payload, msg)
return msg, err
default:
@@ -67,6 +71,9 @@ func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) error {
var payload []byte
var err error
c.mu.Lock()
defer c.mu.Unlock()
if c.useJSON {
msgType = websocket.TextMessage
payload, err = protojson.Marshal(msg)
@@ -78,8 +85,6 @@ func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) error {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
return c.conn.WriteMessage(msgType, payload)
}
+2 -1
View File
@@ -15,7 +15,8 @@ var (
func WithTimeout(t *testing.T, description string, f func() bool) bool {
logger.Infow(description)
ctx, _ := context.WithTimeout(context.Background(), ConnectTimeout)
ctx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
defer cancel()
for {
select {
case <-ctx.Done():
+3 -1
View File
@@ -189,7 +189,7 @@ func (c *RTCClient) Run() error {
})
// create a data channel, in order to work
_, err := c.publisher.PeerConnection().CreateDataChannel("_private", nil)
_, err := c.publisher.PeerConnection().CreateDataChannel("_lossy", nil)
if err != nil {
return err
}
@@ -353,6 +353,8 @@ func (c *RTCClient) SubscribedTracks() map[string][]*webrtc.TrackRemote {
}
func (c *RTCClient) RemoteParticipants() []*livekit.ParticipantInfo {
c.lock.Lock()
defer c.lock.Unlock()
return funk.Values(c.remoteParticipants).([]*livekit.ParticipantInfo)
}
+1
View File
@@ -3,6 +3,7 @@
package tools
import (
_ "github.com/google/wire"
_ "github.com/maxbrunsfeld/counterfeiter/v6"
_ "github.com/twitchtv/twirp/protoc-gen-twirp"
_ "google.golang.org/protobuf/cmd/protoc-gen-go"