Merge branch 'master' into short-links

This commit is contained in:
Evgeny Poberezkin
2024-07-25 13:15:34 +01:00
7 changed files with 98 additions and 82 deletions

View File

@@ -2202,7 +2202,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
processSubOk :: RcvQueue -> TVar [ConnId] -> AM ()
processSubOk rq@RcvQueue {connId} upConnIds =
atomically . whenM (isPendingSub connId) $ do
addSubscription c rq
addSubscription c sessId rq
modifyTVar' upConnIds (connId :)
processSubErr :: RcvQueue -> SMPClientError -> AM ()
processSubErr rq@RcvQueue {connId} e = do

View File

@@ -306,8 +306,8 @@ data AgentClient = AgentClient
userNetworkInfo :: TVar UserNetworkInfo,
userNetworkUpdated :: TVar (Maybe UTCTime),
subscrConns :: TVar (Set ConnId),
activeSubs :: TRcvQueues,
pendingSubs :: TRcvQueues,
activeSubs :: TRcvQueues (SessionId, RcvQueue),
pendingSubs :: TRcvQueues RcvQueue,
removedSubs :: TMap (UserId, SMPServer, SMP.RecipientId) SMPClientError,
workerSeq :: TVar Int,
smpDeliveryWorkers :: TMap SndQAddr (Worker, TMVar ()),
@@ -332,7 +332,7 @@ data AgentClient = AgentClient
agentEnv :: Env,
smpServersStats :: TMap (UserId, SMPServer) AgentSMPServerStats,
xftpServersStats :: TMap (UserId, XFTPServer) AgentXFTPServerStats,
ntfServersStats :: TMap (UserId, NtfServer) AgentNtfServerStats,
ntfServersStats :: TMap (UserId, NtfServer) AgentNtfServerStats,
srvStatsStartedAt :: TVar UTCTime
}
@@ -677,11 +677,13 @@ smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess
-- because we can have a race condition when a new current client could have already
-- made subscriptions active, and the old client would be processing diconnection later.
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
removeClientAndSubs = atomically $ ifM currentActiveClient removeSubs $ pure ([], [])
removeClientAndSubs = atomically $ do
removeSessVar v tSess smpClients
ifM (readTVar active) removeSubs (pure ([], []))
where
currentActiveClient = (&&) <$> removeSessVar' v tSess smpClients <*> readTVar active
sessId = sessionId $ thParams client
removeSubs = do
(qs, cs) <- RQ.getDelSessQueues tSess $ activeSubs c
(qs, cs) <- RQ.getDelSessQueues tSess sessId $ activeSubs c
RQ.batchAddQueues (pendingSubs c) qs
-- this removes proxied relays that this client created sessions to
destSrvs <- M.keys <$> readTVar prs
@@ -1347,8 +1349,8 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode sender
qUri = SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey sndSecure
pure (rq, qUri, tSess, sessId)
processSubResult :: AgentClient -> RcvQueue -> Either SMPClientError () -> STM ()
processSubResult c rq@RcvQueue {userId, server, connId} = \case
processSubResult :: AgentClient -> SessionId -> RcvQueue -> Either SMPClientError () -> STM ()
processSubResult c sessId rq@RcvQueue {userId, server, connId} = \case
Left e ->
unless (temporaryClientError e) $ do
incSMPServerStat c userId server connSubErrs
@@ -1356,7 +1358,7 @@ processSubResult c rq@RcvQueue {userId, server, connId} = \case
Right () ->
ifM
(hasPendingSubscription c connId)
(incSMPServerStat c userId server connSubscribed >> addSubscription c rq)
(incSMPServerStat c userId server connSubscribed >> addSubscription c sessId rq)
(incSMPServerStat c userId server connSubIgnored)
temporaryAgentError :: AgentErrorType -> Bool
@@ -1427,7 +1429,7 @@ subscribeQueues c qs = do
sessId = sessionId $ thParams smp
hasTempErrors = any (either temporaryClientError (const False) . snd)
processSubResults :: NonEmpty (RcvQueue, Either SMPClientError ()) -> STM ()
processSubResults = mapM_ $ uncurry $ processSubResult c
processSubResults = mapM_ $ uncurry $ processSubResult c sessId
resubscribe = resubscribeSMPSession c tSess `runReaderT` env
activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool
@@ -1466,10 +1468,10 @@ sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs)
where
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
addSubscription :: AgentClient -> RcvQueue -> STM ()
addSubscription c rq@RcvQueue {connId} = do
addSubscription :: AgentClient -> SessionId -> RcvQueue -> STM ()
addSubscription c sessId rq@RcvQueue {connId} = do
modifyTVar' (subscrConns c) $ S.insert connId
RQ.addQueue rq $ activeSubs c
RQ.addQueue (sessId, rq) $ activeSubs c
RQ.deleteQueue rq $ pendingSubs c
failSubscription :: AgentClient -> RcvQueue -> SMPClientError -> STM ()
@@ -1488,7 +1490,7 @@ addNewQueueSubscription c rq tSess sessId = do
atomically $
ifM
(activeClientSession c tSess sessId)
(True <$ addSubscription c rq)
(True <$ addSubscription c sessId rq)
(False <$ addPendingSubscription c rq)
unless same $ resubscribeSMPSession c tSess
@@ -2025,7 +2027,9 @@ getAgentSubsTotal c userIds = do
sess <- hasSession . M.toList =<< readTVarIO (smpClients c)
pure (SMPServerSubs {ssActive, ssPending}, sess)
where
getSubsCount :: (AgentClient -> TRcvQueues q) -> IO Int
getSubsCount subs = M.foldrWithKey' addSub 0 <$> readTVarIO (getRcvQueues $ subs c)
addSub :: (UserId, SMPServer, SMP.RecipientId) -> q -> Int -> Int
addSub (userId, _, _) _ cnt = if userId `elem` userIds then cnt + 1 else cnt
hasSession :: [(SMPTransportSession, SMPClientVar)] -> IO Bool
hasSession = \case
@@ -2106,6 +2110,7 @@ getAgentSubscriptions c = do
removedSubscriptions <- getRemovedSubs
pure $ SubscriptionsInfo {activeSubscriptions, pendingSubscriptions, removedSubscriptions}
where
getSubs :: (AgentClient -> TRcvQueues q) -> IO [SubInfo]
getSubs sel = map (`subInfo` Nothing) . M.keys <$> readTVarIO (getRcvQueues $ sel c)
getRemovedSubs = map (uncurry subInfo . second Just) . M.assocs <$> readTVarIO (removedSubs c)
subInfo :: (UserId, SMPServer, SMP.RecipientId) -> Maybe SMPClientError -> SubInfo

