mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-04-05 04:36:11 +00:00
Merge branch 'stable'
This commit is contained in:
@@ -618,6 +618,7 @@ test-suite simplexmq-test
|
||||
CoreTests.EncodingTests
|
||||
CoreTests.ProtocolErrorTests
|
||||
CoreTests.RetryIntervalTests
|
||||
CoreTests.TRcvQueuesTests
|
||||
CoreTests.UtilTests
|
||||
CoreTests.VersionRangeTests
|
||||
FileDescriptionTests
|
||||
|
||||
@@ -124,6 +124,9 @@ module Simplex.Messaging.Agent.Client
|
||||
getAgentWorkersDetails,
|
||||
AgentWorkersSummary (..),
|
||||
getAgentWorkersSummary,
|
||||
SMPTransportSession,
|
||||
NtfTransportSession,
|
||||
XFTPTransportSession,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -532,11 +535,9 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv,
|
||||
where
|
||||
currentActiveClient = (&&) <$> removeTSessVar' v tSess smpClients <*> readTVar active
|
||||
removeSubs = do
|
||||
qs <- RQ.getDelSessQueues tSess $ activeSubs c
|
||||
mapM_ (`RQ.addQueue` pendingSubs c) qs
|
||||
let cs = S.fromList $ map qConnId qs
|
||||
cs' <- RQ.getConns $ activeSubs c
|
||||
pure (qs, S.toList $ cs `S.difference` cs')
|
||||
(qs, cs) <- RQ.getDelSessQueues tSess $ activeSubs c
|
||||
RQ.batchAddQueues (pendingSubs c) qs
|
||||
pure (qs, cs)
|
||||
|
||||
serverDown :: ([RcvQueue], [ConnId]) -> IO ()
|
||||
serverDown (qs, conns) = whenM (readTVarIO active) $ do
|
||||
@@ -594,16 +595,16 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do
|
||||
where
|
||||
resubscribe :: m ()
|
||||
resubscribe = do
|
||||
cs <- atomically . RQ.getConns $ activeSubs c
|
||||
cs <- readTVarIO $ RQ.getConnections $ activeSubs c
|
||||
rs <- subscribeQueues c $ L.toList qs
|
||||
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
|
||||
liftIO $ do
|
||||
let conns = S.toList $ S.fromList okConns `S.difference` cs
|
||||
let conns = filter (`M.notMember` cs) okConns
|
||||
unless (null conns) $ notifySub "" $ UP srv conns
|
||||
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
|
||||
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
|
||||
forM_ (listToMaybe tempErrs) $ \(_, err) -> do
|
||||
when (null okConns && S.null cs && null finalErrs) . liftIO $
|
||||
when (null okConns && M.null cs && null finalErrs) . liftIO $
|
||||
closeClient c smpClients tSess
|
||||
throwError err
|
||||
notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO ()
|
||||
@@ -1060,9 +1061,9 @@ temporaryOrHostError = \case
|
||||
subscribeQueues :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
subscribeQueues c qs = do
|
||||
(errs, qs') <- partitionEithers <$> mapM checkQueue qs
|
||||
forM_ qs' $ \rq@RcvQueue {connId} -> atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
RQ.addQueue rq $ pendingSubs c
|
||||
atomically $ do
|
||||
modifyTVar' (subscrConns c) (`S.union` S.fromList (map qConnId qs'))
|
||||
RQ.batchAddQueues (pendingSubs c) qs'
|
||||
u <- askUnliftIO
|
||||
-- only "checked" queues are subscribed
|
||||
(errs <>) <$> sendTSessionBatches "SUB" 90 id (subscribeQueues_ u) c qs'
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module Simplex.Messaging.Agent.TRcvQueues
|
||||
( TRcvQueues (getRcvQueues),
|
||||
( TRcvQueues (getRcvQueues, getConnections),
|
||||
empty,
|
||||
clear,
|
||||
deleteConn,
|
||||
hasConn,
|
||||
getConns,
|
||||
addQueue,
|
||||
batchAddQueues,
|
||||
deleteQueue,
|
||||
getSessQueues,
|
||||
getDelSessQueues,
|
||||
@@ -14,49 +16,85 @@ module Simplex.Messaging.Agent.TRcvQueues
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Data.Foldable (foldl')
|
||||
import Data.List.NonEmpty (NonEmpty (..), (<|))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Simplex.Messaging.Agent.Protocol (ConnId, UserId)
|
||||
import Simplex.Messaging.Agent.Store (RcvQueue, StoredRcvQueue (..))
|
||||
import Simplex.Messaging.Protocol (RecipientId, SMPServer)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
|
||||
newtype TRcvQueues = TRcvQueues {getRcvQueues :: TMap (UserId, SMPServer, RecipientId) RcvQueue}
|
||||
-- the fields in this record have the same data with swapped keys for lookup efficiency,
|
||||
-- and all methods must maintain this invariant.
|
||||
data TRcvQueues = TRcvQueues
|
||||
{ getRcvQueues :: TMap (UserId, SMPServer, RecipientId) RcvQueue,
|
||||
getConnections :: TMap ConnId (NonEmpty (UserId, SMPServer, RecipientId))
|
||||
}
|
||||
|
||||
empty :: STM TRcvQueues
|
||||
empty = TRcvQueues <$> TM.empty
|
||||
empty = TRcvQueues <$> TM.empty <*> TM.empty
|
||||
|
||||
clear :: TRcvQueues -> STM ()
|
||||
clear (TRcvQueues qs) = TM.clear qs
|
||||
clear (TRcvQueues qs cs) = TM.clear qs >> TM.clear cs
|
||||
|
||||
deleteConn :: ConnId -> TRcvQueues -> STM ()
|
||||
deleteConn cId (TRcvQueues qs) = modifyTVar' qs $ M.filter (\rq -> cId /= connId rq)
|
||||
deleteConn cId (TRcvQueues qs cs) =
|
||||
TM.lookupDelete cId cs >>= \case
|
||||
Just ks -> modifyTVar' qs $ \qs' -> foldl' (flip M.delete) qs' ks
|
||||
Nothing -> pure ()
|
||||
|
||||
hasConn :: ConnId -> TRcvQueues -> STM Bool
|
||||
hasConn cId (TRcvQueues qs) = any (\rq -> cId == connId rq) <$> readTVar qs
|
||||
|
||||
getConns :: TRcvQueues -> STM (Set ConnId)
|
||||
getConns (TRcvQueues qs) = M.foldr' (S.insert . connId) S.empty <$> readTVar qs
|
||||
hasConn cId (TRcvQueues _ cs) = TM.member cId cs
|
||||
|
||||
addQueue :: RcvQueue -> TRcvQueues -> STM ()
|
||||
addQueue rq (TRcvQueues qs) = TM.insert (qKey rq) rq qs
|
||||
addQueue rq (TRcvQueues qs cs) = do
|
||||
TM.insert k rq qs
|
||||
TM.alter addQ (connId rq) cs
|
||||
where
|
||||
addQ = Just . maybe (k :| []) (k <|)
|
||||
k = qKey rq
|
||||
|
||||
-- Save time by aggregating modifyTVar
|
||||
batchAddQueues :: Foldable t => TRcvQueues -> t RcvQueue -> STM ()
|
||||
batchAddQueues (TRcvQueues qs cs) rqs = do
|
||||
modifyTVar' qs $ \now -> foldl' (\rqs' rq -> M.insert (qKey rq) rq rqs') now rqs
|
||||
modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId rq) cs') now rqs
|
||||
where
|
||||
addQ k = Just . maybe (k :| []) (k <|)
|
||||
|
||||
deleteQueue :: RcvQueue -> TRcvQueues -> STM ()
|
||||
deleteQueue rq (TRcvQueues qs) = TM.delete (qKey rq) qs
|
||||
deleteQueue rq (TRcvQueues qs cs) = do
|
||||
TM.delete k qs
|
||||
TM.update delQ (connId rq) cs
|
||||
where
|
||||
delQ = L.nonEmpty . L.filter (/= k)
|
||||
k = qKey rq
|
||||
|
||||
getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue]
|
||||
getSessQueues tSess (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
|
||||
getSessQueues tSess (TRcvQueues qs _) = M.foldl' addQ [] <$> readTVar qs
|
||||
where
|
||||
addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs'
|
||||
|
||||
getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue]
|
||||
getDelSessQueues tSess (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty)
|
||||
getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM ([RcvQueue], [ConnId])
|
||||
getDelSessQueues tSess (TRcvQueues qs cs) = do
|
||||
(removedQs, qs'') <- (\qs' -> M.foldl' delQ ([], qs') qs') <$> readTVar qs
|
||||
writeTVar qs $! qs''
|
||||
removedConns <- stateTVar cs $ \cs' -> foldl' delConn ([], cs') removedQs
|
||||
pure (removedQs, removedConns)
|
||||
where
|
||||
addQ (removed, qs') rq
|
||||
| rq `isSession` tSess = (rq : removed, qs')
|
||||
| otherwise = (removed, M.insert (qKey rq) rq qs')
|
||||
delQ acc@(removed, qs') rq
|
||||
| rq `isSession` tSess = (rq : removed, M.delete (qKey rq) qs')
|
||||
| otherwise = acc
|
||||
delConn (removed, cs') rq = M.alterF f cId cs'
|
||||
where
|
||||
cId = connId rq
|
||||
f = \case
|
||||
Just ks -> case L.nonEmpty $ L.filter (qKey rq /=) ks of
|
||||
Just ks' -> (removed, Just ks')
|
||||
Nothing -> (cId : removed, Nothing)
|
||||
Nothing -> (removed, Nothing) -- "impossible" in invariant holds, because we get keys from the known queues
|
||||
|
||||
isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool
|
||||
isSession rq (uId, srv, connId_) =
|
||||
|
||||
142
tests/CoreTests/TRcvQueuesTests.hs
Normal file
142
tests/CoreTests/TRcvQueuesTests.hs
Normal file
@@ -0,0 +1,142 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module CoreTests.TRcvQueuesTests where
|
||||
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map as M
|
||||
import qualified Data.Set as S
|
||||
import Simplex.Messaging.Agent.Protocol (ConnId, QueueStatus (..), UserId)
|
||||
import Simplex.Messaging.Agent.Store (DBQueueId (..), RcvQueue, StoredRcvQueue (..))
|
||||
import qualified Simplex.Messaging.Agent.TRcvQueues as RQ
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (SMPServer)
|
||||
import Test.Hspec
|
||||
import UnliftIO
|
||||
|
||||
tRcvQueuesTests :: Spec
|
||||
tRcvQueuesTests = do
|
||||
describe "connection API" $ do
|
||||
it "hasConn" hasConnTest
|
||||
it "hasConn, batch add" hasConnTestBatch
|
||||
it "deleteConn" deleteConnTest
|
||||
describe "session API" $ do
|
||||
it "getSessQueues" getSessQueuesTest
|
||||
it "getDelSessQueues" getDelSessQueuesTest
|
||||
|
||||
checkDataInvariant :: RQ.TRcvQueues -> IO Bool
|
||||
checkDataInvariant trq = atomically $ do
|
||||
conns <- readTVar $ RQ.getConnections trq
|
||||
qs <- readTVar $ RQ.getRcvQueues trq
|
||||
-- three invariant checks
|
||||
let inv1 = all (\cId -> (S.fromList . L.toList <$> M.lookup cId conns) == Just (M.keysSet (M.filter (\q -> connId q == cId) qs))) (M.keys conns)
|
||||
inv2 = all (\(k, q) -> maybe False ((k `elem`) . L.toList) (M.lookup (connId q) conns)) (M.assocs qs)
|
||||
inv3 = all (\(k, q) -> RQ.qKey q == k) (M.assocs qs)
|
||||
pure $ inv1 && inv2 && inv3
|
||||
|
||||
hasConnTest :: IO ()
|
||||
hasConnTest = do
|
||||
trq <- atomically RQ.empty
|
||||
atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@beta" "c3") trq
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
atomically (RQ.hasConn "c1" trq) `shouldReturn` True
|
||||
atomically (RQ.hasConn "c2" trq) `shouldReturn` True
|
||||
atomically (RQ.hasConn "c3" trq) `shouldReturn` True
|
||||
atomically (RQ.hasConn "nope" trq) `shouldReturn` False
|
||||
|
||||
hasConnTestBatch :: IO ()
|
||||
hasConnTestBatch = do
|
||||
trq <- atomically RQ.empty
|
||||
let qs = [dummyRQ 0 "smp://1234-w==@alpha" "c1", dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@beta" "c3"]
|
||||
atomically $ RQ.batchAddQueues trq qs
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
atomically (RQ.hasConn "c1" trq) `shouldReturn` True
|
||||
atomically (RQ.hasConn "c2" trq) `shouldReturn` True
|
||||
atomically (RQ.hasConn "c3" trq) `shouldReturn` True
|
||||
atomically (RQ.hasConn "nope" trq) `shouldReturn` False
|
||||
|
||||
deleteConnTest :: IO ()
|
||||
deleteConnTest = do
|
||||
trq <- atomically RQ.empty
|
||||
atomically $ do
|
||||
RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq
|
||||
RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq
|
||||
RQ.addQueue (dummyRQ 0 "smp://1234-w==@beta" "c3") trq
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
atomically $ RQ.deleteConn "c1" trq
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
atomically $ RQ.deleteConn "nope" trq
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
M.keys <$> readTVarIO (RQ.getConnections trq) `shouldReturn` ["c2", "c3"]
|
||||
|
||||
getSessQueuesTest :: IO ()
|
||||
getSessQueuesTest = do
|
||||
trq <- atomically RQ.empty
|
||||
atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@beta" "c3") trq
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
atomically $ RQ.addQueue (dummyRQ 1 "smp://1234-w==@beta" "c4") trq
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Just "c1") trq) `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c1"]
|
||||
atomically (RQ.getSessQueues (1, "smp://1234-w==@alpha", Just "c1") trq) `shouldReturn` []
|
||||
atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Just "nope") trq) `shouldReturn` []
|
||||
atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"]
|
||||
|
||||
getDelSessQueuesTest :: IO ()
|
||||
getDelSessQueuesTest = do
|
||||
trq <- atomically RQ.empty
|
||||
let qs =
|
||||
[ dummyRQ 0 "smp://1234-w==@alpha" "c1",
|
||||
dummyRQ 0 "smp://1234-w==@alpha" "c2",
|
||||
dummyRQ 0 "smp://1234-w==@beta" "c3",
|
||||
dummyRQ 1 "smp://1234-w==@beta" "c4"
|
||||
]
|
||||
atomically $ RQ.batchAddQueues trq qs
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
-- no user
|
||||
atomically (RQ.getDelSessQueues (2, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([], [])
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
-- wrong user
|
||||
atomically (RQ.getDelSessQueues (1, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([], [])
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
-- connections intact
|
||||
atomically (RQ.hasConn "c1" trq) `shouldReturn` True
|
||||
atomically (RQ.hasConn "c2" trq) `shouldReturn` True
|
||||
atomically (RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"], ["c1", "c2"])
|
||||
checkDataInvariant trq `shouldReturn` True
|
||||
-- connections gone
|
||||
atomically (RQ.hasConn "c1" trq) `shouldReturn` False
|
||||
atomically (RQ.hasConn "c2" trq) `shouldReturn` False
|
||||
-- non-matched connections intact
|
||||
atomically (RQ.hasConn "c3" trq) `shouldReturn` True
|
||||
atomically (RQ.hasConn "c4" trq) `shouldReturn` True
|
||||
|
||||
dummyRQ :: UserId -> SMPServer -> ConnId -> RcvQueue
|
||||
dummyRQ userId server connId =
|
||||
RcvQueue
|
||||
{ userId,
|
||||
connId,
|
||||
server,
|
||||
rcvId = "",
|
||||
rcvPrivateKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe",
|
||||
rcvDhSecret = "01234567890123456789012345678901",
|
||||
e2ePrivKey = "MC4CAQAwBQYDK2VuBCIEINCzbVFaCiYHoYncxNY8tSIfn0pXcIAhLBfFc0m+gOpk",
|
||||
e2eDhSecret = Nothing,
|
||||
sndId = "",
|
||||
status = New,
|
||||
dbQueueId = DBQueueId 0,
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
rcvSwchStatus = Nothing,
|
||||
smpClientVersion = 123,
|
||||
clientNtfCreds = Nothing,
|
||||
deleteErrors = 0
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import CoreTests.CryptoTests
|
||||
import CoreTests.EncodingTests
|
||||
import CoreTests.ProtocolErrorTests
|
||||
import CoreTests.RetryIntervalTests
|
||||
import CoreTests.TRcvQueuesTests
|
||||
import CoreTests.UtilTests
|
||||
import CoreTests.VersionRangeTests
|
||||
import FileDescriptionTests (fileDescriptionTests)
|
||||
@@ -52,6 +53,7 @@ main = do
|
||||
describe "Encryption tests" cryptoTests
|
||||
describe "Encrypted files tests" cryptoFileTests
|
||||
describe "Retry interval tests" retryIntervalTests
|
||||
describe "TRcvQueues tests" tRcvQueuesTests
|
||||
describe "Util tests" utilTests
|
||||
describe "SMP server via TLS" $ serverTests (transport @TLS)
|
||||
describe "SMP server via WebSockets" $ serverTests (transport @WS)
|
||||
|
||||
Reference in New Issue
Block a user