Merge branch 'master' into xftp

This commit is contained in:
Evgeny Poberezkin
2023-01-16 19:29:55 +00:00
35 changed files with 1012 additions and 604 deletions
+45
View File
@@ -1,3 +1,48 @@
# 4.3.0
SMP server:
- additional server usage statistics.
SMP agent:
- increase retry interval when sending messages after ERR QUOTA.
# 4.2.0
SMP agent and server:
- reduce sender traffic in cases when queue quota is exceeded:
- server sends quota exceeded message to the recipient when sender receives ERR QUOTA.
- recipient sends QCONT message to the send once the queue is drained (via reply queue).
- sender retry delays are increased, reducing traffic, but sender instantly resumes delivery where QCONT is received.
SMP server:
- increase internal queue sizes.
SMP agent:
- deduplicate connection IDs in connect/disconnect responses.
- unit tests for Crypto.hs.
- fix connection switch to another queue: correctly set primary send queue.
Notification server (v1.3.0):
- check token status when sending verification notificaiton.
# 4.1.0
SMP agent and server:
- option to toggle TLS handshake error logs (disabled by default).
SMP agent:
- include server address in BROKER error.
- api to get hash of double ratchet associated data (for connection verification).
- api to get agent statistics.
# 4.0.0
SMP server:
+1 -1
View File
@@ -1,5 +1,5 @@
name: simplexmq
version: 4.0.0
version: 4.3.0
synopsis: SimpleXMQ message broker
description: |
This package includes <./docs/Simplex-Messaging-Server.html server>,
+2 -1
View File
@@ -5,7 +5,7 @@ cabal-version: 1.12
-- see: https://github.com/sol/hpack
name: simplexmq
version: 4.0.0
version: 4.3.0
synopsis: SimpleXMQ message broker
description: This package includes <./docs/Simplex-Messaging-Server.html server>,
<./docs/Simplex-Messaging-Client.html client> and
@@ -366,6 +366,7 @@ test-suite smp-server-test
CoreTests.CryptoTests
CoreTests.EncodingTests
CoreTests.ProtocolErrorTests
CoreTests.RetryIntervalTests
CoreTests.VersionRangeTests
NtfClient
NtfServerTests
+123 -107
View File
@@ -504,7 +504,7 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
pure connId'
Left e -> do
-- TODO recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
unless asyncMode $ withStore' c (`deleteConn` connId')
throwError e
where
@@ -791,7 +791,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
atomically $ beginAgentOperation c AOSndNetwork
E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case
Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e)
Right (corrId, connId, cmd) -> processCmd ri corrId connId cmdId cmd
Right (corrId, connId, cmd) -> processCmd (riFast ri) corrId connId cmdId cmd
where
processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m ()
processCmd ri corrId connId cmdId command = case command of
@@ -905,7 +905,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
atomically $ do
srvs <- readTVar $ smpServers c
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
writeTVar usedSrvs used'
writeTVar usedSrvs $! used'
action srvAuth
-- ^ ^ ^ async command processing /
@@ -964,22 +964,22 @@ queuePendingMsgs c sq msgIds = atomically $ do
modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + length msgIds}
-- s <- readTVar (msgDeliveryOp c)
-- unsafeIOToSTM $ putStrLn $ "msgDeliveryOp: " <> show (opsInProgress s)
q <- getPendingMsgQ c sq
mapM_ (writeTQueue q) msgIds
(mq, _) <- getPendingMsgQ c sq
mapM_ (writeTQueue mq) msgIds
getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId)
getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId, TMVar ())
getPendingMsgQ c SndQueue {server, sndId} = do
let qKey = (server, sndId)
maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c)
where
newMsgQueue qKey = do
mq <- newTQueue
TM.insert qKey mq $ smpQueueMsgQueues c
pure mq
q <- (,) <$> newTQueue <*> newEmptyTMVar
TM.insert qKey q $ smpQueueMsgQueues c
pure q
runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do
mq <- atomically $ getPendingMsgQ c sq
(mq, qLock) <- atomically $ getPendingMsgQ c sq
ri <- asks $ messageRetryInterval . config
forever $ do
atomically $ endAgentOperation c AOSndNetwork
@@ -993,7 +993,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
Left (e :: E.SomeException) ->
notify $ MERR mId (INTERNAL $ show e)
Right (rq_, PendingMsgData {msgType, msgBody, msgFlags, internalTs}) ->
withRetryInterval ri $ \loop -> do
withRetryLock2 ri qLock $ \loop -> do
resp <- tryError $ case msgType of
AM_CONN_INFO -> sendConfirmation c sq msgBody
_ -> sendAgentMessage c sq msgFlags msgBody
@@ -1004,7 +1004,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
SMP SMP.QUOTA -> case msgType of
AM_CONN_INFO -> connError msgId NOT_AVAILABLE
AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE
_ -> retrySndOp c loop
_ -> retrySndOp c $ loop RISlow
SMP SMP.AUTH -> case msgType of
AM_CONN_INFO -> connError msgId NOT_AVAILABLE
AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE
@@ -1013,7 +1013,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
-- because the queue must be secured by the time the confirmation or the first HELLO is received
| duplexHandshake == Just True -> connErr
| otherwise ->
ifM (msgExpired helloTimeout) connErr (retrySndOp c loop)
ifM (msgExpired helloTimeout) connErr (retrySndOp c $ loop RIFast)
where
connErr = case rq_ of
-- party initiating connection
@@ -1022,6 +1022,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
_ -> connError msgId NOT_ACCEPTED
AM_REPLY_ -> notifyDel msgId err
AM_A_MSG_ -> notifyDel msgId err
AM_QCONT_ -> notifyDel msgId err
AM_QADD_ -> qError msgId "QADD: AUTH"
AM_QKEY_ -> qError msgId "QKEY: AUTH"
AM_QUSE_ -> qError msgId "QUSE: AUTH"
@@ -1031,7 +1032,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
-- the message sending would be retried
| temporaryOrHostError e -> do
let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndOp c loop)
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndOp c $ loop RIFast)
| otherwise -> notifyDel msgId err
where
msgExpired timeoutSel = do
@@ -1044,7 +1045,6 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
withStore' c $ \db -> do
setSndQueueStatus db sq Confirmed
when (isJust rq_) $ removeConfirmations db connId
-- TODO possibly notification flag should be ON for one of the parties, to result in contact connected notification
unless (duplexHandshake == Just True) . void $ enqueueMessage c cData sq SMP.noMsgFlags HELLO
AM_CONN_INFO_REPLY -> pure ()
AM_REPLY_ -> pure ()
@@ -1071,6 +1071,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
qInfo <- createReplyQueue c cData sq srv
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
AM_A_MSG_ -> notify $ SENT mId
AM_QCONT_ -> pure ()
AM_QADD_ -> pure ()
AM_QKEY_ -> pure ()
AM_QUSE_ -> pure ()
@@ -1080,16 +1081,18 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
case conn of
DuplexConnection cData' rqs sqs -> do
-- remove old snd queue from connection once QTEST is sent to the new queue
case findQ (qAddress sq) sqs of
let addr = qAddress sq
case findQ addr sqs of
-- this is the same queue where this loop delivers messages to but with updated state
Just SndQueue {dbReplaceQueueId = Just replacedId, primary} ->
case removeQP (\SndQueue {dbQueueId} -> dbQueueId == replacedId) sqs of
-- second part of this condition is a sanity check because dbReplaceQueueId cannot point to the same queue, see switchConnection'
case removeQP (\sq'@SndQueue {dbQueueId} -> dbQueueId == replacedId && not (sameQueue addr sq')) sqs of
Nothing -> internalErr msgId "sent QTEST: queue not found in connection"
Just (sq', sq'' : sqs') -> do
-- remove the delivery from the map to stop the thread when the delivery loop is complete
atomically $ TM.delete (qAddress sq') $ smpQueueMsgQueues c
withStore' c $ \db -> do
when primary $ setSndQueuePrimary db connId sq'
when primary $ setSndQueuePrimary db connId sq
deletePendingMsgs db connId sq'
deleteConnSndQueue db connId sq'
let sqs'' = sq'' :| sqs'
@@ -1225,7 +1228,7 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
(Nothing, Just NTARegister) -> do
when (savedDeviceToken /= suppliedDeviceToken) $ withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken
registerToken tkn $> NTRegistered
-- TODO minimal time before repeat registration
-- possible improvement: add minimal time before repeat registration
(Just tknId, Nothing)
| savedDeviceToken == suppliedDeviceToken ->
when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered
@@ -1243,8 +1246,8 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
agentNtfEnableCron c tknId tkn cron
when (suppliedNtfMode == NMInstant) $ initializeNtfSubs c
when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete
pure ntfTknStatus -- TODO
-- agentNtfCheckToken c tknId tkn >>= \case
-- possible improvement: get updated token status from the server, or maybe TCRON could return the current status
pure ntfTknStatus
| otherwise -> replaceToken tknId
(Just tknId, Just NTADelete) -> do
agentNtfDeleteToken c tknId tkn
@@ -1411,11 +1414,6 @@ sendNtfConnCommands c cmd = do
_ ->
atomically $ writeTBQueue (subQ c) ("", connId, ERR $ INTERNAL "no connection data")
-- TODO
-- There should probably be another function to cancel all subscriptions that would flush the queue first,
-- so that supervisor stops processing pending commands?
-- It is an optimization, but I am thinking how it would behave if a user were to flip on/off quickly several times.
setNtfServers' :: AgentMonad m => AgentClient -> [NtfServer] -> m ()
setNtfServers' c = atomically . writeTVar (ntfServers c)
@@ -1492,88 +1490,96 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
processSMP :: RcvQueue -> Connection c -> ConnData -> m ()
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {connId, duplexHandshake} = withConnLock c connId "processSMP" $
case cmd of
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> handleNotifyAck $ do
SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} <- decryptSMPMessage v rq msg
clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <-
parseMessage msgBody
clientVRange <- asks $ smpClientVRange . config
unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
case (e2eDhSecret, e2ePubKey_) of
(Nothing, Just e2ePubKey) -> do
let e2eDh = C.dh' e2ePubKey e2ePrivKey
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo, agentVersion}) ->
smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo phVer agentVersion >> ack
(SMP.PHEmpty, AgentInvitation {connReq, connInfo}) ->
smpInvitation connReq connInfo >> ack
_ -> prohibited >> ack
(Just e2eDh, Nothing) -> do
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do
-- primary queue is set as Active in helloMsg, below is to set additional queues Active
let RcvQueue {primary, dbReplaceQueueId} = rq
unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active
case (conn, dbReplaceQueueId) of
(DuplexConnection _ rqs _, Just replacedId) -> do
when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq
case find (\RcvQueue {dbQueueId} -> dbQueueId == replacedId) rqs of
Just RcvQueue {server, rcvId} -> do
enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId
_ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection"
_ -> pure ()
tryError agentClientMsg >>= \case
Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of
HELLO -> helloMsg >> ackDel msgId
REPLY cReq -> replyMsg cReq >> ackDel msgId
-- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command
A_MSG body -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
QADD qs -> qDuplex "QADD" $ qAddMsg qs
QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs
QUSE qs -> qDuplex "QUSE" $ qUseMsg qs
-- no action needed for QTEST
-- any message in the new queue will mark it active and trigger deletion of the old queue
QTEST _ -> logServer "<--" c srv rId "MSG <QTEST>" >> ackDel msgId
where
qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m ()
qDuplex name a = case conn of
DuplexConnection {} -> a conn >> ackDel msgId
_ -> qError $ name <> ": message must be sent to duplex connection"
Right _ -> prohibited >> ack
Left e@(AGENT A_DUPLICATE) -> do
withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case
Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck}
| userAck -> ackDel internalId
| otherwise -> do
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
AgentMessage _ (A_MSG body) -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
_ -> pure ()
_ -> throwError e
Left e -> throwError e
where
agentClientMsg :: m (Maybe (InternalId, MsgMeta, AMessage))
agentClientMsg = withStore c $ \db -> runExceptT $ do
agentMsgBody <- agentRatchetDecrypt db connId encAgentMsg
liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case
agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do
let msgType = agentMessageType agentMsg
internalHash = C.sha256Hash agentMsgBody
internalTs <- liftIO getCurrentTime
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId
let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash
recipient = (unId internalId, internalTs)
broker = (srvMsgId, systemToUTCTime srvTs)
msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId}
rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash}
liftIO $ createRcvMsg db connId rq rcvMsg
pure $ Just (internalId, msgMeta, aMessage)
_ -> pure Nothing
_ -> prohibited >> ack
_ -> prohibited >> ack
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
handleNotifyAck $
decryptSMPMessage v rq msg >>= \case
SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody
SMP.ClientRcvMsgQuota {} -> queueDrained >> ack
where
queueDrained = case conn of
DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq)
_ -> pure ()
processClientMsg srvTs msgFlags msgBody = do
clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <-
parseMessage msgBody
clientVRange <- asks $ smpClientVRange . config
unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
case (e2eDhSecret, e2ePubKey_) of
(Nothing, Just e2ePubKey) -> do
let e2eDh = C.dh' e2ePubKey e2ePrivKey
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption, encConnInfo, agentVersion}) ->
smpConfirmation senderKey e2ePubKey e2eEncryption encConnInfo phVer agentVersion >> ack
(SMP.PHEmpty, AgentInvitation {connReq, connInfo}) ->
smpInvitation connReq connInfo >> ack
_ -> prohibited >> ack
(Just e2eDh, Nothing) -> do
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do
-- primary queue is set as Active in helloMsg, below is to set additional queues Active
let RcvQueue {primary, dbReplaceQueueId} = rq
unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active
case (conn, dbReplaceQueueId) of
(DuplexConnection _ rqs _, Just replacedId) -> do
when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq
case find (\RcvQueue {dbQueueId} -> dbQueueId == replacedId) rqs of
Just RcvQueue {server, rcvId} -> do
enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId
_ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection"
_ -> pure ()
tryError agentClientMsg >>= \case
Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of
HELLO -> helloMsg >> ackDel msgId
REPLY cReq -> replyMsg cReq >> ackDel msgId
-- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command
A_MSG body -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
QCONT addr -> qDuplex "QCONT" $ continueSending addr
QADD qs -> qDuplex "QADD" $ qAddMsg qs
QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs
QUSE qs -> qDuplex "QUSE" $ qUseMsg qs
-- no action needed for QTEST
-- any message in the new queue will mark it active and trigger deletion of the old queue
QTEST _ -> logServer "<--" c srv rId "MSG <QTEST>" >> ackDel msgId
where
qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m ()
qDuplex name a = case conn of
DuplexConnection {} -> a conn >> ackDel msgId
_ -> qError $ name <> ": message must be sent to duplex connection"
Right _ -> prohibited >> ack
Left e@(AGENT A_DUPLICATE) -> do
withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case
Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck}
| userAck -> ackDel internalId
| otherwise -> do
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
AgentMessage _ (A_MSG body) -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
_ -> pure ()
_ -> throwError e
Left e -> throwError e
where
agentClientMsg :: m (Maybe (InternalId, MsgMeta, AMessage))
agentClientMsg = withStore c $ \db -> runExceptT $ do
agentMsgBody <- agentRatchetDecrypt db connId encAgentMsg
liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case
agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do
let msgType = agentMessageType agentMsg
internalHash = C.sha256Hash agentMsgBody
internalTs <- liftIO getCurrentTime
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId
let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash
recipient = (unId internalId, internalTs)
broker = (srvMsgId, systemToUTCTime srvTs)
msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId}
rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash}
liftIO $ createRcvMsg db connId rq rcvMsg
pure $ Just (internalId, msgMeta, aMessage)
_ -> pure Nothing
_ -> prohibited >> ack
_ -> prohibited >> ack
ack :: m ()
ack = enqueueCmd $ ICAck rId srvMsgId
ackDel :: InternalId -> m ()
@@ -1698,6 +1704,16 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
connectReplyQueues c cData ownConnInfo smpQueues `catchError` (notify . ERR)
_ -> prohibited
continueSending :: (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m ()
continueSending addr (DuplexConnection _ _ sqs) =
case findQ addr sqs of
Just sq -> do
logServer "<--" c srv rId "MSG <QCONT>"
atomically $ do
(_, qLock) <- getPendingMsgQ c sq
void $ tryPutTMVar qLock ()
Nothing -> qError "QCONT: queue address not found"
-- processed by queue sender
qAddMsg :: NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> m ()
qAddMsg ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported"
+16 -19
View File
@@ -98,7 +98,7 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (isRight, partitionEithers)
import Data.Functor (($>))
import Data.List (partition, (\\))
import Data.List (partition)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
@@ -180,7 +180,7 @@ data AgentClient = AgentClient
activeSubs :: TRcvQueues,
pendingSubs :: TRcvQueues,
pendingMsgsQueued :: TMap SndQAddr Bool,
smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId),
smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId, TMVar ()),
smpQueueMsgDeliveries :: TMap SndQAddr (Async ()),
connCmdsQueued :: TMap ConnId Bool,
asyncCmdQueues :: TMap (Maybe SMPServer) (TQueue AsyncCmdId),
@@ -303,12 +303,9 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
qs <- RQ.getDelSrvQueues srv $ activeSubs c
(qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c
mapM_ (`RQ.addQueue` pendingSubs c) qs
cs <- RQ.getConns (activeSubs c)
-- TODO deduplicate conns
let conns = map (connId :: RcvQueue -> ConnId) qs \\ S.toList cs
pure (qs, conns)
pure (qs, S.toList conns)
serverDown :: ([RcvQueue], [ConnId]) -> IO ()
serverDown (qs, conns) = whenM (readTVarIO active) $ do
@@ -345,8 +342,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
unless connected . forM_ client_ $ \cl -> do
incClientStat c cl "CONNECT" ""
notifySub "" $ hostEvent CONNECT cl
-- TODO deduplicate okConns
let conns = okConns \\ S.toList cs
let conns = S.toList $ S.fromList okConns `S.difference` cs
unless (null conns) $ notifySub "" $ UP srv conns
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
@@ -611,7 +607,7 @@ subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
modifyTVar' (subscrConns c) $ S.insert connId
RQ.addQueue rq $ pendingSubs c
withLogClient c server rcvId "SUB" $ \smp ->
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq)
@@ -647,9 +643,8 @@ temporaryOrHostError = \case
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (Maybe SMPClient, [(RcvQueue, Either AgentErrorType ())])
subscribeQueues c srv qs = do
(errs, qs_) <- partitionEithers <$> mapM checkQueue qs
forM_ qs_ $ \rq@RcvQueue {connId, server = _server} -> atomically $ do
-- TODO check server is correct
modifyTVar (subscrConns c) $ S.insert connId
forM_ qs_ $ \rq@RcvQueue {connId} -> atomically $ do
modifyTVar' (subscrConns c) $ S.insert connId
RQ.addQueue rq $ pendingSubs c
case L.nonEmpty qs_ of
Just qs' -> do
@@ -667,14 +662,16 @@ subscribeQueues c srv qs = do
pure $ map (second . first $ protocolClientError SMP $ clientServer smp) rs
_ -> pure (Nothing, errs)
where
checkQueue rq@RcvQueue {rcvId, server} = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq
checkQueue rq@RcvQueue {rcvId, server}
| server == srv = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq
| otherwise = pure $ Left (rq, Left $ INTERNAL "queue server does not match parameter")
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m ()
addSubscription c rq@RcvQueue {connId} = atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
modifyTVar' (subscrConns c) $ S.insert connId
RQ.addQueue rq $ activeSubs c
RQ.deleteQueue rq $ pendingSubs c
@@ -683,7 +680,7 @@ hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c
removeSubscription :: AgentClient -> ConnId -> STM ()
removeSubscription c connId = do
modifyTVar (subscrConns c) $ S.delete connId
modifyTVar' (subscrConns c) $ S.delete connId
RQ.deleteConn connId $ activeSubs c
RQ.deleteConn connId $ pendingSubs c
@@ -948,7 +945,7 @@ storeError = \case
incStat :: AgentClient -> Int -> AgentStatsKey -> STM ()
incStat AgentClient {agentStats} n k = do
TM.lookup k agentStats >>= \case
Just v -> modifyTVar v (+ n)
Just v -> modifyTVar' v (+ n)
_ -> newTVar n >>= \v -> TM.insert k v agentStats
incClientStat :: AgentClient -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
+18 -6
View File
@@ -83,7 +83,7 @@ data AgentConfig = AgentConfig
smpCfg :: ProtocolClientConfig,
ntfCfg :: ProtocolClientConfig,
reconnectInterval :: RetryInterval,
messageRetryInterval :: RetryInterval,
messageRetryInterval :: RetryInterval2,
messageTimeout :: NominalDiffTime,
helloTimeout :: NominalDiffTime,
ntfCron :: Word16,
@@ -108,12 +108,24 @@ defaultReconnectInterval =
maxInterval = 180_000000
}
defaultMessageRetryInterval :: RetryInterval
defaultMessageRetryInterval :: RetryInterval2
defaultMessageRetryInterval =
RetryInterval
{ initialInterval = 1_000000,
increaseAfter = 10_000000,
maxInterval = 60_000000
RetryInterval2
{ riFast =
RetryInterval
{ initialInterval = 1_000000,
increaseAfter = 10_000000,
maxInterval = 60_000000
},
riSlow =
-- TODO: these timeouts can be increased in v4.6 once most clients are updated
-- to resume sending on QCONT messages.
-- After that local message expiration period should be also increased.
RetryInterval
{ initialInterval = 30_000000,
increaseAfter = 30_000000,
maxInterval = 600_000000
}
}
defaultAgentConfig :: AgentConfig
@@ -191,7 +191,7 @@ runNtfWorker c srv doWork = do
case clientNtfCreds of
Just ClientNtfCreds {ntfPrivateKey, notifierId} -> do
nSubId <- agentNtfCreateSubscription c tknId tkn (SMPQueueNtf smpServer notifierId) ntfPrivateKey
-- TODO smaller retry until Active, less frequently (daily?) once Active
-- possible improvement: smaller retry until Active, less frequently (daily?) once Active
let actionTs' = addUTCTime 30 ts
withStore' c $ \db ->
updateNtfSubscription db sub {ntfSubId = Just nSubId, ntfSubStatus = NASCreated NSNew} (NtfSubNTFAction NSACheck) actionTs'
+11 -1
View File
@@ -577,6 +577,7 @@ data AgentMessageType
| AM_HELLO_
| AM_REPLY_
| AM_A_MSG_
| AM_QCONT_
| AM_QADD_
| AM_QKEY_
| AM_QUSE_
@@ -590,6 +591,7 @@ instance Encoding AgentMessageType where
AM_HELLO_ -> "H"
AM_REPLY_ -> "R"
AM_A_MSG_ -> "M"
AM_QCONT_ -> "QC"
AM_QADD_ -> "QA"
AM_QKEY_ -> "QK"
AM_QUSE_ -> "QU"
@@ -603,6 +605,7 @@ instance Encoding AgentMessageType where
'M' -> pure AM_A_MSG_
'Q' ->
A.anyChar >>= \case
'C' -> pure AM_QCONT_
'A' -> pure AM_QADD_
'K' -> pure AM_QKEY_
'U' -> pure AM_QUSE_
@@ -623,6 +626,7 @@ agentMessageType = \case
-- REPLY is only used in v1
REPLY _ -> AM_REPLY_
A_MSG _ -> AM_A_MSG_
QCONT _ -> AM_QCONT_
QADD _ -> AM_QADD_
QKEY _ -> AM_QKEY_
QUSE _ -> AM_QUSE_
@@ -645,6 +649,7 @@ data AMsgType
= HELLO_
| REPLY_
| A_MSG_
| QCONT_
| QADD_
| QKEY_
| QUSE_
@@ -656,6 +661,7 @@ instance Encoding AMsgType where
HELLO_ -> "H"
REPLY_ -> "R"
A_MSG_ -> "M"
QCONT_ -> "QC"
QADD_ -> "QA"
QKEY_ -> "QK"
QUSE_ -> "QU"
@@ -667,6 +673,7 @@ instance Encoding AMsgType where
'M' -> pure A_MSG_
'Q' ->
A.anyChar >>= \case
'C' -> pure QCONT_
'A' -> pure QADD_
'K' -> pure QKEY_
'U' -> pure QUSE_
@@ -684,6 +691,8 @@ data AMessage
REPLY (L.NonEmpty SMPQueueInfo)
| -- | agent envelope for the client message
A_MSG MsgBody
| -- | the message instructing the client to continue sending messages (after ERR QUOTA)
QCONT SndQAddr
| -- add queue to connection (sent by recipient), with optional address of the replaced queue
QADD (L.NonEmpty (SMPQueueUri, Maybe SndQAddr))
| -- key to secure the added queues and agree e2e encryption key (sent by sender)
@@ -701,6 +710,7 @@ instance Encoding AMessage where
HELLO -> smpEncode HELLO_
REPLY smpQueues -> smpEncode (REPLY_, smpQueues)
A_MSG body -> smpEncode (A_MSG_, Tail body)
QCONT addr -> smpEncode (QCONT_, addr)
QADD qs -> smpEncode (QADD_, qs)
QKEY qs -> smpEncode (QKEY_, qs)
QUSE qs -> smpEncode (QUSE_, qs)
@@ -711,6 +721,7 @@ instance Encoding AMessage where
HELLO_ -> pure HELLO
REPLY_ -> REPLY <$> smpP
A_MSG_ -> A_MSG . unTail <$> smpP
QCONT_ -> QCONT <$> smpP
QADD_ -> QADD <$> smpP
QKEY_ -> QKEY <$> smpP
QUSE_ -> QUSE <$> smpP
@@ -1119,7 +1130,6 @@ instance ToJSON BrokerErrorType where
toEncoding = J.genericToEncoding $ sumTypeJSON id
-- | Errors of another SMP agent.
-- TODO encode/decode without A prefix
data SMPAgentError
= -- | client or agent message that failed to parse
A_MESSAGE
+55 -10
View File
@@ -1,10 +1,21 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Agent.RetryInterval where
module Simplex.Messaging.Agent.RetryInterval
( RetryInterval (..),
RetryInterval2 (..),
RetryIntervalMode (..),
withRetryInterval,
withRetryLock2,
)
where
import Control.Concurrent (threadDelay)
import Control.Concurrent (forkIO, threadDelay)
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Simplex.Messaging.Util (whenM)
import UnliftIO.STM
data RetryInterval = RetryInterval
{ initialInterval :: Int,
@@ -12,17 +23,51 @@ data RetryInterval = RetryInterval
maxInterval :: Int
}
data RetryInterval2 = RetryInterval2
{ riSlow :: RetryInterval,
riFast :: RetryInterval
}
data RetryIntervalMode = RISlow | RIFast
deriving (Eq)
withRetryInterval :: forall m. MonadIO m => RetryInterval -> (m () -> m ()) -> m ()
withRetryInterval RetryInterval {initialInterval, increaseAfter, maxInterval} action =
callAction 0 initialInterval
withRetryInterval ri action = callAction 0 $ initialInterval ri
where
callAction :: Int -> Int -> m ()
callAction elapsedTime delay = action loop
callAction elapsed delay = action loop
where
loop = do
let newDelay =
if elapsedTime < increaseAfter || delay == maxInterval
then delay
else min (delay * 3 `div` 2) maxInterval
liftIO $ threadDelay delay
callAction (elapsedTime + delay) newDelay
let elapsed' = elapsed + delay
callAction elapsed' $ nextDelay elapsed' delay ri
-- This function allows action to toggle between slow and fast retry intervals.
withRetryLock2 :: forall m. MonadIO m => RetryInterval2 -> TMVar () -> ((RetryIntervalMode -> m ()) -> m ()) -> m ()
withRetryLock2 RetryInterval2 {riSlow, riFast} lock action =
callAction (0, initialInterval riSlow) (0, initialInterval riFast)
where
callAction :: (Int, Int) -> (Int, Int) -> m ()
callAction slow fast = action loop
where
loop = \case
RISlow -> run slow riSlow (`callAction` fast)
RIFast -> run fast riFast (callAction slow)
run (elapsed, delay) ri call = do
wait delay
let elapsed' = elapsed + delay
call (elapsed', nextDelay elapsed' delay ri)
wait delay = do
waiting <- newTVarIO True
_ <- liftIO . forkIO $ do
threadDelay delay
atomically $ whenM (readTVar waiting) $ void $ tryPutTMVar lock ()
atomically $ do
takeTMVar lock
writeTVar waiting False
nextDelay :: Int -> Int -> RetryInterval -> Int
nextDelay elapsed delay RetryInterval {increaseAfter, maxInterval} =
if elapsed < increaseAfter || delay == maxInterval
then delay
else min (delay * 3 `div` 2) maxInterval
+4 -4
View File
@@ -40,9 +40,9 @@ getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
where
addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs'
getDelSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue]
getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty)
getDelSrvQueues :: SMPServer -> TRcvQueues -> STM ([RcvQueue], Set ConnId)
getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ (([], S.empty), M.empty)
where
addQ (removed, qs') rq@RcvQueue {server, rcvId}
| srv == server = (rq : removed, qs')
addQ (removed@(remQs, remConns), qs') rq@RcvQueue {connId, server, rcvId}
| srv == server = ((rq : remQs, S.insert connId remConns), qs')
| otherwise = (removed, M.insert (server, rcvId) rq qs')
+4 -9
View File
@@ -1,7 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DisambiguateRecordFields #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
@@ -94,7 +93,7 @@ import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client (SocksProxy, TransportClientConfig (..), TransportHost (..), runTransportClient)
import Simplex.Messaging.Transport.KeepAlive
import Simplex.Messaging.Transport.WebSockets (WS)
import Simplex.Messaging.Util (bshow, liftError, raceAny_)
import Simplex.Messaging.Util (bshow, raceAny_)
import Simplex.Messaging.Version
import System.Timeout (timeout)
@@ -541,7 +540,7 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ, tcpTimeout}} pKey
mkTransmission :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> ClientCommand msg -> ExceptT ProtocolClientError IO (SentRawTransmission, TMVar (Response msg))
mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCorrId, sentCommands}} (pKey, qId, cmd) = do
corrId <- liftIO $ atomically getNextCorrId
t <- signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd)
let t = signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd)
r <- liftIO . atomically $ mkRequest corrId
pure (t, r)
where
@@ -549,12 +548,8 @@ mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCo
getNextCorrId = do
i <- stateTVar clientCorrId $ \i -> (i, i + 1)
pure . CorrId $ bshow i
signTransmission :: ByteString -> ExceptT ProtocolClientError IO SentRawTransmission
signTransmission t = case pKey of
Nothing -> pure (Nothing, t)
Just pk -> do
sig <- liftError PCESignatureError $ C.sign pk t
return (Just sig, t)
signTransmission :: ByteString -> SentRawTransmission
signTransmission t = ((`C.sign` t) <$> pKey, t)
mkRequest :: CorrId -> STM (TMVar (Response msg))
mkRequest corrId = do
r <- newEmptyTMVar
+8 -10
View File
@@ -67,7 +67,9 @@ module Simplex.Messaging.Crypto
-- * key encoding/decoding
encodePubKey,
decodePubKey,
encodePrivKey,
decodePrivKey,
pubKeyBytes,
-- * sign/verify
@@ -888,12 +890,12 @@ cryptoFailable = liftEither . first AESCipherError . CE.eitherCryptoError
-- | Message signing.
--
-- Used by SMP clients to sign SMP commands and by SMP agents to sign messages.
sign' :: SignatureAlgorithm a => PrivateKey a -> ByteString -> ExceptT CryptoError IO (Signature a)
sign' (PrivateKeyEd25519 pk k) msg = pure . SignatureEd25519 $ Ed25519.sign pk k msg
sign' (PrivateKeyEd448 pk k) msg = pure . SignatureEd448 $ Ed448.sign pk k msg
sign' :: SignatureAlgorithm a => PrivateKey a -> ByteString -> Signature a
sign' (PrivateKeyEd25519 pk k) msg = SignatureEd25519 $ Ed25519.sign pk k msg
sign' (PrivateKeyEd448 pk k) msg = SignatureEd448 $ Ed448.sign pk k msg
sign :: APrivateSignKey -> ByteString -> ExceptT CryptoError IO ASignature
sign (APrivateSignKey a k) = fmap (ASignature a) . sign' k
sign :: APrivateSignKey -> ByteString -> ASignature
sign (APrivateSignKey a k) = ASignature a . sign' k
-- | Signature verification.
--
@@ -962,11 +964,7 @@ pseudoRandomCbNonce :: TVar ChaChaDRG -> STM CbNonce
pseudoRandomCbNonce gVar = CbNonce <$> pseudoRandomBytes 24 gVar
pseudoRandomBytes :: Int -> TVar ChaChaDRG -> STM ByteString
pseudoRandomBytes n gVar = do
g <- readTVar gVar
let (bytes, g') = randomBytesGenerate n g
writeTVar gVar g'
return bytes
pseudoRandomBytes n gVar = stateTVar gVar $ randomBytesGenerate n
instance Encoding CbNonce where
smpEncode = unCbNonce
+17 -17
View File
@@ -257,24 +257,27 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do
(tkn@NtfTknData {ntfTknId, token = DeviceToken pp _, tknStatus}, ntf) <- atomically (readTBQueue pushQ)
liftIO $ logDebug $ "sending push notification to " <> T.pack (show pp)
status <- readTVarIO tknStatus
case (status, ntf) of
(_, PNVerification _) ->
-- TODO check token status
deliverNotification pp tkn ntf >>= \case
Right _ -> do
status_ <- atomically $ stateTVar tknStatus $ \status' -> if status' == NTActive then (Nothing, NTActive) else (Just NTConfirmed, NTConfirmed)
forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status'
_ -> pure ()
(NTActive, PNCheckMessages) ->
case ntf of
PNVerification _
| status /= NTInvalid && status /= NTExpired ->
deliverNotification pp tkn ntf >>= \case
Right _ -> do
status_ <- atomically $ stateTVar tknStatus $ \status' -> if status' == NTActive then (Nothing, NTActive) else (Just NTConfirmed, NTConfirmed)
forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status'
_ -> pure ()
| otherwise -> logError "bad notification token status"
PNCheckMessages -> checkActiveTkn status $ do
void $ deliverNotification pp tkn ntf
(NTActive, PNMessage {}) -> do
PNMessage {} -> checkActiveTkn status $ do
stats <- asks serverStats
atomically $ updatePeriodStats (activeTokens stats) ntfTknId
void $ deliverNotification pp tkn ntf
incNtfStat ntfDelivered
_ ->
liftIO $ logError "bad notification token status"
where
checkActiveTkn :: NtfTknStatus -> M () -> M ()
checkActiveTkn status action
| status == NTActive = action
| otherwise = liftIO $ logError "bad notification token status"
deliverNotification :: PushProvider -> NtfTknData -> PushNotification -> M (Either PushProviderError ())
deliverNotification pp tkn@NtfTknData {ntfTknId, tknStatus} ntf = do
deliver <- liftIO $ getPushClient s pp
@@ -361,13 +364,11 @@ verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
s_ <- atomically $ findNtfSubscription st smpQueue
case s_ of
Nothing -> do
-- TODO move active token check here to differentiate error
t_ <- atomically $ getActiveNtfToken st tknId
verifyToken' t_ $ VRVerified (NtfReqNew corrId (ANE SSubscription sub))
Just s@NtfSubData {tokenId = subTknId} ->
if subTknId == tknId
then do
-- TODO move active token check here to differentiate error
t_ <- atomically $ getActiveNtfToken st subTknId
verifyToken' t_ $ verifiedSubCmd s c
else pure $ maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
@@ -375,7 +376,6 @@ verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
s_ <- atomically $ getNtfSubscription st entId
case s_ of
Just s@NtfSubData {tokenId = subTknId} -> do
-- TODO move active token check here to differentiate error
t_ <- atomically $ getActiveNtfToken st subTknId
verifyToken' t_ $ verifiedSubCmd s c
_ -> pure $ maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
@@ -512,7 +512,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
(corrId,subId,) <$> case cmd of
SNEW (NewNtfSub _ _ notifierKey) -> do
logDebug "SNEW - existing subscription"
-- TODO retry if subscription failed, if pending or AUTH do nothing
-- possible improvement: retry if subscription failed, if pending or AUTH do nothing
pure $
if notifierKey == registeredNKey
then NRSubId subId
@@ -548,7 +548,7 @@ withNtfLog action = liftIO . mapM_ action =<< asks storeLog
incNtfStat :: (NtfServerStats -> TVar Int) -> M ()
incNtfStat statSel = do
stats <- asks serverStats
atomically $ modifyTVar (statSel stats) (+ 1)
atomically $ modifyTVar' (statSel stats) (+ 1)
saveServerStats :: M ()
saveServerStats =
@@ -27,6 +27,9 @@ import System.FilePath (combine)
import System.IO (BufferMode (..), hSetBuffering, stderr, stdout)
import Text.Read (readMaybe)
ntfServerVersion :: String
ntfServerVersion = "1.3.0"
ntfServerCLI :: FilePath -> FilePath -> IO ()
ntfServerCLI cfgPath logPath =
getCliCommand' (cliCommandP cfgPath logPath iniFile) serverVersion >>= \case
@@ -45,7 +48,7 @@ ntfServerCLI cfgPath logPath =
putStrLn "Deleted configuration and log files"
where
iniFile = combine cfgPath "ntf-server.ini"
serverVersion = "SMP notifications server v1.2.0"
serverVersion = "SMP notifications server v" <> ntfServerVersion
defaultServerPort = "443"
executableName = "ntf-server"
storeLogFilePath = combine logPath "ntf-server-store.log"
@@ -97,7 +97,7 @@ readECPrivateKey f = do
data PushNotification
= PNVerification NtfRegCode
| PNMessage PNMessageData
| PNAlert Text
-- | PNAlert Text
| PNCheckMessages
deriving (Show)
@@ -287,14 +287,14 @@ apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case
PNMessage pnMessageData ->
encrypt (strEncode pnMessageData) $ \ntfData ->
apn apnMutableContent . Just $ J.object ["nonce" .= nonce, "message" .= ntfData]
PNAlert text -> Right $ apn (apnAlert $ APNSAlertText text) Nothing
-- PNAlert text -> Right $ apn (apnAlert $ APNSAlertText text) Nothing
PNCheckMessages -> Right $ apn APNSBackground {contentAvailable = 1} . Just $ J.object ["checkMessages" .= True]
where
encrypt :: ByteString -> (Text -> APNSNotification) -> Either C.CryptoError APNSNotification
encrypt ntfData f = f . safeDecodeUtf8 . U.encode <$> C.cbEncrypt tknDhSecret nonce ntfData paddedLen
apn aps notificationData = APNSNotification {aps, notificationData}
apnMutableContent = APNSMutableContent {mutableContent = 1, alert = APNSAlertText "Encrypted message or another app event", category = Just ntfCategoryCheckMessage}
apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing}
-- apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing}
apnsRequest :: APNSPushClient -> ByteString -> APNSNotification -> IO Request
apnsRequest c tkn ntf@APNSNotification {aps} = do
@@ -70,14 +70,14 @@ getNtfServerStatsData s = do
setNtfServerStats :: NtfServerStats -> NtfServerStatsData -> STM ()
setNtfServerStats s d = do
writeTVar (fromTime (s :: NtfServerStats)) (_fromTime (d :: NtfServerStatsData))
writeTVar (tknCreated s) (_tknCreated d)
writeTVar (tknVerified s) (_tknVerified d)
writeTVar (tknDeleted s) (_tknDeleted d)
writeTVar (subCreated s) (_subCreated d)
writeTVar (subDeleted s) (_subDeleted d)
writeTVar (ntfReceived s) (_ntfReceived d)
writeTVar (ntfDelivered s) (_ntfDelivered d)
writeTVar (fromTime (s :: NtfServerStats)) $! _fromTime (d :: NtfServerStatsData)
writeTVar (tknCreated s) $! _tknCreated d
writeTVar (tknVerified s) $! _tknVerified d
writeTVar (tknDeleted s) $! _tknDeleted d
writeTVar (subCreated s) $! _subCreated d
writeTVar (subDeleted s) $! _subDeleted d
writeTVar (ntfReceived s) $! _ntfReceived d
writeTVar (ntfDelivered s) $! _ntfDelivered d
setPeriodStats (activeTokens s) (_activeTokens d)
setPeriodStats (activeSubs s) (_activeSubs d)
+85 -42
View File
@@ -141,6 +141,7 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (isPrint, isSpace)
import Data.Functor (($>))
import Data.Kind
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
@@ -303,12 +304,17 @@ data RcvMessage = RcvMessage
deriving (Eq, Show)
-- | received message without server/recipient encryption
data Message = Message
{ msgId :: MsgId,
msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: C.MaxLenBS MaxMessageLen
}
data Message
= Message
{ msgId :: MsgId,
msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: C.MaxLenBS MaxMessageLen
}
| MessageQuota
{ msgId :: MsgId,
msgTs :: SystemTime
}
instance StrEncoding RcvMessage where
strEncode RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} =
@@ -328,44 +334,72 @@ instance StrEncoding RcvMessage where
newtype EncRcvMsgBody = EncRcvMsgBody ByteString
deriving (Eq, Show)
data RcvMsgBody = RcvMsgBody
{ msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: C.MaxLenBS MaxMessageLen
}
data RcvMsgBody
= RcvMsgBody
{ msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: C.MaxLenBS MaxMessageLen
}
| RcvMsgQuota
{ msgTs :: SystemTime
}
msgQuotaTag :: ByteString
msgQuotaTag = "QUOTA"
encodeRcvMsgBody :: RcvMsgBody -> C.MaxLenBS MaxRcvMessageLen
encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody} =
let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ')
in C.appendMaxLenBS rcvMeta msgBody
encodeRcvMsgBody = \case
RcvMsgBody {msgTs, msgFlags, msgBody} ->
let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ')
in C.appendMaxLenBS rcvMeta msgBody
RcvMsgQuota {msgTs} ->
C.unsafeMaxLenBS $ msgQuotaTag <> " " <> smpEncode msgTs
data ClientRcvMsgBody = ClientRcvMsgBody
{ msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: ByteString
}
data ClientRcvMsgBody
= ClientRcvMsgBody
{ msgTs :: SystemTime,
msgFlags :: MsgFlags,
msgBody :: ByteString
}
| ClientRcvMsgQuota
{ msgTs :: SystemTime
}
clientRcvMsgBodyP :: Parser ClientRcvMsgBody
clientRcvMsgBodyP = do
msgTs <- smpP
msgFlags <- smpP
Tail msgBody <- _smpP
pure ClientRcvMsgBody {msgTs, msgFlags, msgBody}
clientRcvMsgBodyP = msgQuotaP <|> msgBodyP
where
msgQuotaP = A.string msgQuotaTag *> (ClientRcvMsgQuota <$> _smpP)
msgBodyP = do
msgTs <- smpP
msgFlags <- smpP
Tail msgBody <- _smpP
pure ClientRcvMsgBody {msgTs, msgFlags, msgBody}
instance StrEncoding Message where
strEncode Message {msgId, msgTs, msgFlags, msgBody} =
B.unwords
[ strEncode msgId,
strEncode msgTs,
"flags=" <> strEncode msgFlags,
strEncode msgBody
]
strEncode = \case
Message {msgId, msgTs, msgFlags, msgBody} ->
B.unwords
[ strEncode msgId,
strEncode msgTs,
"flags=" <> strEncode msgFlags,
strEncode msgBody
]
MessageQuota {msgId, msgTs} ->
B.unwords
[ strEncode msgId,
strEncode msgTs,
"quota"
]
strP = do
msgId <- strP_
msgTs <- strP_
msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags
msgBody <- strP
pure Message {msgId, msgTs, msgFlags, msgBody}
msgQuotaP msgId msgTs <|> msgP msgId msgTs
where
msgQuotaP msgId msgTs = "quota" $> MessageQuota {msgId, msgTs}
msgP msgId msgTs = do
msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags
msgBody <- strP
pure Message {msgId, msgTs, msgFlags, msgBody}
type EncNMsgMeta = ByteString
@@ -377,7 +411,9 @@ data SMPMsgMeta = SMPMsgMeta
deriving (Show)
rcvMessageMeta :: MsgId -> ClientRcvMsgBody -> SMPMsgMeta
rcvMessageMeta msgId ClientRcvMsgBody {msgTs, msgFlags} = SMPMsgMeta {msgId, msgTs, msgFlags}
rcvMessageMeta msgId = \case
ClientRcvMsgBody {msgTs, msgFlags} -> SMPMsgMeta {msgId, msgTs, msgFlags}
ClientRcvMsgQuota {msgTs} -> SMPMsgMeta {msgId, msgTs, msgFlags = noMsgFlags}
data NMsgMeta = NMsgMeta
{ msgId :: MsgId,
@@ -860,7 +896,7 @@ data ErrorType
| -- | internal server error
INTERNAL
| -- | used internally, never returned by the server (to be removed)
DUPLICATE_ -- TODO remove, not part of SMP protocol
DUPLICATE_ -- not part of SMP protocol, used internally
deriving (Eq, Generic, Read, Show)
instance ToJSON ErrorType where
@@ -1132,26 +1168,33 @@ instance Encoding CommandError where
_ -> fail "bad command error type"
-- | Send signed SMP transmission to TCP transport.
tPut :: Transport c => THandle c -> NonEmpty SentRawTransmission -> IO (NonEmpty (Either TransportError ()))
tPut :: Transport c => THandle c -> NonEmpty SentRawTransmission -> IO [Either TransportError ()]
tPut th trs
| batch th = tPutBatch [] $ L.map tEncode trs
| otherwise = forM trs $ tPutBlock th . tEncode
| otherwise = forM (L.toList trs) $ tPutLog . tEncode
where
tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO (NonEmpty (Either TransportError ()))
tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO [Either TransportError ()]
tPutBatch rs ts = do
let (n, s, ts_) = encodeBatch 0 "" ts
r <- if n == 0 then pure [Left TELargeMsg] else replicate n <$> tPutBlock th (lenEncode n `B.cons` s)
r <- if n == 0 then largeMsg else replicate n <$> tPutLog (lenEncode n `B.cons` s)
let rs' = rs <> r
case ts_ of
Just ts' -> tPutBatch rs' ts'
_ -> pure $ L.fromList rs'
_ -> pure rs'
largeMsg = putStrLn "tPut error: large message" >> pure [Left TELargeMsg]
tPutLog s = do
r <- tPutBlock th s
case r of
Left e -> putStrLn ("tPut error: " <> show e)
_ -> pure ()
pure r
encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString))
encodeBatch n s ts@(t :| ts_)
| n == 255 = (n, s, Just ts)
| otherwise =
let s' = s <> smpEncode (Large t)
n' = n + 1
in if B.length s' > blockSize th - 1
in if B.length s' > blockSize th - 1 -- one byte is reserved for the number of messages in the batch
then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts
else case L.nonEmpty ts_ of
Just ts' -> encodeBatch n' s' ts'
+96 -48
View File
@@ -65,9 +65,9 @@ import Simplex.Messaging.Protocol
import Simplex.Messaging.Server.Env.STM
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Server.MsgStore
import Simplex.Messaging.Server.MsgStore.STM (MsgQueue)
import Simplex.Messaging.Server.MsgStore.STM
import Simplex.Messaging.Server.QueueStore
import Simplex.Messaging.Server.QueueStore.STM (QueueStore)
import Simplex.Messaging.Server.QueueStore.STM
import Simplex.Messaging.Server.Stats
import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.TMap (TMap)
@@ -104,8 +104,8 @@ type M a = ReaderT Env IO a
smpServer :: TMVar Bool -> ServerConfig -> M ()
smpServer started cfg@ServerConfig {transports, logTLSErrors} = do
s <- asks server
restoreServerStats
restoreServerMessages
restoreServerStats
raceAny_
( serverThread s subscribedQ subscribers subscriptions cancelSub :
serverThread s ntfSubscribedQ notifiers ntfSubscriptions (\_ -> pure ()) :
@@ -174,7 +174,7 @@ smpServer started cfg@ServerConfig {transports, logTLSErrors} = do
initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime
liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath
threadDelay $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0)
ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, activeQueues} <- asks serverStats
ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount} <- asks serverStats
let interval = 1000000 * logInterval
withFile statsFilePath AppendMode $ \h -> liftIO $ do
hSetBuffering h LineBuffering
@@ -187,7 +187,31 @@ smpServer started cfg@ServerConfig {transports, logTLSErrors} = do
msgSent' <- atomically $ swapTVar msgSent 0
msgRecv' <- atomically $ swapTVar msgRecv 0
ps <- atomically $ periodStatCounts activeQueues ts
hPutStrLn h $ intercalate "," [iso8601Show $ utctDay fromTime', show qCreated', show qSecured', show qDeleted', show msgSent', show msgRecv', dayCount ps, weekCount ps, monthCount ps]
msgSentNtf' <- atomically $ swapTVar msgSentNtf 0
msgRecvNtf' <- atomically $ swapTVar msgRecvNtf 0
psNtf <- atomically $ periodStatCounts activeQueuesNtf ts
qCount' <- readTVarIO qCount
msgCount' <- readTVarIO msgCount
hPutStrLn h $
intercalate
","
[ iso8601Show $ utctDay fromTime',
show qCreated',
show qSecured',
show qDeleted',
show msgSent',
show msgRecv',
dayCount ps,
weekCount ps,
monthCount ps,
show msgSentNtf',
show msgRecvNtf',
dayCount psNtf,
weekCount psNtf,
monthCount psNtf,
show qCount',
show msgCount'
]
threadDelay interval
runClient :: Transport c => TProxy c -> c -> M ()
@@ -256,7 +280,6 @@ receive th Client {rcvQ, sndQ, activeAt} = forever $ do
send :: Transport c => THandle c -> Client -> IO ()
send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do
ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ
-- TODO the line below can return Lefts, but we ignore it and do not disconnect the client
void . liftIO . tPut h $ L.map ((Nothing,) . encodeTransmission v sessionId) ts
atomically . writeTVar activeAt =<< liftIO getSystemTime
where
@@ -387,7 +410,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
Right _ -> do
withLog (`logCreateById` rId)
stats <- asks serverStats
atomically $ modifyTVar (qCreated stats) (+ 1)
atomically $ modifyTVar' (qCreated stats) (+ 1)
atomically $ modifyTVar' (qCount stats) (+ 1)
subscribeQueue qr rId $> IDS (qik ids)
logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO ()
@@ -405,7 +429,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
secureQueue_ st sKey = time "KEY" $ do
withLog $ \s -> logSecureQueue s queueId sKey
stats <- asks serverStats
atomically $ modifyTVar (qSecured stats) (+ 1)
atomically $ modifyTVar' (qSecured stats) (+ 1)
atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey
addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (Transmission BrokerMsg)
@@ -510,12 +534,12 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
q <- getStoreMsgQueue "ACK" queueId
case s of
Sub {subThread = ProhibitSub} -> do
msgDeleted <- atomically $ tryDelMsg q msgId
when msgDeleted updateStats
deletedMsg_ <- atomically $ tryDelMsg q msgId
mapM_ updateStats deletedMsg_
pure ok
_ -> do
(msgDeleted, msg_) <- atomically $ tryDelPeekMsg q msgId
when msgDeleted updateStats
(deletedMsg_, msg_) <- atomically $ tryDelPeekMsg q msgId
mapM_ updateStats deletedMsg_
deliverMessage "ACK" qr queueId sub q msg_
_ -> pure $ err NO_MSG
where
@@ -526,11 +550,17 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
if msgId == msgId' || B.null msgId
then pure $ Just s
else putTMVar delivered msgId' $> Nothing
updateStats :: m ()
updateStats = do
stats <- asks serverStats
atomically $ modifyTVar (msgRecv stats) (+ 1)
atomically $ updatePeriodStats (activeQueues stats) queueId
updateStats :: Message -> m ()
updateStats = \case
MessageQuota {} -> pure ()
Message {msgFlags} -> do
stats <- asks serverStats
atomically $ modifyTVar' (msgRecv stats) (+ 1)
atomically $ modifyTVar' (msgCount stats) (+ 1)
atomically $ updatePeriodStats (activeQueues stats) queueId
when (notification msgFlags) $ do
atomically $ modifyTVar' (msgRecvNtf stats) (+ 1)
atomically $ updatePeriodStats (activeQueuesNtf stats) queueId
sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
sendMessage qr msgFlags msgBody
@@ -538,20 +568,25 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
| otherwise = case status qr of
QueueOff -> return $ err AUTH
QueueActive ->
mapM mkMessage (C.maxLenBS msgBody) >>= \case
case C.maxLenBS msgBody of
Left _ -> pure $ err LARGE_MSG
Right msg -> do
resp@(_, _, sent) <- time "SEND" $ do
Right body -> do
msg_ <- time "SEND" $ do
q <- getStoreMsgQueue "SEND" $ recipientId qr
expireMessages q
atomically $ ifM (isFull q) (pure $ err QUOTA) (writeMsg q msg $> ok)
when (sent == OK) . time "SEND ok" $ do
when (notification msgFlags) $
atomically . trySendNotification msg =<< asks idsDrg
stats <- asks serverStats
atomically $ modifyTVar (msgSent stats) (+ 1)
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
pure resp
atomically . writeMsg q =<< mkMessage body
case msg_ of
Nothing -> pure $ err QUOTA
Just msg -> time "SEND ok" $ do
stats <- asks serverStats
when (notification msgFlags) $ do
atomically . trySendNotification msg =<< asks idsDrg
atomically $ modifyTVar' (msgSentNtf stats) (+ 1)
atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr)
atomically $ modifyTVar' (msgSent stats) (+ 1)
atomically $ modifyTVar' (msgCount stats) (subtract 1)
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
pure ok
where
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
mkMessage body = do
@@ -572,12 +607,14 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM ()
writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} =
unlessM (isFullTBQueue q) $ do
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg
writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)]
unlessM (isFullTBQueue q) $ case msg of
Message {msgId, msgTs} -> do
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg
writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)]
_ -> pure ()
mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do
mkMessageNotification :: ByteString -> SystemTime -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
mkMessageNotification msgId msgTs rcvNtfDhSecret ntfNonceDrg = do
cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg
let msgMeta = NMsgMeta {msgId, msgTs}
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
@@ -596,9 +633,9 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
where
forkSub :: m ()
forkSub = do
atomically . modifyTVar sub $ \s -> s {subThread = SubPending}
atomically . modifyTVar' sub $ \s -> s {subThread = SubPending}
t <- mkWeakThreadId =<< forkIO subscriber
atomically . modifyTVar sub $ \case
atomically . modifyTVar' sub $ \case
s@Sub {subThread = SubPending} -> s {subThread = SubThread t}
s -> s
where
@@ -609,23 +646,28 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)]
s <- readTVar sub
void $ setDelivered s msg
writeTVar sub s {subThread = NoSub}
writeTVar sub $! s {subThread = NoSub}
time :: T.Text -> m a -> m a
time name = timed name queueId
encryptMsg :: QueueRec -> Message -> RcvMessage
encryptMsg qr Message {msgId, msgTs, msgFlags, msgBody}
| thVersion == 1 || thVersion == 2 = encrypt msgBody
| otherwise = encrypt $ encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody}
encryptMsg qr msg = case msg of
Message {msgFlags, msgBody}
| thVersion == 1 || thVersion == 2 -> encrypt msgFlags msgBody
| otherwise -> encrypt msgFlags $ encodeRcvMsgBody RcvMsgBody {msgTs = msgTs', msgFlags, msgBody}
MessageQuota {} ->
encrypt noMsgFlags $ encodeRcvMsgBody (RcvMsgQuota msgTs')
where
encrypt :: KnownNat i => C.MaxLenBS i -> RcvMessage
encrypt body =
let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId) body
in RcvMessage msgId msgTs msgFlags encBody
encrypt :: KnownNat i => MsgFlags -> C.MaxLenBS i -> RcvMessage
encrypt msgFlags body =
let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId') body
in RcvMessage msgId' msgTs' msgFlags encBody
msgId' = msgId (msg :: Message)
msgTs' = msgTs (msg :: Message)
setDelivered :: Sub -> Message -> STM Bool
setDelivered s Message {msgId} = tryPutTMVar (delivered s) msgId
setDelivered s msg = tryPutTMVar (delivered s) $ msgId (msg :: Message)
getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue
getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do
@@ -638,7 +680,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
withLog (`logDeleteQueue` queueId)
ms <- asks msgStore
stats <- asks serverStats
atomically $ modifyTVar (qDeleted stats) (+ 1)
atomically $ modifyTVar' (qDeleted stats) (+ 1)
atomically $ modifyTVar' (qCount stats) (subtract 1)
atomically $
deleteQueue st queueId >>= \case
Left e -> pure $ err e
@@ -717,8 +760,11 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= mapM_ restoreMessages
addToMsgQueue rId msg = do
full <- atomically $ do
q <- getMsgQueue ms rId quota
ifM (isFull q) (pure True) (writeMsg q msg $> False)
when full . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (msgId (msg :: Message))
isNothing <$> writeMsg q msg
case msg of
Message {} ->
when full . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (msgId (msg :: Message))
MessageQuota {} -> pure ()
updateMsgV1toV3 QueueRec {rcvDhSecret} RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} = do
let nonce = C.cbNonce msgId
msgBody <- liftEither . first (msgErr "v1 message decryption") $ C.maxLenBS =<< C.cbDecrypt rcvDhSecret nonce body
@@ -744,7 +790,9 @@ restoreServerStats = asks (serverStatsBackupFile . config) >>= mapM_ restoreStat
liftIO (strDecode <$> B.readFile f) >>= \case
Right d -> do
s <- asks serverStats
atomically $ setServerStats s d
_qCount <- fmap (length . M.keys) . readTVarIO . queues =<< asks queueStore
_msgCount <- foldM (\n q -> (n +) <$> readTVarIO (size q)) 0 =<< readTVarIO =<< asks msgStore
atomically $ setServerStats s d {_qCount, _msgCount}
renameFile f $ f <> ".bak"
logInfo "server stats restored"
Left e -> do
+3 -3
View File
@@ -39,7 +39,7 @@ data ServerConfig = ServerConfig
{ transports :: [(ServiceName, ATransport)],
tbqSize :: Natural,
serverTbqSize :: Natural,
msgQueueQuota :: Natural,
msgQueueQuota :: Int,
queueIdBytes :: Int,
msgIdBytes :: Int,
storeLogFile :: Maybe FilePath,
@@ -164,8 +164,8 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile,
(qs, s') <- liftIO $ readWriteStoreLog s
atomically $ do
writeTVar queues =<< mapM newTVar qs
writeTVar senders $ M.foldr' addSender M.empty qs
writeTVar notifiers $ M.foldr' addNotifier M.empty qs
writeTVar senders $! M.foldr' addSender M.empty qs
writeTVar notifiers $! M.foldr' addNotifier M.empty qs
pure s'
addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId
addSender q = M.insert (senderId q) (recipientId q)
+2 -2
View File
@@ -160,8 +160,8 @@ smpServerCLI cfgPath logPath =
serverConfig =
ServerConfig
{ transports = iniTransports ini,
tbqSize = 16,
serverTbqSize = 64,
tbqSize = 32,
serverTbqSize = 128,
msgQueueQuota = 128,
queueIdBytes = 24,
msgIdBytes = 24, -- must be at least 24 bytes, it is used as 192-bit nonce for XSalsa20
+1 -18
View File
@@ -1,14 +1,11 @@
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Simplex.Messaging.Server.MsgStore where
import Control.Applicative ((<|>))
import Data.Int (Int64)
import Numeric.Natural
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (Message (..), MsgId, RcvMessage (..), RecipientId)
import Simplex.Messaging.Protocol (Message (..), RcvMessage (..), RecipientId)
data MsgLogRecord = MLRv3 RecipientId Message | MLRv1 RecipientId RcvMessage
@@ -17,17 +14,3 @@ instance StrEncoding MsgLogRecord where
MLRv3 rId msg -> strEncode (Str "v3", rId, msg)
MLRv1 rId msg -> strEncode (rId, msg)
strP = "v3 " *> (MLRv3 <$> strP_ <*> strP) <|> MLRv1 <$> strP_ <*> strP
class MonadMsgStore s q m | s -> q where
getMsgQueue :: s -> RecipientId -> Natural -> m q
delMsgQueue :: s -> RecipientId -> m ()
flushMsgQueue :: s -> RecipientId -> m [Message]
class MonadMsgQueue q m where
isFull :: q -> m Bool
writeMsg :: q -> Message -> m () -- non blocking
tryPeekMsg :: q -> m (Maybe Message) -- non blocking
peekMsg :: q -> m Message -- blocking
tryDelMsg :: q -> MsgId -> m Bool -- non blocking
tryDelPeekMsg :: q -> MsgId -> m (Bool, Maybe Message) -- atomic delete (== read) last and peek next message, if available
deleteExpiredMsgs :: q -> Int64 -> m ()
+86 -49
View File
@@ -1,83 +1,120 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TupleSections #-}
module Simplex.Messaging.Server.MsgStore.STM where
module Simplex.Messaging.Server.MsgStore.STM
( STMMsgStore,
MsgQueue (..),
newMsgStore,
getMsgQueue,
delMsgQueue,
flushMsgQueue,
writeMsg,
tryPeekMsg,
peekMsg,
tryDelMsg,
tryDelPeekMsg,
deleteExpiredMsgs,
)
where
import Control.Concurrent.STM.TBQueue (flushTBQueue)
import Control.Concurrent.STM.TQueue (flushTQueue)
import Control.Monad (when)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
import Data.Int (Int64)
import Data.Time.Clock.System (SystemTime (systemSeconds))
import Numeric.Natural
import Simplex.Messaging.Protocol (Message (..), MsgId, RecipientId)
import Simplex.Messaging.Server.MsgStore
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import UnliftIO.STM
newtype MsgQueue = MsgQueue {msgQueue :: TBQueue Message}
data MsgQueue = MsgQueue
{ msgQueue :: TQueue Message,
quota :: Int,
canWrite :: TVar Bool,
size :: TVar Int
}
type STMMsgStore = TMap RecipientId MsgQueue
newMsgStore :: STM STMMsgStore
newMsgStore = TM.empty
instance MonadMsgStore STMMsgStore MsgQueue STM where
getMsgQueue :: STMMsgStore -> RecipientId -> Natural -> STM MsgQueue
getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st
where
newQ = do
q <- MsgQueue <$> newTBQueue quota
TM.insert rId q st
return q
getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue
getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st
where
newQ = do
msgQueue <- newTQueue
canWrite <- newTVar True
size <- newTVar 0
let q = MsgQueue {msgQueue, quota, canWrite, size}
TM.insert rId q st
pure q
delMsgQueue :: STMMsgStore -> RecipientId -> STM ()
delMsgQueue st rId = TM.delete rId st
delMsgQueue :: STMMsgStore -> RecipientId -> STM ()
delMsgQueue st rId = TM.delete rId st
flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
flushMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (flushTBQueue . msgQueue)
flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue)
instance MonadMsgQueue MsgQueue STM where
isFull :: MsgQueue -> STM Bool
isFull = isFullTBQueue . msgQueue
writeMsg :: MsgQueue -> Message -> STM (Maybe Message)
writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do
canWrt <- readTVar canWrite
empty <- isEmptyTQueue q
if canWrt || empty
then do
canWrt' <- (quota >) <$> readTVar size
writeTVar canWrite $! canWrt'
modifyTVar' size (+ 1)
if canWrt'
then writeTQueue q msg $> Just msg
else writeTQueue q msgQuota $> Nothing
else pure Nothing
where
msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg}
writeMsg :: MsgQueue -> Message -> STM ()
writeMsg = writeTBQueue . msgQueue
tryPeekMsg :: MsgQueue -> STM (Maybe Message)
tryPeekMsg = tryPeekTQueue . msgQueue
{-# INLINE tryPeekMsg #-}
tryPeekMsg :: MsgQueue -> STM (Maybe Message)
tryPeekMsg = tryPeekTBQueue . msgQueue
peekMsg :: MsgQueue -> STM Message
peekMsg = peekTQueue . msgQueue
{-# INLINE peekMsg #-}
peekMsg :: MsgQueue -> STM Message
peekMsg = peekTBQueue . msgQueue
tryDelMsg :: MsgQueue -> MsgId -> STM (Maybe Message)
tryDelMsg mq msgId' =
tryPeekMsg mq >>= \case
msg_@(Just msg)
| msgId msg == msgId' || B.null msgId' -> tryDeleteMsg mq >> pure msg_
| otherwise -> pure Nothing
_ -> pure Nothing
tryDelMsg :: MsgQueue -> MsgId -> STM Bool
tryDelMsg (MsgQueue q) msgId' =
tryPeekTBQueue q >>= \case
Just Message {msgId}
| msgId == msgId' || B.null msgId' -> tryReadTBQueue q $> True
| otherwise -> pure False
_ -> pure False
-- atomic delete (== read) last and peek next message if available
tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Maybe Message, Maybe Message)
tryDelPeekMsg mq msgId' =
tryPeekMsg mq >>= \case
msg_@(Just msg)
| msgId msg == msgId' || B.null msgId' -> (msg_,) <$> (tryDeleteMsg mq >> tryPeekMsg mq)
| otherwise -> pure (Nothing, msg_)
_ -> pure (Nothing, Nothing)
-- atomic delete (== read) last and peek next message if available
tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message)
tryDelPeekMsg (MsgQueue q) msgId' =
tryPeekTBQueue q >>= \case
msg_@(Just Message {msgId})
| msgId == msgId' || B.null msgId' -> (True,) <$> (tryReadTBQueue q >> tryPeekTBQueue q)
| otherwise -> pure (False, msg_)
_ -> pure (False, Nothing)
deleteExpiredMsgs :: MsgQueue -> Int64 -> STM ()
deleteExpiredMsgs (MsgQueue q) old = loop
where
loop = tryPeekTBQueue q >>= mapM_ delOldMsg
delOldMsg Message {msgTs} =
deleteExpiredMsgs :: MsgQueue -> Int64 -> STM ()
deleteExpiredMsgs mq old = loop
where
loop = tryPeekMsg mq >>= mapM_ delOldMsg
delOldMsg = \case
Message {msgTs} ->
when (systemSeconds msgTs < old) $
tryReadTBQueue q >> loop
tryDeleteMsg mq >> loop
_ -> pure ()
tryDeleteMsg :: MsgQueue -> STM ()
tryDeleteMsg MsgQueue {msgQueue = q, size} =
tryReadTQueue q >>= \case
Just _ -> modifyTVar' size (subtract 1)
_ -> pure ()
+10 -19
View File
@@ -9,20 +9,20 @@ import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol
data QueueRec = QueueRec
{ recipientId :: RecipientId,
recipientKey :: RcvPublicVerifyKey,
rcvDhSecret :: RcvDhSecret,
senderId :: SenderId,
senderKey :: Maybe SndPublicVerifyKey,
notifier :: Maybe NtfCreds,
status :: ServerQueueStatus
{ recipientId :: !RecipientId,
recipientKey :: !RcvPublicVerifyKey,
rcvDhSecret :: !RcvDhSecret,
senderId :: !SenderId,
senderKey :: !(Maybe SndPublicVerifyKey),
notifier :: !(Maybe NtfCreds),
status :: !ServerQueueStatus
}
deriving (Eq, Show)
data NtfCreds = NtfCreds
{ notifierId :: NotifierId,
notifierKey :: NtfPublicVerifyKey,
rcvNtfDhSecret :: RcvNtfDhSecret
{ notifierId :: !NotifierId,
notifierKey :: !NtfPublicVerifyKey,
rcvNtfDhSecret :: !RcvNtfDhSecret
}
deriving (Eq, Show)
@@ -33,12 +33,3 @@ instance StrEncoding NtfCreds where
pure NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}
data ServerQueueStatus = QueueActive | QueueOff deriving (Eq, Show)
class MonadQueueStore s m where
addQueue :: s -> QueueRec -> m (Either ErrorType ())
getQueue :: s -> SParty p -> QueueId -> m (Either ErrorType QueueRec)
secureQueue :: s -> RecipientId -> SndPublicVerifyKey -> m (Either ErrorType QueueRec)
addQueueNotifier :: s -> RecipientId -> NtfCreds -> m (Either ErrorType QueueRec)
deleteQueueNotifier :: s -> RecipientId -> m (Either ErrorType ())
suspendQueue :: s -> RecipientId -> m (Either ErrorType ())
deleteQueue :: s -> RecipientId -> m (Either ErrorType ())
+64 -55
View File
@@ -1,7 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
@@ -10,7 +9,18 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
module Simplex.Messaging.Server.QueueStore.STM where
module Simplex.Messaging.Server.QueueStore.STM
( QueueStore (..),
newQueueStore,
addQueue,
getQueue,
secureQueue,
addQueueNotifier,
deleteQueueNotifier,
suspendQueue,
deleteQueue,
)
where
import Control.Monad
import Data.Functor (($>))
@@ -34,66 +44,65 @@ newQueueStore = do
notifiers <- TM.empty
pure QueueStore {queues, senders, notifiers}
instance MonadQueueStore QueueStore STM where
addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ())
addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do
ifM hasId (pure $ Left DUPLICATE_) $ do
qVar <- newTVar q
TM.insert rId qVar queues
TM.insert sId rId senders
pure $ Right ()
where
hasId = (||) <$> TM.member rId queues <*> TM.member sId senders
addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ())
addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do
ifM hasId (pure $ Left DUPLICATE_) $ do
qVar <- newTVar q
TM.insert rId qVar queues
TM.insert sId rId senders
pure $ Right ()
where
hasId = (||) <$> TM.member rId queues <*> TM.member sId senders
getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec)
getQueue QueueStore {queues, senders, notifiers} party qId =
toResult <$> (mapM readTVar =<< getVar)
where
getVar = case party of
SRecipient -> TM.lookup qId queues
SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues)
SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues)
getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec)
getQueue QueueStore {queues, senders, notifiers} party qId =
toResult <$> (mapM readTVar =<< getVar)
where
getVar = case party of
SRecipient -> TM.lookup qId queues
SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues)
SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues)
secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec)
secureQueue QueueStore {queues} rId sKey =
withQueue rId queues $ \qVar ->
readTVar qVar >>= \q -> case senderKey q of
Just k -> pure $ if sKey == k then Just q else Nothing
_ ->
let q' = q {senderKey = Just sKey}
in writeTVar qVar q' $> Just q'
secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec)
secureQueue QueueStore {queues} rId sKey =
withQueue rId queues $ \qVar ->
readTVar qVar >>= \q -> case senderKey q of
Just k -> pure $ if sKey == k then Just q else Nothing
_ ->
let q' = q {senderKey = Just sKey}
in writeTVar qVar q' $> Just q'
addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec)
addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do
ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $
withQueue rId queues $ \qVar -> do
q <- readTVar qVar
forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId
writeTVar qVar q {notifier = Just ntfCreds}
TM.insert nId rId notifiers
pure $ Just q
deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ())
deleteQueueNotifier QueueStore {queues, notifiers} rId =
addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec)
addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do
ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $
withQueue rId queues $ \qVar -> do
q <- readTVar qVar
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
writeTVar qVar q {notifier = Nothing}
pure $ Just ()
forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId
writeTVar qVar $! q {notifier = Just ntfCreds}
TM.insert nId rId notifiers
pure $ Just q
suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
suspendQueue QueueStore {queues} rId =
withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just ()
deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ())
deleteQueueNotifier QueueStore {queues, notifiers} rId =
withQueue rId queues $ \qVar -> do
q <- readTVar qVar
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
writeTVar qVar $! q {notifier = Nothing}
pure $ Just ()
deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
deleteQueue QueueStore {queues, senders, notifiers} rId = do
TM.lookupDelete rId queues >>= \case
Just qVar ->
readTVar qVar >>= \q -> do
TM.delete (senderId q) senders
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
pure $ Right ()
_ -> pure $ Left AUTH
suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
suspendQueue QueueStore {queues} rId =
withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just ()
deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
deleteQueue QueueStore {queues, senders, notifiers} rId = do
TM.lookupDelete rId queues >>= \case
Just qVar ->
readTVar qVar >>= \q -> do
TM.delete (senderId q) senders
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
pure $ Right ()
_ -> pure $ Left AUTH
toResult :: Maybe a -> Either ErrorType a
toResult = maybe (Left AUTH) Right
+66 -27
View File
@@ -1,3 +1,4 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
@@ -5,7 +6,7 @@
module Simplex.Messaging.Server.Stats where
import Control.Applicative (optional)
import Control.Applicative (optional, (<|>))
import qualified Data.Attoparsec.ByteString.Char8 as A
import qualified Data.ByteString.Char8 as B
import Data.Set (Set)
@@ -24,7 +25,12 @@ data ServerStats = ServerStats
qDeleted :: TVar Int,
msgSent :: TVar Int,
msgRecv :: TVar Int,
activeQueues :: PeriodStats RecipientId
activeQueues :: PeriodStats RecipientId,
msgSentNtf :: TVar Int,
msgRecvNtf :: TVar Int,
activeQueuesNtf :: PeriodStats RecipientId,
qCount :: TVar Int,
msgCount :: TVar Int
}
data ServerStatsData = ServerStatsData
@@ -34,7 +40,12 @@ data ServerStatsData = ServerStatsData
_qDeleted :: Int,
_msgSent :: Int,
_msgRecv :: Int,
_activeQueues :: PeriodStatsData RecipientId
_activeQueues :: PeriodStatsData RecipientId,
_msgSentNtf :: Int,
_msgRecvNtf :: Int,
_activeQueuesNtf :: PeriodStatsData RecipientId,
_qCount :: Int,
_msgCount :: Int
}
newServerStats :: UTCTime -> STM ServerStats
@@ -46,7 +57,12 @@ newServerStats ts = do
msgSent <- newTVar 0
msgRecv <- newTVar 0
activeQueues <- newPeriodStats
pure ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, activeQueues}
msgSentNtf <- newTVar 0
msgRecvNtf <- newTVar 0
activeQueuesNtf <- newPeriodStats
qCount <- newTVar 0
msgCount <- newTVar 0
pure ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount}
getServerStatsData :: ServerStats -> STM ServerStatsData
getServerStatsData s = do
@@ -57,20 +73,30 @@ getServerStatsData s = do
_msgSent <- readTVar $ msgSent s
_msgRecv <- readTVar $ msgRecv s
_activeQueues <- getPeriodStatsData $ activeQueues s
pure ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _activeQueues}
_msgSentNtf <- readTVar $ msgSentNtf s
_msgRecvNtf <- readTVar $ msgRecvNtf s
_activeQueuesNtf <- getPeriodStatsData $ activeQueuesNtf s
_qCount <- readTVar $ qCount s
_msgCount <- readTVar $ msgCount s
pure ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _activeQueues, _msgSentNtf, _msgRecvNtf, _activeQueuesNtf, _qCount, _msgCount}
setServerStats :: ServerStats -> ServerStatsData -> STM ()
setServerStats s d = do
writeTVar (fromTime s) (_fromTime d)
writeTVar (qCreated s) (_qCreated d)
writeTVar (qSecured s) (_qSecured d)
writeTVar (qDeleted s) (_qDeleted d)
writeTVar (msgSent s) (_msgSent d)
writeTVar (msgRecv s) (_msgRecv d)
setPeriodStats (activeQueues s) (_activeQueues d)
writeTVar (fromTime s) $! _fromTime d
writeTVar (qCreated s) $! _qCreated d
writeTVar (qSecured s) $! _qSecured d
writeTVar (qDeleted s) $! _qDeleted d
writeTVar (msgSent s) $! _msgSent d
writeTVar (msgRecv s) $! _msgRecv d
setPeriodStats (activeQueuesNtf s) (_activeQueuesNtf d)
writeTVar (msgSentNtf s) $! _msgSentNtf d
writeTVar (msgRecvNtf s) $! _msgRecvNtf d
setPeriodStats (activeQueuesNtf s) (_activeQueuesNtf d)
writeTVar (qCount s) $! _qCount d
writeTVar (msgCount s) $! _qCount d
instance StrEncoding ServerStatsData where
strEncode ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _activeQueues} =
strEncode ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _msgSentNtf, _msgRecvNtf, _activeQueues, _activeQueuesNtf} =
B.unlines
[ "fromTime=" <> strEncode _fromTime,
"qCreated=" <> strEncode _qCreated,
@@ -78,8 +104,12 @@ instance StrEncoding ServerStatsData where
"qDeleted=" <> strEncode _qDeleted,
"msgSent=" <> strEncode _msgSent,
"msgRecv=" <> strEncode _msgRecv,
"msgSentNtf=" <> strEncode _msgSentNtf,
"msgRecvNtf=" <> strEncode _msgRecvNtf,
"activeQueues:",
strEncode _activeQueues
strEncode _activeQueues,
"activeQueuesNtf:",
strEncode _activeQueuesNtf
]
strP = do
_fromTime <- "fromTime=" *> strP <* A.endOfLine
@@ -88,15 +118,21 @@ instance StrEncoding ServerStatsData where
_qDeleted <- "qDeleted=" *> strP <* A.endOfLine
_msgSent <- "msgSent=" *> strP <* A.endOfLine
_msgRecv <- "msgRecv=" *> strP <* A.endOfLine
r <- optional ("activeQueues:" <* A.endOfLine)
_activeQueues <- case r of
Just _ -> strP <* optional A.endOfLine
_ -> do
_day <- "dayMsgQueues=" *> strP <* A.endOfLine
_week <- "weekMsgQueues=" *> strP <* A.endOfLine
_month <- "monthMsgQueues=" *> strP <* optional A.endOfLine
pure PeriodStatsData {_day, _week, _month}
pure ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _activeQueues}
_msgSentNtf <- "msgSentNtf=" *> strP <* A.endOfLine <|> pure 0
_msgRecvNtf <- "msgRecvNtf=" *> strP <* A.endOfLine <|> pure 0
_activeQueues <-
optional ("activeQueues:" <* A.endOfLine) >>= \case
Just _ -> strP <* optional A.endOfLine
_ -> do
_day <- "dayMsgQueues=" *> strP <* A.endOfLine
_week <- "weekMsgQueues=" *> strP <* A.endOfLine
_month <- "monthMsgQueues=" *> strP <* optional A.endOfLine
pure PeriodStatsData {_day, _week, _month}
_activeQueuesNtf <-
optional ("activeQueuesNtf:" <* A.endOfLine) >>= \case
Just _ -> strP <* optional A.endOfLine
_ -> pure newPeriodStatsData
pure ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _msgSentNtf, _msgRecvNtf, _activeQueues, _activeQueuesNtf, _qCount = 0, _msgCount = 0}
data PeriodStats a = PeriodStats
{ day :: TVar (Set a),
@@ -117,6 +153,9 @@ data PeriodStatsData a = PeriodStatsData
_month :: Set a
}
newPeriodStatsData :: PeriodStatsData a
newPeriodStatsData = PeriodStatsData {_day = S.empty, _week = S.empty, _month = S.empty}
getPeriodStatsData :: PeriodStats a -> STM (PeriodStatsData a)
getPeriodStatsData s = do
_day <- readTVar $ day s
@@ -126,9 +165,9 @@ getPeriodStatsData s = do
setPeriodStats :: PeriodStats a -> PeriodStatsData a -> STM ()
setPeriodStats s d = do
writeTVar (day s) (_day d)
writeTVar (week s) (_week d)
writeTVar (month s) (_month d)
writeTVar (day s) $! _day d
writeTVar (week s) $! _week d
writeTVar (month s) $! _month d
instance (Ord a, StrEncoding a) => StrEncoding (PeriodStatsData a) where
strEncode PeriodStatsData {_day, _week, _month} =
@@ -165,4 +204,4 @@ updatePeriodStats stats pId = do
updatePeriod week
updatePeriod month
where
updatePeriod pSel = modifyTVar (pSel stats) (S.insert pId)
updatePeriod pSel = modifyTVar' (pSel stats) (S.insert pId)
+5 -3
View File
@@ -75,12 +75,14 @@ import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as BL
import Data.Default (def)
import Data.Functor (($>))
import Data.Version (showVersion)
import GHC.Generics (Generic)
import GHC.IO.Handle.Internals (ioe_EOF)
import Generic.Random (genericArbitraryU)
import Network.Socket
import qualified Network.TLS as T
import qualified Network.TLS.Extra as TE
import qualified Paths_simplexmq as SMQ
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Parsers (dropPrefix, parse, parseRead1, sumTypeJSON)
@@ -100,7 +102,7 @@ supportedSMPServerVRange :: VersionRange
supportedSMPServerVRange = mkVersionRange 1 5
simplexMQVersion :: String
simplexMQVersion = "4.0.0"
simplexMQVersion = showVersion SMQ.version
-- * Transport connection class
@@ -214,7 +216,7 @@ instance Transport TLS where
$ do
b <- readChunks =<< readTVarIO buffer
let (s, b') = B.splitAt n b
atomically $ writeTVar buffer b'
atomically $ writeTVar buffer $! b'
pure s
where
readChunks :: ByteString -> IO ByteString
@@ -237,7 +239,7 @@ instance Transport TLS where
$ do
b <- readChunks =<< readTVarIO buffer
let (s, b') = B.break (== '\n') b
atomically $ writeTVar buffer (B.drop 1 b') -- drop '\n' we made a break at
atomically $ writeTVar buffer $! B.drop 1 b' -- drop '\n' we made a break at
pure $ trimCR s
where
readChunks :: ByteString -> IO ByteString
+28
View File
@@ -78,6 +78,8 @@ agentTests (ATransport t) = do
smpAgentTest2_2_1 $ testConcurrentMsgDelivery t
it "should deliver messages if one of connections has quota exceeded" $
smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t
it "should resume delivering messages after exceeding quota once all messages are received" $
smpAgentTest2_2_1 $ testResumeDeliveryQuotaExceeded t
tGetAgent :: Transport c => c -> IO (ATransmissionOrError 'Agent)
tGetAgent h = do
@@ -430,6 +432,32 @@ testMsgDeliveryQuotaExceeded _ alice bob = do
-- if delivery is blocked it won't go further
alice <# ("", "bob2", SENT 4)
testResumeDeliveryQuotaExceeded :: Transport c => TProxy c -> c -> c -> IO ()
testResumeDeliveryQuotaExceeded _ alice bob = do
connect (alice, "alice") (bob, "bob")
forM_ [1 .. 4 :: Int] $ \i -> do
let corrId = bshow i
msg = "message " <> bshow i
(_, "bob", Right (MID mId)) <- alice #: (corrId, "bob", "SEND F :" <> msg)
alice <#= \case ("", "bob", SENT m) -> m == mId; _ -> False
("5", "bob", Right (MID 8)) <- alice #: ("5", "bob", "SEND F :over quota")
alice #:# "the last message not sent yet"
bob <#= \case ("", "alice", Msg "message 1") -> True; _ -> False
bob #: ("1", "alice", "ACK 4") #> ("1", "alice", OK)
alice #:# "the last message not sent"
bob <#= \case ("", "alice", Msg "message 2") -> True; _ -> False
bob #: ("2", "alice", "ACK 5") #> ("2", "alice", OK)
alice #:# "the last message not sent"
bob <#= \case ("", "alice", Msg "message 3") -> True; _ -> False
bob #: ("3", "alice", "ACK 6") #> ("3", "alice", OK)
alice #:# "the last message not sent"
bob <#= \case ("", "alice", Msg "message 4") -> True; _ -> False
bob #: ("4", "alice", "ACK 7") #> ("4", "alice", OK)
alice <# ("", "bob", SENT 8)
bob <#= \case ("", "alice", Msg "over quota") -> True; _ -> False
-- message 8 is skipped because of alice agent sending "QCONT" message
bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK)
connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
connect (h1, name1) (h2, name2) = do
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV")
+52 -69
View File
@@ -15,6 +15,8 @@ module AgentTests.FunctionalAPITests
makeConnection,
exchangeGreetingsMsgId,
switchComplete,
runRight,
runRight_,
get,
(##>),
(=##>),
@@ -80,6 +82,15 @@ agentCfgRatchetV1 = agentCfg {e2eEncryptVRange = vr11}
vr11 :: VersionRange
vr11 = mkVersionRange 1 1
runRight_ :: ExceptT AgentErrorType IO () -> Expectation
runRight_ action = runExceptT action `shouldReturn` Right ()
runRight :: ExceptT AgentErrorType IO a -> IO a
runRight action =
runExceptT action >>= \case
Right x -> pure x
Left e -> error $ "Unexpected error: " <> show e
functionalAPITests :: ATransport -> Spec
functionalAPITests t = do
describe "Establishing duplex connection" $
@@ -217,7 +228,7 @@ runTestCfg2 aliceCfg bobCfg baseMsgId runTest = do
runAgentClientTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientTest alice bob baseId = do
Right () <- runExceptT $ do
runRight_ $ do
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
@@ -247,13 +258,12 @@ runAgentClientTest alice bob baseId = do
get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH))
deleteConnection alice bobId
liftIO $ noMessages alice "nothing else should be delivered to alice"
pure ()
where
msgId = subtract baseId
runAgentClientContactTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientContactTest alice bob baseId = do
Right () <- runExceptT $ do
runRight_ $ do
(_, qInfo) <- createConnection alice True SCMContact Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
("", _, REQ invId _ "bob's connInfo") <- get alice
@@ -285,7 +295,6 @@ runAgentClientContactTest alice bob baseId = do
get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH))
deleteConnection alice bobId
liftIO $ noMessages alice "nothing else should be delivered to alice"
pure ()
where
msgId = subtract baseId
@@ -301,7 +310,7 @@ testAsyncInitiatingOffline :: IO ()
testAsyncInitiatingOffline = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
(bobId, cReq) <- createConnection alice True SCMInvitation Nothing
disconnectAgentClient alice
aliceId <- joinConnection bob True cReq "bob's connInfo"
@@ -313,13 +322,12 @@ testAsyncInitiatingOffline = do
get bob ##> ("", aliceId, INFO "alice's connInfo")
get bob ##> ("", aliceId, CON)
exchangeGreetings alice' bobId bob aliceId
pure ()
testAsyncJoiningOfflineBeforeActivation :: IO ()
testAsyncJoiningOfflineBeforeActivation = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
disconnectAgentClient bob
@@ -331,13 +339,12 @@ testAsyncJoiningOfflineBeforeActivation = do
get bob' ##> ("", aliceId, INFO "alice's connInfo")
get bob' ##> ("", aliceId, CON)
exchangeGreetings alice bobId bob' aliceId
pure ()
testAsyncBothOffline :: IO ()
testAsyncBothOffline = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
(bobId, cReq) <- createConnection alice True SCMInvitation Nothing
disconnectAgentClient alice
aliceId <- joinConnection bob True cReq "bob's connInfo"
@@ -352,22 +359,21 @@ testAsyncBothOffline = do
get bob' ##> ("", aliceId, INFO "alice's connInfo")
get bob' ##> ("", aliceId, CON)
exchangeGreetings alice' bobId bob' aliceId
pure ()
testAsyncServerOffline :: ATransport -> IO ()
testAsyncServerOffline t = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
-- create connection and shutdown the server
Right (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ ->
runExceptT $ createConnection alice True SCMInvitation Nothing
(bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ ->
runRight $ createConnection alice True SCMInvitation Nothing
-- connection fails
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob True cReq "bob's connInfo"
("", "", DOWN srv conns) <- get alice
srv `shouldBe` testSMPServer
conns `shouldBe` [bobId]
-- connection succeeds after server start
Right () <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ do
withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do
("", "", UP srv1 conns1) <- get alice
liftIO $ do
srv1 `shouldBe` testSMPServer
@@ -379,27 +385,25 @@ testAsyncServerOffline t = do
get bob ##> ("", aliceId, INFO "alice's connInfo")
get bob ##> ("", aliceId, CON)
exchangeGreetings alice bobId bob aliceId
pure ()
testAsyncHelloTimeout :: IO ()
testAsyncHelloTimeout = do
-- this test would only work if any of the agent is v1, there is no HELLO timeout in v2
alice <- getSMPAgentClient agentCfgV1 initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2, helloTimeout = 1} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
(_, cReq) <- createConnection alice True SCMInvitation Nothing
disconnectAgentClient alice
aliceId <- joinConnection bob True cReq "bob's connInfo"
get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED)
pure ()
testDuplicateMessage :: ATransport -> IO ()
testDuplicateMessage t = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
(aliceId, bobId, bob1) <- withSmpServerStoreMsgLogOn t testPort $ \_ -> do
Right (aliceId, bobId) <- runExceptT $ makeConnection alice bob
Right () <- runExceptT $ do
(aliceId, bobId) <- runRight $ makeConnection alice bob
runRight_ $ do
4 <- sendMessage alice bobId SMP.noMsgFlags "hello"
get alice ##> ("", bobId, SENT 4)
get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
@@ -407,7 +411,7 @@ testDuplicateMessage t = do
-- if the agent user did not send ACK, the message will be delivered again
bob1 <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
subscribeConnection bob1 aliceId
get bob1 =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
ackMessage bob1 aliceId 4
@@ -419,7 +423,7 @@ testDuplicateMessage t = do
get alice =##> \case ("", "", DOWN _ [c]) -> c == bobId; _ -> False
get bob1 =##> \case ("", "", DOWN _ [c]) -> c == aliceId; _ -> False
-- commenting two lines below and uncommenting further two lines would also pass,
-- commenting two lines below and uncommenting further two lines would also runRight_,
-- it is the scenario tested above, when the message was not acknowledged by the user
threadDelay 200000
Left (BROKER _ TIMEOUT) <- runExceptT $ ackMessage bob1 aliceId 5
@@ -431,7 +435,7 @@ testDuplicateMessage t = do
bob2 <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
withSmpServerStoreMsgLogOn t testPort $ \_ -> do
Right () <- runExceptT $ do
runRight_ $ do
subscribeConnection bob2 aliceId
subscribeConnection alice2 bobId
-- get bob2 =##> \case ("", c, Msg "hello 2") -> c == aliceId; _ -> False
@@ -440,7 +444,6 @@ testDuplicateMessage t = do
6 <- sendMessage alice2 bobId SMP.noMsgFlags "hello 3"
get alice2 ##> ("", bobId, SENT 6)
get bob2 =##> \case ("", c, Msg "hello 3") -> c == aliceId; _ -> False
pure ()
makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId)
makeConnection alice bob = do
@@ -458,10 +461,9 @@ testInactiveClientDisconnected t = do
let cfg' = cfg {inactiveClientExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 1}}
withSmpServerConfigOn t cfg' testPort $ \_ -> do
alice <- getSMPAgentClient agentCfg initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
(connId, _cReq) <- createConnection alice True SCMInvitation Nothing
get alice ##> ("", "", DOWN testSMPServer [connId])
pure ()
testActiveClientNotDisconnected :: ATransport -> IO ()
testActiveClientNotDisconnected t = do
@@ -469,10 +471,9 @@ testActiveClientNotDisconnected t = do
withSmpServerConfigOn t cfg' testPort $ \_ -> do
alice <- getSMPAgentClient agentCfg initAgentServers
ts <- getSystemTime
Right () <- runExceptT $ do
runRight_ $ do
(connId, _cReq) <- createConnection alice True SCMInvitation Nothing
keepSubscribing alice connId ts
pure ()
where
keepSubscribing :: AgentClient -> ConnId -> SystemTime -> ExceptT AgentErrorType IO ()
keepSubscribing alice connId ts = do
@@ -495,7 +496,7 @@ testSuspendingAgent :: IO ()
testSuspendingAgent = do
a <- getSMPAgentClient agentCfg initAgentServers
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
(aId, bId) <- makeConnection a b
4 <- sendMessage a bId SMP.noMsgFlags "hello"
get a ##> ("", bId, SENT 4)
@@ -508,13 +509,12 @@ testSuspendingAgent = do
Nothing <- 100000 `timeout` get b
activateAgent b
get b =##> \case ("", c, Msg "hello 2") -> c == aId; _ -> False
pure ()
testSuspendingAgentCompleteSending :: ATransport -> IO ()
testSuspendingAgentCompleteSending t = do
a <- getSMPAgentClient agentCfg initAgentServers
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ do
(aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do
(aId, bId) <- makeConnection a b
4 <- sendMessage a bId SMP.noMsgFlags "hello"
get a ##> ("", bId, SENT 4)
@@ -522,7 +522,7 @@ testSuspendingAgentCompleteSending t = do
ackMessage b aId 4
pure (aId, bId)
Right () <- runExceptT $ do
runRight_ $ do
("", "", DOWN {}) <- get a
("", "", DOWN {}) <- get b
5 <- sendMessage b aId SMP.noMsgFlags "hello too"
@@ -530,7 +530,7 @@ testSuspendingAgentCompleteSending t = do
liftIO $ threadDelay 100000
suspendAgent b 5000000
Right () <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ do
withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do
get b =##> \case ("", c, SENT 5) -> c == aId; ("", "", UP {}) -> True; _ -> False
get b =##> \case ("", c, SENT 5) -> c == aId; ("", "", UP {}) -> True; _ -> False
get b =##> \case ("", c, SENT 6) -> c == aId; ("", "", UP {}) -> True; _ -> False
@@ -544,13 +544,11 @@ testSuspendingAgentCompleteSending t = do
get a =##> \case ("", c, Msg "how are you?") -> c == bId; _ -> False
ackMessage a bId 6
pure ()
testSuspendingAgentTimeout :: ATransport -> IO ()
testSuspendingAgentTimeout t = do
a <- getSMPAgentClient agentCfg initAgentServers
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right (aId, _) <- withSmpServer t . runExceptT $ do
(aId, _) <- withSmpServer t . runRight $ do
(aId, bId) <- makeConnection a b
4 <- sendMessage a bId SMP.noMsgFlags "hello"
get a ##> ("", bId, SENT 4)
@@ -558,7 +556,7 @@ testSuspendingAgentTimeout t = do
ackMessage b aId 4
pure (aId, bId)
Right () <- runExceptT $ do
runRight_ $ do
("", "", DOWN {}) <- get a
("", "", DOWN {}) <- get b
5 <- sendMessage b aId SMP.noMsgFlags "hello too"
@@ -567,13 +565,11 @@ testSuspendingAgentTimeout t = do
("", "", SUSPENDED) <- get b
pure ()
pure ()
testBatchedSubscriptions :: ATransport -> IO ()
testBatchedSubscriptions t = do
a <- getSMPAgentClient agentCfg initAgentServers2
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers2
Right conns <- runServers $ do
conns <- runServers $ do
conns <- forM [1 .. 200 :: Int] . const $ makeConnection a b
forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId
forM_ (take 10 conns) $ \(aId, bId) -> do
@@ -585,7 +581,7 @@ testBatchedSubscriptions t = do
("", "", DOWN {}) <- get a
("", "", DOWN {}) <- get b
("", "", DOWN {}) <- get b
Right () <- runServers $ do
runServers $ do
("", "", UP {}) <- get a
("", "", UP {}) <- get a
("", "", UP {}) <- get b
@@ -594,7 +590,6 @@ testBatchedSubscriptions t = do
subscribe a $ map snd conns
subscribe b $ map fst conns
forM_ (drop 10 conns) $ \(aId, bId) -> exchangeGreetingsMsgId 6 a bId b aId
pure ()
where
subscribe :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO ()
subscribe c cs = do
@@ -604,13 +599,11 @@ testBatchedSubscriptions t = do
all (== Right ()) (M.withoutKeys r dc) `shouldBe` True
all (== Left (CONN NOT_FOUND)) (M.restrictKeys r dc) `shouldBe` True
M.keys r `shouldMatchList` cs
runServers :: ExceptT AgentErrorType IO a -> IO (Either AgentErrorType a)
runServers :: ExceptT AgentErrorType IO a -> IO a
runServers a = do
withSmpServerStoreLogOn t testPort $ \t1 -> do
res <- withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile2} testPort2 $ \t2 -> do
res <- runExceptT a
killThread t2
pure res
res <- withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile2} testPort2 $ \t2 ->
runRight a `finally` killThread t2
killThread t1
pure res
@@ -618,7 +611,7 @@ testAsyncCommands :: IO ()
testAsyncCommands = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
bobId <- createConnectionAsync alice "1" True SCMInvitation
("1", bobId', INV (ACR _ qInfo)) <- get alice
liftIO $ bobId' `shouldBe` bobId
@@ -655,7 +648,6 @@ testAsyncCommands = do
deleteConnectionAsync alice "8" bobId
("8", _, OK) <- get alice
liftIO $ noMessages alice "nothing else should be delivered to alice"
pure ()
where
baseId = 3
msgId = subtract baseId
@@ -663,22 +655,21 @@ testAsyncCommands = do
testAsyncCommandsRestore :: ATransport -> IO ()
testAsyncCommandsRestore t = do
alice <- getSMPAgentClient agentCfg initAgentServers
Right bobId <- runExceptT $ createConnectionAsync alice "1" True SCMInvitation
bobId <- runRight $ createConnectionAsync alice "1" True SCMInvitation
liftIO $ noMessages alice "alice doesn't receive INV because server is down"
disconnectAgentClient alice
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
withSmpServerStoreLogOn t testPort $ \_ -> do
Right () <- runExceptT $ do
runRight_ $ do
subscribeConnection alice' bobId
("1", _, INV _) <- get alice'
pure ()
pure ()
testAcceptContactAsync :: IO ()
testAcceptContactAsync = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
(_, qInfo) <- createConnection alice True SCMContact Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
("", _, REQ invId _ "bob's connInfo") <- get alice
@@ -712,7 +703,6 @@ testAcceptContactAsync = do
get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH))
deleteConnection alice bobId
liftIO $ noMessages alice "nothing else should be delivered to alice"
pure ()
where
baseId = 3
msgId = subtract baseId
@@ -721,13 +711,12 @@ testSwitchConnection :: InitialAgentServers -> IO ()
testSwitchConnection servers = do
a <- getSMPAgentClient agentCfg servers
b <- getSMPAgentClient agentCfg {database = testDB2, initialClientId = 1} servers
Right () <- runExceptT $ do
runRight_ $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
switchConnectionAsync a "" bId
switchComplete a bId b aId
exchangeGreetingsMsgId 10 a bId b aId
pure ()
switchComplete :: AgentClient -> ByteString -> AgentClient -> ByteString -> ExceptT AgentErrorType IO ()
switchComplete a bId b aId = do
@@ -749,12 +738,12 @@ phase c connId d p =
ERR (AGENT A_DUPLICATE) -> phase c connId d p
r -> do
liftIO . putStrLn $ "expected: " <> show p <> ", received: " <> show r
SWITCH _ _ _ <- pure r
SWITCH {} <- pure r
pure ()
testSwitchAsync :: InitialAgentServers -> IO ()
testSwitchAsync servers = do
Right (aId, bId) <- withA $ \a -> withB $ \b -> runExceptT $ do
(aId, bId) <- withA $ \a -> withB $ \b -> runRight $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
pure (aId, bId)
@@ -769,22 +758,20 @@ testSwitchAsync servers = do
phase b aId QDSnd SPConfirmed
phase b aId QDSnd SPCompleted
withA' $ \a -> phase a bId QDRcv SPCompleted
Right () <- withA $ \a -> withB $ \b -> runExceptT $ do
withA $ \a -> withB $ \b -> runRight_ $ do
subscribeConnection a bId
subscribeConnection b aId
exchangeGreetingsMsgId 10 a bId b aId
pure ()
where
withAgent :: AgentConfig -> (AgentClient -> IO a) -> IO a
withAgent cfg' = bracket (getSMPAgentClient cfg' servers) disconnectAgentClient
session :: (forall a. (AgentClient -> IO a) -> IO a) -> ConnId -> (AgentClient -> ExceptT AgentErrorType IO ()) -> IO ()
session withC connId a = do
Right () <- withC $ \c -> runExceptT $ do
session withC connId a =
withC $ \c -> runRight_ $ do
subscribeConnection c connId
r <- a c
liftIO $ threadDelay 500000
pure r
pure ()
withA = withAgent agentCfg
withB = withAgent agentCfg {database = testDB2, initialClientId = 1}
@@ -792,7 +779,7 @@ testSwitchDelete :: InitialAgentServers -> IO ()
testSwitchDelete servers = do
a <- getSMPAgentClient agentCfg servers
b <- getSMPAgentClient agentCfg {database = testDB2, initialClientId = 1} servers
Right () <- runExceptT $ do
runRight_ $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
disconnectAgentClient b
@@ -801,13 +788,12 @@ testSwitchDelete servers = do
deleteConnectionAsync a "1" bId
("1", bId', OK) <- get a
liftIO $ bId `shouldBe` bId'
pure ()
testCreateQueueAuth :: (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int
testCreateQueueAuth clnt1 clnt2 = do
a <- getClient clnt1
b <- getClient clnt2
Right created <- runExceptT $ do
runRight $ do
tryError (createConnection a True SCMInvitation Nothing) >>= \case
Left (SMP AUTH) -> pure 0
Left e -> throwError e
@@ -823,7 +809,6 @@ testCreateQueueAuth clnt1 clnt2 = do
get b ##> ("", aId, CON)
exchangeGreetings a bId b aId
pure 2
pure created
where
getClient (clntAuth, clntVersion) =
let servers = initAgentServers {smp = [ProtoServerWithAuth testSMPServer clntAuth]}
@@ -834,19 +819,17 @@ testSMPServerConnectionTest :: ATransport -> Maybe BasicAuth -> SMPServerWithAut
testSMPServerConnectionTest t newQueueBasicAuth srv =
withSmpServerConfigOn t cfg {newQueueBasicAuth} testPort2 $ \_ -> do
a <- getSMPAgentClient agentCfg initAgentServers -- initially passed server is not running
Right r <- runExceptT $ testSMPServerConnection a srv
pure r
runRight $ testSMPServerConnection a srv
testRatchetAdHash :: IO ()
testRatchetAdHash = do
a <- getSMPAgentClient agentCfg initAgentServers
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
(aId, bId) <- makeConnection a b
ad1 <- getConnectionRatchetAdHash a bId
ad2 <- getConnectionRatchetAdHash b aId
liftIO $ ad1 `shouldBe` ad2
pure ()
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
exchangeGreetings = exchangeGreetingsMsgId 4
+19 -30
View File
@@ -8,7 +8,7 @@
module AgentTests.NotificationTests where
-- import Control.Logger.Simple (LogConfig (..), LogLevel (..), setLogLevel, withGlobalLogging)
import AgentTests.FunctionalAPITests (exchangeGreetingsMsgId, get, makeConnection, switchComplete, testServerMatrix2, (##>), (=##>), pattern Msg)
import AgentTests.FunctionalAPITests (exchangeGreetingsMsgId, get, makeConnection, runRight, runRight_, switchComplete, testServerMatrix2, (##>), (=##>), pattern Msg)
import Control.Concurrent (killThread, threadDelay)
import Control.Monad.Except
import qualified Data.Aeson as J
@@ -91,7 +91,7 @@ notificationTests t =
testNotificationToken :: APNSMockServer -> IO ()
testNotificationToken APNSMockServer {apnsQ} = do
a <- getSMPAgentClient agentCfg initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
let tkn = DeviceToken PPApnsTest "abcd"
NTRegistered <- registerNtfToken a tkn NMPeriodic
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
@@ -105,7 +105,6 @@ testNotificationToken APNSMockServer {apnsQ} = do
-- agent deleted this token
Left (CMD PROHIBITED) <- tryE $ checkNtfToken a tkn
pure ()
pure ()
(.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString
v .-> key = do
@@ -120,7 +119,7 @@ testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do
-- setLogLevel LogError -- LogDebug
-- withGlobalLogging logCfg $ do
a <- getSMPAgentClient agentCfg initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
let tkn = DeviceToken PPApnsTest "abcd"
NTRegistered <- registerNtfToken a tkn NMPeriodic
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
@@ -138,7 +137,6 @@ testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do
verifyNtfToken a tkn nonce verification
NTActive <- checkNtfToken a tkn
pure ()
pure ()
testNtfTokenSecondRegistration :: APNSMockServer -> IO ()
testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do
@@ -146,7 +144,7 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do
-- withGlobalLogging logCfg $ do
a <- getSMPAgentClient agentCfg initAgentServers
a' <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
let tkn = DeviceToken PPApnsTest "abcd"
NTRegistered <- registerNtfToken a tkn NMPeriodic
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
@@ -175,13 +173,12 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do
-- and the second is active
NTActive <- checkNtfToken a' tkn
pure ()
pure ()
testNtfTokenServerRestart :: ATransport -> APNSMockServer -> IO ()
testNtfTokenServerRestart t APNSMockServer {apnsQ} = do
a <- getSMPAgentClient agentCfg initAgentServers
let tkn = DeviceToken PPApnsTest "abcd"
Right ntfData <- withNtfServer t . runExceptT $ do
ntfData <- withNtfServer t . runRight $ do
NTRegistered <- registerNtfToken a tkn NMPeriodic
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
atomically $ readTBQueue apnsQ
@@ -193,7 +190,7 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do
a' <- getSMPAgentClient agentCfg initAgentServers
-- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token,
-- so that repeat verification happens without restarting the clients, when notification arrives
Right () <- withNtfServer t . runExceptT $ do
withNtfServer t . runRight_ $ do
verification <- ntfData .-> "verification"
nonce <- C.cbNonce <$> ntfData .-> "nonce"
Left (NTF AUTH) <- tryE $ verifyNtfToken a' tkn nonce verification
@@ -205,13 +202,12 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do
verifyNtfToken a' tkn nonce' verification'
NTActive <- checkNtfToken a' tkn
pure ()
pure ()
testNotificationSubscriptionExistingConnection :: APNSMockServer -> IO ()
testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right (bobId, aliceId, nonce, message) <- runExceptT $ do
(bobId, aliceId, nonce, message) <- runRight $ do
-- establish connection
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
@@ -243,12 +239,12 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do
-- aliceNtf client doesn't have subscription and is allowed to get notification message
aliceNtf <- getSMPAgentClient agentCfg initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
(_, [SMPMsgMeta {msgFlags = MsgFlags True}]) <- getNotificationMessage aliceNtf nonce message
pure ()
disconnectAgentClient aliceNtf
Right () <- runExceptT $ do
runRight_ $ do
get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False
ackMessage alice bobId $ baseId + 1
-- delete notification subscription
@@ -259,7 +255,6 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do
get bob ##> ("", aliceId, SENT $ baseId + 2)
-- no notifications should follow
noNotification apnsQ
pure ()
where
baseId = 3
msgId = subtract baseId
@@ -268,7 +263,7 @@ testNotificationSubscriptionNewConnection :: APNSMockServer -> IO ()
testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
-- alice registers notification token
DeviceToken {} <- registerTestToken alice "abcd" NMInstant apnsQ
-- bob registers notification token
@@ -303,7 +298,6 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do
ackMessage bob aliceId $ baseId + 2
-- no unexpected notifications should follow
noNotification apnsQ
pure ()
where
baseId = 3
msgId = subtract baseId
@@ -325,7 +319,7 @@ testChangeNotificationsMode :: APNSMockServer -> IO ()
testChangeNotificationsMode APNSMockServer {apnsQ} = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
-- establish connection
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
@@ -381,7 +375,6 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = do
ackMessage alice bobId $ baseId + 5
-- no notifications should follow
noNotification apnsQ
pure ()
where
baseId = 3
msgId = subtract baseId
@@ -390,7 +383,7 @@ testChangeToken :: APNSMockServer -> IO ()
testChangeToken APNSMockServer {apnsQ} = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right (aliceId, bobId) <- runExceptT $ do
(aliceId, bobId) <- runRight $ do
-- establish connection
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
@@ -412,7 +405,7 @@ testChangeToken APNSMockServer {apnsQ} = do
disconnectAgentClient alice
alice1 <- getSMPAgentClient agentCfg initAgentServers
Right () <- runExceptT $ do
runRight_ $ do
subscribeConnection alice1 bobId
-- change notification token
void $ registerTestToken alice1 "bcde" NMInstant apnsQ
@@ -425,7 +418,6 @@ testChangeToken APNSMockServer {apnsQ} = do
ackMessage alice1 bobId $ baseId + 2
-- no notifications should follow
noNotification apnsQ
pure ()
where
baseId = 3
msgId = subtract baseId
@@ -434,7 +426,7 @@ testNotificationsStoreLog :: ATransport -> APNSMockServer -> IO ()
testNotificationsStoreLog t APNSMockServer {apnsQ} = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right (aliceId, bobId) <- withNtfServerStoreLog t $ \threadId -> runExceptT $ do
(aliceId, bobId) <- withNtfServerStoreLog t $ \threadId -> runRight $ do
(aliceId, bobId) <- makeConnection alice bob
_ <- registerTestToken alice "abcd" NMInstant apnsQ
liftIO $ threadDelay 250000
@@ -448,20 +440,19 @@ testNotificationsStoreLog t APNSMockServer {apnsQ} = do
liftIO $ threadDelay 250000
Right () <- withNtfServerStoreLog t $ \threadId -> runExceptT $ do
withNtfServerStoreLog t $ \threadId -> runRight_ $ do
liftIO $ threadDelay 250000
5 <- sendMessage bob aliceId (SMP.MsgFlags True) "hello again"
get bob ##> ("", aliceId, SENT 5)
void $ messageNotification apnsQ
get alice =##> \case ("", c, Msg "hello again") -> c == bobId; _ -> False
liftIO $ killThread threadId
pure ()
testNotificationsSMPRestart :: ATransport -> APNSMockServer -> IO ()
testNotificationsSMPRestart t APNSMockServer {apnsQ} = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \threadId -> runExceptT $ do
(aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \threadId -> runRight $ do
(aliceId, bobId) <- makeConnection alice bob
_ <- registerTestToken alice "abcd" NMInstant apnsQ
liftIO $ threadDelay 250000
@@ -473,11 +464,11 @@ testNotificationsSMPRestart t APNSMockServer {apnsQ} = do
liftIO $ killThread threadId
pure (aliceId, bobId)
Right () <- runExceptT $ do
runRight_ $ do
get alice =##> \case ("", "", DOWN _ [c]) -> c == bobId; _ -> False
get bob =##> \case ("", "", DOWN _ [c]) -> c == aliceId; _ -> False
Right () <- withSmpServerStoreLogOn t testPort $ \threadId -> runExceptT $ do
withSmpServerStoreLogOn t testPort $ \threadId -> runRight_ $ do
get alice =##> \case ("", "", UP _ [c]) -> c == bobId; _ -> False
get bob =##> \case ("", "", UP _ [c]) -> c == aliceId; _ -> False
liftIO $ threadDelay 1000000
@@ -486,13 +477,12 @@ testNotificationsSMPRestart t APNSMockServer {apnsQ} = do
_ <- messageNotificationData alice apnsQ
get alice =##> \case ("", c, Msg "hello again") -> c == bobId; _ -> False
liftIO $ killThread threadId
pure ()
testSwitchNotifications :: InitialAgentServers -> APNSMockServer -> IO ()
testSwitchNotifications servers APNSMockServer {apnsQ} = do
a <- getSMPAgentClient agentCfg servers
b <- getSMPAgentClient agentCfg {database = testDB2, initialClientId = 1} servers
Right () <- runExceptT $ do
runRight_ $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
_ <- registerTestToken a "abcd" NMInstant apnsQ
@@ -508,7 +498,6 @@ testSwitchNotifications servers APNSMockServer {apnsQ} = do
switchComplete a bId b aId
liftIO $ threadDelay 500000
testMessage "hello again"
pure ()
messageNotification :: TBQueue APNSMockRequest -> ExceptT AgentErrorType IO (C.CbNonce, ByteString)
messageNotification apnsQ = do
+3 -2
View File
@@ -6,6 +6,7 @@ import Data.Ini (lookupValue, readIniFile)
import Data.List (isPrefixOf)
import Simplex.Messaging.Notifications.Server.Main
import Simplex.Messaging.Server.Main
import Simplex.Messaging.Transport (simplexMQVersion)
import Simplex.Messaging.Util (catchAll_)
import System.Directory (doesFileExist)
import System.Environment (withArgs)
@@ -51,7 +52,7 @@ smpServerTest storeLog basicAuth = do
lookupValue "INACTIVE_CLIENTS" "disconnect" ini `shouldBe` Right "off"
doesFileExist (cfgPath <> "/ca.key") `shouldReturn` True
r <- lines <$> capture_ (withArgs ["start"] $ (100000 `timeout` smpServerCLI cfgPath logPath) `catchAll_` pure (Just ()))
r `shouldContain` ["SMP server v4.0.0"]
r `shouldContain` ["SMP server v" <> simplexMQVersion]
r `shouldContain` (if storeLog then ["Store log: " <> logPath <> "/smp-server-store.log"] else ["Store log disabled."])
r `shouldContain` ["Listening on port 5223 (TLS)..."]
r `shouldContain` ["not expiring inactive clients"]
@@ -71,7 +72,7 @@ ntfServerTest storeLog = do
lookupValue "TRANSPORT" "websockets" ini `shouldBe` Right "off"
doesFileExist (ntfCfgPath <> "/ca.key") `shouldReturn` True
r <- lines <$> capture_ (withArgs ["start"] $ (100000 `timeout` ntfServerCLI ntfCfgPath ntfLogPath) `catchAll_` pure (Just ()))
r `shouldContain` ["SMP notifications server v1.2.0"]
r `shouldContain` ["SMP notifications server v" <> ntfServerVersion]
r `shouldContain` (if storeLog then ["Store log: " <> ntfLogPath <> "/ntf-server-store.log"] else ["Store log disabled."])
r `shouldContain` ["Listening on port 443 (TLS)..."]
capture_ (withStdin "Y" . withArgs ["delete"] $ ntfServerCLI ntfCfgPath ntfLogPath)
+41 -5
View File
@@ -1,18 +1,20 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module CoreTests.CryptoTests (cryptoTests) where
import qualified Data.ByteString.Char8 as B
import Data.Either (isRight)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import qualified Simplex.Messaging.Crypto as C
import Test.Hspec
import Test.Hspec.QuickCheck (modifyMaxSuccess)
import Test.QuickCheck
import qualified Data.Text as T
import qualified Data.ByteString.Char8 as B
import Data.Text.Encoding (encodeUtf8)
cryptoTests :: Spec
cryptoTests = modifyMaxSuccess (const 10000) $ do
describe "padding / unpadding" $ do
cryptoTests = do
modifyMaxSuccess (const 10000) . describe "padding / unpadding" $ do
it "should pad / unpad string" . property $ \(s, paddedLen) ->
let b = encodeUtf8 $ T.pack s
len = B.length b
@@ -34,3 +36,37 @@ cryptoTests = modifyMaxSuccess (const 10000) $ do
it "unpad should fail on shorter string" $ do
C.unPad "\000\003abc" `shouldBe` Right "abc"
C.unPad "\000\003ab" `shouldBe` Left C.CryptoInvalidMsgError
describe "Ed signatures" $ do
describe "Ed25519" $ testSignature C.SEd25519
describe "Ed448" $ testSignature C.SEd448
describe "DH X25519 + cryptobox" $
testDHCryptoBox
describe "X509 key encoding" $ do
describe "Ed25519" $ testEncoding C.SEd25519
describe "Ed448" $ testEncoding C.SEd448
describe "X25519" $ testEncoding C.SX25519
describe "X448" $ testEncoding C.SX448
testSignature :: (C.AlgorithmI a, C.SignatureAlgorithm a) => C.SAlgorithm a -> Spec
testSignature alg = it "should sign / verify string" . ioProperty $ do
(k, pk) <- C.generateSignatureKeyPair alg
pure $ \s -> let b = encodeUtf8 $ T.pack s in C.verify k (C.sign pk b) b
testDHCryptoBox :: Spec
testDHCryptoBox = it "should encrypt / decrypt string" . ioProperty $ do
(sk, spk) <- C.generateKeyPair'
(rk, rpk) <- C.generateKeyPair'
nonce <- C.randomCbNonce
pure $ \(s, pad) ->
let b = encodeUtf8 $ T.pack s
paddedLen = B.length b + abs pad + 2
cipher = C.cbEncrypt (C.dh' rk spk) nonce b paddedLen
plain = C.cbDecrypt (C.dh' sk rpk) nonce =<< cipher
in isRight cipher && cipher /= plain && Right b == plain
testEncoding :: (C.AlgorithmI a) => C.SAlgorithm a -> Spec
testEncoding alg = it "should encode / decode key" . ioProperty $ do
(k, pk) <- C.generateKeyPair alg
pure $ \(_ :: Int) ->
C.decodePubKey (C.encodePubKey k) == Right k
&& C.decodePrivKey (C.encodePrivKey pk) == Right pk
+61
View File
@@ -0,0 +1,61 @@
{-# LANGUAGE ScopedTypeVariables #-}
module CoreTests.RetryIntervalTests where
import Control.Concurrent.STM
import Control.Monad (when)
import Data.Time.Clock (UTCTime, diffUTCTime, getCurrentTime, nominalDiffTimeToSeconds)
import Simplex.Messaging.Agent.RetryInterval
import Test.Hspec
retryIntervalTests :: Spec
retryIntervalTests = do
describe "Retry interval with 2 modes and lock" $ do
testRetryIntervalSameMode
testRetryIntervalSwitchMode
testRI :: RetryInterval2
testRI =
RetryInterval2
{ riSlow =
RetryInterval
{ initialInterval = 20000,
increaseAfter = 40000,
maxInterval = 40000
},
riFast =
RetryInterval
{ initialInterval = 10000,
increaseAfter = 20000,
maxInterval = 40000
}
}
testRetryIntervalSameMode :: Spec
testRetryIntervalSameMode =
it "should increase elapased time and interval when the mode stays the same" $ do
lock <- newEmptyTMVarIO
intervals <- newTVarIO []
ts <- newTVarIO =<< getCurrentTime
withRetryLock2 testRI lock $ \loop -> do
ints <- addInterval intervals ts
when (length ints < 9) $ loop RIFast
(reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 4, 4, 4]
testRetryIntervalSwitchMode :: Spec
testRetryIntervalSwitchMode =
it "should increase elapased time and interval when the mode stays the same" $ do
lock <- newEmptyTMVarIO
intervals <- newTVarIO []
ts <- newTVarIO =<< getCurrentTime
withRetryLock2 testRI lock $ \loop -> do
ints <- addInterval intervals ts
when (length ints < 11) $ loop $ if length ints <= 5 then RIFast else RISlow
(reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 2, 2, 3, 4, 4]
addInterval :: TVar [Int] -> TVar UTCTime -> IO [Int]
addInterval intervals ts = do
ts' <- getCurrentTime
atomically $ do
int :: Int <- truncate . (* 100) . nominalDiffTimeToSeconds <$> stateTVar ts (\t -> (diffUTCTime ts' t, ts'))
stateTVar intervals $ \ints -> (int : ints, int : ints)
+1 -3
View File
@@ -10,7 +10,6 @@
module NtfServerTests where
import Control.Concurrent (threadDelay)
import Control.Monad.Except (runExceptT)
import qualified Data.Aeson as J
import qualified Data.Aeson.Types as JT
import Data.Bifunctor (first)
@@ -77,8 +76,7 @@ sendRecvNtf h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse)
signSendRecvNtf h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
Right sig <- runExceptT $ C.sign pk t
Right () <- tPut1 h (Just sig, t)
Right () <- tPut1 h (Just $ C.sign pk t, t)
tGet1 h
(.->) :: J.Value -> J.Key -> Either String ByteString
+47 -11
View File
@@ -14,7 +14,7 @@ module ServerTests where
import Control.Concurrent (ThreadId, killThread, threadDelay)
import Control.Concurrent.STM
import Control.Exception (SomeException, try)
import Control.Monad.Except (forM, forM_, runExceptT)
import Control.Monad.Except (forM, forM_)
import Control.Monad.IO.Class
import Data.Bifunctor (first)
import Data.ByteString.Base64
@@ -49,9 +49,10 @@ serverTests t@(ATransport t') = do
describe "switch subscription to another TCP connection" $ testSwitchSub t
describe "GET command" $ testGetCommand t'
describe "GET & SUB commands" $ testGetSubCommands t'
describe "Exceeding queue quota" $ testExceedQueueQuota t'
describe "Store log" $ testWithStoreLog t
describe "Restore messages" $ testRestoreMessages t
describe "Restore messages (v2)" $ testRestoreMessagesV2 t
describe "Restore messages (old / v2)" $ testRestoreMessagesV2 t
describe "Timing of AUTH error" $ testTiming t
describe "Message notifications" $ testMessageNotifications t
describe "Message expiration" $ do
@@ -77,8 +78,7 @@ sendRecv h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do
signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg)
signSendRecv h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
Right sig <- runExceptT $ C.sign pk t
Right () <- tPut1 h (Just sig, t)
Right () <- tPut1 h (Just $ C.sign pk t, t)
tGet1 h
tPut1 :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ())
@@ -104,9 +104,11 @@ decryptMsgV2 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either C.Cry
decryptMsgV2 dhShared = C.cbDecrypt dhShared . C.cbNonce
decryptMsgV3 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either String MsgBody
decryptMsgV3 dhShared nonce body = do
ClientRcvMsgBody {msgBody} <- parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body)
pure msgBody
decryptMsgV3 dhShared nonce body =
case parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body) of
Right ClientRcvMsgBody {msgBody} -> Right msgBody
Right ClientRcvMsgQuota {} -> Left "ClientRcvMsgQuota"
Left e -> Left e
testCreateSecureV2 :: forall c. Transport c => TProxy c -> Spec
testCreateSecureV2 _ =
@@ -494,6 +496,32 @@ testGetSubCommands t =
Resp "12" _ OK <- signSendRecv rh2 rKey ("12", rId, GET)
pure ()
testExceedQueueQuota :: forall c. Transport c => TProxy c -> Spec
testExceedQueueQuota t =
it "should reply with ERR QUOTA to sender and send QUOTA message to the recipient" $ do
withSmpServerConfigOn (ATransport t) cfg {msgQueueQuota = 2} testPort $ \_ ->
testSMPClient @c $ \sh -> testSMPClient @c $ \rh -> do
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
(sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub
let dec = decryptMsgV3 dhShared
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello 1")
Resp "2" _ OK <- signSendRecv sh sKey ("2", sId, _SEND "hello 2")
Resp "3" _ (ERR QUOTA) <- signSendRecv sh sKey ("3", sId, _SEND "hello 3")
Resp "" _ (Msg mId1 msg1) <- tGet1 rh
(dec mId1 msg1, Right "hello 1") #== "hello 1"
Resp "4" _ (Msg mId2 msg2) <- signSendRecv rh rKey ("4", rId, ACK mId1)
(dec mId2 msg2, Right "hello 2") #== "hello 2"
Resp "5" _ (ERR QUOTA) <- signSendRecv sh sKey ("5", sId, _SEND "hello 3")
Resp "6" _ (Msg mId3 msg3) <- signSendRecv rh rKey ("6", rId, ACK mId2)
(dec mId3 msg3, Left "ClientRcvMsgQuota") #== "ClientRcvMsgQuota"
Resp "7" _ (ERR QUOTA) <- signSendRecv sh sKey ("7", sId, _SEND "hello 3")
Resp "8" _ OK <- signSendRecv rh rKey ("8", rId, ACK mId3)
Resp "9" _ OK <- signSendRecv sh sKey ("9", sId, _SEND "hello 3")
Resp "" _ (Msg mId4 msg4) <- tGet1 rh
(dec mId4 msg4, Right "hello 3") #== "hello 3"
Resp "10" _ OK <- signSendRecv rh rKey ("10", rId, ACK mId4)
pure ()
testWithStoreLog :: ATransport -> Spec
testWithStoreLog at@(ATransport t) =
it "should store simplex queues to log and restore them after server restart" $ do
@@ -600,10 +628,12 @@ testRestoreMessages at@(ATransport t) =
Resp "2" _ OK <- signSendRecv h sKey ("2", sId, _SEND "hello 2")
Resp "3" _ OK <- signSendRecv h sKey ("3", sId, _SEND "hello 3")
Resp "4" _ OK <- signSendRecv h sKey ("4", sId, _SEND "hello 4")
Resp "5" _ OK <- signSendRecv h sKey ("5", sId, _SEND "hello 5")
Resp "6" _ (ERR QUOTA) <- signSendRecv h sKey ("6", sId, _SEND "hello 6")
pure ()
logSize testStoreLogFile `shouldReturn` 2
logSize testStoreMsgsFile `shouldReturn` 3
logSize testStoreMsgsFile `shouldReturn` 5
withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do
rId <- readTVarIO recipientId
@@ -619,15 +649,21 @@ testRestoreMessages at@(ATransport t) =
logSize testStoreLogFile `shouldReturn` 1
-- the last message is not removed because it was not ACK'd
logSize testStoreMsgsFile `shouldReturn` 1
logSize testStoreMsgsFile `shouldReturn` 3
withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do
rId <- readTVarIO recipientId
Just rKey <- readTVarIO recipientKey
Just dh <- readTVarIO dhShared
let dec = decryptMsgV3 dh
Resp "4" _ (Msg mId4 msg4) <- signSendRecv h rKey ("4", rId, SUB)
Resp "5" _ OK <- signSendRecv h rKey ("5", rId, ACK mId4)
(decryptMsgV3 dh mId4 msg4, Right "hello 4") #== "restored message delivered"
(dec mId4 msg4, Right "hello 4") #== "restored message delivered"
Resp "5" _ (Msg mId5 msg5) <- signSendRecv h rKey ("5", rId, ACK mId4)
(dec mId5 msg5, Right "hello 5") #== "restored message delivered"
Resp "6" _ (Msg mId6 msg6) <- signSendRecv h rKey ("6", rId, ACK mId5)
(dec mId6 msg6, Left "ClientRcvMsgQuota") #== "restored message delivered"
Resp "7" _ OK <- signSendRecv h rKey ("7", rId, ACK mId6)
pure ()
logSize testStoreLogFile `shouldReturn` 1
logSize testStoreMsgsFile `shouldReturn` 0
+22 -20
View File
@@ -1,11 +1,12 @@
{-# LANGUAGE TypeApplications #-}
import AgentTests (agentTests)
-- import Control.Logger.Simple
import CLITests
import Control.Logger.Simple
import CoreTests.CryptoTests
import CoreTests.EncodingTests
import CoreTests.ProtocolErrorTests
import CoreTests.RetryIntervalTests
import CoreTests.VersionRangeTests
import NtfServerTests (ntfServerTests)
import ServerTests
@@ -15,25 +16,26 @@ import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive)
import System.Environment (setEnv)
import Test.Hspec
-- logCfg :: LogConfig
-- logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
logCfg :: LogConfig
logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
main :: IO ()
main = do
-- setLogLevel LogInfo -- LogError
-- withGlobalLogging logCfg $ do
createDirectoryIfMissing False "tests/tmp"
setEnv "APNS_KEY_ID" "H82WD9K9AQ"
setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8"
hspec $ do
describe "Core tests" $ do
describe "Encoding tests" encodingTests
describe "Protocol error tests" protocolErrorTests
describe "Version range" versionRangeTests
describe "Encryption tests" cryptoTests
describe "SMP server via TLS" $ serverTests (transport @TLS)
describe "SMP server via WebSockets" $ serverTests (transport @WS)
describe "Notifications server" $ ntfServerTests (transport @TLS)
describe "SMP client agent" $ agentTests (transport @TLS)
describe "Server CLIs" cliTests
removeDirectoryRecursive "tests/tmp"
setLogLevel LogError -- LogInfo
withGlobalLogging logCfg $ do
createDirectoryIfMissing False "tests/tmp"
setEnv "APNS_KEY_ID" "H82WD9K9AQ"
setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8"
hspec $ do
describe "Core tests" $ do
describe "Encoding tests" encodingTests
describe "Protocol error tests" protocolErrorTests
describe "Version range" versionRangeTests
describe "Encryption tests" cryptoTests
describe "Retry interval tests" retryIntervalTests
describe "SMP server via TLS" $ serverTests (transport @TLS)
describe "SMP server via WebSockets" $ serverTests (transport @WS)
describe "Notifications server" $ ntfServerTests (transport @TLS)
describe "SMP client agent" $ agentTests (transport @TLS)
describe "Server CLIs" cliTests
removeDirectoryRecursive "tests/tmp"