From 713e67cd52c71a0cef77bd87e62cd667b1ea79e8 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Wed, 18 Dec 2024 10:51:34 +0800 Subject: [PATCH] Thottle the publisher data channel sending when subscriber is slow (#3255) * Thottle the publisher data channel sending when subscriber is slow Avoid the publisher overwhelm the sfu data channel buffer when the subscriber has lower receive bitrates. It will drop message if the subscriber is considered too slow to block the entire room. * Enable nack in mediaengine and disable it in transceiver as need pion doesn't support per transciver codec configuration, so the nack of this session will be disabled forever once it is first disabled by a transceiver. https://github.com/pion/webrtc/pull/2972 --- go.mod | 10 +- go.sum | 20 +- pkg/config/config.go | 9 +- pkg/rtc/config.go | 7 +- pkg/rtc/participant.go | 4 +- pkg/rtc/room.go | 2 +- pkg/rtc/transport.go | 200 +++++++++++------- pkg/rtc/transport_test.go | 9 +- pkg/rtc/transportmanager.go | 66 +++--- pkg/service/roommanager.go | 3 +- pkg/sfu/datachannel/bitrate.go | 93 ++++++++ pkg/sfu/datachannel/bitrate_test.go | 29 +++ pkg/sfu/datachannel/datachannel_writer.go | 75 +++++++ .../datachannel/datachannel_writer_test.go | 94 ++++++++ test/client/client.go | 24 ++- test/client/datachannel_reader.go | 35 +++ test/singlenode_test.go | 120 +++++++++++ 17 files changed, 651 insertions(+), 149 deletions(-) create mode 100644 pkg/sfu/datachannel/bitrate.go create mode 100644 pkg/sfu/datachannel/bitrate_test.go create mode 100644 pkg/sfu/datachannel/datachannel_writer.go create mode 100644 pkg/sfu/datachannel/datachannel_writer_test.go create mode 100644 test/client/datachannel_reader.go diff --git a/go.mod b/go.mod index cddb7ceec..f5376a0b2 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20241128072814-c363618d4c98 - github.com/livekit/protocol v1.29.5-0.20241209183753-f6b5078b2244 + github.com/livekit/protocol v1.29.5-0.20241217013317-bc388341b9f2 github.com/livekit/psrpc v0.6.1-0.20241018124827-1efff3d113a8 github.com/mackerelio/go-osstat v0.2.5 github.com/magefile/mage v1.15.0 @@ -29,16 +29,17 @@ require ( github.com/mitchellh/go-homedir v1.1.0 github.com/olekukonko/tablewriter v0.0.5 github.com/ory/dockertest/v3 v3.11.0 + github.com/pion/datachannel v1.5.10 github.com/pion/dtls/v3 v3.0.4 github.com/pion/ice/v4 v4.0.3 github.com/pion/interceptor v0.1.37 - github.com/pion/rtcp v1.2.14 + github.com/pion/rtcp v1.2.15 github.com/pion/rtp v1.8.9 - github.com/pion/sctp v1.8.34 + github.com/pion/sctp v1.8.35 github.com/pion/sdp/v3 v3.0.9 github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v4 v4.0.0 - github.com/pion/webrtc/v4 v4.0.5 + github.com/pion/webrtc/v4 v4.0.6 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.20.5 github.com/redis/go-redis/v9 v9.7.0 @@ -108,7 +109,6 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opencontainers/runc v1.1.14 // indirect - github.com/pion/datachannel v1.5.9 // indirect github.com/pion/logging v0.2.2 // indirect github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/randutil v0.1.0 // indirect diff --git a/go.sum b/go.sum index 9999d9280..fe7cca267 100644 --- a/go.sum +++ b/go.sum @@ -165,8 +165,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20241128072814-c363618d4c98 h1:QA7DqIC/ZSsMj8HC0+zNfMMwssHbA0alZALK68r30LQ= github.com/livekit/mediatransportutil v0.0.0-20241128072814-c363618d4c98/go.mod h1:WIVFAGzVZ7VMjPC5+nbSfwdFjWcbuLgx97KeNSUDTEo= -github.com/livekit/protocol v1.29.5-0.20241209183753-f6b5078b2244 h1:Eg9HK+5bMCDRKhh5g5g16oyNaMbCqMrJvxFBaBuP7Vo= -github.com/livekit/protocol v1.29.5-0.20241209183753-f6b5078b2244/go.mod h1:NDg1btMpKCzr/w6QR5kDuXw/e4Y7yOBE+RUAHsc+Y/M= +github.com/livekit/protocol v1.29.5-0.20241217013317-bc388341b9f2 h1:knHtTlhR89ly9TZ2JiyfT1ibqziv/rDcfSf3voQw8rE= +github.com/livekit/protocol v1.29.5-0.20241217013317-bc388341b9f2/go.mod h1:NDg1btMpKCzr/w6QR5kDuXw/e4Y7yOBE+RUAHsc+Y/M= github.com/livekit/psrpc v0.6.1-0.20241018124827-1efff3d113a8 h1:Ibh0LoFl5NW5a1KFJEE0eLxxz7dqqKmYTj/BfCb0PbY= github.com/livekit/psrpc v0.6.1-0.20241018124827-1efff3d113a8/go.mod h1:CQUBSPfYYAaevg1TNCc6/aYsa8DJH4jSRFdCeSZk5u0= github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o= @@ -228,8 +228,8 @@ github.com/opencontainers/runc v1.1.14 h1:rgSuzbmgz5DUJjeSnw337TxDbRuqjs6iqQck/2 github.com/opencontainers/runc v1.1.14/go.mod h1:E4C2z+7BxR7GHXp0hAY53mek+x49X1LjPNeMTfRGvOA= github.com/ory/dockertest/v3 v3.11.0 h1:OiHcxKAvSDUwsEVh2BjxQQc/5EHz9n0va9awCtNGuyA= github.com/ory/dockertest/v3 v3.11.0/go.mod h1:VIPxS1gwT9NpPOrfD3rACs8Y9Z7yhzO4SB194iUDnUI= -github.com/pion/datachannel v1.5.9 h1:LpIWAOYPyDrXtU+BW7X0Yt/vGtYxtXQ8ql7dFfYUVZA= -github.com/pion/datachannel v1.5.9/go.mod h1:kDUuk4CU4Uxp82NH4LQZbISULkX/HtzKa4P7ldf9izE= +github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o= +github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M= github.com/pion/dtls/v3 v3.0.4 h1:44CZekewMzfrn9pmGrj5BNnTMDCFwr+6sLH+cCuLM7U= github.com/pion/dtls/v3 v3.0.4/go.mod h1:R373CsjxWqNPf6MEkfdy3aSe9niZvL/JaKlGeFphtMg= github.com/pion/ice/v4 v4.0.3 h1:9s5rI1WKzF5DRqhJ+Id8bls/8PzM7mau0mj1WZb4IXE= @@ -242,12 +242,12 @@ github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= -github.com/pion/rtcp v1.2.14 h1:KCkGV3vJ+4DAJmvP0vaQShsb0xkRfWkO540Gy102KyE= -github.com/pion/rtcp v1.2.14/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4= +github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo= +github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0= github.com/pion/rtp v1.8.9 h1:E2HX740TZKaqdcPmf4pw6ZZuG8u5RlMMt+l3dxeu6Wk= github.com/pion/rtp v1.8.9/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= -github.com/pion/sctp v1.8.34 h1:rCuD3m53i0oGxCSp7FLQKvqVx0Nf5AUAHhMRXTTQjBc= -github.com/pion/sctp v1.8.34/go.mod h1:yWkCClkXlzVW7BXfI2PjrUGBwUI0CjXJBkhLt+sdo4U= +github.com/pion/sctp v1.8.35 h1:qwtKvNK1Wc5tHMIYgTDJhfZk7vATGVHhXbUDfHbYwzA= +github.com/pion/sctp v1.8.35/go.mod h1:EcXP8zCYVTRy3W9xtOF7wJm1L1aXfKRQzaM33SjQlzg= github.com/pion/sdp/v3 v3.0.9 h1:pX++dCHoHUwq43kuwf3PyJfHlwIj4hXA7Vrifiq0IJY= github.com/pion/sdp/v3 v3.0.9/go.mod h1:B5xmvENq5IXJimIO4zfp6LAe1fD9N+kFv+V/1lOdz8M= github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M= @@ -258,8 +258,8 @@ github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1 github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM= github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA= -github.com/pion/webrtc/v4 v4.0.5 h1:8cVPojcv3cQTwVga2vF1rzCNvkiEimnYdCCG7yF317I= -github.com/pion/webrtc/v4 v4.0.5/go.mod h1:LvP8Np5b/sM0uyJIcUPvJcCvhtjHxJwzh2H2PYzE6cQ= +github.com/pion/webrtc/v4 v4.0.6 h1:OfxfGeZGhneUDnZEoebLGDkzwjowSJ0avbOu2xaIUeM= +github.com/pion/webrtc/v4 v4.0.6/go.mod h1:j7oMHYvjl7lESJ/nYiE4d2URyjFbAo3uqJ6Xse6hbSg= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/pkg/config/config.go b/pkg/config/config.go index afc27c2a6..bf2a59be0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -84,6 +84,7 @@ type RTCConfig struct { TURNServers []TURNServer `yaml:"turn_servers,omitempty"` + // Deprecated StrictACKs bool `yaml:"strict_acks,omitempty"` // Deprecated: use PacketBufferSizeVideo and PacketBufferSizeAudio @@ -110,9 +111,13 @@ type RTCConfig struct { // force a reconnect on a data channel error ReconnectOnDataChannelError *bool `yaml:"reconnect_on_data_channel_error,omitempty"` - // max number of bytes to buffer for data channel. 0 means unlimited + // Deprecated DataChannelMaxBufferedAmount uint64 `yaml:"data_channel_max_buffered_amount,omitempty"` + // Threshold of data channel writing to be considered too slow, data packet could + // be dropped for a slow data channel to avoid blocking the room. + DatachannelSlowThreshold int `yaml:"datachannel_slow_threshold,omitempty"` + ForwardStats ForwardStatsConfig `yaml:"forward_stats,omitempty"` } @@ -311,7 +316,6 @@ var DefaultConfig = Config{ PacketBufferSize: 500, PacketBufferSizeVideo: 500, PacketBufferSizeAudio: 200, - StrictACKs: true, PLIThrottle: sfu.DefaultPLIThrottleConfig, CongestionControl: CongestionControlConfig{ Enabled: true, @@ -322,6 +326,7 @@ var DefaultConfig = Config{ UseSendSideBWE: false, SendSideBWE: sendsidebwe.DefaultSendSideBWEConfig, }, + DatachannelSlowThreshold: 1000000, }, Audio: sfu.DefaultAudioConfig, Video: VideoConfig{ diff --git a/pkg/rtc/config.go b/pkg/rtc/config.go index 8e2820195..a95a22526 100644 --- a/pkg/rtc/config.go +++ b/pkg/rtc/config.go @@ -58,7 +58,6 @@ type RTCPFeedbackConfig struct { type DirectionConfig struct { RTPHeaderExtension RTPHeaderExtensionConfig RTCPFeedback RTCPFeedbackConfig - StrictACKs bool } func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { @@ -86,7 +85,6 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { // publisher configuration publisherConfig := DirectionConfig{ - StrictACKs: true, // publisher is dialed, and will always reply with ACK RTPHeaderExtension: RTPHeaderExtensionConfig{ Audio: []string{ sdp.SDESMidURI, @@ -119,7 +117,6 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { // subscriber configuration subscriberConfig := DirectionConfig{ - StrictACKs: rtcConf.StrictACKs, RTPHeaderExtension: RTPHeaderExtensionConfig{ Video: []string{ dd.ExtensionURI, @@ -130,6 +127,10 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { }, }, RTCPFeedback: RTCPFeedbackConfig{ + Audio: []webrtc.RTCPFeedback{ + // always enable NACK for audio but disable it later for red enabled transceiver. https://github.com/pion/webrtc/pull/2972 + {Type: webrtc.TypeRTCPFBNACK}, + }, Video: []webrtc.RTCPFeedback{ {Type: webrtc.TypeRTCPFBCCM, Parameter: "fir"}, {Type: webrtc.TypeRTCPFBNACK}, diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index a8f1e0b56..fbcbdd228 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -147,7 +147,6 @@ type ParticipantParams struct { ReconnectOnPublicationError bool ReconnectOnSubscriptionError bool ReconnectOnDataChannelError bool - DataChannelMaxBufferedAmount uint64 VersionGenerator utils.TimedVersionGenerator TrackResolver types.MediaTrackResolver DisableDynacast bool @@ -161,6 +160,8 @@ type ParticipantParams struct { MetricConfig metric.MetricConfig UseOneShotSignallingMode bool EnableMetrics bool + DataChannelMaxBufferedAmount uint64 + DatachannelSlowThreshold int FireOnTrackBySdp bool } @@ -1565,6 +1566,7 @@ func (p *ParticipantImpl) setupTransportManager() error { TURNSEnabled: p.params.TURNSEnabled, AllowPlayoutDelay: p.params.PlayoutDelay.GetEnabled(), DataChannelMaxBufferedAmount: p.params.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: p.params.DatachannelSlowThreshold, Logger: p.params.Logger.WithComponent(sutils.ComponentTransport), PublisherHandler: pth, SubscriberHandler: sth, diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index f6dff4eb2..412f787b1 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -52,7 +52,7 @@ const ( invAudioLevelQuantization = 1.0 / AudioLevelQuantization subscriberUpdateInterval = 3 * time.Second - dataForwardLoadBalanceThreshold = 20 + dataForwardLoadBalanceThreshold = 4 simulateDisconnectSignalTimeout = 5 * time.Second ) diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 0d76fbf70..21990f6db 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -16,6 +16,7 @@ package rtc import ( "fmt" + "io" "net" "strconv" "strings" @@ -29,7 +30,6 @@ import ( "github.com/pion/interceptor/pkg/gcc" "github.com/pion/interceptor/pkg/twcc" "github.com/pion/rtcp" - "github.com/pion/sctp" "github.com/pion/sdp/v3" "github.com/pion/webrtc/v4" "github.com/pkg/errors" @@ -42,6 +42,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/bwe" "github.com/livekit/livekit-server/pkg/sfu/bwe/remotebwe" "github.com/livekit/livekit-server/pkg/sfu/bwe/sendsidebwe" + "github.com/livekit/livekit-server/pkg/sfu/datachannel" sfuinterceptor "github.com/livekit/livekit-server/pkg/sfu/interceptor" "github.com/livekit/livekit-server/pkg/sfu/pacer" pd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/playoutdelay" @@ -78,6 +79,8 @@ const ( maxConnectTimeoutAfterICE = 20 * time.Second // max duration for waiting pc to connect after ICE is connected shortConnectionThreshold = 90 * time.Second + + dataChannelBufferSize = 65535 ) var ( @@ -188,9 +191,9 @@ type PCTransport struct { firstOfferReceived bool firstOfferNoDataChannel bool - reliableDC *webrtc.DataChannel + reliableDC *datachannel.DataChannelWriter[*webrtc.DataChannel] reliableDCOpened bool - lossyDC *webrtc.DataChannel + lossyDC *datachannel.DataChannelWriter[*webrtc.DataChannel] lossyDCOpened bool iceStartedAt time.Time @@ -258,9 +261,13 @@ type TransportParams struct { IsOfferer bool IsSendSide bool AllowPlayoutDelay bool - DataChannelMaxBufferedAmount uint64 UseOneShotSignallingMode bool FireOnTrackBySdp bool + DataChannelMaxBufferedAmount uint64 + DatachannelSlowThreshold int + + // for development test + DatachannelMaxReceiverBufferSize int } func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimator cc.BandwidthEstimator)) (*webrtc.PeerConnection, *webrtc.MediaEngine, error) { @@ -288,6 +295,13 @@ func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimat // https://github.com/pion/webrtc/pull/2961 se.DisableCloseByDTLS(true) + se.DetachDataChannels() + if params.DatachannelSlowThreshold > 0 { + se.EnableDataChannelBlockWrite(true) + } + if params.DatachannelMaxReceiverBufferSize > 0 { + se.SetSCTPMaxReceiveBufferSize(uint32(params.DatachannelMaxReceiverBufferSize)) + } if params.FireOnTrackBySdp { se.SetFireOnTrackBeforeFirstRTP(true) } @@ -361,7 +375,6 @@ func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimat ir := &interceptor.Registry{} if params.IsSendSide { - se.DetachDataChannels() if params.CongestionControlConfig.UseSendSideBWEInterceptor && !params.CongestionControlConfig.UseSendSideBWE { params.Logger.Infow("using send side BWE - interceptor") gf, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { @@ -717,31 +730,65 @@ func (t *PCTransport) onPeerConnectionStateChange(state webrtc.PeerConnectionSta } func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { - t.params.Logger.Debugw(dc.Label() + " data channel open") - switch dc.Label() { - case ReliableDataChannel: - t.lock.Lock() - t.reliableDC = dc - t.reliableDCOpened = true - t.lock.Unlock() - dc.OnMessage(func(msg webrtc.DataChannelMessage) { - t.params.Handler.OnDataPacket(livekit.DataPacket_RELIABLE, msg.Data) - }) + dc.OnOpen(func() { + t.params.Logger.Debugw(dc.Label() + " data channel open") + var kind livekit.DataPacket_Kind + switch dc.Label() { + case ReliableDataChannel: + kind = livekit.DataPacket_RELIABLE + + case LossyDataChannel: + kind = livekit.DataPacket_LOSSY + + default: + t.params.Logger.Warnw("unsupported datachannel added", nil, "label", dc.Label()) + return + } + + rawDC, err := dc.DetachWithDeadline() + if err != nil { + t.params.Logger.Errorw("failed to detach data channel", err, "label", dc.Label()) + return + } + + switch kind { + case livekit.DataPacket_RELIABLE: + t.lock.Lock() + if t.reliableDC != nil { + t.reliableDC.Close() + } + t.reliableDC = datachannel.NewDataChannelWriter(dc, rawDC, t.params.DatachannelSlowThreshold) + t.reliableDCOpened = true + t.lock.Unlock() + + case livekit.DataPacket_LOSSY: + t.lock.Lock() + if t.lossyDC != nil { + t.lossyDC.Close() + } + t.lossyDC = datachannel.NewDataChannelWriter(dc, rawDC, 0) + t.lossyDCOpened = true + t.lock.Unlock() + } + + go func() { + defer rawDC.Close() + buffer := make([]byte, dataChannelBufferSize) + for { + n, _, err := rawDC.ReadDataChannel(buffer) + if err != nil { + if !errors.Is(err, io.EOF) { + t.params.Logger.Warnw("error reading data channel", err, "label", dc.Label()) + } + return + } + + t.params.Handler.OnDataPacket(kind, buffer[:n]) + } + }() t.maybeNotifyFullyEstablished() - case LossyDataChannel: - t.lock.Lock() - t.lossyDC = dc - t.lossyDCOpened = true - t.lock.Unlock() - dc.OnMessage(func(msg webrtc.DataChannelMessage) { - t.params.Handler.OnDataPacket(livekit.DataPacket_LOSSY, msg.Data) - }) - - t.maybeNotifyFullyEstablished() - default: - t.params.Logger.Warnw("unsupported datachannel added", nil, "label", dc.Label()) - } + }) } func (t *PCTransport) maybeNotifyFullyEstablished() { @@ -866,7 +913,7 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni return err } var ( - dcPtr **webrtc.DataChannel + dcPtr **datachannel.DataChannelWriter[*webrtc.DataChannel] dcReady *bool ) switch dc.Label() { @@ -883,60 +930,41 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni dcReady = &t.lossyDCOpened } - dcReadyHandler := func() { + dc.OnOpen(func() { + rawDC, err := dc.DetachWithDeadline() + if err != nil { + t.params.Logger.Warnw("failed to detach data channel", err) + return + } + + var slowThreshold int + if dc.Label() == ReliableDataChannel { + slowThreshold = t.params.DatachannelSlowThreshold + } + t.lock.Lock() + if *dcPtr != nil { + (*dcPtr).Close() + } + *dcPtr = datachannel.NewDataChannelWriter(dc, rawDC, slowThreshold) *dcReady = true t.lock.Unlock() t.params.Logger.Debugw(dc.Label() + " data channel open") t.maybeNotifyFullyEstablished() - } + }) - dcCloseHandler := func() { - t.params.Logger.Debugw(dc.Label() + " data channel close") - } - - dcErrorHandler := func(err error) { - if !errors.Is(err, sctp.ErrResetPacketInStateNotExist) && !errors.Is(err, sctp.ErrChunk) { - t.params.Logger.Warnw(dc.Label()+" data channel error", err) - } - } - - t.lock.Lock() - defer t.lock.Unlock() - *dcPtr = dc - if t.params.DirectionConfig.StrictACKs { - dc.OnOpen(func() { - if t.params.IsSendSide { - if _, err := dc.Detach(); err != nil { - t.params.Logger.Warnw("failed to detach data channel", err) - } - } - dcReadyHandler() - }) - } else { - dc.OnOpen(func() { - if t.params.IsSendSide { - if _, err := dc.Detach(); err != nil { - t.params.Logger.Warnw("failed to detach data channel", err) - } - } - }) - dc.OnDial(dcReadyHandler) - } - dc.OnClose(dcCloseHandler) - dc.OnError(dcErrorHandler) return nil } func (t *PCTransport) CreateDataChannelIfEmpty(dcLabel string, dci *webrtc.DataChannelInit) (label string, id uint16, existing bool, err error) { t.lock.RLock() - var dc *webrtc.DataChannel + var dcw *datachannel.DataChannelWriter[*webrtc.DataChannel] switch dcLabel { case ReliableDataChannel: - dc = t.reliableDC + dcw = t.reliableDC case LossyDataChannel: - dc = t.lossyDC + dcw = t.lossyDC default: t.params.Logger.Warnw("unknown data channel label", nil, "label", label) err = errors.New("unknown data channel label") @@ -946,11 +974,12 @@ func (t *PCTransport) CreateDataChannelIfEmpty(dcLabel string, dci *webrtc.DataC return } - if dc != nil { + if dcw != nil { + dc := dcw.BufferedAmountGetter() return dc.Label(), *dc.ID(), true, nil } - dc, err = t.pc.CreateDataChannel(dcLabel, dci) + dc, err := t.pc.CreateDataChannel(dcLabel, dci) if err != nil { return } @@ -989,7 +1018,7 @@ func (t *PCTransport) WriteRTCP(pkts []rtcp.Packet) error { } func (t *PCTransport) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byte) error { - var dc *webrtc.DataChannel + var dc *datachannel.DataChannelWriter[*webrtc.DataChannel] t.lock.RLock() if kind == livekit.DataPacket_RELIABLE { dc = t.reliableDC @@ -1006,11 +1035,12 @@ func (t *PCTransport) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byt return ErrTransportFailure } - if t.params.DataChannelMaxBufferedAmount > 0 && dc.BufferedAmount() > t.params.DataChannelMaxBufferedAmount { + if t.params.DatachannelSlowThreshold == 0 && t.params.DataChannelMaxBufferedAmount > 0 && dc.BufferedAmountGetter().BufferedAmount() > t.params.DataChannelMaxBufferedAmount { return ErrDataChannelBufferFull } + _, err := dc.Write(encoded) - return dc.Send(encoded) + return err } func (t *PCTransport) Close() { @@ -1031,6 +1061,18 @@ func (t *PCTransport) Close() { _ = t.pc.Close() t.clearConnTimer() + + t.lock.Lock() + if t.reliableDC != nil { + t.reliableDC.Close() + t.reliableDC = nil + } + + if t.lossyDC != nil { + t.lossyDC.Close() + t.lossyDC = nil + } + t.lock.Unlock() } func (t *PCTransport) clearConnTimer() { @@ -2020,6 +2062,8 @@ func (t *PCTransport) handleICERestart(_ event) error { } // configure subscriber transceiver for audio stereo and nack +// pion doesn't support per transciver codec configuration, so the nack of this session will be disabled +// forever once it is first disabled by a transceiver. func configureAudioTransceiver(tr *webrtc.RTPTransceiver, stereo bool, nack bool) { sender := tr.Sender() if sender == nil { @@ -2034,17 +2078,13 @@ func configureAudioTransceiver(tr *webrtc.RTPTransceiver, stereo bool, nack bool if stereo { c.SDPFmtpLine += ";sprop-stereo=1" } - if nack { - var nackFound bool - for _, fb := range c.RTCPFeedback { + if !nack { + for i, fb := range c.RTCPFeedback { if fb.Type == webrtc.TypeRTCPFBNACK { - nackFound = true + c.RTCPFeedback = append(c.RTCPFeedback[:i], c.RTCPFeedback[i+1:]...) break } } - if !nackFound { - c.RTCPFeedback = append(c.RTCPFeedback, webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBNACK}) - } } } configCodecs = append(configCodecs, c) diff --git a/pkg/rtc/transport_test.go b/pkg/rtc/transport_test.go index 16bff8b52..512b2ff91 100644 --- a/pkg/rtc/transport_test.go +++ b/pkg/rtc/transport_test.go @@ -587,10 +587,6 @@ func untilTransportsConnected(transports ...*transportfakes.FakeHandler) *sync.W } func TestConfigureAudioTransceiver(t *testing.T) { - pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) - require.NoError(t, err) - defer pc.Close() - for _, testcase := range []struct { nack bool stereo bool @@ -601,6 +597,11 @@ func TestConfigureAudioTransceiver(t *testing.T) { {true, true}, } { t.Run(fmt.Sprintf("nack=%v,stereo=%v", testcase.nack, testcase.stereo), func(t *testing.T) { + var me webrtc.MediaEngine + registerCodecs(&me, []*livekit.Codec{{Mime: webrtc.MimeTypeOpus}}, RTCPFeedbackConfig{Audio: []webrtc.RTCPFeedback{{Type: webrtc.TypeRTCPFBNACK}}}, false) + pc, err := webrtc.NewAPI(webrtc.WithMediaEngine(&me)).NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pc.Close() tr, err := pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendonly}) require.NoError(t, err) diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index 275219fee..68889d70e 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -15,6 +15,7 @@ package rtc import ( + "context" "io" "math/bits" "sync" @@ -99,6 +100,7 @@ type TransportManagerParams struct { TURNSEnabled bool AllowPlayoutDelay bool DataChannelMaxBufferedAmount uint64 + DatachannelSlowThreshold int Logger logger.Logger PublisherHandler transport.Handler SubscriberHandler transport.Handler @@ -146,21 +148,23 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro lgr := LoggerWithPCTarget(params.Logger, livekit.SignalTarget_PUBLISHER) publisher, err := NewPCTransport(TransportParams{ - ParticipantID: params.SID, - ParticipantIdentity: params.Identity, - ProtocolVersion: params.ProtocolVersion, - Config: params.Config, - Twcc: params.Twcc, - DirectionConfig: params.Config.Publisher, - CongestionControlConfig: params.CongestionControlConfig, - EnabledCodecs: params.EnabledPublishCodecs, - Logger: lgr, - SimTracks: params.SimTracks, - ClientInfo: params.ClientInfo, - Transport: livekit.SignalTarget_PUBLISHER, - Handler: TransportManagerPublisherTransportHandler{TransportManagerTransportHandler{params.PublisherHandler, t, lgr}}, - UseOneShotSignallingMode: params.UseOneShotSignallingMode, - FireOnTrackBySdp: params.FireOnTrackBySdp, + ParticipantID: params.SID, + ParticipantIdentity: params.Identity, + ProtocolVersion: params.ProtocolVersion, + Config: params.Config, + Twcc: params.Twcc, + DirectionConfig: params.Config.Publisher, + CongestionControlConfig: params.CongestionControlConfig, + EnabledCodecs: params.EnabledPublishCodecs, + Logger: lgr, + SimTracks: params.SimTracks, + ClientInfo: params.ClientInfo, + Transport: livekit.SignalTarget_PUBLISHER, + Handler: TransportManagerPublisherTransportHandler{TransportManagerTransportHandler{params.PublisherHandler, t, lgr}}, + UseOneShotSignallingMode: params.UseOneShotSignallingMode, + DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: params.DatachannelSlowThreshold, + FireOnTrackBySdp: params.FireOnTrackBySdp, }) if err != nil { return nil, err @@ -169,21 +173,21 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro lgr = LoggerWithPCTarget(params.Logger, livekit.SignalTarget_SUBSCRIBER) subscriber, err := NewPCTransport(TransportParams{ - ParticipantID: params.SID, - ParticipantIdentity: params.Identity, - ProtocolVersion: params.ProtocolVersion, - Config: params.Config, - DirectionConfig: params.Config.Subscriber, - CongestionControlConfig: params.CongestionControlConfig, - EnabledCodecs: params.EnabledSubscribeCodecs, - Logger: lgr, - ClientInfo: params.ClientInfo, - IsOfferer: true, - IsSendSide: true, - AllowPlayoutDelay: params.AllowPlayoutDelay, - DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, - Transport: livekit.SignalTarget_SUBSCRIBER, - Handler: TransportManagerTransportHandler{params.SubscriberHandler, t, lgr}, + ParticipantID: params.SID, + ParticipantIdentity: params.Identity, + ProtocolVersion: params.ProtocolVersion, + Config: params.Config, + DirectionConfig: params.Config.Subscriber, + CongestionControlConfig: params.CongestionControlConfig, + EnabledCodecs: params.EnabledSubscribeCodecs, + Logger: lgr, + ClientInfo: params.ClientInfo, + IsOfferer: true, + IsSendSide: true, + AllowPlayoutDelay: params.AllowPlayoutDelay, + DatachannelSlowThreshold: params.DatachannelSlowThreshold, + Transport: livekit.SignalTarget_SUBSCRIBER, + Handler: TransportManagerTransportHandler{params.SubscriberHandler, t, lgr}, }) if err != nil { return nil, err @@ -294,7 +298,7 @@ func (t *TransportManager) SendDataPacket(kind livekit.DataPacket_Kind, encoded // downstream data is sent via primary peer connection err := t.getTransport(true).SendDataPacket(kind, encoded) if err != nil { - if !utils.ErrorIsOneOf(err, io.ErrClosedPipe, sctp.ErrStreamClosed, ErrTransportFailure, ErrDataChannelBufferFull) { + if !utils.ErrorIsOneOf(err, io.ErrClosedPipe, sctp.ErrStreamClosed, ErrTransportFailure, ErrDataChannelBufferFull, context.DeadlineExceeded) { t.params.Logger.Warnw("send data packet error", err) } if utils.ErrorIsOneOf(err, sctp.ErrStreamClosed, io.ErrClosedPipe) { diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 0977c2f41..a33730c98 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -480,7 +480,6 @@ func (r *RoomManager) StartSession( ReconnectOnPublicationError: reconnectOnPublicationError, ReconnectOnSubscriptionError: reconnectOnSubscriptionError, ReconnectOnDataChannelError: reconnectOnDataChannelError, - DataChannelMaxBufferedAmount: r.config.RTC.DataChannelMaxBufferedAmount, VersionGenerator: r.versionGenerator, TrackResolver: room.ResolveMediaTrackForSubscriber, SubscriberAllowPause: subscriberAllowPause, @@ -491,6 +490,8 @@ func (r *RoomManager) StartSession( ForwardStats: r.forwardStats, MetricConfig: r.config.Metric, UseOneShotSignallingMode: useOneShotSignallingMode, + DataChannelMaxBufferedAmount: r.config.RTC.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: r.config.RTC.DatachannelSlowThreshold, FireOnTrackBySdp: true, }) if err != nil { diff --git a/pkg/sfu/datachannel/bitrate.go b/pkg/sfu/datachannel/bitrate.go new file mode 100644 index 000000000..1e31a8c2b --- /dev/null +++ b/pkg/sfu/datachannel/bitrate.go @@ -0,0 +1,93 @@ +package datachannel + +import ( + "sync" + "time" + + "github.com/gammazero/deque" + + "github.com/livekit/protocol/utils/mono" +) + +const ( + BitrateDuration = 2 * time.Second + BitrateWindow = 100 * time.Millisecond +) + +// BitrateCalculator calculates bitrate over sliding window +type BitrateCalculator struct { + lock sync.Mutex + windowDuration time.Duration + duration time.Duration + + windows *deque.Deque[bitrateWindow] + active bitrateWindow + + bytes int + lastBufferedAmount int + start time.Time +} + +func NewBitrateCalculator(duration time.Duration, window time.Duration) *BitrateCalculator { + windowCnt := int((duration + (window - 1)) / window) + if windowCnt == 0 { + windowCnt = 1 + } + now := mono.Now() + c := &BitrateCalculator{ + duration: duration, + windowDuration: window, + windows: deque.New[bitrateWindow](windowCnt+1, windowCnt+1), + start: now, + active: bitrateWindow{start: now}, + } + + return c +} + +func (c *BitrateCalculator) AddBytes(bytes int, bufferedAmout int, ts time.Time) { + c.lock.Lock() + defer c.lock.Unlock() + + bytes -= bufferedAmout - c.lastBufferedAmount + c.lastBufferedAmount = bufferedAmout + if ts.Sub(c.active.start) >= c.windowDuration { + c.windows.PushBack(c.active) + c.active.start = ts + c.active.bytes = 0 + + for c.windows.Len() > 0 { + // pop expired windows + if w := c.windows.Front(); ts.Sub(w.start) > (c.duration + c.windowDuration) { + c.bytes -= w.bytes + c.windows.PopFront() + } else { + c.start = w.start + break + } + } + if c.windows.Len() == 0 { + c.start = ts + c.bytes = 0 + } + } + c.bytes += bytes + c.active.bytes += bytes + +} + +func (c *BitrateCalculator) Bitrate(ts time.Time) int { + c.lock.Lock() + defer c.lock.Unlock() + duration := ts.Sub(c.start) + if duration < c.windowDuration { + duration = c.windowDuration + } + + return c.bytes * 8 * 1000 / int(duration.Milliseconds()) +} + +type bitrateWindow struct { + start time.Time + bytes int +} diff --git a/pkg/sfu/datachannel/bitrate_test.go b/pkg/sfu/datachannel/bitrate_test.go new file mode 100644 index 000000000..4b4e0a1a6 --- /dev/null +++ b/pkg/sfu/datachannel/bitrate_test.go @@ -0,0 +1,29 @@ +package datachannel + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestBitrateCalculator(t *testing.T) { + c := NewBitrateCalculator(BitrateDuration, BitrateWindow) + require.NotNil(t, c) + + t0 := time.Now() + c.AddBytes(100, 0, t0) + // bytes buffered + c.AddBytes(100, 100, t0.Add(50*time.Millisecond)) + // 50 bytes sent (50 bytes buffer flushed) + c.AddBytes(100, 50, t0.Add(time.Second)) + + // 250 bytes sent in 1 second + require.Equal(t, 2000, c.Bitrate(t0.Add(time.Second))) + + // silence for long time + t1 := t0.Add(2 * BitrateDuration) + // 150 bytes sent (50 bytes buffer flushed) + c.AddBytes(100, 0, t1) + require.Equal(t, 1200, c.Bitrate(t1.Add(time.Second))) +} diff --git a/pkg/sfu/datachannel/datachannel_writer.go b/pkg/sfu/datachannel/datachannel_writer.go new file mode 100644 index 000000000..2ac05b351 --- /dev/null +++ b/pkg/sfu/datachannel/datachannel_writer.go @@ -0,0 +1,75 @@ +package datachannel + +import ( + "context" + "errors" + "time" + + "github.com/pion/datachannel" + + "github.com/livekit/protocol/utils/mono" +) + +const ( + singleWriteTimeout = 50 * time.Millisecond +) + +var ErrDataDroppedBySlowReader = errors.New("data dropped by slow reader") + +type BufferedAmountGetter interface { + BufferedAmount() uint64 +} + +type DataChannelWriter[T BufferedAmountGetter] struct { + bufferGetter T + rawDC datachannel.ReadWriteCloserDeadliner + slowThreshold int + rate *BitrateCalculator +} + +// NewDataChannelWriter creates a new DataChannelWriter by detaching the data channel, when +// writing to the datachanel times out, it will block and retry if the receiver's bitrate is +// above the slowThreshold or drop the data if it's below the threshold. If the slowThreshold +// is 0, it will never retry on write timeout. +func NewDataChannelWriter[T BufferedAmountGetter](bufferGetter T, rawDC datachannel.ReadWriteCloserDeadliner, slowThreshold int) *DataChannelWriter[T] { + var rate *BitrateCalculator + if slowThreshold > 0 { + rate = NewBitrateCalculator(BitrateDuration, BitrateWindow) + } + return &DataChannelWriter[T]{ + bufferGetter: bufferGetter, + rawDC: rawDC, + slowThreshold: slowThreshold, + rate: rate, + } +} + +func (w *DataChannelWriter[T]) BufferedAmountGetter() T { + return w.bufferGetter +} + +func (w *DataChannelWriter[T]) Write(p []byte) (n int, err error) { + for { + err = w.rawDC.SetWriteDeadline(time.Now().Add(singleWriteTimeout)) + if err != nil { + return 0, err + } + n, err = w.rawDC.Write(p) + if w.slowThreshold == 0 { + return + } + + now := mono.Now() + w.rate.AddBytes(n, int(w.bufferGetter.BufferedAmount()), now) + // retry if the write timed out on a non-slow receiver + if errors.Is(err, context.DeadlineExceeded) && w.rate.Bitrate(now) > w.slowThreshold { + continue + } + + return + } +} + +func (w *DataChannelWriter[T]) Close() error { + return w.rawDC.Close() +} diff --git a/pkg/sfu/datachannel/datachannel_writer_test.go b/pkg/sfu/datachannel/datachannel_writer_test.go new file mode 100644 index 000000000..b07f8a3b6 --- /dev/null +++ b/pkg/sfu/datachannel/datachannel_writer_test.go @@ -0,0 +1,94 @@ +package datachannel + +import ( + "context" + "testing" + "time" + + "github.com/pion/datachannel" + "github.com/pion/transport/v3/deadline" + "github.com/stretchr/testify/require" +) + +func TestDataChannelWriter(t *testing.T) { + mockDC := newMockDataChannelWriter() + // slow threshold is 1000B/s + w := NewDataChannelWriter(mockDC, mockDC, 8000) + require.Equal(t, mockDC, w.BufferedAmountGetter()) + buf := make([]byte, 2000) + // write 2000 bytes so it should not drop in 2 seconds + t0 := time.Now() + n, err := w.Write(buf) + require.NoError(t, err) + require.Equal(t, 2000, n) + + t1 := time.Now() + mockDC.SetNextWriteCompleteAt(t0.Add(time.Second)) + n, err = w.Write(buf[:10]) + require.NoError(t, err) + require.Equal(t, 10, n) + require.GreaterOrEqual(t, time.Since(t1), time.Second) + + // bitrate below slow threshold(2000bytes/3sec), should drop by timeout + mockDC.SetNextWriteCompleteAt(t0.Add(3 * time.Second)) + n, err = w.Write(buf[:1000]) + require.ErrorIs(t, err, context.DeadlineExceeded, err) + require.Equal(t, 0, n) +} + +func TestDataChannelWriter_NoSlowThreshold(t *testing.T) { + mockDC := newMockDataChannelWriter() + w := NewDataChannelWriter(mockDC, mockDC, 0) + buf := make([]byte, 2000) + n, err := w.Write(buf) + require.NoError(t, err) + require.Equal(t, 2000, n) + mockDC.SetNextWriteCompleteAt(time.Now().Add(singleWriteTimeout / 2)) + n, err = w.Write(buf[:10]) + require.NoError(t, err) + require.Equal(t, 10, n) + + // slow threshold is 0, should not block & retry + mockDC.SetNextWriteCompleteAt(time.Now().Add(singleWriteTimeout * 2)) + n, err = w.Write(buf[:1000]) + require.ErrorIs(t, err, context.DeadlineExceeded, err) + require.Equal(t, 0, n) +} + +type mockDataChannelWriter struct { + datachannel.ReadWriteCloserDeadliner + nextWriteCompleteAt time.Time + deadline *deadline.Deadline +} + +func newMockDataChannelWriter() *mockDataChannelWriter { + return &mockDataChannelWriter{ + deadline: deadline.New(), + } +} + +func (m *mockDataChannelWriter) BufferedAmount() uint64 { + return 0 +} + +func (m *mockDataChannelWriter) Write(b []byte) (int, error) { + wait := time.Until(m.nextWriteCompleteAt) + if wait <= 0 { + return len(b), nil + } + select { + case <-m.deadline.Done(): + return 0, m.deadline.Err() + case <-time.After(wait): + return len(b), nil + } +} + +func (m *mockDataChannelWriter) SetWriteDeadline(t time.Time) error { + m.deadline.Set(t) + return nil +} + +func (m *mockDataChannelWriter) SetNextWriteCompleteAt(t time.Time) { + m.nextWriteCompleteAt = t +} diff --git a/test/client/client.go b/test/client/client.go index 7bdb6ae5c..7304532fa 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -214,23 +214,25 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { // publisherHandler := &transportfakes.FakeHandler{} c.publisher, err = rtc.NewPCTransport(rtc.TransportParams{ - Config: &conf, - DirectionConfig: conf.Subscriber, - EnabledCodecs: codecs, - IsOfferer: true, - IsSendSide: true, - Handler: publisherHandler, + Config: &conf, + DirectionConfig: conf.Subscriber, + EnabledCodecs: codecs, + IsOfferer: true, + IsSendSide: true, + Handler: publisherHandler, + DatachannelSlowThreshold: 1024 * 1024 * 1024, }) if err != nil { return nil, err } subscriberHandler := &transportfakes.FakeHandler{} c.subscriber, err = rtc.NewPCTransport(rtc.TransportParams{ - Config: &conf, - DirectionConfig: conf.Publisher, - EnabledCodecs: codecs, - Handler: subscriberHandler, - FireOnTrackBySdp: true, + Config: &conf, + DirectionConfig: conf.Publisher, + EnabledCodecs: codecs, + Handler: subscriberHandler, + DatachannelMaxReceiverBufferSize: 1500, + FireOnTrackBySdp: true, }) if err != nil { return nil, err diff --git a/test/client/datachannel_reader.go b/test/client/datachannel_reader.go new file mode 100644 index 000000000..e43d8908b --- /dev/null +++ b/test/client/datachannel_reader.go @@ -0,0 +1,35 @@ +package client + +import ( + "time" + + "github.com/livekit/livekit-server/pkg/sfu/datachannel" +) + +type DataChannelReader struct { + bitrate *datachannel.BitrateCalculator + target int +} + +func NewDataChannelReader(bitrate int) *DataChannelReader { + return &DataChannelReader{ + target: bitrate, + bitrate: datachannel.NewBitrateCalculator(datachannel.BitrateDuration*5, datachannel.BitrateWindow), + } +} + +func (d *DataChannelReader) Read(p []byte, sid string) { + for { + if bitrate := d.bitrate.Bitrate(time.Now()); bitrate > 0 && bitrate > d.target { + time.Sleep(20 * time.Millisecond) + d.bitrate.AddBytes(0, 0, time.Now()) + continue + } + break + } + d.bitrate.AddBytes(len(p), 0, time.Now()) +} + +func (d *DataChannelReader) Bitrate() int { + return d.bitrate.Bitrate(time.Now()) +} diff --git a/test/singlenode_test.go b/test/singlenode_test.go index de2133ca0..0e83fce10 100644 --- a/test/singlenode_test.go +++ b/test/singlenode_test.go @@ -16,6 +16,7 @@ package test import ( "context" + "encoding/binary" "errors" "fmt" "net/http" @@ -28,6 +29,7 @@ import ( "github.com/stretchr/testify/require" "github.com/thoas/go-funk" "github.com/twitchtv/twirp" + "go.uber.org/atomic" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" @@ -707,6 +709,124 @@ func TestSubscribeToCodecUnsupported(t *testing.T) { require.Nil(t, c2.GetSubscriptionResponseAndClear()) } +func TestDataPublishSlowSubscriber(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + dataChannelSlowThreshold := 101024 + + logger.Infow("----------------STARTING TEST----------------", "test", t.Name()) + s := createSingleNodeServer(func(c *config.Config) { + c.RTC.DatachannelSlowThreshold = dataChannelSlowThreshold + }) + go func() { + if err := s.Start(); err != nil { + logger.Errorw("server returned error", err) + } + }() + + waitForServerToStart(s) + + defer func() { + s.Stop(true) + logger.Infow("----------------FINISHING TEST----------------", "test", t.Name()) + }() + + pub := createRTCClient("pub", defaultServerPort, nil) + fastSub := createRTCClient("fastSub", defaultServerPort, nil) + slowSubNotDrop := createRTCClient("slowSubNotDrop", defaultServerPort, nil) + slowSubDrop := createRTCClient("slowSubDrop", defaultServerPort, nil) + waitUntilConnected(t, pub, fastSub, slowSubDrop, slowSubNotDrop) + defer func() { + pub.Stop() + fastSub.Stop() + slowSubNotDrop.Stop() + slowSubDrop.Stop() + }() + + // publisher sends data as fast as possible, it will block by the slowest subscriber above the slow threshold + var ( + blocked atomic.Bool + stopWrite atomic.Bool + writeIdx atomic.Uint64 + ) + writeStopped := make(chan struct{}) + go func() { + defer close(writeStopped) + var i int + buf := make([]byte, 100) + for !stopWrite.Load() { + i++ + binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(i)) + if err := pub.PublishData(buf, livekit.DataPacket_RELIABLE); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + blocked.Store(true) + i-- + continue + } else { + t.Log("error writing", err) + break + } + } + writeIdx.Store(uint64(i)) + } + }() + + // no data should be dropped for fast subscriber + var fastDataIndex atomic.Uint64 + fastSub.OnDataReceived = func(data []byte, sid string) { + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + require.Equal(t, fastDataIndex.Load()+1, idx) + fastDataIndex.Store(idx) + } + + // no data should be dropped for slow subscriber that is above threshold + var slowNoDropDataIndex atomic.Uint64 + var drainSlowSubNotDrop atomic.Bool + slowNoDropReader := testclient.NewDataChannelReader(dataChannelSlowThreshold * 3 / 2) + slowSubNotDrop.OnDataReceived = func(data []byte, sid string) { + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + require.Equal(t, slowNoDropDataIndex.Load()+1, idx) + slowNoDropDataIndex.Store(idx) + if !drainSlowSubNotDrop.Load() { + slowNoDropReader.Read(data, sid) + } + } + + // data should be dropped for slow subscriber that is below threshold + var slowDropDataIndex atomic.Uint64 + dropped := make(chan struct{}) + slowDropReader := testclient.NewDataChannelReader(dataChannelSlowThreshold / 2) + slowSubDrop.OnDataReceived = func(data []byte, sid string) { + select { + case <-dropped: + return + default: + } + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + if idx != slowDropDataIndex.Load()+1 { + close(dropped) + } + slowDropDataIndex.Store(idx) + slowDropReader.Read(data, sid) + } + + <-dropped + + time.Sleep(time.Second) + blocked.Store(false) + require.Eventually(t, func() bool { return blocked.Load() }, 30*time.Second, 100*time.Millisecond) + drainSlowSubNotDrop.Store(true) + stopWrite.Store(true) + <-writeStopped + require.Eventually(t, func() bool { + return writeIdx.Load() == fastDataIndex.Load() && + writeIdx.Load() == slowNoDropDataIndex.Load() + }, 5*time.Second, 50*time.Millisecond) +} + func TestFireTrackBySdp(t *testing.T) { _, finish := setupSingleNodeTest("TestFireTrackBySdp") defer finish()