diff --git a/config-sample.yaml b/config-sample.yaml index 99b124dd7..989a3254d 100644 --- a/config-sample.yaml +++ b/config-sample.yaml @@ -176,6 +176,13 @@ keys: # # cert_file: /path/to/cert.pem # # key_file: /path/to/key.pem +# ingress server +# ingress: +# # Prefix used to generate RTMP URLs for RTMP ingress. +# # The stream_key will be appended to this base and returned as part of the +# # ingress info +# rtmp_base_url: "rtmp://my.domain.com/live" + # Region of the current node. Required if using regionaware node selector # region: us-west-2 diff --git a/go.mod b/go.mod index cf982eec5..4f07c6bca 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/gorilla/websocket v1.4.2 github.com/hashicorp/go-version v1.6.0 github.com/hashicorp/golang-lru v0.5.4 - github.com/livekit/protocol v0.13.5-0.20220801175011-ae34dc3ec45d + github.com/livekit/protocol v0.13.5-0.20220805160532-dc99a5ad3ce2 github.com/livekit/rtcscore-go v0.0.0-20220524203225-dfd1ba40744a github.com/mackerelio/go-osstat v0.2.1 github.com/magefile/mage v1.13.0 diff --git a/go.sum b/go.sum index d9051edc1..810debff7 100644 --- a/go.sum +++ b/go.sum @@ -237,14 +237,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lithammer/shortuuid/v3 v3.0.7 h1:trX0KTHy4Pbwo/6ia8fscyHoGA+mf1jWbPJVuvyJQQ8= github.com/lithammer/shortuuid/v3 v3.0.7/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts= -github.com/livekit/protocol v0.13.5-0.20220726184153-ad9c55ddef52 h1:E0trQ3RLu2b9hjSiJG1+1hyK/8v57NPJznA7/lKj0qY= -github.com/livekit/protocol v0.13.5-0.20220726184153-ad9c55ddef52/go.mod h1:Qd/Dn4BkJfZQy/IjtEeUOGXARrR7l09WDkg5SY8thkw= -github.com/livekit/protocol v0.13.5-0.20220727215941-ac26418a52e9 h1:e12j1EyiiTG56Ag44fwpVtnYQ6MVgLv4bYYI0nTgxZY= -github.com/livekit/protocol v0.13.5-0.20220727215941-ac26418a52e9/go.mod h1:Qd/Dn4BkJfZQy/IjtEeUOGXARrR7l09WDkg5SY8thkw= -github.com/livekit/protocol v0.13.5-0.20220801175011-ae34dc3ec45d h1:9VHZG4Tu723DA/jsg0APEmnk5blWRif9indB/nkdeFY= -github.com/livekit/protocol v0.13.5-0.20220801175011-ae34dc3ec45d/go.mod h1:vGQzKUaSYC92o5y7EbnhosgpoLWK9a3PneyYkGOGL0o= -github.com/livekit/protocol v0.13.5-0.20220728214908-67539ebcab2a h1:tRioM9WNDjxGryt03ROYa8zq17J0MqHftCLr8Ex4dM0= -github.com/livekit/protocol v0.13.5-0.20220728214908-67539ebcab2a/go.mod h1:vGQzKUaSYC92o5y7EbnhosgpoLWK9a3PneyYkGOGL0o= +github.com/livekit/protocol v0.13.5-0.20220805160532-dc99a5ad3ce2 h1:PFZfzLm1gNjX4Z3jOlKcSDMMSt1bIbqZ7av2399uoO0= +github.com/livekit/protocol v0.13.5-0.20220805160532-dc99a5ad3ce2/go.mod h1:vGQzKUaSYC92o5y7EbnhosgpoLWK9a3PneyYkGOGL0o= github.com/livekit/rtcscore-go v0.0.0-20220524203225-dfd1ba40744a h1:cENjhGfslLSDV07gt8ASy47Wd12Q0kBS7hsdunyQ62I= github.com/livekit/rtcscore-go v0.0.0-20220524203225-dfd1ba40744a/go.mod h1:116ych8UaEs9vfIE8n6iZCZ30iagUFTls0vRmC+Ix5U= github.com/mackerelio/go-osstat v0.2.1 h1:5AeAcBEutEErAOlDz6WCkEvm6AKYgHTUQrfwm5RbeQc= @@ -763,8 +757,6 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 2af215e13..99734f778 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -152,6 +152,7 @@ type ParticipantImpl struct { pendingDataChannels []*livekit.DataChannelInfo onClose func(types.LocalParticipant, map[livekit.TrackID]livekit.ParticipantID) onClaimsChanged func(participant types.LocalParticipant) + onICEConfigChanged func(participant types.LocalParticipant, iceConfig types.IceConfig) activeCounter atomic.Int32 firstConnected atomic.Bool @@ -277,8 +278,17 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { } else { p.activeCounter.Add(2) } + + primaryPC.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + p.handleICEStateChange(true, state) + }) + secondaryPC.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + p.handleICEStateChange(false, state) + }) + primaryPC.OnConnectionStateChange(p.handlePrimaryStateChange) secondaryPC.OnConnectionStateChange(p.handleSecondaryStateChange) + p.publisher.pc.OnTrack(p.onMediaTrack) p.publisher.pc.OnDataChannel(p.onDataChannel) @@ -529,7 +539,9 @@ func (p *ParticipantImpl) OnClose(callback func(types.LocalParticipant, map[live } func (p *ParticipantImpl) OnClaimsChanged(callback func(types.LocalParticipant)) { + p.lock.Lock() p.onClaimsChanged = callback + p.lock.Unlock() } // HandleOffer an offer from remote participant, used when clients make the initial connection @@ -561,6 +573,8 @@ func (p *ParticipantImpl) HandleOffer(sdp webrtc.SessionDescription) (answer web return } + answer = p.publisher.FilterCandidates(answer) + if err = p.publisher.pc.SetLocalDescription(answer); err != nil { prometheus.ServiceOperationCounter.WithLabelValues("answer", "error", "local_description").Add(1) err = errors.Wrap(err, "could not set local description") @@ -830,9 +844,7 @@ func (p *ParticipantImpl) MigrateState() types.MigrateState { // ICERestart restarts subscriber ICE connections func (p *ParticipantImpl) ICERestart(iceConfig *types.IceConfig) error { if iceConfig != nil { - p.lock.Lock() - p.iceConfig = *iceConfig - p.lock.Unlock() + p.SetICEConfig(*iceConfig) } if p.subscriber.pc.RemoteDescription() == nil { @@ -847,6 +859,31 @@ func (p *ParticipantImpl) ICERestart(iceConfig *types.IceConfig) error { }) } +func (p *ParticipantImpl) OnICEConfigChanged(f func(participant types.LocalParticipant, iceConfig types.IceConfig)) { + p.lock.Lock() + p.onICEConfigChanged = f + p.lock.Unlock() +} + +func (p *ParticipantImpl) SetICEConfig(iceConfig types.IceConfig) { + p.lock.Lock() + p.iceConfig = iceConfig + if iceConfig.PreferPubTcp { + p.publisher.SetPreferTCP(true) + } + + if iceConfig.PreferSubTcp { + p.subscriber.SetPreferTCP(true) + } + + onICEConfigChanged := p.onICEConfigChanged + p.lock.Unlock() + + if onICEConfigChanged != nil { + onICEConfigChanged(p, iceConfig) + } +} + // // signal connection methods // @@ -989,7 +1026,7 @@ func (p *ParticipantImpl) UpdateSubscribedTrackSettings(trackID livekit.TrackID, // AddSubscribedTrack adds a track to the participant's subscribed list func (p *ParticipantImpl) AddSubscribedTrack(subTrack types.SubscribedTrack) { - p.params.Logger.Debugw("added subscribedTrack", + p.params.Logger.Infow("added subscribedTrack", "publisherID", subTrack.PublisherID(), "publisherIdentity", subTrack.PublisherIdentity(), "trackID", subTrack.ID()) @@ -1031,7 +1068,7 @@ func (p *ParticipantImpl) AddSubscribedTrack(subTrack types.SubscribedTrack) { // RemoveSubscribedTrack removes a track to the participant's subscribed list func (p *ParticipantImpl) RemoveSubscribedTrack(subTrack types.SubscribedTrack) { - p.params.Logger.Debugw("removed subscribedTrack", + p.params.Logger.Infow("removed subscribedTrack", "publisherID", subTrack.PublisherID(), "publisherIdentity", subTrack.PublisherIdentity(), "trackID", subTrack.ID(), "kind", subTrack.DownTrack().Kind()) @@ -1283,6 +1320,54 @@ func (p *ParticipantImpl) handleDataMessage(kind livekit.DataPacket_Kind, data [ } } +func (p *ParticipantImpl) getTransport(isPrimary bool) *PCTransport { + pcTransport := p.publisher + if (isPrimary && p.SubscriberAsPrimary()) || (!isPrimary && !p.SubscriberAsPrimary()) { + pcTransport = p.subscriber + } + + return pcTransport +} + +func (p *ParticipantImpl) handleICEConnected(isPrimary bool) { + pcTransport := p.getTransport(isPrimary) + pcTransport.SetICEConnectedAt(time.Now()) + + if pair, err := pcTransport.GetSelectedPair(); err != nil { + pcTransport.Logger().Errorw("error getting selected ICE candidate pair", err) + } else { + pcTransport.Logger().Infow("selected ICE candidate pair", "pair", pair) + } +} + +func (p *ParticipantImpl) handleConnectionFailed(isPrimary bool) { + pcTransport := p.getTransport(isPrimary) + isShort, duration := pcTransport.IsShortConnection(time.Now()) + if isShort { + // irrespective of which one fails, force TCP on both as the other one might + // fail at a different time and cause another disruption + pair, err := pcTransport.GetSelectedPair() + if err != nil { + pcTransport.Logger().Errorw("short ICE connection", err, "duration", duration) + } else { + pcTransport.Logger().Infow("short ICE connection", "pair", pair, "duration", duration) + } + pcTransport.Logger().Infow("restricting transport to TCP on both peer connections") + p.SetICEConfig(types.IceConfig{ + PreferPubTcp: true, + PreferSubTcp: true, + }) + } +} + +func (p *ParticipantImpl) handleICEStateChange(isPrimary bool, state webrtc.ICEConnectionState) { + if state == webrtc.ICEConnectionStateConnected { + p.handleICEConnected(isPrimary) + } else if state == webrtc.ICEConnectionStateFailed { + p.handleConnectionFailed(isPrimary) + } +} + func (p *ParticipantImpl) handlePrimaryStateChange(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateConnected { if !p.firstConnected.Swap(true) { @@ -1294,6 +1379,8 @@ func (p *ParticipantImpl) handlePrimaryStateChange(state webrtc.PeerConnectionSt } p.incActiveCounter() } else if state == webrtc.PeerConnectionStateFailed { + p.handleConnectionFailed(true) + // clients support resuming of connections when websocket becomes disconnected p.closeSignalConnection() @@ -1302,9 +1389,11 @@ func (p *ParticipantImpl) handlePrimaryStateChange(state webrtc.PeerConnectionSt p.lock.Lock() if p.disconnectTimer != nil { p.disconnectTimer.Stop() + p.disconnectTimer = nil } p.disconnectTimer = time.AfterFunc(disconnectCleanupDuration, func() { p.lock.Lock() + p.disconnectTimer.Stop() p.disconnectTimer = nil p.lock.Unlock() @@ -1330,6 +1419,8 @@ func (p *ParticipantImpl) handlePrimaryStateChange(state webrtc.PeerConnectionSt // instead of allowing them to silently fail. func (p *ParticipantImpl) handleSecondaryStateChange(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateFailed { + p.handleConnectionFailed(false) + // clients support resuming of connections when websocket becomes disconnected p.closeSignalConnection() } diff --git a/pkg/rtc/participant_signal.go b/pkg/rtc/participant_signal.go index 971e52a14..670b96012 100644 --- a/pkg/rtc/participant_signal.go +++ b/pkg/rtc/participant_signal.go @@ -54,6 +54,9 @@ func (p *ParticipantImpl) SendJoinResponse( // indicates both server and client support subscriber as primary SubscriberPrimary: p.SubscriberAsPrimary(), ClientConfiguration: p.params.ClientConf, + // sane defaults for ping interval & timeout + PingInterval: 10, + PingTimeout: 20, }, }, }) diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 2a008c814..5b10381f1 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -35,6 +35,8 @@ const ( iceDisconnectedTimeout = 10 * time.Second // compatible for ice-lite with firefox client iceFailedTimeout = 25 * time.Second // pion's default iceKeepaliveInterval = 2 * time.Second // pion's default + + shortConnectionThreshold = 2 * time.Minute ) var ( @@ -61,6 +63,7 @@ type PCTransport struct { me *webrtc.MediaEngine lock sync.RWMutex + iceConnectedAt time.Time pendingCandidates []webrtc.ICECandidateInit debouncedNegotiate func(func()) negotiationPending map[livekit.ParticipantID]bool @@ -76,6 +79,8 @@ type PCTransport struct { streamAllocator *sfu.StreamAllocator previousAnswer *webrtc.SessionDescription + + preferTCP bool } type TransportParams struct { @@ -212,6 +217,53 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { return t, nil } +func (t *PCTransport) Logger() logger.Logger { + return t.params.Logger +} + +func (t *PCTransport) SetICEConnectedAt(at time.Time) { + t.lock.Lock() + t.iceConnectedAt = at + t.lock.Unlock() +} + +func (t *PCTransport) IsShortConnection(at time.Time) (bool, time.Duration) { + t.lock.RLock() + defer t.lock.RUnlock() + + if t.iceConnectedAt.IsZero() { + return false, 0 + } + + duration := at.Sub(t.iceConnectedAt) + return duration < shortConnectionThreshold, duration +} + +func (t *PCTransport) GetSelectedPair() (*webrtc.ICECandidatePair, error) { + sctp := t.pc.SCTP() + if sctp == nil { + return nil, errors.New("no SCTP") + } + + dtlsTransport := sctp.Transport() + if dtlsTransport == nil { + return nil, errors.New("no DTLS transport") + } + + iceTransport := dtlsTransport.ICETransport() + if iceTransport == nil { + return nil, errors.New("no ICE transport") + } + + return iceTransport.GetSelectedCandidatePair() +} + +func (t *PCTransport) SetPreferTCP(preferTCP bool) { + t.lock.Lock() + t.preferTCP = preferTCP + t.lock.Unlock() +} + func (t *PCTransport) createPeerConnection() error { var bwe cc.BandwidthEstimator pc, me, err := newPeerConnection(t.params, func(estimator cc.BandwidthEstimator) { @@ -442,6 +494,8 @@ func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { return err } + offer = t.filterCandidates(offer) + err = t.pc.SetLocalDescription(offer) if err != nil { prometheus.ServiceOperationCounter.WithLabelValues("offer", "error", "local_description").Add(1) @@ -624,6 +678,55 @@ func (t *PCTransport) SetPreviousAnswer(answer *webrtc.SessionDescription) { } } +func (t *PCTransport) FilterCandidates(sd webrtc.SessionDescription) webrtc.SessionDescription { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.filterCandidates(sd) +} + +func (t *PCTransport) filterCandidates(sd webrtc.SessionDescription) webrtc.SessionDescription { + parsed, err := sd.Unmarshal() + if err != nil { + t.params.Logger.Errorw("could not unmarshal SDP to filter candidates", err) + return sd + } + + filterAttributes := func(attrs []sdp.Attribute) []sdp.Attribute { + filteredAttrs := make([]sdp.Attribute, 0, len(attrs)) + for _, a := range attrs { + if a.Key == "candidate" { + if t.preferTCP { + if strings.Contains(a.Value, "tcp") { + filteredAttrs = append(filteredAttrs, a) + } + } else { + filteredAttrs = append(filteredAttrs, a) + } + } else { + filteredAttrs = append(filteredAttrs, a) + } + } + + return filteredAttrs + } + + parsed.Attributes = filterAttributes(parsed.Attributes) + for _, m := range parsed.MediaDescriptions { + m.Attributes = filterAttributes(m.Attributes) + } + + bytes, err := parsed.Marshal() + if err != nil { + t.params.Logger.Errorw("could not marshal SDP to filter candidates", err) + return sd + } + sd.SDP = string(bytes) + return sd +} + +// --------------------------------------------- + func getMidValue(media *sdp.MediaDescription) string { for _, attr := range media.Attributes { if attr.Key == "mid" { diff --git a/pkg/rtc/transport_test.go b/pkg/rtc/transport_test.go index b16261566..6d74ebec3 100644 --- a/pkg/rtc/transport_test.go +++ b/pkg/rtc/transport_test.go @@ -1,10 +1,12 @@ package rtc import ( + "strings" "sync/atomic" "testing" "time" + "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" @@ -269,6 +271,125 @@ func TestNegotiationFailed(t *testing.T) { }, negotiationFailedTimout+time.Second, 10*time.Millisecond, "negotiation failed") } +func TestFilteringCandidates(t *testing.T) { + params := TransportParams{ + ParticipantID: "id", + ParticipantIdentity: "identity", + Target: livekit.SignalTarget_PUBLISHER, + Config: &WebRTCConfig{}, + EnabledCodecs: []*livekit.Codec{ + {Mime: webrtc.MimeTypeOpus}, + {Mime: webrtc.MimeTypeVP8}, + {Mime: webrtc.MimeTypeH264}, + }, + } + transport, err := NewPCTransport(params) + require.NoError(t, err) + + _, err = transport.pc.CreateDataChannel("test", nil) + require.NoError(t, err) + + _, err = transport.pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio) + require.NoError(t, err) + + _, err = transport.pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo) + require.NoError(t, err) + + offer, err := transport.pc.CreateOffer(nil) + require.NoError(t, err) + + offerGatheringComplete := webrtc.GatheringCompletePromise(transport.pc) + require.NoError(t, transport.pc.SetLocalDescription(offer)) + <-offerGatheringComplete + + // should not filter out UDP candidates if TCP is not preferred + offer = *transport.pc.LocalDescription() + filteredOffer := transport.FilterCandidates(offer) + require.EqualValues(t, offer.SDP, filteredOffer.SDP) + + parsed, err := offer.Unmarshal() + require.NoError(t, err) + + // add a couple of TCP candidates + done := false + for _, m := range parsed.MediaDescriptions { + for _, a := range m.Attributes { + if a.Key == "candidate" { + for idx, aa := range m.Attributes { + if aa.Key == "end-of-candidates" { + modifiedAttributes := make([]sdp.Attribute, idx) + copy(modifiedAttributes, m.Attributes[:idx]) + modifiedAttributes = append(modifiedAttributes, []sdp.Attribute{ + { + Key: "candidate", + Value: "054225987 1 tcp 2124414975 159.203.70.248 7881 typ host tcptype passive", + }, + { + Key: "candidate", + Value: "054225987 2 tcp 2124414975 159.203.70.248 7881 typ host tcptype passive", + }, + }...) + m.Attributes = append(modifiedAttributes, m.Attributes[idx:]...) + done = true + break + } + } + } + if done { + break + } + } + if done { + break + } + } + bytes, err := parsed.Marshal() + require.NoError(t, err) + offer.SDP = string(bytes) + + parsed, err = offer.Unmarshal() + require.NoError(t, err) + + getNumTransportTypeCandidates := func(sdp *sdp.SessionDescription) (int, int) { + numUDPCandidates := 0 + numTCPCandidates := 0 + for _, a := range sdp.Attributes { + if a.Key == "candidate" { + if strings.Contains(a.Value, "udp") { + numUDPCandidates++ + } + if strings.Contains(a.Value, "tcp") { + numTCPCandidates++ + } + } + } + for _, m := range sdp.MediaDescriptions { + for _, a := range m.Attributes { + if a.Key == "candidate" { + if strings.Contains(a.Value, "udp") { + numUDPCandidates++ + } + if strings.Contains(a.Value, "tcp") { + numTCPCandidates++ + } + } + } + } + return numUDPCandidates, numTCPCandidates + } + udp, tcp := getNumTransportTypeCandidates(parsed) + require.NotZero(t, udp) + require.Equal(t, 2, tcp) + + transport.SetPreferTCP(true) + filteredOffer = transport.FilterCandidates(offer) + parsed, err = filteredOffer.Unmarshal() + require.NoError(t, err) + udp, tcp = getNumTransportTypeCandidates(parsed) + require.Zero(t, udp) + require.Equal(t, 2, tcp) +} + func handleOfferFunc(t *testing.T, current, other *PCTransport) func(sd webrtc.SessionDescription) { return func(sd webrtc.SessionDescription) { t.Logf("handling offer") diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 43035ce50..95c09956f 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -274,8 +274,8 @@ type LocalParticipant interface { OnParticipantUpdate(callback func(LocalParticipant)) OnDataPacket(callback func(LocalParticipant, *livekit.DataPacket)) OnSubscribedTo(callback func(LocalParticipant, livekit.ParticipantID)) - OnClose(_callback func(LocalParticipant, map[livekit.TrackID]livekit.ParticipantID)) - OnClaimsChanged(_callback func(LocalParticipant)) + OnClose(callback func(LocalParticipant, map[livekit.TrackID]livekit.ParticipantID)) + OnClaimsChanged(callback func(LocalParticipant)) // session migration SetMigrateState(s MigrateState) @@ -292,6 +292,9 @@ type LocalParticipant interface { EnqueueUnsubscribeTrack(trackID livekit.TrackID, willBeResumed bool, f func(subscriberID livekit.ParticipantID, willBeResumed bool) error) ProcessSubscriptionRequestsQueue(trackID livekit.TrackID) ClearInProgressAndProcessSubscriptionRequestsQueue(trackID livekit.TrackID) + + SetICEConfig(iceConfig IceConfig) + OnICEConfigChanged(callback func(participant LocalParticipant, iceConfig IceConfig)) } // Room is a container of participants, and can provide room-level actions diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 4b1fe9182..e79f086ae 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -406,6 +406,11 @@ type FakeLocalParticipant struct { onDataPacketArgsForCall []struct { arg1 func(types.LocalParticipant, *livekit.DataPacket) } + OnICEConfigChangedStub func(func(participant types.LocalParticipant, iceConfig types.IceConfig)) + onICEConfigChangedMutex sync.RWMutex + onICEConfigChangedArgsForCall []struct { + arg1 func(participant types.LocalParticipant, iceConfig types.IceConfig) + } OnParticipantUpdateStub func(func(types.LocalParticipant)) onParticipantUpdateMutex sync.RWMutex onParticipantUpdateArgsForCall []struct { @@ -538,6 +543,11 @@ type FakeLocalParticipant struct { sendSpeakerUpdateReturnsOnCall map[int]struct { result1 error } + SetICEConfigStub func(types.IceConfig) + setICEConfigMutex sync.RWMutex + setICEConfigArgsForCall []struct { + arg1 types.IceConfig + } SetMetadataStub func(string) setMetadataMutex sync.RWMutex setMetadataArgsForCall []struct { @@ -2823,6 +2833,38 @@ func (fake *FakeLocalParticipant) OnDataPacketArgsForCall(i int) func(types.Loca return argsForCall.arg1 } +func (fake *FakeLocalParticipant) OnICEConfigChanged(arg1 func(participant types.LocalParticipant, iceConfig types.IceConfig)) { + fake.onICEConfigChangedMutex.Lock() + fake.onICEConfigChangedArgsForCall = append(fake.onICEConfigChangedArgsForCall, struct { + arg1 func(participant types.LocalParticipant, iceConfig types.IceConfig) + }{arg1}) + stub := fake.OnICEConfigChangedStub + fake.recordInvocation("OnICEConfigChanged", []interface{}{arg1}) + fake.onICEConfigChangedMutex.Unlock() + if stub != nil { + fake.OnICEConfigChangedStub(arg1) + } +} + +func (fake *FakeLocalParticipant) OnICEConfigChangedCallCount() int { + fake.onICEConfigChangedMutex.RLock() + defer fake.onICEConfigChangedMutex.RUnlock() + return len(fake.onICEConfigChangedArgsForCall) +} + +func (fake *FakeLocalParticipant) OnICEConfigChangedCalls(stub func(func(participant types.LocalParticipant, iceConfig types.IceConfig))) { + fake.onICEConfigChangedMutex.Lock() + defer fake.onICEConfigChangedMutex.Unlock() + fake.OnICEConfigChangedStub = stub +} + +func (fake *FakeLocalParticipant) OnICEConfigChangedArgsForCall(i int) func(participant types.LocalParticipant, iceConfig types.IceConfig) { + fake.onICEConfigChangedMutex.RLock() + defer fake.onICEConfigChangedMutex.RUnlock() + argsForCall := fake.onICEConfigChangedArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeLocalParticipant) OnParticipantUpdate(arg1 func(types.LocalParticipant)) { fake.onParticipantUpdateMutex.Lock() fake.onParticipantUpdateArgsForCall = append(fake.onParticipantUpdateArgsForCall, struct { @@ -3584,6 +3626,38 @@ func (fake *FakeLocalParticipant) SendSpeakerUpdateReturnsOnCall(i int, result1 }{result1} } +func (fake *FakeLocalParticipant) SetICEConfig(arg1 types.IceConfig) { + fake.setICEConfigMutex.Lock() + fake.setICEConfigArgsForCall = append(fake.setICEConfigArgsForCall, struct { + arg1 types.IceConfig + }{arg1}) + stub := fake.SetICEConfigStub + fake.recordInvocation("SetICEConfig", []interface{}{arg1}) + fake.setICEConfigMutex.Unlock() + if stub != nil { + fake.SetICEConfigStub(arg1) + } +} + +func (fake *FakeLocalParticipant) SetICEConfigCallCount() int { + fake.setICEConfigMutex.RLock() + defer fake.setICEConfigMutex.RUnlock() + return len(fake.setICEConfigArgsForCall) +} + +func (fake *FakeLocalParticipant) SetICEConfigCalls(stub func(types.IceConfig)) { + fake.setICEConfigMutex.Lock() + defer fake.setICEConfigMutex.Unlock() + fake.SetICEConfigStub = stub +} + +func (fake *FakeLocalParticipant) SetICEConfigArgsForCall(i int) types.IceConfig { + fake.setICEConfigMutex.RLock() + defer fake.setICEConfigMutex.RUnlock() + argsForCall := fake.setICEConfigArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeLocalParticipant) SetMetadata(arg1 string) { fake.setMetadataMutex.Lock() fake.setMetadataArgsForCall = append(fake.setMetadataArgsForCall, struct { @@ -4667,6 +4741,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.onCloseMutex.RUnlock() fake.onDataPacketMutex.RLock() defer fake.onDataPacketMutex.RUnlock() + fake.onICEConfigChangedMutex.RLock() + defer fake.onICEConfigChangedMutex.RUnlock() fake.onParticipantUpdateMutex.RLock() defer fake.onParticipantUpdateMutex.RUnlock() fake.onStateChangeMutex.RLock() @@ -4699,6 +4775,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.sendRoomUpdateMutex.RUnlock() fake.sendSpeakerUpdateMutex.RLock() defer fake.sendSpeakerUpdateMutex.RUnlock() + fake.setICEConfigMutex.RLock() + defer fake.setICEConfigMutex.RUnlock() fake.setMetadataMutex.RLock() defer fake.setMetadataMutex.RUnlock() fake.setMigrateInfoMutex.RLock() diff --git a/pkg/service/ingress.go b/pkg/service/ingress.go index ec6483c80..dd0934e40 100644 --- a/pkg/service/ingress.go +++ b/pkg/service/ingress.go @@ -25,7 +25,7 @@ type IngressService struct { } func NewIngressService( - conf *config.IngressConfig, + conf *config.Config, rpc ingress.RPC, store IngressStore, rs livekit.RoomService, @@ -33,7 +33,7 @@ func NewIngressService( ) *IngressService { return &IngressService{ - conf: conf, + conf: &conf.Ingress, rpc: rpc, store: store, roomService: rs, @@ -54,9 +54,13 @@ func (s *IngressService) Stop() { } func (s *IngressService) CreateIngress(ctx context.Context, req *livekit.CreateIngressRequest) (*livekit.IngressInfo, error) { - if err := EnsureRecordPermission(ctx); err != nil { + roomName, err := EnsureJoinPermission(ctx) + if err != nil { return nil, twirpAuthError(err) } + if req.RoomName != "" && req.RoomName != string(roomName) { + return nil, twirpAuthError(ErrPermissionDenied) + } sk := utils.NewGuid("") @@ -86,9 +90,14 @@ func (s *IngressService) CreateIngress(ctx context.Context, req *livekit.CreateI } func (s *IngressService) UpdateIngress(ctx context.Context, req *livekit.UpdateIngressRequest) (*livekit.IngressInfo, error) { - if err := EnsureRecordPermission(ctx); err != nil { + roomName, err := EnsureJoinPermission(ctx) + if err != nil { return nil, twirpAuthError(err) } + if req.RoomName != "" && req.RoomName != string(roomName) { + return nil, twirpAuthError(ErrPermissionDenied) + } + if s.rpc == nil { return nil, ErrIngressNotConnected } @@ -146,9 +155,14 @@ func (s *IngressService) UpdateIngress(ctx context.Context, req *livekit.UpdateI } func (s *IngressService) ListIngress(ctx context.Context, req *livekit.ListIngressRequest) (*livekit.ListIngressResponse, error) { - if err := EnsureRecordPermission(ctx); err != nil { + roomName, err := EnsureJoinPermission(ctx) + if err != nil { return nil, twirpAuthError(err) } + if req.RoomName != "" && req.RoomName != string(roomName) { + return nil, twirpAuthError(ErrPermissionDenied) + } + if s.rpc == nil { return nil, ErrIngressNotConnected } @@ -163,9 +177,10 @@ func (s *IngressService) ListIngress(ctx context.Context, req *livekit.ListIngre } func (s *IngressService) DeleteIngress(ctx context.Context, req *livekit.DeleteIngressRequest) (*livekit.IngressInfo, error) { - if err := EnsureRecordPermission(ctx); err != nil { + if _, err := EnsureJoinPermission(ctx); err != nil { return nil, twirpAuthError(err) } + if s.rpc == nil { return nil, ErrIngressNotConnected } diff --git a/pkg/service/redisstore.go b/pkg/service/redisstore.go index 3aad87344..a1e3b415d 100644 --- a/pkg/service/redisstore.go +++ b/pkg/service/redisstore.go @@ -471,7 +471,7 @@ func (s *RedisStore) StoreIngress(_ context.Context, info *livekit.IngressInfo) results, err := tx.TxPipelined(s.ctx, func(p redis.Pipeliner) error { p.HSet(s.ctx, IngressKey, info.IngressId, data) - p.HSet(s.ctx, StreamKeyKey, info.IngressId, info.StreamKey) + p.HSet(s.ctx, StreamKeyKey, info.StreamKey, info.IngressId) if oldRoom != info.RoomName { if oldRoom != "" { diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 9ec95bd2e..e29c6a355 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -25,8 +25,14 @@ const ( roomPurgeSeconds = 24 * 60 * 60 tokenRefreshInterval = 5 * time.Minute tokenDefaultTTL = 10 * time.Minute + iceConfigTTL = 60 * time.Minute ) +type iceConfigCacheEntry struct { + iceConfig types.IceConfig + modifiedAt time.Time +} + // RoomManager manages rooms and its interaction with participants. // It's responsible for creating, deleting rooms, as well as running sessions for participants type RoomManager struct { @@ -41,6 +47,8 @@ type RoomManager struct { clientConfManager clientconfiguration.ClientConfigurationManager rooms map[livekit.RoomName]*rtc.Room + + iceConfigCache map[livekit.ParticipantIdentity]*iceConfigCacheEntry } func NewLocalRoomManager( @@ -67,6 +75,8 @@ func NewLocalRoomManager( clientConfManager: clientConfManager, rooms: make(map[livekit.RoomName]*rtc.Room), + + iceConfigCache: make(map[livekit.ParticipantIdentity]*iceConfigCacheEntry), } // hook up to router @@ -264,6 +274,7 @@ func (r *RoomManager) StartSession( if err != nil { return err } + r.setIceConfig(participant) // join room opts := rtc.ParticipantOptions{ @@ -310,6 +321,14 @@ func (r *RoomManager) StartSession( logger.Errorw("could not refresh token", err) } }) + participant.OnICEConfigChanged(func(participant types.LocalParticipant, iceConfig types.IceConfig) { + r.lock.Lock() + r.iceConfigCache[participant.Identity()] = &iceConfigCacheEntry{ + iceConfig: iceConfig, + modifiedAt: time.Now(), + } + r.lock.Unlock() + }) go r.rtcSessionWorker(room, participant, requestSource) return nil @@ -618,6 +637,21 @@ func (r *RoomManager) refreshToken(participant types.LocalParticipant) error { return nil } +func (r *RoomManager) setIceConfig(participant types.LocalParticipant) { + r.lock.Lock() + iceConfigCacheEntry, ok := r.iceConfigCache[participant.Identity()] + if !ok || time.Since(iceConfigCacheEntry.modifiedAt) > iceConfigTTL { + delete(r.iceConfigCache, participant.Identity()) + r.lock.Unlock() + return + } + r.lock.Unlock() + + participant.SetICEConfig(iceConfigCacheEntry.iceConfig) +} + +// ------------------------------------ + func iceServerForStunServers(servers []string) *livekit.ICEServer { iceServer := &livekit.ICEServer{} for _, stunServer := range servers { diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index d1151f09a..f7f484dfa 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -256,6 +256,14 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } } + if _, ok := req.Message.(*livekit.SignalRequest_Ping); ok { + _ = sigConn.WriteResponse(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_Pong{ + Pong: 1, + }, + }) + continue + } if err := reqSink.WriteMessage(req); err != nil { pLogger.Warnw("error writing to request sink", err, "connID", connId) diff --git a/pkg/service/server.go b/pkg/service/server.go index 8f0ba1ea2..7d398a2f0 100644 --- a/pkg/service/server.go +++ b/pkg/service/server.go @@ -27,23 +27,25 @@ import ( ) type LivekitServer struct { - config *config.Config - egressService *EgressService - rtcService *RTCService - httpServer *http.Server - promServer *http.Server - router routing.Router - roomManager *RoomManager - turnServer *turn.Server - currentNode routing.LocalNode - running atomic.Bool - doneChan chan struct{} - closedChan chan struct{} + config *config.Config + egressService *EgressService + ingressService *IngressService + rtcService *RTCService + httpServer *http.Server + promServer *http.Server + router routing.Router + roomManager *RoomManager + turnServer *turn.Server + currentNode routing.LocalNode + running atomic.Bool + doneChan chan struct{} + closedChan chan struct{} } func NewLivekitServer(conf *config.Config, roomService livekit.RoomService, egressService *EgressService, + ingressService *IngressService, rtcService *RTCService, keyProvider auth.KeyProvider, router routing.Router, @@ -52,11 +54,12 @@ func NewLivekitServer(conf *config.Config, currentNode routing.LocalNode, ) (s *LivekitServer, err error) { s = &LivekitServer{ - config: conf, - egressService: egressService, - rtcService: rtcService, - router: router, - roomManager: roomManager, + config: conf, + egressService: egressService, + ingressService: ingressService, + rtcService: rtcService, + router: router, + roomManager: roomManager, // turn server starts automatically turnServer: turnServer, currentNode: currentNode, @@ -80,6 +83,7 @@ func NewLivekitServer(conf *config.Config, roomServer := livekit.NewRoomServiceServer(roomService) egressServer := livekit.NewEgressServer(egressService) + ingressServer := livekit.NewIngressServer(ingressService) mux := http.NewServeMux() if conf.Development { @@ -90,6 +94,7 @@ func NewLivekitServer(conf *config.Config, } mux.Handle(roomServer.PathPrefix(), roomServer) mux.Handle(egressServer.PathPrefix(), egressServer) + mux.Handle(ingressServer.PathPrefix(), ingressServer) mux.Handle("/rtc", rtcService) mux.HandleFunc("/rtc/validate", rtcService.Validate) mux.HandleFunc("/", s.healthCheck) @@ -150,6 +155,8 @@ func (s *LivekitServer) Start() error { return err } + s.ingressService.Start() + addresses := s.config.BindAddresses if addresses == nil { addresses = []string{""} @@ -238,6 +245,7 @@ func (s *LivekitServer) Start() error { s.roomManager.Stop() s.egressService.Stop() + s.ingressService.Stop() close(s.closedChan) return nil diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 8fa768742..c6bf2e300 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -16,6 +16,7 @@ import ( "github.com/livekit/protocol/auth" "github.com/livekit/protocol/egress" + "github.com/livekit/protocol/ingress" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/webhook" @@ -44,6 +45,9 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live egress.NewRedisRPCClient, getEgressStore, NewEgressService, + ingress.NewRedisRPC, + getIngressStore, + NewIngressService, NewRoomAllocator, NewRoomService, NewRTCService, @@ -175,6 +179,15 @@ func getEgressStore(s ObjectStore) EgressStore { } } +func getIngressStore(s ObjectStore) IngressStore { + switch store := s.(type) { + case *RedisStore: + return store + default: + return nil + } +} + func createClientConfiguration() clientconfiguration.ClientConfigurationManager { return clientconfiguration.NewStaticClientConfigurationManager(clientconfiguration.StaticConfigurations) } diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index d15c07306..650b96e6f 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -17,6 +17,7 @@ import ( "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/egress" + "github.com/livekit/protocol/ingress" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/webhook" @@ -61,6 +62,9 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live analyticsService := telemetry.NewAnalyticsService(conf, currentNode) telemetryService := telemetry.NewTelemetryService(notifier, analyticsService) egressService := NewEgressService(rpcClient, objectStore, egressStore, roomService, telemetryService) + rpc := ingress.NewRedisRPC(nodeID, client) + ingressStore := getIngressStore(objectStore) + ingressService := NewIngressService(conf, rpc, ingressStore, roomService, telemetryService) rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode) clientConfigurationManager := createClientConfiguration() roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager) @@ -72,7 +76,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - livekitServer, err := NewLivekitServer(conf, roomService, egressService, rtcService, keyProvider, router, roomManager, server, currentNode) + livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, rtcService, keyProvider, router, roomManager, server, currentNode) if err != nil { return nil, err } @@ -201,6 +205,15 @@ func getEgressStore(s ObjectStore) EgressStore { } } +func getIngressStore(s ObjectStore) IngressStore { + switch store := s.(type) { + case *RedisStore: + return store + default: + return nil + } +} + func createClientConfiguration() clientconfiguration.ClientConfigurationManager { return clientconfiguration.NewStaticClientConfigurationManager(clientconfiguration.StaticConfigurations) } diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index b4f4b172b..9685100bc 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -1275,7 +1275,7 @@ func (f *Forwarder) getTranslationParamsVideo(extPkt *buffer.ExtPacket, layer in if f.targetLayers.Spatial == layer { if extPkt.KeyFrame || tp.switchingToTargetLayer { // lock to target layer - f.logger.Debugw("locking to target layer", "current", f.currentLayers, "target", f.targetLayers) + f.logger.Infow("locking to target layer", "current", f.currentLayers, "target", f.targetLayers) f.currentLayers.Spatial = f.targetLayers.Spatial if !f.isTemporalSupported { f.currentLayers.Temporal = f.targetLayers.Temporal diff --git a/test/client/client.go b/test/client/client.go index eebd76b10..086f90654 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -55,6 +55,7 @@ type RTCClient struct { lossyDCSub *webrtc.DataChannel publisherConnected atomic.Bool publisherNegotiated atomic.Bool + pongReceivedAt atomic.Int64 // tracks waiting to be acked, cid => trackInfo pendingPublishedTracks map[string]*livekit.TrackInfo @@ -347,6 +348,8 @@ func (c *RTCClient) Run() error { delete(c.trackSenders, sid) delete(c.localTracks, sid) c.lock.Unlock() + case *livekit.SignalResponse_Pong: + c.pongReceivedAt.Store(msg.Pong) } } } @@ -441,6 +444,18 @@ func (c *RTCClient) RefreshToken() string { return c.refreshToken } +func (c *RTCClient) PongReceivedAt() int64 { + return c.pongReceivedAt.Load() +} + +func (c *RTCClient) SendPing() error { + return c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_Ping{ + Ping: time.Now().UnixNano(), + }, + }) +} + func (c *RTCClient) SendRequest(msg *livekit.SignalRequest) error { payload, err := proto.Marshal(msg) if err != nil { diff --git a/test/singlenode_test.go b/test/singlenode_test.go index b0a8e9fb4..5af6602e0 100644 --- a/test/singlenode_test.go +++ b/test/singlenode_test.go @@ -328,6 +328,23 @@ func TestSingleNodeCORS(t *testing.T) { require.Equal(t, "testhost.com", res.Header.Get("Access-Control-Allow-Origin")) } +func TestPingPong(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + _, finish := setupSingleNodeTest("TestPingPong") + defer finish() + + c1 := createRTCClient("c1", defaultServerPort, nil) + waitUntilConnected(t, c1) + + require.NoError(t, c1.SendPing()) + require.Eventually(t, func() bool { + return c1.PongReceivedAt() > 0 + }, time.Second, 10*time.Millisecond) +} + func TestSingleNodeJoinAfterClose(t *testing.T) { if testing.Short() { t.SkipNow()