From 4a405a94bb7dda12248fb6c850249b18a30eb152 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Sun, 8 Jun 2025 21:14:56 +0100 Subject: [PATCH] smp server: batch commands (#1560) * smp server: batch commands verification into one DB transaction * ghc 8.10.7 * flatten transmission tuples --- src/Simplex/FileTransfer/Client.hs | 11 +- src/Simplex/FileTransfer/Protocol.hs | 24 +- src/Simplex/FileTransfer/Server.hs | 33 +- src/Simplex/Messaging/Agent/Client.hs | 9 +- src/Simplex/Messaging/Client.hs | 37 +- src/Simplex/Messaging/Client/Agent.hs | 28 +- src/Simplex/Messaging/Notifications/Server.hs | 75 ++--- .../Messaging/Notifications/Server/Env.hs | 6 +- src/Simplex/Messaging/Protocol.hs | 252 +++++++++----- src/Simplex/Messaging/Server.hs | 316 +++++++++--------- src/Simplex/Messaging/Server/Env/STM.hs | 5 +- .../Messaging/Server/MsgStore/Journal.hs | 2 + .../Messaging/Server/MsgStore/Types.hs | 20 +- .../Messaging/Server/QueueStore/Postgres.hs | 66 +++- .../Messaging/Server/QueueStore/STM.hs | 26 +- .../Messaging/Server/QueueStore/Types.hs | 5 +- src/Simplex/Messaging/Server/StoreLog.hs | 2 +- tests/CoreTests/BatchingTests.hs | 8 +- tests/CoreTests/StoreLogTests.hs | 2 +- tests/NtfClient.hs | 2 +- tests/NtfServerTests.hs | 12 +- tests/SMPClient.hs | 2 +- tests/SMPProxyTests.hs | 4 +- tests/ServerTests.hs | 38 +-- 24 files changed, 585 insertions(+), 400 deletions(-) diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index d3e43907a..01eca236c 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -194,7 +194,7 @@ sendXFTPCommand c@XFTPClient {thParams} pKey fId cmd chunkSpec_ = do let corrIdUsedAsNonce = "" t <- liftEither . first PCETransportError $ - xftpEncodeAuthTransmission thParams pKey ((corrIdUsedAsNonce, fId), FileCmd (sFileParty @p) cmd) + xftpEncodeAuthTransmission thParams pKey (corrIdUsedAsNonce, fId, FileCmd (sFileParty @p) cmd) sendXFTPTransmission c t chunkSpec_ sendXFTPTransmission :: XFTPClient -> ByteString -> Maybe XFTPChunkSpec -> ExceptT XFTPClientError IO (FileResponse, HTTP2Body) @@ -204,11 +204,12 @@ sendXFTPTransmission XFTPClient {config, thParams, http2Client} t chunkSpec_ = d HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req (Just reqTimeout) when (B.length bodyHead /= xftpBlockSize) $ throwE $ PCEResponseError BLOCK -- TODO validate that the file ID is the same as in the request? - liftEither (first PCEResponseError $ xftpDecodeTransmission thParams bodyHead) >>= \case - Right (_, (_, r)) -> case protocolError r of + (_, _fId, respOrErr) <-liftEither $ first PCEResponseError $ xftpDecodeTClient thParams bodyHead + case respOrErr of + Right r -> case protocolError r of Just e -> throwE $ PCEProtocolError e _ -> pure (r, body) - Left (_, e) -> throwE $ PCEResponseError e + Left e -> throwE $ PCEResponseError e where streamBody :: (Builder -> IO ()) -> IO () -> IO () streamBody send done = do @@ -283,7 +284,7 @@ pingXFTP :: XFTPClient -> ExceptT XFTPClientError IO () pingXFTP c@XFTPClient {thParams} = do t <- liftEither . first PCETransportError $ - xftpEncodeTransmission thParams (("", NoEntity), FileCmd SFRecipient PING) + xftpEncodeTransmission thParams ("", NoEntity, FileCmd SFRecipient PING) (r, _) <- sendXFTPTransmission c t Nothing case r of FRPong -> pure () diff --git a/src/Simplex/FileTransfer/Protocol.hs b/src/Simplex/FileTransfer/Protocol.hs index 246abfbea..65f85d710 100644 --- a/src/Simplex/FileTransfer/Protocol.hs +++ b/src/Simplex/FileTransfer/Protocol.hs @@ -44,6 +44,7 @@ import Simplex.Messaging.Protocol EntityId (..), RecipientId, SenderId, + RawTransmission, SentRawTransmission, SignedTransmissionOrError, SndPublicAuthKey, @@ -53,7 +54,8 @@ import Simplex.Messaging.Protocol encodeTransmission, encodeTransmissionForAuth, messageTagP, - tDecodeParseValidate, + tDecodeServer, + tDecodeClient, tEncodeBatch1, tParse, ) @@ -336,7 +338,7 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of _ -> Nothing xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion XFTPErrorType c => THandleParams XFTPVersion 'TClient -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString -xftpEncodeAuthTransmission thParams@THandleParams {thAuth} pKey t@((corrId, _), _) = do +xftpEncodeAuthTransmission thParams@THandleParams {thAuth} pKey t@(corrId, _, _) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams t xftpEncodeBatch1 . (,tToSend) =<< authTransmission thAuth False (Just pKey) (C.cbNonce $ bs corrId) tForAuth @@ -347,11 +349,23 @@ xftpEncodeTransmission thParams t = xftpEncodeBatch1 (Nothing, encodeTransmissio xftpEncodeBatch1 :: SentRawTransmission -> Either TransportError ByteString xftpEncodeBatch1 t = first (const TELargeMsg) $ C.pad (tEncodeBatch1 False t) xftpBlockSize -xftpDecodeTransmission :: ProtocolEncoding XFTPVersion XFTPErrorType c => THandleParams XFTPVersion p -> ByteString -> Either XFTPErrorType (SignedTransmissionOrError XFTPErrorType c) -xftpDecodeTransmission thParams t = do +xftpDecodeTServer :: THandleParams XFTPVersion 'TServer -> ByteString -> Either XFTPErrorType (SignedTransmissionOrError XFTPErrorType FileCmd) +xftpDecodeTServer = xftpDecodeTransmission tDecodeServer +{-# INLINE xftpDecodeTServer #-} + +xftpDecodeTClient :: THandleParams XFTPVersion 'TClient -> ByteString -> Either XFTPErrorType (Transmission (Either XFTPErrorType FileResponse)) +xftpDecodeTClient = xftpDecodeTransmission tDecodeClient +{-# INLINE xftpDecodeTClient #-} + +xftpDecodeTransmission :: + (THandleParams XFTPVersion p -> Either TransportError RawTransmission -> r) -> + THandleParams XFTPVersion p -> + ByteString -> + Either XFTPErrorType r +xftpDecodeTransmission tDecode thParams t = do t' <- first (const BLOCK) $ C.unPad t case tParse thParams t' of - t'' :| [] -> Right $ tDecodeParseValidate thParams t'' + t'' :| [] -> Right $ tDecode thParams t'' _ -> Left BLOCK $(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty) diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index c2f415de1..a5e5727e4 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -53,7 +53,7 @@ import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (CorrId (..), BlockingInfo, EntityId (..), RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TAuthorizations, pattern NoEntity) +import Simplex.Messaging.Protocol (BlockingInfo, EntityId (..), RcvPublicAuthKey, RcvPublicDhKey, RecipientId, SignedTransmission, pattern NoEntity) import Simplex.Messaging.Server (dummyVerifyCmd, verifyCmdAuthorization) import Simplex.Messaging.Server.Control (CPClientRole (..)) import Simplex.Messaging.Server.Expiration @@ -316,20 +316,21 @@ data ServerFile = ServerFile processRequest :: XFTPTransportRequest -> M () processRequest XFTPTransportRequest {thParams, reqBody = body@HTTP2Body {bodyHead}, sendResponse} - | B.length bodyHead /= xftpBlockSize = sendErr ("", NoEntity) BLOCK + | B.length bodyHead /= xftpBlockSize = sendXFTPResponse ("", NoEntity, FRErr BLOCK) Nothing | otherwise = - case xftpDecodeTransmission thParams bodyHead of - Right (Right (signed, (cfIds@(corrId, fId), cmd))) -> do + case xftpDecodeTServer thParams bodyHead of + Right (Right t@(_, _, (corrId, fId, _))) -> do let THandleParams {thAuth} = thParams - verifyXFTPTransmission ((,C.cbNonce (bs corrId)) <$> thAuth) signed fId cmd >>= \case - VRVerified req -> uncurry (sendXFTPResponse cfIds) =<< processXFTPRequest body req - VRFailed e -> sendErr cfIds e - Right (Left (cfIds, e)) -> sendErr cfIds e - Left e -> sendErr ("", NoEntity) e + verifyXFTPTransmission thAuth t >>= \case + VRVerified req -> uncurry send =<< processXFTPRequest body req + VRFailed e -> send (FRErr e) Nothing + where + send resp = sendXFTPResponse (corrId, fId, resp) + Right (Left (corrId, fId, e)) -> sendXFTPResponse (corrId, fId, FRErr e) Nothing + Left e -> sendXFTPResponse ("", NoEntity, FRErr e) Nothing where - sendErr cfIds e = sendXFTPResponse cfIds (FRErr e) Nothing - sendXFTPResponse cfIds resp serverFile_ = do - let t_ = xftpEncodeTransmission thParams (cfIds, resp) + sendXFTPResponse t' serverFile_ = do + let t_ = xftpEncodeTransmission thParams t' #ifdef slow_servers randomDelay #endif @@ -358,8 +359,8 @@ randomDelay = do data VerificationResult = VRVerified XFTPRequest | VRFailed XFTPErrorType -verifyXFTPTransmission :: Maybe (THandleAuth 'TServer, C.CbNonce) -> (Maybe TAuthorizations, ByteString) -> XFTPFileId -> FileCmd -> M VerificationResult -verifyXFTPTransmission auth_ (tAuth, authorized) fId cmd = +verifyXFTPTransmission :: Maybe (THandleAuth 'TServer) -> SignedTransmission FileCmd -> M VerificationResult +verifyXFTPTransmission thAuth (tAuth, authorized, (corrId, fId, cmd)) = case cmd of FileCmd SFSender (FNEW file rcps auth') -> pure $ XFTPReqNew file rcps auth' `verifyWith` sndKey file FileCmd SFRecipient PING -> pure $ VRVerified XFTPReqPing @@ -378,9 +379,9 @@ verifyXFTPTransmission auth_ (tAuth, authorized) fId cmd = EntityBlocked info -> VRFailed $ BLOCKED info EntityOff -> noFileAuth Left _ -> pure noFileAuth - noFileAuth = maybe False (dummyVerifyCmd Nothing authorized) tAuth `seq` VRFailed AUTH + noFileAuth = dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH -- TODO verify with DH authorization - req `verifyWith` k = if verifyCmdAuthorization auth_ tAuth authorized k then VRVerified req else VRFailed AUTH + req `verifyWith` k = if verifyCmdAuthorization thAuth tAuth authorized corrId k then VRVerified req else VRFailed AUTH processXFTPRequest :: HTTP2Body -> XFTPRequest -> M (FileResponse, Maybe ServerFile) processXFTPRequest HTTP2Body {bodyPart} = \case diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 3f770f468..28c4fd7a3 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -33,6 +33,7 @@ module Simplex.Messaging.Agent.Client withConnLocks, withInvLock, withLockMap, + withLocksMap, getMapLock, ipAddressProtected, closeAgentClient, @@ -1004,16 +1005,16 @@ withInvLock' AgentClient {invLocks} = withLockMap invLocks {-# INLINE withInvLock' #-} withConnLocks :: AgentClient -> Set ConnId -> Text -> AM' a -> AM' a -withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks +withConnLocks AgentClient {connLocks} = withLocksMap connLocks {-# INLINE withConnLocks #-} withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> Text -> m a -> m a withLockMap = withGetLock . getMapLock {-# INLINE withLockMap #-} -withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> Text -> m a -> m a -withLocksMap_ = withGetLocks . getMapLock -{-# INLINE withLocksMap_ #-} +withLocksMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> Text -> m a -> m a +withLocksMap = withGetLocks . getMapLock +{-# INLINE withLocksMap #-} getMapLock :: Ord k => TMap k Lock -> k -> STM Lock getMapLock locks key = TM.lookup key locks >>= maybe newLock pure diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 854a5810b..16b1169d4 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -183,7 +183,7 @@ data PClient v err msg = PClient clientCorrId :: TVar ChaChaDRG, sentCommands :: TMap CorrId (Request err msg), sndQ :: TBQueue (Maybe (Request err msg), ByteString), - rcvQ :: TBQueue (NonEmpty (SignedTransmissionOrError err msg)), + rcvQ :: TBQueue (NonEmpty (Transmission (Either err msg))), msgQ :: Maybe (TBQueue (ServerTransmissionBatch v err msg)) } @@ -615,7 +615,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize receive :: Transport c => ProtocolClient v err msg -> THandle v c 'TClient -> IO () receive ProtocolClient {client_ = PClient {rcvQ, lastReceived, timeoutErrorCount}} h = forever $ do - tGet h >>= atomically . writeTBQueue rcvQ + tGetClient h >>= atomically . writeTBQueue rcvQ getCurrentTime >>= atomically . writeTVar lastReceived atomically $ writeTVar timeoutErrorCount 0 @@ -642,19 +642,14 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize process :: ProtocolClient v err msg -> IO () process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= processMsgs c - processMsgs :: ProtocolClient v err msg -> NonEmpty (SignedTransmissionOrError err msg) -> IO () + processMsgs :: ProtocolClient v err msg -> NonEmpty (Transmission (Either err msg)) -> IO () processMsgs c ts = do - ts' <- catMaybes <$> mapM (processMsg c . t) (L.toList ts) + ts' <- catMaybes <$> mapM (processMsg c) (L.toList ts) forM_ msgQ $ \q -> mapM_ (atomically . writeTBQueue q . serverTransmission c) (L.nonEmpty ts') - where - t :: SignedTransmissionOrError err msg -> Transmission (Either err msg) - t = \case - Left (ce, err) -> (ce, Left err) - Right (_, (ce, cmd)) -> (ce, Right cmd) processMsg :: ProtocolClient v err msg -> Transmission (Either err msg) -> IO (Maybe (EntityId, ServerTransmission err msg)) - processMsg ProtocolClient {client_ = PClient {sentCommands}} ((corrId, entId), respOrErr) + processMsg ProtocolClient {client_ = PClient {sentCommands}} (corrId, entId, respOrErr) | B.null $ bs corrId = sendMsg $ STEvent clientResp | otherwise = TM.lookupIO corrId sentCommands >>= \case @@ -772,7 +767,7 @@ createSMPQueue :: -- Maybe NewNtfCreds -> ExceptT SMPClientError IO QueueIdsKeys createSMPQueue c nonce_ (rKey, rpKey) dhKey auth subMode qrd = - sendProtocolCommand_ c nonce_ Nothing (Just rpKey) NoEntity (Cmd SRecipient $ NEW $ NewQueueReq rKey dhKey auth subMode (Just qrd)) >>= \case + sendProtocolCommand_ c nonce_ Nothing (Just rpKey) NoEntity (Cmd SCreator $ NEW $ NewQueueReq rKey dhKey auth subMode (Just qrd)) >>= \case IDS qik -> pure qik r -> throwE $ unexpectedResponse r @@ -853,7 +848,7 @@ nsubResponse_ = \case r' -> Left $ unexpectedResponse r' {-# INLINE nsubResponse_ #-} -subscribeService :: forall p. (PartyI p, SubscriberParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO Int64 +subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO Int64 subscribeService c party = case smpClientService c of Just THClientService {serviceId, serviceKey} -> do liftIO $ enablePings c @@ -863,8 +858,8 @@ subscribeService c party = case smpClientService c of where subCmd :: Command p subCmd = case party of - SRecipient -> SUBS - SNotifier -> NSUBS + SRecipientService -> SUBS + SNotifierService -> NSUBS Nothing -> throwE PCEServiceUnavailable smpClientService :: SMPClient -> Maybe THClientService @@ -1107,7 +1102,7 @@ proxySMPCommand c@ProtocolClient {thParams = proxyThParams, client_ = PClient {c let cmdSecret = C.dh' serverKey cmdPrivKey nonce@(C.CbNonce corrId) <- liftIO . atomically $ C.randomCbNonce g -- encode - let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth serverThParams ((CorrId corrId, sId), Cmd (sParty @p) command) + let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth serverThParams (CorrId corrId, sId, Cmd (sParty @p) command) -- serviceAuth is False here – proxied commands are not used with service certificates auth <- liftEitherWith PCETransportError $ authTransmission serverThAuth False spKey nonce tForAuth b <- case batchTransmissions serverThParams [Right (auth, tToSend)] of @@ -1124,11 +1119,11 @@ proxySMPCommand c@ProtocolClient {thParams = proxyThParams, client_ = PClient {c -- server interaction errors are thrown directly t' <- liftEitherWith PCECryptoError $ C.cbDecrypt cmdSecret (C.reverseNonce nonce) er case tParse serverThParams t' of - t'' :| [] -> case tDecodeParseValidate serverThParams t'' of - Right (_, (_, cmd)) -> case cmd of - ERR e -> throwE $ PCEProtocolError e -- this is the error from the destination relay - r' -> pure $ Right r' - Left (_, err) -> throwE $ PCEResponseError err + t'' :| [] -> case tDecodeClient serverThParams t'' of + (_, _, cmd) -> case cmd of + Right (ERR e) -> throwE $ PCEProtocolError e -- this is the error from the destination relay + Right r' -> pure $ Right r' + Left e -> throwE $ PCEResponseError e _ -> throwE $ PCETransportError TEBadBlock ERR e -> pure . Left $ ProxyProtocolError e -- this will not happen, this error is returned via Left _ -> pure . Left $ ProxyUnexpectedResponse $ take 32 $ show r @@ -1280,7 +1275,7 @@ mkTransmission c = mkTransmission_ c Nothing mkTransmission_ :: forall v err msg. Protocol v err msg => ProtocolClient v err msg -> Maybe C.CbNonce -> ClientCommand msg -> IO (PCTransmission err msg) mkTransmission_ ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCommands}} nonce_ (entityId, pKey_, command) = do nonce@(C.CbNonce corrId) <- maybe (atomically $ C.randomCbNonce clientCorrId) pure nonce_ - let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams ((CorrId corrId, entityId), command) + let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (CorrId corrId, entityId, command) auth = authTransmission (thAuth thParams) (useServiceAuth command) pKey_ nonce tForAuth r <- mkRequest (CorrId corrId) pure ((,tToSend) <$> auth, r) diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index b63b69e32..998be2797 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -70,9 +70,9 @@ import Simplex.Messaging.Protocol QueueId, SMPServer, SParty (..), - SubscriberParty, - subscriberParty, - subscriberServiceRole + ServiceParty, + serviceParty, + partyServiceRole ) import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) @@ -331,11 +331,11 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> ExceptT SMPClientError IO () reconnectSMPClient ca@SMPClientAgent {agentCfg, agentParty} srv (sSub_, qSubs_) = - withSMP ca srv $ \smp -> liftIO $ case subscriberParty agentParty of + withSMP ca srv $ \smp -> liftIO $ case serviceParty agentParty of Just Dict -> resubscribe smp Nothing -> pure () where - resubscribe :: (PartyI p, SubscriberParty p) => SMPClient -> IO () + resubscribe :: (PartyI p, ServiceParty p) => SMPClient -> IO () resubscribe smp = do mapM_ (smpSubscribeService ca smp srv) sSub_ forM_ qSubs_ $ \qSubs -> do @@ -394,22 +394,22 @@ withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPE logInfo $ "SMP error (" <> safeDecodeUtf8 (strEncode $ host srv) <> "): " <> tshow e throwE e -subscribeQueuesNtfs :: SMPClientAgent 'Notifier -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO () +subscribeQueuesNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO () subscribeQueuesNtfs = subscribeQueues_ {-# INLINE subscribeQueuesNtfs #-} -subscribeQueues_ :: SubscriberParty p => SMPClientAgent p -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO () +subscribeQueues_ :: ServiceParty p => SMPClientAgent p -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO () subscribeQueues_ ca srv subs = do atomically $ addPendingSubs ca srv $ L.toList subs runExceptT (getSMPServerClient' ca srv) >>= \case Right smp -> smpSubscribeQueues ca smp srv subs Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that -smpSubscribeQueues :: SubscriberParty p => SMPClientAgent p -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO () +smpSubscribeQueues :: ServiceParty p => SMPClientAgent p -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO () smpSubscribeQueues ca smp srv subs = do rs <- case agentParty ca of - SRecipient -> subscribeSMPQueues smp subs - SNotifier -> subscribeSMPQueuesNtfs smp subs + SRecipientService -> subscribeSMPQueues smp subs + SNotifierService -> subscribeSMPQueuesNtfs smp subs rs' <- atomically $ ifM @@ -454,18 +454,18 @@ smpSubscribeQueues ca smp srv subs = do notify_ :: (SMPServer -> NonEmpty a -> SMPClientAgentEvent) -> [a] -> IO () notify_ evt qs = mapM_ (notify ca . evt srv) $ L.nonEmpty qs -subscribeServiceNtfs :: SMPClientAgent 'Notifier -> SMPServer -> (ServiceId, Int64) -> IO () +subscribeServiceNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> (ServiceId, Int64) -> IO () subscribeServiceNtfs = subscribeService_ {-# INLINE subscribeServiceNtfs #-} -subscribeService_ :: (PartyI p, SubscriberParty p) => SMPClientAgent p -> SMPServer -> (ServiceId, Int64) -> IO () +subscribeService_ :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPServer -> (ServiceId, Int64) -> IO () subscribeService_ ca srv serviceSub = do atomically $ setPendingServiceSub ca srv $ Just serviceSub runExceptT (getSMPServerClient' ca srv) >>= \case Right smp -> smpSubscribeService ca smp srv serviceSub Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that -smpSubscribeService :: (PartyI p, SubscriberParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> (ServiceId, Int64) -> IO () +smpSubscribeService :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> (ServiceId, Int64) -> IO () smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService smp of Just service | serviceAvailable service -> subscribe _ -> notifyUnavailable @@ -490,7 +490,7 @@ smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService setActiveServiceSub ca srv $ Just ((serviceId, n), sessId) setPendingServiceSub ca srv Nothing serviceAvailable THClientService {serviceRole, serviceId = serviceId'} = - serviceId == serviceId' && subscriberServiceRole (agentParty ca) == serviceRole + serviceId == serviceId' && partyServiceRole (agentParty ca) == serviceRole notifyUnavailable = do atomically $ setPendingServiceSub ca srv Nothing notify ca $ CAServiceUnavailable srv serviceSub -- this will resubscribe all queues directly diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index acb038be9..6909be55d 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -62,7 +62,7 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore, TokenNtfMessag import Simplex.Messaging.Notifications.Server.Store.Postgres import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Notifications.Transport -import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceId, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGet, tPut) +import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceId, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGetServer, tPut) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server import Simplex.Messaging.Server.Control (CPClientRole (..)) @@ -277,21 +277,21 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions} apnsPushQLength } where - getSMPServiceSubMetrics :: forall sub. SMPClientAgent 'Notifier -> (SMPClientAgent 'Notifier -> TMap SMPServer (TVar (Maybe sub))) -> (sub -> Int64) -> IO NtfSMPSubMetrics + getSMPServiceSubMetrics :: forall sub. SMPClientAgent 'NotifierService -> (SMPClientAgent 'NotifierService -> TMap SMPServer (TVar (Maybe sub))) -> (sub -> Int64) -> IO NtfSMPSubMetrics getSMPServiceSubMetrics a sel subQueueCount = getSubMetrics_ a sel countSubs where countSubs :: (NtfSMPSubMetrics, S.Set Text) -> (SMPServer, TVar (Maybe sub)) -> IO (NtfSMPSubMetrics, S.Set Text) countSubs acc (srv, serviceSubs) = maybe acc (subMetricsResult a acc srv . fromIntegral . subQueueCount) <$> readTVarIO serviceSubs - getSMPSubMetrics :: SMPClientAgent 'Notifier -> (SMPClientAgent 'Notifier -> TMap SMPServer (TMap NotifierId a)) -> IO NtfSMPSubMetrics + getSMPSubMetrics :: SMPClientAgent 'NotifierService -> (SMPClientAgent 'NotifierService -> TMap SMPServer (TMap NotifierId a)) -> IO NtfSMPSubMetrics getSMPSubMetrics a sel = getSubMetrics_ a sel countSubs where countSubs :: (NtfSMPSubMetrics, S.Set Text) -> (SMPServer, TMap NotifierId a) -> IO (NtfSMPSubMetrics, S.Set Text) countSubs acc (srv, queueSubs) = subMetricsResult a acc srv . M.size <$> readTVarIO queueSubs getSubMetrics_ :: - SMPClientAgent 'Notifier -> - (SMPClientAgent 'Notifier -> TVar (M.Map SMPServer sub')) -> + SMPClientAgent 'NotifierService -> + (SMPClientAgent 'NotifierService -> TVar (M.Map SMPServer sub')) -> ((NtfSMPSubMetrics, S.Set Text) -> (SMPServer, sub') -> IO (NtfSMPSubMetrics, S.Set Text)) -> IO NtfSMPSubMetrics getSubMetrics_ a sel countSubs = do @@ -300,7 +300,7 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions} (metrics', otherSrvs) <- foldM countSubs (metrics, S.empty) $ M.assocs subs pure (metrics' :: NtfSMPSubMetrics) {otherServers = S.size otherSrvs} - subMetricsResult :: SMPClientAgent 'Notifier -> (NtfSMPSubMetrics, S.Set Text) -> SMPServer -> Int -> (NtfSMPSubMetrics, S.Set Text) + subMetricsResult :: SMPClientAgent 'NotifierService -> (NtfSMPSubMetrics, S.Set Text) -> SMPServer -> Int -> (NtfSMPSubMetrics, S.Set Text) subMetricsResult a acc@(metrics, !otherSrvs) srv@(SMPServer (h :| _) _ _) cnt | isOwnServer a srv = let !ownSrvSubs' = M.alter (Just . maybe cnt (+ cnt)) host ownSrvSubs @@ -314,9 +314,9 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions} NtfSMPSubMetrics {ownSrvSubs, otherSrvSubCount} = metrics host = safeDecodeUtf8 $ strEncode h - getSMPWorkerMetrics :: SMPClientAgent 'Notifier -> TMap SMPServer a -> IO NtfSMPWorkerMetrics + getSMPWorkerMetrics :: SMPClientAgent 'NotifierService -> TMap SMPServer a -> IO NtfSMPWorkerMetrics getSMPWorkerMetrics a v = workerMetrics a . M.keys <$> readTVarIO v - workerMetrics :: SMPClientAgent 'Notifier -> [SMPServer] -> NtfSMPWorkerMetrics + workerMetrics :: SMPClientAgent 'NotifierService -> [SMPServer] -> NtfSMPWorkerMetrics workerMetrics a srvs = NtfSMPWorkerMetrics {ownServers = reverse ownSrvs, otherServers} where (ownSrvs, otherServers) = foldl' countSrv ([], 0) srvs @@ -455,7 +455,7 @@ resubscribe NtfSubscriber {smpAgent = ca} = do counts <- mapConcurrently (subscribeSrvSubs ca st batchSize) srvs logNote $ "Completed all SMP resubscriptions for " <> tshow (length srvs) <> " servers (" <> tshow (sum counts) <> " subscriptions)" -subscribeSrvSubs :: SMPClientAgent 'Notifier -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe (ServiceId, Int64)) -> IO Int +subscribeSrvSubs :: SMPClientAgent 'NotifierService -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe (ServiceId, Int64)) -> IO Int subscribeSrvSubs ca st batchSize (srv, srvId, service_) = do let srvStr = safeDecodeUtf8 (strEncode $ L.head $ host srv) logNote $ "Starting SMP resubscriptions for " <> srvStr @@ -709,25 +709,24 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte receive :: Transport c => NtfPostgresStore -> THandleNTF c 'TServer -> NtfServerClient -> IO () receive st th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ, rcvActiveAt} = forever $ do - ts <- L.toList <$> tGet th + ts <- L.toList <$> tGetServer th atomically . (writeTVar rcvActiveAt $!) =<< getSystemTime (errs, cmds) <- partitionEithers <$> mapM cmdAction ts write sndQ errs write rcvQ cmds where - cmdAction = - \case - Left (ceIds, e) -> do - logError $ "invalid client request: " <> tshow e - pure $ Left (ceIds, NRErr e) - Right t@(_, (ceIds@(corrId, _), cmd)) -> - verified =<< verifyNtfTransmission st ((,C.cbNonce (SMP.bs corrId)) <$> thAuth) t cmd - where - verified = \case - VRVerified req -> pure $ Right req - VRFailed e -> do - logError "unauthorized client request" - pure $ Left (ceIds, NRErr e) + cmdAction = \case + Left (corrId, entId, e) -> do + logError $ "invalid client request: " <> tshow e + pure $ Left (corrId, entId, NRErr e) + Right t@(_, _, (corrId, entId, _)) -> + verified =<< verifyNtfTransmission st thAuth t + where + verified = \case + VRVerified req -> pure $ Right req + VRFailed e -> do + logError "unauthorized client request" + pure $ Left (corrId, entId, NRErr e) write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty send :: Transport c => THandleNTF c 'TServer -> NtfServerClient -> IO () @@ -738,10 +737,10 @@ send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do data VerificationResult = VRVerified NtfRequest | VRFailed ErrorType -verifyNtfTransmission :: NtfPostgresStore -> Maybe (THandleAuth 'TServer, C.CbNonce) -> SignedTransmission NtfCmd -> NtfCmd -> IO VerificationResult -verifyNtfTransmission st auth_ ((tAuth, authorized), (ceIds@(corrId, entId), _)) = \case +verifyNtfTransmission :: NtfPostgresStore -> Maybe (THandleAuth 'TServer) -> SignedTransmission NtfCmd -> IO VerificationResult +verifyNtfTransmission st thAuth (tAuth, authorized, (corrId, entId, cmd)) = case cmd of NtfCmd SToken c@(TNEW tkn@(NewNtfTkn _ k _)) - | verifyCmdAuthorization auth_ tAuth authorized k -> + | verifyCmdAuthorization thAuth tAuth authorized corrId k -> result <$> findNtfTokenRegistration st tkn | otherwise -> pure $ VRFailed AUTH where @@ -761,19 +760,19 @@ verifyNtfTransmission st auth_ ((tAuth, authorized), (ceIds@(corrId, entId), _)) verify (t, s_) = verifyToken t $ case s_ of Nothing -> NtfReqNew corrId (ANE SSubscription sub) Just s -> subCmd s c - NtfCmd SSubscription PING -> pure $ VRVerified $ NtfReqPing ceIds + NtfCmd SSubscription PING -> pure $ VRVerified $ NtfReqPing corrId entId NtfCmd SSubscription c -> either err verify <$> getNtfSubscription st entId where verify (t, s) = verifyToken t $ subCmd s c where - tknCmd t c = NtfReqCmd SToken (NtfTkn t) (ceIds, c) - subCmd s c = NtfReqCmd SSubscription (NtfSub s) (ceIds, c) + tknCmd t c = NtfReqCmd SToken (NtfTkn t) (corrId, entId, c) + subCmd s c = NtfReqCmd SSubscription (NtfSub s) (corrId, entId, c) verifyToken :: NtfTknRec -> NtfRequest -> VerificationResult verifyToken NtfTknRec {tknVerifyKey} r - | verifyCmdAuthorization auth_ tAuth authorized tknVerifyKey = VRVerified r + | verifyCmdAuthorization thAuth tAuth authorized corrId tknVerifyKey = VRVerified r | otherwise = VRFailed AUTH err = \case -- signature verification for AUTH errors mitigates timing attacks for existence checks - AUTH -> maybe False (dummyVerifyCmd auth_ authorized) tAuth `seq` VRFailed AUTH + AUTH -> dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH e -> VRFailed e client :: NtfServerClient -> NtfSubscriber -> NtfPushServer -> M () @@ -785,7 +784,7 @@ client NtfServerClient {rcvQ, sndQ} ns@NtfSubscriber {smpAgent = ca} NtfPushServ where processCommand :: NtfRequest -> M (Transmission NtfResponse) processCommand = \case - NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn token _ dhPubKey)) -> ((corrId, NoEntity),) <$> do + NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn token _ dhPubKey)) -> (corrId,NoEntity,) <$> do logDebug "TNEW - new token" (srvDhPubKey, srvDhPrivKey) <- atomically . C.generateKeyPair =<< asks random let dhSecret = C.dh' dhPubKey srvDhPrivKey @@ -798,8 +797,8 @@ client NtfServerClient {rcvQ, sndQ} ns@NtfSubscriber {smpAgent = ca} NtfPushServ incNtfStatT token ntfVrfQueued incNtfStatT token tknCreated pure $ NRTknId tknId srvDhPubKey - NtfReqCmd SToken (NtfTkn tkn@NtfTknRec {token, ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhPrivKey}) (ctIds@(_, tknId), cmd) -> do - (ctIds,) <$> case cmd of + NtfReqCmd SToken (NtfTkn tkn@NtfTknRec {token, ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhPrivKey}) (corrId, tknId, cmd) -> do + (corrId,tknId,) <$> case cmd of TNEW (NewNtfTkn _ _ dhPubKey) -> do logDebug "TNEW - registered token" let dhSecret = C.dh' dhPubKey tknDhPrivKey @@ -860,9 +859,9 @@ client NtfServerClient {rcvQ, sndQ} ns@NtfSubscriber {smpAgent = ca} NtfPushServ incNtfStat subCreated pure $ NRSubId subId False -> pure $ NRErr AUTH - pure ((corrId, NoEntity), resp) - NtfReqCmd SSubscription (NtfSub NtfSubRec {ntfSubId, smpQueue = SMPQueueNtf {smpServer, notifierId}, notifierKey = registeredNKey, subStatus}) (csIds@(_, subId), cmd) -> do - (csIds,) <$> case cmd of + pure (corrId, NoEntity, resp) + NtfReqCmd SSubscription (NtfSub NtfSubRec {ntfSubId, smpQueue = SMPQueueNtf {smpServer, notifierId}, notifierKey = registeredNKey, subStatus}) (corrId, subId, cmd) -> do + (corrId,subId,) <$> case cmd of SNEW (NewNtfSub _ _ notifierKey) -> do logDebug "SNEW - existing subscription" pure $ @@ -880,7 +879,7 @@ client NtfServerClient {rcvQ, sndQ} ns@NtfSubscriber {smpAgent = ca} NtfPushServ incNtfStat subDeleted pure NROk PING -> pure NRPong - NtfReqPing ceIds -> pure (ceIds, NRPong) + NtfReqPing corrId entId -> pure (corrId, entId, NRPong) getId :: M NtfEntityId getId = fmap EntityId . randomBytes =<< asks (subIdBytes . config) getRegCode :: M NtfRegCode diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index e52f13a64..c80dd7741 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -127,7 +127,7 @@ newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbSt data NtfSubscriber = NtfSubscriber { smpSubscribers :: TMap SMPServer SMPSubscriberVar, subscriberSeq :: TVar Int, - smpAgent :: SMPClientAgent 'Notifier + smpAgent :: SMPClientAgent 'NotifierService } type SMPSubscriberVar = SessionVar SMPSubscriber @@ -136,7 +136,7 @@ newNtfSubscriber :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO NtfSubscriber newNtfSubscriber smpAgentCfg random = do smpSubscribers <- TM.emptyIO subscriberSeq <- newTVarIO 0 - smpAgent <- newSMPClientAgent SNotifier smpAgentCfg random + smpAgent <- newSMPClientAgent SNotifierService smpAgentCfg random pure NtfSubscriber {smpSubscribers, subscriberSeq, smpAgent} data SMPSubscriber = SMPSubscriber @@ -172,7 +172,7 @@ getPushClient s@NtfPushServer {pushClients} pp = data NtfRequest = NtfReqNew CorrId ANewNtfEntity | forall e. NtfEntityI e => NtfReqCmd (SNtfEntity e) (NtfEntityRec e) (Transmission (NtfCommand e)) - | NtfReqPing (CorrId, NtfEntityId) + | NtfReqPing CorrId NtfEntityId data NtfServerClient = NtfServerClient { rcvQ :: TBQueue (NonEmpty NtfRequest), diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index f4af0e5c9..b3a76f37e 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -66,8 +66,9 @@ module Simplex.Messaging.Protocol EncDataBytes (..), Party (..), Cmd (..), - DirectParty, - SubscriberParty, + QueueParty, + BatchParty, + ServiceParty, ASubscriberParty (..), BrokerMsg (..), SParty (..), @@ -80,6 +81,7 @@ module Simplex.Messaging.Protocol BrokerErrorType (..), BlockingInfo (..), BlockingReason (..), + RawTransmission, Transmission, TAuthorizations, TransmissionAuth (..), @@ -153,8 +155,11 @@ module Simplex.Messaging.Protocol currentSMPClientVersion, senderCanSecure, queueReqMode, - subscriberParty, - subscriberServiceRole, + queueParty, + batchParty, + serviceParty, + partyClientRole, + partyServiceRole, userProtocol, rcvMessageMeta, noMsgFlags, @@ -186,9 +191,11 @@ module Simplex.Messaging.Protocol TransportBatch (..), tPut, tPutLog, - tGet, + tGetServer, + tGetClient, tParse, - tDecodeParseValidate, + tDecodeServer, + tDecodeClient, tEncode, tEncodeBatch1, batchTransmissions, @@ -303,22 +310,40 @@ e2eEncMessageLength :: Int e2eEncMessageLength = 16000 -- 15988 .. 16005 -- | SMP protocol clients -data Party = Recipient | Sender | Notifier | LinkClient | ProxiedClient | ProxyService +data Party + = Creator + | Recipient + | RecipientService + | Sender + | IdleClient + | Notifier + | NotifierService + | LinkClient + | ProxiedClient + | ProxyService deriving (Show) -- | Singleton types for SMP protocol clients data SParty :: Party -> Type where + SCreator :: SParty Creator SRecipient :: SParty Recipient + SRecipientService :: SParty RecipientService SSender :: SParty Sender + SIdleClient :: SParty IdleClient SNotifier :: SParty Notifier + SNotifierService :: SParty NotifierService SSenderLink :: SParty LinkClient SProxiedClient :: SParty ProxiedClient SProxyService :: SParty ProxyService instance TestEquality SParty where + testEquality SCreator SCreator = Just Refl testEquality SRecipient SRecipient = Just Refl + testEquality SRecipientService SRecipientService = Just Refl testEquality SSender SSender = Just Refl + testEquality SIdleClient SIdleClient = Just Refl testEquality SNotifier SNotifier = Just Refl + testEquality SNotifierService SNotifierService = Just Refl testEquality SSenderLink SSenderLink = Just Refl testEquality SProxiedClient SProxiedClient = Just Refl testEquality SProxyService SProxyService = Just Refl @@ -328,34 +353,72 @@ deriving instance Show (SParty p) class PartyI (p :: Party) where sParty :: SParty p +instance PartyI Creator where sParty = SCreator + instance PartyI Recipient where sParty = SRecipient +instance PartyI RecipientService where sParty = SRecipientService + instance PartyI Sender where sParty = SSender +instance PartyI IdleClient where sParty = SIdleClient + instance PartyI Notifier where sParty = SNotifier +instance PartyI NotifierService where sParty = SNotifierService + instance PartyI LinkClient where sParty = SSenderLink instance PartyI ProxiedClient where sParty = SProxiedClient instance PartyI ProxyService where sParty = SProxyService -type family DirectParty (p :: Party) :: Constraint where - DirectParty Recipient = () - DirectParty Sender = () - DirectParty Notifier = () - DirectParty LinkClient = () - DirectParty ProxyService = () - DirectParty p = - (Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not direct")) +-- command parties that can read queues +type family QueueParty (p :: Party) :: Constraint where + QueueParty Recipient = () + QueueParty Sender = () + QueueParty Notifier = () + QueueParty LinkClient = () + QueueParty p = + (Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not QueueParty")) -type family SubscriberParty (p :: Party) :: Constraint where - SubscriberParty Recipient = () - SubscriberParty Notifier = () - SubscriberParty p = - (Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not subscriber")) +queueParty :: SParty p -> Maybe (Dict (PartyI p, QueueParty p)) +queueParty = \case + SRecipient -> Just Dict + SSender -> Just Dict + SSenderLink -> Just Dict + SNotifier -> Just Dict + _ -> Nothing +{-# INLINE queueParty #-} -data ASubscriberParty = forall p. (PartyI p, SubscriberParty p) => ASP (SParty p) +type family BatchParty (p :: Party) :: Constraint where + BatchParty Recipient = () + BatchParty Notifier = () + BatchParty p = + (Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not BatchParty")) + +batchParty :: SParty p -> Maybe (Dict (PartyI p, BatchParty p)) +batchParty = \case + SRecipient -> Just Dict + SNotifier -> Just Dict + _ -> Nothing +{-# INLINE batchParty #-} + +-- command parties that can subscribe to individual queues +type family ServiceParty (p :: Party) :: Constraint where + ServiceParty RecipientService = () + ServiceParty NotifierService = () + ServiceParty p = + (Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not ServiceParty")) + +serviceParty :: SParty p -> Maybe (Dict (PartyI p, ServiceParty p)) +serviceParty = \case + SRecipientService -> Just Dict + SNotifierService -> Just Dict + _ -> Nothing +{-# INLINE serviceParty #-} + +data ASubscriberParty = forall p. (PartyI p, ServiceParty p) => ASP (SParty p) deriving instance Show ASubscriberParty @@ -364,30 +427,37 @@ instance Eq ASubscriberParty where instance Encoding ASubscriberParty where smpEncode = \case - ASP SRecipient -> "R" - ASP SNotifier -> "N" + ASP SRecipientService -> "R" + ASP SNotifierService -> "N" smpP = A.anyChar >>= \case - 'R' -> pure $ ASP SRecipient - 'N' -> pure $ ASP SNotifier + 'R' -> pure $ ASP SRecipientService + 'N' -> pure $ ASP SNotifierService _ -> fail "bad ASubscriberParty" instance StrEncoding ASubscriberParty where strEncode = smpEncode strP = smpP -subscriberParty :: SParty p -> Maybe (Dict (PartyI p, SubscriberParty p)) -subscriberParty = \case - SRecipient -> Just Dict - SNotifier -> Just Dict - _ -> Nothing -{-# INLINE subscriberParty #-} +partyClientRole :: SParty p -> Maybe SMPServiceRole +partyClientRole = \case + SCreator -> Just SRMessaging + SRecipient -> Just SRMessaging + SRecipientService -> Just SRMessaging + SSender -> Just SRMessaging + SIdleClient -> Nothing + SNotifier -> Just SRNotifier + SNotifierService -> Just SRNotifier + SSenderLink -> Just SRMessaging + SProxiedClient -> Just SRMessaging + SProxyService -> Just SRProxy +{-# INLINE partyClientRole #-} -subscriberServiceRole :: SubscriberParty p => SParty p -> SMPServiceRole -subscriberServiceRole = \case - SRecipient -> SRMessaging - SNotifier -> SRNotifier -{-# INLINE subscriberServiceRole #-} +partyServiceRole :: ServiceParty p => SParty p -> SMPServiceRole +partyServiceRole = \case + SRecipientService -> SRMessaging + SNotifierService -> SRNotifier +{-# INLINE partyServiceRole #-} -- | Type for client command of any participant. data Cmd = forall p. PartyI p => Cmd (SParty p) (Command p) @@ -395,10 +465,10 @@ data Cmd = forall p. PartyI p => Cmd (SParty p) (Command p) deriving instance Show Cmd -- | Parsed SMP transmission without signature, size and session ID. -type Transmission c = ((CorrId, EntityId), c) +type Transmission c = (CorrId, EntityId, c) -- | signed parsed transmission, with original raw bytes and parsing error. -type SignedTransmission c = ((Maybe TAuthorizations, Signed), Transmission c) +type SignedTransmission c = (Maybe TAuthorizations, Signed, Transmission c) type SignedTransmissionOrError e c = Either (Transmission e) (SignedTransmission c) @@ -465,10 +535,10 @@ data Command (p :: Party) where -- v6 of SMP servers only support signature algorithm for command authorization. -- v7 of SMP servers additionally support additional layer of authenticated encryption. -- RcvPublicAuthKey is defined as C.APublicKey - it can be either signature or DH public keys. - NEW :: NewQueueReq -> Command Recipient + NEW :: NewQueueReq -> Command Creator SUB :: Command Recipient -- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command. - SUBS :: Command Recipient + SUBS :: Command RecipientService KEY :: SndPublicAuthKey -> Command Recipient RKEY :: NonEmpty RcvPublicAuthKey -> Command Recipient LSET :: LinkId -> QueueLinkData -> Command Recipient @@ -485,14 +555,14 @@ data Command (p :: Party) where -- SEND v1 has to be supported for encoding/decoding -- SEND :: MsgBody -> Command Sender SEND :: MsgFlags -> MsgBody -> Command Sender - PING :: Command Sender + PING :: Command IdleClient -- Client accessing short links LKEY :: SndPublicAuthKey -> Command LinkClient LGET :: Command LinkClient -- SMP notification subscriber commands NSUB :: Command Notifier -- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command. - NSUBS :: Command Notifier + NSUBS :: Command NotifierService PRXY :: SMPServer -> Maybe BasicAuth -> Command ProxiedClient -- request a relay server connection by URI -- Transmission to proxy: -- - entity ID: ID of the session with relay returned in PKEY (response to PRXY) @@ -826,9 +896,9 @@ noMsgFlags = MsgFlags {notification = False} -- * SMP command tags data CommandTag (p :: Party) where - NEW_ :: CommandTag Recipient + NEW_ :: CommandTag Creator SUB_ :: CommandTag Recipient - SUBS_ :: CommandTag Recipient + SUBS_ :: CommandTag RecipientService KEY_ :: CommandTag Recipient RKEY_ :: CommandTag Recipient LSET_ :: CommandTag Recipient @@ -842,14 +912,14 @@ data CommandTag (p :: Party) where QUE_ :: CommandTag Recipient SKEY_ :: CommandTag Sender SEND_ :: CommandTag Sender - PING_ :: CommandTag Sender + PING_ :: CommandTag IdleClient LKEY_ :: CommandTag LinkClient LGET_ :: CommandTag LinkClient PRXY_ :: CommandTag ProxiedClient PFWD_ :: CommandTag ProxiedClient RFWD_ :: CommandTag ProxyService NSUB_ :: CommandTag Notifier - NSUBS_ :: CommandTag Notifier + NSUBS_ :: CommandTag NotifierService data CmdTag = forall p. PartyI p => CT (SParty p) (CommandTag p) @@ -915,9 +985,9 @@ instance PartyI p => Encoding (CommandTag p) where instance ProtocolMsgTag CmdTag where decodeTag = \case - "NEW" -> Just $ CT SRecipient NEW_ + "NEW" -> Just $ CT SCreator NEW_ "SUB" -> Just $ CT SRecipient SUB_ - "SUBS" -> Just $ CT SRecipient SUBS_ + "SUBS" -> Just $ CT SRecipientService SUBS_ "KEY" -> Just $ CT SRecipient KEY_ "RKEY" -> Just $ CT SRecipient RKEY_ "LSET" -> Just $ CT SRecipient LSET_ @@ -931,14 +1001,14 @@ instance ProtocolMsgTag CmdTag where "QUE" -> Just $ CT SRecipient QUE_ "SKEY" -> Just $ CT SSender SKEY_ "SEND" -> Just $ CT SSender SEND_ - "PING" -> Just $ CT SSender PING_ + "PING" -> Just $ CT SIdleClient PING_ "LKEY" -> Just $ CT SSenderLink LKEY_ "LGET" -> Just $ CT SSenderLink LGET_ "PRXY" -> Just $ CT SProxiedClient PRXY_ "PFWD" -> Just $ CT SProxiedClient PFWD_ "RFWD" -> Just $ CT SProxyService RFWD_ "NSUB" -> Just $ CT SNotifier NSUB_ - "NSUBS" -> Just $ CT SNotifier NSUBS_ + "NSUBS" -> Just $ CT SNotifierService NSUBS_ _ -> Nothing instance Encoding CmdTag where @@ -1563,7 +1633,7 @@ instance Protocol SMPVersion ErrorType BrokerMsg where Cmd _ NSUB -> True _ -> False {-# INLINE useServiceAuth #-} - protocolPing = Cmd SSender PING + protocolPing = Cmd SIdleClient PING {-# INLINE protocolPing #-} protocolError = \case ERR e -> Just e @@ -1662,14 +1732,14 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where {-# INLINE encodeProtocol #-} protocolP v = \case - CT SRecipient tag -> - Cmd SRecipient <$> case tag of - NEW_ - | v >= shortLinksSMPVersion -> NEW <$> new smpP smpP - | v >= sndAuthKeySMPVersion -> NEW <$> new smpP (qReq <$> smpP) - | otherwise -> NEW <$> new auth (pure Nothing) + CT SCreator NEW_ -> Cmd SCreator <$> newCmd + where + newCmd + | v >= shortLinksSMPVersion = new smpP smpP + | v >= sndAuthKeySMPVersion = new smpP (qReq <$> smpP) + | otherwise = new auth (pure Nothing) where - new p1 p2 = do + new p1 p2 = NEW <$> do rcvAuthKey <- _smpP rcvDhKey <- smpP auth_ <- p1 @@ -1680,8 +1750,9 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where pure NewQueueReq {rcvAuthKey, rcvDhKey, auth_, subMode, queueReqData} -- ntfCreds auth = optional (A.char 'A' *> smpP) qReq sndSecure = Just $ if sndSecure then QRMessaging Nothing else QRContact Nothing + CT SRecipient tag -> + Cmd SRecipient <$> case tag of SUB_ -> pure SUB - SUBS_ -> pure SUBS KEY_ -> KEY <$> _smpP RKEY_ -> RKEY <$> _smpP LSET_ -> LSET <$> _smpP <*> smpP @@ -1693,11 +1764,12 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where OFF_ -> pure OFF DEL_ -> pure DEL QUE_ -> pure QUE + CT SRecipientService SUBS_ -> pure $ Cmd SRecipientService SUBS CT SSender tag -> Cmd SSender <$> case tag of SKEY_ -> SKEY <$> _smpP SEND_ -> SEND <$> _smpP <*> (unTail <$> _smpP) - PING_ -> pure PING + CT SIdleClient PING_ -> pure $ Cmd SIdleClient PING CT SProxyService RFWD_ -> Cmd SProxyService . RFWD . EncFwdTransmission . unTail <$> _smpP CT SSenderLink tag -> @@ -1708,10 +1780,8 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where Cmd SProxiedClient <$> case tag of PFWD_ -> PFWD <$> _smpP <*> smpP <*> (EncTransmission . unTail <$> smpP) PRXY_ -> PRXY <$> _smpP <*> smpP - CT SNotifier tag -> - pure $ Cmd SNotifier $ case tag of - NSUB_ -> NSUB - NSUBS_ -> NSUBS + CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB + CT SNotifierService NSUBS_ -> pure $ Cmd SNotifierService NSUBS fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg {-# INLINE fromProtocolError #-} @@ -2056,7 +2126,7 @@ encodeTransmission THandleParams {thVersion = v, sessionId, implySessId} t = {-# INLINE encodeTransmission #-} encodeTransmission_ :: ProtocolEncoding v e c => Version v -> Transmission c -> ByteString -encodeTransmission_ v ((CorrId corrId, queueId), command) = +encodeTransmission_ v (CorrId corrId, queueId, command) = smpEncode (corrId, queueId) <> encodeProtocol v command {-# INLINE encodeTransmission_ #-} @@ -2076,25 +2146,51 @@ tParse thParams@THandleParams {batch} s eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b) eitherList = either (\e -> [Left e]) --- | Receive client and server transmissions (determined by `cmd` type). -tGet :: forall v err cmd c p. (ProtocolEncoding v err cmd, Transport c) => THandle v c p -> IO (NonEmpty (SignedTransmissionOrError err cmd)) -tGet th@THandle {params} = L.map (tDecodeParseValidate params) <$> tGetParse th +-- | Receive server transmissions +tGetServer :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TServer -> IO (NonEmpty (SignedTransmissionOrError err cmd)) +tGetServer = tGet tDecodeServer +{-# INLINE tGetServer #-} -tDecodeParseValidate :: forall v p err cmd. ProtocolEncoding v err cmd => THandleParams v p -> Either TransportError RawTransmission -> SignedTransmissionOrError err cmd -tDecodeParseValidate THandleParams {sessionId, thVersion = v, implySessId} = \case +-- | Receive client transmissions +tGetClient :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TClient -> IO (NonEmpty (Transmission (Either err cmd))) +tGetClient = tGet tDecodeClient +{-# INLINE tGetClient #-} + +tGet :: + Transport c => + (THandleParams v p -> Either TransportError RawTransmission -> r) -> + THandle v c p -> + IO (NonEmpty r) +tGet tDecode th@THandle {params} = L.map (tDecode params) <$> tGetParse th +{-# INLINE tGet #-} + +tDecodeServer :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v 'TServer -> Either TransportError RawTransmission -> SignedTransmissionOrError err cmd +tDecodeServer THandleParams {sessionId, thVersion = v, implySessId} = \case Right RawTransmission {authenticator, serviceSig, authorized, sessId, corrId, entityId, command} | implySessId || sessId == sessionId -> case decodeTAuthBytes authenticator serviceSig of - Right tAuth -> bimap t (((tAuth, authorized),) . t) cmdOrErr + Right tAuth -> bimap t ((tAuth,authorized,) . t) cmdOrErr where cmdOrErr = parseProtocol @v @err @cmd v command >>= checkCredentials tAuth entityId - t :: a -> ((CorrId, EntityId), a) - t = ((corrId, entityId),) - Left _ -> tError corrId - | otherwise -> Left ((corrId, NoEntity), fromProtocolError @v @err @cmd PESession) - Left _ -> tError "" + t :: a -> (CorrId, EntityId, a) + t = (corrId,entityId,) + Left _ -> tError corrId PEBlock + | otherwise -> tError corrId PESession + Left _ -> tError "" PEBlock where - tError :: CorrId -> SignedTransmissionOrError err cmd - tError corrId = Left ((corrId, NoEntity), fromProtocolError @v @err @cmd PEBlock) + tError :: CorrId -> ProtocolErrorType -> SignedTransmissionOrError err cmd + tError corrId err = Left (corrId, NoEntity, fromProtocolError @v @err @cmd err) + +tDecodeClient :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v 'TClient -> Either TransportError RawTransmission -> Transmission (Either err cmd) +tDecodeClient THandleParams {sessionId, thVersion = v, implySessId} = \case + Right RawTransmission {sessId, corrId, entityId, command} + | implySessId || sessId == sessionId -> (corrId, entityId, cmdOrErr) + | otherwise -> tError corrId PESession + where + cmdOrErr = parseProtocol @v @err @cmd v command >>= checkCredentials Nothing entityId + Left _ -> tError "" PEBlock + where + tError :: CorrId -> ProtocolErrorType -> Transmission (Either err cmd) + tError corrId err = (corrId, NoEntity, Left $ fromProtocolError @v @err @cmd err) $(J.deriveJSON defaultJSON ''MsgFlags) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 8c48b31f9..dc69ea339 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -52,12 +52,13 @@ import Control.Monad.Reader import Control.Monad.Trans.Except import Control.Monad.STM (retry) import Crypto.Random (ChaChaDRG) -import Data.Bifunctor (first) +import Data.Bifunctor (first, second) import Data.ByteString.Base64 (encode) import qualified Data.ByteString.Builder as BLD import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB +import Data.Constraint (Dict (..)) import Data.Dynamic (toDyn) import Data.Either (fromRight, partitionEithers) import Data.Functor (($>)) @@ -392,7 +393,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt ntfs -> do writeTVar v [] pure $ foldl' (\acc' ntf -> nmsg nId ntf : acc') acc ntfs -- reverses, to order by time - nmsg nId MsgNtf {ntfNonce, ntfEncMeta} = ((CorrId "", nId), NMSG ntfNonce ntfEncMeta) + nmsg nId MsgNtf {ntfNonce, ntfEncMeta} = (CorrId "", nId, NMSG ntfNonce ntfEncMeta) updateNtfStats :: Client s' -> Either SomeException Int -> IO () updateNtfStats Client {clientId} = \case Right 0 -> pure () @@ -425,7 +426,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt forkClient c ("sendPendingEvtsThread.queueEvts") $ atomically (writeTBQueue sndQ ts) >> updateEndStats where - ts = L.map (\(entId, evt) -> ((CorrId "", entId), evt)) evts + ts = L.map (\(entId, evt) -> (CorrId "", entId, evt)) evts -- this accounts for both END and DELD events updateEndStats = do let len = L.length evts @@ -1068,48 +1069,54 @@ cancelSub s = case subThread s of _ -> pure () ProhibitSub -> pure () +type VerifiedTransmissionOrError s = Either (Transmission BrokerMsg) (VerifiedTransmission s) + receive :: forall c s. (Transport c, MsgStoreClass s) => THandleSMP c 'TServer -> s -> Client s -> M s () receive h@THandle {params = THandleParams {thAuth, sessionId}} ms Client {rcvQ, sndQ, rcvActiveAt} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive" sa <- asks serverActive stats <- asks serverStats liftIO $ forever $ do - ts <- tGet h + ts <- tGetServer h unlessM (readTVarIO sa) $ throwIO $ userError "server stopped" atomically . (writeTVar rcvActiveAt $!) =<< getSystemTime - let service = peerClientService =<< thAuth - (errs, cmds) <- partitionEithers <$> mapM (cmdAction stats service) (L.toList ts) - updateBatchStats stats cmds - write sndQ errs - write rcvQ cmds + let (es, ts') = partitionEithers $ L.toList ts + errs = map (second ERR) es + case ts' of + [] -> write sndQ errs + (_, _, (_, _, Cmd p cmd)) : _ -> do + let service = peerClientService =<< thAuth + (errs', cmds) <- partitionEithers <$> case batchParty p of + Just Dict | all (sameParty p) ts'-> do + updateBatchStats stats cmd -- even if nothing is verified + let queueId (_, _, (_, qId, _)) = qId + qs <- getQueueRecs ms p $ map queueId ts' + zipWithM (\t -> verified stats t . verifyLoadedQueue service thAuth t) ts' qs + _ -> mapM (\t -> verified stats t =<< verifyTransmission ms service thAuth t) ts' + write rcvQ cmds + write sndQ $ errs ++ errs' where - updateBatchStats :: ServerStats -> [(Maybe (StoreQueue s, QueueRec), Transmission Cmd)] -> IO () + sameParty :: SParty p -> SignedTransmission Cmd -> Bool + sameParty p (_, _, (_, _, Cmd p' _)) = isJust $ testEquality p p' + updateBatchStats :: ServerStats -> Command p -> IO () updateBatchStats stats = \case - (_, (_, (Cmd _ cmd))) : _ -> do - let sel_ = case cmd of - SUB -> Just qSubAllB - DEL -> Just qDeletedAllB - NSUB -> Just ntfSubB - NDEL -> Just ntfDeletedB - _ -> Nothing - mapM_ (\sel -> incStat $ sel stats) sel_ - [] -> pure () - cmdAction :: ServerStats -> Maybe THPeerClientService -> SignedTransmissionOrError ErrorType Cmd -> IO (Either (Transmission BrokerMsg) (Maybe (StoreQueue s, QueueRec), Transmission Cmd)) - cmdAction stats service = - \case - Left (ceIds, e) -> pure $ Left (ceIds, ERR e) - Right (signed, (ceIds@(corrId, entId), cmd)) -> verified =<< verifyTransmission ms service ((,C.cbNonce (bs corrId)) <$> thAuth) signed entId cmd - where - verified = \case - VRVerified q -> pure $ Right (q, (ceIds, cmd)) - VRFailed e -> do - case cmd of - Cmd _ SEND {} -> incStat $ msgSentAuth stats - Cmd _ SUB -> incStat $ qSubAuth stats - Cmd _ NSUB -> incStat $ ntfSubAuth stats - Cmd _ GET -> incStat $ msgGetAuth stats - _ -> pure () - pure $ Left (ceIds, ERR e) + SUB -> incStat $ qSubAllB stats + DEL -> incStat $ qDeletedAllB stats + NDEL -> incStat $ ntfDeletedB stats + NSUB -> incStat $ ntfSubB stats + _ -> pure () + verified :: ServerStats -> SignedTransmission Cmd -> VerificationResult s -> IO (VerifiedTransmissionOrError s) + verified stats (_, _, t@(corrId, entId, Cmd _ command)) = \case + VRVerified q -> pure $ Right (q, t) + VRFailed AUTH -> do + case command of + SEND {} -> incStat $ msgSentAuth stats + SUB -> incStat $ qSubAuth stats + NSUB -> incStat $ ntfSubAuth stats + GET -> incStat $ msgGetAuth stats + _ -> pure () + pure $ Left (corrId, entId, ERR AUTH) + VRFailed e -> pure $ Left (corrId, entId, ERR e) write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty send :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> IO () @@ -1132,9 +1139,9 @@ send th c@Client {sndQ, msgQ, clientTHParams = THandleParams {sessionId}} = do mapM_ (atomically . writeTBQueue msgQ) $ L.nonEmpty msgs_ where splitMessages :: [Transmission BrokerMsg] -> Transmission BrokerMsg -> ([Transmission BrokerMsg], Transmission BrokerMsg) - splitMessages msgs t@((corrId, entId), cmd) = case cmd of + splitMessages msgs t@(corrId, entId, cmd) = case cmd of -- replace MSG response with OK, accumulating MSG in a separate list. - MSG {} -> (((CorrId "", entId), cmd) : msgs, ((corrId, entId), OK)) + MSG {} -> ((CorrId "", entId, cmd) : msgs, (corrId, entId, OK)) _ -> (msgs, t) sendMsg :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> IO () @@ -1169,34 +1176,42 @@ data VerificationResult s = VRVerified (Maybe (StoreQueue s, QueueRec)) | VRFail -- - the queue or party key do not exist. -- In all cases, the time of the verification should depend only on the provided authorization type, -- a dummy key is used to run verification in the last two cases, and failure is returned irrespective of the result. -verifyTransmission :: forall s. MsgStoreClass s => s -> Maybe THPeerClientService -> Maybe (THandleAuth 'TServer, C.CbNonce) -> (Maybe TAuthorizations, ByteString) -> QueueId -> Cmd -> IO (VerificationResult s) -verifyTransmission ms service auth_ (tAuth, authorized) queueId command@(Cmd party cmd) - | verifyServiceSig = case party of - SRecipient | hasRole SRMessaging -> case cmd of - NEW NewQueueReq {rcvAuthKey = k} -> pure $ Nothing `verifiedWith` k - SUB -> verifyQueue SRecipient $ \q -> Just q `verifiedWithKeys` recipientKeys (snd q) - SUBS -> pure verifyServiceCmd - _ -> verifyQueue SRecipient $ \q -> Just q `verifiedWithKeys` recipientKeys (snd q) - SSender | hasRole SRMessaging -> case cmd of - SKEY k -> verifySecure SSender k - -- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command - SEND {} -> verifyQueue SSender $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified (Just q) else VRFailed AUTH - PING -> pure $ VRVerified Nothing - SSenderLink | hasRole SRMessaging -> case cmd of - LKEY k -> verifySecure SSenderLink k - LGET -> verifyQueue SSenderLink $ \q -> if isContactQueue (snd q) then VRVerified (Just q) else VRFailed AUTH - SNotifier | hasRole SRNotifier -> case cmd of - NSUB -> verifyQueue SNotifier $ \q -> maybe dummyVerify (\n -> Just q `verifiedWith` notifierKey n) (notifier $ snd q) - NSUBS -> pure verifyServiceCmd - SProxiedClient | hasRole SRMessaging -> pure $ VRVerified Nothing - SProxyService | hasRole SRProxy -> pure $ VRVerified Nothing - _ -> pure $ VRFailed $ CMD PROHIBITED - | otherwise = pure $ VRFailed SERVICE +verifyTransmission :: forall s. MsgStoreClass s => s -> Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> IO (VerificationResult s) +verifyTransmission ms service thAuth t@(_, _, (_, queueId, Cmd p _)) = case queueParty p of + Just Dict -> verifyLoadedQueue service thAuth t <$> getQueueRec ms p queueId + Nothing -> pure $ verifyQueueTransmission service thAuth t Nothing + +verifyLoadedQueue :: Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> Either ErrorType (StoreQueue s, QueueRec) -> VerificationResult s +verifyLoadedQueue service thAuth t@(tAuth, authorized, (corrId, _, _)) = \case + Right q -> verifyQueueTransmission service thAuth t (Just q) + Left AUTH -> (dummyVerifyCmd thAuth tAuth authorized corrId) `seq` VRFailed AUTH + Left e -> VRFailed e + +verifyQueueTransmission :: forall s. Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> Maybe (StoreQueue s, QueueRec) -> VerificationResult s +verifyQueueTransmission service thAuth (tAuth, authorized, (corrId, _, cmd@(Cmd p command))) q_ + | not checkRole = VRFailed $ CMD PROHIBITED + | not verifyServiceSig = VRFailed SERVICE + | otherwise = vc p command where - hasRole role = case service of - Just THClientService {serviceRole} -> serviceRole == role - Nothing -> True - verify = verifyCmdAuthorization auth_ tAuth authorized' + vc :: SParty p -> Command p -> VerificationResult s -- this pattern match works with ghc8.10.7, flat case sees it as non-exhastive. + vc SCreator (NEW NewQueueReq {rcvAuthKey = k}) = verifiedWith k + vc SRecipient SUB = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q) + vc SRecipient _ = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q) + vc SRecipientService SUBS = verifyServiceCmd + vc SSender (SKEY k) = verifySecure k + -- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command + vc SSender SEND {} = verifyQueue $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified q_ else VRFailed AUTH + vc SIdleClient PING = VRVerified Nothing + vc SSenderLink (LKEY k) = verifySecure k + vc SSenderLink LGET = verifyQueue $ \q -> if isContactQueue (snd q) then VRVerified q_ else VRFailed AUTH + vc SNotifier NSUB = verifyQueue $ \q -> maybe dummyVerify (\n -> verifiedWith $ notifierKey n) (notifier $ snd q) + vc SNotifierService NSUBS = verifyServiceCmd + vc SProxiedClient _ = VRVerified Nothing + vc SProxyService (RFWD _) = VRVerified Nothing + checkRole = case (service, partyClientRole p) of + (Just THClientService {serviceRole}, Just role) -> serviceRole == role + _ -> True + verify = verifyCmdAuthorization thAuth tAuth authorized' corrId verifyServiceCmd :: VerificationResult s verifyServiceCmd = case (service, tAuth) of (Just THClientService {serviceKey = k}, Just (TASignature (C.ASignature C.SEd25519 s), Nothing)) @@ -1204,7 +1219,7 @@ verifyTransmission ms service auth_ (tAuth, authorized) queueId command@(Cmd par _ -> VRFailed SERVICE -- this function verify service signature for commands that use it in service sessions verifyServiceSig - | useServiceAuth command = case (service, serviceSig) of + | useServiceAuth cmd = case (service, serviceSig) of (Just THClientService {serviceKey = k}, Just s) -> C.verify' k s authorized (Nothing, Nothing) -> True _ -> False @@ -1214,20 +1229,17 @@ verifyTransmission ms service auth_ (tAuth, authorized) queueId command@(Cmd par (Just THClientService {serviceCertHash = XV.Fingerprint fp}, Just _) -> fp <> authorized _ -> authorized dummyVerify :: VerificationResult s - dummyVerify = verify (dummyAuthKey tAuth) `seq` VRFailed AUTH - verifyQueue :: DirectParty p => SParty p -> ((StoreQueue s, QueueRec) -> VerificationResult s) -> IO (VerificationResult s) - verifyQueue p v = either err v <$> getQueueRec ms p queueId - where - -- this prevents reporting any STORE errors as AUTH errors - err = \case - AUTH -> dummyVerify - e -> VRFailed e - verifySecure :: DirectParty p => SParty p -> SndPublicAuthKey -> IO (VerificationResult s) - verifySecure p k = verifyQueue p $ \q -> if k `allowedKey` snd q then Just q `verifiedWith` k else dummyVerify - verifiedWith :: Maybe (StoreQueue s, QueueRec) -> C.APublicAuthKey -> VerificationResult s - verifiedWith q_ k = if verify k then VRVerified q_ else VRFailed AUTH - verifiedWithKeys :: Maybe (StoreQueue s, QueueRec) -> NonEmpty C.APublicAuthKey -> VerificationResult s - verifiedWithKeys q_ ks = if any verify ks then VRVerified q_ else VRFailed AUTH + dummyVerify = (dummyVerifyCmd thAuth tAuth authorized corrId) `seq` VRFailed AUTH + -- That a specific command requires queue signature verification is determined by `queueParty`, + -- it should be coordinated with the case in this function (`verifyQueueTransmission`) + verifyQueue :: ((StoreQueue s, QueueRec) -> VerificationResult s) -> VerificationResult s + verifyQueue v = maybe (VRFailed INTERNAL) v q_ + verifySecure :: SndPublicAuthKey -> VerificationResult s + verifySecure k = verifyQueue $ \q -> if k `allowedKey` snd q then verifiedWith k else dummyVerify + verifiedWith :: C.APublicAuthKey -> VerificationResult s + verifiedWith k = if verify k then VRVerified q_ else VRFailed AUTH + verifiedWithKeys :: NonEmpty C.APublicAuthKey -> VerificationResult s + verifiedWithKeys ks = if any verify ks then VRVerified q_ else VRFailed AUTH allowedKey k = \case QueueRec {queueMode = Just QMMessaging, senderKey} -> maybe True (k ==) senderKey _ -> False @@ -1243,8 +1255,9 @@ isSecuredMsgQueue QueueRec {queueMode, senderKey} = case queueMode of Just QMContact -> False _ -> isJust senderKey -verifyCmdAuthorization :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TAuthorizations -> ByteString -> C.APublicAuthKey -> Bool -verifyCmdAuthorization auth_ tAuth authorized key = maybe False (verify key) tAuth +-- Random correlation ID is used as a nonce in case crypto_box authenticator is used to authorize transmission +verifyCmdAuthorization :: Maybe (THandleAuth 'TServer) -> Maybe TAuthorizations -> ByteString -> CorrId -> C.APublicAuthKey -> Bool +verifyCmdAuthorization thAuth tAuth authorized corrId key = maybe False (verify key) tAuth where verify :: C.APublicAuthKey -> TAuthorizations -> Bool verify (C.APublicAuthKey a k) = \case @@ -1252,18 +1265,20 @@ verifyCmdAuthorization auth_ tAuth authorized key = maybe False (verify key) tAu Just Refl -> C.verify' k s authorized _ -> C.verify' (dummySignKey a') s authorized `seq` False (TAAuthenticator s, _) -> case a of - C.SX25519 -> verifyCmdAuth auth_ k s authorized - _ -> verifyCmdAuth auth_ dummyKeyX25519 s authorized `seq` False + C.SX25519 -> verifyCmdAuth thAuth k s authorized corrId + _ -> verifyCmdAuth thAuth dummyKeyX25519 s authorized corrId `seq` False -verifyCmdAuth :: Maybe (THandleAuth 'TServer, C.CbNonce) -> C.PublicKeyX25519 -> C.CbAuthenticator -> ByteString -> Bool -verifyCmdAuth auth_ k authenticator authorized = case auth_ of - Just (THAuthServer {serverPrivKey = pk}, nonce) -> C.cbVerify k pk nonce authenticator authorized +verifyCmdAuth :: Maybe (THandleAuth 'TServer) -> C.PublicKeyX25519 -> C.CbAuthenticator -> ByteString -> CorrId -> Bool +verifyCmdAuth thAuth k authenticator authorized (CorrId corrId) = case thAuth of + Just THAuthServer {serverPrivKey = pk} -> C.cbVerify k pk (C.cbNonce corrId) authenticator authorized Nothing -> False -dummyVerifyCmd :: Maybe (THandleAuth 'TServer, C.CbNonce) -> ByteString -> TAuthorizations -> Bool -dummyVerifyCmd auth_ authorized = \case - (TASignature (C.ASignature a s), _) -> C.verify' (dummySignKey a) s authorized - (TAAuthenticator s, _) -> verifyCmdAuth auth_ dummyKeyX25519 s authorized +dummyVerifyCmd :: Maybe (THandleAuth 'TServer) -> Maybe TAuthorizations -> ByteString -> CorrId -> Maybe Bool +dummyVerifyCmd thAuth tAuth authorized corrId = verify . fst <$> tAuth + where + verify = \case + TASignature (C.ASignature a s) -> C.verify' (dummySignKey a) s authorized + TAAuthenticator s -> verifyCmdAuth thAuth dummyKeyX25519 s authorized corrId -- These dummy keys are used with `dummyVerify` function to mitigate timing attacks -- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes @@ -1272,13 +1287,6 @@ dummySignKey = \case C.SEd25519 -> dummyKeyEd25519 C.SEd448 -> dummyKeyEd448 -dummyAuthKey :: Maybe TAuthorizations -> C.APublicAuthKey -dummyAuthKey = \case - Just (TASignature (C.ASignature a _), _) -> case a of - C.SEd25519 -> C.APublicAuthKey C.SEd25519 dummyKeyEd25519 - C.SEd448 -> C.APublicAuthKey C.SEd448 dummyKeyEd448 - _ -> C.APublicAuthKey C.SX25519 dummyKeyX25519 - dummyKeyEd25519 :: C.PublicKey 'C.Ed25519 dummyKeyEd25519 = "MCowBQYDK2VwAyEA139Oqs4QgpqbAmB0o7rZf6T19ryl7E65k4AYe0kE3Qs=" @@ -1313,7 +1321,7 @@ client reply :: MonadIO m => NonEmpty (Transmission BrokerMsg) -> m () reply = atomically . writeTBQueue sndQ processProxiedCmd :: Transmission (Command 'ProxiedClient) -> M s (Maybe (Transmission BrokerMsg)) - processProxiedCmd (ceIds@(corrId, EntityId sessId), command) = (ceIds,) <$$> case command of + processProxiedCmd (corrId, EntityId sessId, command) = (corrId,EntityId sessId,) <$$> case command of PRXY srv auth -> ifM allowProxy getRelay (pure $ Just $ ERR $ PROXY BASIC_AUTH) where allowProxy = do @@ -1376,7 +1384,7 @@ client forkProxiedCmd cmdAction = do bracket_ wait signal . forkClient clnt (B.unpack $ "client $" <> encode sessionId <> " proxy") $ do -- commands MUST be processed under a reasonable timeout or the client would halt - cmdAction >>= \t -> reply [(ceIds, t)] + cmdAction >>= \t -> reply [(corrId, EntityId sessId, t)] pure Nothing where wait = do @@ -1392,37 +1400,35 @@ client mkIncProxyStats ps psOwn own sel = do incStat $ sel ps when own $ incStat $ sel psOwn - processCommand :: Maybe THPeerClientService -> VersionSMP -> (Maybe (StoreQueue s, QueueRec), Transmission Cmd) -> M s (Maybe (Transmission BrokerMsg)) - processCommand service clntVersion (q_, (ceIds@(corrId, entId), cmd)) = case cmd of - Cmd SProxiedClient command -> processProxiedCmd (ceIds, command) + processCommand :: Maybe THPeerClientService -> VersionSMP -> VerifiedTransmission s -> M s (Maybe (Transmission BrokerMsg)) + processCommand service clntVersion (q_, (corrId, entId, cmd)) = case cmd of + Cmd SProxiedClient command -> processProxiedCmd (corrId, entId, command) Cmd SSender command -> Just <$> case command of SKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k SEND flags msgBody -> withQueue_ False $ sendMessage flags msgBody - PING -> pure ((corrId, NoEntity), PONG) - Cmd SProxyService (RFWD encBlock) -> Just . ((corrId, NoEntity),) <$> processForwardedCommand encBlock + Cmd SIdleClient PING -> pure $ Just (corrId, NoEntity, PONG) + Cmd SProxyService (RFWD encBlock) -> Just . (corrId,NoEntity,) <$> processForwardedCommand encBlock Cmd SSenderLink command -> Just <$> case command of LKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k $>> getQueueLink_ q qr LGET -> withQueue $ \q qr -> checkContact qr $ getQueueLink_ q qr - Cmd SNotifier command -> Just . (ceIds,) <$> case command of - NSUB -> case q_ of - Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds - _ -> pure $ ERR INTERNAL - NSUBS -> case service of - Just s -> subscribeServiceNotifications s - Nothing -> pure $ ERR INTERNAL + Cmd SNotifier NSUB -> Just . (corrId,entId,) <$> case q_ of + Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds + _ -> pure $ ERR INTERNAL + Cmd SNotifierService NSUBS -> Just . (corrId,entId,) <$> case service of + Just s -> subscribeServiceNotifications s + Nothing -> pure $ ERR INTERNAL + Cmd SCreator (NEW nqr@NewQueueReq {auth_}) -> + Just <$> ifM allowNew (createQueue nqr) (pure (corrId, entId, ERR AUTH)) + where + allowNew = do + ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config + pure $ allowNewQueues && maybe True ((== auth_) . Just) newQueueBasicAuth Cmd SRecipient command -> Just <$> case command of - NEW nqr@NewQueueReq {auth_} -> - ifM allowNew (createQueue nqr) (pure (ceIds, ERR AUTH)) - where - allowNew = do - ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config - pure $ allowNewQueues && maybe True ((== auth_) . Just) newQueueBasicAuth SUB -> withQueue subscribeQueue - SUBS -> pure $ err (CMD PROHIBITED) -- "TODO [certs rcv]" GET -> withQueue getMessage ACK msgId -> withQueue $ acknowledgeMsg msgId - KEY sKey -> withQueue $ \q _ -> either err (ceIds,) <$> secureQueue_ q sKey + KEY sKey -> withQueue $ \q _ -> either err (corrId,entId,) <$> secureQueue_ q sKey RKEY rKeys -> withQueue $ \q qr -> checkMode QMContact qr $ OK <$$ liftIO (updateKeys (queueStore ms) q rKeys) LSET lnkId d -> withQueue $ \q qr -> case queueData qr of @@ -1437,11 +1443,12 @@ client NDEL -> withQueue $ \q _ -> deleteQueueNotifier_ q OFF -> maybe (pure $ err INTERNAL) suspendQueue_ q_ DEL -> maybe (pure $ err INTERNAL) delQueueAndMsgs q_ - QUE -> withQueue $ \q qr -> (ceIds,) <$> getQueueInfo q qr + QUE -> withQueue $ \q qr -> (corrId,entId,) <$> getQueueInfo q qr + Cmd SRecipientService SUBS -> pure $ Just $ err (CMD PROHIBITED) -- "TODO [certs rcv]" where createQueue :: NewQueueReq -> M s (Transmission BrokerMsg) createQueue NewQueueReq {rcvAuthKey, rcvDhKey, subMode, queueReqData} - | isJust service && subMode == SMOnlyCreate = pure (ceIds, ERR $ CMD PROHIBITED) + | isJust service && subMode == SMOnlyCreate = pure (corrId, entId, ERR $ CMD PROHIBITED) | otherwise = time "NEW" $ do g <- asks random idSize <- asks $ queueIdBytes . config @@ -1492,7 +1499,7 @@ client | clntIds -> pure $ ERR AUTH -- no retry on collision if sender ID is client-supplied | otherwise -> tryCreate (n - 1) Left e -> pure $ ERR e - Right q -> do + Right _q -> do stats <- asks serverStats incStat $ qCreated stats incStat $ qCount stats @@ -1500,20 +1507,20 @@ client -- when (isJust ntf) $ incStat $ ntfCreated stats case subMode of SMOnlyCreate -> pure () - SMSubscribe -> void $ subscribeQueue q qr + SMSubscribe -> void $ subscribeNewQueue rcvId qr -- no need to check if message is available, it's a new queue pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId = fst <$> queueData, serviceId = rcvServiceId} -- , serverNtfCreds = snd <$> ntf - (ceIds,) <$> tryCreate (3 :: Int) + (corrId,entId,) <$> tryCreate (3 :: Int) -- this check allows to support contact queues created prior to SKEY, -- using `queueMode == Just QMContact` would prevent it, as they have queueMode `Nothing`. checkContact :: QueueRec -> M s (Either ErrorType BrokerMsg) -> M s (Transmission BrokerMsg) checkContact qr a = - either err (ceIds,) + either err (corrId,entId,) <$> if isContactQueue qr then a else pure $ Left AUTH checkMode :: QueueMode -> QueueRec -> M s (Either ErrorType BrokerMsg) -> M s (Transmission BrokerMsg) checkMode qm QueueRec {queueMode} a = - either err (ceIds,) + either err (corrId,entId,) <$> if queueMode == Just qm then a else pure $ Left AUTH secureQueue_ :: StoreQueue s -> SndPublicAuthKey -> M s (Either ErrorType BrokerMsg) @@ -1528,7 +1535,7 @@ client addQueueNotifier_ q notifierKey dhKey = time "NKEY" $ do (rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random let rcvNtfDhSecret = C.dh' dhKey privDhKey - (ceIds,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret + (corrId,entId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret where addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> M s BrokerMsg addNotifierRetry 0 _ _ = pure $ ERR INTERNAL @@ -1563,35 +1570,36 @@ client -- TODO [certs rcv] if serviceId is passed, associate with the service and respond with SOK subscribeQueue :: StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg) - subscribeQueue q qr@QueueRec {rcvServiceId} = + subscribeQueue q qr = liftIO (TM.lookupIO rId subscriptions) >>= \case - Nothing -> newSub >>= deliver True + Nothing -> subscribeNewQueue rId qr >>= deliver True Just s@Sub {subThread} -> do stats <- asks serverStats case subThread of ProhibitSub -> do -- cannot use SUB in the same connection where GET was used incStat $ qSubProhibited stats - pure ((corrId, rId), ERR $ CMD PROHIBITED) + pure (corrId, rId, ERR $ CMD PROHIBITED) _ -> do incStat $ qSubDuplicate stats atomically (tryTakeTMVar $ delivered s) >> deliver False s where rId = recipientId q - newSub :: M s Sub - newSub = time "SUB newSub" . atomically $ do - writeTQueue (subQ subscribers) (CSClient rId rcvServiceId Nothing, clientId) - sub <- newSubscription NoSub - TM.insert rId sub subscriptions - pure sub deliver :: Bool -> Sub -> M s (Transmission BrokerMsg) deliver inc sub = do stats <- asks serverStats - fmap (either (\e -> ((corrId, rId), ERR e)) id) $ liftIO $ runExceptT $ do + fmap (either (\e -> (corrId, rId, ERR e)) id) $ liftIO $ runExceptT $ do msg_ <- tryPeekMsg ms q liftIO $ when (inc && isJust msg_) $ incStat (qSub stats) liftIO $ deliverMessage "SUB" qr rId sub msg_ + subscribeNewQueue :: RecipientId -> QueueRec -> M s Sub + subscribeNewQueue rId QueueRec {rcvServiceId} = time "SUB newSub" . atomically $ do + writeTQueue (subQ subscribers) (CSClient rId rcvServiceId Nothing, clientId) + sub <- newSubscription NoSub + TM.insert rId sub subscriptions + pure sub + -- clients that use GET are not added to server subscribers getMessage :: StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg) getMessage q qr = time "GET" $ do @@ -1626,7 +1634,7 @@ client Just msg -> do let encMsg = encryptMsg qr msg incStat $ (if isJust delivered_ then msgGetDuplicate else msgGet) stats - atomically $ setDelivered s msg $> (ceIds, MSG encMsg) + atomically $ setDelivered s msg $> (corrId, entId, MSG encMsg) Nothing -> incStat (msgGetNoMsg stats) $> ok withQueue :: (StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)) -> M s (Transmission BrokerMsg) @@ -1660,7 +1668,7 @@ client pure $ SOK $ Just serviceId | otherwise -> -- new or updated queue-service association - liftIO (setQueueService (queueStore ms) q SNotifier (Just serviceId)) >>= \case + liftIO (setQueueService (queueStore ms) q SNotifierService (Just serviceId)) >>= \case Left e -> pure $ ERR e Right () -> do hasSub <- atomically $ (<$ newServiceQueueSub) =<< hasServiceSub @@ -1677,7 +1685,7 @@ client modifyTVar' (totalServiceSubs ntfSubscribers) (+ 1) -- server count for all services Nothing -> case ntfServiceId of Just _ -> - liftIO (setQueueService (queueStore ms) q SNotifier Nothing) >>= \case + liftIO (setQueueService (queueStore ms) q SNotifierService Nothing) >>= \case Left e -> pure $ ERR e Right () -> do -- hasSubscription should never be True in this branch, because queue was associated with service. @@ -1849,7 +1857,7 @@ client _ -> pure Nothing deliver sndQ' s = do let encMsg = encryptMsg qr msg - writeTBQueue sndQ' [((CorrId "", rId), MSG encMsg)] + writeTBQueue sndQ' [(CorrId "", rId, MSG encMsg)] void $ setDelivered s msg forkDeliver (rc@Client {sndQ = sndQ'}, s@Sub {delivered}, st) = do t <- mkWeakThreadId =<< forkIO deliverThread @@ -1900,7 +1908,7 @@ client let clntTHParams = smpTHParamsSetVersion fwdVersion thParams' -- only allowing single forwarded transactions t' <- case tParse clntTHParams b of - t :| [] -> pure $ tDecodeParseValidate clntTHParams t + t :| [] -> pure $ tDecodeServer clntTHParams t _ -> throwE BLOCK let clntThAuth = Just $ THAuthServer {serverPrivKey, peerClientService = Nothing, sessSecret' = Just clientSecret} -- process forwarded command @@ -1909,7 +1917,7 @@ client Left r -> pure r -- rejectOrVerify filters allowed commands, no need to repeat it here. -- INTERNAL is used because processCommand never returns Nothing for sender commands (could be extracted for better types). - Right t''@(_, (ceIds', _)) -> fromMaybe (ceIds', ERR INTERNAL) <$> lift (processCommand Nothing fwdVersion t'') + Right t''@(_, (corrId', entId', _)) -> fromMaybe (corrId', entId', ERR INTERNAL) <$> lift (processCommand Nothing fwdVersion t'') -- encode response r' <- case batchTransmissions clntTHParams [Right (Nothing, encodeTransmission clntTHParams r)] of [] -> throwE INTERNAL -- at least 1 item is guaranteed from NonEmpty/Right @@ -1925,13 +1933,13 @@ client incStat $ pMsgFwdsRecv stats pure r3 where - rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmissionOrError ErrorType Cmd -> M s (Either (Transmission BrokerMsg) (Maybe (StoreQueue s, QueueRec), Transmission Cmd)) + rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmissionOrError ErrorType Cmd -> M s (VerifiedTransmissionOrError s) rejectOrVerify clntThAuth = \case - Left (ceIds', e) -> pure $ Left (ceIds', ERR e) - Right (signed, (ceIds'@(corrId', entId'), cmd')) - | allowed -> liftIO $ verified <$> verifyTransmission ms Nothing ((,C.cbNonce (bs corrId')) <$> clntThAuth) signed entId' cmd' - | otherwise -> pure $ Left (ceIds', ERR $ CMD PROHIBITED) + Left (corrId', entId', e) -> pure $ Left (corrId', entId', ERR e) + Right t'@(_, _, t''@(corrId', entId', cmd')) + | allowed -> liftIO $ verified <$> verifyTransmission ms Nothing clntThAuth t' + | otherwise -> pure $ Left (corrId', entId', ERR $ CMD PROHIBITED) where allowed = case cmd' of Cmd SSender SEND {} -> True @@ -1940,8 +1948,8 @@ client Cmd SSenderLink LGET -> True _ -> False verified = \case - VRVerified q -> Right (q, (ceIds', cmd')) - VRFailed e -> Left (ceIds', ERR e) + VRVerified q -> Right (q, t'') + VRFailed e -> Left (corrId', entId', ERR e) deliverMessage :: T.Text -> QueueRec -> RecipientId -> Sub -> Maybe Message -> IO (Transmission BrokerMsg) deliverMessage name qr rId s@Sub {subThread} msg_ = time (name <> " deliver") . atomically $ @@ -1950,10 +1958,10 @@ client _ -> case msg_ of Just msg -> let encMsg = encryptMsg qr msg - in setDelivered s msg $> ((corrId, rId), MSG encMsg) + in setDelivered s msg $> (corrId, rId, MSG encMsg) _ -> pure resp where - resp = ((corrId, rId), OK) + resp = (corrId, rId, OK) time :: MonadIO m => T.Text -> m a -> m a time name = timed name entId @@ -2014,10 +2022,10 @@ client pure QSub {qSubThread, qDelivered} ok :: Transmission BrokerMsg - ok = (ceIds, OK) + ok = (corrId, entId, OK) err :: ErrorType -> Transmission BrokerMsg - err e = (ceIds, ERR e) + err e = (corrId, entId, ERR e) updateDeletedStats :: QueueRec -> M s () updateDeletedStats q = do diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 7819c297e..627c6079a 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -39,6 +39,7 @@ module Simplex.Messaging.Server.Env.STM MsgStoreType, MsgStore (..), AStoreType (..), + VerifiedTransmission, newEnv, mkJournalStoreConfig, msgStore, @@ -390,7 +391,7 @@ data Client s = Client ntfSubscriptions :: TMap NotifierId (), serviceSubsCount :: TVar Int64, -- only one service can be subscribed, based on its certificate, this is subscription count ntfServiceSubsCount :: TVar Int64, -- only one service can be subscribed, based on its certificate, this is subscription count - rcvQ :: TBQueue (NonEmpty (Maybe (StoreQueue s, QueueRec), Transmission Cmd)), + rcvQ :: TBQueue (NonEmpty (VerifiedTransmission s)), sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), msgQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), procThreads :: TVar Int, @@ -403,6 +404,8 @@ data Client s = Client sndActiveAt :: TVar SystemTime } +type VerifiedTransmission s = (Maybe (StoreQueue s, QueueRec), Transmission Cmd) + data ServerSub = ServerSub (TVar SubscriptionThread) | ProhibitSub data SubscriptionThread = NoSub | SubPending | SubThread (Weak ThreadId) diff --git a/src/Simplex/Messaging/Server/MsgStore/Journal.hs b/src/Simplex/Messaging/Server/MsgStore/Journal.hs index c1fc94c08..59357f0af 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Journal.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Journal.hs @@ -324,6 +324,8 @@ instance QueueStoreClass (JournalQueue s) (QStore s) where {-# INLINE addQueue_ #-} getQueue_ = withQS getQueue_ {-# INLINE getQueue_ #-} + getQueues_ = withQS getQueues_ + {-# INLINE getQueues_ #-} addQueueLinkData = withQS addQueueLinkData {-# INLINE addQueueLinkData #-} getQueueLinkData = withQS getQueueLinkData diff --git a/src/Simplex/Messaging/Server/MsgStore/Types.hs b/src/Simplex/Messaging/Server/MsgStore/Types.hs index e0d32482d..ef0ee4822 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Types.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Types.hs @@ -18,6 +18,7 @@ module Simplex.Messaging.Server.MsgStore.Types where import Control.Concurrent.STM +import Control.Monad import Control.Monad.Trans.Except import Data.Functor (($>)) import Data.Int (Int64) @@ -107,14 +108,23 @@ addQueue :: MsgStoreClass s => s -> RecipientId -> QueueRec -> IO (Either ErrorT addQueue st = addQueue_ (queueStore st) (mkQueue st True) {-# INLINE addQueue #-} -getQueue :: (MsgStoreClass s, DirectParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s)) +getQueue :: (MsgStoreClass s, QueueParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s)) getQueue st = getQueue_ (queueStore st) (mkQueue st) {-# INLINE getQueue #-} -getQueueRec :: (MsgStoreClass s, DirectParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s, QueueRec)) -getQueueRec st party qId = - getQueue st party qId - $>>= (\q -> maybe (Left AUTH) (Right . (q,)) <$> readTVarIO (queueRec q)) +getQueues :: (MsgStoreClass s, BatchParty p) => s -> SParty p -> [QueueId] -> IO [Either ErrorType (StoreQueue s)] +getQueues st = getQueues_ (queueStore st) (mkQueue st) +{-# INLINE getQueues #-} + +getQueueRec :: (MsgStoreClass s, QueueParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s, QueueRec)) +getQueueRec st party qId = getQueue st party qId $>>= readQueueRec + +getQueueRecs :: (MsgStoreClass s, BatchParty p) => s -> SParty p -> [QueueId] -> IO [Either ErrorType (StoreQueue s, QueueRec)] +getQueueRecs st party qIds = getQueues st party qIds >>= mapM (fmap join . mapM readQueueRec) + +readQueueRec :: StoreQueueClass q => q -> IO (Either ErrorType (q, QueueRec)) +readQueueRec q = maybe (Left AUTH) (Right . (q,)) <$> readTVarIO (queueRec q) +{-# INLINE readQueueRec #-} getQueueSize :: MsgStoreClass s => s -> StoreQueue s -> ExceptT ErrorType IO Int getQueueSize st q = withPeekMsgQueue st q "getQueueSize" $ maybe (pure 0) (getQueueSize_ . fst) diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs index 20307ac9d..35667fa62 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs @@ -37,18 +37,20 @@ import Control.Monad import Control.Monad.Except import Control.Monad.IO.Class import Control.Monad.Trans.Except +import Data.Bifunctor (first) import Data.ByteString.Builder (Builder) import qualified Data.ByteString.Builder as BB import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Lazy as LB import Data.Bitraversable (bimapM) -import Data.Either (fromRight) +import Data.Either (fromRight, lefts, rights) import Data.Functor (($>)) import Data.Int (Int64) import Data.List (foldl', intersperse, partition) import Data.List.NonEmpty (NonEmpty) +import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (catMaybes, fromMaybe) +import Data.Maybe (catMaybes, fromMaybe, mapMaybe) import qualified Data.Set as S import Data.Text (Text) import Data.Time.Clock.System (SystemTime (..), getSystemTime) @@ -62,7 +64,7 @@ import Database.PostgreSQL.Simple.ToField (Action (..), ToField (..)) import Database.PostgreSQL.Simple.Errors (ConstraintViolation (..), constraintViolation) import Database.PostgreSQL.Simple.SqlQQ (sql) import GHC.IO (catchAny) -import Simplex.Messaging.Agent.Client (withLockMap) +import Simplex.Messaging.Agent.Client (withLockMap, withLocksMap) import Simplex.Messaging.Agent.Lock (Lock) import Simplex.Messaging.Agent.Store.AgentStore () import Simplex.Messaging.Agent.Store.Postgres (createDBStore, closeDBStore) @@ -81,7 +83,7 @@ import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPServiceRole (..)) -import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>)) +import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>), ($>>=)) import System.Exit (exitFailure) import System.IO (IOMode (..), hFlush, stdout) import UnliftIO.STM @@ -180,18 +182,16 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where -- hasId = anyM [TM.memberIO rId queues, TM.memberIO senderId senders, hasNotifier] -- hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.memberIO notifierId notifiers) notifier - getQueue_ :: DirectParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q) + getQueue_ :: QueueParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q) getQueue_ st mkQ party qId = case party of SRecipient -> getRcvQueue qId - SSender -> getSndQueue - SProxyService -> getSndQueue + SSender -> TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue SSenderLink -> TM.lookupIO qId links >>= maybe (mask loadLinkQueue) getRcvQueue -- loaded queue is deleted from notifiers map to reduce cache size after queue was subscribed to by ntf server SNotifier -> TM.lookupIO qId notifiers >>= maybe (mask loadNtfQueue) (getRcvQueue >=> (atomically (TM.delete qId notifiers) $>)) where PostgresQueueStore {queues, senders, links, notifiers} = st getRcvQueue rId = TM.lookupIO rId queues >>= maybe (mask loadRcvQueue) (pure . Right) - getSndQueue = TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue loadRcvQueue = do (rId, qRec) <- loadQueue " WHERE recipient_id = ?" liftIO $ cacheQueue rId qRec $ \_ -> pure () -- recipient map already checked, not caching sender ref @@ -228,6 +228,48 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where TM.insert rId sq queues pure sq + getQueues_ :: forall p. BatchParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q] + getQueues_ st mkQ party qIds = case party of + SRecipient -> do + qs <- readTVarIO queues + let qs' = map (\qId -> get qs qId qId) qIds + E.uninterruptibleMask_ $ loadQueues qs' " WHERE recipient_id IN ?" cacheRcvQueue + SNotifier -> do + ns <- readTVarIO notifiers + qs <- readTVarIO queues + let qs' = map (\qId -> get ns qId qId >>= get qs qId) qIds + E.uninterruptibleMask_ $ loadQueues qs' " WHERE notifier_id IN ?" $ \(rId, qRec) -> + forM (notifier qRec) $ \NtfCreds {notifierId = nId} -> -- it is always Just with this query + (nId,) <$> maybe (mkQ False rId qRec) pure (M.lookup rId qs) + where + PostgresQueueStore {queues, notifiers} = st + get :: M.Map QueueId a -> QueueId -> QueueId -> Either QueueId a + get m qId = maybe (Left qId) Right . (`M.lookup` m) + loadQueues :: [Either QueueId q] -> Query -> ((RecipientId, QueueRec) -> IO (Maybe (QueueId, q))) -> IO [Either ErrorType q] + loadQueues qs' cond mkCacheQueue = do + let qIds' = lefts qs' + if null qIds' + then pure $ map (first (const INTERNAL)) qs' + else do + qs_ <- + runExceptT $ fmap M.fromList $ + withDB' "getQueues_" st (\db -> DB.query db (queueRecQuery <> cond <> " AND deleted_at IS NULL") (Only (In qIds'))) + >>= liftIO . fmap catMaybes . mapM (mkCacheQueue . rowToQueueRec) + pure $ map (result qs_) qs' + where + result :: Either ErrorType (M.Map QueueId q) -> Either QueueId q -> Either ErrorType q + result qs_ = \case + Right q -> Right q + Left qId -> maybe (Left AUTH) Right . M.lookup qId =<< qs_ + cacheRcvQueue (rId, qRec) = do + sq <- mkQ True rId qRec + sq' <- withQueueLock sq "getQueue_" $ atomically $ + -- checking the cache again for concurrent reads, use previously loaded queue if exists. + TM.lookup rId queues >>= \case + Just sq' -> pure sq' + Nothing -> sq <$ TM.insert rId sq queues + pure $ Just (rId, sq') + getQueueLinkData :: PostgresQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData) getQueueLinkData st sq lnkId = runExceptT $ do qr <- ExceptT $ readQueueRecIO $ queueRec sq @@ -333,7 +375,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where db [sql| UPDATE msg_queues - SET notifier_id = NULL, notifier_key = NULL, rcv_ntf_dh_secret = NULL + SET notifier_id = NULL, notifier_key = NULL, rcv_ntf_dh_secret = NULL, ntf_service_id = NULL WHERE recipient_id = ? AND deleted_at IS NULL |] (Only rId) @@ -402,15 +444,15 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where when new $ withLog "getCreateService" st (`logNewService` sr) pure serviceId - setQueueService :: (PartyI p, SubscriberParty p) => PostgresQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) + setQueueService :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) setQueueService st sq party serviceId = withQueueRec sq "setQueueService" $ \q -> case party of - SRecipient + SRecipientService | rcvServiceId q == serviceId -> pure () | otherwise -> do assertUpdated $ withDB' "setQueueService" st $ \db -> DB.execute db "UPDATE msg_queues SET rcv_service_id = ? WHERE recipient_id = ? AND deleted_at IS NULL" (serviceId, rId) updateQueueRec q {rcvServiceId = serviceId} - SNotifier -> case notifier q of + SNotifierService -> case notifier q of Nothing -> throwE AUTH Just nc@NtfCreds {ntfServiceId = prevSrvId} | prevSrvId == serviceId -> pure () diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index 522f2f28e..5c16825d2 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -128,17 +128,29 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.member notifierId notifiers) notifier hasLink = maybe (pure False) (\(lnkId, _) -> TM.member lnkId links) queueData - getQueue_ :: DirectParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q) + getQueue_ :: QueueParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q) getQueue_ st _ party qId = maybe (Left AUTH) Right <$> case party of SRecipient -> TM.lookupIO qId queues - SSender -> getSndQueue - SProxyService -> getSndQueue + SSender -> TM.lookupIO qId senders $>>= (`TM.lookupIO` queues) SNotifier -> TM.lookupIO qId notifiers $>>= (`TM.lookupIO` queues) SSenderLink -> TM.lookupIO qId links $>>= (`TM.lookupIO` queues) where STMQueueStore {queues, senders, notifiers, links} = st - getSndQueue = TM.lookupIO qId senders $>>= (`TM.lookupIO` queues) + + getQueues_ :: BatchParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q] + getQueues_ st _ party qIds = case party of + SRecipient -> do + qs <- readTVarIO queues + pure $ map (get qs) qIds + SNotifier -> do + ns <- readTVarIO notifiers + qs <- readTVarIO queues + pure $ map (get qs <=< get ns) qIds + where + STMQueueStore {queues, notifiers} = st + get :: M.Map QueueId a -> QueueId -> Either ErrorType a + get m = maybe (Left AUTH) Right . (`M.lookup` m) getQueueLinkData :: STMQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData) getQueueLinkData _ q lnkId = atomically $ readQueueRec (queueRec q) $>>= pure . getData @@ -292,7 +304,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where serviceNtfQueues <- newTVar S.empty pure STMService {serviceRec = sr, serviceRcvQueues, serviceNtfQueues} - setQueueService :: (PartyI p, SubscriberParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) + setQueueService :: (PartyI p, ServiceParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) setQueueService st sq party serviceId = atomically (readQueueRec qr $>>= setService) $>> withLog "setQueueService" st (\sl -> logQueueService sl rId party serviceId) @@ -301,13 +313,13 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where rId = recipientId sq setService :: QueueRec -> STM (Either ErrorType ()) setService q@QueueRec {rcvServiceId = prevSrvId} = case party of - SRecipient + SRecipientService | prevSrvId == serviceId -> pure $ Right () | otherwise -> do updateServiceQueues serviceRcvQueues rId prevSrvId let !q' = Just q {rcvServiceId = serviceId} writeTVar qr q' $> Right () - SNotifier -> case notifier q of + SNotifierService -> case notifier q of Nothing -> pure $ Left AUTH Just nc@NtfCreds {notifierId = nId, ntfServiceId = prevNtfSrvId} | prevNtfSrvId == serviceId -> pure $ Right () diff --git a/src/Simplex/Messaging/Server/QueueStore/Types.hs b/src/Simplex/Messaging/Server/QueueStore/Types.hs index e8af996cb..104d62267 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Types.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Types.hs @@ -31,7 +31,8 @@ class StoreQueueClass q => QueueStoreClass q s where loadedQueues :: s -> TMap RecipientId q compactQueues :: s -> IO Int64 addQueue_ :: s -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q) - getQueue_ :: DirectParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q) + getQueue_ :: QueueParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q) + getQueues_ :: BatchParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q] getQueueLinkData :: s -> q -> LinkId -> IO (Either ErrorType QueueLinkData) addQueueLinkData :: s -> q -> LinkId -> QueueLinkData -> IO (Either ErrorType ()) deleteQueueLinkData :: s -> q -> IO (Either ErrorType ()) @@ -45,7 +46,7 @@ class StoreQueueClass q => QueueStoreClass q s where updateQueueTime :: s -> q -> RoundedSystemTime -> IO (Either ErrorType QueueRec) deleteStoreQueue :: s -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q))) getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId) - setQueueService :: (PartyI p, SubscriberParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) + setQueueService :: (PartyI p, ServiceParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) getQueueNtfServices :: s -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)])) getNtfServiceQueueCount :: s -> ServiceId -> IO (Either ErrorType Int64) diff --git a/src/Simplex/Messaging/Server/StoreLog.hs b/src/Simplex/Messaging/Server/StoreLog.hs index 0baad8a11..6ea015066 100644 --- a/src/Simplex/Messaging/Server/StoreLog.hs +++ b/src/Simplex/Messaging/Server/StoreLog.hs @@ -286,7 +286,7 @@ logUpdateQueueTime s qId t = writeStoreLogRecord s $ UpdateTime qId t logNewService :: StoreLog 'WriteMode -> ServiceRec -> IO () logNewService s = writeStoreLogRecord s . NewService -logQueueService :: (PartyI p, SubscriberParty p) => StoreLog 'WriteMode -> RecipientId -> SParty p -> Maybe ServiceId -> IO () +logQueueService :: (PartyI p, ServiceParty p) => StoreLog 'WriteMode -> RecipientId -> SParty p -> Maybe ServiceId -> IO () logQueueService s rId party = writeStoreLogRecord s . QueueService rId (ASP party) readWriteStoreLog :: (FilePath -> s -> IO ()) -> (StoreLog 'WriteMode -> s -> IO ()) -> FilePath -> s -> IO (StoreLog 'WriteMode) diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index dd3e8df7f..9069cfc89 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -328,7 +328,7 @@ randomSUB_ a v sessId = do (rKey, rpKey) <- atomically $ C.generateAuthKeyPair a g thAuth_ <- testTHandleAuth v g rKey let thParams = testTHandleParams v sessId - TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams ((CorrId corrId, EntityId rId), Cmd SRecipient SUB) + TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (CorrId corrId, EntityId rId, Cmd SRecipient SUB) pure $ (,tToSend) <$> authTransmission thAuth_ True (Just rpKey) nonce tForAuth randomSUBCmdV6 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) @@ -348,7 +348,7 @@ randomENDCmd :: IO (Transmission BrokerMsg) randomENDCmd = do g <- C.newRandom rId <- atomically $ C.randomBytes 24 g - pure ((CorrId "", EntityId rId), END) + pure (CorrId "", EntityId rId, END) randomNMSGCmd :: SystemTime -> IO (Transmission BrokerMsg) randomNMSGCmd ts = do @@ -359,7 +359,7 @@ randomNMSGCmd ts = do nonce <- atomically $ C.randomCbNonce g let msgMeta = NMsgMeta {msgId, msgTs = ts} Right encNMsgMeta <- pure $ C.cbEncrypt (C.dh' k pk) nonce (smpEncode msgMeta) 128 - pure ((CorrId "", EntityId nId), NMSG nonce encNMsgMeta) + pure (CorrId "", EntityId nId, NMSG nonce encNMsgMeta) randomSENDv6 :: ByteString -> Int -> IO (Either TransportError (Maybe TAuthorizations, ByteString)) randomSENDv6 = randomSEND_ C.SEd25519 minServerSMPRelayVersion @@ -376,7 +376,7 @@ randomSEND_ a v sessId len = do thAuth_ <- testTHandleAuth v g sKey msg <- atomically $ C.randomBytes len g let thParams = testTHandleParams v sessId - TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams ((CorrId corrId, EntityId sId), Cmd SSender $ SEND noMsgFlags msg) + TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (CorrId corrId, EntityId sId, Cmd SSender $ SEND noMsgFlags msg) pure $ (,tToSend) <$> authTransmission thAuth_ False (Just spKey) nonce tForAuth testTHandleParams :: VersionSMP -> ByteString -> THandleParams SMPVersion 'TClient diff --git a/tests/CoreTests/StoreLogTests.hs b/tests/CoreTests/StoreLogTests.hs index f03f1d2ee..3a898ef6a 100644 --- a/tests/CoreTests/StoreLogTests.hs +++ b/tests/CoreTests/StoreLogTests.hs @@ -122,7 +122,7 @@ storeLogTests = }, SLTC { name = "create queue, add notifier, register and associate notification service", - saved = [CreateQueue rId qr, AddNotifier rId ntfCreds, NewService sr, QueueService rId (ASP SNotifier) (Just serviceId)], + saved = [CreateQueue rId qr, AddNotifier rId ntfCreds, NewService sr, QueueService rId (ASP SNotifierService) (Just serviceId)], compacted = [NewService sr, CreateQueue rId qr {notifier = Just ntfCreds {ntfServiceId = Just serviceId}}], state = M.fromList [(rId, qr {notifier = Just ntfCreds {ntfServiceId = Just serviceId}})] }, diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index 8feaa301c..30b648401 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -206,7 +206,7 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h [Right ()] <- tPut h [Right (sig, t')] pure () tGet' h = do - [Right ((Nothing, _), ((CorrId corrId, EntityId qId), cmd))] <- tGet h + [(CorrId corrId, EntityId qId, Right cmd)] <- tGetClient h pure (Nothing, corrId, qId, cmd) ntfTest :: Transport c => TProxy c 'TServer -> (THandleNTF c 'TClient -> IO ()) -> Expectation diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index 0631c0589..a4f0a7d62 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -72,20 +72,20 @@ ntfSyntaxTests (ATransport t) = do Expectation command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response -pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmissionOrError ErrorType NtfResponse -pattern RespNtf corrId queueId command <- Right (_, ((corrId, queueId), command)) +pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> Transmission (Either ErrorType NtfResponse) +pattern RespNtf corrId queueId command <- (corrId, queueId, Right command) deriving instance Eq NtfResponse -sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> (Maybe TAuthorizations, ByteString, NtfEntityId, NtfCommand e) -> IO (SignedTransmissionOrError ErrorType NtfResponse) +sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> (Maybe TAuthorizations, ByteString, NtfEntityId, NtfCommand e) -> IO (Transmission (Either ErrorType NtfResponse)) sendRecvNtf h@THandle {params} (sgn, corrId, qId, cmd) = do - let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params ((CorrId corrId, qId), cmd) + let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, tToSend) tGet1 h -signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> C.APrivateAuthKey -> (ByteString, NtfEntityId, NtfCommand e) -> IO (SignedTransmissionOrError ErrorType NtfResponse) +signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> C.APrivateAuthKey -> (ByteString, NtfEntityId, NtfCommand e) -> IO (Transmission (Either ErrorType NtfResponse)) signSendRecvNtf h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do - let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params ((CorrId corrId, qId), cmd) + let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (authorize tForAuth, tToSend) tGet1 h where diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 71a36e316..50f11e25f 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -394,7 +394,7 @@ smpServerTest _ t = runSmpTest (ASType SQSMemory SMSJournal) $ \h -> tPut' h t > [Right ()] <- tPut h [Right (sig, t')] pure () tGet' h = do - [Right ((Nothing, _), ((CorrId corrId, EntityId qId), cmd))] <- tGet h + [(CorrId corrId, EntityId qId, Right cmd)] <- tGetClient h pure (Nothing, corrId, qId, cmd) smpTest :: (HasCallStack, Transport c) => TProxy c 'TServer -> AStoreType -> (HasCallStack => THandleSMP c 'TClient -> IO ()) -> Expectation diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index b767842d8..e8a4c60cd 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -435,14 +435,14 @@ testNoProxy :: AStoreType -> IO () testNoProxy msType = do withSmpServerConfigOn (transport @TLS) (cfgMS msType) testPort2 $ \_ -> do testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do - Right (_, (_ceIds, reply)) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer Nothing) + (_, _, Right reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer Nothing) reply `shouldBe` (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH) testProxyAuth :: AStoreType -> IO () testProxyAuth msType = do withSmpServerConfigOn (transport @TLS) proxyCfgAuth testPort $ \_ -> do testSMPClient_ "127.0.0.1" testPort proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do - Right (_, (_ceIds, reply)) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer2 $ Just "wrong") + (_, _, Right reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer2 $ Just "wrong") reply `shouldBe` (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH) where proxyCfgAuth = updateCfg (proxyCfgMS msType) $ \cfg_ -> cfg_ {newQueueBasicAuth = Just "correct"} diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 1af571910..e7f2c35c6 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -92,10 +92,10 @@ serverTests = do testInvQueueLinkData testContactQueueLinkData -pattern Resp :: CorrId -> QueueId -> BrokerMsg -> SignedTransmissionOrError ErrorType BrokerMsg -pattern Resp corrId queueId command <- Right (_, ((corrId, queueId), command)) +pattern Resp :: CorrId -> QueueId -> BrokerMsg -> Transmission (Either ErrorType BrokerMsg) +pattern Resp corrId queueId command <- (corrId, queueId, Right command) -pattern New :: RcvPublicAuthKey -> RcvPublicDhKey -> Command 'Recipient +pattern New :: RcvPublicAuthKey -> RcvPublicDhKey -> Command 'Creator pattern New rPub dhPub = NEW (NewQueueReq rPub dhPub Nothing SMSubscribe (Just (QRMessaging Nothing))) pattern Ids :: RecipientId -> SenderId -> RcvPublicDhKey -> BrokerMsg @@ -104,21 +104,21 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh _sndSecure _linkId Nothing) pattern Msg :: MsgId -> MsgBody -> BrokerMsg pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> (Maybe TAuthorizations, ByteString, EntityId, Command p) -> IO (SignedTransmissionOrError ErrorType BrokerMsg) +sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> (Maybe TAuthorizations, ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg)) sendRecv h@THandle {params} (sgn, corrId, qId, cmd) = do - let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params ((CorrId corrId, qId), cmd) + let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, tToSend) tGet1 h -signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> (ByteString, EntityId, Command p) -> IO (SignedTransmissionOrError ErrorType BrokerMsg) +signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> (ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg)) signSendRecv h pk = signSendRecv_ h pk Nothing -serviceSignSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (SignedTransmissionOrError ErrorType BrokerMsg) +serviceSignSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg)) serviceSignSendRecv h pk = signSendRecv_ h pk . Just -signSendRecv_ :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> Maybe C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (SignedTransmissionOrError ErrorType BrokerMsg) +signSendRecv_ :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> Maybe C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg)) signSendRecv_ h@THandle {params} (C.APrivateAuthKey a pk) serviceKey_ (corrId, qId, cmd) = do - let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params ((CorrId corrId, qId), cmd) + let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (authorize tForAuth, tToSend) tGet1 h where @@ -139,9 +139,9 @@ tPut1 h t = do [r] <- tPut h [Right t] pure r -tGet1 :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TClient -> IO (SignedTransmissionOrError err cmd) +tGet1 :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TClient -> IO (Transmission (Either err cmd)) tGet1 h = do - [r] <- liftIO $ tGet h + [r] <- liftIO $ tGetClient h pure r (#==) :: (HasCallStack, Eq a, Show a) => (a, a) -> String -> Assertion @@ -519,7 +519,7 @@ testSwitchSub = Resp "" rId' DELD <- tGet1 rh2 (rId', rId) #== "connection deleted event delivered to subscribed client" - 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case + 1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case Nothing -> return () Just _ -> error "nothing else is delivered to the 1st TCP connection" @@ -1017,7 +1017,7 @@ testMessageNotifications = Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2) (dec mId2 msg2, Right "hello again") #== "delivered from queue again" Resp "" _ (NMSG _ _) <- tGet1 nh2 - 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case + 1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection" Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL) @@ -1027,7 +1027,7 @@ testMessageNotifications = Resp "" _ (Msg mId3 msg3) <- tGet1 rh (dec mId3 msg3, Right "hello there") #== "delivered from queue again" Resp "7a" _ OK <- signSendRecv rh rKey ("7a", rId, ACK mId3) - 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case + 1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection" (nPub'', nKey'') <- atomically $ C.generateAuthKeyPair C.SEd25519 g @@ -1069,7 +1069,7 @@ testMessageServiceNotifications = Resp "" serviceId2 (ENDS 1) <- tGet1 nh1 serviceId2 `shouldBe` serviceId deliverMessage rh rId rKey sh sId sKey nh2 "hello again" dec - 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case + 1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection" Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL) @@ -1079,7 +1079,7 @@ testMessageServiceNotifications = Resp "" _ (Msg mId3 msg3) <- tGet1 rh (dec mId3 msg3, Right "hello there") #== "delivered from queue again" Resp "7a" _ OK <- signSendRecv rh rKey ("7a", rId, ACK mId3) - 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case + 1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection" -- new notification credentials @@ -1133,7 +1133,7 @@ testMsgExpireOnSend = testSMPClient @c $ \rh -> do Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB) (dec mId msg, Right "hello (should NOT expire)") #== "delivered" - 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing else should be delivered" @@ -1153,7 +1153,7 @@ testMsgExpireOnInterval = signSendRecv rh rKey ("2", rId, SUB) >>= \case Resp "2" _ OK -> pure () r -> unexpected r - 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing should be delivered" @@ -1172,7 +1172,7 @@ testMsgNOTExpireOnInterval = testSMPClient @c $ \rh -> do Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB) (dec mId msg, Right "hello (should NOT expire)") #== "delivered" - 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing else should be delivered"