Merge branch 'master' into ep/rfc-rotation

This commit is contained in:
Evgeny Poberezkin
2022-09-06 15:19:38 +01:00
6 changed files with 67 additions and 51 deletions
+1 -1
View File
@@ -49,7 +49,7 @@ dependencies:
- memory == 0.15.*
- mtl == 2.2.*
- network >= 3.1.2.7 && < 3.2
- network-transport == 0.5.*
- network-transport == 0.5.4
- optparse-applicative >= 0.15 && < 0.17
- QuickCheck == 2.14.*
- process == 1.6.*
+5 -5
View File
@@ -126,7 +126,7 @@ library
, memory ==0.15.*
, mtl ==2.2.*
, network >=3.1.2.7 && <3.2
, network-transport ==0.5.*
, network-transport ==0.5.4
, optparse-applicative >=0.15 && <0.17
, process ==1.6.*
, random >=1.1 && <1.3
@@ -187,7 +187,7 @@ executable ntf-server
, memory ==0.15.*
, mtl ==2.2.*
, network >=3.1.2.7 && <3.2
, network-transport ==0.5.*
, network-transport ==0.5.4
, optparse-applicative >=0.15 && <0.17
, process ==1.6.*
, random >=1.1 && <1.3
@@ -249,7 +249,7 @@ executable smp-agent
, memory ==0.15.*
, mtl ==2.2.*
, network >=3.1.2.7 && <3.2
, network-transport ==0.5.*
, network-transport ==0.5.4
, optparse-applicative >=0.15 && <0.17
, process ==1.6.*
, random >=1.1 && <1.3
@@ -311,7 +311,7 @@ executable smp-server
, memory ==0.15.*
, mtl ==2.2.*
, network >=3.1.2.7 && <3.2
, network-transport ==0.5.*
, network-transport ==0.5.4
, optparse-applicative >=0.15 && <0.17
, process ==1.6.*
, random >=1.1 && <1.3
@@ -392,7 +392,7 @@ test-suite smp-server-test
, memory ==0.15.*
, mtl ==2.2.*
, network >=3.1.2.7 && <3.2
, network-transport ==0.5.*
, network-transport ==0.5.4
, optparse-applicative >=0.15 && <0.17
, process ==1.6.*
, random >=1.1 && <1.3
+35 -23
View File
@@ -627,7 +627,7 @@ enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue
enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage = do
resumeMsgDelivery c cData sq
msgId <- storeSentMsg
queuePendingMsgs c connId sq [msgId]
queuePendingMsgs c sq [msgId]
pure $ unId msgId
where
storeSentMsg :: m InternalId
@@ -647,29 +647,29 @@ enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage
resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do
let qKey = (connId, server, sndId)
let qKey = (server, sndId)
unlessM (queueDelivering qKey) $ do
mq <- atomically $ getPendingMsgQ c connId sq
mq <- atomically $ getPendingMsgQ c sq
async (runSmpQueueMsgDelivery c cData mq)
>>= \a -> atomically (TM.insert qKey a $ smpQueueMsgDeliveries c)
unlessM connQueued $
withStore' c (`getPendingMsgs` connId)
>>= queuePendingMsgs c connId sq
>>= queuePendingMsgs c sq
where
queueDelivering qKey = atomically $ TM.member qKey (smpQueueMsgDeliveries c)
connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connMsgsQueued c)
queuePendingMsgs :: AgentMonad m => AgentClient -> ConnId -> SndQueue -> [InternalId] -> m ()
queuePendingMsgs c connId sq msgIds = atomically $ do
queuePendingMsgs :: AgentMonad m => AgentClient -> SndQueue -> [InternalId] -> m ()
queuePendingMsgs c sq msgIds = atomically $ do
modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + length msgIds}
-- s <- readTVar (msgDeliveryOp c)
-- unsafeIOToSTM $ putStrLn $ "msgDeliveryOp: " <> show (opsInProgress s)
q <- getPendingMsgQ c connId sq
q <- getPendingMsgQ c sq
mapM_ (writeTQueue q) msgIds
getPendingMsgQ :: AgentClient -> ConnId -> SndQueue -> STM (TQueue InternalId)
getPendingMsgQ c connId SndQueue {server, sndId} = do
let qKey = (connId, server, sndId)
getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId)
getPendingMsgQ c SndQueue {server, sndId} = do
let qKey = (server, sndId)
maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c)
where
newMsgQueue qKey = do
@@ -881,11 +881,11 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
(Just tknId, Nothing)
| savedDeviceToken == suppliedDeviceToken ->
when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered
| otherwise -> replaceToken tknId $> NTRegistered
| otherwise -> replaceToken tknId
(Just tknId, Just (NTAVerify code))
| savedDeviceToken == suppliedDeviceToken ->
t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code
| otherwise -> replaceToken tknId $> NTRegistered
| otherwise -> replaceToken tknId
(Just tknId, Just NTACheck)
| savedDeviceToken == suppliedDeviceToken -> do
ns <- asks ntfSupervisor
@@ -897,7 +897,7 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete
pure ntfTknStatus -- TODO
-- agentNtfCheckToken c tknId tkn >>= \case
| otherwise -> replaceToken tknId $> NTRegistered
| otherwise -> replaceToken tknId
(Just tknId, Just NTADelete) -> do
agentNtfDeleteToken c tknId tkn
withStore' c (`removeNtfToken` tkn)
@@ -908,13 +908,27 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
withStore' c $ \db -> updateNtfMode db tkn suppliedNtfMode
pure status
where
replaceToken :: NtfTokenId -> m ()
replaceToken :: NtfTokenId -> m NtfTknStatus
replaceToken tknId = do
agentNtfReplaceToken c tknId tkn suppliedDeviceToken
withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken
ns <- asks ntfSupervisor
atomically $ nsUpdateToken ns tkn {deviceToken = suppliedDeviceToken, ntfTknStatus = NTRegistered, ntfMode = suppliedNtfMode}
_ ->
tryReplace ns `catchError` \e ->
if temporaryAgentError e || e == BROKER HOST
then throwError e
else do
withStore' c $ \db -> removeNtfToken db tkn
atomically $ nsRemoveNtfToken ns
createToken
where
tryReplace ns = do
agentNtfReplaceToken c tknId tkn suppliedDeviceToken
withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken
atomically $ nsUpdateToken ns tkn {deviceToken = suppliedDeviceToken, ntfTknStatus = NTRegistered, ntfMode = suppliedNtfMode}
pure NTRegistered
_ -> createToken
where
t tkn = withToken c tkn Nothing
createToken :: m NtfTknStatus
createToken =
getNtfServer c >>= \case
Just ntfServer ->
asks (cmdSignAlg . config) >>= \case
@@ -926,8 +940,6 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
registerToken tkn
pure NTRegistered
_ -> throwError $ CMD PROHIBITED
where
t tkn = withToken c tkn Nothing
registerToken :: NtfToken -> m ()
registerToken tkn@NtfToken {ntfPubKey, ntfDhKeys = (pubDhKey, privDhKey)} = do
(tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey
@@ -1409,8 +1421,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se
DuplexConnection _ _ sq@SndQueue {server, sndId} nextRq_ nextSq_ -> case nextSq_ of
Just sq'@SndQueue {server = server', sndId = sndId'} -> do
unless (smpServer == server' && senderId == sndId') . throwError $ INTERNAL "incorrect queue address"
let qKey = (connId, server, sndId)
qKey' = (connId, server', sndId')
let qKey = (server, sndId)
qKey' = (server', sndId')
ok <-
switchQueues qKey qKey' `catchError` \e -> do
atomically (switchDeliveries qKey' qKey)
@@ -1505,7 +1517,7 @@ enqueueConfirmation :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQ
enqueueConfirmation c cData@ConnData {connId, connAgentVersion} sq connInfo e2eEncryption = do
resumeMsgDelivery c cData sq
msgId <- storeConfirmation
queuePendingMsgs c connId sq [msgId]
queuePendingMsgs c sq [msgId]
where
storeConfirmation :: m InternalId
storeConfirmation = withStore c $ \db -> runExceptT $ do
+24 -16
View File
@@ -89,6 +89,7 @@ import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (listToMaybe)
import Data.Set (Set)
import qualified Data.Set as S
import Data.Text.Encoding
import Data.Time.Clock (getCurrentTime)
import Data.Tuple (swap)
@@ -146,7 +147,7 @@ type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
type NtfClientVar = TMVar (Either AgentErrorType NtfClient)
type MsgDeliveryKey = (ConnId, SMPServer, SMP.SenderId)
type MsgDeliveryKey = (SMPServer, SMP.SenderId)
data AgentClient = AgentClient
{ active :: TVar Bool,
@@ -160,7 +161,8 @@ data AgentClient = AgentClient
useNetworkConfig :: TVar NetworkConfig,
subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
subscrConns :: TMap ConnId SMPServer,
subscrConns :: TVar (Set ConnId),
activeSubscrConns :: TMap ConnId SMPServer,
connMsgsQueued :: TMap ConnId Bool,
smpQueueMsgQueues :: TMap MsgDeliveryKey (TQueue InternalId),
smpQueueMsgDeliveries :: TMap MsgDeliveryKey (Async ()),
@@ -212,7 +214,8 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
useNetworkConfig <- newTVar netCfg
subscrSrvrs <- TM.empty
pendingSubscrSrvrs <- TM.empty
subscrConns <- TM.empty
subscrConns <- newTVar S.empty
activeSubscrConns <- TM.empty
connMsgsQueued <- TM.empty
smpQueueMsgQueues <- TM.empty
smpQueueMsgDeliveries <- TM.empty
@@ -228,7 +231,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
asyncClients <- newTVar []
clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i')
lock <- newTMVar ()
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, activeSubscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
agentDbPath :: AgentClient -> FilePath
agentDbPath AgentClient {agentEnv = Env {store = SQLiteStore {dbFilePath}}} = dbFilePath
@@ -271,7 +274,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
where
updateSubs cVar = do
cs <- readTVar cVar
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
modifyTVar' (activeSubscrConns c) (`M.withoutKeys` M.keysSet cs)
addPendingSubs cVar cs
pure cs
@@ -413,12 +416,13 @@ closeAgentClient c = liftIO $ do
clear subscrSrvrs
clear pendingSubscrSrvrs
clear subscrConns
clear activeSubscrConns
clear connMsgsQueued
clear smpQueueMsgQueues
clear getMsgLocks
where
clear :: (AgentClient -> TMap k a) -> IO ()
clear sel = atomically $ writeTVar (sel c) M.empty
clear :: Monoid m => (AgentClient -> TVar m) -> IO ()
clear sel = atomically $ writeTVar (sel c) mempty
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients c clientsSel =
@@ -522,7 +526,9 @@ newRcvQueue_ a c srv vRange current = do
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
atomically $ addPendingSubscription c rq connId
atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
addPendingSubscription c rq connId
withLogClient c server rcvId "SUB" $ \smp ->
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId)
>>= either throwError pure
@@ -552,7 +558,9 @@ temporaryAgentError = \case
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ()))
subscribeQueues c srv qs = do
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
forM_ qs_ $ atomically . uncurry (addPendingSubscription c) . swap
forM_ qs_ $ \q -> atomically $ do
modifyTVar (subscrConns c) . S.insert $ fst q
uncurry (addPendingSubscription c) $ swap q
case L.nonEmpty qs_ of
Just qs' -> do
smp_ <- tryError (getSMPServerClient c srv)
@@ -574,12 +582,13 @@ subscribeQueues c srv qs = do
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
addSubscription c rq@RcvQueue {server} connId = atomically $ do
TM.insert connId server $ subscrConns c
TM.insert connId server $ activeSubscrConns c
modifyTVar (subscrConns c) $ S.insert connId
addSubs_ (subscrSrvrs c) rq connId
removePendingSubscription c server connId
hasActiveSubscription :: AgentClient -> ConnId -> STM Bool
hasActiveSubscription c connId = TM.member connId (subscrConns c)
hasActiveSubscription c connId = TM.member connId (activeSubscrConns c)
addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM ()
addPendingSubscription = addSubs_ . pendingSubscrSrvrs
@@ -591,8 +600,9 @@ addSubs_ ss rq@RcvQueue {server} connId =
_ -> TM.singleton connId rq >>= \m -> TM.insert server m ss
removeSubscription :: AgentClient -> ConnId -> STM ()
removeSubscription c@AgentClient {subscrConns} connId = do
server_ <- TM.lookupDelete connId subscrConns
removeSubscription c connId = do
modifyTVar (subscrConns c) $ S.delete connId
server_ <- TM.lookupDelete connId $ activeSubscrConns c
mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_
removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM ()
@@ -603,9 +613,7 @@ removeSubs_ ss server connId =
TM.lookup server ss >>= mapM_ (TM.delete connId)
getSubscriptions :: AgentClient -> STM (Set ConnId)
getSubscriptions AgentClient {subscrConns} = do
m <- readTVar subscrConns
pure $ M.keysSet m
getSubscriptions = readTVar . subscrConns
logServer :: MonadIO m => ByteString -> AgentClient -> ProtocolServer s -> QueueId -> ByteString -> m ()
logServer dir AgentClient {clientId} srv qId cmdStr =
@@ -106,6 +106,7 @@ removeInactiveTokenRegistrations st NtfTknData {ntfTknId = tId, token} =
forM_ tIds $ \(regKey, tId') -> do
TM.delete regKey tknRegs
TM.delete tId' $ tokens st
-- TODO remove token subscriptions as in deleteNtfToken
pure $ map snd tIds
removeTokenRegistration :: NtfStore -> NtfTknData -> STM ()
@@ -130,6 +131,7 @@ deleteNtfToken st tknId = do
)
)
-- TODO refactor
qs <-
TM.lookupDelete tknId (tokenSubscriptions st)
>>= mapM
-6
View File
@@ -276,31 +276,25 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do
liftIO $ threadDelay 1000000
aliceId <- joinConnection bob True qInfo "bob's connInfo"
liftIO $ threadDelay 750000
liftIO $ print 0
void $ messageNotification apnsQ
("", _, CONF confId _ "bob's connInfo") <- get alice
liftIO $ threadDelay 500000
allowConnection alice bobId confId "alice's connInfo"
liftIO $ print 1
void $ messageNotification apnsQ
get bob ##> ("", aliceId, INFO "alice's connInfo")
liftIO $ print 2
void $ messageNotification apnsQ
get alice ##> ("", bobId, CON)
liftIO $ print 3
void $ messageNotification apnsQ
get bob ##> ("", aliceId, CON)
-- bob sends message
1 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello"
get bob ##> ("", aliceId, SENT $ baseId + 1)
liftIO $ print 4
void $ messageNotification apnsQ
get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False
ackMessage alice bobId $ baseId + 1
-- alice sends message
2 <- msgId <$> sendMessage alice bobId (SMP.MsgFlags True) "hey there"
get alice ##> ("", bobId, SENT $ baseId + 2)
liftIO $ print 5
void $ messageNotification apnsQ
get bob =##> \case ("", c, Msg "hey there") -> c == aliceId; _ -> False
ackMessage bob aliceId $ baseId + 2