From c38d4df52f284a4ab15ac1b4af1eaa96da9e4254 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Mon, 15 Aug 2022 18:46:24 +0800 Subject: [PATCH] server side codec preference for publish (#916) --- pkg/rtc/participant.go | 100 ++++++++++++++++++++++++++++++++++++ pkg/rtc/transportmanager.go | 51 +++++++++++++++++- 2 files changed, 150 insertions(+), 1 deletion(-) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index db3f0cff3..4e8ed4e09 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -3,12 +3,14 @@ package rtc import ( "context" "io" + "strconv" "strings" "sync" "time" lru "github.com/hashicorp/golang-lru" "github.com/pion/rtcp" + "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" "github.com/pkg/errors" "go.uber.org/atomic" @@ -392,6 +394,8 @@ func (p *ParticipantImpl) HandleOffer(offer webrtc.SessionDescription) error { shouldPend = true } + offer = p.setCodecPreferencesForPublisher(offer) + if err := p.TransportManager.HandleOffer(offer, shouldPend); err != nil { prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "remote_description").Add(1) return err @@ -400,6 +404,77 @@ func (p *ParticipantImpl) HandleOffer(offer webrtc.SessionDescription) error { return nil } +func (p *ParticipantImpl) setCodecPreferencesForPublisher(offer webrtc.SessionDescription) webrtc.SessionDescription { + parsed, lastVideo, err := p.TransportManager.GetLastUnmatchedMediaForOffer(offer, "video") + if err != nil || lastVideo == nil { + return offer + } + // last video is pending for publish, set codec preference + var streamID string + msid, ok := lastVideo.Attribute(sdp.AttrKeyMsid) + if !ok { + return offer + } + ids := strings.Split(msid, " ") + if len(ids) < 2 { + streamID = msid + } else { + streamID = ids[1] + } + + p.pendingTracksLock.RLock() + _, info := p.getPendingTrack(streamID, livekit.TrackType_VIDEO) + if info == nil { + p.pendingTracksLock.RUnlock() + return offer + } + var mime string + for _, c := range info.Codecs { + if c.Cid == streamID { + mime = c.MimeType + break + } + } + if mime == "" && len(info.Codecs) > 0 { + mime = info.Codecs[0].MimeType + } + p.pendingTracksLock.RUnlock() + + if mime == "" { + return offer + } + + codecs, err := codecsFromMediaDescription(lastVideo) + if err != nil { + return offer + } + + mime = strings.ToUpper(mime) + var preferredCodecs, leftCodecs []string + for _, c := range codecs { + if strings.HasSuffix(mime, strings.ToUpper(c.Name)) { + preferredCodecs = append(preferredCodecs, strconv.FormatInt(int64(c.PayloadType), 10)) + } else { + leftCodecs = append(leftCodecs, strconv.FormatInt(int64(c.PayloadType), 10)) + } + } + + lastVideo.MediaName.Formats = append(lastVideo.MediaName.Formats[:0], preferredCodecs...) + lastVideo.MediaName.Formats = append(lastVideo.MediaName.Formats, leftCodecs...) + + bytes, err := parsed.Marshal() + if err != nil { + p.params.Logger.Infow("failed to marshal offer", "error", err) + p.params.Logger.Errorw("failed to marshal offer", err) + return offer + } + + return webrtc.SessionDescription{ + Type: offer.Type, + SDP: string(bytes), + } +} + func (p *ParticipantImpl) onPublisherAnswer(answer webrtc.SessionDescription) { p.params.Logger.Infow("sending answer", "transport", livekit.SignalTarget_PUBLISHER) if err := p.writeMessage(&livekit.SignalResponse{ @@ -1904,3 +1979,28 @@ func (p *ParticipantImpl) UpdateMediaLoss(nodeID livekit.NodeID, trackID livekit track.(types.LocalMediaTrack).NotifySubscriberNodeMediaLoss(nodeID, uint8(fractionalLoss)) return nil } + +func codecsFromMediaDescription(m *sdp.MediaDescription) (out []sdp.Codec, err error) { + s := &sdp.SessionDescription{ + MediaDescriptions: []*sdp.MediaDescription{m}, + } + + for _, payloadStr := range m.MediaName.Formats { + payloadType, err := strconv.ParseUint(payloadStr, 10, 8) + if err != nil { + return nil, err + } + + codec, err := s.GetCodecForPayloadType(uint8(payloadType)) + if err != nil { + if payloadType == 0 { + continue + } + return nil, err + } + + out = append(out, codec) + } + + return out, nil +} diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index 4cf23059e..856d141fc 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -3,8 +3,10 @@ package rtc import ( "strings" "sync" + "sync/atomic" "github.com/pion/rtcp" + "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" "github.com/pkg/errors" @@ -43,6 +45,7 @@ type TransportManager struct { pendingOfferPublisher *webrtc.SessionDescription pendingDataChannelsPublisher []*livekit.DataChannelInfo + lastPublisherAnswer atomic.Value onPublisherGetDTX func() bool @@ -154,7 +157,10 @@ func (t *TransportManager) OnPublisherGetDTX(f func() bool) { } func (t *TransportManager) OnPublisherAnswer(f func(answer webrtc.SessionDescription)) { - t.publisher.OnAnswer(f) + t.publisher.OnAnswer(func(sd webrtc.SessionDescription) { + t.lastPublisherAnswer.Store(sd) + f(sd) + }) } func (t *TransportManager) OnPublisherTrack(f func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver)) { @@ -290,6 +296,49 @@ func (t *TransportManager) createDataChannelsForSubscriber(pendingDataChannels [ return nil } +func (t *TransportManager) GetLastUnmatchedMediaForOffer(offer webrtc.SessionDescription, mediaType string) (parsed *sdp.SessionDescription, unmatched *sdp.MediaDescription, err error) { + // prefer codec from offer for clients that don't support setCodecPreferences + parsed, err = offer.Unmarshal() + if err != nil { + t.params.Logger.Errorw("failed to parse offer for codec preference", err) + return + } + + for i := len(parsed.MediaDescriptions) - 1; i >= 0; i-- { + media := parsed.MediaDescriptions[i] + if media.MediaName.Media == mediaType { + unmatched = media + break + } + } + + if unmatched == nil { + return + } + + lastAnswer := t.lastPublisherAnswer.Load() + if lastAnswer != nil { + answer := lastAnswer.(webrtc.SessionDescription) + parsedAnswer, err1 := answer.Unmarshal() + if err1 != nil { + // should not happend + t.params.Logger.Errorw("failed to parse last answer", err) + return + } + + for _, m := range parsedAnswer.MediaDescriptions { + mid, _ := m.Attribute(sdp.AttrKeyMID) + if lastMid, _ := unmatched.Attribute(sdp.AttrKeyMID); lastMid == mid { + // mid matched, return + unmatched = nil + return + } + } + } + + return +} + func (t *TransportManager) HandleOffer(offer webrtc.SessionDescription, shouldPend bool) error { t.lock.Lock() if shouldPend {