mirror of
https://github.com/livekit/livekit.git
synced 2026-03-30 17:45:40 +00:00
195 lines
4.9 KiB
Go
195 lines
4.9 KiB
Go
// Copyright 2024 LiveKit, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package interceptor
|
|
|
|
import (
|
|
"sync"
|
|
|
|
"github.com/pion/interceptor"
|
|
"github.com/pion/sdp/v3"
|
|
"github.com/pion/webrtc/v4"
|
|
|
|
"github.com/livekit/livekit-server/pkg/sfu/utils"
|
|
"github.com/livekit/protocol/logger"
|
|
)
|
|
|
|
const (
|
|
SDESRepairRTPStreamIDURI = "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id"
|
|
|
|
rtxProbeCount = 10
|
|
)
|
|
|
|
type streamInfo struct {
|
|
mid string
|
|
rid string
|
|
rsid string
|
|
}
|
|
|
|
type RTXInfoExtractorFactory struct {
|
|
onStreamFound func(*interceptor.StreamInfo)
|
|
onRTXPairFound func(repair, base uint32, rsid string)
|
|
lock sync.Mutex
|
|
streams map[uint32]streamInfo
|
|
logger logger.Logger
|
|
}
|
|
|
|
func NewRTXInfoExtractorFactory(
|
|
onStreamFound func(*interceptor.StreamInfo),
|
|
onRTXPairFound func(repair, base uint32, rsid string),
|
|
logger logger.Logger,
|
|
) *RTXInfoExtractorFactory {
|
|
return &RTXInfoExtractorFactory{
|
|
onStreamFound: onStreamFound,
|
|
onRTXPairFound: onRTXPairFound,
|
|
streams: make(map[uint32]streamInfo),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (f *RTXInfoExtractorFactory) NewInterceptor(id string) (interceptor.Interceptor, error) {
|
|
return &RTXInfoExtractor{
|
|
factory: f,
|
|
logger: f.logger,
|
|
}, nil
|
|
}
|
|
|
|
func (f *RTXInfoExtractorFactory) SetStreamInfo(ssrc uint32, mid, rid, rsid string) {
|
|
var repairSsrc, baseSsrc uint32
|
|
var repairSid string
|
|
f.lock.Lock()
|
|
|
|
if mid == "" || (rid == "" && rsid == "") {
|
|
f.lock.Unlock()
|
|
return
|
|
}
|
|
|
|
if rsid != "" {
|
|
// repair stream found, find base stream
|
|
for base, info := range f.streams {
|
|
if info.mid == mid && info.rid == rsid {
|
|
repairSsrc = ssrc
|
|
baseSsrc = base
|
|
repairSid = rsid
|
|
delete(f.streams, base)
|
|
break
|
|
}
|
|
}
|
|
} else {
|
|
// base stream found, find repair stream
|
|
for repair, info := range f.streams {
|
|
if info.mid == mid && info.rsid == rid {
|
|
repairSsrc = repair
|
|
baseSsrc = ssrc
|
|
repairSid = info.rid
|
|
delete(f.streams, repair)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// no rtx pair found, save it for later
|
|
if repairSsrc == 0 || baseSsrc == 0 {
|
|
f.streams[ssrc] = streamInfo{
|
|
mid: mid,
|
|
rid: rid,
|
|
rsid: rsid,
|
|
}
|
|
}
|
|
|
|
f.lock.Unlock()
|
|
|
|
if repairSsrc != 0 && baseSsrc != 0 {
|
|
f.onRTXPairFound(repairSsrc, baseSsrc, repairSid)
|
|
}
|
|
}
|
|
|
|
type RTXInfoExtractor struct {
|
|
interceptor.NoOp
|
|
|
|
factory *RTXInfoExtractorFactory
|
|
logger logger.Logger
|
|
}
|
|
|
|
func (u *RTXInfoExtractor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
|
|
u.factory.onStreamFound(info)
|
|
|
|
midExtensionID := utils.GetHeaderExtensionID(info.RTPHeaderExtensions, webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESMidURI})
|
|
streamIDExtensionID := utils.GetHeaderExtensionID(info.RTPHeaderExtensions, webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESRTPStreamIDURI})
|
|
repairStreamIDExtensionID := utils.GetHeaderExtensionID(info.RTPHeaderExtensions, webrtc.RTPHeaderExtensionCapability{URI: SDESRepairRTPStreamIDURI})
|
|
if midExtensionID == 0 || streamIDExtensionID == 0 || repairStreamIDExtensionID == 0 {
|
|
return reader
|
|
}
|
|
|
|
return &rtxInfoReader{
|
|
tryTimes: rtxProbeCount,
|
|
reader: reader,
|
|
midExtID: uint8(midExtensionID),
|
|
ridExtID: uint8(streamIDExtensionID),
|
|
rsidExtID: uint8(repairStreamIDExtensionID),
|
|
factory: u.factory,
|
|
logger: u.logger,
|
|
}
|
|
}
|
|
|
|
type rtxInfoReader struct {
|
|
tryTimes int
|
|
reader interceptor.RTPReader
|
|
midExtID uint8
|
|
ridExtID uint8
|
|
rsidExtID uint8
|
|
factory *RTXInfoExtractorFactory
|
|
logger logger.Logger
|
|
}
|
|
|
|
func (r *rtxInfoReader) Read(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
|
|
n, a, err := r.reader.Read(b, a)
|
|
if r.tryTimes < 0 || err != nil {
|
|
return n, a, err
|
|
}
|
|
|
|
if a == nil {
|
|
a = make(interceptor.Attributes)
|
|
}
|
|
header, err := a.GetRTPHeader(b[:n])
|
|
if err != nil {
|
|
return n, a, nil
|
|
}
|
|
|
|
var mid, rid, rsid string
|
|
if payload := header.GetExtension(r.midExtID); payload != nil {
|
|
mid = string(payload)
|
|
}
|
|
|
|
if payload := header.GetExtension(r.ridExtID); payload != nil {
|
|
rid = string(payload)
|
|
}
|
|
|
|
if payload := header.GetExtension(r.rsidExtID); payload != nil {
|
|
rsid = string(payload)
|
|
}
|
|
|
|
if mid != "" && (rid != "" || rsid != "") {
|
|
r.logger.Debugw("stream found", "mid", mid, "rid", rid, "rsid", rsid, "ssrc", header.SSRC)
|
|
r.tryTimes = -1
|
|
go r.factory.SetStreamInfo(header.SSRC, mid, rid, rsid)
|
|
} else {
|
|
// ignore padding only packet for probe count
|
|
if !(header.Padding && n-header.MarshalSize()-int(b[n-1]) == 0) {
|
|
r.tryTimes--
|
|
}
|
|
}
|
|
return n, a, nil
|
|
}
|