From 6464ae3cd6881091047fce2a79883b3f2c4c5d3a Mon Sep 17 00:00:00 2001 From: David Zhao Date: Sun, 10 Jan 2021 22:34:02 -0800 Subject: [PATCH] send downtrack binding reports fixed tests --- pkg/auth/interfaces.go | 8 -- pkg/rtc/mediatrack.go | 70 +++++----- pkg/rtc/participant.go | 6 +- pkg/rtc/types/interfaces.go | 4 + pkg/rtc/types/typesfakes/fake_down_track.go | 131 +++++++++++++++++++ pkg/rtc/types/typesfakes/fake_participant.go | 66 ++++++++++ 6 files changed, 238 insertions(+), 47 deletions(-) diff --git a/pkg/auth/interfaces.go b/pkg/auth/interfaces.go index 187ac9505..18727697a 100644 --- a/pkg/auth/interfaces.go +++ b/pkg/auth/interfaces.go @@ -10,14 +10,6 @@ var ( ErrKeysMissing = errors.New("missing API key or secret key") ) -////counterfeiter:generate . AccessToken -//type AccessToken interface { -// SetExpiration(time.Duration) AccessToken -// Identity(string) AccessToken -// AddGrant(*VideoGrant) AccessToken -// ToJWT() (string, error) -//} - //counterfeiter:generate . TokenVerifier type TokenVerifier interface { Identity() string diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 528a4d480..d885ec978 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -134,7 +134,7 @@ func (t *MediaTrack) AddSubscriber(participant types.Participant) error { t.handleRTCP(outTrack, pkt) }) } - //go t.scheduleDownTrackBindingReports(recv.Name()) + go t.sendDownTrackBindingReports(participant.ID(), participant.RTCPChan()) }) outTrack.OnCloseHandler(func() { t.lock.Lock() @@ -187,43 +187,37 @@ func (t *MediaTrack) RemoveAllSubscribers() { t.downtracks = make(map[string]types.DownTrack) } -//func (t *MediaTrack) scheduleDownTrackBindingReports(streamId string) { -// var sd []rtcp.SourceDescriptionChunk -// -// p.lock.RLock() -// dts := p.subscribedTracks[streamId] -// for _, dt := range dts { -// if !dt.IsBound() { -// continue -// } -// chunks := dt.CreateSourceDescriptionChunks() -// if chunks != nil { -// sd = append(sd, chunks...) -// } -// } -// p.lock.RUnlock() -// -// pkts := []rtcp.Packet{ -// &rtcp.SourceDescription{Chunks: sd}, -// } -// -// go func() { -// batch := pkts -// i := 0 -// for { -// if err := p.peerConn.WriteRTCP(batch); err != nil { -// logger.GetLogger().Debugw("error sending track binding reports", -// "participant", p.id, -// "err", err) -// } -// if i > 5 { -// return -// } -// i++ -// time.Sleep(20 * time.Millisecond) -// } -// }() -//} +func (t *MediaTrack) sendDownTrackBindingReports(participantId string, rtcpCh chan<- []rtcp.Packet) { + var sd []rtcp.SourceDescriptionChunk + + t.lock.RLock() + defer t.lock.RUnlock() + dt := t.downtracks[participantId] + if !dt.IsBound() { + return + } + chunks := dt.CreateSourceDescriptionChunks() + if chunks != nil { + sd = append(sd, chunks...) + } + + pkts := []rtcp.Packet{ + &rtcp.SourceDescription{Chunks: sd}, + } + + go func() { + batch := pkts + i := 0 + for { + rtcpCh <- batch + if i > 5 { + return + } + i++ + time.Sleep(20 * time.Millisecond) + } + }() +} // b reads from the receiver and writes to each sender func (t *MediaTrack) forwardRTPWorker() { diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 4ffdb87cb..4ef8c3cc8 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -159,6 +159,10 @@ func (p *ParticipantImpl) IsReady() bool { return p.state == livekit.ParticipantInfo_JOINED || p.state == livekit.ParticipantInfo_ACTIVE } +func (p *ParticipantImpl) RTCPChan() chan<- []rtcp.Packet { + return p.rtcpCh +} + func (p *ParticipantImpl) ToProto() *livekit.ParticipantInfo { info := &livekit.ParticipantInfo{ Sid: p.id, @@ -407,7 +411,7 @@ func (p *ParticipantImpl) AddDownTrack(streamId string, dt *sfu.DownTrack) { p.subscribedTracks[streamId] = append(p.subscribedTracks[streamId], dt) p.lock.Unlock() //dt.OnBind(func() { - // go p.scheduleDownTrackBindingReports(streamId) + // go p.sendDownTrackBindingReports(streamId) //}) } diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 21fc1d4dd..e1fbfc4a1 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -56,6 +56,8 @@ type Participant interface { State() livekit.ParticipantInfo_State IsReady() bool ToProto() *livekit.ParticipantInfo + RTCPChan() chan<- []rtcp.Packet + AddTrack(clientId, name string, trackType livekit.TrackType) RemoveTrack(sid string) error Answer(sdp webrtc.SessionDescription) (answer webrtc.SessionDescription, err error) @@ -112,6 +114,7 @@ type Receiver interface { //counterfeiter:generate . DownTrack type DownTrack interface { WriteRTP(p rtp.Packet) error + IsBound() bool Close() OnCloseHandler(fn func()) OnBind(fn func()) @@ -120,6 +123,7 @@ type DownTrack interface { SnOffset() uint16 TsOffset() uint32 GetNACKSeqNo(seqNo []uint16) []uint16 + CreateSourceDescriptionChunks() []rtcp.SourceDescriptionChunk } // interface for properties of webrtc.TrackRemote diff --git a/pkg/rtc/types/typesfakes/fake_down_track.go b/pkg/rtc/types/typesfakes/fake_down_track.go index 8c09317a9..1d62796fa 100644 --- a/pkg/rtc/types/typesfakes/fake_down_track.go +++ b/pkg/rtc/types/typesfakes/fake_down_track.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/pion/rtcp" "github.com/pion/rtp" ) @@ -13,6 +14,16 @@ type FakeDownTrack struct { closeMutex sync.RWMutex closeArgsForCall []struct { } + CreateSourceDescriptionChunksStub func() []rtcp.SourceDescriptionChunk + createSourceDescriptionChunksMutex sync.RWMutex + createSourceDescriptionChunksArgsForCall []struct { + } + createSourceDescriptionChunksReturns struct { + result1 []rtcp.SourceDescriptionChunk + } + createSourceDescriptionChunksReturnsOnCall map[int]struct { + result1 []rtcp.SourceDescriptionChunk + } GetNACKSeqNoStub func([]uint16) []uint16 getNACKSeqNoMutex sync.RWMutex getNACKSeqNoArgsForCall []struct { @@ -24,6 +35,16 @@ type FakeDownTrack struct { getNACKSeqNoReturnsOnCall map[int]struct { result1 []uint16 } + IsBoundStub func() bool + isBoundMutex sync.RWMutex + isBoundArgsForCall []struct { + } + isBoundReturns struct { + result1 bool + } + isBoundReturnsOnCall map[int]struct { + result1 bool + } LastSSRCStub func() uint32 lastSSRCMutex sync.RWMutex lastSSRCArgsForCall []struct { @@ -113,6 +134,59 @@ func (fake *FakeDownTrack) CloseCalls(stub func()) { fake.CloseStub = stub } +func (fake *FakeDownTrack) CreateSourceDescriptionChunks() []rtcp.SourceDescriptionChunk { + fake.createSourceDescriptionChunksMutex.Lock() + ret, specificReturn := fake.createSourceDescriptionChunksReturnsOnCall[len(fake.createSourceDescriptionChunksArgsForCall)] + fake.createSourceDescriptionChunksArgsForCall = append(fake.createSourceDescriptionChunksArgsForCall, struct { + }{}) + stub := fake.CreateSourceDescriptionChunksStub + fakeReturns := fake.createSourceDescriptionChunksReturns + fake.recordInvocation("CreateSourceDescriptionChunks", []interface{}{}) + fake.createSourceDescriptionChunksMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDownTrack) CreateSourceDescriptionChunksCallCount() int { + fake.createSourceDescriptionChunksMutex.RLock() + defer fake.createSourceDescriptionChunksMutex.RUnlock() + return len(fake.createSourceDescriptionChunksArgsForCall) +} + +func (fake *FakeDownTrack) CreateSourceDescriptionChunksCalls(stub func() []rtcp.SourceDescriptionChunk) { + fake.createSourceDescriptionChunksMutex.Lock() + defer fake.createSourceDescriptionChunksMutex.Unlock() + fake.CreateSourceDescriptionChunksStub = stub +} + +func (fake *FakeDownTrack) CreateSourceDescriptionChunksReturns(result1 []rtcp.SourceDescriptionChunk) { + fake.createSourceDescriptionChunksMutex.Lock() + defer fake.createSourceDescriptionChunksMutex.Unlock() + fake.CreateSourceDescriptionChunksStub = nil + fake.createSourceDescriptionChunksReturns = struct { + result1 []rtcp.SourceDescriptionChunk + }{result1} +} + +func (fake *FakeDownTrack) CreateSourceDescriptionChunksReturnsOnCall(i int, result1 []rtcp.SourceDescriptionChunk) { + fake.createSourceDescriptionChunksMutex.Lock() + defer fake.createSourceDescriptionChunksMutex.Unlock() + fake.CreateSourceDescriptionChunksStub = nil + if fake.createSourceDescriptionChunksReturnsOnCall == nil { + fake.createSourceDescriptionChunksReturnsOnCall = make(map[int]struct { + result1 []rtcp.SourceDescriptionChunk + }) + } + fake.createSourceDescriptionChunksReturnsOnCall[i] = struct { + result1 []rtcp.SourceDescriptionChunk + }{result1} +} + func (fake *FakeDownTrack) GetNACKSeqNo(arg1 []uint16) []uint16 { var arg1Copy []uint16 if arg1 != nil { @@ -179,6 +253,59 @@ func (fake *FakeDownTrack) GetNACKSeqNoReturnsOnCall(i int, result1 []uint16) { }{result1} } +func (fake *FakeDownTrack) IsBound() bool { + fake.isBoundMutex.Lock() + ret, specificReturn := fake.isBoundReturnsOnCall[len(fake.isBoundArgsForCall)] + fake.isBoundArgsForCall = append(fake.isBoundArgsForCall, struct { + }{}) + stub := fake.IsBoundStub + fakeReturns := fake.isBoundReturns + fake.recordInvocation("IsBound", []interface{}{}) + fake.isBoundMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDownTrack) IsBoundCallCount() int { + fake.isBoundMutex.RLock() + defer fake.isBoundMutex.RUnlock() + return len(fake.isBoundArgsForCall) +} + +func (fake *FakeDownTrack) IsBoundCalls(stub func() bool) { + fake.isBoundMutex.Lock() + defer fake.isBoundMutex.Unlock() + fake.IsBoundStub = stub +} + +func (fake *FakeDownTrack) IsBoundReturns(result1 bool) { + fake.isBoundMutex.Lock() + defer fake.isBoundMutex.Unlock() + fake.IsBoundStub = nil + fake.isBoundReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeDownTrack) IsBoundReturnsOnCall(i int, result1 bool) { + fake.isBoundMutex.Lock() + defer fake.isBoundMutex.Unlock() + fake.IsBoundStub = nil + if fake.isBoundReturnsOnCall == nil { + fake.isBoundReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isBoundReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeDownTrack) LastSSRC() uint32 { fake.lastSSRCMutex.Lock() ret, specificReturn := fake.lastSSRCReturnsOnCall[len(fake.lastSSRCArgsForCall)] @@ -521,8 +648,12 @@ func (fake *FakeDownTrack) Invocations() map[string][][]interface{} { defer fake.invocationsMutex.RUnlock() fake.closeMutex.RLock() defer fake.closeMutex.RUnlock() + fake.createSourceDescriptionChunksMutex.RLock() + defer fake.createSourceDescriptionChunksMutex.RUnlock() fake.getNACKSeqNoMutex.RLock() defer fake.getNACKSeqNoMutex.RUnlock() + fake.isBoundMutex.RLock() + defer fake.isBoundMutex.RUnlock() fake.lastSSRCMutex.RLock() defer fake.lastSSRCMutex.RUnlock() fake.onBindMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_participant.go b/pkg/rtc/types/typesfakes/fake_participant.go index 18aa134c8..55adf7ec3 100644 --- a/pkg/rtc/types/typesfakes/fake_participant.go +++ b/pkg/rtc/types/typesfakes/fake_participant.go @@ -7,6 +7,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/proto/livekit" + "github.com/pion/rtcp" webrtc "github.com/pion/webrtc/v3" ) @@ -145,6 +146,16 @@ type FakeParticipant struct { peerConnectionReturnsOnCall map[int]struct { result1 types.PeerConnection } + RTCPChanStub func() chan<- []rtcp.Packet + rTCPChanMutex sync.RWMutex + rTCPChanArgsForCall []struct { + } + rTCPChanReturns struct { + result1 chan<- []rtcp.Packet + } + rTCPChanReturnsOnCall map[int]struct { + result1 chan<- []rtcp.Packet + } RemoveDownTrackStub func(string, *sfu.DownTrack) removeDownTrackMutex sync.RWMutex removeDownTrackArgsForCall []struct { @@ -963,6 +974,59 @@ func (fake *FakeParticipant) PeerConnectionReturnsOnCall(i int, result1 types.Pe }{result1} } +func (fake *FakeParticipant) RTCPChan() chan<- []rtcp.Packet { + fake.rTCPChanMutex.Lock() + ret, specificReturn := fake.rTCPChanReturnsOnCall[len(fake.rTCPChanArgsForCall)] + fake.rTCPChanArgsForCall = append(fake.rTCPChanArgsForCall, struct { + }{}) + stub := fake.RTCPChanStub + fakeReturns := fake.rTCPChanReturns + fake.recordInvocation("RTCPChan", []interface{}{}) + fake.rTCPChanMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) RTCPChanCallCount() int { + fake.rTCPChanMutex.RLock() + defer fake.rTCPChanMutex.RUnlock() + return len(fake.rTCPChanArgsForCall) +} + +func (fake *FakeParticipant) RTCPChanCalls(stub func() chan<- []rtcp.Packet) { + fake.rTCPChanMutex.Lock() + defer fake.rTCPChanMutex.Unlock() + fake.RTCPChanStub = stub +} + +func (fake *FakeParticipant) RTCPChanReturns(result1 chan<- []rtcp.Packet) { + fake.rTCPChanMutex.Lock() + defer fake.rTCPChanMutex.Unlock() + fake.RTCPChanStub = nil + fake.rTCPChanReturns = struct { + result1 chan<- []rtcp.Packet + }{result1} +} + +func (fake *FakeParticipant) RTCPChanReturnsOnCall(i int, result1 chan<- []rtcp.Packet) { + fake.rTCPChanMutex.Lock() + defer fake.rTCPChanMutex.Unlock() + fake.RTCPChanStub = nil + if fake.rTCPChanReturnsOnCall == nil { + fake.rTCPChanReturnsOnCall = make(map[int]struct { + result1 chan<- []rtcp.Packet + }) + } + fake.rTCPChanReturnsOnCall[i] = struct { + result1 chan<- []rtcp.Packet + }{result1} +} + func (fake *FakeParticipant) RemoveDownTrack(arg1 string, arg2 *sfu.DownTrack) { fake.removeDownTrackMutex.Lock() fake.removeDownTrackArgsForCall = append(fake.removeDownTrackArgsForCall, struct { @@ -1420,6 +1484,8 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.onTrackUpdatedMutex.RUnlock() fake.peerConnectionMutex.RLock() defer fake.peerConnectionMutex.RUnlock() + fake.rTCPChanMutex.RLock() + defer fake.rTCPChanMutex.RUnlock() fake.removeDownTrackMutex.RLock() defer fake.removeDownTrackMutex.RUnlock() fake.removeSubscriberMutex.RLock()