diff --git a/cmd/cli/client/client.go b/cmd/cli/client/client.go index 1a37fed81..dd60527e5 100644 --- a/cmd/cli/client/client.go +++ b/cmd/cli/client/client.go @@ -101,7 +101,7 @@ func NewRTCClient(conn *websocket.Conn) (*RTCClient, error) { peerConn.OnTrack(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { c.AppendLog("track received", "label", track.StreamID(), "id", track.ID()) peerId, _ := rtc.UnpackTrackId(track.ID()) - r := rtc.NewReceiver(c.ctx, peerId, rtpReceiver, nil) + r := rtc.NewReceiver(peerId, rtpReceiver, nil) c.lock.Lock() c.receivers = append(c.receivers, r) r.Start() @@ -452,7 +452,7 @@ func (c *RTCClient) consumeReceiver(r *rtc.Receiver) { numBytes := 0 for { pkt, err := r.ReadRTP() - if err == io.EOF || err == io.ErrClosedPipe { + if rtc.IsEOF(err) { // all done return } diff --git a/go.mod b/go.mod index 25170004f..d69dcfdc1 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/gorilla/websocket v1.4.2 github.com/lithammer/shortuuid/v3 v3.0.4 github.com/lunixbochs/vtclean v1.0.0 // indirect + github.com/magiconair/properties v1.8.1 github.com/manifoldco/promptui v0.8.0 github.com/pion/interceptor v0.0.5 github.com/pion/ion-log v1.0.0 diff --git a/go.sum b/go.sum index 66f39f0ab..cefa0373d 100644 --- a/go.sum +++ b/go.sum @@ -222,6 +222,7 @@ github.com/lunixbochs/vtclean v0.0.0-20180621232353-2d01aacdc34a/go.mod h1:pHhQN github.com/lunixbochs/vtclean v1.0.0 h1:xu2sLAri4lGiovBDQKxl5mrXyESr3gUr5m5SM5+LVb8= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= +github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4= github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/manifoldco/promptui v0.8.0 h1:R95mMF+McvXZQ7j1g8ucVZE1gLP3Sv6j9vlF9kyRqQo= github.com/manifoldco/promptui v0.8.0/go.mod h1:n4zTdgP0vr0S3w7/O/g98U+e0gwLScEXGwov2nIKuGQ= @@ -296,8 +297,6 @@ github.com/pion/interceptor v0.0.5 h1:BOwlubM1lntji3eNaVrhW1Qk3u1UoemrhM4mbv24XG github.com/pion/interceptor v0.0.5/go.mod h1:lPVrf5xfosI989ZcmgPS4WwwRhd+XAyTFaYI2wHf7nU= github.com/pion/ion-log v1.0.0 h1:2lJLImCmfCWCR38hLWsjQfBWe6NFz/htbqiYHwvOP/Q= github.com/pion/ion-log v1.0.0/go.mod h1:jwcla9KoB9bB/4FxYDSRJPcPYSLp5XiUUMnOLaqwl4E= -github.com/pion/ion-sfu v1.6.3 h1:qK0nn57I2DDsylszNZPjbroF8V1MI8nE4wsDePf/s9U= -github.com/pion/ion-sfu v1.6.3/go.mod h1:xHrwxirzClAvn056es4grzQq0BactA7esDBsQuRf7k8= github.com/pion/ion-sfu v1.6.5 h1:L1V0eJ2hW0ox6LJAKBayOVaoHzQMIqKMP+kjS5IMp6Q= github.com/pion/ion-sfu v1.6.5/go.mod h1:1NUUIynUZuNjfnc/r7sjeI7RlVk+sq6q/sFnu8x9Sv8= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index f1af7cf60..66f74f5d7 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -360,7 +360,7 @@ func (p *Participant) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *webrt logger.GetLogger().Debugw("remoteTrack added", "participantId", p.ID(), "remoteTrack", track.ID()) // create Receiver - receiver := NewReceiver(p.ctx, p.id, rtpReceiver, p.bi) + receiver := NewReceiver(p.id, rtpReceiver, p.bi) mt := NewMediaTrack(p.id, p.rtcpCh, track, receiver) p.handleTrackPublished(mt) diff --git a/pkg/rtc/receiver.go b/pkg/rtc/receiver.go index d699accdc..7e52d49bb 100644 --- a/pkg/rtc/receiver.go +++ b/pkg/rtc/receiver.go @@ -1,8 +1,6 @@ package rtc import ( - "context" - "io" "sync" "github.com/pion/ion-sfu/pkg/buffer" @@ -20,8 +18,6 @@ const ( // A receiver is responsible for pulling from a remoteTrack type Receiver struct { peerId string - ctx context.Context - cancel context.CancelFunc rtpReceiver *webrtc.RTPReceiver track *webrtc.TrackRemote bi *buffer.Interceptor @@ -29,15 +25,11 @@ type Receiver struct { bytesRead int64 } -func NewReceiver(ctx context.Context, peerId string, rtpReceiver *webrtc.RTPReceiver, bi *buffer.Interceptor) *Receiver { - ctx, cancel := context.WithCancel(ctx) - track := rtpReceiver.Track() +func NewReceiver(peerId string, rtpReceiver *webrtc.RTPReceiver, bi *buffer.Interceptor) *Receiver { return &Receiver{ - ctx: ctx, - cancel: cancel, peerId: peerId, rtpReceiver: rtpReceiver, - track: track, + track: rtpReceiver.Track(), bi: bi, once: sync.Once{}, } @@ -58,14 +50,6 @@ func (r *Receiver) Start() { }) } -// Close gracefully close the remoteTrack. if the context is canceled -func (r *Receiver) Close() { - if r.ctx.Err() != nil { - return - } - r.cancel() -} - // PacketBuffer interface, to provide forwarders packets from the buffer func (r *Receiver) GetBufferedPackets(mediaSSRC uint32, snOffset uint16, tsOffset uint32, sn []uint16) []rtp.Packet { if r.bi == nil { @@ -80,9 +64,10 @@ func (r *Receiver) ReadRTP() (*rtp.Packet, error) { // rtcpWorker reads RTCP messages from receiver, notifies buffer func (r *Receiver) rtcpWorker() { + // consume RTCP from the sender/source, but don't need to do anything with the packets for { _, err := r.rtpReceiver.ReadRTCP() - if err == io.ErrClosedPipe || r.ctx.Err() != nil { + if IsEOF(err) { return } if err != nil { diff --git a/pkg/rtc/utils.go b/pkg/rtc/utils.go index b13e323e2..f65b6239f 100644 --- a/pkg/rtc/utils.go +++ b/pkg/rtc/utils.go @@ -2,6 +2,7 @@ package rtc import ( "encoding/json" + "io" "strings" "github.com/pion/webrtc/v3" @@ -85,3 +86,7 @@ func FromProtoTrickle(trickle *livekit.Trickle) webrtc.ICECandidateInit { json.Unmarshal([]byte(trickle.CandidateInit), &ci) return ci } + +func IsEOF(err error) bool { + return err == io.ErrClosedPipe || err == io.EOF +} diff --git a/pkg/rtc/utils_test.go b/pkg/rtc/utils_test.go new file mode 100644 index 000000000..6c9d91e48 --- /dev/null +++ b/pkg/rtc/utils_test.go @@ -0,0 +1,29 @@ +package rtc + +import ( + "testing" + + "github.com/magiconair/properties/assert" +) + +func TestPackTrackId(t *testing.T) { + packed := "PA_123abc|uuid-id" + pId, trackId := UnpackTrackId(packed) + assert.Equal(t, "PA_123abc", pId) + assert.Equal(t, "uuid-id", trackId) + + assert.Equal(t, packed, PackTrackId(pId, trackId)) +} + +func TestPackDataTrackLabel(t *testing.T) { + pId := "PA_123abc" + trackId := "TR_b3da25" + label := "trackLabel" + packed := "PA_123abc|TR_b3da25|trackLabel" + assert.Equal(t, packed, PackDataTrackLabel(pId, trackId, label)) + + p, tr, l := UnpackDataTrackLabel(packed) + assert.Equal(t, pId, p) + assert.Equal(t, trackId, tr) + assert.Equal(t, label, l) +} diff --git a/pkg/rtc/wsprotocol.go b/pkg/rtc/wsprotocol.go index d01843a28..de4460868 100644 --- a/pkg/rtc/wsprotocol.go +++ b/pkg/rtc/wsprotocol.go @@ -2,6 +2,7 @@ package rtc import ( "sync" + "time" "github.com/gorilla/websocket" "google.golang.org/protobuf/encoding/protojson" @@ -11,23 +12,36 @@ import ( "github.com/livekit/livekit-server/proto/livekit" ) +const ( + pingFrequency = 10 * time.Second + pingTimeout = 2 * time.Second +) + +type WebsocketClient interface { + ReadMessage() (messageType int, p []byte, err error) + WriteMessage(messageType int, data []byte) error + WriteControl(messageType int, data []byte, deadline time.Time) error +} + type SignalConnection interface { ReadRequest() (*livekit.SignalRequest, error) WriteResponse(*livekit.SignalResponse) error } type WSSignalConnection struct { - conn *websocket.Conn + conn WebsocketClient mu sync.Mutex useJSON bool } -func NewWSSignalConnection(conn *websocket.Conn) *WSSignalConnection { - return &WSSignalConnection{ +func NewWSSignalConnection(conn WebsocketClient) *WSSignalConnection { + wsc := &WSSignalConnection{ conn: conn, mu: sync.Mutex{}, useJSON: true, } + go wsc.pingWorker() + return wsc } func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, error) { @@ -40,9 +54,6 @@ func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, error) { msg := &livekit.SignalRequest{} switch messageType { - case websocket.PingMessage: - c.conn.WriteMessage(websocket.PongMessage, nil) - continue case websocket.BinaryMessage: // protobuf encoded err := proto.Unmarshal(payload, msg) @@ -79,3 +90,13 @@ func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) error { defer c.mu.Unlock() return c.conn.WriteMessage(msgType, payload) } + +func (c *WSSignalConnection) pingWorker() { + for { + <-time.After(pingFrequency) + err := c.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(pingTimeout)) + if err != nil { + return + } + } +} diff --git a/pkg/service/rtc.go b/pkg/service/rtc.go index a6eb37676..476832b1d 100644 --- a/pkg/service/rtc.go +++ b/pkg/service/rtc.go @@ -98,16 +98,14 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err == io.EOF { // client disconnected from websocket return - } - if err != nil { - // most of these are disconnection, just return vs clogging up logs - //logger.GetLogger().Errorw("error reading WS", - // "err", err, - // "participantName", pName, - // "roomId", roomId) + } else if err != nil { return } + if req == nil { + continue + } + switch msg := req.Message.(type) { case *livekit.SignalRequest_Offer: err = s.handleOffer(participant, msg.Offer)