View File

@@ -1,7 +1,9 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
module Simplex.Messaging.Agent.TRcvQueues
( TRcvQueues (getRcvQueues, getConnections),
Queue (..),
empty,
clear,
deleteConn,
@@ -11,7 +13,6 @@ module Simplex.Messaging.Agent.TRcvQueues
deleteQueue,
getSessQueues,
getDelSessQueues,
qKey,
)
where
@@ -25,46 +26,51 @@ import Simplex.Messaging.Agent.Store (RcvQueue, StoredRcvQueue (..))
import Simplex.Messaging.Protocol (RecipientId, SMPServer)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport
class Queue q where
connId' :: q -> ConnId
qKey :: q -> (UserId, SMPServer, ConnId)
-- the fields in this record have the same data with swapped keys for lookup efficiency,
-- and all methods must maintain this invariant.
data TRcvQueues = TRcvQueues
{ getRcvQueues :: TMap (UserId, SMPServer, RecipientId) RcvQueue,
data TRcvQueues q = TRcvQueues
{ getRcvQueues :: TMap (UserId, SMPServer, RecipientId) q,
getConnections :: TMap ConnId (NonEmpty (UserId, SMPServer, RecipientId))
}
empty :: STM TRcvQueues
empty :: STM (TRcvQueues q)
empty = TRcvQueues <$> TM.empty <*> TM.empty
clear :: TRcvQueues -> STM ()
clear :: TRcvQueues q -> STM ()
clear (TRcvQueues qs cs) = TM.clear qs >> TM.clear cs
deleteConn :: ConnId -> TRcvQueues -> STM ()
deleteConn :: ConnId -> TRcvQueues q -> STM ()
deleteConn cId (TRcvQueues qs cs) =
TM.lookupDelete cId cs >>= \case
Just ks -> modifyTVar' qs $ \qs' -> foldl' (flip M.delete) qs' ks
Nothing -> pure ()
hasConn :: ConnId -> TRcvQueues -> STM Bool
hasConn :: ConnId -> TRcvQueues q -> STM Bool
hasConn cId (TRcvQueues _ cs) = TM.member cId cs
addQueue :: RcvQueue -> TRcvQueues -> STM ()
addQueue :: Queue q => q -> TRcvQueues q -> STM ()
addQueue rq (TRcvQueues qs cs) = do
TM.insert k rq qs
TM.alter addQ (connId rq) cs
TM.alter addQ (connId' rq) cs
where
addQ = Just . maybe (k :| []) (k <|)
k = qKey rq
-- Save time by aggregating modifyTVar
batchAddQueues :: Foldable t => TRcvQueues -> t RcvQueue -> STM ()
batchAddQueues :: (Foldable t, Queue q) => TRcvQueues q -> t q -> STM ()
batchAddQueues (TRcvQueues qs cs) rqs = do
modifyTVar' qs $ \now -> foldl' (\rqs' rq -> M.insert (qKey rq) rq rqs') now rqs
modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId rq) cs') now rqs
modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId' rq) cs') now rqs
where
addQ k = Just . maybe (k :| []) (k <|)
deleteQueue :: RcvQueue -> TRcvQueues -> STM ()
deleteQueue :: RcvQueue -> TRcvQueues RcvQueue -> STM ()
deleteQueue rq (TRcvQueues qs cs) = do
TM.delete k qs
TM.update delQ (connId rq) cs
@@ -72,21 +78,22 @@ deleteQueue rq (TRcvQueues qs cs) = do
delQ = L.nonEmpty . L.filter (/= k)
k = qKey rq
getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue]
getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues RcvQueue -> STM [RcvQueue]
getSessQueues tSess (TRcvQueues qs _) = M.foldl' addQ [] <$> readTVar qs
where
addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs'
getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM ([RcvQueue], [ConnId])
getDelSessQueues tSess (TRcvQueues qs cs) = do
getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> SessionId -> TRcvQueues (SessionId, RcvQueue) -> STM ([RcvQueue], [ConnId])
getDelSessQueues tSess sessId' (TRcvQueues qs cs) = do
(removedQs, qs'') <- (\qs' -> M.foldl' delQ ([], qs') qs') <$> readTVar qs
writeTVar qs $! qs''
removedConns <- stateTVar cs $ \cs' -> foldl' delConn ([], cs') removedQs
pure (removedQs, removedConns)
where
delQ acc@(removed, qs') rq
| rq `isSession` tSess = (rq : removed, M.delete (qKey rq) qs')
delQ acc@(removed, qs') (sessId, rq)
| rq `isSession` tSess && sessId == sessId' = (rq : removed, M.delete (qKey rq) qs')
| otherwise = acc
delConn :: ([ConnId], M.Map ConnId (NonEmpty (UserId, SMPServer, ConnId))) -> RcvQueue -> ([ConnId], M.Map ConnId (NonEmpty (UserId, SMPServer, ConnId)))
delConn (removed, cs') rq = M.alterF f cId cs'
where
cId = connId rq
@@ -100,5 +107,10 @@ isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool
isSession rq (uId, srv, connId_) =
userId rq == uId && server rq == srv && maybe True (connId rq ==) connId_
qKey :: RcvQueue -> (UserId, SMPServer, ConnId)
qKey rq = (userId rq, server rq, connId rq)
instance Queue RcvQueue where
connId' = connId
qKey rq = (userId rq, server rq, connId rq)
instance Queue (SessionId, RcvQueue) where
connId' = connId . snd
qKey = qKey . snd

View File

@@ -100,7 +100,7 @@ data SMPClientAgent = SMPClientAgent
randomDrg :: TVar ChaChaDRG,
smpClients :: TMap SMPServer SMPClientVar,
smpSessions :: TMap SessionId (OwnServer, SMPClient),
srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
srvSubs :: TMap SMPServer (TMap SMPSub (SessionId, C.APrivateAuthKey)),
pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
smpSubWorkers :: TMap SMPServer (SessionVar (Async ())),
workerSeq :: TVar Int
@@ -204,14 +204,17 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random
removeClientAndSubs :: SMPClient -> IO (Maybe (Map SMPSub C.APrivateAuthKey))
removeClientAndSubs smp = atomically $ do
TM.delete sessId smpSessions
removeSessVar v srv smpClients
TM.delete (sessionId $ thParams smp) smpSessions
TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs
TM.lookup srv (srvSubs ca) >>= mapM updateSubs
where
sessId = sessionId $ thParams smp
updateSubs sVar = do
ss <- readTVar sVar
addSubs_ (pendingSrvSubs ca) srv ss
pure ss
-- removing subscriptions that have matching sessionId to disconnected client
-- and keep the other ones (they can be made by the new client)
pending <- M.map snd <$> stateTVar sVar (M.partition ((sessId ==) . fst))
addSubs_ (pendingSrvSubs ca) srv pending
pure pending
serverDown :: Map SMPSub C.APrivateAuthKey -> IO ()
serverDown ss = unless (M.null ss) $ do
@@ -256,9 +259,9 @@ reconnectSMPClient ca@SMPClientAgent {agentCfg} srv cs =
subscribe_ smp SPNotifier nSubs
subscribe_ smp SPRecipient rSubs
where
groupSub :: Map SMPSub C.APrivateAuthKey -> (SMPSub, C.APrivateAuthKey) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)])
groupSub currSubs (s@(party, qId), k) (nSubs, rSubs)
| M.member s currSubs = (nSubs, rSubs)
groupSub :: Map SMPSub (SessionId, C.APrivateAuthKey) -> (SMPSub, C.APrivateAuthKey) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)])
groupSub currSubs (s@(party, qId), k) acc@(nSubs, rSubs)
| M.member s currSubs = acc
| otherwise = case party of
SPNotifier -> (s' : nSubs, rSubs)
SPRecipient -> (nSubs, s' : rSubs)
@@ -346,17 +349,18 @@ smpSubscribeQueues party ca smp srv subs = do
when tempErrs $ reconnectClient ca srv
Nothing -> reconnectClient ca srv
where
processSubscriptions :: NonEmpty (Either SMPClientError ()) -> STM (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId])
processSubscriptions :: NonEmpty (Either SMPClientError ()) -> STM (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId])
processSubscriptions rs = do
pending <- maybe (pure M.empty) readTVar =<< TM.lookup srv (pendingSrvSubs ca)
let acc@(_, _, oks, notPending) = foldr (groupSub pending) (False, [], [], []) (L.zip subs rs)
unless (null oks) $ addSubscriptions ca srv party oks
unless (null notPending) $ removePendingSubs ca srv party notPending
pure acc
groupSub :: Map SMPSub C.APrivateAuthKey -> ((QueueId, C.APrivateAuthKey), Either SMPClientError ()) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId]) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId])
groupSub pending (s@(qId, _), r) acc@(!tempErrs, finalErrs, oks, notPending) = case r of
sessId = sessionId $ thParams smp
groupSub :: Map SMPSub C.APrivateAuthKey -> ((QueueId, C.APrivateAuthKey), Either SMPClientError ()) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId])
groupSub pending ((qId, pk), r) acc@(!tempErrs, finalErrs, oks, notPending) = case r of
Right ()
| M.member (party, qId) pending -> (tempErrs, finalErrs, s : oks, qId : notPending)
| M.member (party, qId) pending -> (tempErrs, finalErrs, (qId, (sessId, pk)) : oks, qId : notPending)
| otherwise -> acc
Left e
| temporaryClientError e -> (True, finalErrs, oks, notPending)
@@ -379,7 +383,7 @@ showServer :: SMPServer -> ByteString
showServer ProtocolServer {host, port} =
strEncode host <> B.pack (if null port then "" else ':' : port)
addSubscriptions :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> STM ()
addSubscriptions :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, (SessionId, C.APrivateAuthKey))] -> STM ()
addSubscriptions = addSubsList_ . srvSubs
{-# INLINE addSubscriptions #-}
@@ -387,12 +391,12 @@ addPendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, C.APr
addPendingSubs = addSubsList_ . pendingSrvSubs
{-# INLINE addPendingSubs #-}
addSubsList_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> STM ()
addSubsList_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> SMPSubParty -> [(QueueId, s)] -> STM ()
addSubsList_ subs srv party ss = addSubs_ subs srv ss'
where
ss' = M.fromList $ map (first (party,)) ss
addSubs_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> Map SMPSub C.APrivateAuthKey -> STM ()
addSubs_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> Map SMPSub s -> STM ()
addSubs_ subs srv ss =
TM.lookup srv subs >>= \case
Just m -> TM.union ss m
@@ -402,7 +406,7 @@ removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
removeSubscription = removeSub_ . srvSubs
{-# INLINE removeSubscription #-}
removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM ()
removeSub_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> SMPSub -> STM ()
removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s)
removePendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [QueueId] -> STM ()

View File

@@ -11,7 +11,7 @@ import Data.Text.Encoding (decodeLatin1, encodeUtf8)
import Data.Time (UTCTime)
import Database.SQLite.Simple.FromField (FromField (..))
import Database.SQLite.Simple.ToField (ToField (..))
import Simplex.Messaging.Agent.Protocol (UserId, ConnId, NotificationsMode (..))
import Simplex.Messaging.Agent.Protocol (ConnId, NotificationsMode (..), UserId)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Protocol
@@ -48,6 +48,7 @@ data NtfToken = NtfToken
ntfServer :: NtfServer,
ntfTokenId :: Maybe NtfTokenId,
-- TODO combine keys to key pair as the types should match
-- | key used by the ntf server to verify transmissions
ntfPubKey :: C.APublicAuthKey,
-- | key used by the ntf client to sign transmissions

View File

@@ -5,9 +5,6 @@
module Simplex.Messaging.Session where
import Control.Concurrent.STM
import Control.Monad
import Data.Composition ((.:.))
import Data.Functor (($>))
import Data.Time (UTCTime)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
@@ -31,14 +28,10 @@ getSessVar sessSeq sessKey vs sessionVarTs = maybe (Left <$> newSessionVar) (pur
pure v
removeSessVar :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM ()
removeSessVar = void .:. removeSessVar'
{-# INLINE removeSessVar #-}
removeSessVar' :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM Bool
removeSessVar' v sessKey vs =
removeSessVar v sessKey vs =
TM.lookup sessKey vs >>= \case
Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs $> True
_ -> pure False
Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs
_ -> pure ()
tryReadSessVar :: Ord k => k -> TMap k (SessionVar a) -> STM (Maybe a)
tryReadSessVar sessKey vs = TM.lookup sessKey vs $>>= (tryReadTMVar . sessionVar)