From 8e74b1fa976be91f6dc646d5e35cc5d84e592015 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sun, 20 Mar 2022 21:19:10 +0000 Subject: [PATCH] call sendData with an empty bytestring --- src/Simplex/Messaging/Transport.hs | 32 +++++++++++------ src/Simplex/Messaging/Transport/KeepAlive.hs | 37 ++++++++++++++++++-- 2 files changed, 56 insertions(+), 13 deletions(-) diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 8379c35bd..51fd305f5 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -80,6 +80,7 @@ import qualified Network.TLS.Extra as TE import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (dropPrefix, parse, parseRead1, sumTypeJSON) +import Simplex.Messaging.Transport.KeepAlive import Simplex.Messaging.Util (bshow) import Simplex.Messaging.Version import Test.QuickCheck (Arbitrary (..)) @@ -147,7 +148,8 @@ data TLS = TLS tlsPeer :: TransportPeer, tlsUniq :: ByteString, buffer :: TVar ByteString, - getLock :: TMVar () + getLock :: TMVar (), + keepAlive :: Maybe KeepAliveThread } connectTLS :: T.TLSParams p => p -> Socket -> IO T.Context @@ -163,7 +165,7 @@ getTLS tlsPeer cxt = withTlsUnique tlsPeer cxt newTLS newTLS tlsUniq = do buffer <- newTVarIO "" getLock <- newTMVarIO () - pure TLS {tlsContext = cxt, tlsPeer, tlsUniq, buffer, getLock} + pure TLS {tlsContext = cxt, tlsPeer, tlsUniq, buffer, getLock, keepAlive = Nothing} withTlsUnique :: TransportPeer -> T.Context -> (ByteString -> IO c) -> IO c withTlsUnique peer cxt f = @@ -195,12 +197,17 @@ instance Transport TLS where transportName _ = "TLS" transportPeer = tlsPeer getServerConnection = getTLS TServer - getClientConnection = getTLS TClient + getClientConnection cxt = do + tls <- getTLS TClient cxt + keepAlive <- Just <$> startKeepAlive (tlsContext tls) + pure tls {keepAlive} tlsUnique = tlsUniq - closeConnection tls = closeTLS $ tlsContext tls + closeConnection TLS {tlsContext, keepAlive} = do + mapM_ stopKeepAlive keepAlive + closeTLS tlsContext cGet :: TLS -> Int -> IO ByteString - cGet TLS {tlsContext, buffer, getLock} n = + cGet TLS {tlsContext, buffer, getLock, keepAlive} n = E.bracket_ (atomically $ takeTMVar getLock) (atomically $ putTMVar getLock ()) @@ -213,16 +220,21 @@ instance Transport TLS where readChunks :: ByteString -> IO ByteString readChunks b | B.length b >= n = pure b - | otherwise = readChunks . (b <>) =<< T.recvData tlsContext `E.catch` handleEOF + | otherwise = do + chunk <- T.recvData tlsContext `E.catch` handleEOF + mapM_ touchKeepAlive keepAlive + readChunks $ b <> chunk handleEOF = \case T.Error_EOF -> E.throwIO TEBadBlock e -> E.throwIO e cPut :: TLS -> ByteString -> IO () - cPut tls = T.sendData (tlsContext tls) . BL.fromStrict + cPut TLS {tlsContext, keepAlive} s = do + mapM_ touchKeepAlive keepAlive + T.sendData tlsContext $ BL.fromStrict s getLn :: TLS -> IO ByteString - getLn TLS {tlsContext, buffer, getLock} = do + getLn TLS {tlsContext, buffer, getLock, keepAlive} = do E.bracket_ (atomically $ takeTMVar getLock) (atomically $ putTMVar getLock ()) @@ -236,9 +248,9 @@ instance Transport TLS where readChunks b | B.elem '\n' b = pure b | otherwise = readChunks . (b <>) =<< T.recvData tlsContext `E.catch` handleEOF - handleEOF = \case + handleEOF e = mapM_ stopKeepAlive keepAlive >> case e of T.Error_EOF -> E.throwIO TEBadBlock - e -> E.throwIO e + _ -> E.throwIO e -- | Trim trailing CR from ByteString. trimCR :: ByteString -> ByteString diff --git a/src/Simplex/Messaging/Transport/KeepAlive.hs b/src/Simplex/Messaging/Transport/KeepAlive.hs index 0f7ee9b1a..57ce05eb0 100644 --- a/src/Simplex/Messaging/Transport/KeepAlive.hs +++ b/src/Simplex/Messaging/Transport/KeepAlive.hs @@ -1,11 +1,17 @@ {-# LANGUAGE CApiFFI #-} {-# LANGUAGE CPP #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} module Simplex.Messaging.Transport.KeepAlive where +import Control.Concurrent +import Control.Concurrent.STM +import Control.Monad +import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Foreign.C (CInt (..)) import Network.Socket +import qualified Network.TLS as T foreign import capi "netinet/tcp.h value TCP_KEEPCNT" tcpKeepCnt :: CInt @@ -28,9 +34,9 @@ data KeepAliveOpts = KeepAliveOpts defaultKeepAlive :: KeepAliveOpts defaultKeepAlive = KeepAliveOpts - { keepCnt = 2, - keepIdle = 30, - keepIntvl = 15 + { keepCnt = 4, + keepIdle = 60, + keepIntvl = 30 } setSocketKeepAlive :: Socket -> KeepAliveOpts -> IO () @@ -39,3 +45,28 @@ setSocketKeepAlive sock KeepAliveOpts {keepCnt, keepIdle, keepIntvl} = do setSocketOption sock (SockOpt solTcp tcpKeepCnt) keepCnt setSocketOption sock (SockOpt solTcp tcpKeepIdle) keepIdle setSocketOption sock (SockOpt solTcp tcpKeepIntvl) keepIntvl + +data KeepAliveThread = KeepAliveThread + { threadId :: ThreadId, + dataTs :: TVar SystemTime + } + +startKeepAlive :: T.Context -> IO KeepAliveThread +startKeepAlive cxt = do + dataTs <- newTVarIO =<< getSystemTime + threadId <- forkIO . forever $ do + threadDelay 30000000 + ts' <- getSystemTime + doPing <- atomically $ do + ts <- readTVar dataTs + let ping = systemSeconds ts' - systemSeconds ts >= 30 + when ping $ writeTVar dataTs ts' + pure ping + when doPing $ putStrLn "*** ping ***" >> T.sendData cxt "" + pure KeepAliveThread {threadId, dataTs} + +touchKeepAlive :: KeepAliveThread -> IO () +touchKeepAlive KeepAliveThread {dataTs} = atomically . writeTVar dataTs =<< getSystemTime + +stopKeepAlive :: KeepAliveThread -> IO () +stopKeepAlive KeepAliveThread {threadId} = killThread threadId