diff --git a/simplexmq.cabal b/simplexmq.cabal index a482ce29f..5a8d91390 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -618,6 +618,7 @@ test-suite simplexmq-test CoreTests.EncodingTests CoreTests.ProtocolErrorTests CoreTests.RetryIntervalTests + CoreTests.TRcvQueuesTests CoreTests.UtilTests CoreTests.VersionRangeTests FileDescriptionTests diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index b5491a6ed..23caa2254 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -124,6 +124,9 @@ module Simplex.Messaging.Agent.Client getAgentWorkersDetails, AgentWorkersSummary (..), getAgentWorkersSummary, + SMPTransportSession, + NtfTransportSession, + XFTPTransportSession, ) where @@ -532,11 +535,9 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, where currentActiveClient = (&&) <$> removeTSessVar' v tSess smpClients <*> readTVar active removeSubs = do - qs <- RQ.getDelSessQueues tSess $ activeSubs c - mapM_ (`RQ.addQueue` pendingSubs c) qs - let cs = S.fromList $ map qConnId qs - cs' <- RQ.getConns $ activeSubs c - pure (qs, S.toList $ cs `S.difference` cs') + (qs, cs) <- RQ.getDelSessQueues tSess $ activeSubs c + RQ.batchAddQueues (pendingSubs c) qs + pure (qs, cs) serverDown :: ([RcvQueue], [ConnId]) -> IO () serverDown (qs, conns) = whenM (readTVarIO active) $ do @@ -594,16 +595,16 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do where resubscribe :: m () resubscribe = do - cs <- atomically . RQ.getConns $ activeSubs c + cs <- readTVarIO $ RQ.getConnections $ activeSubs c rs <- subscribeQueues c $ L.toList qs let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs liftIO $ do - let conns = S.toList $ S.fromList okConns `S.difference` cs + let conns = filter (`M.notMember` cs) okConns unless (null conns) $ notifySub "" $ UP srv conns let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs forM_ (listToMaybe tempErrs) $ \(_, err) -> do - when (null okConns && S.null cs && null finalErrs) . liftIO $ + when (null okConns && M.null cs && null finalErrs) . liftIO $ closeClient c smpClients tSess throwError err notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO () @@ -1060,9 +1061,9 @@ temporaryOrHostError = \case subscribeQueues :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())] subscribeQueues c qs = do (errs, qs') <- partitionEithers <$> mapM checkQueue qs - forM_ qs' $ \rq@RcvQueue {connId} -> atomically $ do - modifyTVar (subscrConns c) $ S.insert connId - RQ.addQueue rq $ pendingSubs c + atomically $ do + modifyTVar' (subscrConns c) (`S.union` S.fromList (map qConnId qs')) + RQ.batchAddQueues (pendingSubs c) qs' u <- askUnliftIO -- only "checked" queues are subscribed (errs <>) <$> sendTSessionBatches "SUB" 90 id (subscribeQueues_ u) c qs' diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs index 38b9a6d47..9ffe325b2 100644 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -1,11 +1,13 @@ +{-# LANGUAGE LambdaCase #-} + module Simplex.Messaging.Agent.TRcvQueues - ( TRcvQueues (getRcvQueues), + ( TRcvQueues (getRcvQueues, getConnections), empty, clear, deleteConn, hasConn, - getConns, addQueue, + batchAddQueues, deleteQueue, getSessQueues, getDelSessQueues, @@ -14,49 +16,85 @@ module Simplex.Messaging.Agent.TRcvQueues where import Control.Concurrent.STM +import Data.Foldable (foldl') +import Data.List.NonEmpty (NonEmpty (..), (<|)) +import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Set (Set) -import qualified Data.Set as S import Simplex.Messaging.Agent.Protocol (ConnId, UserId) import Simplex.Messaging.Agent.Store (RcvQueue, StoredRcvQueue (..)) import Simplex.Messaging.Protocol (RecipientId, SMPServer) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -newtype TRcvQueues = TRcvQueues {getRcvQueues :: TMap (UserId, SMPServer, RecipientId) RcvQueue} +-- the fields in this record have the same data with swapped keys for lookup efficiency, +-- and all methods must maintain this invariant. +data TRcvQueues = TRcvQueues + { getRcvQueues :: TMap (UserId, SMPServer, RecipientId) RcvQueue, + getConnections :: TMap ConnId (NonEmpty (UserId, SMPServer, RecipientId)) + } empty :: STM TRcvQueues -empty = TRcvQueues <$> TM.empty +empty = TRcvQueues <$> TM.empty <*> TM.empty clear :: TRcvQueues -> STM () -clear (TRcvQueues qs) = TM.clear qs +clear (TRcvQueues qs cs) = TM.clear qs >> TM.clear cs deleteConn :: ConnId -> TRcvQueues -> STM () -deleteConn cId (TRcvQueues qs) = modifyTVar' qs $ M.filter (\rq -> cId /= connId rq) +deleteConn cId (TRcvQueues qs cs) = + TM.lookupDelete cId cs >>= \case + Just ks -> modifyTVar' qs $ \qs' -> foldl' (flip M.delete) qs' ks + Nothing -> pure () hasConn :: ConnId -> TRcvQueues -> STM Bool -hasConn cId (TRcvQueues qs) = any (\rq -> cId == connId rq) <$> readTVar qs - -getConns :: TRcvQueues -> STM (Set ConnId) -getConns (TRcvQueues qs) = M.foldr' (S.insert . connId) S.empty <$> readTVar qs +hasConn cId (TRcvQueues _ cs) = TM.member cId cs addQueue :: RcvQueue -> TRcvQueues -> STM () -addQueue rq (TRcvQueues qs) = TM.insert (qKey rq) rq qs +addQueue rq (TRcvQueues qs cs) = do + TM.insert k rq qs + TM.alter addQ (connId rq) cs + where + addQ = Just . maybe (k :| []) (k <|) + k = qKey rq + +-- Save time by aggregating modifyTVar +batchAddQueues :: Foldable t => TRcvQueues -> t RcvQueue -> STM () +batchAddQueues (TRcvQueues qs cs) rqs = do + modifyTVar' qs $ \now -> foldl' (\rqs' rq -> M.insert (qKey rq) rq rqs') now rqs + modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId rq) cs') now rqs + where + addQ k = Just . maybe (k :| []) (k <|) deleteQueue :: RcvQueue -> TRcvQueues -> STM () -deleteQueue rq (TRcvQueues qs) = TM.delete (qKey rq) qs +deleteQueue rq (TRcvQueues qs cs) = do + TM.delete k qs + TM.update delQ (connId rq) cs + where + delQ = L.nonEmpty . L.filter (/= k) + k = qKey rq getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue] -getSessQueues tSess (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs +getSessQueues tSess (TRcvQueues qs _) = M.foldl' addQ [] <$> readTVar qs where addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs' -getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue] -getDelSessQueues tSess (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty) +getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM ([RcvQueue], [ConnId]) +getDelSessQueues tSess (TRcvQueues qs cs) = do + (removedQs, qs'') <- (\qs' -> M.foldl' delQ ([], qs') qs') <$> readTVar qs + writeTVar qs $! qs'' + removedConns <- stateTVar cs $ \cs' -> foldl' delConn ([], cs') removedQs + pure (removedQs, removedConns) where - addQ (removed, qs') rq - | rq `isSession` tSess = (rq : removed, qs') - | otherwise = (removed, M.insert (qKey rq) rq qs') + delQ acc@(removed, qs') rq + | rq `isSession` tSess = (rq : removed, M.delete (qKey rq) qs') + | otherwise = acc + delConn (removed, cs') rq = M.alterF f cId cs' + where + cId = connId rq + f = \case + Just ks -> case L.nonEmpty $ L.filter (qKey rq /=) ks of + Just ks' -> (removed, Just ks') + Nothing -> (cId : removed, Nothing) + Nothing -> (removed, Nothing) -- "impossible" in invariant holds, because we get keys from the known queues isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool isSession rq (uId, srv, connId_) = diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs new file mode 100644 index 000000000..70f2d93ab --- /dev/null +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -0,0 +1,142 @@ +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeApplications #-} + +module CoreTests.TRcvQueuesTests where + +import qualified Data.List.NonEmpty as L +import qualified Data.Map as M +import qualified Data.Set as S +import Simplex.Messaging.Agent.Protocol (ConnId, QueueStatus (..), UserId) +import Simplex.Messaging.Agent.Store (DBQueueId (..), RcvQueue, StoredRcvQueue (..)) +import qualified Simplex.Messaging.Agent.TRcvQueues as RQ +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Protocol (SMPServer) +import Test.Hspec +import UnliftIO + +tRcvQueuesTests :: Spec +tRcvQueuesTests = do + describe "connection API" $ do + it "hasConn" hasConnTest + it "hasConn, batch add" hasConnTestBatch + it "deleteConn" deleteConnTest + describe "session API" $ do + it "getSessQueues" getSessQueuesTest + it "getDelSessQueues" getDelSessQueuesTest + +checkDataInvariant :: RQ.TRcvQueues -> IO Bool +checkDataInvariant trq = atomically $ do + conns <- readTVar $ RQ.getConnections trq + qs <- readTVar $ RQ.getRcvQueues trq + -- three invariant checks + let inv1 = all (\cId -> (S.fromList . L.toList <$> M.lookup cId conns) == Just (M.keysSet (M.filter (\q -> connId q == cId) qs))) (M.keys conns) + inv2 = all (\(k, q) -> maybe False ((k `elem`) . L.toList) (M.lookup (connId q) conns)) (M.assocs qs) + inv3 = all (\(k, q) -> RQ.qKey q == k) (M.assocs qs) + pure $ inv1 && inv2 && inv3 + +hasConnTest :: IO () +hasConnTest = do + trq <- atomically RQ.empty + atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq + checkDataInvariant trq `shouldReturn` True + atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq + checkDataInvariant trq `shouldReturn` True + atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@beta" "c3") trq + checkDataInvariant trq `shouldReturn` True + atomically (RQ.hasConn "c1" trq) `shouldReturn` True + atomically (RQ.hasConn "c2" trq) `shouldReturn` True + atomically (RQ.hasConn "c3" trq) `shouldReturn` True + atomically (RQ.hasConn "nope" trq) `shouldReturn` False + +hasConnTestBatch :: IO () +hasConnTestBatch = do + trq <- atomically RQ.empty + let qs = [dummyRQ 0 "smp://1234-w==@alpha" "c1", dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@beta" "c3"] + atomically $ RQ.batchAddQueues trq qs + checkDataInvariant trq `shouldReturn` True + atomically (RQ.hasConn "c1" trq) `shouldReturn` True + atomically (RQ.hasConn "c2" trq) `shouldReturn` True + atomically (RQ.hasConn "c3" trq) `shouldReturn` True + atomically (RQ.hasConn "nope" trq) `shouldReturn` False + +deleteConnTest :: IO () +deleteConnTest = do + trq <- atomically RQ.empty + atomically $ do + RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq + RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq + RQ.addQueue (dummyRQ 0 "smp://1234-w==@beta" "c3") trq + checkDataInvariant trq `shouldReturn` True + atomically $ RQ.deleteConn "c1" trq + checkDataInvariant trq `shouldReturn` True + atomically $ RQ.deleteConn "nope" trq + checkDataInvariant trq `shouldReturn` True + M.keys <$> readTVarIO (RQ.getConnections trq) `shouldReturn` ["c2", "c3"] + +getSessQueuesTest :: IO () +getSessQueuesTest = do + trq <- atomically RQ.empty + atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq + checkDataInvariant trq `shouldReturn` True + atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq + checkDataInvariant trq `shouldReturn` True + atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@beta" "c3") trq + checkDataInvariant trq `shouldReturn` True + atomically $ RQ.addQueue (dummyRQ 1 "smp://1234-w==@beta" "c4") trq + checkDataInvariant trq `shouldReturn` True + atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Just "c1") trq) `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c1"] + atomically (RQ.getSessQueues (1, "smp://1234-w==@alpha", Just "c1") trq) `shouldReturn` [] + atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Just "nope") trq) `shouldReturn` [] + atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"] + +getDelSessQueuesTest :: IO () +getDelSessQueuesTest = do + trq <- atomically RQ.empty + let qs = + [ dummyRQ 0 "smp://1234-w==@alpha" "c1", + dummyRQ 0 "smp://1234-w==@alpha" "c2", + dummyRQ 0 "smp://1234-w==@beta" "c3", + dummyRQ 1 "smp://1234-w==@beta" "c4" + ] + atomically $ RQ.batchAddQueues trq qs + checkDataInvariant trq `shouldReturn` True + -- no user + atomically (RQ.getDelSessQueues (2, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([], []) + checkDataInvariant trq `shouldReturn` True + -- wrong user + atomically (RQ.getDelSessQueues (1, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([], []) + checkDataInvariant trq `shouldReturn` True + -- connections intact + atomically (RQ.hasConn "c1" trq) `shouldReturn` True + atomically (RQ.hasConn "c2" trq) `shouldReturn` True + atomically (RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"], ["c1", "c2"]) + checkDataInvariant trq `shouldReturn` True + -- connections gone + atomically (RQ.hasConn "c1" trq) `shouldReturn` False + atomically (RQ.hasConn "c2" trq) `shouldReturn` False + -- non-matched connections intact + atomically (RQ.hasConn "c3" trq) `shouldReturn` True + atomically (RQ.hasConn "c4" trq) `shouldReturn` True + +dummyRQ :: UserId -> SMPServer -> ConnId -> RcvQueue +dummyRQ userId server connId = + RcvQueue + { userId, + connId, + server, + rcvId = "", + rcvPrivateKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe", + rcvDhSecret = "01234567890123456789012345678901", + e2ePrivKey = "MC4CAQAwBQYDK2VuBCIEINCzbVFaCiYHoYncxNY8tSIfn0pXcIAhLBfFc0m+gOpk", + e2eDhSecret = Nothing, + sndId = "", + status = New, + dbQueueId = DBQueueId 0, + primary = True, + dbReplaceQueueId = Nothing, + rcvSwchStatus = Nothing, + smpClientVersion = 123, + clientNtfCreds = Nothing, + deleteErrors = 0 + } diff --git a/tests/Test.hs b/tests/Test.hs index b7fc3e9cb..aebceb22d 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -13,6 +13,7 @@ import CoreTests.CryptoTests import CoreTests.EncodingTests import CoreTests.ProtocolErrorTests import CoreTests.RetryIntervalTests +import CoreTests.TRcvQueuesTests import CoreTests.UtilTests import CoreTests.VersionRangeTests import FileDescriptionTests (fileDescriptionTests) @@ -52,6 +53,7 @@ main = do describe "Encryption tests" cryptoTests describe "Encrypted files tests" cryptoFileTests describe "Retry interval tests" retryIntervalTests + describe "TRcvQueues tests" tRcvQueuesTests describe "Util tests" utilTests describe "SMP server via TLS" $ serverTests (transport @TLS) describe "SMP server via WebSockets" $ serverTests (transport @WS)