package rtc import ( "sync" "github.com/pion/webrtc/v3" "github.com/livekit/livekit-server/pkg/logger" "github.com/livekit/livekit-server/pkg/rtc/types" livekit "github.com/livekit/livekit-server/proto" ) const ( dataBufferSize = 50 ) // DataTrack wraps a WebRTC DataChannel to satisfy the PublishedTrack interface // it shall forward publishedTracks to all of its subscribers type DataTrack struct { id string name string participantId string dataChannel *webrtc.DataChannel lock sync.RWMutex once sync.Once msgChan chan livekit.DataMessage onClose func() // map of target participantId -> DownDataChannel subscribers map[string]*DownDataChannel } func NewDataTrack(trackId, participantId string, dc *webrtc.DataChannel) *DataTrack { t := &DataTrack{ //ctx: context.Background(), id: trackId, name: dc.Label(), participantId: participantId, dataChannel: dc, msgChan: make(chan livekit.DataMessage, dataBufferSize), lock: sync.RWMutex{}, subscribers: make(map[string]*DownDataChannel), } dc.OnMessage(func(msg webrtc.DataChannelMessage) { dm := messageFromDataChannelMessage(msg) t.msgChan <- dm }) dc.OnClose(func() { t.RemoveAllSubscribers() if t.onClose != nil { t.onClose() } }) return t } func (t *DataTrack) Start() { t.once.Do(func() { go t.forwardWorker() }) } func (t *DataTrack) ID() string { return t.id } func (t *DataTrack) Kind() livekit.TrackType { return livekit.TrackType_DATA } func (t *DataTrack) Name() string { return t.name } // DataTrack cannot be muted func (t *DataTrack) IsMuted() bool { return false } func (t *DataTrack) SetMuted(muted bool) { } func (t *DataTrack) OnClose(f func()) { t.onClose = f } func (t *DataTrack) IsSubscriber(subId string) bool { t.lock.RLock() defer t.lock.RUnlock() return t.subscribers[subId] != nil } func (t *DataTrack) AddSubscriber(participant types.Participant) error { t.lock.Lock() defer t.lock.Unlock() if t.subscribers[participant.ID()] != nil { return nil } label := PackDataTrackLabel(t.participantId, t.ID(), t.dataChannel.Label()) downChannel, err := participant.SubscriberPC().CreateDataChannel(label, t.dataChannelOptions()) if err != nil { return err } sub := &DownDataChannel{ participantId: participant.ID(), dataChannel: downChannel, } t.subscribers[participant.ID()] = sub downChannel.OnClose(func() { t.RemoveSubscriber(sub.participantId) }) return nil } func (t *DataTrack) RemoveSubscriber(participantId string) { t.lock.Lock() sub := t.subscribers[participantId] delete(t.subscribers, participantId) t.lock.Unlock() if sub != nil { go sub.dataChannel.Close() } } func (t *DataTrack) RemoveAllSubscribers() { t.lock.Lock() defer t.lock.Unlock() for _, sub := range t.subscribers { go sub.dataChannel.Close() } t.subscribers = make(map[string]*DownDataChannel) } func (t *DataTrack) forwardWorker() { defer Recover() for { msg := <-t.msgChan if msg.Value == nil { // track closed return } t.lock.RLock() for _, sub := range t.subscribers { err := sub.SendMessage(msg) if err != nil { logger.Errorw("could not send data message", err, "source", t.participantId, "dest", sub.participantId) } } t.lock.RUnlock() } } func (t *DataTrack) dataChannelOptions() *webrtc.DataChannelInit { ordered := t.dataChannel.Ordered() protocol := t.dataChannel.Protocol() negotiated := false return &webrtc.DataChannelInit{ Ordered: &ordered, MaxPacketLifeTime: t.dataChannel.MaxPacketLifeTime(), MaxRetransmits: t.dataChannel.MaxRetransmits(), Protocol: &protocol, Negotiated: &negotiated, } } type DownDataChannel struct { participantId string dataChannel *webrtc.DataChannel } func (d *DownDataChannel) SendMessage(msg livekit.DataMessage) error { var err error switch val := msg.Value.(type) { case *livekit.DataMessage_Binary: err = d.dataChannel.Send(val.Binary) case *livekit.DataMessage_Text: err = d.dataChannel.SendText(val.Text) } return err } func messageFromDataChannelMessage(msg webrtc.DataChannelMessage) livekit.DataMessage { dm := livekit.DataMessage{} if msg.IsString { dm.Value = &livekit.DataMessage_Text{ Text: string(msg.Data), } } else { dm.Value = &livekit.DataMessage_Binary{ Binary: msg.Data, } } return dm }