diff --git a/package.yaml b/package.yaml index ac686a874..f6999668c 100644 --- a/package.yaml +++ b/package.yaml @@ -49,7 +49,7 @@ dependencies: - memory == 0.15.* - mtl == 2.2.* - network >= 3.1.2.7 && < 3.2 - - network-transport == 0.5.* + - network-transport == 0.5.4 - optparse-applicative >= 0.15 && < 0.17 - QuickCheck == 2.14.* - process == 1.6.* diff --git a/simplexmq.cabal b/simplexmq.cabal index 8410ce723..617abb389 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -126,7 +126,7 @@ library , memory ==0.15.* , mtl ==2.2.* , network >=3.1.2.7 && <3.2 - , network-transport ==0.5.* + , network-transport ==0.5.4 , optparse-applicative >=0.15 && <0.17 , process ==1.6.* , random >=1.1 && <1.3 @@ -187,7 +187,7 @@ executable ntf-server , memory ==0.15.* , mtl ==2.2.* , network >=3.1.2.7 && <3.2 - , network-transport ==0.5.* + , network-transport ==0.5.4 , optparse-applicative >=0.15 && <0.17 , process ==1.6.* , random >=1.1 && <1.3 @@ -249,7 +249,7 @@ executable smp-agent , memory ==0.15.* , mtl ==2.2.* , network >=3.1.2.7 && <3.2 - , network-transport ==0.5.* + , network-transport ==0.5.4 , optparse-applicative >=0.15 && <0.17 , process ==1.6.* , random >=1.1 && <1.3 @@ -311,7 +311,7 @@ executable smp-server , memory ==0.15.* , mtl ==2.2.* , network >=3.1.2.7 && <3.2 - , network-transport ==0.5.* + , network-transport ==0.5.4 , optparse-applicative >=0.15 && <0.17 , process ==1.6.* , random >=1.1 && <1.3 @@ -392,7 +392,7 @@ test-suite smp-server-test , memory ==0.15.* , mtl ==2.2.* , network >=3.1.2.7 && <3.2 - , network-transport ==0.5.* + , network-transport ==0.5.4 , optparse-applicative >=0.15 && <0.17 , process ==1.6.* , random >=1.1 && <1.3 diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index ea6c3a72d..51d0aca61 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -627,7 +627,7 @@ enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage = do resumeMsgDelivery c cData sq msgId <- storeSentMsg - queuePendingMsgs c connId sq [msgId] + queuePendingMsgs c sq [msgId] pure $ unId msgId where storeSentMsg :: m InternalId @@ -647,29 +647,29 @@ enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m () resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do - let qKey = (connId, server, sndId) + let qKey = (server, sndId) unlessM (queueDelivering qKey) $ do - mq <- atomically $ getPendingMsgQ c connId sq + mq <- atomically $ getPendingMsgQ c sq async (runSmpQueueMsgDelivery c cData mq) >>= \a -> atomically (TM.insert qKey a $ smpQueueMsgDeliveries c) unlessM connQueued $ withStore' c (`getPendingMsgs` connId) - >>= queuePendingMsgs c connId sq + >>= queuePendingMsgs c sq where queueDelivering qKey = atomically $ TM.member qKey (smpQueueMsgDeliveries c) connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connMsgsQueued c) -queuePendingMsgs :: AgentMonad m => AgentClient -> ConnId -> SndQueue -> [InternalId] -> m () -queuePendingMsgs c connId sq msgIds = atomically $ do +queuePendingMsgs :: AgentMonad m => AgentClient -> SndQueue -> [InternalId] -> m () +queuePendingMsgs c sq msgIds = atomically $ do modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + length msgIds} -- s <- readTVar (msgDeliveryOp c) -- unsafeIOToSTM $ putStrLn $ "msgDeliveryOp: " <> show (opsInProgress s) - q <- getPendingMsgQ c connId sq + q <- getPendingMsgQ c sq mapM_ (writeTQueue q) msgIds -getPendingMsgQ :: AgentClient -> ConnId -> SndQueue -> STM (TQueue InternalId) -getPendingMsgQ c connId SndQueue {server, sndId} = do - let qKey = (connId, server, sndId) +getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId) +getPendingMsgQ c SndQueue {server, sndId} = do + let qKey = (server, sndId) maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c) where newMsgQueue qKey = do @@ -881,11 +881,11 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = (Just tknId, Nothing) | savedDeviceToken == suppliedDeviceToken -> when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered - | otherwise -> replaceToken tknId $> NTRegistered + | otherwise -> replaceToken tknId (Just tknId, Just (NTAVerify code)) | savedDeviceToken == suppliedDeviceToken -> t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code - | otherwise -> replaceToken tknId $> NTRegistered + | otherwise -> replaceToken tknId (Just tknId, Just NTACheck) | savedDeviceToken == suppliedDeviceToken -> do ns <- asks ntfSupervisor @@ -897,7 +897,7 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete pure ntfTknStatus -- TODO -- agentNtfCheckToken c tknId tkn >>= \case - | otherwise -> replaceToken tknId $> NTRegistered + | otherwise -> replaceToken tknId (Just tknId, Just NTADelete) -> do agentNtfDeleteToken c tknId tkn withStore' c (`removeNtfToken` tkn) @@ -908,13 +908,27 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = withStore' c $ \db -> updateNtfMode db tkn suppliedNtfMode pure status where - replaceToken :: NtfTokenId -> m () + replaceToken :: NtfTokenId -> m NtfTknStatus replaceToken tknId = do - agentNtfReplaceToken c tknId tkn suppliedDeviceToken - withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken ns <- asks ntfSupervisor - atomically $ nsUpdateToken ns tkn {deviceToken = suppliedDeviceToken, ntfTknStatus = NTRegistered, ntfMode = suppliedNtfMode} - _ -> + tryReplace ns `catchError` \e -> + if temporaryAgentError e || e == BROKER HOST + then throwError e + else do + withStore' c $ \db -> removeNtfToken db tkn + atomically $ nsRemoveNtfToken ns + createToken + where + tryReplace ns = do + agentNtfReplaceToken c tknId tkn suppliedDeviceToken + withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken + atomically $ nsUpdateToken ns tkn {deviceToken = suppliedDeviceToken, ntfTknStatus = NTRegistered, ntfMode = suppliedNtfMode} + pure NTRegistered + _ -> createToken + where + t tkn = withToken c tkn Nothing + createToken :: m NtfTknStatus + createToken = getNtfServer c >>= \case Just ntfServer -> asks (cmdSignAlg . config) >>= \case @@ -926,8 +940,6 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = registerToken tkn pure NTRegistered _ -> throwError $ CMD PROHIBITED - where - t tkn = withToken c tkn Nothing registerToken :: NtfToken -> m () registerToken tkn@NtfToken {ntfPubKey, ntfDhKeys = (pubDhKey, privDhKey)} = do (tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey @@ -1409,8 +1421,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} transmission@(srv, v, se DuplexConnection _ _ sq@SndQueue {server, sndId} nextRq_ nextSq_ -> case nextSq_ of Just sq'@SndQueue {server = server', sndId = sndId'} -> do unless (smpServer == server' && senderId == sndId') . throwError $ INTERNAL "incorrect queue address" - let qKey = (connId, server, sndId) - qKey' = (connId, server', sndId') + let qKey = (server, sndId) + qKey' = (server', sndId') ok <- switchQueues qKey qKey' `catchError` \e -> do atomically (switchDeliveries qKey' qKey) @@ -1505,7 +1517,7 @@ enqueueConfirmation :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQ enqueueConfirmation c cData@ConnData {connId, connAgentVersion} sq connInfo e2eEncryption = do resumeMsgDelivery c cData sq msgId <- storeConfirmation - queuePendingMsgs c connId sq [msgId] + queuePendingMsgs c sq [msgId] where storeConfirmation :: m InternalId storeConfirmation = withStore c $ \db -> runExceptT $ do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 035643bae..6e03e24ad 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -89,6 +89,7 @@ import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (listToMaybe) import Data.Set (Set) +import qualified Data.Set as S import Data.Text.Encoding import Data.Time.Clock (getCurrentTime) import Data.Tuple (swap) @@ -146,7 +147,7 @@ type SMPClientVar = TMVar (Either AgentErrorType SMPClient) type NtfClientVar = TMVar (Either AgentErrorType NtfClient) -type MsgDeliveryKey = (ConnId, SMPServer, SMP.SenderId) +type MsgDeliveryKey = (SMPServer, SMP.SenderId) data AgentClient = AgentClient { active :: TVar Bool, @@ -160,7 +161,8 @@ data AgentClient = AgentClient useNetworkConfig :: TVar NetworkConfig, subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), - subscrConns :: TMap ConnId SMPServer, + subscrConns :: TVar (Set ConnId), + activeSubscrConns :: TMap ConnId SMPServer, connMsgsQueued :: TMap ConnId Bool, smpQueueMsgQueues :: TMap MsgDeliveryKey (TQueue InternalId), smpQueueMsgDeliveries :: TMap MsgDeliveryKey (Async ()), @@ -212,7 +214,8 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do useNetworkConfig <- newTVar netCfg subscrSrvrs <- TM.empty pendingSubscrSrvrs <- TM.empty - subscrConns <- TM.empty + subscrConns <- newTVar S.empty + activeSubscrConns <- TM.empty connMsgsQueued <- TM.empty smpQueueMsgQueues <- TM.empty smpQueueMsgDeliveries <- TM.empty @@ -228,7 +231,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do asyncClients <- newTVar [] clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i') lock <- newTMVar () - return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock} + return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, activeSubscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock} agentDbPath :: AgentClient -> FilePath agentDbPath AgentClient {agentEnv = Env {store = SQLiteStore {dbFilePath}}} = dbFilePath @@ -271,7 +274,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do where updateSubs cVar = do cs <- readTVar cVar - modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs) + modifyTVar' (activeSubscrConns c) (`M.withoutKeys` M.keysSet cs) addPendingSubs cVar cs pure cs @@ -413,12 +416,13 @@ closeAgentClient c = liftIO $ do clear subscrSrvrs clear pendingSubscrSrvrs clear subscrConns + clear activeSubscrConns clear connMsgsQueued clear smpQueueMsgQueues clear getMsgLocks where - clear :: (AgentClient -> TMap k a) -> IO () - clear sel = atomically $ writeTVar (sel c) M.empty + clear :: Monoid m => (AgentClient -> TVar m) -> IO () + clear sel = atomically $ writeTVar (sel c) mempty closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO () closeProtocolServerClients c clientsSel = @@ -522,7 +526,9 @@ newRcvQueue_ a c srv vRange current = do subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m () subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED - atomically $ addPendingSubscription c rq connId + atomically $ do + modifyTVar (subscrConns c) $ S.insert connId + addPendingSubscription c rq connId withLogClient c server rcvId "SUB" $ \smp -> liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId) >>= either throwError pure @@ -552,7 +558,9 @@ temporaryAgentError = \case subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ())) subscribeQueues c srv qs = do (errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs) - forM_ qs_ $ atomically . uncurry (addPendingSubscription c) . swap + forM_ qs_ $ \q -> atomically $ do + modifyTVar (subscrConns c) . S.insert $ fst q + uncurry (addPendingSubscription c) $ swap q case L.nonEmpty qs_ of Just qs' -> do smp_ <- tryError (getSMPServerClient c srv) @@ -574,12 +582,13 @@ subscribeQueues c srv qs = do addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m () addSubscription c rq@RcvQueue {server} connId = atomically $ do - TM.insert connId server $ subscrConns c + TM.insert connId server $ activeSubscrConns c + modifyTVar (subscrConns c) $ S.insert connId addSubs_ (subscrSrvrs c) rq connId removePendingSubscription c server connId hasActiveSubscription :: AgentClient -> ConnId -> STM Bool -hasActiveSubscription c connId = TM.member connId (subscrConns c) +hasActiveSubscription c connId = TM.member connId (activeSubscrConns c) addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM () addPendingSubscription = addSubs_ . pendingSubscrSrvrs @@ -591,8 +600,9 @@ addSubs_ ss rq@RcvQueue {server} connId = _ -> TM.singleton connId rq >>= \m -> TM.insert server m ss removeSubscription :: AgentClient -> ConnId -> STM () -removeSubscription c@AgentClient {subscrConns} connId = do - server_ <- TM.lookupDelete connId subscrConns +removeSubscription c connId = do + modifyTVar (subscrConns c) $ S.delete connId + server_ <- TM.lookupDelete connId $ activeSubscrConns c mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_ removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM () @@ -603,9 +613,7 @@ removeSubs_ ss server connId = TM.lookup server ss >>= mapM_ (TM.delete connId) getSubscriptions :: AgentClient -> STM (Set ConnId) -getSubscriptions AgentClient {subscrConns} = do - m <- readTVar subscrConns - pure $ M.keysSet m +getSubscriptions = readTVar . subscrConns logServer :: MonadIO m => ByteString -> AgentClient -> ProtocolServer s -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = diff --git a/src/Simplex/Messaging/Notifications/Server/Store.hs b/src/Simplex/Messaging/Notifications/Server/Store.hs index 38ef6cc72..d9af0fb29 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store.hs @@ -106,6 +106,7 @@ removeInactiveTokenRegistrations st NtfTknData {ntfTknId = tId, token} = forM_ tIds $ \(regKey, tId') -> do TM.delete regKey tknRegs TM.delete tId' $ tokens st + -- TODO remove token subscriptions as in deleteNtfToken pure $ map snd tIds removeTokenRegistration :: NtfStore -> NtfTknData -> STM () @@ -130,6 +131,7 @@ deleteNtfToken st tknId = do ) ) + -- TODO refactor qs <- TM.lookupDelete tknId (tokenSubscriptions st) >>= mapM diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index f3a0cc111..a8dbd1a17 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -276,31 +276,25 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do liftIO $ threadDelay 1000000 aliceId <- joinConnection bob True qInfo "bob's connInfo" liftIO $ threadDelay 750000 - liftIO $ print 0 void $ messageNotification apnsQ ("", _, CONF confId _ "bob's connInfo") <- get alice liftIO $ threadDelay 500000 allowConnection alice bobId confId "alice's connInfo" - liftIO $ print 1 void $ messageNotification apnsQ get bob ##> ("", aliceId, INFO "alice's connInfo") - liftIO $ print 2 void $ messageNotification apnsQ get alice ##> ("", bobId, CON) - liftIO $ print 3 void $ messageNotification apnsQ get bob ##> ("", aliceId, CON) -- bob sends message 1 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello" get bob ##> ("", aliceId, SENT $ baseId + 1) - liftIO $ print 4 void $ messageNotification apnsQ get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False ackMessage alice bobId $ baseId + 1 -- alice sends message 2 <- msgId <$> sendMessage alice bobId (SMP.MsgFlags True) "hey there" get alice ##> ("", bobId, SENT $ baseId + 2) - liftIO $ print 5 void $ messageNotification apnsQ get bob =##> \case ("", c, Msg "hey there") -> c == aliceId; _ -> False ackMessage bob aliceId $ baseId + 2