From 8fd3e8fe2d8c50a06d3c438621444cc7c3cd6bd2 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Mon, 17 Oct 2022 10:48:11 +0800 Subject: [PATCH] Support track level stereo and red setting (#1086) * Support track level stereo and red setting * fix test client --- go.mod | 2 +- go.sum | 4 +- pkg/rtc/mediatracksubscriptions.go | 8 +- pkg/rtc/participant.go | 156 +----------- pkg/rtc/participant_sdp.go | 226 ++++++++++++++++++ pkg/rtc/transport.go | 144 +++-------- pkg/rtc/transportmanager.go | 21 +- pkg/rtc/types/interfaces.go | 8 +- .../typesfakes/fake_local_participant.go | 36 +-- test/client/client.go | 3 +- 10 files changed, 308 insertions(+), 300 deletions(-) create mode 100644 pkg/rtc/participant_sdp.go diff --git a/go.mod b/go.mod index 264ff9dd2..03592a135 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/hashicorp/golang-lru v0.5.4 github.com/livekit/mageutil v0.0.0-20221002073820-d9198083cfdc github.com/livekit/mediatransportutil v0.0.0-20221007030528-7440725c362b - github.com/livekit/protocol v1.1.3-0.20221007212651-d9bc6cd9cb77 + github.com/livekit/protocol v1.1.3-0.20221014075341-b0c33b869aa5 github.com/livekit/rtcscore-go v0.0.0-20220815072451-20ee10ae1995 github.com/mackerelio/go-osstat v0.2.3 github.com/magefile/mage v1.14.0 diff --git a/go.sum b/go.sum index e7c588b08..b2227cbed 100644 --- a/go.sum +++ b/go.sum @@ -244,8 +244,8 @@ github.com/livekit/mageutil v0.0.0-20221002073820-d9198083cfdc h1:e3GIA9AL6h4a38 github.com/livekit/mageutil v0.0.0-20221002073820-d9198083cfdc/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20221007030528-7440725c362b h1:RBNV8TckETSkIkKxcD12d8nZKVkB9GSY/sQlMoaruP4= github.com/livekit/mediatransportutil v0.0.0-20221007030528-7440725c362b/go.mod h1:1Dlx20JPoIKGP45eo+yuj0HjeE25zmyeX/EWHiPCjFw= -github.com/livekit/protocol v1.1.3-0.20221007212651-d9bc6cd9cb77 h1:vHVvfoKWUT1eZahFn2CVjg9dHatp4XIRLVczy6uVnGI= -github.com/livekit/protocol v1.1.3-0.20221007212651-d9bc6cd9cb77/go.mod h1:jshI3nWbZkF1y1TUr2WIqzhN9HnyMqM9v/e/31L78z0= +github.com/livekit/protocol v1.1.3-0.20221014075341-b0c33b869aa5 h1:qlpTUN/xw9xk5Y54LdIBjLBxxy923spUi31FAnO5b7o= +github.com/livekit/protocol v1.1.3-0.20221014075341-b0c33b869aa5/go.mod h1:jshI3nWbZkF1y1TUr2WIqzhN9HnyMqM9v/e/31L78z0= github.com/livekit/rtcscore-go v0.0.0-20220815072451-20ee10ae1995 h1:vOaY2qvfLihDyeZtnGGN1Law9wRrw8BMGCr1TygTvMw= github.com/livekit/rtcscore-go v0.0.0-20220815072451-20ee10ae1995/go.mod h1:116ych8UaEs9vfIE8n6iZCZ30iagUFTls0vRmC+Ix5U= github.com/mackerelio/go-osstat v0.2.3 h1:jAMXD5erlDE39kdX2CU7YwCGRcxIO33u/p8+Fhe5dJw= diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 6824d156d..ac810b210 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -203,18 +203,22 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * // if cannot replace, find an unused transceiver or add new one if transceiver == nil { + info := t.params.MediaTrack.ToProto() + addTrackParams := types.AddTrackParams{ + Stereo: info.Stereo, + } if sub.ProtocolVersion().SupportsTransceiverReuse() { // // AddTrack will create a new transceiver or re-use an unused one // if the attributes match. This prevents SDP from bloating // because of dormant transceivers building up. // - sender, transceiver, err = sub.AddTrackToSubscriber(downTrack) + sender, transceiver, err = sub.AddTrackToSubscriber(downTrack, addTrackParams) if err != nil { return err } } else { - sender, transceiver, err = sub.AddTransceiverFromTrackToSubscriber(downTrack) + sender, transceiver, err = sub.AddTransceiverFromTrackToSubscriber(downTrack, addTrackParams) if err != nil { return err } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 2e766d389..92eccb2ac 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -412,136 +412,9 @@ func (p *ParticipantImpl) HandleOffer(offer webrtc.SessionDescription) { p.TransportManager.HandleOffer(offer, shouldPend) } -func (p *ParticipantImpl) setCodecPreferencesForPublisher(offer webrtc.SessionDescription) webrtc.SessionDescription { - offer = p.setCodecPreferencesOpusRedForPublisher(offer) - offer = p.setCodecPreferencesVideoForPublisher(offer) - return offer -} - -func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.SessionDescription) webrtc.SessionDescription { - parsed, lastAudio, err := p.TransportManager.GetLastUnmatchedMediaForOffer(offer, "audio") - if err != nil || lastAudio == nil { - return offer - } - - codecs, err := codecsFromMediaDescription(lastAudio) - if err != nil { - return offer - } - - var opusPayload uint8 - for _, codec := range codecs { - if strings.EqualFold(codec.Name, "opus") { - opusPayload = codec.PayloadType - break - } - } - if opusPayload == 0 { - return offer - } - - var preferredCodecs, leftCodecs []string - for _, codec := range codecs { - // codec contain opus/red - if strings.EqualFold(codec.Name, "red") && strings.Contains(codec.Fmtp, strconv.FormatInt(int64(opusPayload), 10)) { - preferredCodecs = append(preferredCodecs, strconv.FormatInt(int64(codec.PayloadType), 10)) - } else { - leftCodecs = append(leftCodecs, strconv.FormatInt(int64(codec.PayloadType), 10)) - } - } - - // no opus/red found - if len(preferredCodecs) == 0 { - return offer - } - - lastAudio.MediaName.Formats = append(lastAudio.MediaName.Formats[:0], preferredCodecs...) - lastAudio.MediaName.Formats = append(lastAudio.MediaName.Formats, leftCodecs...) - - bytes, err := parsed.Marshal() - if err != nil { - p.params.Logger.Errorw("failed to marshal offer", err) - return offer - } - - return webrtc.SessionDescription{ - Type: offer.Type, - SDP: string(bytes), - } -} - -func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.SessionDescription) webrtc.SessionDescription { - parsed, lastVideo, err := p.TransportManager.GetLastUnmatchedMediaForOffer(offer, "video") - if err != nil || lastVideo == nil { - return offer - } - // last video is pending for publish, set codec preference - var streamID string - msid, ok := lastVideo.Attribute(sdp.AttrKeyMsid) - if !ok { - return offer - } - ids := strings.Split(msid, " ") - if len(ids) < 2 { - streamID = msid - } else { - streamID = ids[1] - } - - p.pendingTracksLock.RLock() - _, info := p.getPendingTrack(streamID, livekit.TrackType_VIDEO) - if info == nil { - p.pendingTracksLock.RUnlock() - return offer - } - var mime string - for _, c := range info.Codecs { - if c.Cid == streamID { - mime = c.MimeType - break - } - } - if mime == "" && len(info.Codecs) > 0 { - mime = info.Codecs[0].MimeType - } - p.pendingTracksLock.RUnlock() - - if mime == "" { - return offer - } - - codecs, err := codecsFromMediaDescription(lastVideo) - if err != nil { - return offer - } - - mime = strings.ToUpper(mime) - var preferredCodecs, leftCodecs []string - for _, c := range codecs { - if strings.HasSuffix(mime, strings.ToUpper(c.Name)) { - preferredCodecs = append(preferredCodecs, strconv.FormatInt(int64(c.PayloadType), 10)) - } else { - leftCodecs = append(leftCodecs, strconv.FormatInt(int64(c.PayloadType), 10)) - } - } - - lastVideo.MediaName.Formats = append(lastVideo.MediaName.Formats[:0], preferredCodecs...) - lastVideo.MediaName.Formats = append(lastVideo.MediaName.Formats, leftCodecs...) - - bytes, err := parsed.Marshal() - if err != nil { - p.params.Logger.Errorw("failed to marshal offer", err) - return offer - } - - return webrtc.SessionDescription{ - Type: offer.Type, - SDP: string(bytes), - } -} - func (p *ParticipantImpl) onPublisherAnswer(answer webrtc.SessionDescription) error { p.params.Logger.Infow("sending answer", "transport", livekit.SignalTarget_PUBLISHER) + answer = p.configurePublisherAnswer(answer) if err := p.writeMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Answer{ Answer: ToProtoSessionDescription(answer), @@ -1166,7 +1039,6 @@ func (p *ParticipantImpl) setupTransportManager() error { tm.OnPublisherICECandidate(func(c *webrtc.ICECandidate) error { return p.onICECandidate(c, livekit.SignalTarget_PUBLISHER) }) - tm.OnPublisherGetDTX(p.onPublisherGetDTX) tm.OnPublisherAnswer(p.onPublisherAnswer) tm.OnPublisherTrack(p.onMediaTrack) tm.OnPublisherInitialConnected(p.onPublisherInitialConnected) @@ -1548,6 +1420,8 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l DisableDtx: req.DisableDtx, Source: req.Source, Layers: req.Layers, + DisableRed: req.DisableRed, + Stereo: req.Stereo, } p.setStableTrackID(req.Cid, ti) for _, codec := range req.SimulcastCodecs { @@ -1640,30 +1514,6 @@ func (p *ParticipantImpl) getPublisherConnectionQuality() map[livekit.TrackID]fl return scores } -func (p *ParticipantImpl) onPublisherGetDTX() bool { - p.pendingTracksLock.RLock() - defer p.pendingTracksLock.RUnlock() - - // - // Although DTX is set per track, there are cases where - // pending track has to be looked up by kind. This happens - // when clients change track id between signalling and SDP. - // In that case, look at all pending tracks by kind and - // enable DTX even if one has it enabled. - // - // Most of the time in practice, there is going to be one - // audio kind track and hence this is fine. - // - for _, pti := range p.pendingTracks { - ti := pti.trackInfos[0] - if ti != nil && ti.Type == livekit.TrackType_AUDIO { - return !ti.DisableDtx - } - } - - return false -} - func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) (*MediaTrack, bool) { p.pendingTracksLock.Lock() newTrack := false diff --git a/pkg/rtc/participant_sdp.go b/pkg/rtc/participant_sdp.go new file mode 100644 index 000000000..43e9b763c --- /dev/null +++ b/pkg/rtc/participant_sdp.go @@ -0,0 +1,226 @@ +package rtc + +import ( + "fmt" + "strconv" + "strings" + + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v3" + + "github.com/livekit/protocol/livekit" + lksdp "github.com/livekit/protocol/sdp" +) + +func (p *ParticipantImpl) setCodecPreferencesForPublisher(offer webrtc.SessionDescription) webrtc.SessionDescription { + offer = p.setCodecPreferencesOpusRedForPublisher(offer) + offer = p.setCodecPreferencesVideoForPublisher(offer) + return offer +} + +func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.SessionDescription) webrtc.SessionDescription { + parsed, lastAudio, err := p.TransportManager.GetLastUnmatchedMediaForOffer(offer, "audio") + if err != nil || lastAudio == nil { + return offer + } + + streamID, ok := lksdp.ExtractStreamID(lastAudio) + if !ok { + return offer + } + + p.pendingTracksLock.RLock() + _, info := p.getPendingTrack(streamID, livekit.TrackType_AUDIO) + // if RED is disabled for this track, don't prefer RED codec in offer + if info != nil && info.DisableRed { + p.pendingTracksLock.RUnlock() + return offer + } + p.pendingTracksLock.RUnlock() + + codecs, err := codecsFromMediaDescription(lastAudio) + if err != nil { + return offer + } + + var opusPayload uint8 + for _, codec := range codecs { + if strings.EqualFold(codec.Name, "opus") { + opusPayload = codec.PayloadType + break + } + } + if opusPayload == 0 { + return offer + } + + var preferredCodecs, leftCodecs []string + for _, codec := range codecs { + // codec contain opus/red + if strings.EqualFold(codec.Name, "red") && strings.Contains(codec.Fmtp, strconv.FormatInt(int64(opusPayload), 10)) { + preferredCodecs = append(preferredCodecs, strconv.FormatInt(int64(codec.PayloadType), 10)) + } else { + leftCodecs = append(leftCodecs, strconv.FormatInt(int64(codec.PayloadType), 10)) + } + } + + // no opus/red found + if len(preferredCodecs) == 0 { + return offer + } + + lastAudio.MediaName.Formats = append(lastAudio.MediaName.Formats[:0], preferredCodecs...) + lastAudio.MediaName.Formats = append(lastAudio.MediaName.Formats, leftCodecs...) + + bytes, err := parsed.Marshal() + if err != nil { + p.params.Logger.Errorw("failed to marshal offer", err) + return offer + } + + return webrtc.SessionDescription{ + Type: offer.Type, + SDP: string(bytes), + } +} + +func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.SessionDescription) webrtc.SessionDescription { + parsed, lastVideo, err := p.TransportManager.GetLastUnmatchedMediaForOffer(offer, "video") + if err != nil || lastVideo == nil { + return offer + } + // last video is pending for publish, set codec preference + streamID, ok := lksdp.ExtractStreamID(lastVideo) + if !ok { + return offer + } + + p.pendingTracksLock.RLock() + _, info := p.getPendingTrack(streamID, livekit.TrackType_VIDEO) + if info == nil { + p.pendingTracksLock.RUnlock() + return offer + } + var mime string + for _, c := range info.Codecs { + if c.Cid == streamID { + mime = c.MimeType + break + } + } + if mime == "" && len(info.Codecs) > 0 { + mime = info.Codecs[0].MimeType + } + p.pendingTracksLock.RUnlock() + + if mime == "" { + return offer + } + + codecs, err := codecsFromMediaDescription(lastVideo) + if err != nil { + return offer + } + + mime = strings.ToUpper(mime) + var preferredCodecs, leftCodecs []string + for _, c := range codecs { + if strings.HasSuffix(mime, strings.ToUpper(c.Name)) { + preferredCodecs = append(preferredCodecs, strconv.FormatInt(int64(c.PayloadType), 10)) + } else { + leftCodecs = append(leftCodecs, strconv.FormatInt(int64(c.PayloadType), 10)) + } + } + + lastVideo.MediaName.Formats = append(lastVideo.MediaName.Formats[:0], preferredCodecs...) + lastVideo.MediaName.Formats = append(lastVideo.MediaName.Formats, leftCodecs...) + + bytes, err := parsed.Marshal() + if err != nil { + p.params.Logger.Errorw("failed to marshal offer", err) + return offer + } + + return webrtc.SessionDescription{ + Type: offer.Type, + SDP: string(bytes), + } +} + +// configure publisher answer for audio track's dtx and stereo settings +func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescription) webrtc.SessionDescription { + offer := p.TransportManager.LastPublisherOffer() + parsedOffer, err := offer.Unmarshal() + if err != nil { + return answer + } + + parsed, err := answer.Unmarshal() + if err != nil { + return answer + } + + for _, m := range parsed.MediaDescriptions { + switch m.MediaName.Media { + case "audio": + mid, ok := m.Attribute(sdp.AttrKeyMID) + if !ok { + continue + } + // find track info from offer's stream id + var ti *livekit.TrackInfo + for _, om := range parsedOffer.MediaDescriptions { + omid, ok := om.Attribute(sdp.AttrKeyMID) + if ok && omid == mid { + streamID, ok := lksdp.ExtractStreamID(om) + if !ok { + continue + } + track, _ := p.getPublishedTrackBySdpCid(streamID).(*MediaTrack) + if track == nil { + p.pendingTracksLock.RLock() + _, ti = p.getPendingTrack(streamID, livekit.TrackType_AUDIO) + p.pendingTracksLock.RUnlock() + } else { + ti = track.TrackInfo(false) + } + break + } + } + + if ti == nil || (ti.DisableDtx && !ti.Stereo) { + // no need to configure + continue + } + + opusPT, err := parsed.GetPayloadTypeForCodec(sdp.Codec{Name: "opus"}) + if err != nil { + p.params.Logger.Infow("failed to get opus payload type", "error", err, "trakcID", ti.Sid) + continue + } + + for i, attr := range m.Attributes { + if strings.HasPrefix(attr.String(), fmt.Sprintf("fmtp:%d", opusPT)) { + if !ti.DisableDtx { + attr.Value += ";usedtx=1" + } + if ti.Stereo { + attr.Value += ";stereo=1" + } + m.Attributes[i] = attr + } + } + + default: + continue + } + } + + bytes, err := parsed.Marshal() + if err != nil { + p.params.Logger.Infow("failed to marshal answer", "error", err) + return answer + } + answer.SDP = string(bytes) + return answer +} diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 30a6d5f0c..5a467f143 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -167,7 +167,6 @@ type PCTransport struct { onAnswer func(answer webrtc.SessionDescription) error onInitialConnected func() onFailed func(isShortLived bool) - onGetDTX func() bool onNegotiationStateChanged func(state NegotiationState) onNegotiationFailed func() @@ -557,7 +556,7 @@ func (t *PCTransport) AddICECandidate(candidate webrtc.ICECandidateInit) { }) } -func (t *PCTransport) AddTrack(trackLocal webrtc.TrackLocal) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { +func (t *PCTransport) AddTrack(trackLocal webrtc.TrackLocal, params types.AddTrackParams) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { t.lock.Lock() canReuse := t.canReuseTransceiver td, ok := t.previousTrackDescription[trackLocal.ID()] @@ -577,7 +576,7 @@ func (t *PCTransport) AddTrack(trackLocal webrtc.TrackLocal) (sender *webrtc.RTP // if never negotiated with client, can't reuse transeiver for track not subscribed before migration if !canReuse { - return t.AddTransceiverFromTrack(trackLocal) + return t.AddTransceiverFromTrack(trackLocal, params) } sender, err = t.pc.AddTrack(trackLocal) @@ -598,10 +597,12 @@ func (t *PCTransport) AddTrack(trackLocal webrtc.TrackLocal) (sender *webrtc.RTP return } + configureTransceiverStereo(transceiver, params.Stereo) + return } -func (t *PCTransport) AddTransceiverFromTrack(trackLocal webrtc.TrackLocal) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { +func (t *PCTransport) AddTransceiverFromTrack(trackLocal webrtc.TrackLocal, params types.AddTrackParams) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { transceiver, err = t.pc.AddTransceiverFromTrack(trackLocal) if err != nil { return @@ -613,6 +614,8 @@ func (t *PCTransport) AddTransceiverFromTrack(trackLocal webrtc.TrackLocal) (sen return } + configureTransceiverStereo(transceiver, params.Stereo) + return } @@ -820,19 +823,6 @@ func (t *PCTransport) getOnFailed() func(isShortLived bool) { return t.onFailed } -func (t *PCTransport) OnGetDTX(f func() bool) { - t.lock.Lock() - t.onGetDTX = f - t.lock.Unlock() -} - -func (t *PCTransport) getOnGetDTX() func() bool { - t.lock.RLock() - defer t.lock.RUnlock() - - return t.onGetDTX -} - func (t *PCTransport) OnTrack(f func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver)) { t.pc.OnTrack(f) } @@ -933,98 +923,6 @@ func (t *PCTransport) Negotiate(force bool) { } } -func (t *PCTransport) configureReceiverDTXAndStereo(enableDTX bool) { - // - // DTX (Discontinuous Transmission) allows audio bandwidth saving - // by not sending packets during silence periods. - // - // Publisher side DTX can enabled by including `usedtx=1` in - // the `fmtp` line corresponding to audio codec (Opus) in SDP. - // By doing this in the SDP `answer`, it can be controlled from - // server side and avoid doing it in all the client SDKs. - // - // Ideally, a publisher should be able to specify per audio - // track if DTX should be enabled. But, translating the - // DTX preference of publisher to the correct transceiver - // is non-deterministic due to the lack of a synchronizing id - // like the track id. - // - // The codec preference to set DTX needs to be done - // - after calling `SetRemoteDescription` which sets up - // the transceivers, but only if there are no tracks in the - // transceiver yet - // - before calling `CreateAnswer` - // Due to the absence of tracks when it is required to set DTX, - // it is not possible to cross reference against a pending track - // with the same track id. - // - // Due to the restriction above and given that in practice - // most of the time there is going to be only one audio track - // that is published, do the following - // - if there is no pending audio track, no-op - // - if there are no audio transceivers without tracks, no-op - // - else, apply the DTX setting from pending audio track - // to the audio transceiver without any track - // - // NOTE: The above logic will fail if there is an `offer` SDP with - // multiple audio tracks. At that point, there might be a need to - // rely on something like order of tracks. TODO - // - transceivers := t.pc.GetTransceivers() - for _, transceiver := range transceivers { - if transceiver.Kind() != webrtc.RTPCodecTypeAudio { - continue - } - - receiver := transceiver.Receiver() - if receiver == nil || receiver.Track() != nil { - continue - } - - var modifiedReceiverCodecs []webrtc.RTPCodecParameters - - receiverCodecs := receiver.GetParameters().Codecs - for _, receiverCodec := range receiverCodecs { - if receiverCodec.MimeType == webrtc.MimeTypeOpus { - fmtpUseDTX := "usedtx=1" - // remove occurrence in the middle - sdpFmtpLine := strings.ReplaceAll(receiverCodec.SDPFmtpLine, fmtpUseDTX+";", "") - // remove occurrence at the end - sdpFmtpLine = strings.ReplaceAll(sdpFmtpLine, fmtpUseDTX, "") - if enableDTX { - sdpFmtpLine += ";" + fmtpUseDTX - } - - fmtpStereo := "stereo=1" - // remove occurrence in the middle - sdpFmtpLine = strings.ReplaceAll(sdpFmtpLine, fmtpStereo+";", "") - // remove occurrence at the end - sdpFmtpLine = strings.ReplaceAll(sdpFmtpLine, fmtpStereo, "") - sdpFmtpLine += ";" + fmtpStereo - - receiverCodec.SDPFmtpLine = sdpFmtpLine - } - modifiedReceiverCodecs = append(modifiedReceiverCodecs, receiverCodec) - } - - // - // As `SetCodecPreferences` on a transceiver replaces all codecs, - // cycle through sender codecs also and add them before calling - // `SetCodecPreferences` - // - var senderCodecs []webrtc.RTPCodecParameters - sender := transceiver.Sender() - if sender != nil { - senderCodecs = sender.GetParameters().Codecs - } - - err := transceiver.SetCodecPreferences(append(modifiedReceiverCodecs, senderCodecs...)) - if err != nil { - t.params.Logger.Warnw("failed to SetCodecPreferences", err) - } - } -} - func (t *PCTransport) ICERestart() { t.postEvent(event{ signal: signalICERestart, @@ -1689,12 +1587,6 @@ func (t *PCTransport) setRemoteDescription(sd webrtc.SessionDescription) error { } func (t *PCTransport) createAndSendAnswer() error { - enableDTX := false - if onGetDTX := t.getOnGetDTX(); onGetDTX != nil { - enableDTX = onGetDTX() - } - t.configureReceiverDTXAndStereo(enableDTX) - answer, err := t.pc.CreateAnswer(nil) if err != nil { prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "create").Add(1) @@ -1836,3 +1728,25 @@ func (t *PCTransport) doICERestart() error { func (t *PCTransport) handleICERestart(e *event) error { return t.doICERestart() } + +// configure subscriber tranceiver for audio stereo +func configureTransceiverStereo(tr *webrtc.RTPTransceiver, stereo bool) { + sender := tr.Sender() + if sender == nil { + return + } + // enable stereo + codecs := sender.GetParameters().Codecs + configCodecs := make([]webrtc.RTPCodecParameters, 0, len(codecs)) + for _, c := range codecs { + if strings.EqualFold(c.MimeType, webrtc.MimeTypeOpus) { + c.SDPFmtpLine = strings.ReplaceAll(c.SDPFmtpLine, ";sprop-stereo=1", "") + if stereo { + c.SDPFmtpLine += ";sprop-stereo=1" + } + } + configCodecs = append(configCodecs, c) + } + + tr.SetCodecPreferences(configCodecs) +} diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index c013abf0e..99f31e72b 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -52,6 +52,7 @@ type TransportManager struct { pendingOfferPublisher *webrtc.SessionDescription pendingDataChannelsPublisher []*livekit.DataChannelInfo lastPublisherAnswer atomic.Value + lastPublisherOffer atomic.Value iceConfig types.IceConfig onPublisherInitialConnected func() @@ -172,10 +173,6 @@ func (t *TransportManager) OnPublisherICECandidate(f func(c *webrtc.ICECandidate t.publisher.OnICECandidate(f) } -func (t *TransportManager) OnPublisherGetDTX(f func() bool) { - t.publisher.OnGetDTX(f) -} - func (t *TransportManager) OnPublisherAnswer(f func(answer webrtc.SessionDescription) error) { t.publisher.OnAnswer(func(sd webrtc.SessionDescription) error { t.lastPublisherAnswer.Store(sd) @@ -227,12 +224,12 @@ func (t *TransportManager) HasSubscriberEverConnected() bool { return t.subscriber.HasEverConnected() } -func (t *TransportManager) AddTrackToSubscriber(trackLocal webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { - return t.subscriber.AddTrack(trackLocal) +func (t *TransportManager) AddTrackToSubscriber(trackLocal webrtc.TrackLocal, params types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { + return t.subscriber.AddTrack(trackLocal, params) } -func (t *TransportManager) AddTransceiverFromTrackToSubscriber(trackLocal webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { - return t.subscriber.AddTransceiverFromTrack(trackLocal) +func (t *TransportManager) AddTransceiverFromTrackToSubscriber(trackLocal webrtc.TrackLocal, params types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { + return t.subscriber.AddTransceiverFromTrack(trackLocal, params) } func (t *TransportManager) RemoveTrackFromSubscriber(sender *webrtc.RTPSender) error { @@ -368,6 +365,13 @@ func (t *TransportManager) GetLastUnmatchedMediaForOffer(offer webrtc.SessionDes return } +func (t *TransportManager) LastPublisherOffer() webrtc.SessionDescription { + if sd := t.lastPublisherOffer.Load(); sd != nil { + return sd.(webrtc.SessionDescription) + } + return webrtc.SessionDescription{} +} + func (t *TransportManager) HandleOffer(offer webrtc.SessionDescription, shouldPend bool) { t.lock.Lock() if shouldPend { @@ -376,6 +380,7 @@ func (t *TransportManager) HandleOffer(offer webrtc.SessionDescription, shouldPe return } t.lock.Unlock() + t.lastPublisherOffer.Store(offer) t.publisher.HandleRemoteDescription(offer) } diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 9a72956bc..8796daca2 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -225,6 +225,10 @@ const ( ICEConnectionTypeUnknown ICEConnectionType = "unknown" ) +type AddTrackParams struct { + Stereo bool +} + //counterfeiter:generate . LocalParticipant type LocalParticipant interface { Participant @@ -259,8 +263,8 @@ type LocalParticipant interface { HandleAnswer(sdp webrtc.SessionDescription) Negotiate(force bool) ICERestart(iceConfig *IceConfig) - AddTrackToSubscriber(trackLocal webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) - AddTransceiverFromTrackToSubscriber(trackLocal webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) + AddTrackToSubscriber(trackLocal webrtc.TrackLocal, params AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) + AddTransceiverFromTrackToSubscriber(trackLocal webrtc.TrackLocal, params AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) RemoveTrackFromSubscriber(sender *webrtc.RTPSender) error // subscriptions diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index f6ee8c0cf..a963fe020 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -45,10 +45,11 @@ type FakeLocalParticipant struct { addTrackArgsForCall []struct { arg1 *livekit.AddTrackRequest } - AddTrackToSubscriberStub func(webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) + AddTrackToSubscriberStub func(webrtc.TrackLocal, types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) addTrackToSubscriberMutex sync.RWMutex addTrackToSubscriberArgsForCall []struct { arg1 webrtc.TrackLocal + arg2 types.AddTrackParams } addTrackToSubscriberReturns struct { result1 *webrtc.RTPSender @@ -60,10 +61,11 @@ type FakeLocalParticipant struct { result2 *webrtc.RTPTransceiver result3 error } - AddTransceiverFromTrackToSubscriberStub func(webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) + AddTransceiverFromTrackToSubscriberStub func(webrtc.TrackLocal, types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) addTransceiverFromTrackToSubscriberMutex sync.RWMutex addTransceiverFromTrackToSubscriberArgsForCall []struct { arg1 webrtc.TrackLocal + arg2 types.AddTrackParams } addTransceiverFromTrackToSubscriberReturns struct { result1 *webrtc.RTPSender @@ -912,18 +914,19 @@ func (fake *FakeLocalParticipant) AddTrackArgsForCall(i int) *livekit.AddTrackRe return argsForCall.arg1 } -func (fake *FakeLocalParticipant) AddTrackToSubscriber(arg1 webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { +func (fake *FakeLocalParticipant) AddTrackToSubscriber(arg1 webrtc.TrackLocal, arg2 types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { fake.addTrackToSubscriberMutex.Lock() ret, specificReturn := fake.addTrackToSubscriberReturnsOnCall[len(fake.addTrackToSubscriberArgsForCall)] fake.addTrackToSubscriberArgsForCall = append(fake.addTrackToSubscriberArgsForCall, struct { arg1 webrtc.TrackLocal - }{arg1}) + arg2 types.AddTrackParams + }{arg1, arg2}) stub := fake.AddTrackToSubscriberStub fakeReturns := fake.addTrackToSubscriberReturns - fake.recordInvocation("AddTrackToSubscriber", []interface{}{arg1}) + fake.recordInvocation("AddTrackToSubscriber", []interface{}{arg1, arg2}) fake.addTrackToSubscriberMutex.Unlock() if stub != nil { - return stub(arg1) + return stub(arg1, arg2) } if specificReturn { return ret.result1, ret.result2, ret.result3 @@ -937,17 +940,17 @@ func (fake *FakeLocalParticipant) AddTrackToSubscriberCallCount() int { return len(fake.addTrackToSubscriberArgsForCall) } -func (fake *FakeLocalParticipant) AddTrackToSubscriberCalls(stub func(webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error)) { +func (fake *FakeLocalParticipant) AddTrackToSubscriberCalls(stub func(webrtc.TrackLocal, types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error)) { fake.addTrackToSubscriberMutex.Lock() defer fake.addTrackToSubscriberMutex.Unlock() fake.AddTrackToSubscriberStub = stub } -func (fake *FakeLocalParticipant) AddTrackToSubscriberArgsForCall(i int) webrtc.TrackLocal { +func (fake *FakeLocalParticipant) AddTrackToSubscriberArgsForCall(i int) (webrtc.TrackLocal, types.AddTrackParams) { fake.addTrackToSubscriberMutex.RLock() defer fake.addTrackToSubscriberMutex.RUnlock() argsForCall := fake.addTrackToSubscriberArgsForCall[i] - return argsForCall.arg1 + return argsForCall.arg1, argsForCall.arg2 } func (fake *FakeLocalParticipant) AddTrackToSubscriberReturns(result1 *webrtc.RTPSender, result2 *webrtc.RTPTransceiver, result3 error) { @@ -979,18 +982,19 @@ func (fake *FakeLocalParticipant) AddTrackToSubscriberReturnsOnCall(i int, resul }{result1, result2, result3} } -func (fake *FakeLocalParticipant) AddTransceiverFromTrackToSubscriber(arg1 webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { +func (fake *FakeLocalParticipant) AddTransceiverFromTrackToSubscriber(arg1 webrtc.TrackLocal, arg2 types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { fake.addTransceiverFromTrackToSubscriberMutex.Lock() ret, specificReturn := fake.addTransceiverFromTrackToSubscriberReturnsOnCall[len(fake.addTransceiverFromTrackToSubscriberArgsForCall)] fake.addTransceiverFromTrackToSubscriberArgsForCall = append(fake.addTransceiverFromTrackToSubscriberArgsForCall, struct { arg1 webrtc.TrackLocal - }{arg1}) + arg2 types.AddTrackParams + }{arg1, arg2}) stub := fake.AddTransceiverFromTrackToSubscriberStub fakeReturns := fake.addTransceiverFromTrackToSubscriberReturns - fake.recordInvocation("AddTransceiverFromTrackToSubscriber", []interface{}{arg1}) + fake.recordInvocation("AddTransceiverFromTrackToSubscriber", []interface{}{arg1, arg2}) fake.addTransceiverFromTrackToSubscriberMutex.Unlock() if stub != nil { - return stub(arg1) + return stub(arg1, arg2) } if specificReturn { return ret.result1, ret.result2, ret.result3 @@ -1004,17 +1008,17 @@ func (fake *FakeLocalParticipant) AddTransceiverFromTrackToSubscriberCallCount() return len(fake.addTransceiverFromTrackToSubscriberArgsForCall) } -func (fake *FakeLocalParticipant) AddTransceiverFromTrackToSubscriberCalls(stub func(webrtc.TrackLocal) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error)) { +func (fake *FakeLocalParticipant) AddTransceiverFromTrackToSubscriberCalls(stub func(webrtc.TrackLocal, types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error)) { fake.addTransceiverFromTrackToSubscriberMutex.Lock() defer fake.addTransceiverFromTrackToSubscriberMutex.Unlock() fake.AddTransceiverFromTrackToSubscriberStub = stub } -func (fake *FakeLocalParticipant) AddTransceiverFromTrackToSubscriberArgsForCall(i int) webrtc.TrackLocal { +func (fake *FakeLocalParticipant) AddTransceiverFromTrackToSubscriberArgsForCall(i int) (webrtc.TrackLocal, types.AddTrackParams) { fake.addTransceiverFromTrackToSubscriberMutex.RLock() defer fake.addTransceiverFromTrackToSubscriberMutex.RUnlock() argsForCall := fake.addTransceiverFromTrackToSubscriberArgsForCall[i] - return argsForCall.arg1 + return argsForCall.arg1, argsForCall.arg2 } func (fake *FakeLocalParticipant) AddTransceiverFromTrackToSubscriberReturns(result1 *webrtc.RTPSender, result2 *webrtc.RTPTransceiver, result3 error) { diff --git a/test/client/client.go b/test/client/client.go index e000f7710..0923b81f2 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -23,6 +23,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/rtc/types" ) type RTCClient struct { @@ -507,7 +508,7 @@ func (c *RTCClient) AddTrack(track *webrtc.TrackLocalStaticSample, path string) c.lock.Lock() defer c.lock.Unlock() - sender, _, err := c.publisher.AddTrack(track) + sender, _, err := c.publisher.AddTrack(track, types.AddTrackParams{}) if err != nil { logger.Errorw("add track failed", err, "trackID", ti.Sid, "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) return