From ff390820e192748e8b98db0115824cf3bf98ad81 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sun, 5 Dec 2021 10:01:35 +0530 Subject: [PATCH] Make VP8 packet translation thread-safe. (#237) * Make VP8 packet translation thread-safe. Was using one packet from pool for all VP8 translation which was not thread safe. Grab packets from the pool when needed for VP8 translation and return to pool after done. Do not grab packet from pool if the header size between incoming and translated matches. That also saves copying the packet payload. * Keep Get/Put in the same function. --- pkg/sfu/downtrack.go | 129 +++++++++++++++++++++---------------------- pkg/sfu/forwarder.go | 22 ++++++++ 2 files changed, 84 insertions(+), 67 deletions(-) diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 21e59d9ff..a5b58524b 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -42,26 +42,6 @@ const ( RTPPaddingMaxPayloadSize = 255 RTPPaddingEstimatedHeaderSize = 20 RTPBlankFramesMax = 6 - - InvalidSpatialLayer = -1 - InvalidTemporalLayer = -1 -) - -type SequenceNumberOrdering int - -const ( - SequenceNumberOrderingContiguous SequenceNumberOrdering = iota - SequenceNumberOrderingOutOfOrder - SequenceNumberOrderingGap - SequenceNumberOrderingDuplicate -) - -type ForwardingStatus int - -const ( - ForwardingStatusOff ForwardingStatus = iota - ForwardingStatusPartial - ForwardingStatusOptimal ) var ( @@ -131,7 +111,6 @@ type DownTrack struct { sequencer *sequencer trackType DownTrackType bufferFactory *buffer.Factory - payload *[]byte forwarder *Forwarder @@ -197,10 +176,6 @@ func NewDownTrack(c webrtc.RTPCodecCapability, r TrackReceiver, bf *buffer.Facto forwarder: NewForwarder(c, kind), } - if strings.ToLower(c.MimeType) == "video/vp8" { - d.payload = PacketFactory.Get().(*[]byte) - } - return d, nil } @@ -287,6 +262,14 @@ func (d *DownTrack) SetTransceiver(transceiver *webrtc.RTPTransceiver) { // WriteRTP writes a RTP Packet to the DownTrack func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { + var pool *[]byte + defer func() { + if pool != nil { + PacketFactory.Put(pool) + pool = nil + } + }() + d.lastRTP.set(time.Now().UnixNano()) if !d.bound.get() { @@ -303,10 +286,15 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { return err } - payload := extPkt.Packet.Payload + payload := &extPkt.Packet.Payload if tp.vp8 != nil { incomingVP8, _ := extPkt.Payload.(buffer.VP8) - payload, err = d.translateVP8Packet(&extPkt.Packet, &incomingVP8, tp.vp8.header) + + if incomingVP8.HeaderSize != tp.vp8.header.HeaderSize { + pool = PacketFactory.Get().(*[]byte) + payload = pool + } + err = d.translateVP8PacketTo(&extPkt.Packet, &incomingVP8, tp.vp8.header, payload) if err != nil { d.pktsDropped.add(1) return err @@ -326,17 +314,17 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { return err } - _, err = d.writeStream.WriteRTP(hdr, payload) + _, err = d.writeStream.WriteRTP(hdr, *payload) if err == nil { for _, f := range d.onPacketSent { - f(d, hdr.MarshalSize()+len(payload)) + f(d, hdr.MarshalSize()+len(*payload)) } } else { d.pktsDropped.add(1) } // LK-TODO maybe include RTP header size also - d.UpdateStats(uint32(len(payload))) + d.UpdateStats(uint32(len(*payload))) return err } @@ -460,9 +448,6 @@ func (d *DownTrack) Close() { d.closeOnce.Do(func() { Logger.V(1).Info("Closing sender", "peer_id", d.peerID, "kind", d.kind) - if d.payload != nil { - PacketFactory.Put(d.payload) - } if d.onCloseHandler != nil { d.onCloseHandler() } @@ -787,30 +772,24 @@ func (d *DownTrack) handleRTCP(bytes []byte) { } } -func (d *DownTrack) maybeTranslateVP8(pkt *rtp.Packet, meta packetMeta) error { - if d.mime != "video/vp8" || len(pkt.Payload) == 0 { - return nil - } - - var incomingVP8 buffer.VP8 - if err := incomingVP8.Unmarshal(pkt.Payload); err != nil { - return err - } - - translatedVP8 := meta.unpackVP8() - payload, err := d.translateVP8Packet(pkt, &incomingVP8, translatedVP8) - if err != nil { - return err - } - - pkt.Payload = payload - return nil -} - func (d *DownTrack) retransmitPackets(nackedPackets []packetMeta) { + var pool *[]byte + defer func() { + if pool != nil { + PacketFactory.Put(pool) + pool = nil + } + }() + src := PacketFactory.Get().(*[]byte) defer PacketFactory.Put(src) + for _, meta := range nackedPackets { + if pool != nil { + PacketFactory.Put(pool) + pool = nil + } + pktBuff := *src n, err := d.receiver.ReadRTP(pktBuff, meta.layer, meta.sourceSeqNo) if err != nil { @@ -828,10 +807,24 @@ func (d *DownTrack) retransmitPackets(nackedPackets []packetMeta) { pkt.Header.SSRC = d.ssrc pkt.Header.PayloadType = d.payloadType - err = d.maybeTranslateVP8(&pkt, meta) - if err != nil { - Logger.Error(err, "translating VP8 packet err") - continue + payload := &pkt.Payload + if d.mime == "video/vp8" && len(pkt.Payload) > 0 { + var incomingVP8 buffer.VP8 + if err = incomingVP8.Unmarshal(pkt.Payload); err != nil { + Logger.Error(err, "unmarshalling VP8 packet err") + continue + } + + translatedVP8 := meta.unpackVP8() + if incomingVP8.HeaderSize != translatedVP8.HeaderSize { + pool = PacketFactory.Get().(*[]byte) + payload = pool + } + err = d.translateVP8PacketTo(&pkt, &incomingVP8, translatedVP8, payload) + if err != nil { + Logger.Error(err, "translating VP8 packet err") + continue + } } err = d.writeRTPHeaderExtensions(&pkt.Header) @@ -840,7 +833,7 @@ func (d *DownTrack) retransmitPackets(nackedPackets []packetMeta) { continue } - if _, err = d.writeStream.WriteRTP(&pkt.Header, pkt.Payload); err != nil { + if _, err = d.writeStream.WriteRTP(&pkt.Header, *payload); err != nil { Logger.Error(err, "Writing rtx packet err") } else { d.UpdateStats(uint32(n)) @@ -895,18 +888,20 @@ func (d *DownTrack) getTranslatedRTPHeader(extPkt *buffer.ExtPacket, tpRTP *Tran return &hdr, nil } -func (d *DownTrack) translateVP8Packet(pkt *rtp.Packet, incomingVP8 *buffer.VP8, translatedVP8 *buffer.VP8) (buf []byte, err error) { - buf = *d.payload - buf = buf[:len(pkt.Payload)+translatedVP8.HeaderSize-incomingVP8.HeaderSize] +func (d *DownTrack) translateVP8PacketTo(pkt *rtp.Packet, incomingVP8 *buffer.VP8, translatedVP8 *buffer.VP8, outbuf *[]byte) error { + var buf []byte + if outbuf == &pkt.Payload { + buf = pkt.Payload + } else { + buf = (*outbuf)[:len(pkt.Payload)+translatedVP8.HeaderSize-incomingVP8.HeaderSize] - srcPayload := pkt.Payload[incomingVP8.HeaderSize:] - dstPayload := buf[translatedVP8.HeaderSize:] - copy(dstPayload, srcPayload) + srcPayload := pkt.Payload[incomingVP8.HeaderSize:] + dstPayload := buf[translatedVP8.HeaderSize:] + copy(dstPayload, srcPayload) + } hdr := buf[:translatedVP8.HeaderSize] - err = translatedVP8.MarshalTo(hdr) - - return + return translatedVP8.MarshalTo(hdr) } func (d *DownTrack) DebugInfo() map[string]interface{} { diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index 51f00dd9a..1e50c8e52 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -12,6 +12,28 @@ import ( // // Forwarder // +const ( + InvalidSpatialLayer = -1 + InvalidTemporalLayer = -1 +) + +type SequenceNumberOrdering int + +const ( + SequenceNumberOrderingContiguous SequenceNumberOrdering = iota + SequenceNumberOrderingOutOfOrder + SequenceNumberOrderingGap + SequenceNumberOrderingDuplicate +) + +type ForwardingStatus int + +const ( + ForwardingStatusOff ForwardingStatus = iota + ForwardingStatusPartial + ForwardingStatusOptimal +) + type VideoStreamingChange int const (