diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a408cfb8..28921a626 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,48 @@ +# 4.3.0 + +SMP server: + +- additional server usage statistics. + +SMP agent: + +- increase retry interval when sending messages after ERR QUOTA. + +# 4.2.0 + +SMP agent and server: + +- reduce sender traffic in cases when queue quota is exceeded: + - server sends quota exceeded message to the recipient when sender receives ERR QUOTA. + - recipient sends QCONT message to the send once the queue is drained (via reply queue). + - sender retry delays are increased, reducing traffic, but sender instantly resumes delivery where QCONT is received. + +SMP server: + +- increase internal queue sizes. + +SMP agent: + +- deduplicate connection IDs in connect/disconnect responses. +- unit tests for Crypto.hs. +- fix connection switch to another queue: correctly set primary send queue. + +Notification server (v1.3.0): + +- check token status when sending verification notificaiton. + +# 4.1.0 + +SMP agent and server: + +- option to toggle TLS handshake error logs (disabled by default). + +SMP agent: + +- include server address in BROKER error. +- api to get hash of double ratchet associated data (for connection verification). +- api to get agent statistics. + # 4.0.0 SMP server: diff --git a/package.yaml b/package.yaml index e74fac751..d4ed1c127 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: simplexmq -version: 4.0.0 +version: 4.3.0 synopsis: SimpleXMQ message broker description: | This package includes <./docs/Simplex-Messaging-Server.html server>, diff --git a/simplexmq.cabal b/simplexmq.cabal index d9f23b64b..93b4ea5f3 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -5,7 +5,7 @@ cabal-version: 1.12 -- see: https://github.com/sol/hpack name: simplexmq -version: 4.0.0 +version: 4.3.0 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and @@ -366,6 +366,7 @@ test-suite smp-server-test CoreTests.CryptoTests CoreTests.EncodingTests CoreTests.ProtocolErrorTests + CoreTests.RetryIntervalTests CoreTests.VersionRangeTests NtfClient NtfServerTests diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 763c51522..632610690 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -504,7 +504,7 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO pure connId' Left e -> do - -- TODO recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md + -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md unless asyncMode $ withStore' c (`deleteConn` connId') throwError e where @@ -791,7 +791,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do atomically $ beginAgentOperation c AOSndNetwork E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e) - Right (corrId, connId, cmd) -> processCmd ri corrId connId cmdId cmd + Right (corrId, connId, cmd) -> processCmd (riFast ri) corrId connId cmdId cmd where processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m () processCmd ri corrId connId cmdId command = case command of @@ -905,7 +905,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do atomically $ do srvs <- readTVar $ smpServers c let used' = if length used + 1 >= L.length srvs then initUsed else srv : used - writeTVar usedSrvs used' + writeTVar usedSrvs $! used' action srvAuth -- ^ ^ ^ async command processing / @@ -964,22 +964,22 @@ queuePendingMsgs c sq msgIds = atomically $ do modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + length msgIds} -- s <- readTVar (msgDeliveryOp c) -- unsafeIOToSTM $ putStrLn $ "msgDeliveryOp: " <> show (opsInProgress s) - q <- getPendingMsgQ c sq - mapM_ (writeTQueue q) msgIds + (mq, _) <- getPendingMsgQ c sq + mapM_ (writeTQueue mq) msgIds -getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId) +getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId, TMVar ()) getPendingMsgQ c SndQueue {server, sndId} = do let qKey = (server, sndId) maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c) where newMsgQueue qKey = do - mq <- newTQueue - TM.insert qKey mq $ smpQueueMsgQueues c - pure mq + q <- (,) <$> newTQueue <*> newEmptyTMVar + TM.insert qKey q $ smpQueueMsgQueues c + pure q runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m () runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do - mq <- atomically $ getPendingMsgQ c sq + (mq, qLock) <- atomically $ getPendingMsgQ c sq ri <- asks $ messageRetryInterval . config forever $ do atomically $ endAgentOperation c AOSndNetwork @@ -993,7 +993,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh Left (e :: E.SomeException) -> notify $ MERR mId (INTERNAL $ show e) Right (rq_, PendingMsgData {msgType, msgBody, msgFlags, internalTs}) -> - withRetryInterval ri $ \loop -> do + withRetryLock2 ri qLock $ \loop -> do resp <- tryError $ case msgType of AM_CONN_INFO -> sendConfirmation c sq msgBody _ -> sendAgentMessage c sq msgFlags msgBody @@ -1004,7 +1004,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh SMP SMP.QUOTA -> case msgType of AM_CONN_INFO -> connError msgId NOT_AVAILABLE AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE - _ -> retrySndOp c loop + _ -> retrySndOp c $ loop RISlow SMP SMP.AUTH -> case msgType of AM_CONN_INFO -> connError msgId NOT_AVAILABLE AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE @@ -1013,7 +1013,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh -- because the queue must be secured by the time the confirmation or the first HELLO is received | duplexHandshake == Just True -> connErr | otherwise -> - ifM (msgExpired helloTimeout) connErr (retrySndOp c loop) + ifM (msgExpired helloTimeout) connErr (retrySndOp c $ loop RIFast) where connErr = case rq_ of -- party initiating connection @@ -1022,6 +1022,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh _ -> connError msgId NOT_ACCEPTED AM_REPLY_ -> notifyDel msgId err AM_A_MSG_ -> notifyDel msgId err + AM_QCONT_ -> notifyDel msgId err AM_QADD_ -> qError msgId "QADD: AUTH" AM_QKEY_ -> qError msgId "QKEY: AUTH" AM_QUSE_ -> qError msgId "QUSE: AUTH" @@ -1031,7 +1032,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh -- the message sending would be retried | temporaryOrHostError e -> do let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout - ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndOp c loop) + ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndOp c $ loop RIFast) | otherwise -> notifyDel msgId err where msgExpired timeoutSel = do @@ -1044,7 +1045,6 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh withStore' c $ \db -> do setSndQueueStatus db sq Confirmed when (isJust rq_) $ removeConfirmations db connId - -- TODO possibly notification flag should be ON for one of the parties, to result in contact connected notification unless (duplexHandshake == Just True) . void $ enqueueMessage c cData sq SMP.noMsgFlags HELLO AM_CONN_INFO_REPLY -> pure () AM_REPLY_ -> pure () @@ -1071,6 +1071,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh qInfo <- createReplyQueue c cData sq srv void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo] AM_A_MSG_ -> notify $ SENT mId + AM_QCONT_ -> pure () AM_QADD_ -> pure () AM_QKEY_ -> pure () AM_QUSE_ -> pure () @@ -1080,16 +1081,18 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh case conn of DuplexConnection cData' rqs sqs -> do -- remove old snd queue from connection once QTEST is sent to the new queue - case findQ (qAddress sq) sqs of + let addr = qAddress sq + case findQ addr sqs of -- this is the same queue where this loop delivers messages to but with updated state Just SndQueue {dbReplaceQueueId = Just replacedId, primary} -> - case removeQP (\SndQueue {dbQueueId} -> dbQueueId == replacedId) sqs of + -- second part of this condition is a sanity check because dbReplaceQueueId cannot point to the same queue, see switchConnection' + case removeQP (\sq'@SndQueue {dbQueueId} -> dbQueueId == replacedId && not (sameQueue addr sq')) sqs of Nothing -> internalErr msgId "sent QTEST: queue not found in connection" Just (sq', sq'' : sqs') -> do -- remove the delivery from the map to stop the thread when the delivery loop is complete atomically $ TM.delete (qAddress sq') $ smpQueueMsgQueues c withStore' c $ \db -> do - when primary $ setSndQueuePrimary db connId sq' + when primary $ setSndQueuePrimary db connId sq deletePendingMsgs db connId sq' deleteConnSndQueue db connId sq' let sqs'' = sq'' :| sqs' @@ -1225,7 +1228,7 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = (Nothing, Just NTARegister) -> do when (savedDeviceToken /= suppliedDeviceToken) $ withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken registerToken tkn $> NTRegistered - -- TODO minimal time before repeat registration + -- possible improvement: add minimal time before repeat registration (Just tknId, Nothing) | savedDeviceToken == suppliedDeviceToken -> when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered @@ -1243,8 +1246,8 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = agentNtfEnableCron c tknId tkn cron when (suppliedNtfMode == NMInstant) $ initializeNtfSubs c when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete - pure ntfTknStatus -- TODO - -- agentNtfCheckToken c tknId tkn >>= \case + -- possible improvement: get updated token status from the server, or maybe TCRON could return the current status + pure ntfTknStatus | otherwise -> replaceToken tknId (Just tknId, Just NTADelete) -> do agentNtfDeleteToken c tknId tkn @@ -1411,11 +1414,6 @@ sendNtfConnCommands c cmd = do _ -> atomically $ writeTBQueue (subQ c) ("", connId, ERR $ INTERNAL "no connection data") --- TODO --- There should probably be another function to cancel all subscriptions that would flush the queue first, --- so that supervisor stops processing pending commands? --- It is an optimization, but I am thinking how it would behave if a user were to flip on/off quickly several times. - setNtfServers' :: AgentMonad m => AgentClient -> [NtfServer] -> m () setNtfServers' c = atomically . writeTVar (ntfServers c) @@ -1492,88 +1490,96 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm processSMP :: RcvQueue -> Connection c -> ConnData -> m () processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {connId, duplexHandshake} = withConnLock c connId "processSMP" $ case cmd of - SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> handleNotifyAck $ do - SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} <- decryptSMPMessage v rq msg - clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <- - parseMessage msgBody - clientVRange <- asks $ smpClientVRange . config - unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION - case (e2eDhSecret, e2ePubKey_) of - (Nothing, Just e2ePubKey) -> do - let e2eDh = C.dh' e2ePubKey e2ePrivKey - decryptClientMessage e2eDh clientMsg >>= \case - (SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo, agentVersion}) -> - smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo phVer agentVersion >> ack - (SMP.PHEmpty, AgentInvitation {connReq, connInfo}) -> - smpInvitation connReq connInfo >> ack - _ -> prohibited >> ack - (Just e2eDh, Nothing) -> do - decryptClientMessage e2eDh clientMsg >>= \case - (SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do - -- primary queue is set as Active in helloMsg, below is to set additional queues Active - let RcvQueue {primary, dbReplaceQueueId} = rq - unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active - case (conn, dbReplaceQueueId) of - (DuplexConnection _ rqs _, Just replacedId) -> do - when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq - case find (\RcvQueue {dbQueueId} -> dbQueueId == replacedId) rqs of - Just RcvQueue {server, rcvId} -> do - enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId - _ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection" - _ -> pure () - tryError agentClientMsg >>= \case - Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of - HELLO -> helloMsg >> ackDel msgId - REPLY cReq -> replyMsg cReq >> ackDel msgId - -- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command - A_MSG body -> do - logServer "<--" c srv rId "MSG " - notify $ MSG msgMeta msgFlags body - QADD qs -> qDuplex "QADD" $ qAddMsg qs - QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs - QUSE qs -> qDuplex "QUSE" $ qUseMsg qs - -- no action needed for QTEST - -- any message in the new queue will mark it active and trigger deletion of the old queue - QTEST _ -> logServer "<--" c srv rId "MSG " >> ackDel msgId - where - qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m () - qDuplex name a = case conn of - DuplexConnection {} -> a conn >> ackDel msgId - _ -> qError $ name <> ": message must be sent to duplex connection" - Right _ -> prohibited >> ack - Left e@(AGENT A_DUPLICATE) -> do - withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case - Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck} - | userAck -> ackDel internalId - | otherwise -> do - liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case - AgentMessage _ (A_MSG body) -> do - logServer "<--" c srv rId "MSG " - notify $ MSG msgMeta msgFlags body - _ -> pure () - _ -> throwError e - Left e -> throwError e - where - agentClientMsg :: m (Maybe (InternalId, MsgMeta, AMessage)) - agentClientMsg = withStore c $ \db -> runExceptT $ do - agentMsgBody <- agentRatchetDecrypt db connId encAgentMsg - liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case - agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do - let msgType = agentMessageType agentMsg - internalHash = C.sha256Hash agentMsgBody - internalTs <- liftIO getCurrentTime - (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId - let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash - recipient = (unId internalId, internalTs) - broker = (srvMsgId, systemToUTCTime srvTs) - msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId} - rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash} - liftIO $ createRcvMsg db connId rq rcvMsg - pure $ Just (internalId, msgMeta, aMessage) - _ -> pure Nothing - _ -> prohibited >> ack - _ -> prohibited >> ack + SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> + handleNotifyAck $ + decryptSMPMessage v rq msg >>= \case + SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody + SMP.ClientRcvMsgQuota {} -> queueDrained >> ack where + queueDrained = case conn of + DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq) + _ -> pure () + processClientMsg srvTs msgFlags msgBody = do + clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <- + parseMessage msgBody + clientVRange <- asks $ smpClientVRange . config + unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION + case (e2eDhSecret, e2ePubKey_) of + (Nothing, Just e2ePubKey) -> do + let e2eDh = C.dh' e2ePubKey e2ePrivKey + decryptClientMessage e2eDh clientMsg >>= \case + (SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo, agentVersion}) -> + smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo phVer agentVersion >> ack + (SMP.PHEmpty, AgentInvitation {connReq, connInfo}) -> + smpInvitation connReq connInfo >> ack + _ -> prohibited >> ack + (Just e2eDh, Nothing) -> do + decryptClientMessage e2eDh clientMsg >>= \case + (SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do + -- primary queue is set as Active in helloMsg, below is to set additional queues Active + let RcvQueue {primary, dbReplaceQueueId} = rq + unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active + case (conn, dbReplaceQueueId) of + (DuplexConnection _ rqs _, Just replacedId) -> do + when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq + case find (\RcvQueue {dbQueueId} -> dbQueueId == replacedId) rqs of + Just RcvQueue {server, rcvId} -> do + enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId + _ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection" + _ -> pure () + tryError agentClientMsg >>= \case + Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of + HELLO -> helloMsg >> ackDel msgId + REPLY cReq -> replyMsg cReq >> ackDel msgId + -- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command + A_MSG body -> do + logServer "<--" c srv rId "MSG " + notify $ MSG msgMeta msgFlags body + QCONT addr -> qDuplex "QCONT" $ continueSending addr + QADD qs -> qDuplex "QADD" $ qAddMsg qs + QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs + QUSE qs -> qDuplex "QUSE" $ qUseMsg qs + -- no action needed for QTEST + -- any message in the new queue will mark it active and trigger deletion of the old queue + QTEST _ -> logServer "<--" c srv rId "MSG " >> ackDel msgId + where + qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m () + qDuplex name a = case conn of + DuplexConnection {} -> a conn >> ackDel msgId + _ -> qError $ name <> ": message must be sent to duplex connection" + Right _ -> prohibited >> ack + Left e@(AGENT A_DUPLICATE) -> do + withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case + Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck} + | userAck -> ackDel internalId + | otherwise -> do + liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case + AgentMessage _ (A_MSG body) -> do + logServer "<--" c srv rId "MSG " + notify $ MSG msgMeta msgFlags body + _ -> pure () + _ -> throwError e + Left e -> throwError e + where + agentClientMsg :: m (Maybe (InternalId, MsgMeta, AMessage)) + agentClientMsg = withStore c $ \db -> runExceptT $ do + agentMsgBody <- agentRatchetDecrypt db connId encAgentMsg + liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case + agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do + let msgType = agentMessageType agentMsg + internalHash = C.sha256Hash agentMsgBody + internalTs <- liftIO getCurrentTime + (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId + let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash + recipient = (unId internalId, internalTs) + broker = (srvMsgId, systemToUTCTime srvTs) + msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId} + rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash} + liftIO $ createRcvMsg db connId rq rcvMsg + pure $ Just (internalId, msgMeta, aMessage) + _ -> pure Nothing + _ -> prohibited >> ack + _ -> prohibited >> ack ack :: m () ack = enqueueCmd $ ICAck rId srvMsgId ackDel :: InternalId -> m () @@ -1698,6 +1704,16 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm connectReplyQueues c cData ownConnInfo smpQueues `catchError` (notify . ERR) _ -> prohibited + continueSending :: (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m () + continueSending addr (DuplexConnection _ _ sqs) = + case findQ addr sqs of + Just sq -> do + logServer "<--" c srv rId "MSG " + atomically $ do + (_, qLock) <- getPendingMsgQ c sq + void $ tryPutTMVar qLock () + Nothing -> qError "QCONT: queue address not found" + -- processed by queue sender qAddMsg :: NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> m () qAddMsg ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported" diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 700459cc2..a94d866d0 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -98,7 +98,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Either (isRight, partitionEithers) import Data.Functor (($>)) -import Data.List (partition, (\\)) +import Data.List (partition) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) @@ -180,7 +180,7 @@ data AgentClient = AgentClient activeSubs :: TRcvQueues, pendingSubs :: TRcvQueues, pendingMsgsQueued :: TMap SndQAddr Bool, - smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId), + smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId, TMVar ()), smpQueueMsgDeliveries :: TMap SndQAddr (Async ()), connCmdsQueued :: TMap ConnId Bool, asyncCmdQueues :: TMap (Maybe SMPServer) (TQueue AsyncCmdId), @@ -303,12 +303,9 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do removeClientAndSubs :: IO ([RcvQueue], [ConnId]) removeClientAndSubs = atomically $ do TM.delete srv smpClients - qs <- RQ.getDelSrvQueues srv $ activeSubs c + (qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c mapM_ (`RQ.addQueue` pendingSubs c) qs - cs <- RQ.getConns (activeSubs c) - -- TODO deduplicate conns - let conns = map (connId :: RcvQueue -> ConnId) qs \\ S.toList cs - pure (qs, conns) + pure (qs, S.toList conns) serverDown :: ([RcvQueue], [ConnId]) -> IO () serverDown (qs, conns) = whenM (readTVarIO active) $ do @@ -345,8 +342,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do unless connected . forM_ client_ $ \cl -> do incClientStat c cl "CONNECT" "" notifySub "" $ hostEvent CONNECT cl - -- TODO deduplicate okConns - let conns = okConns \\ S.toList cs + let conns = S.toList $ S.fromList okConns `S.difference` cs unless (null conns) $ notifySub "" $ UP srv conns let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs @@ -611,7 +607,7 @@ subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> m () subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED atomically $ do - modifyTVar (subscrConns c) $ S.insert connId + modifyTVar' (subscrConns c) $ S.insert connId RQ.addQueue rq $ pendingSubs c withLogClient c server rcvId "SUB" $ \smp -> liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq) @@ -647,9 +643,8 @@ temporaryOrHostError = \case subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (Maybe SMPClient, [(RcvQueue, Either AgentErrorType ())]) subscribeQueues c srv qs = do (errs, qs_) <- partitionEithers <$> mapM checkQueue qs - forM_ qs_ $ \rq@RcvQueue {connId, server = _server} -> atomically $ do - -- TODO check server is correct - modifyTVar (subscrConns c) $ S.insert connId + forM_ qs_ $ \rq@RcvQueue {connId} -> atomically $ do + modifyTVar' (subscrConns c) $ S.insert connId RQ.addQueue rq $ pendingSubs c case L.nonEmpty qs_ of Just qs' -> do @@ -667,14 +662,16 @@ subscribeQueues c srv qs = do pure $ map (second . first $ protocolClientError SMP $ clientServer smp) rs _ -> pure (Nothing, errs) where - checkQueue rq@RcvQueue {rcvId, server} = do - prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c - pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq + checkQueue rq@RcvQueue {rcvId, server} + | server == srv = do + prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c + pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq + | otherwise = pure $ Left (rq, Left $ INTERNAL "queue server does not match parameter") queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m () addSubscription c rq@RcvQueue {connId} = atomically $ do - modifyTVar (subscrConns c) $ S.insert connId + modifyTVar' (subscrConns c) $ S.insert connId RQ.addQueue rq $ activeSubs c RQ.deleteQueue rq $ pendingSubs c @@ -683,7 +680,7 @@ hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c removeSubscription :: AgentClient -> ConnId -> STM () removeSubscription c connId = do - modifyTVar (subscrConns c) $ S.delete connId + modifyTVar' (subscrConns c) $ S.delete connId RQ.deleteConn connId $ activeSubs c RQ.deleteConn connId $ pendingSubs c @@ -948,7 +945,7 @@ storeError = \case incStat :: AgentClient -> Int -> AgentStatsKey -> STM () incStat AgentClient {agentStats} n k = do TM.lookup k agentStats >>= \case - Just v -> modifyTVar v (+ n) + Just v -> modifyTVar' v (+ n) _ -> newTVar n >>= \v -> TM.insert k v agentStats incClientStat :: AgentClient -> ProtocolClient msg -> ByteString -> ByteString -> IO () diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index ce754e5e9..ca41e0b3e 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -83,7 +83,7 @@ data AgentConfig = AgentConfig smpCfg :: ProtocolClientConfig, ntfCfg :: ProtocolClientConfig, reconnectInterval :: RetryInterval, - messageRetryInterval :: RetryInterval, + messageRetryInterval :: RetryInterval2, messageTimeout :: NominalDiffTime, helloTimeout :: NominalDiffTime, ntfCron :: Word16, @@ -108,12 +108,24 @@ defaultReconnectInterval = maxInterval = 180_000000 } -defaultMessageRetryInterval :: RetryInterval +defaultMessageRetryInterval :: RetryInterval2 defaultMessageRetryInterval = - RetryInterval - { initialInterval = 1_000000, - increaseAfter = 10_000000, - maxInterval = 60_000000 + RetryInterval2 + { riFast = + RetryInterval + { initialInterval = 1_000000, + increaseAfter = 10_000000, + maxInterval = 60_000000 + }, + riSlow = + -- TODO: these timeouts can be increased in v4.6 once most clients are updated + -- to resume sending on QCONT messages. + -- After that local message expiration period should be also increased. + RetryInterval + { initialInterval = 30_000000, + increaseAfter = 30_000000, + maxInterval = 600_000000 + } } defaultAgentConfig :: AgentConfig diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index a76787db2..5d6af212b 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -191,7 +191,7 @@ runNtfWorker c srv doWork = do case clientNtfCreds of Just ClientNtfCreds {ntfPrivateKey, notifierId} -> do nSubId <- agentNtfCreateSubscription c tknId tkn (SMPQueueNtf smpServer notifierId) ntfPrivateKey - -- TODO smaller retry until Active, less frequently (daily?) once Active + -- possible improvement: smaller retry until Active, less frequently (daily?) once Active let actionTs' = addUTCTime 30 ts withStore' c $ \db -> updateNtfSubscription db sub {ntfSubId = Just nSubId, ntfSubStatus = NASCreated NSNew} (NtfSubNTFAction NSACheck) actionTs' diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 9d2b33a70..8f9c833d1 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -577,6 +577,7 @@ data AgentMessageType | AM_HELLO_ | AM_REPLY_ | AM_A_MSG_ + | AM_QCONT_ | AM_QADD_ | AM_QKEY_ | AM_QUSE_ @@ -590,6 +591,7 @@ instance Encoding AgentMessageType where AM_HELLO_ -> "H" AM_REPLY_ -> "R" AM_A_MSG_ -> "M" + AM_QCONT_ -> "QC" AM_QADD_ -> "QA" AM_QKEY_ -> "QK" AM_QUSE_ -> "QU" @@ -603,6 +605,7 @@ instance Encoding AgentMessageType where 'M' -> pure AM_A_MSG_ 'Q' -> A.anyChar >>= \case + 'C' -> pure AM_QCONT_ 'A' -> pure AM_QADD_ 'K' -> pure AM_QKEY_ 'U' -> pure AM_QUSE_ @@ -623,6 +626,7 @@ agentMessageType = \case -- REPLY is only used in v1 REPLY _ -> AM_REPLY_ A_MSG _ -> AM_A_MSG_ + QCONT _ -> AM_QCONT_ QADD _ -> AM_QADD_ QKEY _ -> AM_QKEY_ QUSE _ -> AM_QUSE_ @@ -645,6 +649,7 @@ data AMsgType = HELLO_ | REPLY_ | A_MSG_ + | QCONT_ | QADD_ | QKEY_ | QUSE_ @@ -656,6 +661,7 @@ instance Encoding AMsgType where HELLO_ -> "H" REPLY_ -> "R" A_MSG_ -> "M" + QCONT_ -> "QC" QADD_ -> "QA" QKEY_ -> "QK" QUSE_ -> "QU" @@ -667,6 +673,7 @@ instance Encoding AMsgType where 'M' -> pure A_MSG_ 'Q' -> A.anyChar >>= \case + 'C' -> pure QCONT_ 'A' -> pure QADD_ 'K' -> pure QKEY_ 'U' -> pure QUSE_ @@ -684,6 +691,8 @@ data AMessage REPLY (L.NonEmpty SMPQueueInfo) | -- | agent envelope for the client message A_MSG MsgBody + | -- | the message instructing the client to continue sending messages (after ERR QUOTA) + QCONT SndQAddr | -- add queue to connection (sent by recipient), with optional address of the replaced queue QADD (L.NonEmpty (SMPQueueUri, Maybe SndQAddr)) | -- key to secure the added queues and agree e2e encryption key (sent by sender) @@ -701,6 +710,7 @@ instance Encoding AMessage where HELLO -> smpEncode HELLO_ REPLY smpQueues -> smpEncode (REPLY_, smpQueues) A_MSG body -> smpEncode (A_MSG_, Tail body) + QCONT addr -> smpEncode (QCONT_, addr) QADD qs -> smpEncode (QADD_, qs) QKEY qs -> smpEncode (QKEY_, qs) QUSE qs -> smpEncode (QUSE_, qs) @@ -711,6 +721,7 @@ instance Encoding AMessage where HELLO_ -> pure HELLO REPLY_ -> REPLY <$> smpP A_MSG_ -> A_MSG . unTail <$> smpP + QCONT_ -> QCONT <$> smpP QADD_ -> QADD <$> smpP QKEY_ -> QKEY <$> smpP QUSE_ -> QUSE <$> smpP @@ -1119,7 +1130,6 @@ instance ToJSON BrokerErrorType where toEncoding = J.genericToEncoding $ sumTypeJSON id -- | Errors of another SMP agent. --- TODO encode/decode without A prefix data SMPAgentError = -- | client or agent message that failed to parse A_MESSAGE diff --git a/src/Simplex/Messaging/Agent/RetryInterval.hs b/src/Simplex/Messaging/Agent/RetryInterval.hs index 048b9e09c..3d5cfcbae 100644 --- a/src/Simplex/Messaging/Agent/RetryInterval.hs +++ b/src/Simplex/Messaging/Agent/RetryInterval.hs @@ -1,10 +1,21 @@ +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} -module Simplex.Messaging.Agent.RetryInterval where +module Simplex.Messaging.Agent.RetryInterval + ( RetryInterval (..), + RetryInterval2 (..), + RetryIntervalMode (..), + withRetryInterval, + withRetryLock2, + ) +where -import Control.Concurrent (threadDelay) +import Control.Concurrent (forkIO, threadDelay) +import Control.Monad (void) import Control.Monad.IO.Class (MonadIO, liftIO) +import Simplex.Messaging.Util (whenM) +import UnliftIO.STM data RetryInterval = RetryInterval { initialInterval :: Int, @@ -12,17 +23,51 @@ data RetryInterval = RetryInterval maxInterval :: Int } +data RetryInterval2 = RetryInterval2 + { riSlow :: RetryInterval, + riFast :: RetryInterval + } + +data RetryIntervalMode = RISlow | RIFast + deriving (Eq) + withRetryInterval :: forall m. MonadIO m => RetryInterval -> (m () -> m ()) -> m () -withRetryInterval RetryInterval {initialInterval, increaseAfter, maxInterval} action = - callAction 0 initialInterval +withRetryInterval ri action = callAction 0 $ initialInterval ri where callAction :: Int -> Int -> m () - callAction elapsedTime delay = action loop + callAction elapsed delay = action loop where loop = do - let newDelay = - if elapsedTime < increaseAfter || delay == maxInterval - then delay - else min (delay * 3 `div` 2) maxInterval liftIO $ threadDelay delay - callAction (elapsedTime + delay) newDelay + let elapsed' = elapsed + delay + callAction elapsed' $ nextDelay elapsed' delay ri + +-- This function allows action to toggle between slow and fast retry intervals. +withRetryLock2 :: forall m. MonadIO m => RetryInterval2 -> TMVar () -> ((RetryIntervalMode -> m ()) -> m ()) -> m () +withRetryLock2 RetryInterval2 {riSlow, riFast} lock action = + callAction (0, initialInterval riSlow) (0, initialInterval riFast) + where + callAction :: (Int, Int) -> (Int, Int) -> m () + callAction slow fast = action loop + where + loop = \case + RISlow -> run slow riSlow (`callAction` fast) + RIFast -> run fast riFast (callAction slow) + run (elapsed, delay) ri call = do + wait delay + let elapsed' = elapsed + delay + call (elapsed', nextDelay elapsed' delay ri) + wait delay = do + waiting <- newTVarIO True + _ <- liftIO . forkIO $ do + threadDelay delay + atomically $ whenM (readTVar waiting) $ void $ tryPutTMVar lock () + atomically $ do + takeTMVar lock + writeTVar waiting False + +nextDelay :: Int -> Int -> RetryInterval -> Int +nextDelay elapsed delay RetryInterval {increaseAfter, maxInterval} = + if elapsed < increaseAfter || delay == maxInterval + then delay + else min (delay * 3 `div` 2) maxInterval diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs index bed4138a2..9ace32db7 100644 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -40,9 +40,9 @@ getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs where addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs' -getDelSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue] -getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty) +getDelSrvQueues :: SMPServer -> TRcvQueues -> STM ([RcvQueue], Set ConnId) +getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ (([], S.empty), M.empty) where - addQ (removed, qs') rq@RcvQueue {server, rcvId} - | srv == server = (rq : removed, qs') + addQ (removed@(remQs, remConns), qs') rq@RcvQueue {connId, server, rcvId} + | srv == server = ((rq : remQs, S.insert connId remConns), qs') | otherwise = (removed, M.insert (server, rcvId) rq qs') diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 370965c7f..611702b69 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -1,7 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DisambiguateRecordFields #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} @@ -94,7 +93,7 @@ import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client (SocksProxy, TransportClientConfig (..), TransportHost (..), runTransportClient) import Simplex.Messaging.Transport.KeepAlive import Simplex.Messaging.Transport.WebSockets (WS) -import Simplex.Messaging.Util (bshow, liftError, raceAny_) +import Simplex.Messaging.Util (bshow, raceAny_) import Simplex.Messaging.Version import System.Timeout (timeout) @@ -541,7 +540,7 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ, tcpTimeout}} pKey mkTransmission :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> ClientCommand msg -> ExceptT ProtocolClientError IO (SentRawTransmission, TMVar (Response msg)) mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCorrId, sentCommands}} (pKey, qId, cmd) = do corrId <- liftIO $ atomically getNextCorrId - t <- signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd) + let t = signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd) r <- liftIO . atomically $ mkRequest corrId pure (t, r) where @@ -549,12 +548,8 @@ mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCo getNextCorrId = do i <- stateTVar clientCorrId $ \i -> (i, i + 1) pure . CorrId $ bshow i - signTransmission :: ByteString -> ExceptT ProtocolClientError IO SentRawTransmission - signTransmission t = case pKey of - Nothing -> pure (Nothing, t) - Just pk -> do - sig <- liftError PCESignatureError $ C.sign pk t - return (Just sig, t) + signTransmission :: ByteString -> SentRawTransmission + signTransmission t = ((`C.sign` t) <$> pKey, t) mkRequest :: CorrId -> STM (TMVar (Response msg)) mkRequest corrId = do r <- newEmptyTMVar diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index c5d727aad..6b56a7847 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -67,7 +67,9 @@ module Simplex.Messaging.Crypto -- * key encoding/decoding encodePubKey, + decodePubKey, encodePrivKey, + decodePrivKey, pubKeyBytes, -- * sign/verify @@ -888,12 +890,12 @@ cryptoFailable = liftEither . first AESCipherError . CE.eitherCryptoError -- | Message signing. -- -- Used by SMP clients to sign SMP commands and by SMP agents to sign messages. -sign' :: SignatureAlgorithm a => PrivateKey a -> ByteString -> ExceptT CryptoError IO (Signature a) -sign' (PrivateKeyEd25519 pk k) msg = pure . SignatureEd25519 $ Ed25519.sign pk k msg -sign' (PrivateKeyEd448 pk k) msg = pure . SignatureEd448 $ Ed448.sign pk k msg +sign' :: SignatureAlgorithm a => PrivateKey a -> ByteString -> Signature a +sign' (PrivateKeyEd25519 pk k) msg = SignatureEd25519 $ Ed25519.sign pk k msg +sign' (PrivateKeyEd448 pk k) msg = SignatureEd448 $ Ed448.sign pk k msg -sign :: APrivateSignKey -> ByteString -> ExceptT CryptoError IO ASignature -sign (APrivateSignKey a k) = fmap (ASignature a) . sign' k +sign :: APrivateSignKey -> ByteString -> ASignature +sign (APrivateSignKey a k) = ASignature a . sign' k -- | Signature verification. -- @@ -962,11 +964,7 @@ pseudoRandomCbNonce :: TVar ChaChaDRG -> STM CbNonce pseudoRandomCbNonce gVar = CbNonce <$> pseudoRandomBytes 24 gVar pseudoRandomBytes :: Int -> TVar ChaChaDRG -> STM ByteString -pseudoRandomBytes n gVar = do - g <- readTVar gVar - let (bytes, g') = randomBytesGenerate n g - writeTVar gVar g' - return bytes +pseudoRandomBytes n gVar = stateTVar gVar $ randomBytesGenerate n instance Encoding CbNonce where smpEncode = unCbNonce diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 54b30e0f9..80069a534 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -257,24 +257,27 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do (tkn@NtfTknData {ntfTknId, token = DeviceToken pp _, tknStatus}, ntf) <- atomically (readTBQueue pushQ) liftIO $ logDebug $ "sending push notification to " <> T.pack (show pp) status <- readTVarIO tknStatus - case (status, ntf) of - (_, PNVerification _) -> - -- TODO check token status - deliverNotification pp tkn ntf >>= \case - Right _ -> do - status_ <- atomically $ stateTVar tknStatus $ \status' -> if status' == NTActive then (Nothing, NTActive) else (Just NTConfirmed, NTConfirmed) - forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status' - _ -> pure () - (NTActive, PNCheckMessages) -> + case ntf of + PNVerification _ + | status /= NTInvalid && status /= NTExpired -> + deliverNotification pp tkn ntf >>= \case + Right _ -> do + status_ <- atomically $ stateTVar tknStatus $ \status' -> if status' == NTActive then (Nothing, NTActive) else (Just NTConfirmed, NTConfirmed) + forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status' + _ -> pure () + | otherwise -> logError "bad notification token status" + PNCheckMessages -> checkActiveTkn status $ do void $ deliverNotification pp tkn ntf - (NTActive, PNMessage {}) -> do + PNMessage {} -> checkActiveTkn status $ do stats <- asks serverStats atomically $ updatePeriodStats (activeTokens stats) ntfTknId void $ deliverNotification pp tkn ntf incNtfStat ntfDelivered - _ -> - liftIO $ logError "bad notification token status" where + checkActiveTkn :: NtfTknStatus -> M () -> M () + checkActiveTkn status action + | status == NTActive = action + | otherwise = liftIO $ logError "bad notification token status" deliverNotification :: PushProvider -> NtfTknData -> PushNotification -> M (Either PushProviderError ()) deliverNotification pp tkn@NtfTknData {ntfTknId, tknStatus} ntf = do deliver <- liftIO $ getPushClient s pp @@ -361,13 +364,11 @@ verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do s_ <- atomically $ findNtfSubscription st smpQueue case s_ of Nothing -> do - -- TODO move active token check here to differentiate error t_ <- atomically $ getActiveNtfToken st tknId verifyToken' t_ $ VRVerified (NtfReqNew corrId (ANE SSubscription sub)) Just s@NtfSubData {tokenId = subTknId} -> if subTknId == tknId then do - -- TODO move active token check here to differentiate error t_ <- atomically $ getActiveNtfToken st subTknId verifyToken' t_ $ verifiedSubCmd s c else pure $ maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed @@ -375,7 +376,6 @@ verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do s_ <- atomically $ getNtfSubscription st entId case s_ of Just s@NtfSubData {tokenId = subTknId} -> do - -- TODO move active token check here to differentiate error t_ <- atomically $ getActiveNtfToken st subTknId verifyToken' t_ $ verifiedSubCmd s c _ -> pure $ maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed @@ -512,7 +512,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu (corrId,subId,) <$> case cmd of SNEW (NewNtfSub _ _ notifierKey) -> do logDebug "SNEW - existing subscription" - -- TODO retry if subscription failed, if pending or AUTH do nothing + -- possible improvement: retry if subscription failed, if pending or AUTH do nothing pure $ if notifierKey == registeredNKey then NRSubId subId @@ -548,7 +548,7 @@ withNtfLog action = liftIO . mapM_ action =<< asks storeLog incNtfStat :: (NtfServerStats -> TVar Int) -> M () incNtfStat statSel = do stats <- asks serverStats - atomically $ modifyTVar (statSel stats) (+ 1) + atomically $ modifyTVar' (statSel stats) (+ 1) saveServerStats :: M () saveServerStats = diff --git a/src/Simplex/Messaging/Notifications/Server/Main.hs b/src/Simplex/Messaging/Notifications/Server/Main.hs index 15139ec90..ee0f4e279 100644 --- a/src/Simplex/Messaging/Notifications/Server/Main.hs +++ b/src/Simplex/Messaging/Notifications/Server/Main.hs @@ -27,6 +27,9 @@ import System.FilePath (combine) import System.IO (BufferMode (..), hSetBuffering, stderr, stdout) import Text.Read (readMaybe) +ntfServerVersion :: String +ntfServerVersion = "1.3.0" + ntfServerCLI :: FilePath -> FilePath -> IO () ntfServerCLI cfgPath logPath = getCliCommand' (cliCommandP cfgPath logPath iniFile) serverVersion >>= \case @@ -45,7 +48,7 @@ ntfServerCLI cfgPath logPath = putStrLn "Deleted configuration and log files" where iniFile = combine cfgPath "ntf-server.ini" - serverVersion = "SMP notifications server v1.2.0" + serverVersion = "SMP notifications server v" <> ntfServerVersion defaultServerPort = "443" executableName = "ntf-server" storeLogFilePath = combine logPath "ntf-server-store.log" diff --git a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs index 01a41c1ec..5221d7dbc 100644 --- a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs +++ b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs @@ -97,7 +97,7 @@ readECPrivateKey f = do data PushNotification = PNVerification NtfRegCode | PNMessage PNMessageData - | PNAlert Text + -- | PNAlert Text | PNCheckMessages deriving (Show) @@ -287,14 +287,14 @@ apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case PNMessage pnMessageData -> encrypt (strEncode pnMessageData) $ \ntfData -> apn apnMutableContent . Just $ J.object ["nonce" .= nonce, "message" .= ntfData] - PNAlert text -> Right $ apn (apnAlert $ APNSAlertText text) Nothing + -- PNAlert text -> Right $ apn (apnAlert $ APNSAlertText text) Nothing PNCheckMessages -> Right $ apn APNSBackground {contentAvailable = 1} . Just $ J.object ["checkMessages" .= True] where encrypt :: ByteString -> (Text -> APNSNotification) -> Either C.CryptoError APNSNotification encrypt ntfData f = f . safeDecodeUtf8 . U.encode <$> C.cbEncrypt tknDhSecret nonce ntfData paddedLen apn aps notificationData = APNSNotification {aps, notificationData} apnMutableContent = APNSMutableContent {mutableContent = 1, alert = APNSAlertText "Encrypted message or another app event", category = Just ntfCategoryCheckMessage} - apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing} + -- apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing} apnsRequest :: APNSPushClient -> ByteString -> APNSNotification -> IO Request apnsRequest c tkn ntf@APNSNotification {aps} = do diff --git a/src/Simplex/Messaging/Notifications/Server/Stats.hs b/src/Simplex/Messaging/Notifications/Server/Stats.hs index 6af4b0611..10703d284 100644 --- a/src/Simplex/Messaging/Notifications/Server/Stats.hs +++ b/src/Simplex/Messaging/Notifications/Server/Stats.hs @@ -70,14 +70,14 @@ getNtfServerStatsData s = do setNtfServerStats :: NtfServerStats -> NtfServerStatsData -> STM () setNtfServerStats s d = do - writeTVar (fromTime (s :: NtfServerStats)) (_fromTime (d :: NtfServerStatsData)) - writeTVar (tknCreated s) (_tknCreated d) - writeTVar (tknVerified s) (_tknVerified d) - writeTVar (tknDeleted s) (_tknDeleted d) - writeTVar (subCreated s) (_subCreated d) - writeTVar (subDeleted s) (_subDeleted d) - writeTVar (ntfReceived s) (_ntfReceived d) - writeTVar (ntfDelivered s) (_ntfDelivered d) + writeTVar (fromTime (s :: NtfServerStats)) $! _fromTime (d :: NtfServerStatsData) + writeTVar (tknCreated s) $! _tknCreated d + writeTVar (tknVerified s) $! _tknVerified d + writeTVar (tknDeleted s) $! _tknDeleted d + writeTVar (subCreated s) $! _subCreated d + writeTVar (subDeleted s) $! _subDeleted d + writeTVar (ntfReceived s) $! _ntfReceived d + writeTVar (ntfDelivered s) $! _ntfDelivered d setPeriodStats (activeTokens s) (_activeTokens d) setPeriodStats (activeSubs s) (_activeSubs d) diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 80101e577..b975251e9 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -141,6 +141,7 @@ import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Char (isPrint, isSpace) +import Data.Functor (($>)) import Data.Kind import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L @@ -303,12 +304,17 @@ data RcvMessage = RcvMessage deriving (Eq, Show) -- | received message without server/recipient encryption -data Message = Message - { msgId :: MsgId, - msgTs :: SystemTime, - msgFlags :: MsgFlags, - msgBody :: C.MaxLenBS MaxMessageLen - } +data Message + = Message + { msgId :: MsgId, + msgTs :: SystemTime, + msgFlags :: MsgFlags, + msgBody :: C.MaxLenBS MaxMessageLen + } + | MessageQuota + { msgId :: MsgId, + msgTs :: SystemTime + } instance StrEncoding RcvMessage where strEncode RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} = @@ -328,44 +334,72 @@ instance StrEncoding RcvMessage where newtype EncRcvMsgBody = EncRcvMsgBody ByteString deriving (Eq, Show) -data RcvMsgBody = RcvMsgBody - { msgTs :: SystemTime, - msgFlags :: MsgFlags, - msgBody :: C.MaxLenBS MaxMessageLen - } +data RcvMsgBody + = RcvMsgBody + { msgTs :: SystemTime, + msgFlags :: MsgFlags, + msgBody :: C.MaxLenBS MaxMessageLen + } + | RcvMsgQuota + { msgTs :: SystemTime + } + +msgQuotaTag :: ByteString +msgQuotaTag = "QUOTA" encodeRcvMsgBody :: RcvMsgBody -> C.MaxLenBS MaxRcvMessageLen -encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody} = - let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ') - in C.appendMaxLenBS rcvMeta msgBody +encodeRcvMsgBody = \case + RcvMsgBody {msgTs, msgFlags, msgBody} -> + let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ') + in C.appendMaxLenBS rcvMeta msgBody + RcvMsgQuota {msgTs} -> + C.unsafeMaxLenBS $ msgQuotaTag <> " " <> smpEncode msgTs -data ClientRcvMsgBody = ClientRcvMsgBody - { msgTs :: SystemTime, - msgFlags :: MsgFlags, - msgBody :: ByteString - } +data ClientRcvMsgBody + = ClientRcvMsgBody + { msgTs :: SystemTime, + msgFlags :: MsgFlags, + msgBody :: ByteString + } + | ClientRcvMsgQuota + { msgTs :: SystemTime + } clientRcvMsgBodyP :: Parser ClientRcvMsgBody -clientRcvMsgBodyP = do - msgTs <- smpP - msgFlags <- smpP - Tail msgBody <- _smpP - pure ClientRcvMsgBody {msgTs, msgFlags, msgBody} +clientRcvMsgBodyP = msgQuotaP <|> msgBodyP + where + msgQuotaP = A.string msgQuotaTag *> (ClientRcvMsgQuota <$> _smpP) + msgBodyP = do + msgTs <- smpP + msgFlags <- smpP + Tail msgBody <- _smpP + pure ClientRcvMsgBody {msgTs, msgFlags, msgBody} instance StrEncoding Message where - strEncode Message {msgId, msgTs, msgFlags, msgBody} = - B.unwords - [ strEncode msgId, - strEncode msgTs, - "flags=" <> strEncode msgFlags, - strEncode msgBody - ] + strEncode = \case + Message {msgId, msgTs, msgFlags, msgBody} -> + B.unwords + [ strEncode msgId, + strEncode msgTs, + "flags=" <> strEncode msgFlags, + strEncode msgBody + ] + MessageQuota {msgId, msgTs} -> + B.unwords + [ strEncode msgId, + strEncode msgTs, + "quota" + ] strP = do msgId <- strP_ msgTs <- strP_ - msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags - msgBody <- strP - pure Message {msgId, msgTs, msgFlags, msgBody} + msgQuotaP msgId msgTs <|> msgP msgId msgTs + where + msgQuotaP msgId msgTs = "quota" $> MessageQuota {msgId, msgTs} + msgP msgId msgTs = do + msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags + msgBody <- strP + pure Message {msgId, msgTs, msgFlags, msgBody} type EncNMsgMeta = ByteString @@ -377,7 +411,9 @@ data SMPMsgMeta = SMPMsgMeta deriving (Show) rcvMessageMeta :: MsgId -> ClientRcvMsgBody -> SMPMsgMeta -rcvMessageMeta msgId ClientRcvMsgBody {msgTs, msgFlags} = SMPMsgMeta {msgId, msgTs, msgFlags} +rcvMessageMeta msgId = \case + ClientRcvMsgBody {msgTs, msgFlags} -> SMPMsgMeta {msgId, msgTs, msgFlags} + ClientRcvMsgQuota {msgTs} -> SMPMsgMeta {msgId, msgTs, msgFlags = noMsgFlags} data NMsgMeta = NMsgMeta { msgId :: MsgId, @@ -860,7 +896,7 @@ data ErrorType | -- | internal server error INTERNAL | -- | used internally, never returned by the server (to be removed) - DUPLICATE_ -- TODO remove, not part of SMP protocol + DUPLICATE_ -- not part of SMP protocol, used internally deriving (Eq, Generic, Read, Show) instance ToJSON ErrorType where @@ -1132,26 +1168,33 @@ instance Encoding CommandError where _ -> fail "bad command error type" -- | Send signed SMP transmission to TCP transport. -tPut :: Transport c => THandle c -> NonEmpty SentRawTransmission -> IO (NonEmpty (Either TransportError ())) +tPut :: Transport c => THandle c -> NonEmpty SentRawTransmission -> IO [Either TransportError ()] tPut th trs | batch th = tPutBatch [] $ L.map tEncode trs - | otherwise = forM trs $ tPutBlock th . tEncode + | otherwise = forM (L.toList trs) $ tPutLog . tEncode where - tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO (NonEmpty (Either TransportError ())) + tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO [Either TransportError ()] tPutBatch rs ts = do let (n, s, ts_) = encodeBatch 0 "" ts - r <- if n == 0 then pure [Left TELargeMsg] else replicate n <$> tPutBlock th (lenEncode n `B.cons` s) + r <- if n == 0 then largeMsg else replicate n <$> tPutLog (lenEncode n `B.cons` s) let rs' = rs <> r case ts_ of Just ts' -> tPutBatch rs' ts' - _ -> pure $ L.fromList rs' + _ -> pure rs' + largeMsg = putStrLn "tPut error: large message" >> pure [Left TELargeMsg] + tPutLog s = do + r <- tPutBlock th s + case r of + Left e -> putStrLn ("tPut error: " <> show e) + _ -> pure () + pure r encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString)) encodeBatch n s ts@(t :| ts_) | n == 255 = (n, s, Just ts) | otherwise = let s' = s <> smpEncode (Large t) n' = n + 1 - in if B.length s' > blockSize th - 1 + in if B.length s' > blockSize th - 1 -- one byte is reserved for the number of messages in the batch then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts else case L.nonEmpty ts_ of Just ts' -> encodeBatch n' s' ts' diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 9fdb99599..e5cb7f919 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -65,9 +65,9 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Server.MsgStore -import Simplex.Messaging.Server.MsgStore.STM (MsgQueue) +import Simplex.Messaging.Server.MsgStore.STM import Simplex.Messaging.Server.QueueStore -import Simplex.Messaging.Server.QueueStore.STM (QueueStore) +import Simplex.Messaging.Server.QueueStore.STM import Simplex.Messaging.Server.Stats import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) @@ -104,8 +104,8 @@ type M a = ReaderT Env IO a smpServer :: TMVar Bool -> ServerConfig -> M () smpServer started cfg@ServerConfig {transports, logTLSErrors} = do s <- asks server - restoreServerStats restoreServerMessages + restoreServerStats raceAny_ ( serverThread s subscribedQ subscribers subscriptions cancelSub : serverThread s ntfSubscribedQ notifiers ntfSubscriptions (\_ -> pure ()) : @@ -174,7 +174,7 @@ smpServer started cfg@ServerConfig {transports, logTLSErrors} = do initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath threadDelay $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0) - ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, activeQueues} <- asks serverStats + ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount} <- asks serverStats let interval = 1000000 * logInterval withFile statsFilePath AppendMode $ \h -> liftIO $ do hSetBuffering h LineBuffering @@ -187,7 +187,31 @@ smpServer started cfg@ServerConfig {transports, logTLSErrors} = do msgSent' <- atomically $ swapTVar msgSent 0 msgRecv' <- atomically $ swapTVar msgRecv 0 ps <- atomically $ periodStatCounts activeQueues ts - hPutStrLn h $ intercalate "," [iso8601Show $ utctDay fromTime', show qCreated', show qSecured', show qDeleted', show msgSent', show msgRecv', dayCount ps, weekCount ps, monthCount ps] + msgSentNtf' <- atomically $ swapTVar msgSentNtf 0 + msgRecvNtf' <- atomically $ swapTVar msgRecvNtf 0 + psNtf <- atomically $ periodStatCounts activeQueuesNtf ts + qCount' <- readTVarIO qCount + msgCount' <- readTVarIO msgCount + hPutStrLn h $ + intercalate + "," + [ iso8601Show $ utctDay fromTime', + show qCreated', + show qSecured', + show qDeleted', + show msgSent', + show msgRecv', + dayCount ps, + weekCount ps, + monthCount ps, + show msgSentNtf', + show msgRecvNtf', + dayCount psNtf, + weekCount psNtf, + monthCount psNtf, + show qCount', + show msgCount' + ] threadDelay interval runClient :: Transport c => TProxy c -> c -> M () @@ -256,7 +280,6 @@ receive th Client {rcvQ, sndQ, activeAt} = forever $ do send :: Transport c => THandle c -> Client -> IO () send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ - -- TODO the line below can return Lefts, but we ignore it and do not disconnect the client void . liftIO . tPut h $ L.map ((Nothing,) . encodeTransmission v sessionId) ts atomically . writeTVar activeAt =<< liftIO getSystemTime where @@ -387,7 +410,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv Right _ -> do withLog (`logCreateById` rId) stats <- asks serverStats - atomically $ modifyTVar (qCreated stats) (+ 1) + atomically $ modifyTVar' (qCreated stats) (+ 1) + atomically $ modifyTVar' (qCount stats) (+ 1) subscribeQueue qr rId $> IDS (qik ids) logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO () @@ -405,7 +429,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv secureQueue_ st sKey = time "KEY" $ do withLog $ \s -> logSecureQueue s queueId sKey stats <- asks serverStats - atomically $ modifyTVar (qSecured stats) (+ 1) + atomically $ modifyTVar' (qSecured stats) (+ 1) atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (Transmission BrokerMsg) @@ -510,12 +534,12 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv q <- getStoreMsgQueue "ACK" queueId case s of Sub {subThread = ProhibitSub} -> do - msgDeleted <- atomically $ tryDelMsg q msgId - when msgDeleted updateStats + deletedMsg_ <- atomically $ tryDelMsg q msgId + mapM_ updateStats deletedMsg_ pure ok _ -> do - (msgDeleted, msg_) <- atomically $ tryDelPeekMsg q msgId - when msgDeleted updateStats + (deletedMsg_, msg_) <- atomically $ tryDelPeekMsg q msgId + mapM_ updateStats deletedMsg_ deliverMessage "ACK" qr queueId sub q msg_ _ -> pure $ err NO_MSG where @@ -526,11 +550,17 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv if msgId == msgId' || B.null msgId then pure $ Just s else putTMVar delivered msgId' $> Nothing - updateStats :: m () - updateStats = do - stats <- asks serverStats - atomically $ modifyTVar (msgRecv stats) (+ 1) - atomically $ updatePeriodStats (activeQueues stats) queueId + updateStats :: Message -> m () + updateStats = \case + MessageQuota {} -> pure () + Message {msgFlags} -> do + stats <- asks serverStats + atomically $ modifyTVar' (msgRecv stats) (+ 1) + atomically $ modifyTVar' (msgCount stats) (+ 1) + atomically $ updatePeriodStats (activeQueues stats) queueId + when (notification msgFlags) $ do + atomically $ modifyTVar' (msgRecvNtf stats) (+ 1) + atomically $ updatePeriodStats (activeQueuesNtf stats) queueId sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg) sendMessage qr msgFlags msgBody @@ -538,20 +568,25 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv | otherwise = case status qr of QueueOff -> return $ err AUTH QueueActive -> - mapM mkMessage (C.maxLenBS msgBody) >>= \case + case C.maxLenBS msgBody of Left _ -> pure $ err LARGE_MSG - Right msg -> do - resp@(_, _, sent) <- time "SEND" $ do + Right body -> do + msg_ <- time "SEND" $ do q <- getStoreMsgQueue "SEND" $ recipientId qr expireMessages q - atomically $ ifM (isFull q) (pure $ err QUOTA) (writeMsg q msg $> ok) - when (sent == OK) . time "SEND ok" $ do - when (notification msgFlags) $ - atomically . trySendNotification msg =<< asks idsDrg - stats <- asks serverStats - atomically $ modifyTVar (msgSent stats) (+ 1) - atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) - pure resp + atomically . writeMsg q =<< mkMessage body + case msg_ of + Nothing -> pure $ err QUOTA + Just msg -> time "SEND ok" $ do + stats <- asks serverStats + when (notification msgFlags) $ do + atomically . trySendNotification msg =<< asks idsDrg + atomically $ modifyTVar' (msgSentNtf stats) (+ 1) + atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr) + atomically $ modifyTVar' (msgSent stats) (+ 1) + atomically $ modifyTVar' (msgCount stats) (subtract 1) + atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) + pure ok where mkMessage :: C.MaxLenBS MaxMessageLen -> m Message mkMessage body = do @@ -572,12 +607,14 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM () writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} = - unlessM (isFullTBQueue q) $ do - (nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg - writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)] + unlessM (isFullTBQueue q) $ case msg of + Message {msgId, msgTs} -> do + (nmsgNonce, encNMsgMeta) <- mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg + writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)] + _ -> pure () - mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta) - mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do + mkMessageNotification :: ByteString -> SystemTime -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta) + mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg = do cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg let msgMeta = NMsgMeta {msgId, msgTs} encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128 @@ -596,9 +633,9 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv where forkSub :: m () forkSub = do - atomically . modifyTVar sub $ \s -> s {subThread = SubPending} + atomically . modifyTVar' sub $ \s -> s {subThread = SubPending} t <- mkWeakThreadId =<< forkIO subscriber - atomically . modifyTVar sub $ \case + atomically . modifyTVar' sub $ \case s@Sub {subThread = SubPending} -> s {subThread = SubThread t} s -> s where @@ -609,23 +646,28 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)] s <- readTVar sub void $ setDelivered s msg - writeTVar sub s {subThread = NoSub} + writeTVar sub $! s {subThread = NoSub} time :: T.Text -> m a -> m a time name = timed name queueId encryptMsg :: QueueRec -> Message -> RcvMessage - encryptMsg qr Message {msgId, msgTs, msgFlags, msgBody} - | thVersion == 1 || thVersion == 2 = encrypt msgBody - | otherwise = encrypt $ encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody} + encryptMsg qr msg = case msg of + Message {msgFlags, msgBody} + | thVersion == 1 || thVersion == 2 -> encrypt msgFlags msgBody + | otherwise -> encrypt msgFlags $ encodeRcvMsgBody RcvMsgBody {msgTs = msgTs', msgFlags, msgBody} + MessageQuota {} -> + encrypt noMsgFlags $ encodeRcvMsgBody (RcvMsgQuota msgTs') where - encrypt :: KnownNat i => C.MaxLenBS i -> RcvMessage - encrypt body = - let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId) body - in RcvMessage msgId msgTs msgFlags encBody + encrypt :: KnownNat i => MsgFlags -> C.MaxLenBS i -> RcvMessage + encrypt msgFlags body = + let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId') body + in RcvMessage msgId' msgTs' msgFlags encBody + msgId' = msgId (msg :: Message) + msgTs' = msgTs (msg :: Message) setDelivered :: Sub -> Message -> STM Bool - setDelivered s Message {msgId} = tryPutTMVar (delivered s) msgId + setDelivered s msg = tryPutTMVar (delivered s) $ msgId (msg :: Message) getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do @@ -638,7 +680,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv withLog (`logDeleteQueue` queueId) ms <- asks msgStore stats <- asks serverStats - atomically $ modifyTVar (qDeleted stats) (+ 1) + atomically $ modifyTVar' (qDeleted stats) (+ 1) + atomically $ modifyTVar' (qCount stats) (subtract 1) atomically $ deleteQueue st queueId >>= \case Left e -> pure $ err e @@ -717,8 +760,11 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= mapM_ restoreMessages addToMsgQueue rId msg = do full <- atomically $ do q <- getMsgQueue ms rId quota - ifM (isFull q) (pure True) (writeMsg q msg $> False) - when full . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (msgId (msg :: Message)) + isNothing <$> writeMsg q msg + case msg of + Message {} -> + when full . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (msgId (msg :: Message)) + MessageQuota {} -> pure () updateMsgV1toV3 QueueRec {rcvDhSecret} RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} = do let nonce = C.cbNonce msgId msgBody <- liftEither . first (msgErr "v1 message decryption") $ C.maxLenBS =<< C.cbDecrypt rcvDhSecret nonce body @@ -744,7 +790,9 @@ restoreServerStats = asks (serverStatsBackupFile . config) >>= mapM_ restoreStat liftIO (strDecode <$> B.readFile f) >>= \case Right d -> do s <- asks serverStats - atomically $ setServerStats s d + _qCount <- fmap (length . M.keys) . readTVarIO . queues =<< asks queueStore + _msgCount <- foldM (\n q -> (n +) <$> readTVarIO (size q)) 0 =<< readTVarIO =<< asks msgStore + atomically $ setServerStats s d {_qCount, _msgCount} renameFile f $ f <> ".bak" logInfo "server stats restored" Left e -> do diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 61458fdc8..15253fec8 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -39,7 +39,7 @@ data ServerConfig = ServerConfig { transports :: [(ServiceName, ATransport)], tbqSize :: Natural, serverTbqSize :: Natural, - msgQueueQuota :: Natural, + msgQueueQuota :: Int, queueIdBytes :: Int, msgIdBytes :: Int, storeLogFile :: Maybe FilePath, @@ -164,8 +164,8 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile, (qs, s') <- liftIO $ readWriteStoreLog s atomically $ do writeTVar queues =<< mapM newTVar qs - writeTVar senders $ M.foldr' addSender M.empty qs - writeTVar notifiers $ M.foldr' addNotifier M.empty qs + writeTVar senders $! M.foldr' addSender M.empty qs + writeTVar notifiers $! M.foldr' addNotifier M.empty qs pure s' addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId addSender q = M.insert (senderId q) (recipientId q) diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 35ecce103..b66ed5ba4 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -160,8 +160,8 @@ smpServerCLI cfgPath logPath = serverConfig = ServerConfig { transports = iniTransports ini, - tbqSize = 16, - serverTbqSize = 64, + tbqSize = 32, + serverTbqSize = 128, msgQueueQuota = 128, queueIdBytes = 24, msgIdBytes = 24, -- must be at least 24 bytes, it is used as 192-bit nonce for XSalsa20 diff --git a/src/Simplex/Messaging/Server/MsgStore.hs b/src/Simplex/Messaging/Server/MsgStore.hs index 476a03f3f..55a9c5499 100644 --- a/src/Simplex/Messaging/Server/MsgStore.hs +++ b/src/Simplex/Messaging/Server/MsgStore.hs @@ -1,14 +1,11 @@ -{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Simplex.Messaging.Server.MsgStore where import Control.Applicative ((<|>)) -import Data.Int (Int64) -import Numeric.Natural import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (Message (..), MsgId, RcvMessage (..), RecipientId) +import Simplex.Messaging.Protocol (Message (..), RcvMessage (..), RecipientId) data MsgLogRecord = MLRv3 RecipientId Message | MLRv1 RecipientId RcvMessage @@ -17,17 +14,3 @@ instance StrEncoding MsgLogRecord where MLRv3 rId msg -> strEncode (Str "v3", rId, msg) MLRv1 rId msg -> strEncode (rId, msg) strP = "v3 " *> (MLRv3 <$> strP_ <*> strP) <|> MLRv1 <$> strP_ <*> strP - -class MonadMsgStore s q m | s -> q where - getMsgQueue :: s -> RecipientId -> Natural -> m q - delMsgQueue :: s -> RecipientId -> m () - flushMsgQueue :: s -> RecipientId -> m [Message] - -class MonadMsgQueue q m where - isFull :: q -> m Bool - writeMsg :: q -> Message -> m () -- non blocking - tryPeekMsg :: q -> m (Maybe Message) -- non blocking - peekMsg :: q -> m Message -- blocking - tryDelMsg :: q -> MsgId -> m Bool -- non blocking - tryDelPeekMsg :: q -> MsgId -> m (Bool, Maybe Message) -- atomic delete (== read) last and peek next message, if available - deleteExpiredMsgs :: q -> Int64 -> m () diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 5905a1789..e9dd95eec 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -1,83 +1,120 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TupleSections #-} -module Simplex.Messaging.Server.MsgStore.STM where +module Simplex.Messaging.Server.MsgStore.STM + ( STMMsgStore, + MsgQueue (..), + newMsgStore, + getMsgQueue, + delMsgQueue, + flushMsgQueue, + writeMsg, + tryPeekMsg, + peekMsg, + tryDelMsg, + tryDelPeekMsg, + deleteExpiredMsgs, + ) +where -import Control.Concurrent.STM.TBQueue (flushTBQueue) +import Control.Concurrent.STM.TQueue (flushTQueue) import Control.Monad (when) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Data.Int (Int64) import Data.Time.Clock.System (SystemTime (systemSeconds)) -import Numeric.Natural import Simplex.Messaging.Protocol (Message (..), MsgId, RecipientId) -import Simplex.Messaging.Server.MsgStore import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import UnliftIO.STM -newtype MsgQueue = MsgQueue {msgQueue :: TBQueue Message} +data MsgQueue = MsgQueue + { msgQueue :: TQueue Message, + quota :: Int, + canWrite :: TVar Bool, + size :: TVar Int + } type STMMsgStore = TMap RecipientId MsgQueue newMsgStore :: STM STMMsgStore newMsgStore = TM.empty -instance MonadMsgStore STMMsgStore MsgQueue STM where - getMsgQueue :: STMMsgStore -> RecipientId -> Natural -> STM MsgQueue - getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st - where - newQ = do - q <- MsgQueue <$> newTBQueue quota - TM.insert rId q st - return q +getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue +getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st + where + newQ = do + msgQueue <- newTQueue + canWrite <- newTVar True + size <- newTVar 0 + let q = MsgQueue {msgQueue, quota, canWrite, size} + TM.insert rId q st + pure q - delMsgQueue :: STMMsgStore -> RecipientId -> STM () - delMsgQueue st rId = TM.delete rId st +delMsgQueue :: STMMsgStore -> RecipientId -> STM () +delMsgQueue st rId = TM.delete rId st - flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] - flushMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (flushTBQueue . msgQueue) +flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] +flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue) -instance MonadMsgQueue MsgQueue STM where - isFull :: MsgQueue -> STM Bool - isFull = isFullTBQueue . msgQueue +writeMsg :: MsgQueue -> Message -> STM (Maybe Message) +writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do + canWrt <- readTVar canWrite + empty <- isEmptyTQueue q + if canWrt || empty + then do + canWrt' <- (quota >) <$> readTVar size + writeTVar canWrite $! canWrt' + modifyTVar' size (+ 1) + if canWrt' + then writeTQueue q msg $> Just msg + else writeTQueue q msgQuota $> Nothing + else pure Nothing + where + msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg} - writeMsg :: MsgQueue -> Message -> STM () - writeMsg = writeTBQueue . msgQueue +tryPeekMsg :: MsgQueue -> STM (Maybe Message) +tryPeekMsg = tryPeekTQueue . msgQueue +{-# INLINE tryPeekMsg #-} - tryPeekMsg :: MsgQueue -> STM (Maybe Message) - tryPeekMsg = tryPeekTBQueue . msgQueue +peekMsg :: MsgQueue -> STM Message +peekMsg = peekTQueue . msgQueue +{-# INLINE peekMsg #-} - peekMsg :: MsgQueue -> STM Message - peekMsg = peekTBQueue . msgQueue +tryDelMsg :: MsgQueue -> MsgId -> STM (Maybe Message) +tryDelMsg mq msgId' = + tryPeekMsg mq >>= \case + msg_@(Just msg) + | msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure msg_ + | otherwise -> pure Nothing + _ -> pure Nothing - tryDelMsg :: MsgQueue -> MsgId -> STM Bool - tryDelMsg (MsgQueue q) msgId' = - tryPeekTBQueue q >>= \case - Just Message {msgId} - | msgId == msgId' || B.null msgId' -> tryReadTBQueue q $> True - | otherwise -> pure False - _ -> pure False +-- atomic delete (== read) last and peek next message if available +tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Maybe Message, Maybe Message) +tryDelPeekMsg mq msgId' = + tryPeekMsg mq >>= \case + msg_@(Just msg) + | msgId msg == msgId' || B.null msgId' -> (msg_,) <$> (tryDeleteMsg mq >> tryPeekMsg mq) + | otherwise -> pure (Nothing, msg_) + _ -> pure (Nothing, Nothing) - -- atomic delete (== read) last and peek next message if available - tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message) - tryDelPeekMsg (MsgQueue q) msgId' = - tryPeekTBQueue q >>= \case - msg_@(Just Message {msgId}) - | msgId == msgId' || B.null msgId' -> (True,) <$> (tryReadTBQueue q >> tryPeekTBQueue q) - | otherwise -> pure (False, msg_) - _ -> pure (False, Nothing) - - deleteExpiredMsgs :: MsgQueue -> Int64 -> STM () - deleteExpiredMsgs (MsgQueue q) old = loop - where - loop = tryPeekTBQueue q >>= mapM_ delOldMsg - delOldMsg Message {msgTs} = +deleteExpiredMsgs :: MsgQueue -> Int64 -> STM () +deleteExpiredMsgs mq old = loop + where + loop = tryPeekMsg mq >>= mapM_ delOldMsg + delOldMsg = \case + Message {msgTs} -> when (systemSeconds msgTs < old) $ - tryReadTBQueue q >> loop + tryDeleteMsg mq >> loop + _ -> pure () + +tryDeleteMsg :: MsgQueue -> STM () +tryDeleteMsg MsgQueue {msgQueue = q, size} = + tryReadTQueue q >>= \case + Just _ -> modifyTVar' size (subtract 1) + _ -> pure () diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index c05a0d3a6..8a7856eb6 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -9,20 +9,20 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol data QueueRec = QueueRec - { recipientId :: RecipientId, - recipientKey :: RcvPublicVerifyKey, - rcvDhSecret :: RcvDhSecret, - senderId :: SenderId, - senderKey :: Maybe SndPublicVerifyKey, - notifier :: Maybe NtfCreds, - status :: ServerQueueStatus + { recipientId :: !RecipientId, + recipientKey :: !RcvPublicVerifyKey, + rcvDhSecret :: !RcvDhSecret, + senderId :: !SenderId, + senderKey :: !(Maybe SndPublicVerifyKey), + notifier :: !(Maybe NtfCreds), + status :: !ServerQueueStatus } deriving (Eq, Show) data NtfCreds = NtfCreds - { notifierId :: NotifierId, - notifierKey :: NtfPublicVerifyKey, - rcvNtfDhSecret :: RcvNtfDhSecret + { notifierId :: !NotifierId, + notifierKey :: !NtfPublicVerifyKey, + rcvNtfDhSecret :: !RcvNtfDhSecret } deriving (Eq, Show) @@ -33,12 +33,3 @@ instance StrEncoding NtfCreds where pure NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} data ServerQueueStatus = QueueActive | QueueOff deriving (Eq, Show) - -class MonadQueueStore s m where - addQueue :: s -> QueueRec -> m (Either ErrorType ()) - getQueue :: s -> SParty p -> QueueId -> m (Either ErrorType QueueRec) - secureQueue :: s -> RecipientId -> SndPublicVerifyKey -> m (Either ErrorType QueueRec) - addQueueNotifier :: s -> RecipientId -> NtfCreds -> m (Either ErrorType QueueRec) - deleteQueueNotifier :: s -> RecipientId -> m (Either ErrorType ()) - suspendQueue :: s -> RecipientId -> m (Either ErrorType ()) - deleteQueue :: s -> RecipientId -> m (Either ErrorType ()) diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index 02375f168..b4c41c0de 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -1,7 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -10,7 +9,18 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE UndecidableInstances #-} -module Simplex.Messaging.Server.QueueStore.STM where +module Simplex.Messaging.Server.QueueStore.STM + ( QueueStore (..), + newQueueStore, + addQueue, + getQueue, + secureQueue, + addQueueNotifier, + deleteQueueNotifier, + suspendQueue, + deleteQueue, + ) +where import Control.Monad import Data.Functor (($>)) @@ -34,66 +44,65 @@ newQueueStore = do notifiers <- TM.empty pure QueueStore {queues, senders, notifiers} -instance MonadQueueStore QueueStore STM where - addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ()) - addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do - ifM hasId (pure $ Left DUPLICATE_) $ do - qVar <- newTVar q - TM.insert rId qVar queues - TM.insert sId rId senders - pure $ Right () - where - hasId = (||) <$> TM.member rId queues <*> TM.member sId senders +addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ()) +addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do + ifM hasId (pure $ Left DUPLICATE_) $ do + qVar <- newTVar q + TM.insert rId qVar queues + TM.insert sId rId senders + pure $ Right () + where + hasId = (||) <$> TM.member rId queues <*> TM.member sId senders - getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec) - getQueue QueueStore {queues, senders, notifiers} party qId = - toResult <$> (mapM readTVar =<< getVar) - where - getVar = case party of - SRecipient -> TM.lookup qId queues - SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues) - SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues) +getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec) +getQueue QueueStore {queues, senders, notifiers} party qId = + toResult <$> (mapM readTVar =<< getVar) + where + getVar = case party of + SRecipient -> TM.lookup qId queues + SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues) + SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues) - secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec) - secureQueue QueueStore {queues} rId sKey = - withQueue rId queues $ \qVar -> - readTVar qVar >>= \q -> case senderKey q of - Just k -> pure $ if sKey == k then Just q else Nothing - _ -> - let q' = q {senderKey = Just sKey} - in writeTVar qVar q' $> Just q' +secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec) +secureQueue QueueStore {queues} rId sKey = + withQueue rId queues $ \qVar -> + readTVar qVar >>= \q -> case senderKey q of + Just k -> pure $ if sKey == k then Just q else Nothing + _ -> + let q' = q {senderKey = Just sKey} + in writeTVar qVar q' $> Just q' - addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec) - addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do - ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $ - withQueue rId queues $ \qVar -> do - q <- readTVar qVar - forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId - writeTVar qVar q {notifier = Just ntfCreds} - TM.insert nId rId notifiers - pure $ Just q - - deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ()) - deleteQueueNotifier QueueStore {queues, notifiers} rId = +addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec) +addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do + ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $ withQueue rId queues $ \qVar -> do q <- readTVar qVar - forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers - writeTVar qVar q {notifier = Nothing} - pure $ Just () + forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId + writeTVar qVar $! q {notifier = Just ntfCreds} + TM.insert nId rId notifiers + pure $ Just q - suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) - suspendQueue QueueStore {queues} rId = - withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just () +deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ()) +deleteQueueNotifier QueueStore {queues, notifiers} rId = + withQueue rId queues $ \qVar -> do + q <- readTVar qVar + forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers + writeTVar qVar $! q {notifier = Nothing} + pure $ Just () - deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) - deleteQueue QueueStore {queues, senders, notifiers} rId = do - TM.lookupDelete rId queues >>= \case - Just qVar -> - readTVar qVar >>= \q -> do - TM.delete (senderId q) senders - forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers - pure $ Right () - _ -> pure $ Left AUTH +suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) +suspendQueue QueueStore {queues} rId = + withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just () + +deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) +deleteQueue QueueStore {queues, senders, notifiers} rId = do + TM.lookupDelete rId queues >>= \case + Just qVar -> + readTVar qVar >>= \q -> do + TM.delete (senderId q) senders + forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers + pure $ Right () + _ -> pure $ Left AUTH toResult :: Maybe a -> Either ErrorType a toResult = maybe (Left AUTH) Right diff --git a/src/Simplex/Messaging/Server/Stats.hs b/src/Simplex/Messaging/Server/Stats.hs index 44c1b97d7..82170e90f 100644 --- a/src/Simplex/Messaging/Server/Stats.hs +++ b/src/Simplex/Messaging/Server/Stats.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} @@ -5,7 +6,7 @@ module Simplex.Messaging.Server.Stats where -import Control.Applicative (optional) +import Control.Applicative (optional, (<|>)) import qualified Data.Attoparsec.ByteString.Char8 as A import qualified Data.ByteString.Char8 as B import Data.Set (Set) @@ -24,7 +25,12 @@ data ServerStats = ServerStats qDeleted :: TVar Int, msgSent :: TVar Int, msgRecv :: TVar Int, - activeQueues :: PeriodStats RecipientId + activeQueues :: PeriodStats RecipientId, + msgSentNtf :: TVar Int, + msgRecvNtf :: TVar Int, + activeQueuesNtf :: PeriodStats RecipientId, + qCount :: TVar Int, + msgCount :: TVar Int } data ServerStatsData = ServerStatsData @@ -34,7 +40,12 @@ data ServerStatsData = ServerStatsData _qDeleted :: Int, _msgSent :: Int, _msgRecv :: Int, - _activeQueues :: PeriodStatsData RecipientId + _activeQueues :: PeriodStatsData RecipientId, + _msgSentNtf :: Int, + _msgRecvNtf :: Int, + _activeQueuesNtf :: PeriodStatsData RecipientId, + _qCount :: Int, + _msgCount :: Int } newServerStats :: UTCTime -> STM ServerStats @@ -46,7 +57,12 @@ newServerStats ts = do msgSent <- newTVar 0 msgRecv <- newTVar 0 activeQueues <- newPeriodStats - pure ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, activeQueues} + msgSentNtf <- newTVar 0 + msgRecvNtf <- newTVar 0 + activeQueuesNtf <- newPeriodStats + qCount <- newTVar 0 + msgCount <- newTVar 0 + pure ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount} getServerStatsData :: ServerStats -> STM ServerStatsData getServerStatsData s = do @@ -57,20 +73,30 @@ getServerStatsData s = do _msgSent <- readTVar $ msgSent s _msgRecv <- readTVar $ msgRecv s _activeQueues <- getPeriodStatsData $ activeQueues s - pure ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _activeQueues} + _msgSentNtf <- readTVar $ msgSentNtf s + _msgRecvNtf <- readTVar $ msgRecvNtf s + _activeQueuesNtf <- getPeriodStatsData $ activeQueuesNtf s + _qCount <- readTVar $ qCount s + _msgCount <- readTVar $ msgCount s + pure ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _activeQueues, _msgSentNtf, _msgRecvNtf, _activeQueuesNtf, _qCount, _msgCount} setServerStats :: ServerStats -> ServerStatsData -> STM () setServerStats s d = do - writeTVar (fromTime s) (_fromTime d) - writeTVar (qCreated s) (_qCreated d) - writeTVar (qSecured s) (_qSecured d) - writeTVar (qDeleted s) (_qDeleted d) - writeTVar (msgSent s) (_msgSent d) - writeTVar (msgRecv s) (_msgRecv d) - setPeriodStats (activeQueues s) (_activeQueues d) + writeTVar (fromTime s) $! _fromTime d + writeTVar (qCreated s) $! _qCreated d + writeTVar (qSecured s) $! _qSecured d + writeTVar (qDeleted s) $! _qDeleted d + writeTVar (msgSent s) $! _msgSent d + writeTVar (msgRecv s) $! _msgRecv d + setPeriodStats (activeQueuesNtf s) (_activeQueuesNtf d) + writeTVar (msgSentNtf s) $! _msgSentNtf d + writeTVar (msgRecvNtf s) $! _msgRecvNtf d + setPeriodStats (activeQueuesNtf s) (_activeQueuesNtf d) + writeTVar (qCount s) $! _qCount d + writeTVar (msgCount s) $! _qCount d instance StrEncoding ServerStatsData where - strEncode ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _activeQueues} = + strEncode ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _msgSentNtf, _msgRecvNtf, _activeQueues, _activeQueuesNtf} = B.unlines [ "fromTime=" <> strEncode _fromTime, "qCreated=" <> strEncode _qCreated, @@ -78,8 +104,12 @@ instance StrEncoding ServerStatsData where "qDeleted=" <> strEncode _qDeleted, "msgSent=" <> strEncode _msgSent, "msgRecv=" <> strEncode _msgRecv, + "msgSentNtf=" <> strEncode _msgSentNtf, + "msgRecvNtf=" <> strEncode _msgRecvNtf, "activeQueues:", - strEncode _activeQueues + strEncode _activeQueues, + "activeQueuesNtf:", + strEncode _activeQueuesNtf ] strP = do _fromTime <- "fromTime=" *> strP <* A.endOfLine @@ -88,15 +118,21 @@ instance StrEncoding ServerStatsData where _qDeleted <- "qDeleted=" *> strP <* A.endOfLine _msgSent <- "msgSent=" *> strP <* A.endOfLine _msgRecv <- "msgRecv=" *> strP <* A.endOfLine - r <- optional ("activeQueues:" <* A.endOfLine) - _activeQueues <- case r of - Just _ -> strP <* optional A.endOfLine - _ -> do - _day <- "dayMsgQueues=" *> strP <* A.endOfLine - _week <- "weekMsgQueues=" *> strP <* A.endOfLine - _month <- "monthMsgQueues=" *> strP <* optional A.endOfLine - pure PeriodStatsData {_day, _week, _month} - pure ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _activeQueues} + _msgSentNtf <- "msgSentNtf=" *> strP <* A.endOfLine <|> pure 0 + _msgRecvNtf <- "msgRecvNtf=" *> strP <* A.endOfLine <|> pure 0 + _activeQueues <- + optional ("activeQueues:" <* A.endOfLine) >>= \case + Just _ -> strP <* optional A.endOfLine + _ -> do + _day <- "dayMsgQueues=" *> strP <* A.endOfLine + _week <- "weekMsgQueues=" *> strP <* A.endOfLine + _month <- "monthMsgQueues=" *> strP <* optional A.endOfLine + pure PeriodStatsData {_day, _week, _month} + _activeQueuesNtf <- + optional ("activeQueuesNtf:" <* A.endOfLine) >>= \case + Just _ -> strP <* optional A.endOfLine + _ -> pure newPeriodStatsData + pure ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _msgSentNtf, _msgRecvNtf, _activeQueues, _activeQueuesNtf, _qCount = 0, _msgCount = 0} data PeriodStats a = PeriodStats { day :: TVar (Set a), @@ -117,6 +153,9 @@ data PeriodStatsData a = PeriodStatsData _month :: Set a } +newPeriodStatsData :: PeriodStatsData a +newPeriodStatsData = PeriodStatsData {_day = S.empty, _week = S.empty, _month = S.empty} + getPeriodStatsData :: PeriodStats a -> STM (PeriodStatsData a) getPeriodStatsData s = do _day <- readTVar $ day s @@ -126,9 +165,9 @@ getPeriodStatsData s = do setPeriodStats :: PeriodStats a -> PeriodStatsData a -> STM () setPeriodStats s d = do - writeTVar (day s) (_day d) - writeTVar (week s) (_week d) - writeTVar (month s) (_month d) + writeTVar (day s) $! _day d + writeTVar (week s) $! _week d + writeTVar (month s) $! _month d instance (Ord a, StrEncoding a) => StrEncoding (PeriodStatsData a) where strEncode PeriodStatsData {_day, _week, _month} = @@ -165,4 +204,4 @@ updatePeriodStats stats pId = do updatePeriod week updatePeriod month where - updatePeriod pSel = modifyTVar (pSel stats) (S.insert pId) + updatePeriod pSel = modifyTVar' (pSel stats) (S.insert pId) diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 7d0f3cc1d..0234c5146 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -75,12 +75,14 @@ import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy as BL import Data.Default (def) import Data.Functor (($>)) +import Data.Version (showVersion) import GHC.Generics (Generic) import GHC.IO.Handle.Internals (ioe_EOF) import Generic.Random (genericArbitraryU) import Network.Socket import qualified Network.TLS as T import qualified Network.TLS.Extra as TE +import qualified Paths_simplexmq as SMQ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (dropPrefix, parse, parseRead1, sumTypeJSON) @@ -100,7 +102,7 @@ supportedSMPServerVRange :: VersionRange supportedSMPServerVRange = mkVersionRange 1 5 simplexMQVersion :: String -simplexMQVersion = "4.0.0" +simplexMQVersion = showVersion SMQ.version -- * Transport connection class @@ -214,7 +216,7 @@ instance Transport TLS where $ do b <- readChunks =<< readTVarIO buffer let (s, b') = B.splitAt n b - atomically $ writeTVar buffer b' + atomically $ writeTVar buffer $! b' pure s where readChunks :: ByteString -> IO ByteString @@ -237,7 +239,7 @@ instance Transport TLS where $ do b <- readChunks =<< readTVarIO buffer let (s, b') = B.break (== '\n') b - atomically $ writeTVar buffer (B.drop 1 b') -- drop '\n' we made a break at + atomically $ writeTVar buffer $! B.drop 1 b' -- drop '\n' we made a break at pure $ trimCR s where readChunks :: ByteString -> IO ByteString diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 1568eda57..bdc3a28d0 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -78,6 +78,8 @@ agentTests (ATransport t) = do smpAgentTest2_2_1 $ testConcurrentMsgDelivery t it "should deliver messages if one of connections has quota exceeded" $ smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t + it "should resume delivering messages after exceeding quota once all messages are received" $ + smpAgentTest2_2_1 $ testResumeDeliveryQuotaExceeded t tGetAgent :: Transport c => c -> IO (ATransmissionOrError 'Agent) tGetAgent h = do @@ -430,6 +432,32 @@ testMsgDeliveryQuotaExceeded _ alice bob = do -- if delivery is blocked it won't go further alice <# ("", "bob2", SENT 4) +testResumeDeliveryQuotaExceeded :: Transport c => TProxy c -> c -> c -> IO () +testResumeDeliveryQuotaExceeded _ alice bob = do + connect (alice, "alice") (bob, "bob") + forM_ [1 .. 4 :: Int] $ \i -> do + let corrId = bshow i + msg = "message " <> bshow i + (_, "bob", Right (MID mId)) <- alice #: (corrId, "bob", "SEND F :" <> msg) + alice <#= \case ("", "bob", SENT m) -> m == mId; _ -> False + ("5", "bob", Right (MID 8)) <- alice #: ("5", "bob", "SEND F :over quota") + alice #:# "the last message not sent yet" + bob <#= \case ("", "alice", Msg "message 1") -> True; _ -> False + bob #: ("1", "alice", "ACK 4") #> ("1", "alice", OK) + alice #:# "the last message not sent" + bob <#= \case ("", "alice", Msg "message 2") -> True; _ -> False + bob #: ("2", "alice", "ACK 5") #> ("2", "alice", OK) + alice #:# "the last message not sent" + bob <#= \case ("", "alice", Msg "message 3") -> True; _ -> False + bob #: ("3", "alice", "ACK 6") #> ("3", "alice", OK) + alice #:# "the last message not sent" + bob <#= \case ("", "alice", Msg "message 4") -> True; _ -> False + bob #: ("4", "alice", "ACK 7") #> ("4", "alice", OK) + alice <# ("", "bob", SENT 8) + bob <#= \case ("", "alice", Msg "over quota") -> True; _ -> False + -- message 8 is skipped because of alice agent sending "QCONT" message + bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK) + connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO () connect (h1, name1) (h2, name2) = do ("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV") diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 7f69070c6..282b57ea3 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -15,6 +15,8 @@ module AgentTests.FunctionalAPITests makeConnection, exchangeGreetingsMsgId, switchComplete, + runRight, + runRight_, get, (##>), (=##>), @@ -80,6 +82,15 @@ agentCfgRatchetV1 = agentCfg {e2eEncryptVRange = vr11} vr11 :: VersionRange vr11 = mkVersionRange 1 1 +runRight_ :: ExceptT AgentErrorType IO () -> Expectation +runRight_ action = runExceptT action `shouldReturn` Right () + +runRight :: ExceptT AgentErrorType IO a -> IO a +runRight action = + runExceptT action >>= \case + Right x -> pure x + Left e -> error $ "Unexpected error: " <> show e + functionalAPITests :: ATransport -> Spec functionalAPITests t = do describe "Establishing duplex connection" $ @@ -217,7 +228,7 @@ runTestCfg2 aliceCfg bobCfg baseMsgId runTest = do runAgentClientTest :: AgentClient -> AgentClient -> AgentMsgId -> IO () runAgentClientTest alice bob baseId = do - Right () <- runExceptT $ do + runRight_ $ do (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing aliceId <- joinConnection bob True qInfo "bob's connInfo" ("", _, CONF confId _ "bob's connInfo") <- get alice @@ -247,13 +258,12 @@ runAgentClientTest alice bob baseId = do get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH)) deleteConnection alice bobId liftIO $ noMessages alice "nothing else should be delivered to alice" - pure () where msgId = subtract baseId runAgentClientContactTest :: AgentClient -> AgentClient -> AgentMsgId -> IO () runAgentClientContactTest alice bob baseId = do - Right () <- runExceptT $ do + runRight_ $ do (_, qInfo) <- createConnection alice True SCMContact Nothing aliceId <- joinConnection bob True qInfo "bob's connInfo" ("", _, REQ invId _ "bob's connInfo") <- get alice @@ -285,7 +295,6 @@ runAgentClientContactTest alice bob baseId = do get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH)) deleteConnection alice bobId liftIO $ noMessages alice "nothing else should be delivered to alice" - pure () where msgId = subtract baseId @@ -301,7 +310,7 @@ testAsyncInitiatingOffline :: IO () testAsyncInitiatingOffline = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do (bobId, cReq) <- createConnection alice True SCMInvitation Nothing disconnectAgentClient alice aliceId <- joinConnection bob True cReq "bob's connInfo" @@ -313,13 +322,12 @@ testAsyncInitiatingOffline = do get bob ##> ("", aliceId, INFO "alice's connInfo") get bob ##> ("", aliceId, CON) exchangeGreetings alice' bobId bob aliceId - pure () testAsyncJoiningOfflineBeforeActivation :: IO () testAsyncJoiningOfflineBeforeActivation = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing aliceId <- joinConnection bob True qInfo "bob's connInfo" disconnectAgentClient bob @@ -331,13 +339,12 @@ testAsyncJoiningOfflineBeforeActivation = do get bob' ##> ("", aliceId, INFO "alice's connInfo") get bob' ##> ("", aliceId, CON) exchangeGreetings alice bobId bob' aliceId - pure () testAsyncBothOffline :: IO () testAsyncBothOffline = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do (bobId, cReq) <- createConnection alice True SCMInvitation Nothing disconnectAgentClient alice aliceId <- joinConnection bob True cReq "bob's connInfo" @@ -352,22 +359,21 @@ testAsyncBothOffline = do get bob' ##> ("", aliceId, INFO "alice's connInfo") get bob' ##> ("", aliceId, CON) exchangeGreetings alice' bobId bob' aliceId - pure () testAsyncServerOffline :: ATransport -> IO () testAsyncServerOffline t = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers -- create connection and shutdown the server - Right (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ -> - runExceptT $ createConnection alice True SCMInvitation Nothing + (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ -> + runRight $ createConnection alice True SCMInvitation Nothing -- connection fails Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob True cReq "bob's connInfo" ("", "", DOWN srv conns) <- get alice srv `shouldBe` testSMPServer conns `shouldBe` [bobId] -- connection succeeds after server start - Right () <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ do + withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do ("", "", UP srv1 conns1) <- get alice liftIO $ do srv1 `shouldBe` testSMPServer @@ -379,27 +385,25 @@ testAsyncServerOffline t = do get bob ##> ("", aliceId, INFO "alice's connInfo") get bob ##> ("", aliceId, CON) exchangeGreetings alice bobId bob aliceId - pure () testAsyncHelloTimeout :: IO () testAsyncHelloTimeout = do -- this test would only work if any of the agent is v1, there is no HELLO timeout in v2 alice <- getSMPAgentClient agentCfgV1 initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2, helloTimeout = 1} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do (_, cReq) <- createConnection alice True SCMInvitation Nothing disconnectAgentClient alice aliceId <- joinConnection bob True cReq "bob's connInfo" get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED) - pure () testDuplicateMessage :: ATransport -> IO () testDuplicateMessage t = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers (aliceId, bobId, bob1) <- withSmpServerStoreMsgLogOn t testPort $ \_ -> do - Right (aliceId, bobId) <- runExceptT $ makeConnection alice bob - Right () <- runExceptT $ do + (aliceId, bobId) <- runRight $ makeConnection alice bob + runRight_ $ do 4 <- sendMessage alice bobId SMP.noMsgFlags "hello" get alice ##> ("", bobId, SENT 4) get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False @@ -407,7 +411,7 @@ testDuplicateMessage t = do -- if the agent user did not send ACK, the message will be delivered again bob1 <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do subscribeConnection bob1 aliceId get bob1 =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False ackMessage bob1 aliceId 4 @@ -419,7 +423,7 @@ testDuplicateMessage t = do get alice =##> \case ("", "", DOWN _ [c]) -> c == bobId; _ -> False get bob1 =##> \case ("", "", DOWN _ [c]) -> c == aliceId; _ -> False - -- commenting two lines below and uncommenting further two lines would also pass, + -- commenting two lines below and uncommenting further two lines would also runRight_, -- it is the scenario tested above, when the message was not acknowledged by the user threadDelay 200000 Left (BROKER _ TIMEOUT) <- runExceptT $ ackMessage bob1 aliceId 5 @@ -431,7 +435,7 @@ testDuplicateMessage t = do bob2 <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers withSmpServerStoreMsgLogOn t testPort $ \_ -> do - Right () <- runExceptT $ do + runRight_ $ do subscribeConnection bob2 aliceId subscribeConnection alice2 bobId -- get bob2 =##> \case ("", c, Msg "hello 2") -> c == aliceId; _ -> False @@ -440,7 +444,6 @@ testDuplicateMessage t = do 6 <- sendMessage alice2 bobId SMP.noMsgFlags "hello 3" get alice2 ##> ("", bobId, SENT 6) get bob2 =##> \case ("", c, Msg "hello 3") -> c == aliceId; _ -> False - pure () makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnection alice bob = do @@ -458,10 +461,9 @@ testInactiveClientDisconnected t = do let cfg' = cfg {inactiveClientExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 1}} withSmpServerConfigOn t cfg' testPort $ \_ -> do alice <- getSMPAgentClient agentCfg initAgentServers - Right () <- runExceptT $ do + runRight_ $ do (connId, _cReq) <- createConnection alice True SCMInvitation Nothing get alice ##> ("", "", DOWN testSMPServer [connId]) - pure () testActiveClientNotDisconnected :: ATransport -> IO () testActiveClientNotDisconnected t = do @@ -469,10 +471,9 @@ testActiveClientNotDisconnected t = do withSmpServerConfigOn t cfg' testPort $ \_ -> do alice <- getSMPAgentClient agentCfg initAgentServers ts <- getSystemTime - Right () <- runExceptT $ do + runRight_ $ do (connId, _cReq) <- createConnection alice True SCMInvitation Nothing keepSubscribing alice connId ts - pure () where keepSubscribing :: AgentClient -> ConnId -> SystemTime -> ExceptT AgentErrorType IO () keepSubscribing alice connId ts = do @@ -495,7 +496,7 @@ testSuspendingAgent :: IO () testSuspendingAgent = do a <- getSMPAgentClient agentCfg initAgentServers b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do (aId, bId) <- makeConnection a b 4 <- sendMessage a bId SMP.noMsgFlags "hello" get a ##> ("", bId, SENT 4) @@ -508,13 +509,12 @@ testSuspendingAgent = do Nothing <- 100000 `timeout` get b activateAgent b get b =##> \case ("", c, Msg "hello 2") -> c == aId; _ -> False - pure () testSuspendingAgentCompleteSending :: ATransport -> IO () testSuspendingAgentCompleteSending t = do a <- getSMPAgentClient agentCfg initAgentServers b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ do + (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do (aId, bId) <- makeConnection a b 4 <- sendMessage a bId SMP.noMsgFlags "hello" get a ##> ("", bId, SENT 4) @@ -522,7 +522,7 @@ testSuspendingAgentCompleteSending t = do ackMessage b aId 4 pure (aId, bId) - Right () <- runExceptT $ do + runRight_ $ do ("", "", DOWN {}) <- get a ("", "", DOWN {}) <- get b 5 <- sendMessage b aId SMP.noMsgFlags "hello too" @@ -530,7 +530,7 @@ testSuspendingAgentCompleteSending t = do liftIO $ threadDelay 100000 suspendAgent b 5000000 - Right () <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ do + withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do get b =##> \case ("", c, SENT 5) -> c == aId; ("", "", UP {}) -> True; _ -> False get b =##> \case ("", c, SENT 5) -> c == aId; ("", "", UP {}) -> True; _ -> False get b =##> \case ("", c, SENT 6) -> c == aId; ("", "", UP {}) -> True; _ -> False @@ -544,13 +544,11 @@ testSuspendingAgentCompleteSending t = do get a =##> \case ("", c, Msg "how are you?") -> c == bId; _ -> False ackMessage a bId 6 - pure () - testSuspendingAgentTimeout :: ATransport -> IO () testSuspendingAgentTimeout t = do a <- getSMPAgentClient agentCfg initAgentServers b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right (aId, _) <- withSmpServer t . runExceptT $ do + (aId, _) <- withSmpServer t . runRight $ do (aId, bId) <- makeConnection a b 4 <- sendMessage a bId SMP.noMsgFlags "hello" get a ##> ("", bId, SENT 4) @@ -558,7 +556,7 @@ testSuspendingAgentTimeout t = do ackMessage b aId 4 pure (aId, bId) - Right () <- runExceptT $ do + runRight_ $ do ("", "", DOWN {}) <- get a ("", "", DOWN {}) <- get b 5 <- sendMessage b aId SMP.noMsgFlags "hello too" @@ -567,13 +565,11 @@ testSuspendingAgentTimeout t = do ("", "", SUSPENDED) <- get b pure () - pure () - testBatchedSubscriptions :: ATransport -> IO () testBatchedSubscriptions t = do a <- getSMPAgentClient agentCfg initAgentServers2 b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers2 - Right conns <- runServers $ do + conns <- runServers $ do conns <- forM [1 .. 200 :: Int] . const $ makeConnection a b forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId forM_ (take 10 conns) $ \(aId, bId) -> do @@ -585,7 +581,7 @@ testBatchedSubscriptions t = do ("", "", DOWN {}) <- get a ("", "", DOWN {}) <- get b ("", "", DOWN {}) <- get b - Right () <- runServers $ do + runServers $ do ("", "", UP {}) <- get a ("", "", UP {}) <- get a ("", "", UP {}) <- get b @@ -594,7 +590,6 @@ testBatchedSubscriptions t = do subscribe a $ map snd conns subscribe b $ map fst conns forM_ (drop 10 conns) $ \(aId, bId) -> exchangeGreetingsMsgId 6 a bId b aId - pure () where subscribe :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO () subscribe c cs = do @@ -604,13 +599,11 @@ testBatchedSubscriptions t = do all (== Right ()) (M.withoutKeys r dc) `shouldBe` True all (== Left (CONN NOT_FOUND)) (M.restrictKeys r dc) `shouldBe` True M.keys r `shouldMatchList` cs - runServers :: ExceptT AgentErrorType IO a -> IO (Either AgentErrorType a) + runServers :: ExceptT AgentErrorType IO a -> IO a runServers a = do withSmpServerStoreLogOn t testPort $ \t1 -> do - res <- withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile2} testPort2 $ \t2 -> do - res <- runExceptT a - killThread t2 - pure res + res <- withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile2} testPort2 $ \t2 -> + runRight a `finally` killThread t2 killThread t1 pure res @@ -618,7 +611,7 @@ testAsyncCommands :: IO () testAsyncCommands = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do bobId <- createConnectionAsync alice "1" True SCMInvitation ("1", bobId', INV (ACR _ qInfo)) <- get alice liftIO $ bobId' `shouldBe` bobId @@ -655,7 +648,6 @@ testAsyncCommands = do deleteConnectionAsync alice "8" bobId ("8", _, OK) <- get alice liftIO $ noMessages alice "nothing else should be delivered to alice" - pure () where baseId = 3 msgId = subtract baseId @@ -663,22 +655,21 @@ testAsyncCommands = do testAsyncCommandsRestore :: ATransport -> IO () testAsyncCommandsRestore t = do alice <- getSMPAgentClient agentCfg initAgentServers - Right bobId <- runExceptT $ createConnectionAsync alice "1" True SCMInvitation + bobId <- runRight $ createConnectionAsync alice "1" True SCMInvitation liftIO $ noMessages alice "alice doesn't receive INV because server is down" disconnectAgentClient alice alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers withSmpServerStoreLogOn t testPort $ \_ -> do - Right () <- runExceptT $ do + runRight_ $ do subscribeConnection alice' bobId ("1", _, INV _) <- get alice' pure () - pure () testAcceptContactAsync :: IO () testAcceptContactAsync = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do (_, qInfo) <- createConnection alice True SCMContact Nothing aliceId <- joinConnection bob True qInfo "bob's connInfo" ("", _, REQ invId _ "bob's connInfo") <- get alice @@ -712,7 +703,6 @@ testAcceptContactAsync = do get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH)) deleteConnection alice bobId liftIO $ noMessages alice "nothing else should be delivered to alice" - pure () where baseId = 3 msgId = subtract baseId @@ -721,13 +711,12 @@ testSwitchConnection :: InitialAgentServers -> IO () testSwitchConnection servers = do a <- getSMPAgentClient agentCfg servers b <- getSMPAgentClient agentCfg {database = testDB2, initialClientId = 1} servers - Right () <- runExceptT $ do + runRight_ $ do (aId, bId) <- makeConnection a b exchangeGreetingsMsgId 4 a bId b aId switchConnectionAsync a "" bId switchComplete a bId b aId exchangeGreetingsMsgId 10 a bId b aId - pure () switchComplete :: AgentClient -> ByteString -> AgentClient -> ByteString -> ExceptT AgentErrorType IO () switchComplete a bId b aId = do @@ -749,12 +738,12 @@ phase c connId d p = ERR (AGENT A_DUPLICATE) -> phase c connId d p r -> do liftIO . putStrLn $ "expected: " <> show p <> ", received: " <> show r - SWITCH _ _ _ <- pure r + SWITCH {} <- pure r pure () testSwitchAsync :: InitialAgentServers -> IO () testSwitchAsync servers = do - Right (aId, bId) <- withA $ \a -> withB $ \b -> runExceptT $ do + (aId, bId) <- withA $ \a -> withB $ \b -> runRight $ do (aId, bId) <- makeConnection a b exchangeGreetingsMsgId 4 a bId b aId pure (aId, bId) @@ -769,22 +758,20 @@ testSwitchAsync servers = do phase b aId QDSnd SPConfirmed phase b aId QDSnd SPCompleted withA' $ \a -> phase a bId QDRcv SPCompleted - Right () <- withA $ \a -> withB $ \b -> runExceptT $ do + withA $ \a -> withB $ \b -> runRight_ $ do subscribeConnection a bId subscribeConnection b aId exchangeGreetingsMsgId 10 a bId b aId - pure () where withAgent :: AgentConfig -> (AgentClient -> IO a) -> IO a withAgent cfg' = bracket (getSMPAgentClient cfg' servers) disconnectAgentClient session :: (forall a. (AgentClient -> IO a) -> IO a) -> ConnId -> (AgentClient -> ExceptT AgentErrorType IO ()) -> IO () - session withC connId a = do - Right () <- withC $ \c -> runExceptT $ do + session withC connId a = + withC $ \c -> runRight_ $ do subscribeConnection c connId r <- a c liftIO $ threadDelay 500000 pure r - pure () withA = withAgent agentCfg withB = withAgent agentCfg {database = testDB2, initialClientId = 1} @@ -792,7 +779,7 @@ testSwitchDelete :: InitialAgentServers -> IO () testSwitchDelete servers = do a <- getSMPAgentClient agentCfg servers b <- getSMPAgentClient agentCfg {database = testDB2, initialClientId = 1} servers - Right () <- runExceptT $ do + runRight_ $ do (aId, bId) <- makeConnection a b exchangeGreetingsMsgId 4 a bId b aId disconnectAgentClient b @@ -801,13 +788,12 @@ testSwitchDelete servers = do deleteConnectionAsync a "1" bId ("1", bId', OK) <- get a liftIO $ bId `shouldBe` bId' - pure () testCreateQueueAuth :: (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int testCreateQueueAuth clnt1 clnt2 = do a <- getClient clnt1 b <- getClient clnt2 - Right created <- runExceptT $ do + runRight $ do tryError (createConnection a True SCMInvitation Nothing) >>= \case Left (SMP AUTH) -> pure 0 Left e -> throwError e @@ -823,7 +809,6 @@ testCreateQueueAuth clnt1 clnt2 = do get b ##> ("", aId, CON) exchangeGreetings a bId b aId pure 2 - pure created where getClient (clntAuth, clntVersion) = let servers = initAgentServers {smp = [ProtoServerWithAuth testSMPServer clntAuth]} @@ -834,19 +819,17 @@ testSMPServerConnectionTest :: ATransport -> Maybe BasicAuth -> SMPServerWithAut testSMPServerConnectionTest t newQueueBasicAuth srv = withSmpServerConfigOn t cfg {newQueueBasicAuth} testPort2 $ \_ -> do a <- getSMPAgentClient agentCfg initAgentServers -- initially passed server is not running - Right r <- runExceptT $ testSMPServerConnection a srv - pure r + runRight $ testSMPServerConnection a srv testRatchetAdHash :: IO () testRatchetAdHash = do a <- getSMPAgentClient agentCfg initAgentServers b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do (aId, bId) <- makeConnection a b ad1 <- getConnectionRatchetAdHash a bId ad2 <- getConnectionRatchetAdHash b aId liftIO $ ad1 `shouldBe` ad2 - pure () exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () exchangeGreetings = exchangeGreetingsMsgId 4 diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index d340a3522..6196b6979 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -8,7 +8,7 @@ module AgentTests.NotificationTests where -- import Control.Logger.Simple (LogConfig (..), LogLevel (..), setLogLevel, withGlobalLogging) -import AgentTests.FunctionalAPITests (exchangeGreetingsMsgId, get, makeConnection, switchComplete, testServerMatrix2, (##>), (=##>), pattern Msg) +import AgentTests.FunctionalAPITests (exchangeGreetingsMsgId, get, makeConnection, runRight, runRight_, switchComplete, testServerMatrix2, (##>), (=##>), pattern Msg) import Control.Concurrent (killThread, threadDelay) import Control.Monad.Except import qualified Data.Aeson as J @@ -91,7 +91,7 @@ notificationTests t = testNotificationToken :: APNSMockServer -> IO () testNotificationToken APNSMockServer {apnsQ} = do a <- getSMPAgentClient agentCfg initAgentServers - Right () <- runExceptT $ do + runRight_ $ do let tkn = DeviceToken PPApnsTest "abcd" NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- @@ -105,7 +105,6 @@ testNotificationToken APNSMockServer {apnsQ} = do -- agent deleted this token Left (CMD PROHIBITED) <- tryE $ checkNtfToken a tkn pure () - pure () (.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString v .-> key = do @@ -120,7 +119,7 @@ testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do -- setLogLevel LogError -- LogDebug -- withGlobalLogging logCfg $ do a <- getSMPAgentClient agentCfg initAgentServers - Right () <- runExceptT $ do + runRight_ $ do let tkn = DeviceToken PPApnsTest "abcd" NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- @@ -138,7 +137,6 @@ testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do verifyNtfToken a tkn nonce verification NTActive <- checkNtfToken a tkn pure () - pure () testNtfTokenSecondRegistration :: APNSMockServer -> IO () testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do @@ -146,7 +144,7 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do -- withGlobalLogging logCfg $ do a <- getSMPAgentClient agentCfg initAgentServers a' <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do let tkn = DeviceToken PPApnsTest "abcd" NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- @@ -175,13 +173,12 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do -- and the second is active NTActive <- checkNtfToken a' tkn pure () - pure () testNtfTokenServerRestart :: ATransport -> APNSMockServer -> IO () testNtfTokenServerRestart t APNSMockServer {apnsQ} = do a <- getSMPAgentClient agentCfg initAgentServers let tkn = DeviceToken PPApnsTest "abcd" - Right ntfData <- withNtfServer t . runExceptT $ do + ntfData <- withNtfServer t . runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ @@ -193,7 +190,7 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do a' <- getSMPAgentClient agentCfg initAgentServers -- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token, -- so that repeat verification happens without restarting the clients, when notification arrives - Right () <- withNtfServer t . runExceptT $ do + withNtfServer t . runRight_ $ do verification <- ntfData .-> "verification" nonce <- C.cbNonce <$> ntfData .-> "nonce" Left (NTF AUTH) <- tryE $ verifyNtfToken a' tkn nonce verification @@ -205,13 +202,12 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do verifyNtfToken a' tkn nonce' verification' NTActive <- checkNtfToken a' tkn pure () - pure () testNotificationSubscriptionExistingConnection :: APNSMockServer -> IO () testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right (bobId, aliceId, nonce, message) <- runExceptT $ do + (bobId, aliceId, nonce, message) <- runRight $ do -- establish connection (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing aliceId <- joinConnection bob True qInfo "bob's connInfo" @@ -243,12 +239,12 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do -- aliceNtf client doesn't have subscription and is allowed to get notification message aliceNtf <- getSMPAgentClient agentCfg initAgentServers - Right () <- runExceptT $ do + runRight_ $ do (_, [SMPMsgMeta {msgFlags = MsgFlags True}]) <- getNotificationMessage aliceNtf nonce message pure () disconnectAgentClient aliceNtf - Right () <- runExceptT $ do + runRight_ $ do get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False ackMessage alice bobId $ baseId + 1 -- delete notification subscription @@ -259,7 +255,6 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do get bob ##> ("", aliceId, SENT $ baseId + 2) -- no notifications should follow noNotification apnsQ - pure () where baseId = 3 msgId = subtract baseId @@ -268,7 +263,7 @@ testNotificationSubscriptionNewConnection :: APNSMockServer -> IO () testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do -- alice registers notification token DeviceToken {} <- registerTestToken alice "abcd" NMInstant apnsQ -- bob registers notification token @@ -303,7 +298,6 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do ackMessage bob aliceId $ baseId + 2 -- no unexpected notifications should follow noNotification apnsQ - pure () where baseId = 3 msgId = subtract baseId @@ -325,7 +319,7 @@ testChangeNotificationsMode :: APNSMockServer -> IO () testChangeNotificationsMode APNSMockServer {apnsQ} = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right () <- runExceptT $ do + runRight_ $ do -- establish connection (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing aliceId <- joinConnection bob True qInfo "bob's connInfo" @@ -381,7 +375,6 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = do ackMessage alice bobId $ baseId + 5 -- no notifications should follow noNotification apnsQ - pure () where baseId = 3 msgId = subtract baseId @@ -390,7 +383,7 @@ testChangeToken :: APNSMockServer -> IO () testChangeToken APNSMockServer {apnsQ} = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right (aliceId, bobId) <- runExceptT $ do + (aliceId, bobId) <- runRight $ do -- establish connection (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing aliceId <- joinConnection bob True qInfo "bob's connInfo" @@ -412,7 +405,7 @@ testChangeToken APNSMockServer {apnsQ} = do disconnectAgentClient alice alice1 <- getSMPAgentClient agentCfg initAgentServers - Right () <- runExceptT $ do + runRight_ $ do subscribeConnection alice1 bobId -- change notification token void $ registerTestToken alice1 "bcde" NMInstant apnsQ @@ -425,7 +418,6 @@ testChangeToken APNSMockServer {apnsQ} = do ackMessage alice1 bobId $ baseId + 2 -- no notifications should follow noNotification apnsQ - pure () where baseId = 3 msgId = subtract baseId @@ -434,7 +426,7 @@ testNotificationsStoreLog :: ATransport -> APNSMockServer -> IO () testNotificationsStoreLog t APNSMockServer {apnsQ} = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right (aliceId, bobId) <- withNtfServerStoreLog t $ \threadId -> runExceptT $ do + (aliceId, bobId) <- withNtfServerStoreLog t $ \threadId -> runRight $ do (aliceId, bobId) <- makeConnection alice bob _ <- registerTestToken alice "abcd" NMInstant apnsQ liftIO $ threadDelay 250000 @@ -448,20 +440,19 @@ testNotificationsStoreLog t APNSMockServer {apnsQ} = do liftIO $ threadDelay 250000 - Right () <- withNtfServerStoreLog t $ \threadId -> runExceptT $ do + withNtfServerStoreLog t $ \threadId -> runRight_ $ do liftIO $ threadDelay 250000 5 <- sendMessage bob aliceId (SMP.MsgFlags True) "hello again" get bob ##> ("", aliceId, SENT 5) void $ messageNotification apnsQ get alice =##> \case ("", c, Msg "hello again") -> c == bobId; _ -> False liftIO $ killThread threadId - pure () testNotificationsSMPRestart :: ATransport -> APNSMockServer -> IO () testNotificationsSMPRestart t APNSMockServer {apnsQ} = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - Right (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \threadId -> runExceptT $ do + (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \threadId -> runRight $ do (aliceId, bobId) <- makeConnection alice bob _ <- registerTestToken alice "abcd" NMInstant apnsQ liftIO $ threadDelay 250000 @@ -473,11 +464,11 @@ testNotificationsSMPRestart t APNSMockServer {apnsQ} = do liftIO $ killThread threadId pure (aliceId, bobId) - Right () <- runExceptT $ do + runRight_ $ do get alice =##> \case ("", "", DOWN _ [c]) -> c == bobId; _ -> False get bob =##> \case ("", "", DOWN _ [c]) -> c == aliceId; _ -> False - Right () <- withSmpServerStoreLogOn t testPort $ \threadId -> runExceptT $ do + withSmpServerStoreLogOn t testPort $ \threadId -> runRight_ $ do get alice =##> \case ("", "", UP _ [c]) -> c == bobId; _ -> False get bob =##> \case ("", "", UP _ [c]) -> c == aliceId; _ -> False liftIO $ threadDelay 1000000 @@ -486,13 +477,12 @@ testNotificationsSMPRestart t APNSMockServer {apnsQ} = do _ <- messageNotificationData alice apnsQ get alice =##> \case ("", c, Msg "hello again") -> c == bobId; _ -> False liftIO $ killThread threadId - pure () testSwitchNotifications :: InitialAgentServers -> APNSMockServer -> IO () testSwitchNotifications servers APNSMockServer {apnsQ} = do a <- getSMPAgentClient agentCfg servers b <- getSMPAgentClient agentCfg {database = testDB2, initialClientId = 1} servers - Right () <- runExceptT $ do + runRight_ $ do (aId, bId) <- makeConnection a b exchangeGreetingsMsgId 4 a bId b aId _ <- registerTestToken a "abcd" NMInstant apnsQ @@ -508,7 +498,6 @@ testSwitchNotifications servers APNSMockServer {apnsQ} = do switchComplete a bId b aId liftIO $ threadDelay 500000 testMessage "hello again" - pure () messageNotification :: TBQueue APNSMockRequest -> ExceptT AgentErrorType IO (C.CbNonce, ByteString) messageNotification apnsQ = do diff --git a/tests/CLITests.hs b/tests/CLITests.hs index 1834bf6e5..e18cf1963 100644 --- a/tests/CLITests.hs +++ b/tests/CLITests.hs @@ -6,6 +6,7 @@ import Data.Ini (lookupValue, readIniFile) import Data.List (isPrefixOf) import Simplex.Messaging.Notifications.Server.Main import Simplex.Messaging.Server.Main +import Simplex.Messaging.Transport (simplexMQVersion) import Simplex.Messaging.Util (catchAll_) import System.Directory (doesFileExist) import System.Environment (withArgs) @@ -51,7 +52,7 @@ smpServerTest storeLog basicAuth = do lookupValue "INACTIVE_CLIENTS" "disconnect" ini `shouldBe` Right "off" doesFileExist (cfgPath <> "/ca.key") `shouldReturn` True r <- lines <$> capture_ (withArgs ["start"] $ (100000 `timeout` smpServerCLI cfgPath logPath) `catchAll_` pure (Just ())) - r `shouldContain` ["SMP server v4.0.0"] + r `shouldContain` ["SMP server v" <> simplexMQVersion] r `shouldContain` (if storeLog then ["Store log: " <> logPath <> "/smp-server-store.log"] else ["Store log disabled."]) r `shouldContain` ["Listening on port 5223 (TLS)..."] r `shouldContain` ["not expiring inactive clients"] @@ -71,7 +72,7 @@ ntfServerTest storeLog = do lookupValue "TRANSPORT" "websockets" ini `shouldBe` Right "off" doesFileExist (ntfCfgPath <> "/ca.key") `shouldReturn` True r <- lines <$> capture_ (withArgs ["start"] $ (100000 `timeout` ntfServerCLI ntfCfgPath ntfLogPath) `catchAll_` pure (Just ())) - r `shouldContain` ["SMP notifications server v1.2.0"] + r `shouldContain` ["SMP notifications server v" <> ntfServerVersion] r `shouldContain` (if storeLog then ["Store log: " <> ntfLogPath <> "/ntf-server-store.log"] else ["Store log disabled."]) r `shouldContain` ["Listening on port 443 (TLS)..."] capture_ (withStdin "Y" . withArgs ["delete"] $ ntfServerCLI ntfCfgPath ntfLogPath) diff --git a/tests/CoreTests/CryptoTests.hs b/tests/CoreTests/CryptoTests.hs index 25975dc6f..e35b109ff 100644 --- a/tests/CoreTests/CryptoTests.hs +++ b/tests/CoreTests/CryptoTests.hs @@ -1,18 +1,20 @@ {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} module CoreTests.CryptoTests (cryptoTests) where +import qualified Data.ByteString.Char8 as B +import Data.Either (isRight) +import qualified Data.Text as T +import Data.Text.Encoding (encodeUtf8) import qualified Simplex.Messaging.Crypto as C import Test.Hspec import Test.Hspec.QuickCheck (modifyMaxSuccess) import Test.QuickCheck -import qualified Data.Text as T -import qualified Data.ByteString.Char8 as B -import Data.Text.Encoding (encodeUtf8) cryptoTests :: Spec -cryptoTests = modifyMaxSuccess (const 10000) $ do - describe "padding / unpadding" $ do +cryptoTests = do + modifyMaxSuccess (const 10000) . describe "padding / unpadding" $ do it "should pad / unpad string" . property $ \(s, paddedLen) -> let b = encodeUtf8 $ T.pack s len = B.length b @@ -34,3 +36,37 @@ cryptoTests = modifyMaxSuccess (const 10000) $ do it "unpad should fail on shorter string" $ do C.unPad "\000\003abc" `shouldBe` Right "abc" C.unPad "\000\003ab" `shouldBe` Left C.CryptoInvalidMsgError + describe "Ed signatures" $ do + describe "Ed25519" $ testSignature C.SEd25519 + describe "Ed448" $ testSignature C.SEd448 + describe "DH X25519 + cryptobox" $ + testDHCryptoBox + describe "X509 key encoding" $ do + describe "Ed25519" $ testEncoding C.SEd25519 + describe "Ed448" $ testEncoding C.SEd448 + describe "X25519" $ testEncoding C.SX25519 + describe "X448" $ testEncoding C.SX448 + +testSignature :: (C.AlgorithmI a, C.SignatureAlgorithm a) => C.SAlgorithm a -> Spec +testSignature alg = it "should sign / verify string" . ioProperty $ do + (k, pk) <- C.generateSignatureKeyPair alg + pure $ \s -> let b = encodeUtf8 $ T.pack s in C.verify k (C.sign pk b) b + +testDHCryptoBox :: Spec +testDHCryptoBox = it "should encrypt / decrypt string" . ioProperty $ do + (sk, spk) <- C.generateKeyPair' + (rk, rpk) <- C.generateKeyPair' + nonce <- C.randomCbNonce + pure $ \(s, pad) -> + let b = encodeUtf8 $ T.pack s + paddedLen = B.length b + abs pad + 2 + cipher = C.cbEncrypt (C.dh' rk spk) nonce b paddedLen + plain = C.cbDecrypt (C.dh' sk rpk) nonce =<< cipher + in isRight cipher && cipher /= plain && Right b == plain + +testEncoding :: (C.AlgorithmI a) => C.SAlgorithm a -> Spec +testEncoding alg = it "should encode / decode key" . ioProperty $ do + (k, pk) <- C.generateKeyPair alg + pure $ \(_ :: Int) -> + C.decodePubKey (C.encodePubKey k) == Right k + && C.decodePrivKey (C.encodePrivKey pk) == Right pk diff --git a/tests/CoreTests/RetryIntervalTests.hs b/tests/CoreTests/RetryIntervalTests.hs new file mode 100644 index 000000000..5495e2a3a --- /dev/null +++ b/tests/CoreTests/RetryIntervalTests.hs @@ -0,0 +1,61 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +module CoreTests.RetryIntervalTests where + +import Control.Concurrent.STM +import Control.Monad (when) +import Data.Time.Clock (UTCTime, diffUTCTime, getCurrentTime, nominalDiffTimeToSeconds) +import Simplex.Messaging.Agent.RetryInterval +import Test.Hspec + +retryIntervalTests :: Spec +retryIntervalTests = do + describe "Retry interval with 2 modes and lock" $ do + testRetryIntervalSameMode + testRetryIntervalSwitchMode + +testRI :: RetryInterval2 +testRI = + RetryInterval2 + { riSlow = + RetryInterval + { initialInterval = 20000, + increaseAfter = 40000, + maxInterval = 40000 + }, + riFast = + RetryInterval + { initialInterval = 10000, + increaseAfter = 20000, + maxInterval = 40000 + } + } + +testRetryIntervalSameMode :: Spec +testRetryIntervalSameMode = + it "should increase elapased time and interval when the mode stays the same" $ do + lock <- newEmptyTMVarIO + intervals <- newTVarIO [] + ts <- newTVarIO =<< getCurrentTime + withRetryLock2 testRI lock $ \loop -> do + ints <- addInterval intervals ts + when (length ints < 9) $ loop RIFast + (reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 4, 4, 4] + +testRetryIntervalSwitchMode :: Spec +testRetryIntervalSwitchMode = + it "should increase elapased time and interval when the mode stays the same" $ do + lock <- newEmptyTMVarIO + intervals <- newTVarIO [] + ts <- newTVarIO =<< getCurrentTime + withRetryLock2 testRI lock $ \loop -> do + ints <- addInterval intervals ts + when (length ints < 11) $ loop $ if length ints <= 5 then RIFast else RISlow + (reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 2, 2, 3, 4, 4] + +addInterval :: TVar [Int] -> TVar UTCTime -> IO [Int] +addInterval intervals ts = do + ts' <- getCurrentTime + atomically $ do + int :: Int <- truncate . (* 100) . nominalDiffTimeToSeconds <$> stateTVar ts (\t -> (diffUTCTime ts' t, ts')) + stateTVar intervals $ \ints -> (int : ints, int : ints) diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index 42c439987..7601652d1 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -10,7 +10,6 @@ module NtfServerTests where import Control.Concurrent (threadDelay) -import Control.Monad.Except (runExceptT) import qualified Data.Aeson as J import qualified Data.Aeson.Types as JT import Data.Bifunctor (first) @@ -77,8 +76,7 @@ sendRecvNtf h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse) signSendRecvNtf h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd) - Right sig <- runExceptT $ C.sign pk t - Right () <- tPut1 h (Just sig, t) + Right () <- tPut1 h (Just $ C.sign pk t, t) tGet1 h (.->) :: J.Value -> J.Key -> Either String ByteString diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 5dd98fe14..66315336a 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -14,7 +14,7 @@ module ServerTests where import Control.Concurrent (ThreadId, killThread, threadDelay) import Control.Concurrent.STM import Control.Exception (SomeException, try) -import Control.Monad.Except (forM, forM_, runExceptT) +import Control.Monad.Except (forM, forM_) import Control.Monad.IO.Class import Data.Bifunctor (first) import Data.ByteString.Base64 @@ -49,9 +49,10 @@ serverTests t@(ATransport t') = do describe "switch subscription to another TCP connection" $ testSwitchSub t describe "GET command" $ testGetCommand t' describe "GET & SUB commands" $ testGetSubCommands t' + describe "Exceeding queue quota" $ testExceedQueueQuota t' describe "Store log" $ testWithStoreLog t describe "Restore messages" $ testRestoreMessages t - describe "Restore messages (v2)" $ testRestoreMessagesV2 t + describe "Restore messages (old / v2)" $ testRestoreMessagesV2 t describe "Timing of AUTH error" $ testTiming t describe "Message notifications" $ testMessageNotifications t describe "Message expiration" $ do @@ -77,8 +78,7 @@ sendRecv h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg) signSendRecv h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd) - Right sig <- runExceptT $ C.sign pk t - Right () <- tPut1 h (Just sig, t) + Right () <- tPut1 h (Just $ C.sign pk t, t) tGet1 h tPut1 :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ()) @@ -104,9 +104,11 @@ decryptMsgV2 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either C.Cry decryptMsgV2 dhShared = C.cbDecrypt dhShared . C.cbNonce decryptMsgV3 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either String MsgBody -decryptMsgV3 dhShared nonce body = do - ClientRcvMsgBody {msgBody} <- parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body) - pure msgBody +decryptMsgV3 dhShared nonce body = + case parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body) of + Right ClientRcvMsgBody {msgBody} -> Right msgBody + Right ClientRcvMsgQuota {} -> Left "ClientRcvMsgQuota" + Left e -> Left e testCreateSecureV2 :: forall c. Transport c => TProxy c -> Spec testCreateSecureV2 _ = @@ -494,6 +496,32 @@ testGetSubCommands t = Resp "12" _ OK <- signSendRecv rh2 rKey ("12", rId, GET) pure () +testExceedQueueQuota :: forall c. Transport c => TProxy c -> Spec +testExceedQueueQuota t = + it "should reply with ERR QUOTA to sender and send QUOTA message to the recipient" $ do + withSmpServerConfigOn (ATransport t) cfg {msgQueueQuota = 2} testPort $ \_ -> + testSMPClient @c $ \sh -> testSMPClient @c $ \rh -> do + (sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519 + (sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub + let dec = decryptMsgV3 dhShared + Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello 1") + Resp "2" _ OK <- signSendRecv sh sKey ("2", sId, _SEND "hello 2") + Resp "3" _ (ERR QUOTA) <- signSendRecv sh sKey ("3", sId, _SEND "hello 3") + Resp "" _ (Msg mId1 msg1) <- tGet1 rh + (dec mId1 msg1, Right "hello 1") #== "hello 1" + Resp "4" _ (Msg mId2 msg2) <- signSendRecv rh rKey ("4", rId, ACK mId1) + (dec mId2 msg2, Right "hello 2") #== "hello 2" + Resp "5" _ (ERR QUOTA) <- signSendRecv sh sKey ("5", sId, _SEND "hello 3") + Resp "6" _ (Msg mId3 msg3) <- signSendRecv rh rKey ("6", rId, ACK mId2) + (dec mId3 msg3, Left "ClientRcvMsgQuota") #== "ClientRcvMsgQuota" + Resp "7" _ (ERR QUOTA) <- signSendRecv sh sKey ("7", sId, _SEND "hello 3") + Resp "8" _ OK <- signSendRecv rh rKey ("8", rId, ACK mId3) + Resp "9" _ OK <- signSendRecv sh sKey ("9", sId, _SEND "hello 3") + Resp "" _ (Msg mId4 msg4) <- tGet1 rh + (dec mId4 msg4, Right "hello 3") #== "hello 3" + Resp "10" _ OK <- signSendRecv rh rKey ("10", rId, ACK mId4) + pure () + testWithStoreLog :: ATransport -> Spec testWithStoreLog at@(ATransport t) = it "should store simplex queues to log and restore them after server restart" $ do @@ -600,10 +628,12 @@ testRestoreMessages at@(ATransport t) = Resp "2" _ OK <- signSendRecv h sKey ("2", sId, _SEND "hello 2") Resp "3" _ OK <- signSendRecv h sKey ("3", sId, _SEND "hello 3") Resp "4" _ OK <- signSendRecv h sKey ("4", sId, _SEND "hello 4") + Resp "5" _ OK <- signSendRecv h sKey ("5", sId, _SEND "hello 5") + Resp "6" _ (ERR QUOTA) <- signSendRecv h sKey ("6", sId, _SEND "hello 6") pure () logSize testStoreLogFile `shouldReturn` 2 - logSize testStoreMsgsFile `shouldReturn` 3 + logSize testStoreMsgsFile `shouldReturn` 5 withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do rId <- readTVarIO recipientId @@ -619,15 +649,21 @@ testRestoreMessages at@(ATransport t) = logSize testStoreLogFile `shouldReturn` 1 -- the last message is not removed because it was not ACK'd - logSize testStoreMsgsFile `shouldReturn` 1 + logSize testStoreMsgsFile `shouldReturn` 3 withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do rId <- readTVarIO recipientId Just rKey <- readTVarIO recipientKey Just dh <- readTVarIO dhShared + let dec = decryptMsgV3 dh Resp "4" _ (Msg mId4 msg4) <- signSendRecv h rKey ("4", rId, SUB) - Resp "5" _ OK <- signSendRecv h rKey ("5", rId, ACK mId4) - (decryptMsgV3 dh mId4 msg4, Right "hello 4") #== "restored message delivered" + (dec mId4 msg4, Right "hello 4") #== "restored message delivered" + Resp "5" _ (Msg mId5 msg5) <- signSendRecv h rKey ("5", rId, ACK mId4) + (dec mId5 msg5, Right "hello 5") #== "restored message delivered" + Resp "6" _ (Msg mId6 msg6) <- signSendRecv h rKey ("6", rId, ACK mId5) + (dec mId6 msg6, Left "ClientRcvMsgQuota") #== "restored message delivered" + Resp "7" _ OK <- signSendRecv h rKey ("7", rId, ACK mId6) + pure () logSize testStoreLogFile `shouldReturn` 1 logSize testStoreMsgsFile `shouldReturn` 0 diff --git a/tests/Test.hs b/tests/Test.hs index fbf8a4af6..c710a58df 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -1,11 +1,12 @@ {-# LANGUAGE TypeApplications #-} import AgentTests (agentTests) --- import Control.Logger.Simple import CLITests +import Control.Logger.Simple import CoreTests.CryptoTests import CoreTests.EncodingTests import CoreTests.ProtocolErrorTests +import CoreTests.RetryIntervalTests import CoreTests.VersionRangeTests import NtfServerTests (ntfServerTests) import ServerTests @@ -15,25 +16,26 @@ import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive) import System.Environment (setEnv) import Test.Hspec --- logCfg :: LogConfig --- logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} +logCfg :: LogConfig +logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} main :: IO () main = do - -- setLogLevel LogInfo -- LogError - -- withGlobalLogging logCfg $ do - createDirectoryIfMissing False "tests/tmp" - setEnv "APNS_KEY_ID" "H82WD9K9AQ" - setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8" - hspec $ do - describe "Core tests" $ do - describe "Encoding tests" encodingTests - describe "Protocol error tests" protocolErrorTests - describe "Version range" versionRangeTests - describe "Encryption tests" cryptoTests - describe "SMP server via TLS" $ serverTests (transport @TLS) - describe "SMP server via WebSockets" $ serverTests (transport @WS) - describe "Notifications server" $ ntfServerTests (transport @TLS) - describe "SMP client agent" $ agentTests (transport @TLS) - describe "Server CLIs" cliTests - removeDirectoryRecursive "tests/tmp" + setLogLevel LogError -- LogInfo + withGlobalLogging logCfg $ do + createDirectoryIfMissing False "tests/tmp" + setEnv "APNS_KEY_ID" "H82WD9K9AQ" + setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8" + hspec $ do + describe "Core tests" $ do + describe "Encoding tests" encodingTests + describe "Protocol error tests" protocolErrorTests + describe "Version range" versionRangeTests + describe "Encryption tests" cryptoTests + describe "Retry interval tests" retryIntervalTests + describe "SMP server via TLS" $ serverTests (transport @TLS) + describe "SMP server via WebSockets" $ serverTests (transport @WS) + describe "Notifications server" $ ntfServerTests (transport @TLS) + describe "SMP client agent" $ agentTests (transport @TLS) + describe "Server CLIs" cliTests + removeDirectoryRecursive "tests/tmp"