From 302facc60dae494f04ec2a94933b9e8a0379427e Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Mon, 6 Nov 2023 21:11:39 +0800 Subject: [PATCH] Reject migration if codec mismatch with published tracks (#2225) * Reject migrated/published track mismatch codec with track info * Check potential codecs * Issue full connect if mismatch * fix codec finding --- pkg/rtc/mediatrack.go | 13 ++++++++-- pkg/rtc/participant.go | 36 ++++++++++++++++++++++------ pkg/rtc/participant_internal_test.go | 2 +- pkg/rtc/participant_sdp.go | 6 ++--- pkg/rtc/types/interfaces.go | 5 +++- 5 files changed, 48 insertions(+), 14 deletions(-) diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 9b7f1c311..584c953c3 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -239,13 +239,22 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra t.params.Logger.Debugw("AddReceiver", "mime", track.Codec().MimeType) wr := t.MediaTrackReceiver.Receiver(mime) if wr == nil { - var priority int + priority := -1 for idx, c := range t.params.TrackInfo.Codecs { - if strings.HasSuffix(mime, c.MimeType) { + if strings.EqualFold(mime, c.MimeType) { priority = idx break } } + if len(t.params.TrackInfo.Codecs) == 0 { + priority = 0 + } + if priority < 0 { + t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mime, "track", logger.Proto(t.params.TrackInfo)) + t.lock.Unlock() + return false + } + newWR := sfu.NewWebRTCReceiver( receiver, track, diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index b04dd2561..21134b0f1 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -1624,8 +1624,10 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l seenCodecs := make(map[string]struct{}) for _, codec := range req.SimulcastCodecs { mime := codec.Codec - if req.Type == livekit.TrackType_VIDEO && !strings.HasPrefix(mime, "video/") { - mime = "video/" + mime + if req.Type == livekit.TrackType_VIDEO { + if !strings.HasPrefix(mime, "video/") { + mime = "video/" + mime + } if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) { altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs) p.pubLogger.Infow("falling back to alternative codec", @@ -1762,12 +1764,32 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei // use existing media track to handle simulcast mt, ok := p.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack) if !ok { - signalCid, ti := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) + signalCid, ti, migrated := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) if ti == nil { p.pendingTracksLock.Unlock() return nil, false } + // check if the migrated track has correct codec + if migrated && len(ti.Codecs) > 0 { + parameters := rtpReceiver.GetParameters() + var codecFound int + for _, c := range ti.Codecs { + for _, nc := range parameters.Codecs { + if strings.EqualFold(nc.MimeType, c.MimeType) { + codecFound++ + break + } + } + } + if codecFound != len(ti.Codecs) { + p.params.Logger.Warnw("migrated track codec mismatched", nil, "track", logger.Proto(ti), "webrtcCodec", parameters) + p.pendingTracksLock.Unlock() + p.IssueFullReconnect(types.ParticipantCloseReasonMigrateCodecMismatch) + return nil, false + } + } + ti.MimeType = track.Codec().MimeType mt = p.addMediaTrack(signalCid, track.ID(), ti) newTrack = true @@ -1976,7 +1998,7 @@ func (p *ParticipantImpl) onUpTrackManagerClose() { p.postRtcp(nil) } -func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) { +func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo, bool) { signalCid := clientId pendingInfo := p.pendingTracks[clientId] if pendingInfo == nil { @@ -2012,10 +2034,10 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp // if still not found, we are done if pendingInfo == nil { p.pubLogger.Errorw("track info not published prior to track", nil, "clientId", clientId) - return signalCid, nil + return signalCid, nil, false } - return signalCid, pendingInfo.trackInfos[0] + return signalCid, pendingInfo.trackInfos[0], pendingInfo.migrated } // setStableTrackID either generates a new TrackID or reuses a previously used one @@ -2206,7 +2228,7 @@ func (p *ParticipantImpl) IssueFullReconnect(reason types.ParticipantCloseReason scr := types.SignallingCloseReasonUnknown switch reason { - case types.ParticipantCloseReasonPublicationError: + case types.ParticipantCloseReasonPublicationError, types.ParticipantCloseReasonMigrateCodecMismatch: scr = types.SignallingCloseReasonFullReconnectPublicationError case types.ParticipantCloseReasonSubscriptionError: scr = types.SignallingCloseReasonFullReconnectSubscriptionError diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 2256a70ca..afd312bc8 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -284,7 +284,7 @@ func TestMuteSetting(t *testing.T) { Muted: true, }) - _, ti := p.getPendingTrack("cid", livekit.TrackType_AUDIO) + _, ti, _ := p.getPendingTrack("cid", livekit.TrackType_AUDIO) require.NotNil(t, ti) require.True(t, ti.Muted) }) diff --git a/pkg/rtc/participant_sdp.go b/pkg/rtc/participant_sdp.go index 70829a34c..51e98c478 100644 --- a/pkg/rtc/participant_sdp.go +++ b/pkg/rtc/participant_sdp.go @@ -46,7 +46,7 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.Se } p.pendingTracksLock.RLock() - _, info := p.getPendingTrack(streamID, livekit.TrackType_AUDIO) + _, info, _ := p.getPendingTrack(streamID, livekit.TrackType_AUDIO) // if RED is disabled for this track, don't prefer RED codec in offer disableRed := info != nil && info.DisableRed p.pendingTracksLock.RUnlock() @@ -132,7 +132,7 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess if mt != nil { info = mt.ToProto() } else { - _, info = p.getPendingTrack(streamID, livekit.TrackType_VIDEO) + _, info, _ = p.getPendingTrack(streamID, livekit.TrackType_VIDEO) } if info == nil { @@ -239,7 +239,7 @@ func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescript track, _ := p.getPublishedTrackBySdpCid(streamID).(*MediaTrack) if track == nil { p.pendingTracksLock.RLock() - _, ti = p.getPendingTrack(streamID, livekit.TrackType_AUDIO) + _, ti, _ = p.getPendingTrack(streamID, livekit.TrackType_AUDIO) p.pendingTracksLock.RUnlock() } else { ti = track.TrackInfo(false) diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index da3ee7178..6c2210613 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -104,6 +104,7 @@ const ( ParticipantCloseReasonPublicationError ParticipantCloseReasonSubscriptionError ParticipantCloseReasonDataChannelError + ParticipantCloseReasonMigrateCodecMismatch ) func (p ParticipantCloseReason) String() string { @@ -154,6 +155,8 @@ func (p ParticipantCloseReason) String() string { return "SUBSCRIPTION_ERROR" case ParticipantCloseReasonDataChannelError: return "DATA_CHANNEL_ERROR" + case ParticipantCloseReasonMigrateCodecMismatch: + return "MIGRATE_CODEC_MISMATCH" default: return fmt.Sprintf("%d", int(p)) } @@ -184,7 +187,7 @@ func (p ParticipantCloseReason) ToDisconnectReason() livekit.DisconnectReason { return livekit.DisconnectReason_SERVER_SHUTDOWN case ParticipantCloseReasonOvercommitted: return livekit.DisconnectReason_SERVER_SHUTDOWN - case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError: + case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError, ParticipantCloseReasonMigrateCodecMismatch: return livekit.DisconnectReason_STATE_MISMATCH default: // the other types will map to unknown reason