diff --git a/go.mod b/go.mod index cddb7ceec..f5376a0b2 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20241128072814-c363618d4c98 - github.com/livekit/protocol v1.29.5-0.20241209183753-f6b5078b2244 + github.com/livekit/protocol v1.29.5-0.20241217013317-bc388341b9f2 github.com/livekit/psrpc v0.6.1-0.20241018124827-1efff3d113a8 github.com/mackerelio/go-osstat v0.2.5 github.com/magefile/mage v1.15.0 @@ -29,16 +29,17 @@ require ( github.com/mitchellh/go-homedir v1.1.0 github.com/olekukonko/tablewriter v0.0.5 github.com/ory/dockertest/v3 v3.11.0 + github.com/pion/datachannel v1.5.10 github.com/pion/dtls/v3 v3.0.4 github.com/pion/ice/v4 v4.0.3 github.com/pion/interceptor v0.1.37 - github.com/pion/rtcp v1.2.14 + github.com/pion/rtcp v1.2.15 github.com/pion/rtp v1.8.9 - github.com/pion/sctp v1.8.34 + github.com/pion/sctp v1.8.35 github.com/pion/sdp/v3 v3.0.9 github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v4 v4.0.0 - github.com/pion/webrtc/v4 v4.0.5 + github.com/pion/webrtc/v4 v4.0.6 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.20.5 github.com/redis/go-redis/v9 v9.7.0 @@ -108,7 +109,6 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opencontainers/runc v1.1.14 // indirect - github.com/pion/datachannel v1.5.9 // indirect github.com/pion/logging v0.2.2 // indirect github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/randutil v0.1.0 // indirect diff --git a/go.sum b/go.sum index 9999d9280..fe7cca267 100644 --- a/go.sum +++ b/go.sum @@ -165,8 +165,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20241128072814-c363618d4c98 h1:QA7DqIC/ZSsMj8HC0+zNfMMwssHbA0alZALK68r30LQ= github.com/livekit/mediatransportutil v0.0.0-20241128072814-c363618d4c98/go.mod h1:WIVFAGzVZ7VMjPC5+nbSfwdFjWcbuLgx97KeNSUDTEo= -github.com/livekit/protocol v1.29.5-0.20241209183753-f6b5078b2244 h1:Eg9HK+5bMCDRKhh5g5g16oyNaMbCqMrJvxFBaBuP7Vo= -github.com/livekit/protocol v1.29.5-0.20241209183753-f6b5078b2244/go.mod h1:NDg1btMpKCzr/w6QR5kDuXw/e4Y7yOBE+RUAHsc+Y/M= +github.com/livekit/protocol v1.29.5-0.20241217013317-bc388341b9f2 h1:knHtTlhR89ly9TZ2JiyfT1ibqziv/rDcfSf3voQw8rE= +github.com/livekit/protocol v1.29.5-0.20241217013317-bc388341b9f2/go.mod h1:NDg1btMpKCzr/w6QR5kDuXw/e4Y7yOBE+RUAHsc+Y/M= github.com/livekit/psrpc v0.6.1-0.20241018124827-1efff3d113a8 h1:Ibh0LoFl5NW5a1KFJEE0eLxxz7dqqKmYTj/BfCb0PbY= github.com/livekit/psrpc v0.6.1-0.20241018124827-1efff3d113a8/go.mod h1:CQUBSPfYYAaevg1TNCc6/aYsa8DJH4jSRFdCeSZk5u0= github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o= @@ -228,8 +228,8 @@ github.com/opencontainers/runc v1.1.14 h1:rgSuzbmgz5DUJjeSnw337TxDbRuqjs6iqQck/2 github.com/opencontainers/runc v1.1.14/go.mod h1:E4C2z+7BxR7GHXp0hAY53mek+x49X1LjPNeMTfRGvOA= github.com/ory/dockertest/v3 v3.11.0 h1:OiHcxKAvSDUwsEVh2BjxQQc/5EHz9n0va9awCtNGuyA= github.com/ory/dockertest/v3 v3.11.0/go.mod h1:VIPxS1gwT9NpPOrfD3rACs8Y9Z7yhzO4SB194iUDnUI= -github.com/pion/datachannel v1.5.9 h1:LpIWAOYPyDrXtU+BW7X0Yt/vGtYxtXQ8ql7dFfYUVZA= -github.com/pion/datachannel v1.5.9/go.mod h1:kDUuk4CU4Uxp82NH4LQZbISULkX/HtzKa4P7ldf9izE= +github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o= +github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M= github.com/pion/dtls/v3 v3.0.4 h1:44CZekewMzfrn9pmGrj5BNnTMDCFwr+6sLH+cCuLM7U= github.com/pion/dtls/v3 v3.0.4/go.mod h1:R373CsjxWqNPf6MEkfdy3aSe9niZvL/JaKlGeFphtMg= github.com/pion/ice/v4 v4.0.3 h1:9s5rI1WKzF5DRqhJ+Id8bls/8PzM7mau0mj1WZb4IXE= @@ -242,12 +242,12 @@ github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= -github.com/pion/rtcp v1.2.14 h1:KCkGV3vJ+4DAJmvP0vaQShsb0xkRfWkO540Gy102KyE= -github.com/pion/rtcp v1.2.14/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4= +github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo= +github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0= github.com/pion/rtp v1.8.9 h1:E2HX740TZKaqdcPmf4pw6ZZuG8u5RlMMt+l3dxeu6Wk= github.com/pion/rtp v1.8.9/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= -github.com/pion/sctp v1.8.34 h1:rCuD3m53i0oGxCSp7FLQKvqVx0Nf5AUAHhMRXTTQjBc= -github.com/pion/sctp v1.8.34/go.mod h1:yWkCClkXlzVW7BXfI2PjrUGBwUI0CjXJBkhLt+sdo4U= +github.com/pion/sctp v1.8.35 h1:qwtKvNK1Wc5tHMIYgTDJhfZk7vATGVHhXbUDfHbYwzA= +github.com/pion/sctp v1.8.35/go.mod h1:EcXP8zCYVTRy3W9xtOF7wJm1L1aXfKRQzaM33SjQlzg= github.com/pion/sdp/v3 v3.0.9 h1:pX++dCHoHUwq43kuwf3PyJfHlwIj4hXA7Vrifiq0IJY= github.com/pion/sdp/v3 v3.0.9/go.mod h1:B5xmvENq5IXJimIO4zfp6LAe1fD9N+kFv+V/1lOdz8M= github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M= @@ -258,8 +258,8 @@ github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1 github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM= github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA= -github.com/pion/webrtc/v4 v4.0.5 h1:8cVPojcv3cQTwVga2vF1rzCNvkiEimnYdCCG7yF317I= -github.com/pion/webrtc/v4 v4.0.5/go.mod h1:LvP8Np5b/sM0uyJIcUPvJcCvhtjHxJwzh2H2PYzE6cQ= +github.com/pion/webrtc/v4 v4.0.6 h1:OfxfGeZGhneUDnZEoebLGDkzwjowSJ0avbOu2xaIUeM= +github.com/pion/webrtc/v4 v4.0.6/go.mod h1:j7oMHYvjl7lESJ/nYiE4d2URyjFbAo3uqJ6Xse6hbSg= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/pkg/config/config.go b/pkg/config/config.go index afc27c2a6..bf2a59be0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -84,6 +84,7 @@ type RTCConfig struct { TURNServers []TURNServer `yaml:"turn_servers,omitempty"` + // Deprecated StrictACKs bool `yaml:"strict_acks,omitempty"` // Deprecated: use PacketBufferSizeVideo and PacketBufferSizeAudio @@ -110,9 +111,13 @@ type RTCConfig struct { // force a reconnect on a data channel error ReconnectOnDataChannelError *bool `yaml:"reconnect_on_data_channel_error,omitempty"` - // max number of bytes to buffer for data channel. 0 means unlimited + // Deprecated DataChannelMaxBufferedAmount uint64 `yaml:"data_channel_max_buffered_amount,omitempty"` + // Threshold of data channel writing to be considered too slow, data packet could + // be dropped for a slow data channel to avoid blocking the room. + DatachannelSlowThreshold int `yaml:"datachannel_slow_threshold,omitempty"` + ForwardStats ForwardStatsConfig `yaml:"forward_stats,omitempty"` } @@ -311,7 +316,6 @@ var DefaultConfig = Config{ PacketBufferSize: 500, PacketBufferSizeVideo: 500, PacketBufferSizeAudio: 200, - StrictACKs: true, PLIThrottle: sfu.DefaultPLIThrottleConfig, CongestionControl: CongestionControlConfig{ Enabled: true, @@ -322,6 +326,7 @@ var DefaultConfig = Config{ UseSendSideBWE: false, SendSideBWE: sendsidebwe.DefaultSendSideBWEConfig, }, + DatachannelSlowThreshold: 1000000, }, Audio: sfu.DefaultAudioConfig, Video: VideoConfig{ diff --git a/pkg/rtc/config.go b/pkg/rtc/config.go index 8e2820195..a95a22526 100644 --- a/pkg/rtc/config.go +++ b/pkg/rtc/config.go @@ -58,7 +58,6 @@ type RTCPFeedbackConfig struct { type DirectionConfig struct { RTPHeaderExtension RTPHeaderExtensionConfig RTCPFeedback RTCPFeedbackConfig - StrictACKs bool } func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { @@ -86,7 +85,6 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { // publisher configuration publisherConfig := DirectionConfig{ - StrictACKs: true, // publisher is dialed, and will always reply with ACK RTPHeaderExtension: RTPHeaderExtensionConfig{ Audio: []string{ sdp.SDESMidURI, @@ -119,7 +117,6 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { // subscriber configuration subscriberConfig := DirectionConfig{ - StrictACKs: rtcConf.StrictACKs, RTPHeaderExtension: RTPHeaderExtensionConfig{ Video: []string{ dd.ExtensionURI, @@ -130,6 +127,10 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { }, }, RTCPFeedback: RTCPFeedbackConfig{ + Audio: []webrtc.RTCPFeedback{ + // always enable NACK for audio but disable it later for red enabled transceiver. https://github.com/pion/webrtc/pull/2972 + {Type: webrtc.TypeRTCPFBNACK}, + }, Video: []webrtc.RTCPFeedback{ {Type: webrtc.TypeRTCPFBCCM, Parameter: "fir"}, {Type: webrtc.TypeRTCPFBNACK}, diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index a8f1e0b56..fbcbdd228 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -147,7 +147,6 @@ type ParticipantParams struct { ReconnectOnPublicationError bool ReconnectOnSubscriptionError bool ReconnectOnDataChannelError bool - DataChannelMaxBufferedAmount uint64 VersionGenerator utils.TimedVersionGenerator TrackResolver types.MediaTrackResolver DisableDynacast bool @@ -161,6 +160,8 @@ type ParticipantParams struct { MetricConfig metric.MetricConfig UseOneShotSignallingMode bool EnableMetrics bool + DataChannelMaxBufferedAmount uint64 + DatachannelSlowThreshold int FireOnTrackBySdp bool } @@ -1565,6 +1566,7 @@ func (p *ParticipantImpl) setupTransportManager() error { TURNSEnabled: p.params.TURNSEnabled, AllowPlayoutDelay: p.params.PlayoutDelay.GetEnabled(), DataChannelMaxBufferedAmount: p.params.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: p.params.DatachannelSlowThreshold, Logger: p.params.Logger.WithComponent(sutils.ComponentTransport), PublisherHandler: pth, SubscriberHandler: sth, diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index f6dff4eb2..412f787b1 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -52,7 +52,7 @@ const ( invAudioLevelQuantization = 1.0 / AudioLevelQuantization subscriberUpdateInterval = 3 * time.Second - dataForwardLoadBalanceThreshold = 20 + dataForwardLoadBalanceThreshold = 4 simulateDisconnectSignalTimeout = 5 * time.Second ) diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 0d76fbf70..21990f6db 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -16,6 +16,7 @@ package rtc import ( "fmt" + "io" "net" "strconv" "strings" @@ -29,7 +30,6 @@ import ( "github.com/pion/interceptor/pkg/gcc" "github.com/pion/interceptor/pkg/twcc" "github.com/pion/rtcp" - "github.com/pion/sctp" "github.com/pion/sdp/v3" "github.com/pion/webrtc/v4" "github.com/pkg/errors" @@ -42,6 +42,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/bwe" "github.com/livekit/livekit-server/pkg/sfu/bwe/remotebwe" "github.com/livekit/livekit-server/pkg/sfu/bwe/sendsidebwe" + "github.com/livekit/livekit-server/pkg/sfu/datachannel" sfuinterceptor "github.com/livekit/livekit-server/pkg/sfu/interceptor" "github.com/livekit/livekit-server/pkg/sfu/pacer" pd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/playoutdelay" @@ -78,6 +79,8 @@ const ( maxConnectTimeoutAfterICE = 20 * time.Second // max duration for waiting pc to connect after ICE is connected shortConnectionThreshold = 90 * time.Second + + dataChannelBufferSize = 65535 ) var ( @@ -188,9 +191,9 @@ type PCTransport struct { firstOfferReceived bool firstOfferNoDataChannel bool - reliableDC *webrtc.DataChannel + reliableDC *datachannel.DataChannelWriter[*webrtc.DataChannel] reliableDCOpened bool - lossyDC *webrtc.DataChannel + lossyDC *datachannel.DataChannelWriter[*webrtc.DataChannel] lossyDCOpened bool iceStartedAt time.Time @@ -258,9 +261,13 @@ type TransportParams struct { IsOfferer bool IsSendSide bool AllowPlayoutDelay bool - DataChannelMaxBufferedAmount uint64 UseOneShotSignallingMode bool FireOnTrackBySdp bool + DataChannelMaxBufferedAmount uint64 + DatachannelSlowThreshold int + + // for development test + DatachannelMaxReceiverBufferSize int } func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimator cc.BandwidthEstimator)) (*webrtc.PeerConnection, *webrtc.MediaEngine, error) { @@ -288,6 +295,13 @@ func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimat // https://github.com/pion/webrtc/pull/2961 se.DisableCloseByDTLS(true) + se.DetachDataChannels() + if params.DatachannelSlowThreshold > 0 { + se.EnableDataChannelBlockWrite(true) + } + if params.DatachannelMaxReceiverBufferSize > 0 { + se.SetSCTPMaxReceiveBufferSize(uint32(params.DatachannelMaxReceiverBufferSize)) + } if params.FireOnTrackBySdp { se.SetFireOnTrackBeforeFirstRTP(true) } @@ -361,7 +375,6 @@ func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimat ir := &interceptor.Registry{} if params.IsSendSide { - se.DetachDataChannels() if params.CongestionControlConfig.UseSendSideBWEInterceptor && !params.CongestionControlConfig.UseSendSideBWE { params.Logger.Infow("using send side BWE - interceptor") gf, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { @@ -717,31 +730,65 @@ func (t *PCTransport) onPeerConnectionStateChange(state webrtc.PeerConnectionSta } func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { - t.params.Logger.Debugw(dc.Label() + " data channel open") - switch dc.Label() { - case ReliableDataChannel: - t.lock.Lock() - t.reliableDC = dc - t.reliableDCOpened = true - t.lock.Unlock() - dc.OnMessage(func(msg webrtc.DataChannelMessage) { - t.params.Handler.OnDataPacket(livekit.DataPacket_RELIABLE, msg.Data) - }) + dc.OnOpen(func() { + t.params.Logger.Debugw(dc.Label() + " data channel open") + var kind livekit.DataPacket_Kind + switch dc.Label() { + case ReliableDataChannel: + kind = livekit.DataPacket_RELIABLE + + case LossyDataChannel: + kind = livekit.DataPacket_LOSSY + + default: + t.params.Logger.Warnw("unsupported datachannel added", nil, "label", dc.Label()) + return + } + + rawDC, err := dc.DetachWithDeadline() + if err != nil { + t.params.Logger.Errorw("failed to detach data channel", err, "label", dc.Label()) + return + } + + switch kind { + case livekit.DataPacket_RELIABLE: + t.lock.Lock() + if t.reliableDC != nil { + t.reliableDC.Close() + } + t.reliableDC = datachannel.NewDataChannelWriter(dc, rawDC, t.params.DatachannelSlowThreshold) + t.reliableDCOpened = true + t.lock.Unlock() + + case livekit.DataPacket_LOSSY: + t.lock.Lock() + if t.lossyDC != nil { + t.lossyDC.Close() + } + t.lossyDC = datachannel.NewDataChannelWriter(dc, rawDC, 0) + t.lossyDCOpened = true + t.lock.Unlock() + } + + go func() { + defer rawDC.Close() + buffer := make([]byte, dataChannelBufferSize) + for { + n, _, err := rawDC.ReadDataChannel(buffer) + if err != nil { + if !errors.Is(err, io.EOF) { + t.params.Logger.Warnw("error reading data channel", err, "label", dc.Label()) + } + return + } + + t.params.Handler.OnDataPacket(kind, buffer[:n]) + } + }() t.maybeNotifyFullyEstablished() - case LossyDataChannel: - t.lock.Lock() - t.lossyDC = dc - t.lossyDCOpened = true - t.lock.Unlock() - dc.OnMessage(func(msg webrtc.DataChannelMessage) { - t.params.Handler.OnDataPacket(livekit.DataPacket_LOSSY, msg.Data) - }) - - t.maybeNotifyFullyEstablished() - default: - t.params.Logger.Warnw("unsupported datachannel added", nil, "label", dc.Label()) - } + }) } func (t *PCTransport) maybeNotifyFullyEstablished() { @@ -866,7 +913,7 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni return err } var ( - dcPtr **webrtc.DataChannel + dcPtr **datachannel.DataChannelWriter[*webrtc.DataChannel] dcReady *bool ) switch dc.Label() { @@ -883,60 +930,41 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni dcReady = &t.lossyDCOpened } - dcReadyHandler := func() { + dc.OnOpen(func() { + rawDC, err := dc.DetachWithDeadline() + if err != nil { + t.params.Logger.Warnw("failed to detach data channel", err) + return + } + + var slowThreshold int + if dc.Label() == ReliableDataChannel { + slowThreshold = t.params.DatachannelSlowThreshold + } + t.lock.Lock() + if *dcPtr != nil { + (*dcPtr).Close() + } + *dcPtr = datachannel.NewDataChannelWriter(dc, rawDC, slowThreshold) *dcReady = true t.lock.Unlock() t.params.Logger.Debugw(dc.Label() + " data channel open") t.maybeNotifyFullyEstablished() - } + }) - dcCloseHandler := func() { - t.params.Logger.Debugw(dc.Label() + " data channel close") - } - - dcErrorHandler := func(err error) { - if !errors.Is(err, sctp.ErrResetPacketInStateNotExist) && !errors.Is(err, sctp.ErrChunk) { - t.params.Logger.Warnw(dc.Label()+" data channel error", err) - } - } - - t.lock.Lock() - defer t.lock.Unlock() - *dcPtr = dc - if t.params.DirectionConfig.StrictACKs { - dc.OnOpen(func() { - if t.params.IsSendSide { - if _, err := dc.Detach(); err != nil { - t.params.Logger.Warnw("failed to detach data channel", err) - } - } - dcReadyHandler() - }) - } else { - dc.OnOpen(func() { - if t.params.IsSendSide { - if _, err := dc.Detach(); err != nil { - t.params.Logger.Warnw("failed to detach data channel", err) - } - } - }) - dc.OnDial(dcReadyHandler) - } - dc.OnClose(dcCloseHandler) - dc.OnError(dcErrorHandler) return nil } func (t *PCTransport) CreateDataChannelIfEmpty(dcLabel string, dci *webrtc.DataChannelInit) (label string, id uint16, existing bool, err error) { t.lock.RLock() - var dc *webrtc.DataChannel + var dcw *datachannel.DataChannelWriter[*webrtc.DataChannel] switch dcLabel { case ReliableDataChannel: - dc = t.reliableDC + dcw = t.reliableDC case LossyDataChannel: - dc = t.lossyDC + dcw = t.lossyDC default: t.params.Logger.Warnw("unknown data channel label", nil, "label", label) err = errors.New("unknown data channel label") @@ -946,11 +974,12 @@ func (t *PCTransport) CreateDataChannelIfEmpty(dcLabel string, dci *webrtc.DataC return } - if dc != nil { + if dcw != nil { + dc := dcw.BufferedAmountGetter() return dc.Label(), *dc.ID(), true, nil } - dc, err = t.pc.CreateDataChannel(dcLabel, dci) + dc, err := t.pc.CreateDataChannel(dcLabel, dci) if err != nil { return } @@ -989,7 +1018,7 @@ func (t *PCTransport) WriteRTCP(pkts []rtcp.Packet) error { } func (t *PCTransport) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byte) error { - var dc *webrtc.DataChannel + var dc *datachannel.DataChannelWriter[*webrtc.DataChannel] t.lock.RLock() if kind == livekit.DataPacket_RELIABLE { dc = t.reliableDC @@ -1006,11 +1035,12 @@ func (t *PCTransport) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byt return ErrTransportFailure } - if t.params.DataChannelMaxBufferedAmount > 0 && dc.BufferedAmount() > t.params.DataChannelMaxBufferedAmount { + if t.params.DatachannelSlowThreshold == 0 && t.params.DataChannelMaxBufferedAmount > 0 && dc.BufferedAmountGetter().BufferedAmount() > t.params.DataChannelMaxBufferedAmount { return ErrDataChannelBufferFull } + _, err := dc.Write(encoded) - return dc.Send(encoded) + return err } func (t *PCTransport) Close() { @@ -1031,6 +1061,18 @@ func (t *PCTransport) Close() { _ = t.pc.Close() t.clearConnTimer() + + t.lock.Lock() + if t.reliableDC != nil { + t.reliableDC.Close() + t.reliableDC = nil + } + + if t.lossyDC != nil { + t.lossyDC.Close() + t.lossyDC = nil + } + t.lock.Unlock() } func (t *PCTransport) clearConnTimer() { @@ -2020,6 +2062,8 @@ func (t *PCTransport) handleICERestart(_ event) error { } // configure subscriber transceiver for audio stereo and nack +// pion doesn't support per transciver codec configuration, so the nack of this session will be disabled +// forever once it is first disabled by a transceiver. func configureAudioTransceiver(tr *webrtc.RTPTransceiver, stereo bool, nack bool) { sender := tr.Sender() if sender == nil { @@ -2034,17 +2078,13 @@ func configureAudioTransceiver(tr *webrtc.RTPTransceiver, stereo bool, nack bool if stereo { c.SDPFmtpLine += ";sprop-stereo=1" } - if nack { - var nackFound bool - for _, fb := range c.RTCPFeedback { + if !nack { + for i, fb := range c.RTCPFeedback { if fb.Type == webrtc.TypeRTCPFBNACK { - nackFound = true + c.RTCPFeedback = append(c.RTCPFeedback[:i], c.RTCPFeedback[i+1:]...) break } } - if !nackFound { - c.RTCPFeedback = append(c.RTCPFeedback, webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBNACK}) - } } } configCodecs = append(configCodecs, c) diff --git a/pkg/rtc/transport_test.go b/pkg/rtc/transport_test.go index 16bff8b52..512b2ff91 100644 --- a/pkg/rtc/transport_test.go +++ b/pkg/rtc/transport_test.go @@ -587,10 +587,6 @@ func untilTransportsConnected(transports ...*transportfakes.FakeHandler) *sync.W } func TestConfigureAudioTransceiver(t *testing.T) { - pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) - require.NoError(t, err) - defer pc.Close() - for _, testcase := range []struct { nack bool stereo bool @@ -601,6 +597,11 @@ func TestConfigureAudioTransceiver(t *testing.T) { {true, true}, } { t.Run(fmt.Sprintf("nack=%v,stereo=%v", testcase.nack, testcase.stereo), func(t *testing.T) { + var me webrtc.MediaEngine + registerCodecs(&me, []*livekit.Codec{{Mime: webrtc.MimeTypeOpus}}, RTCPFeedbackConfig{Audio: []webrtc.RTCPFeedback{{Type: webrtc.TypeRTCPFBNACK}}}, false) + pc, err := webrtc.NewAPI(webrtc.WithMediaEngine(&me)).NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pc.Close() tr, err := pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendonly}) require.NoError(t, err) diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index 275219fee..68889d70e 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -15,6 +15,7 @@ package rtc import ( + "context" "io" "math/bits" "sync" @@ -99,6 +100,7 @@ type TransportManagerParams struct { TURNSEnabled bool AllowPlayoutDelay bool DataChannelMaxBufferedAmount uint64 + DatachannelSlowThreshold int Logger logger.Logger PublisherHandler transport.Handler SubscriberHandler transport.Handler @@ -146,21 +148,23 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro lgr := LoggerWithPCTarget(params.Logger, livekit.SignalTarget_PUBLISHER) publisher, err := NewPCTransport(TransportParams{ - ParticipantID: params.SID, - ParticipantIdentity: params.Identity, - ProtocolVersion: params.ProtocolVersion, - Config: params.Config, - Twcc: params.Twcc, - DirectionConfig: params.Config.Publisher, - CongestionControlConfig: params.CongestionControlConfig, - EnabledCodecs: params.EnabledPublishCodecs, - Logger: lgr, - SimTracks: params.SimTracks, - ClientInfo: params.ClientInfo, - Transport: livekit.SignalTarget_PUBLISHER, - Handler: TransportManagerPublisherTransportHandler{TransportManagerTransportHandler{params.PublisherHandler, t, lgr}}, - UseOneShotSignallingMode: params.UseOneShotSignallingMode, - FireOnTrackBySdp: params.FireOnTrackBySdp, + ParticipantID: params.SID, + ParticipantIdentity: params.Identity, + ProtocolVersion: params.ProtocolVersion, + Config: params.Config, + Twcc: params.Twcc, + DirectionConfig: params.Config.Publisher, + CongestionControlConfig: params.CongestionControlConfig, + EnabledCodecs: params.EnabledPublishCodecs, + Logger: lgr, + SimTracks: params.SimTracks, + ClientInfo: params.ClientInfo, + Transport: livekit.SignalTarget_PUBLISHER, + Handler: TransportManagerPublisherTransportHandler{TransportManagerTransportHandler{params.PublisherHandler, t, lgr}}, + UseOneShotSignallingMode: params.UseOneShotSignallingMode, + DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: params.DatachannelSlowThreshold, + FireOnTrackBySdp: params.FireOnTrackBySdp, }) if err != nil { return nil, err @@ -169,21 +173,21 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro lgr = LoggerWithPCTarget(params.Logger, livekit.SignalTarget_SUBSCRIBER) subscriber, err := NewPCTransport(TransportParams{ - ParticipantID: params.SID, - ParticipantIdentity: params.Identity, - ProtocolVersion: params.ProtocolVersion, - Config: params.Config, - DirectionConfig: params.Config.Subscriber, - CongestionControlConfig: params.CongestionControlConfig, - EnabledCodecs: params.EnabledSubscribeCodecs, - Logger: lgr, - ClientInfo: params.ClientInfo, - IsOfferer: true, - IsSendSide: true, - AllowPlayoutDelay: params.AllowPlayoutDelay, - DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, - Transport: livekit.SignalTarget_SUBSCRIBER, - Handler: TransportManagerTransportHandler{params.SubscriberHandler, t, lgr}, + ParticipantID: params.SID, + ParticipantIdentity: params.Identity, + ProtocolVersion: params.ProtocolVersion, + Config: params.Config, + DirectionConfig: params.Config.Subscriber, + CongestionControlConfig: params.CongestionControlConfig, + EnabledCodecs: params.EnabledSubscribeCodecs, + Logger: lgr, + ClientInfo: params.ClientInfo, + IsOfferer: true, + IsSendSide: true, + AllowPlayoutDelay: params.AllowPlayoutDelay, + DatachannelSlowThreshold: params.DatachannelSlowThreshold, + Transport: livekit.SignalTarget_SUBSCRIBER, + Handler: TransportManagerTransportHandler{params.SubscriberHandler, t, lgr}, }) if err != nil { return nil, err @@ -294,7 +298,7 @@ func (t *TransportManager) SendDataPacket(kind livekit.DataPacket_Kind, encoded // downstream data is sent via primary peer connection err := t.getTransport(true).SendDataPacket(kind, encoded) if err != nil { - if !utils.ErrorIsOneOf(err, io.ErrClosedPipe, sctp.ErrStreamClosed, ErrTransportFailure, ErrDataChannelBufferFull) { + if !utils.ErrorIsOneOf(err, io.ErrClosedPipe, sctp.ErrStreamClosed, ErrTransportFailure, ErrDataChannelBufferFull, context.DeadlineExceeded) { t.params.Logger.Warnw("send data packet error", err) } if utils.ErrorIsOneOf(err, sctp.ErrStreamClosed, io.ErrClosedPipe) { diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 0977c2f41..a33730c98 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -480,7 +480,6 @@ func (r *RoomManager) StartSession( ReconnectOnPublicationError: reconnectOnPublicationError, ReconnectOnSubscriptionError: reconnectOnSubscriptionError, ReconnectOnDataChannelError: reconnectOnDataChannelError, - DataChannelMaxBufferedAmount: r.config.RTC.DataChannelMaxBufferedAmount, VersionGenerator: r.versionGenerator, TrackResolver: room.ResolveMediaTrackForSubscriber, SubscriberAllowPause: subscriberAllowPause, @@ -491,6 +490,8 @@ func (r *RoomManager) StartSession( ForwardStats: r.forwardStats, MetricConfig: r.config.Metric, UseOneShotSignallingMode: useOneShotSignallingMode, + DataChannelMaxBufferedAmount: r.config.RTC.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: r.config.RTC.DatachannelSlowThreshold, FireOnTrackBySdp: true, }) if err != nil { diff --git a/pkg/sfu/datachannel/bitrate.go b/pkg/sfu/datachannel/bitrate.go new file mode 100644 index 000000000..1e31a8c2b --- /dev/null +++ b/pkg/sfu/datachannel/bitrate.go @@ -0,0 +1,93 @@ +package datachannel + +import ( + "sync" + "time" + + "github.com/gammazero/deque" + + "github.com/livekit/protocol/utils/mono" +) + +const ( + BitrateDuration = 2 * time.Second + BitrateWindow = 100 * time.Millisecond +) + +// BitrateCalculator calculates bitrate over sliding window +type BitrateCalculator struct { + lock sync.Mutex + windowDuration time.Duration + duration time.Duration + + windows *deque.Deque[bitrateWindow] + active bitrateWindow + + bytes int + lastBufferedAmount int + start time.Time +} + +func NewBitrateCalculator(duration time.Duration, window time.Duration) *BitrateCalculator { + windowCnt := int((duration + (window - 1)) / window) + if windowCnt == 0 { + windowCnt = 1 + } + now := mono.Now() + c := &BitrateCalculator{ + duration: duration, + windowDuration: window, + windows: deque.New[bitrateWindow](windowCnt+1, windowCnt+1), + start: now, + active: bitrateWindow{start: now}, + } + + return c +} + +func (c *BitrateCalculator) AddBytes(bytes int, bufferedAmout int, ts time.Time) { + c.lock.Lock() + defer c.lock.Unlock() + + bytes -= bufferedAmout - c.lastBufferedAmount + c.lastBufferedAmount = bufferedAmout + if ts.Sub(c.active.start) >= c.windowDuration { + c.windows.PushBack(c.active) + c.active.start = ts + c.active.bytes = 0 + + for c.windows.Len() > 0 { + // pop expired windows + if w := c.windows.Front(); ts.Sub(w.start) > (c.duration + c.windowDuration) { + c.bytes -= w.bytes + c.windows.PopFront() + } else { + c.start = w.start + break + } + } + if c.windows.Len() == 0 { + c.start = ts + c.bytes = 0 + } + } + c.bytes += bytes + c.active.bytes += bytes + +} + +func (c *BitrateCalculator) Bitrate(ts time.Time) int { + c.lock.Lock() + defer c.lock.Unlock() + duration := ts.Sub(c.start) + if duration < c.windowDuration { + duration = c.windowDuration + } + + return c.bytes * 8 * 1000 / int(duration.Milliseconds()) +} + +type bitrateWindow struct { + start time.Time + bytes int +} diff --git a/pkg/sfu/datachannel/bitrate_test.go b/pkg/sfu/datachannel/bitrate_test.go new file mode 100644 index 000000000..4b4e0a1a6 --- /dev/null +++ b/pkg/sfu/datachannel/bitrate_test.go @@ -0,0 +1,29 @@ +package datachannel + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestBitrateCalculator(t *testing.T) { + c := NewBitrateCalculator(BitrateDuration, BitrateWindow) + require.NotNil(t, c) + + t0 := time.Now() + c.AddBytes(100, 0, t0) + // bytes buffered + c.AddBytes(100, 100, t0.Add(50*time.Millisecond)) + // 50 bytes sent (50 bytes buffer flushed) + c.AddBytes(100, 50, t0.Add(time.Second)) + + // 250 bytes sent in 1 second + require.Equal(t, 2000, c.Bitrate(t0.Add(time.Second))) + + // silence for long time + t1 := t0.Add(2 * BitrateDuration) + // 150 bytes sent (50 bytes buffer flushed) + c.AddBytes(100, 0, t1) + require.Equal(t, 1200, c.Bitrate(t1.Add(time.Second))) +} diff --git a/pkg/sfu/datachannel/datachannel_writer.go b/pkg/sfu/datachannel/datachannel_writer.go new file mode 100644 index 000000000..2ac05b351 --- /dev/null +++ b/pkg/sfu/datachannel/datachannel_writer.go @@ -0,0 +1,75 @@ +package datachannel + +import ( + "context" + "errors" + "time" + + "github.com/pion/datachannel" + + "github.com/livekit/protocol/utils/mono" +) + +const ( + singleWriteTimeout = 50 * time.Millisecond +) + +var ErrDataDroppedBySlowReader = errors.New("data dropped by slow reader") + +type BufferedAmountGetter interface { + BufferedAmount() uint64 +} + +type DataChannelWriter[T BufferedAmountGetter] struct { + bufferGetter T + rawDC datachannel.ReadWriteCloserDeadliner + slowThreshold int + rate *BitrateCalculator +} + +// NewDataChannelWriter creates a new DataChannelWriter by detaching the data channel, when +// writing to the datachanel times out, it will block and retry if the receiver's bitrate is +// above the slowThreshold or drop the data if it's below the threshold. If the slowThreshold +// is 0, it will never retry on write timeout. +func NewDataChannelWriter[T BufferedAmountGetter](bufferGetter T, rawDC datachannel.ReadWriteCloserDeadliner, slowThreshold int) *DataChannelWriter[T] { + var rate *BitrateCalculator + if slowThreshold > 0 { + rate = NewBitrateCalculator(BitrateDuration, BitrateWindow) + } + return &DataChannelWriter[T]{ + bufferGetter: bufferGetter, + rawDC: rawDC, + slowThreshold: slowThreshold, + rate: rate, + } +} + +func (w *DataChannelWriter[T]) BufferedAmountGetter() T { + return w.bufferGetter +} + +func (w *DataChannelWriter[T]) Write(p []byte) (n int, err error) { + for { + err = w.rawDC.SetWriteDeadline(time.Now().Add(singleWriteTimeout)) + if err != nil { + return 0, err + } + n, err = w.rawDC.Write(p) + if w.slowThreshold == 0 { + return + } + + now := mono.Now() + w.rate.AddBytes(n, int(w.bufferGetter.BufferedAmount()), now) + // retry if the write timed out on a non-slow receiver + if errors.Is(err, context.DeadlineExceeded) && w.rate.Bitrate(now) > w.slowThreshold { + continue + } + + return + } +} + +func (w *DataChannelWriter[T]) Close() error { + return w.rawDC.Close() +} diff --git a/pkg/sfu/datachannel/datachannel_writer_test.go b/pkg/sfu/datachannel/datachannel_writer_test.go new file mode 100644 index 000000000..b07f8a3b6 --- /dev/null +++ b/pkg/sfu/datachannel/datachannel_writer_test.go @@ -0,0 +1,94 @@ +package datachannel + +import ( + "context" + "testing" + "time" + + "github.com/pion/datachannel" + "github.com/pion/transport/v3/deadline" + "github.com/stretchr/testify/require" +) + +func TestDataChannelWriter(t *testing.T) { + mockDC := newMockDataChannelWriter() + // slow threshold is 1000B/s + w := NewDataChannelWriter(mockDC, mockDC, 8000) + require.Equal(t, mockDC, w.BufferedAmountGetter()) + buf := make([]byte, 2000) + // write 2000 bytes so it should not drop in 2 seconds + t0 := time.Now() + n, err := w.Write(buf) + require.NoError(t, err) + require.Equal(t, 2000, n) + + t1 := time.Now() + mockDC.SetNextWriteCompleteAt(t0.Add(time.Second)) + n, err = w.Write(buf[:10]) + require.NoError(t, err) + require.Equal(t, 10, n) + require.GreaterOrEqual(t, time.Since(t1), time.Second) + + // bitrate below slow threshold(2000bytes/3sec), should drop by timeout + mockDC.SetNextWriteCompleteAt(t0.Add(3 * time.Second)) + n, err = w.Write(buf[:1000]) + require.ErrorIs(t, err, context.DeadlineExceeded, err) + require.Equal(t, 0, n) +} + +func TestDataChannelWriter_NoSlowThreshold(t *testing.T) { + mockDC := newMockDataChannelWriter() + w := NewDataChannelWriter(mockDC, mockDC, 0) + buf := make([]byte, 2000) + n, err := w.Write(buf) + require.NoError(t, err) + require.Equal(t, 2000, n) + mockDC.SetNextWriteCompleteAt(time.Now().Add(singleWriteTimeout / 2)) + n, err = w.Write(buf[:10]) + require.NoError(t, err) + require.Equal(t, 10, n) + + // slow threshold is 0, should not block & retry + mockDC.SetNextWriteCompleteAt(time.Now().Add(singleWriteTimeout * 2)) + n, err = w.Write(buf[:1000]) + require.ErrorIs(t, err, context.DeadlineExceeded, err) + require.Equal(t, 0, n) +} + +type mockDataChannelWriter struct { + datachannel.ReadWriteCloserDeadliner + nextWriteCompleteAt time.Time + deadline *deadline.Deadline +} + +func newMockDataChannelWriter() *mockDataChannelWriter { + return &mockDataChannelWriter{ + deadline: deadline.New(), + } +} + +func (m *mockDataChannelWriter) BufferedAmount() uint64 { + return 0 +} + +func (m *mockDataChannelWriter) Write(b []byte) (int, error) { + wait := time.Until(m.nextWriteCompleteAt) + if wait <= 0 { + return len(b), nil + } + select { + case <-m.deadline.Done(): + return 0, m.deadline.Err() + case <-time.After(wait): + return len(b), nil + } +} + +func (m *mockDataChannelWriter) SetWriteDeadline(t time.Time) error { + m.deadline.Set(t) + return nil +} + +func (m *mockDataChannelWriter) SetNextWriteCompleteAt(t time.Time) { + m.nextWriteCompleteAt = t +} diff --git a/test/client/client.go b/test/client/client.go index 7bdb6ae5c..7304532fa 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -214,23 +214,25 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { // publisherHandler := &transportfakes.FakeHandler{} c.publisher, err = rtc.NewPCTransport(rtc.TransportParams{ - Config: &conf, - DirectionConfig: conf.Subscriber, - EnabledCodecs: codecs, - IsOfferer: true, - IsSendSide: true, - Handler: publisherHandler, + Config: &conf, + DirectionConfig: conf.Subscriber, + EnabledCodecs: codecs, + IsOfferer: true, + IsSendSide: true, + Handler: publisherHandler, + DatachannelSlowThreshold: 1024 * 1024 * 1024, }) if err != nil { return nil, err } subscriberHandler := &transportfakes.FakeHandler{} c.subscriber, err = rtc.NewPCTransport(rtc.TransportParams{ - Config: &conf, - DirectionConfig: conf.Publisher, - EnabledCodecs: codecs, - Handler: subscriberHandler, - FireOnTrackBySdp: true, + Config: &conf, + DirectionConfig: conf.Publisher, + EnabledCodecs: codecs, + Handler: subscriberHandler, + DatachannelMaxReceiverBufferSize: 1500, + FireOnTrackBySdp: true, }) if err != nil { return nil, err diff --git a/test/client/datachannel_reader.go b/test/client/datachannel_reader.go new file mode 100644 index 000000000..e43d8908b --- /dev/null +++ b/test/client/datachannel_reader.go @@ -0,0 +1,35 @@ +package client + +import ( + "time" + + "github.com/livekit/livekit-server/pkg/sfu/datachannel" +) + +type DataChannelReader struct { + bitrate *datachannel.BitrateCalculator + target int +} + +func NewDataChannelReader(bitrate int) *DataChannelReader { + return &DataChannelReader{ + target: bitrate, + bitrate: datachannel.NewBitrateCalculator(datachannel.BitrateDuration*5, datachannel.BitrateWindow), + } +} + +func (d *DataChannelReader) Read(p []byte, sid string) { + for { + if bitrate := d.bitrate.Bitrate(time.Now()); bitrate > 0 && bitrate > d.target { + time.Sleep(20 * time.Millisecond) + d.bitrate.AddBytes(0, 0, time.Now()) + continue + } + break + } + d.bitrate.AddBytes(len(p), 0, time.Now()) +} + +func (d *DataChannelReader) Bitrate() int { + return d.bitrate.Bitrate(time.Now()) +} diff --git a/test/singlenode_test.go b/test/singlenode_test.go index de2133ca0..0e83fce10 100644 --- a/test/singlenode_test.go +++ b/test/singlenode_test.go @@ -16,6 +16,7 @@ package test import ( "context" + "encoding/binary" "errors" "fmt" "net/http" @@ -28,6 +29,7 @@ import ( "github.com/stretchr/testify/require" "github.com/thoas/go-funk" "github.com/twitchtv/twirp" + "go.uber.org/atomic" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" @@ -707,6 +709,124 @@ func TestSubscribeToCodecUnsupported(t *testing.T) { require.Nil(t, c2.GetSubscriptionResponseAndClear()) } +func TestDataPublishSlowSubscriber(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + dataChannelSlowThreshold := 101024 + + logger.Infow("----------------STARTING TEST----------------", "test", t.Name()) + s := createSingleNodeServer(func(c *config.Config) { + c.RTC.DatachannelSlowThreshold = dataChannelSlowThreshold + }) + go func() { + if err := s.Start(); err != nil { + logger.Errorw("server returned error", err) + } + }() + + waitForServerToStart(s) + + defer func() { + s.Stop(true) + logger.Infow("----------------FINISHING TEST----------------", "test", t.Name()) + }() + + pub := createRTCClient("pub", defaultServerPort, nil) + fastSub := createRTCClient("fastSub", defaultServerPort, nil) + slowSubNotDrop := createRTCClient("slowSubNotDrop", defaultServerPort, nil) + slowSubDrop := createRTCClient("slowSubDrop", defaultServerPort, nil) + waitUntilConnected(t, pub, fastSub, slowSubDrop, slowSubNotDrop) + defer func() { + pub.Stop() + fastSub.Stop() + slowSubNotDrop.Stop() + slowSubDrop.Stop() + }() + + // publisher sends data as fast as possible, it will block by the slowest subscriber above the slow threshold + var ( + blocked atomic.Bool + stopWrite atomic.Bool + writeIdx atomic.Uint64 + ) + writeStopped := make(chan struct{}) + go func() { + defer close(writeStopped) + var i int + buf := make([]byte, 100) + for !stopWrite.Load() { + i++ + binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(i)) + if err := pub.PublishData(buf, livekit.DataPacket_RELIABLE); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + blocked.Store(true) + i-- + continue + } else { + t.Log("error writing", err) + break + } + } + writeIdx.Store(uint64(i)) + } + }() + + // no data should be dropped for fast subscriber + var fastDataIndex atomic.Uint64 + fastSub.OnDataReceived = func(data []byte, sid string) { + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + require.Equal(t, fastDataIndex.Load()+1, idx) + fastDataIndex.Store(idx) + } + + // no data should be dropped for slow subscriber that is above threshold + var slowNoDropDataIndex atomic.Uint64 + var drainSlowSubNotDrop atomic.Bool + slowNoDropReader := testclient.NewDataChannelReader(dataChannelSlowThreshold * 3 / 2) + slowSubNotDrop.OnDataReceived = func(data []byte, sid string) { + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + require.Equal(t, slowNoDropDataIndex.Load()+1, idx) + slowNoDropDataIndex.Store(idx) + if !drainSlowSubNotDrop.Load() { + slowNoDropReader.Read(data, sid) + } + } + + // data should be dropped for slow subscriber that is below threshold + var slowDropDataIndex atomic.Uint64 + dropped := make(chan struct{}) + slowDropReader := testclient.NewDataChannelReader(dataChannelSlowThreshold / 2) + slowSubDrop.OnDataReceived = func(data []byte, sid string) { + select { + case <-dropped: + return + default: + } + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + if idx != slowDropDataIndex.Load()+1 { + close(dropped) + } + slowDropDataIndex.Store(idx) + slowDropReader.Read(data, sid) + } + + <-dropped + + time.Sleep(time.Second) + blocked.Store(false) + require.Eventually(t, func() bool { return blocked.Load() }, 30*time.Second, 100*time.Millisecond) + drainSlowSubNotDrop.Store(true) + stopWrite.Store(true) + <-writeStopped + require.Eventually(t, func() bool { + return writeIdx.Load() == fastDataIndex.Load() && + writeIdx.Load() == slowNoDropDataIndex.Load() + }, 5*time.Second, 50*time.Millisecond) +} + func TestFireTrackBySdp(t *testing.T) { _, finish := setupSingleNodeTest("TestFireTrackBySdp") defer finish()