From dc60e274130f9fab0b74bc408eb7d5baabe1e11f Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Fri, 26 Nov 2021 15:40:10 +0800 Subject: [PATCH] create TrackSender & TrackReceiver (#211) * create TrackSender & TrackReceiver change WebRtcReceiver & DownTrack to use corresponding interface --- pkg/sfu/downtrack.go | 108 +++++++++++++++++++++++++++++-------------- pkg/sfu/receiver.go | 102 ++++++++++++---------------------------- 2 files changed, 103 insertions(+), 107 deletions(-) diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 28df055d3..2c136d5c3 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "strings" "sync" "time" @@ -19,6 +20,17 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/buffer" ) +// TrackSender defines a interface send media to remote peer +type TrackSender interface { + UptrackLayersChange(availableLayers []uint16, layerAdded bool) (int32, error) + WriteRTP(p *buffer.ExtPacket, layer int32) error + Close() + // ID is the globally unique identifier for this Track. + ID() string + SetTrackType(isSimulcast bool) + PeerID() string +} + // DownTrackType determines the type of track type DownTrackType int @@ -106,7 +118,7 @@ type DownTrack struct { codec webrtc.RTPCodecCapability rtpHeaderExtensions []webrtc.RTPHeaderExtensionParameter - receiver Receiver + receiver TrackReceiver transceiver *webrtc.RTPTransceiver writeStream webrtc.TrackLocalWriter onCloseHandler func() @@ -149,7 +161,7 @@ type DownTrack struct { } // NewDownTrack returns a DownTrack. -func NewDownTrack(c webrtc.RTPCodecCapability, r Receiver, bf *buffer.Factory, peerID string, mt int) (*DownTrack, error) { +func NewDownTrack(c webrtc.RTPCodecCapability, r TrackReceiver, bf *buffer.Factory, peerID string, mt int) (*DownTrack, error) { d := &DownTrack{ id: r.TrackID(), peerID: peerID, @@ -234,6 +246,8 @@ func (d *DownTrack) Codec() webrtc.RTPCodecCapability { return d.codec } // StreamID is the group this track belongs too. This must be unique func (d *DownTrack) StreamID() string { return d.streamID } +func (d *DownTrack) PeerID() string { return d.peerID } + // Sets RTP header extensions for this track func (d *DownTrack) SetRTPHeaderExtensions(rtpHeaderExtensions []webrtc.RTPHeaderExtensionParameter) { d.rtpHeaderExtensions = rtpHeaderExtensions @@ -266,7 +280,7 @@ func (d *DownTrack) SetTransceiver(transceiver *webrtc.RTPTransceiver) { d.transceiver = transceiver } -func (d *DownTrack) MaybeTranslateVP8(pkt *rtp.Packet, meta packetMeta) error { +func (d *DownTrack) maybeTranslateVP8(pkt *rtp.Packet, meta packetMeta) error { if d.vp8Munger == nil || len(pkt.Payload) == 0 { return nil } @@ -287,7 +301,7 @@ func (d *DownTrack) MaybeTranslateVP8(pkt *rtp.Packet, meta packetMeta) error { } // Writes RTP header extensions of track -func (d *DownTrack) WriteRTPHeaderExtensions(hdr *rtp.Header) error { +func (d *DownTrack) writeRTPHeaderExtensions(hdr *rtp.Header) error { // clear out extensions that may have been in the forwarded header hdr.Extension = false hdr.ExtensionProfile = 0 @@ -402,7 +416,7 @@ func (d *DownTrack) WritePaddingRTP(bytesToSend int) int { CSRC: []uint32{}, } - err = d.WriteRTPHeaderExtensions(&hdr) + err = d.writeRTPHeaderExtensions(&hdr) if err != nil { return bytesSent } @@ -837,10 +851,8 @@ func (d *DownTrack) writeSimpleRTP(extPkt *buffer.ExtPacket) error { if d.reSync.get() { if d.Kind() == webrtc.RTPCodecTypeVideo { if !extPkt.KeyFrame { - d.receiver.SendRTCP([]rtcp.Packet{ - &rtcp.PictureLossIndication{SenderSSRC: d.ssrc, MediaSSRC: extPkt.Packet.SSRC}, - }) d.lastPli.set(time.Now().UnixNano()) + d.receiver.SendPLI(0) d.pktsDropped.add(1) return nil } @@ -930,7 +942,7 @@ func (d *DownTrack) writeSimpleRTP(extPkt *buffer.ExtPacket) error { hdr.SequenceNumber = newSN hdr.SSRC = d.ssrc - err = d.WriteRTPHeaderExtensions(&hdr) + err = d.writeRTPHeaderExtensions(&hdr) if err != nil { return err } @@ -1001,9 +1013,7 @@ func (d *DownTrack) writeSimulcastRTP(extPkt *buffer.ExtPacket, layer int32) err // all the packets to down tracks and down track should be // the only one deciding whether to switch/forward/drop // LK-TODO-END - d.receiver.SendRTCP([]rtcp.Packet{ - &rtcp.PictureLossIndication{SenderSSRC: d.ssrc, MediaSSRC: extPkt.Packet.SSRC}, - }) + d.receiver.SendPLI(layer) d.lastPli.set(time.Now().UnixNano()) d.pktsDropped.add(1) return nil @@ -1122,7 +1132,7 @@ func (d *DownTrack) writeSimulcastRTP(extPkt *buffer.ExtPacket, layer int32) err hdr.SSRC = d.ssrc hdr.PayloadType = d.payloadType - err = d.WriteRTPHeaderExtensions(&hdr) + err = d.writeRTPHeaderExtensions(&hdr) if err != nil { return err } @@ -1177,7 +1187,7 @@ func (d *DownTrack) writeBlankFrameRTP() error { CSRC: []uint32{}, } - err = d.WriteRTPHeaderExtensions(&hdr) + err = d.writeRTPHeaderExtensions(&hdr) if err != nil { return err } @@ -1260,9 +1270,7 @@ func (d *DownTrack) handleRTCP(bytes []byte) { } } - var fwdPkts []rtcp.Packet pliOnce := true - firOnce := true var ( maxRatePacketLoss uint8 @@ -1273,23 +1281,20 @@ func (d *DownTrack) handleRTCP(bytes []byte) { return } + sendPliOnce := func() { + if pliOnce { + d.lastPli.set(time.Now().UnixNano()) + d.receiver.SendPLI(d.TargetSpatialLayer()) + pliOnce = false + } + } + for _, pkt := range pkts { switch p := pkt.(type) { case *rtcp.PictureLossIndication: - if pliOnce { - d.lastPli.set(time.Now().UnixNano()) - p.MediaSSRC = ssrc - p.SenderSSRC = d.ssrc - fwdPkts = append(fwdPkts, p) - pliOnce = false - } + sendPliOnce() case *rtcp.FullIntraRequest: - if firOnce { - p.MediaSSRC = ssrc - p.SenderSSRC = d.ssrc - fwdPkts = append(fwdPkts, p) - firOnce = false - } + sendPliOnce() case *rtcp.ReceiverEstimatedMaximumBitrate: if d.onREMB != nil { d.onREMB(d, p) @@ -1322,14 +1327,49 @@ func (d *DownTrack) handleRTCP(bytes []byte) { for _, pair := range p.Nacks { nackedPackets = append(nackedPackets, d.sequencer.getSeqNoPairs(pair.PacketList())...) } - if err = d.receiver.RetransmitPackets(d, nackedPackets); err != nil { - return - } + go d.retransmitPackets(nackedPackets) } } +} - if len(fwdPkts) > 0 { - d.receiver.SendRTCP(fwdPkts) +func (d *DownTrack) retransmitPackets(nackedPackets []packetMeta) { + src := packetFactory.Get().(*[]byte) + defer packetFactory.Put(src) + for _, meta := range nackedPackets { + pktBuff := *src + n, err := d.receiver.ReadRTP(pktBuff, meta.layer, meta.sourceSeqNo) + if err != nil { + if err == io.EOF { + break + } + continue + } + var pkt rtp.Packet + if err = pkt.Unmarshal(pktBuff[:n]); err != nil { + continue + } + pkt.Header.SequenceNumber = meta.targetSeqNo + pkt.Header.Timestamp = meta.timestamp + pkt.Header.SSRC = d.ssrc + pkt.Header.PayloadType = d.payloadType + + err = d.maybeTranslateVP8(&pkt, meta) + if err != nil { + Logger.Error(err, "translating VP8 packet err") + continue + } + + err = d.writeRTPHeaderExtensions(&pkt.Header) + if err != nil { + Logger.Error(err, "writing rtp header extensions err") + continue + } + + if _, err = d.writeStream.WriteRTP(&pkt.Header, pkt.Payload); err != nil { + Logger.Error(err, "Writing rtx packet err") + } else { + d.UpdateStats(uint32(n)) + } } } diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 8e9b5c216..49d679354 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -8,32 +8,38 @@ import ( "sync/atomic" "time" - "github.com/gammazero/workerpool" "github.com/pion/rtcp" - "github.com/pion/rtp" "github.com/pion/webrtc/v3" "github.com/rs/zerolog/log" "github.com/livekit/livekit-server/pkg/sfu/buffer" ) +// TrackReceiver defines a interface receive media from remote peer +type TrackReceiver interface { + TrackID() string + StreamID() string + GetBitrateTemporalCumulative() [3][4]uint64 + ReadRTP(buf []byte, layer uint8, sn uint16) (int, error) + AddDownTrack(track TrackSender) + DeleteDownTrack(peerID string) + SendPLI(layer int32) + GetSenderReportTime(layer int32) (rtpTS uint32, ntpTS uint64) +} + // Receiver defines a interface for a track receivers type Receiver interface { TrackID() string StreamID() string Codec() webrtc.RTPCodecParameters - Kind() webrtc.RTPCodecType - SSRC(layer int) uint32 - SetTrackMeta(trackID, streamID string) AddUpTrack(track *webrtc.TrackRemote, buffer *buffer.Buffer) - AddDownTrack(track *DownTrack) + AddDownTrack(track TrackSender) SetUpTrackPaused(paused bool) NumAvailableSpatialLayers() int GetBitrateTemporalCumulative() [3][4]uint64 - RetransmitPackets(track *DownTrack, packets []packetMeta) error + ReadRTP(buf []byte, layer uint8, sn uint16) (int, error) DeleteDownTrack(peerID string) OnCloseHandler(fn func()) - SendRTCP(p []rtcp.Packet) SendPLI(layer int32) SetRTCPCh(ch chan []rtcp.Packet) @@ -58,7 +64,6 @@ type WebRTCReceiver struct { stream string receiver *webrtc.RTPReceiver codec webrtc.RTPCodecParameters - nackWorker *workerpool.WorkerPool isSimulcast bool availableLayers atomic.Value onCloseHandler func() @@ -79,7 +84,7 @@ type WebRTCReceiver struct { upTracks [3]*webrtc.TrackRemote downTrackMu sync.RWMutex - downTracks []*DownTrack + downTracks []TrackSender index map[string]int free map[int]struct{} numProcs int @@ -129,10 +134,9 @@ func NewWebRTCReceiver(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote, streamID: track.StreamID(), codec: track.Codec(), kind: track.Kind(), - nackWorker: workerpool.New(1), isSimulcast: len(track.RID()) > 0, pliThrottle: 500e6, - downTracks: make([]*DownTrack, 0), + downTracks: make([]TrackSender, 0), index: make(map[string]int), free: make(map[int]struct{}), numProcs: runtime.NumCPU(), @@ -231,13 +235,13 @@ func (w *WebRTCReceiver) SetUpTrackPaused(paused bool) { } } -func (w *WebRTCReceiver) AddDownTrack(track *DownTrack) { +func (w *WebRTCReceiver) AddDownTrack(track TrackSender) { if w.closed.get() { return } w.downTrackMu.RLock() - _, ok := w.index[track.peerID] + _, ok := w.index[track.PeerID()] w.downTrackMu.RUnlock() if ok { return @@ -413,58 +417,11 @@ func (w *WebRTCReceiver) GetSenderReportTime(layer int32) (rtpTS uint32, ntpTS u return } -func (w *WebRTCReceiver) RetransmitPackets(track *DownTrack, packets []packetMeta) error { - if w.nackWorker.Stopped() { - return io.ErrClosedPipe - } - // LK-TODO: should move down track specific bits into there - w.nackWorker.Submit(func() { - src := packetFactory.Get().(*[]byte) - for _, meta := range packets { - pktBuff := *src - w.bufferMu.RLock() - buff := w.buffers[meta.layer] - w.bufferMu.RUnlock() - if buff == nil { - break - } - i, err := buff.GetPacket(pktBuff, meta.sourceSeqNo) - if err != nil { - if err == io.EOF { - break - } - continue - } - var pkt rtp.Packet - if err = pkt.Unmarshal(pktBuff[:i]); err != nil { - continue - } - pkt.Header.SequenceNumber = meta.targetSeqNo - pkt.Header.Timestamp = meta.timestamp - pkt.Header.SSRC = track.ssrc - pkt.Header.PayloadType = track.payloadType - - err = track.MaybeTranslateVP8(&pkt, meta) - if err != nil { - Logger.Error(err, "translating VP8 packet err") - continue - } - - err = track.WriteRTPHeaderExtensions(&pkt.Header) - if err != nil { - Logger.Error(err, "writing rtp header extensions err") - continue - } - - if _, err = track.writeStream.WriteRTP(&pkt.Header, pkt.Payload); err != nil { - Logger.Error(err, "Writing rtx packet err") - } else { - track.UpdateStats(uint32(i)) - } - } - packetFactory.Put(src) - }) - return nil +func (w *WebRTCReceiver) ReadRTP(buf []byte, layer uint8, sn uint16) (int, error) { + w.bufferMu.RLock() + buff := w.buffers[layer] + w.bufferMu.RUnlock() + return buff.GetPacket(buf, sn) } func (w *WebRTCReceiver) forwardRTP(layer int32) { @@ -536,9 +493,9 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) { } } -func (w *WebRTCReceiver) writeRTP(layer int32, dt *DownTrack, pkt *buffer.ExtPacket) { +func (w *WebRTCReceiver) writeRTP(layer int32, dt TrackSender, pkt *buffer.ExtPacket) { if err := dt.WriteRTP(pkt, layer); err != nil { - log.Error().Err(err).Str("id", dt.id).Msg("Error writing to down track") + log.Error().Err(err).Str("id", dt.ID()).Msg("Error writing to down track") } } @@ -550,29 +507,28 @@ func (w *WebRTCReceiver) closeTracks() { dt.Close() } } - w.downTracks = make([]*DownTrack, 0) + w.downTracks = make([]TrackSender, 0) w.index = make(map[string]int) w.free = make(map[int]struct{}) w.downTrackMu.Unlock() - w.nackWorker.StopWait() if w.onCloseHandler != nil { w.onCloseHandler() } } -func (w *WebRTCReceiver) storeDownTrack(track *DownTrack) { +func (w *WebRTCReceiver) storeDownTrack(track TrackSender) { w.downTrackMu.Lock() defer w.downTrackMu.Unlock() for idx := range w.free { - w.index[track.peerID] = idx + w.index[track.PeerID()] = idx w.downTracks[idx] = track delete(w.free, idx) return } - w.index[track.peerID] = len(w.downTracks) + w.index[track.PeerID()] = len(w.downTracks) w.downTracks = append(w.downTracks, track) }