diff --git a/CHANGELOG.md b/CHANGELOG.md index 9696a64b3..48e6d77ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +# 6.0.1 + +SMP agent: +- support changing user of the new connection. +- do not start delivery workers when there are no messages to deliver. +- enable notifications for all connections. +- combine database transactions when subscribing. +- store query errors, reduce slow query threshold to 1ms. + +SMP server: +- safe compacting of store log. +- fix possible race when creating client that might lead to memory leak. + +Dependencies: upgrade tls to 1.9 + # 6.0.0 Version 6.0.0.8 diff --git a/package.yaml b/package.yaml index 26cdcc51a..789e2e151 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: simplexmq -version: 6.0.0.8 +version: 6.0.1.0 synopsis: SimpleXMQ message broker description: | This package includes <./docs/Simplex-Messaging-Server.html server>, @@ -69,7 +69,7 @@ dependencies: - temporary == 1.3.* - time == 1.12.* - time-manager == 0.0.* - - tls >= 1.7.0 && < 1.8 + - tls >= 1.9.0 && < 1.10 - transformers == 0.6.* - unliftio == 0.2.* - unliftio-core == 0.2.* diff --git a/simplexmq.cabal b/simplexmq.cabal index d557ac509..d222352a1 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -5,7 +5,7 @@ cabal-version: 1.12 -- see: https://github.com/sol/hpack name: simplexmq -version: 6.0.0.8 +version: 6.0.1.0 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and @@ -258,7 +258,7 @@ library , temporary ==1.3.* , time ==1.12.* , time-manager ==0.0.* - , tls >=1.7.0 && <1.8 + , tls >=1.9.0 && <1.10 , transformers ==0.6.* , unliftio ==0.2.* , unliftio-core ==0.2.* @@ -333,7 +333,7 @@ executable ntf-server , temporary ==1.3.* , time ==1.12.* , time-manager ==0.0.* - , tls >=1.7.0 && <1.8 + , tls >=1.9.0 && <1.10 , transformers ==0.6.* , unliftio ==0.2.* , unliftio-core ==0.2.* @@ -412,7 +412,7 @@ executable smp-server , temporary ==1.3.* , time ==1.12.* , time-manager ==0.0.* - , tls >=1.7.0 && <1.8 + , tls >=1.9.0 && <1.10 , transformers ==0.6.* , unliftio ==0.2.* , unliftio-core ==0.2.* @@ -490,7 +490,7 @@ executable xftp , temporary ==1.3.* , time ==1.12.* , time-manager ==0.0.* - , tls >=1.7.0 && <1.8 + , tls >=1.9.0 && <1.10 , transformers ==0.6.* , unliftio ==0.2.* , unliftio-core ==0.2.* @@ -565,7 +565,7 @@ executable xftp-server , temporary ==1.3.* , time ==1.12.* , time-manager ==0.0.* - , tls >=1.7.0 && <1.8 + , tls >=1.9.0 && <1.10 , transformers ==0.6.* , unliftio ==0.2.* , unliftio-core ==0.2.* @@ -681,7 +681,7 @@ test-suite simplexmq-test , time ==1.12.* , time-manager ==0.0.* , timeit ==2.0.* - , tls >=1.7.0 && <1.8 + , tls >=1.9.0 && <1.10 , transformers ==0.6.* , unliftio ==0.2.* , unliftio-core ==0.2.* diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 672375aaf..176d33403 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -53,6 +53,7 @@ module Simplex.Messaging.Agent deleteConnectionAsync, deleteConnectionsAsync, createConnection, + changeConnectionUser, prepareConnectionToJoin, joinConnection, allowConnection, @@ -131,6 +132,7 @@ import Data.Bifunctor (bimap, first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.), (.::), (.::.)) +import Data.Containers.ListUtils (nubOrd) import Data.Either (isRight, rights) import Data.Foldable (foldl', toList) import Data.Functor (($>)) @@ -333,6 +335,11 @@ createConnection :: AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe createConnection c userId enableNtfs = withAgentEnv c .:: newConn c userId "" enableNtfs {-# INLINE createConnection #-} +-- | Changes the user id associated with a connection +changeConnectionUser :: AgentClient -> UserId -> ConnId -> UserId -> AE () +changeConnectionUser c oldUserId connId newUserId = withAgentEnv c $ changeConnectionUser' c oldUserId connId newUserId +{-# INLINE changeConnectionUser #-} + -- | Create SMP agent connection without queue (to be joined with joinConnection passing connection ID). -- This method is required to prevent race condition when confirmation from peer is received before -- the caller of joinConnection saves connection ID to the database. @@ -741,6 +748,16 @@ newConn :: AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe newConn c userId connId enableNtfs cMode clientData pqInitKeys subMode = getSMPServer c userId >>= newConnSrv c userId connId False enableNtfs cMode clientData pqInitKeys subMode +changeConnectionUser' :: AgentClient -> UserId -> ConnId -> UserId -> AM () +changeConnectionUser' c oldUserId connId newUserId = do + SomeConn _ conn <- withStore c (`getConn` connId) + case conn of + NewConnection {} -> updateConn + RcvConnection {} -> updateConn + _ -> throwE $ CMD PROHIBITED "changeConnectionUser: established connection" + where + updateConn = withStore' c $ \db -> setConnUserId db oldUserId connId newUserId + newConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (ConnId, ConnectionRequestUri c) newConnSrv c userId connId hasNewConn enableNtfs cMode clientData pqInitKeys subMode srv = do connId' <- @@ -958,12 +975,12 @@ subscribeConnections' c connIds = do let (errs, cs) = M.mapEither id conns errs' = M.map (Left . storeError) errs (subRs, rcvQs) = M.mapEither rcvQueueOrResult cs - mapM_ (mapM_ (\(cData, sqs) -> mapM_ (lift . resumeMsgDelivery c cData) sqs) . sndQueue) cs - mapM_ (resumeConnCmds c) $ M.keys cs + resumeDelivery cs + lift $ resumeConnCmds c $ M.keys cs rcvRs <- lift $ connResults . fst <$> subscribeQueues c (concat $ M.elems rcvQs) ns <- asks ntfSupervisor tkn <- readTVarIO (ntfTkn ns) - when (instantNotifications tkn) . void . lift . forkIO . void . runExceptT $ sendNtfCreate ns rcvRs conns + lift $ when (instantNotifications tkn) . void . forkIO . void $ sendNtfCreate ns rcvRs cs let rs = M.unions ([errs', subRs, rcvRs] :: [Map ConnId (Either AgentErrorType ())]) notifyResultError rs pure rs @@ -995,15 +1012,20 @@ subscribeConnections' c connIds = do order (Active, _) = 2 order (_, Right _) = 3 order _ = 4 - sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> Map ConnId (Either StoreError SomeConn) -> AM () - sendNtfCreate ns rcvRs conns = - forM_ (M.assocs rcvRs) $ \case - (connId, Right _) -> forM_ (M.lookup connId conns) $ \case - Right (SomeConn _ conn) -> do - let cmd = if enableNtfs $ toConnData conn then NSCCreate else NSCDelete - atomically $ writeTBQueue (ntfSubQ ns) (connId, cmd) - _ -> pure () - _ -> pure () + sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> Map ConnId SomeConn -> AM' () + sendNtfCreate ns rcvRs cs = do + -- TODO this needs to be batched end to end. + -- Currently, the only change is to ignore failed subscriptions. + let oks = M.keysSet $ M.filter (either temporaryAgentError $ const True) rcvRs + forM_ (M.restrictKeys cs oks) $ \case + SomeConn _ conn -> do + let cmd = if enableNtfs $ toConnData conn then NSCCreate else NSCDelete + ConnData {connId} = toConnData conn + atomically $ writeTBQueue (ntfSubQ ns) (connId, cmd) + resumeDelivery :: Map ConnId SomeConn -> AM () + resumeDelivery conns = do + conns' <- M.restrictKeys conns . S.fromList <$> withStore' c getConnectionsForDelivery + lift $ mapM_ (mapM_ (\(cData, sqs) -> mapM_ (resumeMsgDelivery c cData) sqs) . sndQueue) conns' sndQueue :: SomeConn -> Maybe (ConnData, NonEmpty SndQueue) sndQueue (SomeConn _ conn) = case conn of DuplexConnection cData _ sqs -> Just (cData, sqs) @@ -1118,13 +1140,10 @@ resumeSrvCmds :: AgentClient -> Maybe SMPServer -> AM' () resumeSrvCmds = void .: getAsyncCmdWorker False {-# INLINE resumeSrvCmds #-} -resumeConnCmds :: AgentClient -> ConnId -> AM () -resumeConnCmds c connId = - unlessM connQueued $ - withStore' c (`getPendingCommandServers` connId) - >>= mapM_ (lift . resumeSrvCmds c) - where - connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connCmdsQueued c) +resumeConnCmds :: AgentClient -> [ConnId] -> AM' () +resumeConnCmds c connIds = do + srvs <- nubOrd . concat . rights <$> withStoreBatch' c (\db -> fmap (getPendingCommandServers db) connIds) + mapM_ (resumeSrvCmds c) srvs getAsyncCmdWorker :: Bool -> AgentClient -> Maybe SMPServer -> AM' Worker getAsyncCmdWorker hasWork c server = diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 23f0a98d1..fbdb53548 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -163,7 +163,7 @@ module Simplex.Messaging.Agent.Client where import Control.Applicative ((<|>)) -import Control.Concurrent (ThreadId, forkIO) +import Control.Concurrent (ThreadId, killThread) import Control.Concurrent.Async (Async, uninterruptibleCancel) import Control.Concurrent.STM (retry) import Control.Exception (AsyncException (..), BlockedIndefinitelyOnSTM (..)) @@ -266,10 +266,11 @@ import Simplex.Messaging.Transport (SMPVersion, SessionId, THandleParams (sessio import Simplex.Messaging.Transport.Client (TransportHost (..)) import Simplex.Messaging.Util import Simplex.Messaging.Version -import System.Mem.Weak (Weak) +import System.Mem.Weak (Weak, deRefWeak) import System.Random (randomR) import UnliftIO (mapConcurrently, timeout) import UnliftIO.Async (async) +import UnliftIO.Concurrent (forkIO, mkWeakThreadId) import UnliftIO.Directory (doesFileExist, getTemporaryDirectory, removeFile) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -313,7 +314,6 @@ data AgentClient = AgentClient workerSeq :: TVar Int, smpDeliveryWorkers :: TMap SndQAddr (Worker, TMVar ()), asyncCmdWorkers :: TMap (Maybe SMPServer) Worker, - connCmdsQueued :: TMap ConnId Bool, ntfNetworkOp :: TVar AgentOpState, rcvNetworkOp :: TVar AgentOpState, msgDeliveryOp :: TVar AgentOpState, @@ -411,7 +411,7 @@ runWorkerAsync Worker {action} work = (atomically . tryPutTMVar action) -- if it was running (or if start crashes), put it back and unlock (don't lock if it was just started) (\a -> when (isNothing a) start) -- start worker if it's not running where - start = atomically . putTMVar action . Just =<< async work + start = atomically . putTMVar action . Just =<< mkWeakThreadId =<< forkIO work data AgentOperation = AONtfNetwork | AORcvNetwork | AOMsgDelivery | AOSndNetwork | AODatabase deriving (Eq, Show) @@ -480,7 +480,6 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} currentTs a workerSeq <- newTVarIO 0 smpDeliveryWorkers <- TM.emptyIO asyncCmdWorkers <- TM.emptyIO - connCmdsQueued <- TM.emptyIO ntfNetworkOp <- newTVarIO $ AgentOpState False 0 rcvNetworkOp <- newTVarIO $ AgentOpState False 0 msgDeliveryOp <- newTVarIO $ AgentOpState False 0 @@ -519,7 +518,6 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} currentTs a workerSeq, smpDeliveryWorkers, asyncCmdWorkers, - connCmdsQueued, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, @@ -893,7 +891,6 @@ closeAgentClient c = do atomically (swapTVar (smpSubWorkers c) M.empty) >>= mapM_ cancelReconnect clearWorkers smpDeliveryWorkers >>= mapM_ (cancelWorker . fst) clearWorkers asyncCmdWorkers >>= mapM_ cancelWorker - clear connCmdsQueued atomically . RQ.clear $ activeSubs c atomically . RQ.clear $ pendingSubs c clear subscrConns @@ -909,7 +906,7 @@ closeAgentClient c = do cancelWorker :: Worker -> IO () cancelWorker Worker {doWork, action} = do noWorkToDo doWork - atomically (tryTakeTMVar action) >>= mapM_ (mapM_ uninterruptibleCancel) + atomically (tryTakeTMVar action) >>= mapM_ (mapM_ $ deRefWeak >=> mapM_ killThread) waitUntilActive :: AgentClient -> IO () waitUntilActive AgentClient {active} = unlessM (readTVarIO active) $ atomically $ unlessM (readTVar active) retry diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index f57cf91e9..bc5e800c5 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -41,6 +41,7 @@ module Simplex.Messaging.Agent.Env.SQLite ) where +import Control.Concurrent (ThreadId) import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader @@ -76,8 +77,9 @@ import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPVersion, TLS, Transport (..)) import Simplex.Messaging.Transport.Client (defaultSMPPort) import Simplex.Messaging.Util (allFinally, catchAllErrors, catchAllErrors', tryAllErrors, tryAllErrors') +import System.Mem.Weak (Weak) import System.Random (StdGen, newStdGen) -import UnliftIO (Async, SomeException) +import UnliftIO (SomeException) import UnliftIO.STM type AM' a = ReaderT Env IO a @@ -312,7 +314,7 @@ mkInternal = INTERNAL . show data Worker = Worker { workerId :: Int, doWork :: TMVar (), - action :: TMVar (Maybe (Async ())), + action :: TMVar (Maybe (Weak ThreadId)), restarts :: TVar RestartCount } diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 20f382d40..69b1b07e9 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -57,6 +57,7 @@ module Simplex.Messaging.Agent.Store.SQLite getDeletedConns, getConnData, setConnDeleted, + setConnUserId, setConnAgentVersion, setConnPQSupport, getDeletedConnIds, @@ -110,6 +111,7 @@ module Simplex.Messaging.Agent.Store.SQLite getSndMsgViaRcpt, updateSndMsgRcpt, getPendingQueueMsg, + getConnectionsForDelivery, updatePendingMsgRIState, deletePendingMsgs, getExpiredSndMessages, @@ -1008,6 +1010,10 @@ updateSndMsgRcpt db connId sndMsgId MsgReceipt {agentMsgId, msgRcptStatus} = "UPDATE snd_messages SET rcpt_internal_id = ?, rcpt_status = ? WHERE conn_id = ? AND internal_snd_id = ?" (agentMsgId, msgRcptStatus, connId, sndMsgId) +getConnectionsForDelivery :: DB.Connection -> IO [ConnId] +getConnectionsForDelivery db = + map fromOnly <$> DB.query_ db "SELECT DISTINCT conn_id FROM snd_message_deliveries WHERE failed = 0" + getPendingQueueMsg :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError (Maybe (Maybe RcvQueue, PendingMsgData))) getPendingQueueMsg db connId SndQueue {dbQueueId} = getWorkItem "message" getMsgId getMsgData markMsgFailed @@ -1909,9 +1915,11 @@ newQueueId_ (Only maxId : _) = DBQueueId (maxId + 1) getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) getConn = getAnyConn False +{-# INLINE getConn #-} getDeletedConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) getDeletedConn = getAnyConn True +{-# INLINE getDeletedConn #-} getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn) getAnyConn deleted' dbConn connId = @@ -1932,9 +1940,11 @@ getAnyConn deleted' dbConn connId = getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getConns = getAnyConns_ False +{-# INLINE getConns #-} getDeletedConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getDeletedConns = getAnyConns_ True +{-# INLINE getDeletedConns #-} getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db @@ -1967,6 +1977,10 @@ setConnDeleted db waitDelivery connId | otherwise = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId) +setConnUserId :: DB.Connection -> UserId -> ConnId -> UserId -> IO () +setConnUserId db oldUserId connId newUserId = + DB.execute db "UPDATE connections SET user_id = ? WHERE conn_id = ? and user_id = ?" (newUserId, connId, oldUserId) + setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () setConnAgentVersion db connId aVersion = DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index c5d067475..a70995d5e 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -175,7 +175,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do $>>= endPreviousSubscriptions >>= liftIO . mapM_ unsub where - updateSubscribers :: TVar (IM.IntMap Client) -> STM (Maybe (QueueId, Client)) + updateSubscribers :: TVar (IM.IntMap (Maybe Client)) -> STM (Maybe (QueueId, Client)) updateSubscribers cls = do (qId, clnt, subscribed) <- readTQueue $ subQ s current <- IM.member (clientId clnt) <$> readTVar cls @@ -412,7 +412,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do CPClients -> withAdminRole $ do active <- unliftIO u (asks clients) >>= readTVarIO hPutStrLn h "clientId,sessionId,connected,createdAt,rcvActiveAt,sndActiveAt,age,subscriptions" - forM_ (IM.toList active) $ \(cid, Client {sessionId, connected, createdAt, rcvActiveAt, sndActiveAt, subscriptions}) -> do + forM_ (IM.toList active) $ \(cid, cl) -> forM_ cl $ \Client {sessionId, connected, createdAt, rcvActiveAt, sndActiveAt, subscriptions} -> do connected' <- bshow <$> readTVarIO connected rcvActiveAt' <- strEncode <$> readTVarIO rcvActiveAt sndActiveAt' <- strEncode <$> readTVarIO sndActiveAt @@ -507,7 +507,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do activeClients <- readTVarIO clients hPutStrLn h $ "Clients: " <> show (IM.size activeClients) when (r == CPRAdmin) $ do - clQs <- clientTBQueueLengths activeClients + clQs <- clientTBQueueLengths' activeClients hPutStrLn h $ "Client queues (rcvQ, sndQ, msgQ): " <> show clQs (smpSubCnt, smpSubCntByGroup, smpClCnt, smpClQs) <- countClientSubs subscriptions (Just countSMPSubs) activeClients hPutStrLn h $ "SMP subscriptions (via clients): " <> show smpSubCnt @@ -542,11 +542,12 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do | otherwise = (cl : cls, IS.insert clientId clSet) countSubClients :: M.Map QueueId Client -> Int countSubClients = IS.size . M.foldr' (IS.insert . clientId) IS.empty - countClientSubs :: (Client -> TMap QueueId a) -> Maybe (M.Map QueueId a -> IO (Int, Int, Int, Int)) -> IM.IntMap Client -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) + countClientSubs :: (Client -> TMap QueueId a) -> Maybe (M.Map QueueId a -> IO (Int, Int, Int, Int)) -> IM.IntMap (Maybe Client) -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) countClientSubs subSel countSubs_ = foldM addSubs (0, (0, 0, 0, 0), 0, (0, 0, 0)) where - addSubs :: (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) -> Client -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) - addSubs (!subCnt, cnts@(!c1, !c2, !c3, !c4), !clCnt, !qs) cl = do + addSubs :: (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) -> Maybe Client -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) + addSubs acc Nothing = pure acc + addSubs (!subCnt, cnts@(!c1, !c2, !c3, !c4), !clCnt, !qs) (Just cl) = do subs <- readTVarIO $ subSel cl cnts' <- case countSubs_ of Nothing -> pure cnts @@ -559,6 +560,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do pure (subCnt + cnt, cnts', clCnt', qs') clientTBQueueLengths :: Foldable t => t Client -> IO (Natural, Natural, Natural) clientTBQueueLengths = foldM addQueueLengths (0, 0, 0) + clientTBQueueLengths' :: Foldable t => t (Maybe Client) -> IO (Natural, Natural, Natural) + clientTBQueueLengths' = foldM (\acc -> maybe (pure acc) (addQueueLengths acc)) (0, 0, 0) addQueueLengths (!rl, !sl, !ml) cl = do (rl', sl', ml') <- queueLengths cl pure (rl + rl', sl + sl', ml + ml') @@ -619,15 +622,18 @@ runClientTransport h@THandle {params = thParams@THandleParams {thVersion, sessio ts <- liftIO getSystemTime active <- asks clients nextClientId <- asks clientSeq - c@Client {clientId} <- liftIO $ newClient nextClientId q thVersion sessionId ts - atomically $ modifyTVar' active $ IM.insert clientId c - s <- asks server - expCfg <- asks $ inactiveClientExpiration . config - th <- newMVar h -- put TH under a fair lock to interleave messages and command responses - labelMyThread . B.unpack $ "client $" <> encode sessionId - raceAny_ ([liftIO $ send th c, liftIO $ sendMsg th c, client thParams c s, receive h c] <> disconnectThread_ c expCfg) - `finally` clientDisconnected c + clientId <- atomically $ stateTVar nextClientId $ \next -> (next, next + 1) + atomically $ modifyTVar' active $ IM.insert clientId Nothing + c <- liftIO $ newClient clientId q thVersion sessionId ts + runClientThreads active c clientId `finally` clientDisconnected c where + runClientThreads active c clientId = do + atomically $ modifyTVar' active $ IM.insert clientId $ Just c + s <- asks server + expCfg <- asks $ inactiveClientExpiration . config + th <- newMVar h -- put TH under a fair lock to interleave messages and command responses + labelMyThread . B.unpack $ "client $" <> encode sessionId + raceAny_ $ [liftIO $ send th c, liftIO $ sendMsg th c, client thParams c s, receive h c] <> disconnectThread_ c expCfg disconnectThread_ c (Just expCfg) = [liftIO $ disconnectTransport h (rcvActiveAt c) (sndActiveAt c) expCfg (noSubscriptions c)] disconnectThread_ _ _ = [] noSubscriptions c = atomically $ (&&) <$> TM.null (ntfSubscriptions c) <*> (not . hasSubs <$> readTVar (subscriptions c)) diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 84e664607..9b1dd9405 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -127,7 +127,7 @@ data Env = Env serverStats :: ServerStats, sockets :: SocketState, clientSeq :: TVar ClientId, - clients :: TVar (IntMap Client), + clients :: TVar (IntMap (Maybe Client)), proxyAgent :: ProxyAgent -- senders served on this proxy } @@ -183,9 +183,8 @@ newServer = do savingLock <- atomically createLock return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers, savingLock} -newClient :: TVar ClientId -> Natural -> VersionSMP -> ByteString -> SystemTime -> IO Client -newClient nextClientId qSize thVersion sessionId createdAt = do - clientId <- atomically $ stateTVar nextClientId $ \next -> (next, next + 1) +newClient :: ClientId -> Natural -> VersionSMP -> ByteString -> SystemTime -> IO Client +newClient clientId qSize thVersion sessionId createdAt = do subscriptions <- TM.emptyIO ntfSubscriptions <- TM.emptyIO rcvQ <- newTBQueueIO qSize diff --git a/src/Simplex/Messaging/Server/StoreLog.hs b/src/Simplex/Messaging/Server/StoreLog.hs index d1ce15ed6..94a340d94 100644 --- a/src/Simplex/Messaging/Server/StoreLog.hs +++ b/src/Simplex/Messaging/Server/StoreLog.hs @@ -36,7 +36,7 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore (NtfCreds (..), QueueRec (..), ServerQueueStatus (..)) import Simplex.Messaging.Transport.Buffer (trimCR) import Simplex.Messaging.Util (ifM) -import System.Directory (doesFileExist) +import System.Directory (doesFileExist, renameFile) import System.IO -- | opaque container for file handle with a type-safe IOMode @@ -140,10 +140,12 @@ logDeleteNotifier s = writeStoreLogRecord s . DeleteNotifier readWriteStoreLog :: FilePath -> IO (Map RecipientId QueueRec, StoreLog 'WriteMode) readWriteStoreLog f = do - qs <- ifM (doesFileExist f) (readQueues f) (pure M.empty) + qs <- ifM (doesFileExist f) readQS (pure M.empty) s <- openWriteStoreLog f writeQueues s qs pure (qs, s) + where + readQS = readQueues f <* renameFile f (f <> ".bak") writeQueues :: StoreLog 'WriteMode -> Map RecipientId QueueRec -> IO () writeQueues s = mapM_ $ \q -> when (active q) $ logCreateQueue s q diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 58843b7f5..3386f82f3 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -113,6 +113,7 @@ import Simplex.Messaging.Transport.Buffer import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith) import Simplex.Messaging.Version import Simplex.Messaging.Version.Internal +import System.IO.Error (isEOFError) import UnliftIO.Exception (Exception) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -339,11 +340,12 @@ instance Transport TLS where getLn :: TLS -> IO ByteString getLn TLS {tlsContext, tlsBuffer} = do - getLnBuffered tlsBuffer (T.recvData tlsContext) `E.catch` handleEOF + getLnBuffered tlsBuffer (T.recvData tlsContext) `E.catches` [E.Handler handleTlsEOF, E.Handler handleEOF] where - handleEOF = \case - T.Error_EOF -> E.throwIO TEBadBlock + handleTlsEOF = \case + T.PostHandshake T.Error_EOF -> E.throwIO TEBadBlock e -> E.throwIO e + handleEOF e = if isEOFError e then E.throwIO TEBadBlock else E.throwIO e -- * SMP transport diff --git a/src/Simplex/Messaging/Transport/WebSockets.hs b/src/Simplex/Messaging/Transport/WebSockets.hs index 0883fcc28..866d0d197 100644 --- a/src/Simplex/Messaging/Transport/WebSockets.hs +++ b/src/Simplex/Messaging/Transport/WebSockets.hs @@ -25,6 +25,7 @@ import Simplex.Messaging.Transport withTlsUnique, ) import Simplex.Messaging.Transport.Buffer (trimCR) +import System.IO.Error (isEOFError) data WS = WS { wsPeer :: TransportPeer, @@ -108,9 +109,11 @@ makeTLSContextStream cxt = S.makeStream readStream writeStream where readStream :: IO (Maybe ByteString) - readStream = - (Just <$> T.recvData cxt) `E.catch` \case - T.Error_EOF -> pure Nothing - e -> E.throwIO e + readStream = (Just <$> T.recvData cxt) `E.catches` [E.Handler handleTlsEOF, E.Handler handleEOF] + where + handleTlsEOF = \case + T.PostHandshake T.Error_EOF -> pure Nothing + e -> E.throwIO e + handleEOF e = if isEOFError e then pure Nothing else E.throwIO e writeStream :: Maybe LB.ByteString -> IO () writeStream = maybe (closeTLS cxt) (T.sendData cxt) diff --git a/src/Simplex/RemoteControl/Client.hs b/src/Simplex/RemoteControl/Client.hs index de0cbce3b..381397c6e 100644 --- a/src/Simplex/RemoteControl/Client.hs +++ b/src/Simplex/RemoteControl/Client.hs @@ -305,7 +305,7 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca, catchRCError :: ExceptT RCErrorType IO a -> (RCErrorType -> ExceptT RCErrorType IO a) -> ExceptT RCErrorType IO a catchRCError = catchAllErrors $ \e -> case fromException e of - Just (TLS.Terminated _ _ (TLS.Error_Protocol (_, _, TLS.UnknownCa))) -> RCEIdentity + Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity _ -> RCEException $ show e {-# INLINE catchRCError #-} diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 4d61d8463..f0a25b758 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -283,6 +283,9 @@ functionalAPITests t = do testPQMatrix3 t $ runAgentClientContactTestPQ3 True it "should support rejecting contact request" $ withSmpServer t testRejectContactRequest + describe "Changing connection user id" $ do + it "should change user id for new connections" $ do + withSmpServer t testUpdateConnectionUserId describe "Establishing connection asynchronously" $ do it "should connect with initiating client going offline" $ withSmpServer t testAsyncInitiatingOffline @@ -912,6 +915,25 @@ testRejectContactRequest = rejectContact alice addrConnId invId liftIO $ noMessages bob "nothing delivered to bob" +testUpdateConnectionUserId :: HasCallStack => IO () +testUpdateConnectionUserId = + withAgentClients2 $ \alice bob -> runRight_ $ do + (connId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe + newUserId <- createUser alice [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] + _ <- changeConnectionUser alice 1 connId newUserId + aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn + (aliceId', sqSecured') <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" PQSupportOn SMSubscribe + liftIO $ do + aliceId' `shouldBe` aliceId + sqSecured' `shouldBe` True + ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice + liftIO $ pqSup' `shouldBe` PQSupportOn + allowConnection alice connId confId "alice's connInfo" + let pqEnc = CR.pqSupportToEnc PQSupportOn + get alice ##> ("", connId, A.CON pqEnc) + get bob ##> ("", aliceId, A.INFO PQSupportOn "alice's connInfo") + get bob ##> ("", aliceId, A.CON pqEnc) + testAsyncInitiatingOffline :: HasCallStack => IO () testAsyncInitiatingOffline = withAgent 2 agentCfg initAgentServers testDB2 $ \bob -> runRight_ $ do diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index cc79faeca..012da704d 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -508,6 +508,7 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} baseId ali suspendAgent alice 0 closeSQLiteStore store threadDelay 1000000 + print "before opening the database from another agent" -- aliceNtf client doesn't have subscription and is allowed to get notification message withAgent 3 aliceCfg initAgentServers testDB $ \aliceNtf -> runRight_ $ do @@ -515,6 +516,7 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} baseId ali pure () threadDelay 1000000 + print "after closing the database in another agent" reopenSQLiteStore store foregroundAgent alice threadDelay 500000 diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 39a4b1b95..95096e800 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -114,6 +114,8 @@ storeTests = do testDeleteRcvConn testDeleteSndConn testDeleteDuplexConn + describe "setConnUserId" $ do + testSetConnUserIdNewConn describe "upgradeRcvConnToDuplex" $ do testUpgradeRcvConnToDuplex describe "upgradeSndConnToDuplex" $ do @@ -336,6 +338,21 @@ testGetRcvConn = getRcvConn db smpServer recipientId `shouldReturn` Right (rq, SomeConn SCRcv (RcvConnection cData1 rq)) +testSetConnUserIdNewConn :: SpecWith SQLiteStore +testSetConnUserIdNewConn = + it "should set user id for new connection" . withStoreTransaction $ \db -> do + g <- C.newRandom + Right connId <- createNewConn db g cData1 {connId = ""} SCMInvitation + newUserId <- createUserRecord db + _ <- setConnUserId db 1 connId newUserId + connResult <- getConn db connId + case connResult of + Right (SomeConn SCNew (NewConnection connData)) -> do + let ConnData {userId} = connData + userId `shouldBe` newUserId + _ -> do + expectationFailure "Failed to get connection" + testDeleteRcvConn :: SpecWith SQLiteStore testDeleteRcvConn = it "should create RcvConnection and delete it" . withStoreTransaction $ \db -> do