diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 9da1d3d6e..caeafed45 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -219,6 +219,16 @@ func (t *MediaTrack) GetMimeTypeForSdpCid(cid string) mime.MimeType { return mime.MimeTypeUnknown } +func (t *MediaTrack) GetMimeTypeForMid(mid string) mime.MimeType { + ti := t.MediaTrackReceiver.TrackInfoClone() + for _, c := range ti.Codecs { + if c.Mid == mid { + return mime.NormalizeMimeType(c.MimeType) + } + } + return mime.MimeTypeUnknown +} + func (t *MediaTrack) GetCidsForMimeType(mimeType mime.MimeType) (string, string) { ti := t.MediaTrackReceiver.TrackInfoClone() for _, c := range ti.Codecs { diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index b0ba9c3be..d86720287 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -741,6 +741,20 @@ func (t *MediaTrackReceiver) UpdateCodecSdpCid(mimeType mime.MimeType, sdpCid st t.updateTrackInfoOfReceivers() } +func (t *MediaTrackReceiver) UpdateCodecMid(mimeType mime.MimeType, mid string) { + t.lock.Lock() + trackInfo := t.TrackInfoClone() + for _, origin := range trackInfo.Codecs { + if mime.NormalizeMimeType(origin.MimeType) == mimeType { + origin.Mid = mid + } + } + t.trackInfo.Store(trackInfo) + t.lock.Unlock() + + t.updateTrackInfoOfReceivers() +} + func (t *MediaTrackReceiver) UpdateCodecRids(mimeType mime.MimeType, rids buffer.VideoLayersRid) { t.lock.Lock() trackInfo := t.TrackInfoClone() diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 56835d16c..6a9e97747 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -40,6 +40,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/protocol/observability" "github.com/livekit/protocol/observability/roomobs" + lksdp "github.com/livekit/protocol/sdp" sdpHelper "github.com/livekit/protocol/sdp" "github.com/livekit/protocol/utils" "github.com/livekit/protocol/utils/guid" @@ -1097,7 +1098,7 @@ func (p *ParticipantImpl) updateRidsFromSDP(parsed *sdp.SessionDescription, unma } p.pendingTracksLock.Lock() - pti := p.getPendingTrackPrimaryBySdpCid(mst, true) + pti := p.getPendingTrackPrimaryBySdpCid(mst) if pti != nil { pti.sdpRids = getRids(pti.sdpRids) p.pubLogger.Debugw( @@ -1147,6 +1148,85 @@ func (p *ParticipantImpl) updateRidsFromSDP(parsed *sdp.SessionDescription, unma } } +func (p *ParticipantImpl) updateRidsFromSDPByMid(parsed *sdp.SessionDescription) { + getRids := func(m *sdp.MediaDescription, inRids buffer.VideoLayersRid) buffer.VideoLayersRid { + var outRids buffer.VideoLayersRid + rids, ok := sdpHelper.GetSimulcastRids(m) + if ok { + n := min(len(rids), len(inRids)) + for i := 0; i < n; i++ { + outRids[i] = rids[i] + } + for i := n; i < len(inRids); i++ { + outRids[i] = "" + } + outRids = buffer.NormalizeVideoLayersRid(outRids) + } else { + for i := 0; i < len(inRids); i++ { + outRids[i] = "" + } + } + + return outRids + } + + for _, md := range parsed.MediaDescriptions { + mid := lksdp.GetMidValue(md) + if mid == "" { + continue + } + + p.pendingTracksLock.Lock() + pti := p.getPendingTrackPrimaryByMid(mid) + if pti != nil { + pti.sdpRids = getRids(md, pti.sdpRids) + p.pubLogger.Debugw( + "pending track rids updated", + "trackID", pti.trackInfos[0].Sid, + "pendingTrack", pti, + ) + + ti := pti.trackInfos[0] + for _, codec := range ti.Codecs { + if codec.Mid == mid { + mimeType := mime.NormalizeMimeType(codec.MimeType) + for _, layer := range codec.Layers { + layer.SpatialLayer = buffer.VideoQualityToSpatialLayer(mimeType, layer.Quality, ti) + layer.Rid = buffer.VideoQualityToRid(mimeType, layer.Quality, ti, pti.sdpRids) + } + } + } + } + p.pendingTracksLock.Unlock() + + if pti == nil { + // track could already be published, but this could be back up codec offer, + // so check in published tracks also + mt := p.getPublishedTrackByMid(mid) + if mt != nil { + mimeType := mt.(*MediaTrack).GetMimeTypeForMid(mid) + if mimeType != mime.MimeTypeUnknown { + rids := getRids(md, buffer.DefaultVideoLayersRid) + mt.(*MediaTrack).UpdateCodecRids(mimeType, rids) + p.pubLogger.Debugw( + "published track rids updated", + "trackID", mt.ID(), + "mime", mimeType, + "track", logger.Proto(mt.ToProto()), + ) + } else { + p.pubLogger.Warnw( + "could not get mime type for mid", nil, + "trackID", mt.ID(), + "mid", mid, + "track", logger.Proto(mt.ToProto()), + ) + } + } + } + } +} + // HandleOffer an offer from remote participant, used when clients make the initial connection func (p *ParticipantImpl) HandleOffer(offer webrtc.SessionDescription, offerId uint32) error { p.pubLogger.Debugw( @@ -1256,8 +1336,21 @@ func (p *ParticipantImpl) HandleAnswer(answer webrtc.SessionDescription, answerI signalConnCost := time.Since(p.ConnectedAt()).Milliseconds() p.TransportManager.UpdateSignalingRTT(uint32(signalConnCost)) - // SINGLE-PEER-CONNECTION-TODO: have to run `populateSdpCid` and `updateRidsFromSDP` - // SINGLE-PEER-CONNECTION-TODO: there won't be unmatched media though, maybe need to store Mid in trackInfo and use that + if p.ProtocolVersion().SupportsSinglePeerConnection() { + parsedAnswer, err := answer.Unmarshal() + if err != nil { + p.pubLogger.Warnw( + "could not parse answer", err, + "transport", livekit.SignalTarget_SUBSCRIBER, + "answer", answer, + "answerId", answerId, + ) + return + } + + p.populateSdpCidByMid(parsedAnswer) + p.updateRidsFromSDPByMid(parsedAnswer) + } p.TransportManager.HandleAnswer(answer, answerId) } @@ -1322,16 +1415,6 @@ func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) { return } - if p.ProtocolVersion().SupportsSinglePeerConnection() { - if err := p.TransportManager.AddRemoteTrackAndNegotiate( - ti, - p.getDisabledPublishCodecs(), - p.params.Config.Publisher.RTCPFeedback, - ); err != nil { - return - } - } - p.sendTrackPublished(req.Cid, ti) p.handlePendingRemoteTracks() @@ -2146,6 +2229,21 @@ func (p *ParticipantImpl) setIsPublisher(isPublisher bool) { // when the server has an offer for participant func (p *ParticipantImpl) onSubscriberOffer(offer webrtc.SessionDescription, offerId uint32) error { + if p.ProtocolVersion().SupportsSinglePeerConnection() { + parsedOffer, err := offer.Unmarshal() + if err != nil { + p.pubLogger.Warnw( + "could not parse offer", err, + "transport", livekit.SignalTarget_PUBLISHER, + "offer", offer, + "offerId", offerId, + ) + return err + } + + p.populateMid(parsedOffer) + } + p.subLogger.Debugw( "sending offer", "transport", livekit.SignalTarget_SUBSCRIBER, @@ -2877,6 +2975,16 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l return nil } + if p.ProtocolVersion().SupportsSinglePeerConnection() { + if err := p.TransportManager.AddRemoteTrackAndNegotiate( + ti, + p.getDisabledPublishCodecs(), + p.params.Config.Publisher.RTCPFeedback, + ); err != nil { + return nil + } + } + p.pendingTracks[req.Cid] = &pendingTrackInfo{ trackInfos: []*livekit.TrackInfo{ti}, sdpRids: buffer.DefaultVideoLayersRid, // could get updated from SDP @@ -3218,6 +3326,7 @@ func (p *ParticipantImpl) addMediaTrack(signalCid string, ti *livekit.TrackInfo) p.pendingTracksLock.Lock() if pti := p.pendingTracks[signalCid]; pti != nil { + // SINGLE-PEER-CONNECTION-TODO: need to add remote track when dequeuing p.sendTrackPublished(signalCid, pti.trackInfos[0]) pti.queued = false } @@ -3326,6 +3435,7 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp return signalCid, utils.CloneProto(pendingInfo.trackInfos[0]), pendingInfo.sdpRids, pendingInfo.migrated, pendingInfo.createdAt } +// SINGLE-PEER-CONNECTION-TODO: this may not be needed func (p *ParticipantImpl) getPendingTracksByTrackType(trackType livekit.TrackType) []*livekit.TrackInfo { var pendingTracks []*livekit.TrackInfo for _, pti := range p.pendingTracks { @@ -3337,7 +3447,38 @@ func (p *ParticipantImpl) getPendingTracksByTrackType(trackType livekit.TrackTyp return pendingTracks } -func (p *ParticipantImpl) getPendingTrackPrimaryBySdpCid(sdpCid string, skipQueued bool) *pendingTrackInfo { +func (p *ParticipantImpl) getPendingTrackByTrackTypeWithoutMid(trackType livekit.TrackType) (string, *livekit.TrackInfo, bool) { + for cid, pti := range p.pendingTracks { + ti := pti.trackInfos[0] + if ti.Type == trackType { + for _, c := range ti.Codecs { + if c.Mid == "" { + return cid, utils.CloneProto(ti), pti.migrated + } + } + } + } + return "", nil, false +} + +func (p *ParticipantImpl) getPendingTrackByMid(mid string, skipQueued bool) (string, *livekit.TrackInfo, bool) { + for cid, pti := range p.pendingTracks { + if skipQueued && pti.queued { + continue + } + + ti := pti.trackInfos[0] + for _, c := range ti.Codecs { + if c.Mid == mid { + return cid, utils.CloneProto(ti), pti.migrated + } + } + } + + return "", nil, false +} + +func (p *ParticipantImpl) getPendingTrackPrimaryBySdpCid(sdpCid string) *pendingTrackInfo { for _, pti := range p.pendingTracks { ti := pti.trackInfos[0] if len(ti.Codecs) == 0 { @@ -3351,6 +3492,20 @@ func (p *ParticipantImpl) getPendingTrackPrimaryBySdpCid(sdpCid string, skipQueu return nil } +func (p *ParticipantImpl) getPendingTrackPrimaryByMid(mid string) *pendingTrackInfo { + for _, pti := range p.pendingTracks { + ti := pti.trackInfos[0] + if len(ti.Codecs) == 0 { + continue + } + if ti.Codecs[0].Mid == mid { + return pti + } + } + + return nil +} + // setTrackID either generates a new TrackID for an AddTrackRequest func (p *ParticipantImpl) setTrackID(cid string, info *livekit.TrackInfo) { var trackID string @@ -3405,6 +3560,38 @@ func (p *ParticipantImpl) getPublishedTrackBySdpCid(clientId string) types.Media return nil } +func (p *ParticipantImpl) getPublishedTrackPendingMid() types.MediaTrack { + for _, publishedTrack := range p.GetPublishedTracks() { + ti := publishedTrack.ToProto() + for _, c := range ti.Codecs { + if c.Mid == "" && c.Cid != "" { + p.pubLogger.Debugw( + "found track pending mid", + "trackID", publishedTrack.ID(), + "track", logger.Proto(publishedTrack.ToProto()), + ) + return publishedTrack + } + } + } + + return nil +} + +func (p *ParticipantImpl) getPublishedTrackByMid(mid string) types.MediaTrack { + for _, publishedTrack := range p.GetPublishedTracks() { + ti := publishedTrack.ToProto() + for _, c := range ti.Codecs { + if c.Mid == mid { + p.pubLogger.Debugw("found track by mid", "mid", mid, "trackID", publishedTrack.ID()) + return publishedTrack + } + } + } + + return nil +} + func (p *ParticipantImpl) DebugInfo() map[string]interface{} { info := map[string]interface{}{ "ID": p.ID(), diff --git a/pkg/rtc/participant_sdp.go b/pkg/rtc/participant_sdp.go index 3a0b9d027..3cea576a4 100644 --- a/pkg/rtc/participant_sdp.go +++ b/pkg/rtc/participant_sdp.go @@ -26,6 +26,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu/mime" "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" lksdp "github.com/livekit/protocol/sdp" "github.com/livekit/protocol/utils" ) @@ -92,6 +93,11 @@ func (p *ParticipantImpl) populateSdpCid(parsedOffer *sdp.SessionDescription) ([ ) } unmatchedTrack.(*MediaTrack).UpdateCodecSdpCid(unmatchedSdpMimeType, streamID) + p.pubLogger.Debugw( + "published track SDP cid updated", + "trackID", unmatchedTrack.ID(), + "track", logger.Proto(unmatchedTrack.ToProto()), + ) } continue } @@ -165,6 +171,143 @@ func (p *ParticipantImpl) populateSdpCid(parsedOffer *sdp.SessionDescription) ([ return unmatchAudios, unmatchVideos } +func (p *ParticipantImpl) populateMid(parsedOffer *sdp.SessionDescription) { + processUnmatch := func(unmatches []*sdp.MediaDescription, trackType livekit.TrackType) { + for _, unmatch := range unmatches { + mid := lksdp.GetMidValue(unmatch) + if mid == "" { + continue + } + + p.pendingTracksLock.Lock() + signalCid, ti, migrated := p.getPendingTrackByTrackTypeWithoutMid(trackType) + if migrated || ti == nil || signalCid == "" { + p.pendingTracksLock.Unlock() + + // check for back up codec pending publish + publishedTrack := p.getPublishedTrackByMid(mid) + if publishedTrack != nil { + var mimeType mime.MimeType + updated := false + + ti := publishedTrack.ToProto() + for _, c := range ti.Codecs { + if c.Mid == "" && c.Cid != "" { + mimeType = mime.NormalizeMimeType(c.MimeType) + updated = true + } + } + publishedTrack.(*MediaTrack).UpdateCodecMid(mimeType, mid) + if updated { + p.pubLogger.Debugw( + "published track mid updated", + "trackID", publishedTrack.ID(), + "track", logger.Proto(publishedTrack.ToProto()), + ) + } + } + continue + } + + updated := false + for _, c := range ti.Codecs { + if c.Mid == "" { + c.Mid = mid + updated = true + } + } + + if updated { + p.pendingTracks[signalCid].trackInfos[0] = utils.CloneProto(ti) + p.pubLogger.Debugw( + "pending track mid updated", + "signalCid", signalCid, + "trackID", ti.Sid, + "pendingTrack", p.pendingTracks[signalCid], + ) + } + p.pendingTracksLock.Unlock() + } + } + + unmatchAudios, err := p.TransportManager.GetUnmatchMediaForOffer(parsedOffer, "audio") + if err != nil { + p.pubLogger.Warnw("could not get unmatch audios", err) + return + } + + unmatchVideos, err := p.TransportManager.GetUnmatchMediaForOffer(parsedOffer, "video") + if err != nil { + p.pubLogger.Warnw("could not get unmatch audios", err) + return + } + + processUnmatch(unmatchAudios, livekit.TrackType_AUDIO) + processUnmatch(unmatchVideos, livekit.TrackType_VIDEO) +} + +func (p *ParticipantImpl) populateSdpCidByMid(parsedAnswer *sdp.SessionDescription) { + for _, md := range parsedAnswer.MediaDescriptions { + mid := lksdp.GetMidValue(md) + if mid == "" { + continue + } + + streamID, ok := lksdp.ExtractStreamID(md) + if !ok { + continue + } + + p.pendingTracksLock.Lock() + signalCid, ti, migrated := p.getPendingTrackByMid(mid, true) + if migrated || ti == nil || signalCid == "" { + p.pendingTracksLock.Unlock() + + publishedTrack := p.getPublishedTrackByMid(mid) + if publishedTrack != nil { + var mimeType mime.MimeType + updated := false + + ti := publishedTrack.ToProto() + for _, c := range ti.Codecs { + if c.Mid == mid && c.Cid != streamID { + mimeType = mime.NormalizeMimeType(c.MimeType) + updated = true + } + } + publishedTrack.(*MediaTrack).UpdateCodecSdpCid(mimeType, streamID) + if updated { + p.pubLogger.Debugw( + "published track SDP cid updated", + "trackID", publishedTrack.ID(), + "track", logger.Proto(publishedTrack.ToProto()), + ) + } + } + continue + } + + updated := false + for _, c := range ti.Codecs { + if c.Mid == mid && c.Cid != streamID { + c.SdpCid = streamID + updated = true + } + } + + if updated { + p.pendingTracks[signalCid].trackInfos[0] = utils.CloneProto(ti) + p.pubLogger.Debugw( + "pending track SDP cid updated", + "signalCid", signalCid, + "trackID", ti.Sid, + "pendingTrack", p.pendingTracks[signalCid], + ) + } + p.pendingTracksLock.Unlock() + } +} + func (p *ParticipantImpl) setCodecPreferencesForPublisher( parsedOffer *sdp.SessionDescription, unmatchAudios []*sdp.MediaDescription, diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index a3bf91e68..ca7fe9efb 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -984,10 +984,6 @@ func (t *PCTransport) AddRemoteTrackAndNegotiate( publishDisabledCodecs []*livekit.Codec, rtcpFeedbackConfig RTCPFeedbackConfig, ) error { - if ti == nil { - return nil - } - rtpCodecType := webrtc.RTPCodecTypeVideo if ti.Type == livekit.TrackType_AUDIO { rtpCodecType = webrtc.RTPCodecTypeAudio