diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 0bd95d1dc..7159a1324 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -2202,7 +2202,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId processSubOk :: RcvQueue -> TVar [ConnId] -> AM () processSubOk rq@RcvQueue {connId} upConnIds = atomically . whenM (isPendingSub connId) $ do - addSubscription c rq + addSubscription c sessId rq modifyTVar' upConnIds (connId :) processSubErr :: RcvQueue -> SMPClientError -> AM () processSubErr rq@RcvQueue {connId} e = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 8aacd8836..55caf754c 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -306,8 +306,8 @@ data AgentClient = AgentClient userNetworkInfo :: TVar UserNetworkInfo, userNetworkUpdated :: TVar (Maybe UTCTime), subscrConns :: TVar (Set ConnId), - activeSubs :: TRcvQueues, - pendingSubs :: TRcvQueues, + activeSubs :: TRcvQueues (SessionId, RcvQueue), + pendingSubs :: TRcvQueues RcvQueue, removedSubs :: TMap (UserId, SMPServer, SMP.RecipientId) SMPClientError, workerSeq :: TVar Int, smpDeliveryWorkers :: TMap SndQAddr (Worker, TMVar ()), @@ -332,7 +332,7 @@ data AgentClient = AgentClient agentEnv :: Env, smpServersStats :: TMap (UserId, SMPServer) AgentSMPServerStats, xftpServersStats :: TMap (UserId, XFTPServer) AgentXFTPServerStats, - ntfServersStats :: TMap (UserId, NtfServer) AgentNtfServerStats, + ntfServersStats :: TMap (UserId, NtfServer) AgentNtfServerStats, srvStatsStartedAt :: TVar UTCTime } @@ -677,11 +677,13 @@ smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess -- because we can have a race condition when a new current client could have already -- made subscriptions active, and the old client would be processing diconnection later. removeClientAndSubs :: IO ([RcvQueue], [ConnId]) - removeClientAndSubs = atomically $ ifM currentActiveClient removeSubs $ pure ([], []) + removeClientAndSubs = atomically $ do + removeSessVar v tSess smpClients + ifM (readTVar active) removeSubs (pure ([], [])) where - currentActiveClient = (&&) <$> removeSessVar' v tSess smpClients <*> readTVar active + sessId = sessionId $ thParams client removeSubs = do - (qs, cs) <- RQ.getDelSessQueues tSess $ activeSubs c + (qs, cs) <- RQ.getDelSessQueues tSess sessId $ activeSubs c RQ.batchAddQueues (pendingSubs c) qs -- this removes proxied relays that this client created sessions to destSrvs <- M.keys <$> readTVar prs @@ -1347,8 +1349,8 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode sender qUri = SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey sndSecure pure (rq, qUri, tSess, sessId) -processSubResult :: AgentClient -> RcvQueue -> Either SMPClientError () -> STM () -processSubResult c rq@RcvQueue {userId, server, connId} = \case +processSubResult :: AgentClient -> SessionId -> RcvQueue -> Either SMPClientError () -> STM () +processSubResult c sessId rq@RcvQueue {userId, server, connId} = \case Left e -> unless (temporaryClientError e) $ do incSMPServerStat c userId server connSubErrs @@ -1356,7 +1358,7 @@ processSubResult c rq@RcvQueue {userId, server, connId} = \case Right () -> ifM (hasPendingSubscription c connId) - (incSMPServerStat c userId server connSubscribed >> addSubscription c rq) + (incSMPServerStat c userId server connSubscribed >> addSubscription c sessId rq) (incSMPServerStat c userId server connSubIgnored) temporaryAgentError :: AgentErrorType -> Bool @@ -1427,7 +1429,7 @@ subscribeQueues c qs = do sessId = sessionId $ thParams smp hasTempErrors = any (either temporaryClientError (const False) . snd) processSubResults :: NonEmpty (RcvQueue, Either SMPClientError ()) -> STM () - processSubResults = mapM_ $ uncurry $ processSubResult c + processSubResults = mapM_ $ uncurry $ processSubResult c sessId resubscribe = resubscribeSMPSession c tSess `runReaderT` env activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool @@ -1466,10 +1468,10 @@ sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs) where queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) -addSubscription :: AgentClient -> RcvQueue -> STM () -addSubscription c rq@RcvQueue {connId} = do +addSubscription :: AgentClient -> SessionId -> RcvQueue -> STM () +addSubscription c sessId rq@RcvQueue {connId} = do modifyTVar' (subscrConns c) $ S.insert connId - RQ.addQueue rq $ activeSubs c + RQ.addQueue (sessId, rq) $ activeSubs c RQ.deleteQueue rq $ pendingSubs c failSubscription :: AgentClient -> RcvQueue -> SMPClientError -> STM () @@ -1488,7 +1490,7 @@ addNewQueueSubscription c rq tSess sessId = do atomically $ ifM (activeClientSession c tSess sessId) - (True <$ addSubscription c rq) + (True <$ addSubscription c sessId rq) (False <$ addPendingSubscription c rq) unless same $ resubscribeSMPSession c tSess @@ -2025,7 +2027,9 @@ getAgentSubsTotal c userIds = do sess <- hasSession . M.toList =<< readTVarIO (smpClients c) pure (SMPServerSubs {ssActive, ssPending}, sess) where + getSubsCount :: (AgentClient -> TRcvQueues q) -> IO Int getSubsCount subs = M.foldrWithKey' addSub 0 <$> readTVarIO (getRcvQueues $ subs c) + addSub :: (UserId, SMPServer, SMP.RecipientId) -> q -> Int -> Int addSub (userId, _, _) _ cnt = if userId `elem` userIds then cnt + 1 else cnt hasSession :: [(SMPTransportSession, SMPClientVar)] -> IO Bool hasSession = \case @@ -2106,6 +2110,7 @@ getAgentSubscriptions c = do removedSubscriptions <- getRemovedSubs pure $ SubscriptionsInfo {activeSubscriptions, pendingSubscriptions, removedSubscriptions} where + getSubs :: (AgentClient -> TRcvQueues q) -> IO [SubInfo] getSubs sel = map (`subInfo` Nothing) . M.keys <$> readTVarIO (getRcvQueues $ sel c) getRemovedSubs = map (uncurry subInfo . second Just) . M.assocs <$> readTVarIO (removedSubs c) subInfo :: (UserId, SMPServer, SMP.RecipientId) -> Maybe SMPClientError -> SubInfo diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs index 9ffe325b2..10e4574cb 100644 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -1,7 +1,9 @@ +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} module Simplex.Messaging.Agent.TRcvQueues ( TRcvQueues (getRcvQueues, getConnections), + Queue (..), empty, clear, deleteConn, @@ -11,7 +13,6 @@ module Simplex.Messaging.Agent.TRcvQueues deleteQueue, getSessQueues, getDelSessQueues, - qKey, ) where @@ -25,46 +26,51 @@ import Simplex.Messaging.Agent.Store (RcvQueue, StoredRcvQueue (..)) import Simplex.Messaging.Protocol (RecipientId, SMPServer) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Transport + +class Queue q where + connId' :: q -> ConnId + qKey :: q -> (UserId, SMPServer, ConnId) -- the fields in this record have the same data with swapped keys for lookup efficiency, -- and all methods must maintain this invariant. -data TRcvQueues = TRcvQueues - { getRcvQueues :: TMap (UserId, SMPServer, RecipientId) RcvQueue, +data TRcvQueues q = TRcvQueues + { getRcvQueues :: TMap (UserId, SMPServer, RecipientId) q, getConnections :: TMap ConnId (NonEmpty (UserId, SMPServer, RecipientId)) } -empty :: STM TRcvQueues +empty :: STM (TRcvQueues q) empty = TRcvQueues <$> TM.empty <*> TM.empty -clear :: TRcvQueues -> STM () +clear :: TRcvQueues q -> STM () clear (TRcvQueues qs cs) = TM.clear qs >> TM.clear cs -deleteConn :: ConnId -> TRcvQueues -> STM () +deleteConn :: ConnId -> TRcvQueues q -> STM () deleteConn cId (TRcvQueues qs cs) = TM.lookupDelete cId cs >>= \case Just ks -> modifyTVar' qs $ \qs' -> foldl' (flip M.delete) qs' ks Nothing -> pure () -hasConn :: ConnId -> TRcvQueues -> STM Bool +hasConn :: ConnId -> TRcvQueues q -> STM Bool hasConn cId (TRcvQueues _ cs) = TM.member cId cs -addQueue :: RcvQueue -> TRcvQueues -> STM () +addQueue :: Queue q => q -> TRcvQueues q -> STM () addQueue rq (TRcvQueues qs cs) = do TM.insert k rq qs - TM.alter addQ (connId rq) cs + TM.alter addQ (connId' rq) cs where addQ = Just . maybe (k :| []) (k <|) k = qKey rq -- Save time by aggregating modifyTVar -batchAddQueues :: Foldable t => TRcvQueues -> t RcvQueue -> STM () +batchAddQueues :: (Foldable t, Queue q) => TRcvQueues q -> t q -> STM () batchAddQueues (TRcvQueues qs cs) rqs = do modifyTVar' qs $ \now -> foldl' (\rqs' rq -> M.insert (qKey rq) rq rqs') now rqs - modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId rq) cs') now rqs + modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId' rq) cs') now rqs where addQ k = Just . maybe (k :| []) (k <|) -deleteQueue :: RcvQueue -> TRcvQueues -> STM () +deleteQueue :: RcvQueue -> TRcvQueues RcvQueue -> STM () deleteQueue rq (TRcvQueues qs cs) = do TM.delete k qs TM.update delQ (connId rq) cs @@ -72,21 +78,22 @@ deleteQueue rq (TRcvQueues qs cs) = do delQ = L.nonEmpty . L.filter (/= k) k = qKey rq -getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue] +getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues RcvQueue -> STM [RcvQueue] getSessQueues tSess (TRcvQueues qs _) = M.foldl' addQ [] <$> readTVar qs where addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs' -getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM ([RcvQueue], [ConnId]) -getDelSessQueues tSess (TRcvQueues qs cs) = do +getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> SessionId -> TRcvQueues (SessionId, RcvQueue) -> STM ([RcvQueue], [ConnId]) +getDelSessQueues tSess sessId' (TRcvQueues qs cs) = do (removedQs, qs'') <- (\qs' -> M.foldl' delQ ([], qs') qs') <$> readTVar qs writeTVar qs $! qs'' removedConns <- stateTVar cs $ \cs' -> foldl' delConn ([], cs') removedQs pure (removedQs, removedConns) where - delQ acc@(removed, qs') rq - | rq `isSession` tSess = (rq : removed, M.delete (qKey rq) qs') + delQ acc@(removed, qs') (sessId, rq) + | rq `isSession` tSess && sessId == sessId' = (rq : removed, M.delete (qKey rq) qs') | otherwise = acc + delConn :: ([ConnId], M.Map ConnId (NonEmpty (UserId, SMPServer, ConnId))) -> RcvQueue -> ([ConnId], M.Map ConnId (NonEmpty (UserId, SMPServer, ConnId))) delConn (removed, cs') rq = M.alterF f cId cs' where cId = connId rq @@ -100,5 +107,10 @@ isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool isSession rq (uId, srv, connId_) = userId rq == uId && server rq == srv && maybe True (connId rq ==) connId_ -qKey :: RcvQueue -> (UserId, SMPServer, ConnId) -qKey rq = (userId rq, server rq, connId rq) +instance Queue RcvQueue where + connId' = connId + qKey rq = (userId rq, server rq, connId rq) + +instance Queue (SessionId, RcvQueue) where + connId' = connId . snd + qKey = qKey . snd diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 99c77f67c..a95f706bf 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -100,7 +100,7 @@ data SMPClientAgent = SMPClientAgent randomDrg :: TVar ChaChaDRG, smpClients :: TMap SMPServer SMPClientVar, smpSessions :: TMap SessionId (OwnServer, SMPClient), - srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey), + srvSubs :: TMap SMPServer (TMap SMPSub (SessionId, C.APrivateAuthKey)), pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey), smpSubWorkers :: TMap SMPServer (SessionVar (Async ())), workerSeq :: TVar Int @@ -204,14 +204,17 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random removeClientAndSubs :: SMPClient -> IO (Maybe (Map SMPSub C.APrivateAuthKey)) removeClientAndSubs smp = atomically $ do + TM.delete sessId smpSessions removeSessVar v srv smpClients - TM.delete (sessionId $ thParams smp) smpSessions - TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs + TM.lookup srv (srvSubs ca) >>= mapM updateSubs where + sessId = sessionId $ thParams smp updateSubs sVar = do - ss <- readTVar sVar - addSubs_ (pendingSrvSubs ca) srv ss - pure ss + -- removing subscriptions that have matching sessionId to disconnected client + -- and keep the other ones (they can be made by the new client) + pending <- M.map snd <$> stateTVar sVar (M.partition ((sessId ==) . fst)) + addSubs_ (pendingSrvSubs ca) srv pending + pure pending serverDown :: Map SMPSub C.APrivateAuthKey -> IO () serverDown ss = unless (M.null ss) $ do @@ -256,9 +259,9 @@ reconnectSMPClient ca@SMPClientAgent {agentCfg} srv cs = subscribe_ smp SPNotifier nSubs subscribe_ smp SPRecipient rSubs where - groupSub :: Map SMPSub C.APrivateAuthKey -> (SMPSub, C.APrivateAuthKey) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) - groupSub currSubs (s@(party, qId), k) (nSubs, rSubs) - | M.member s currSubs = (nSubs, rSubs) + groupSub :: Map SMPSub (SessionId, C.APrivateAuthKey) -> (SMPSub, C.APrivateAuthKey) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) + groupSub currSubs (s@(party, qId), k) acc@(nSubs, rSubs) + | M.member s currSubs = acc | otherwise = case party of SPNotifier -> (s' : nSubs, rSubs) SPRecipient -> (nSubs, s' : rSubs) @@ -346,17 +349,18 @@ smpSubscribeQueues party ca smp srv subs = do when tempErrs $ reconnectClient ca srv Nothing -> reconnectClient ca srv where - processSubscriptions :: NonEmpty (Either SMPClientError ()) -> STM (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId]) + processSubscriptions :: NonEmpty (Either SMPClientError ()) -> STM (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) processSubscriptions rs = do pending <- maybe (pure M.empty) readTVar =<< TM.lookup srv (pendingSrvSubs ca) let acc@(_, _, oks, notPending) = foldr (groupSub pending) (False, [], [], []) (L.zip subs rs) unless (null oks) $ addSubscriptions ca srv party oks unless (null notPending) $ removePendingSubs ca srv party notPending pure acc - groupSub :: Map SMPSub C.APrivateAuthKey -> ((QueueId, C.APrivateAuthKey), Either SMPClientError ()) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId]) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId]) - groupSub pending (s@(qId, _), r) acc@(!tempErrs, finalErrs, oks, notPending) = case r of + sessId = sessionId $ thParams smp + groupSub :: Map SMPSub C.APrivateAuthKey -> ((QueueId, C.APrivateAuthKey), Either SMPClientError ()) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) + groupSub pending ((qId, pk), r) acc@(!tempErrs, finalErrs, oks, notPending) = case r of Right () - | M.member (party, qId) pending -> (tempErrs, finalErrs, s : oks, qId : notPending) + | M.member (party, qId) pending -> (tempErrs, finalErrs, (qId, (sessId, pk)) : oks, qId : notPending) | otherwise -> acc Left e | temporaryClientError e -> (True, finalErrs, oks, notPending) @@ -379,7 +383,7 @@ showServer :: SMPServer -> ByteString showServer ProtocolServer {host, port} = strEncode host <> B.pack (if null port then "" else ':' : port) -addSubscriptions :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> STM () +addSubscriptions :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, (SessionId, C.APrivateAuthKey))] -> STM () addSubscriptions = addSubsList_ . srvSubs {-# INLINE addSubscriptions #-} @@ -387,12 +391,12 @@ addPendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, C.APr addPendingSubs = addSubsList_ . pendingSrvSubs {-# INLINE addPendingSubs #-} -addSubsList_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> STM () +addSubsList_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> SMPSubParty -> [(QueueId, s)] -> STM () addSubsList_ subs srv party ss = addSubs_ subs srv ss' where ss' = M.fromList $ map (first (party,)) ss -addSubs_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> Map SMPSub C.APrivateAuthKey -> STM () +addSubs_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> Map SMPSub s -> STM () addSubs_ subs srv ss = TM.lookup srv subs >>= \case Just m -> TM.union ss m @@ -402,7 +406,7 @@ removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM () removeSubscription = removeSub_ . srvSubs {-# INLINE removeSubscription #-} -removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM () +removeSub_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> SMPSub -> STM () removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s) removePendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [QueueId] -> STM () diff --git a/src/Simplex/Messaging/Notifications/Types.hs b/src/Simplex/Messaging/Notifications/Types.hs index 97f9e6bdd..8fcedab53 100644 --- a/src/Simplex/Messaging/Notifications/Types.hs +++ b/src/Simplex/Messaging/Notifications/Types.hs @@ -11,7 +11,7 @@ import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time (UTCTime) import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) -import Simplex.Messaging.Agent.Protocol (UserId, ConnId, NotificationsMode (..)) +import Simplex.Messaging.Agent.Protocol (ConnId, NotificationsMode (..), UserId) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Protocol @@ -48,6 +48,7 @@ data NtfToken = NtfToken ntfServer :: NtfServer, ntfTokenId :: Maybe NtfTokenId, -- TODO combine keys to key pair as the types should match + -- | key used by the ntf server to verify transmissions ntfPubKey :: C.APublicAuthKey, -- | key used by the ntf client to sign transmissions diff --git a/src/Simplex/Messaging/Session.hs b/src/Simplex/Messaging/Session.hs index 3ce5a35c8..45c182046 100644 --- a/src/Simplex/Messaging/Session.hs +++ b/src/Simplex/Messaging/Session.hs @@ -5,9 +5,6 @@ module Simplex.Messaging.Session where import Control.Concurrent.STM -import Control.Monad -import Data.Composition ((.:.)) -import Data.Functor (($>)) import Data.Time (UTCTime) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -31,14 +28,10 @@ getSessVar sessSeq sessKey vs sessionVarTs = maybe (Left <$> newSessionVar) (pur pure v removeSessVar :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM () -removeSessVar = void .:. removeSessVar' -{-# INLINE removeSessVar #-} - -removeSessVar' :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM Bool -removeSessVar' v sessKey vs = +removeSessVar v sessKey vs = TM.lookup sessKey vs >>= \case - Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs $> True - _ -> pure False + Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs + _ -> pure () tryReadSessVar :: Ord k => k -> TMap k (SessionVar a) -> STM (Maybe a) tryReadSessVar sessKey vs = TM.lookup sessKey vs $>>= (tryReadTMVar . sessionVar) diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs index 9f7c4932e..7e39d7fd9 100644 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} module CoreTests.TRcvQueuesTests where @@ -30,13 +31,13 @@ tRcvQueuesTests = do describe "queue transfer" $ do it "getDelSessQueues-batchAddQueues preserves total length" removeSubsTest -checkDataInvariant :: RQ.TRcvQueues -> IO Bool +checkDataInvariant :: RQ.Queue q => RQ.TRcvQueues q -> IO Bool checkDataInvariant trq = atomically $ do conns <- readTVar $ RQ.getConnections trq qs <- readTVar $ RQ.getRcvQueues trq -- three invariant checks - let inv1 = all (\cId -> (S.fromList . L.toList <$> M.lookup cId conns) == Just (M.keysSet (M.filter (\q -> connId q == cId) qs))) (M.keys conns) - inv2 = all (\(k, q) -> maybe False ((k `elem`) . L.toList) (M.lookup (connId q) conns)) (M.assocs qs) + let inv1 = all (\cId -> (S.fromList . L.toList <$> M.lookup cId conns) == Just (M.keysSet (M.filter (\q -> RQ.connId' q == cId) qs))) (M.keys conns) + inv2 = all (\(k, q) -> maybe False ((k `elem`) . L.toList) (M.lookup (RQ.connId' q) conns)) (M.assocs qs) inv3 = all (\(k, q) -> RQ.qKey q == k) (M.assocs qs) pure $ inv1 && inv2 && inv3 @@ -76,7 +77,7 @@ batchIdempotentTest = do atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True readTVarIO (RQ.getRcvQueues trq) `shouldReturn` qs' - fmap L.nub <$> readTVarIO (RQ.getConnections trq) `shouldReturn`cs' -- connections get duplicated, but that doesn't appear to affect anybody + fmap L.nub <$> readTVarIO (RQ.getConnections trq) `shouldReturn` cs' -- connections get duplicated, but that doesn't appear to affect anybody deleteConnTest :: IO () deleteConnTest = do @@ -112,23 +113,23 @@ getDelSessQueuesTest :: IO () getDelSessQueuesTest = do trq <- atomically RQ.empty let qs = - [ dummyRQ 0 "smp://1234-w==@alpha" "c1", - dummyRQ 0 "smp://1234-w==@alpha" "c2", - dummyRQ 0 "smp://1234-w==@beta" "c3", - dummyRQ 1 "smp://1234-w==@beta" "c4" + [ ("1", dummyRQ 0 "smp://1234-w==@alpha" "c1"), + ("1", dummyRQ 0 "smp://1234-w==@alpha" "c2"), + ("1", dummyRQ 0 "smp://1234-w==@beta" "c3"), + ("1", dummyRQ 1 "smp://1234-w==@beta" "c4") ] atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True -- no user - atomically (RQ.getDelSessQueues (2, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([], []) + atomically (RQ.getDelSessQueues (2, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([], []) checkDataInvariant trq `shouldReturn` True -- wrong user - atomically (RQ.getDelSessQueues (1, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([], []) + atomically (RQ.getDelSessQueues (1, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([], []) checkDataInvariant trq `shouldReturn` True -- connections intact atomically (RQ.hasConn "c1" trq) `shouldReturn` True atomically (RQ.hasConn "c2" trq) `shouldReturn` True - atomically (RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"], ["c1", "c2"]) + atomically (RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"], ["c1", "c2"]) checkDataInvariant trq `shouldReturn` True -- connections gone atomically (RQ.hasConn "c1" trq) `shouldReturn` False @@ -141,29 +142,29 @@ removeSubsTest :: IO () removeSubsTest = do aq <- atomically RQ.empty let qs = - [ dummyRQ 0 "smp://1234-w==@alpha" "c1", - dummyRQ 0 "smp://1234-w==@alpha" "c2", - dummyRQ 0 "smp://1234-w==@beta" "c3", - dummyRQ 1 "smp://1234-w==@beta" "c4" + [ ("1", dummyRQ 0 "smp://1234-w==@alpha" "c1"), + ("1", dummyRQ 0 "smp://1234-w==@alpha" "c2"), + ("1", dummyRQ 0 "smp://1234-w==@beta" "c3"), + ("1", dummyRQ 1 "smp://1234-w==@beta" "c4") ] atomically $ RQ.batchAddQueues aq qs pq <- atomically RQ.empty atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "non-existent") aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "non-existent") "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@localhost", Nothing) aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@localhost", Nothing) "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "c3") aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "c3") "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) -totalSize :: RQ.TRcvQueues -> RQ.TRcvQueues -> STM (Int, Int) +totalSize :: RQ.TRcvQueues q -> RQ.TRcvQueues q -> STM (Int, Int) totalSize a b = do qsizeA <- M.size <$> readTVar (RQ.getRcvQueues a) qsizeB <- M.size <$> readTVar (RQ.getRcvQueues b)