smp server: batch commands (#1559)

* protocol: refactor types and encoding

* clean

* smp server: batch commands (#1560)

* smp server: batch commands verification into one DB transaction

* ghc 8.10.7

* flatten transmission tuples

* diff

* only use batch logic if there is more than one transmission

* func

* reset NTF service when adding notifier

* version

* Revert "smp server: use separate database pool for reading queues and creating service records (#1561)"

This reverts commit 3df2425162.

* version

* Revert "version"

This reverts commit d80a6b74c5.
This commit is contained in:
Evgeny
2025-06-12 23:05:04 +01:00
committed by GitHub
parent 1658048c2c
commit da37384335
24 changed files with 556 additions and 377 deletions

View File

@@ -204,7 +204,7 @@ sendXFTPTransmission XFTPClient {config, thParams, http2Client} t chunkSpec_ = d
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req (Just reqTimeout)
when (B.length bodyHead /= xftpBlockSize) $ throwE $ PCEResponseError BLOCK
-- TODO validate that the file ID is the same as in the request?
(_, _, (_, _fId, respOrErr)) <- liftEither . first PCEResponseError $ xftpDecodeTransmission thParams bodyHead
(_, _fId, respOrErr) <-liftEither $ first PCEResponseError $ xftpDecodeTClient thParams bodyHead
case respOrErr of
Right r -> case protocolError r of
Just e -> throwE $ PCEProtocolError e

View File

@@ -44,8 +44,9 @@ import Simplex.Messaging.Protocol
EntityId (..),
RecipientId,
SenderId,
RawTransmission,
SentRawTransmission,
SignedTransmission,
SignedTransmissionOrError,
SndPublicAuthKey,
Transmission,
TransmissionForAuth (..),
@@ -53,7 +54,8 @@ import Simplex.Messaging.Protocol
encodeTransmission,
encodeTransmissionForAuth,
messageTagP,
tDecodeParseValidate,
tDecodeServer,
tDecodeClient,
tEncodeBatch1,
tParse,
)
@@ -197,7 +199,7 @@ instance FilePartyI p => ProtocolEncoding XFTPVersion XFTPErrorType (FileCommand
fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse
{-# INLINE fromProtocolError #-}
checkCredentials (auth, _, EntityId fileId, _) cmd = case cmd of
checkCredentials auth (EntityId fileId) cmd = case cmd of
-- FNEW must not have signature and chunk ID
FNEW {}
| isNothing auth -> Left $ CMD NO_AUTH
@@ -231,7 +233,7 @@ instance ProtocolEncoding XFTPVersion XFTPErrorType FileCmd where
fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse
{-# INLINE fromProtocolError #-}
checkCredentials t (FileCmd p c) = FileCmd p <$> checkCredentials t c
checkCredentials tAuth entId (FileCmd p c) = FileCmd p <$> checkCredentials tAuth entId c
{-# INLINE checkCredentials #-}
instance Encoding FileInfo where
@@ -310,7 +312,7 @@ instance ProtocolEncoding XFTPVersion XFTPErrorType FileResponse where
PEBlock -> BLOCK
{-# INLINE fromProtocolError #-}
checkCredentials (_, _, EntityId entId, _) cmd = case cmd of
checkCredentials _ (EntityId entId) cmd = case cmd of
FRSndIds {} -> noEntity
-- ERR response does not always have entity ID
FRErr _ -> Right cmd
@@ -335,25 +337,35 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of
Just Refl -> Just c
_ -> Nothing
xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion 'TClient -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString
xftpEncodeAuthTransmission thParams@THandleParams {thAuth} pKey (corrId, fId, msg) = do
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, fId, msg)
xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion XFTPErrorType c => THandleParams XFTPVersion 'TClient -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString
xftpEncodeAuthTransmission thParams@THandleParams {thAuth} pKey t@(corrId, _, _) = do
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams t
xftpEncodeBatch1 . (,tToSend) =<< authTransmission thAuth False (Just pKey) (C.cbNonce $ bs corrId) tForAuth
xftpEncodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion p -> Transmission c -> Either TransportError ByteString
xftpEncodeTransmission thParams (corrId, fId, msg) = do
let t = encodeTransmission thParams (corrId, fId, msg)
xftpEncodeBatch1 (Nothing, t)
xftpEncodeTransmission :: ProtocolEncoding XFTPVersion XFTPErrorType c => THandleParams XFTPVersion p -> Transmission c -> Either TransportError ByteString
xftpEncodeTransmission thParams t = xftpEncodeBatch1 (Nothing, encodeTransmission thParams t)
-- this function uses batch syntax but puts only one transmission in the batch
xftpEncodeBatch1 :: SentRawTransmission -> Either TransportError ByteString
xftpEncodeBatch1 t = first (const TELargeMsg) $ C.pad (tEncodeBatch1 False t) xftpBlockSize
xftpDecodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion p -> ByteString -> Either XFTPErrorType (SignedTransmission e c)
xftpDecodeTransmission thParams t = do
xftpDecodeTServer :: THandleParams XFTPVersion 'TServer -> ByteString -> Either XFTPErrorType (SignedTransmissionOrError XFTPErrorType FileCmd)
xftpDecodeTServer = xftpDecodeTransmission tDecodeServer
{-# INLINE xftpDecodeTServer #-}
xftpDecodeTClient :: THandleParams XFTPVersion 'TClient -> ByteString -> Either XFTPErrorType (Transmission (Either XFTPErrorType FileResponse))
xftpDecodeTClient = xftpDecodeTransmission tDecodeClient
{-# INLINE xftpDecodeTClient #-}
xftpDecodeTransmission ::
(THandleParams XFTPVersion p -> Either TransportError RawTransmission -> r) ->
THandleParams XFTPVersion p ->
ByteString ->
Either XFTPErrorType r
xftpDecodeTransmission tDecode thParams t = do
t' <- first (const BLOCK) $ C.unPad t
case tParse thParams t' of
t'' :| [] -> Right $ tDecodeParseValidate thParams t''
t'' :| [] -> Right $ tDecode thParams t''
_ -> Left BLOCK
$(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty)

View File

@@ -53,7 +53,7 @@ import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (CorrId (..), BlockingInfo, EntityId (..), RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TAuthorizations, pattern NoEntity)
import Simplex.Messaging.Protocol (BlockingInfo, EntityId (..), RcvPublicAuthKey, RcvPublicDhKey, RecipientId, SignedTransmission, pattern NoEntity)
import Simplex.Messaging.Server (dummyVerifyCmd, verifyCmdAuthorization)
import Simplex.Messaging.Server.Control (CPClientRole (..))
import Simplex.Messaging.Server.Expiration
@@ -317,22 +317,20 @@ data ServerFile = ServerFile
processRequest :: XFTPTransportRequest -> M ()
processRequest XFTPTransportRequest {thParams, reqBody = body@HTTP2Body {bodyHead}, sendResponse}
| B.length bodyHead /= xftpBlockSize = sendXFTPResponse ("", NoEntity, FRErr BLOCK) Nothing
| otherwise = do
case xftpDecodeTransmission thParams bodyHead of
Right (sig_, signed, (corrId, fId, cmdOrErr)) ->
case cmdOrErr of
Right cmd -> do
let THandleParams {thAuth} = thParams
verifyXFTPTransmission ((,C.cbNonce (bs corrId)) <$> thAuth) sig_ signed fId cmd >>= \case
VRVerified req -> uncurry send =<< processXFTPRequest body req
VRFailed e -> send (FRErr e) Nothing
Left e -> send (FRErr e) Nothing
| otherwise =
case xftpDecodeTServer thParams bodyHead of
Right (Right t@(_, _, (corrId, fId, _))) -> do
let THandleParams {thAuth} = thParams
verifyXFTPTransmission thAuth t >>= \case
VRVerified req -> uncurry send =<< processXFTPRequest body req
VRFailed e -> send (FRErr e) Nothing
where
send resp = sendXFTPResponse (corrId, fId, resp)
Right (Left (corrId, fId, e)) -> sendXFTPResponse (corrId, fId, FRErr e) Nothing
Left e -> sendXFTPResponse ("", NoEntity, FRErr e) Nothing
where
sendXFTPResponse (corrId, fId, resp) serverFile_ = do
let t_ = xftpEncodeTransmission thParams (corrId, fId, resp)
sendXFTPResponse t' serverFile_ = do
let t_ = xftpEncodeTransmission thParams t'
#ifdef slow_servers
randomDelay
#endif
@@ -361,8 +359,8 @@ randomDelay = do
data VerificationResult = VRVerified XFTPRequest | VRFailed XFTPErrorType
verifyXFTPTransmission :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TAuthorizations -> ByteString -> XFTPFileId -> FileCmd -> M VerificationResult
verifyXFTPTransmission auth_ tAuth authorized fId cmd =
verifyXFTPTransmission :: Maybe (THandleAuth 'TServer) -> SignedTransmission FileCmd -> M VerificationResult
verifyXFTPTransmission thAuth (tAuth, authorized, (corrId, fId, cmd)) =
case cmd of
FileCmd SFSender (FNEW file rcps auth') -> pure $ XFTPReqNew file rcps auth' `verifyWith` sndKey file
FileCmd SFRecipient PING -> pure $ VRVerified XFTPReqPing
@@ -381,9 +379,9 @@ verifyXFTPTransmission auth_ tAuth authorized fId cmd =
EntityBlocked info -> VRFailed $ BLOCKED info
EntityOff -> noFileAuth
Left _ -> pure noFileAuth
noFileAuth = maybe False (dummyVerifyCmd Nothing authorized) tAuth `seq` VRFailed AUTH
noFileAuth = dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH
-- TODO verify with DH authorization
req `verifyWith` k = if verifyCmdAuthorization auth_ tAuth authorized k then VRVerified req else VRFailed AUTH
req `verifyWith` k = if verifyCmdAuthorization thAuth tAuth authorized corrId k then VRVerified req else VRFailed AUTH
processXFTPRequest :: HTTP2Body -> XFTPRequest -> M (FileResponse, Maybe ServerFile)
processXFTPRequest HTTP2Body {bodyPart} = \case

View File

@@ -33,6 +33,7 @@ module Simplex.Messaging.Agent.Client
withConnLocks,
withInvLock,
withLockMap,
withLocksMap,
getMapLock,
ipAddressProtected,
closeAgentClient,
@@ -1004,16 +1005,16 @@ withInvLock' AgentClient {invLocks} = withLockMap invLocks
{-# INLINE withInvLock' #-}
withConnLocks :: AgentClient -> Set ConnId -> Text -> AM' a -> AM' a
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks
withConnLocks AgentClient {connLocks} = withLocksMap connLocks
{-# INLINE withConnLocks #-}
withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> Text -> m a -> m a
withLockMap = withGetLock . getMapLock
{-# INLINE withLockMap #-}
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> Text -> m a -> m a
withLocksMap_ = withGetLocks . getMapLock
{-# INLINE withLocksMap_ #-}
withLocksMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> Text -> m a -> m a
withLocksMap = withGetLocks . getMapLock
{-# INLINE withLocksMap #-}
getMapLock :: Ord k => TMap k Lock -> k -> STM Lock
getMapLock locks key = TM.lookup key locks >>= maybe newLock pure

View File

@@ -183,7 +183,7 @@ data PClient v err msg = PClient
clientCorrId :: TVar ChaChaDRG,
sentCommands :: TMap CorrId (Request err msg),
sndQ :: TBQueue (Maybe (Request err msg), ByteString),
rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)),
rcvQ :: TBQueue (NonEmpty (Transmission (Either err msg))),
msgQ :: Maybe (TBQueue (ServerTransmissionBatch v err msg))
}
@@ -615,7 +615,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
receive :: Transport c => ProtocolClient v err msg -> THandle v c 'TClient -> IO ()
receive ProtocolClient {client_ = PClient {rcvQ, lastReceived, timeoutErrorCount}} h = forever $ do
tGet h >>= atomically . writeTBQueue rcvQ
tGetClient h >>= atomically . writeTBQueue rcvQ
getCurrentTime >>= atomically . writeTVar lastReceived
atomically $ writeTVar timeoutErrorCount 0
@@ -642,14 +642,14 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
process :: ProtocolClient v err msg -> IO ()
process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= processMsgs c
processMsgs :: ProtocolClient v err msg -> NonEmpty (SignedTransmission err msg) -> IO ()
processMsgs :: ProtocolClient v err msg -> NonEmpty (Transmission (Either err msg)) -> IO ()
processMsgs c ts = do
ts' <- catMaybes <$> mapM (processMsg c) (L.toList ts)
forM_ msgQ $ \q ->
mapM_ (atomically . writeTBQueue q . serverTransmission c) (L.nonEmpty ts')
processMsg :: ProtocolClient v err msg -> SignedTransmission err msg -> IO (Maybe (EntityId, ServerTransmission err msg))
processMsg ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr))
processMsg :: ProtocolClient v err msg -> Transmission (Either err msg) -> IO (Maybe (EntityId, ServerTransmission err msg))
processMsg ProtocolClient {client_ = PClient {sentCommands}} (corrId, entId, respOrErr)
| B.null $ bs corrId = sendMsg $ STEvent clientResp
| otherwise =
TM.lookupIO corrId sentCommands >>= \case
@@ -767,7 +767,7 @@ createSMPQueue ::
-- Maybe NewNtfCreds ->
ExceptT SMPClientError IO QueueIdsKeys
createSMPQueue c nonce_ (rKey, rpKey) dhKey auth subMode qrd =
sendProtocolCommand_ c nonce_ Nothing (Just rpKey) NoEntity (Cmd SRecipient $ NEW $ NewQueueReq rKey dhKey auth subMode (Just qrd)) >>= \case
sendProtocolCommand_ c nonce_ Nothing (Just rpKey) NoEntity (Cmd SCreator $ NEW $ NewQueueReq rKey dhKey auth subMode (Just qrd)) >>= \case
IDS qik -> pure qik
r -> throwE $ unexpectedResponse r
@@ -848,7 +848,7 @@ nsubResponse_ = \case
r' -> Left $ unexpectedResponse r'
{-# INLINE nsubResponse_ #-}
subscribeService :: forall p. (PartyI p, SubscriberParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO Int64
subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO Int64
subscribeService c party = case smpClientService c of
Just THClientService {serviceId, serviceKey} -> do
liftIO $ enablePings c
@@ -858,8 +858,8 @@ subscribeService c party = case smpClientService c of
where
subCmd :: Command p
subCmd = case party of
SRecipient -> SUBS
SNotifier -> NSUBS
SRecipientService -> SUBS
SNotifierService -> NSUBS
Nothing -> throwE PCEServiceUnavailable
smpClientService :: SMPClient -> Maybe THClientService
@@ -1119,8 +1119,8 @@ proxySMPCommand c@ProtocolClient {thParams = proxyThParams, client_ = PClient {c
-- server interaction errors are thrown directly
t' <- liftEitherWith PCECryptoError $ C.cbDecrypt cmdSecret (C.reverseNonce nonce) er
case tParse serverThParams t' of
t'' :| [] -> case tDecodeParseValidate serverThParams t'' of
(_auth, _signed, (_c, _e, cmd)) -> case cmd of
t'' :| [] -> case tDecodeClient serverThParams t'' of
(_, _, cmd) -> case cmd of
Right (ERR e) -> throwE $ PCEProtocolError e -- this is the error from the destination relay
Right r' -> pure $ Right r'
Left e -> throwE $ PCEResponseError e

View File

@@ -70,9 +70,9 @@ import Simplex.Messaging.Protocol
QueueId,
SMPServer,
SParty (..),
SubscriberParty,
subscriberParty,
subscriberServiceRole
ServiceParty,
serviceParty,
partyServiceRole
)
import Simplex.Messaging.Session
import Simplex.Messaging.TMap (TMap)
@@ -331,11 +331,11 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s
reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> ExceptT SMPClientError IO ()
reconnectSMPClient ca@SMPClientAgent {agentCfg, agentParty} srv (sSub_, qSubs_) =
withSMP ca srv $ \smp -> liftIO $ case subscriberParty agentParty of
withSMP ca srv $ \smp -> liftIO $ case serviceParty agentParty of
Just Dict -> resubscribe smp
Nothing -> pure ()
where
resubscribe :: (PartyI p, SubscriberParty p) => SMPClient -> IO ()
resubscribe :: (PartyI p, ServiceParty p) => SMPClient -> IO ()
resubscribe smp = do
mapM_ (smpSubscribeService ca smp srv) sSub_
forM_ qSubs_ $ \qSubs -> do
@@ -394,22 +394,22 @@ withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPE
logInfo $ "SMP error (" <> safeDecodeUtf8 (strEncode $ host srv) <> "): " <> tshow e
throwE e
subscribeQueuesNtfs :: SMPClientAgent 'Notifier -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO ()
subscribeQueuesNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO ()
subscribeQueuesNtfs = subscribeQueues_
{-# INLINE subscribeQueuesNtfs #-}
subscribeQueues_ :: SubscriberParty p => SMPClientAgent p -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO ()
subscribeQueues_ :: ServiceParty p => SMPClientAgent p -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO ()
subscribeQueues_ ca srv subs = do
atomically $ addPendingSubs ca srv $ L.toList subs
runExceptT (getSMPServerClient' ca srv) >>= \case
Right smp -> smpSubscribeQueues ca smp srv subs
Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that
smpSubscribeQueues :: SubscriberParty p => SMPClientAgent p -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO ()
smpSubscribeQueues :: ServiceParty p => SMPClientAgent p -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO ()
smpSubscribeQueues ca smp srv subs = do
rs <- case agentParty ca of
SRecipient -> subscribeSMPQueues smp subs
SNotifier -> subscribeSMPQueuesNtfs smp subs
SRecipientService -> subscribeSMPQueues smp subs
SNotifierService -> subscribeSMPQueuesNtfs smp subs
rs' <-
atomically $
ifM
@@ -454,18 +454,18 @@ smpSubscribeQueues ca smp srv subs = do
notify_ :: (SMPServer -> NonEmpty a -> SMPClientAgentEvent) -> [a] -> IO ()
notify_ evt qs = mapM_ (notify ca . evt srv) $ L.nonEmpty qs
subscribeServiceNtfs :: SMPClientAgent 'Notifier -> SMPServer -> (ServiceId, Int64) -> IO ()
subscribeServiceNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> (ServiceId, Int64) -> IO ()
subscribeServiceNtfs = subscribeService_
{-# INLINE subscribeServiceNtfs #-}
subscribeService_ :: (PartyI p, SubscriberParty p) => SMPClientAgent p -> SMPServer -> (ServiceId, Int64) -> IO ()
subscribeService_ :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPServer -> (ServiceId, Int64) -> IO ()
subscribeService_ ca srv serviceSub = do
atomically $ setPendingServiceSub ca srv $ Just serviceSub
runExceptT (getSMPServerClient' ca srv) >>= \case
Right smp -> smpSubscribeService ca smp srv serviceSub
Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that
smpSubscribeService :: (PartyI p, SubscriberParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> (ServiceId, Int64) -> IO ()
smpSubscribeService :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> (ServiceId, Int64) -> IO ()
smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService smp of
Just service | serviceAvailable service -> subscribe
_ -> notifyUnavailable
@@ -490,7 +490,7 @@ smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService
setActiveServiceSub ca srv $ Just ((serviceId, n), sessId)
setPendingServiceSub ca srv Nothing
serviceAvailable THClientService {serviceRole, serviceId = serviceId'} =
serviceId == serviceId' && subscriberServiceRole (agentParty ca) == serviceRole
serviceId == serviceId' && partyServiceRole (agentParty ca) == serviceRole
notifyUnavailable = do
atomically $ setPendingServiceSub ca srv Nothing
notify ca $ CAServiceUnavailable srv serviceSub -- this will resubscribe all queues directly

View File

@@ -214,7 +214,7 @@ instance NtfEntityI e => ProtocolEncoding NTFVersion ErrorType (NtfCommand e) wh
fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse
{-# INLINE fromProtocolError #-}
checkCredentials (auth, _, EntityId entityId, _) cmd = case cmd of
checkCredentials auth (EntityId entityId) cmd = case cmd of
-- TNEW and SNEW must have signature but NOT token/subscription IDs
TNEW {} -> sigNoEntity
SNEW {} -> sigNoEntity
@@ -254,7 +254,7 @@ instance ProtocolEncoding NTFVersion ErrorType NtfCmd where
fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse
{-# INLINE fromProtocolError #-}
checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c
checkCredentials tAuth entId (NtfCmd e c) = NtfCmd e <$> checkCredentials tAuth entId c
data NtfResponseTag
= NRTknId_
@@ -334,7 +334,7 @@ instance ProtocolEncoding NTFVersion ErrorType NtfResponse where
PEBlock -> BLOCK
{-# INLINE fromProtocolError #-}
checkCredentials (_, _, EntityId entId, _) cmd = case cmd of
checkCredentials _ (EntityId entId) cmd = case cmd of
-- IDTKN response must not have queue ID
NRTknId {} -> noEntity
-- IDSUB response must not have queue ID

View File

@@ -62,7 +62,7 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore, TokenNtfMessag
import Simplex.Messaging.Notifications.Server.Store.Postgres
import Simplex.Messaging.Notifications.Server.Store.Types
import Simplex.Messaging.Notifications.Transport
import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceId, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGet, tPut)
import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceId, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGetServer, tPut)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Server
import Simplex.Messaging.Server.Control (CPClientRole (..))
@@ -277,21 +277,21 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
apnsPushQLength
}
where
getSMPServiceSubMetrics :: forall sub. SMPClientAgent 'Notifier -> (SMPClientAgent 'Notifier -> TMap SMPServer (TVar (Maybe sub))) -> (sub -> Int64) -> IO NtfSMPSubMetrics
getSMPServiceSubMetrics :: forall sub. SMPClientAgent 'NotifierService -> (SMPClientAgent 'NotifierService -> TMap SMPServer (TVar (Maybe sub))) -> (sub -> Int64) -> IO NtfSMPSubMetrics
getSMPServiceSubMetrics a sel subQueueCount = getSubMetrics_ a sel countSubs
where
countSubs :: (NtfSMPSubMetrics, S.Set Text) -> (SMPServer, TVar (Maybe sub)) -> IO (NtfSMPSubMetrics, S.Set Text)
countSubs acc (srv, serviceSubs) = subMetricsResult a acc srv . fromIntegral . maybe 0 subQueueCount <$> readTVarIO serviceSubs
getSMPSubMetrics :: SMPClientAgent 'Notifier -> (SMPClientAgent 'Notifier -> TMap SMPServer (TMap NotifierId a)) -> IO NtfSMPSubMetrics
getSMPSubMetrics :: SMPClientAgent 'NotifierService -> (SMPClientAgent 'NotifierService -> TMap SMPServer (TMap NotifierId a)) -> IO NtfSMPSubMetrics
getSMPSubMetrics a sel = getSubMetrics_ a sel countSubs
where
countSubs :: (NtfSMPSubMetrics, S.Set Text) -> (SMPServer, TMap NotifierId a) -> IO (NtfSMPSubMetrics, S.Set Text)
countSubs acc (srv, queueSubs) = subMetricsResult a acc srv . M.size <$> readTVarIO queueSubs
getSubMetrics_ ::
SMPClientAgent 'Notifier ->
(SMPClientAgent 'Notifier -> TVar (M.Map SMPServer sub')) ->
SMPClientAgent 'NotifierService ->
(SMPClientAgent 'NotifierService -> TVar (M.Map SMPServer sub')) ->
((NtfSMPSubMetrics, S.Set Text) -> (SMPServer, sub') -> IO (NtfSMPSubMetrics, S.Set Text)) ->
IO NtfSMPSubMetrics
getSubMetrics_ a sel countSubs = do
@@ -300,7 +300,7 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
(metrics', otherSrvs) <- foldM countSubs (metrics, S.empty) $ M.assocs subs
pure (metrics' :: NtfSMPSubMetrics) {otherServers = S.size otherSrvs}
subMetricsResult :: SMPClientAgent 'Notifier -> (NtfSMPSubMetrics, S.Set Text) -> SMPServer -> Int -> (NtfSMPSubMetrics, S.Set Text)
subMetricsResult :: SMPClientAgent 'NotifierService -> (NtfSMPSubMetrics, S.Set Text) -> SMPServer -> Int -> (NtfSMPSubMetrics, S.Set Text)
subMetricsResult a acc@(metrics, !otherSrvs) srv@(SMPServer (h :| _) _ _) cnt
| isOwnServer a srv =
let !ownSrvSubs' = M.alter (Just . maybe cnt (+ cnt)) host ownSrvSubs
@@ -314,9 +314,9 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
NtfSMPSubMetrics {ownSrvSubs, otherSrvSubCount} = metrics
host = safeDecodeUtf8 $ strEncode h
getSMPWorkerMetrics :: SMPClientAgent 'Notifier -> TMap SMPServer a -> IO NtfSMPWorkerMetrics
getSMPWorkerMetrics :: SMPClientAgent 'NotifierService -> TMap SMPServer a -> IO NtfSMPWorkerMetrics
getSMPWorkerMetrics a v = workerMetrics a . M.keys <$> readTVarIO v
workerMetrics :: SMPClientAgent 'Notifier -> [SMPServer] -> NtfSMPWorkerMetrics
workerMetrics :: SMPClientAgent 'NotifierService -> [SMPServer] -> NtfSMPWorkerMetrics
workerMetrics a srvs = NtfSMPWorkerMetrics {ownServers = reverse ownSrvs, otherServers}
where
(ownSrvs, otherServers) = foldl' countSrv ([], 0) srvs
@@ -455,7 +455,7 @@ resubscribe NtfSubscriber {smpAgent = ca} = do
counts <- mapConcurrently (subscribeSrvSubs ca st batchSize) srvs
logNote $ "Completed all SMP resubscriptions for " <> tshow (length srvs) <> " servers (" <> tshow (sum counts) <> " subscriptions)"
subscribeSrvSubs :: SMPClientAgent 'Notifier -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe (ServiceId, Int64)) -> IO Int
subscribeSrvSubs :: SMPClientAgent 'NotifierService -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe (ServiceId, Int64)) -> IO Int
subscribeSrvSubs ca st batchSize (srv, srvId, service_) = do
let srvStr = safeDecodeUtf8 (strEncode $ L.head $ host srv)
logNote $ "Starting SMP resubscriptions for " <> srvStr
@@ -722,25 +722,24 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte
receive :: Transport c => NtfPostgresStore -> THandleNTF c 'TServer -> NtfServerClient -> IO ()
receive st th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ, rcvActiveAt} = forever $ do
ts <- L.toList <$> tGet th
ts <- L.toList <$> tGetServer th
atomically . (writeTVar rcvActiveAt $!) =<< getSystemTime
(errs, cmds) <- partitionEithers <$> mapM cmdAction ts
write sndQ errs
write rcvQ cmds
where
cmdAction t@(_, _, (corrId, entId, cmdOrError)) =
case cmdOrError of
Left e -> do
logError $ "invalid client request: " <> tshow e
pure $ Left (corrId, entId, NRErr e)
Right cmd ->
verified =<< verifyNtfTransmission st ((,C.cbNonce (SMP.bs corrId)) <$> thAuth) t cmd
where
verified = \case
VRVerified req -> pure $ Right req
VRFailed e -> do
logError "unauthorized client request"
pure $ Left (corrId, entId, NRErr e)
cmdAction = \case
Left (corrId, entId, e) -> do
logError $ "invalid client request: " <> tshow e
pure $ Left (corrId, entId, NRErr e)
Right t@(_, _, (corrId, entId, _)) ->
verified =<< verifyNtfTransmission st thAuth t
where
verified = \case
VRVerified req -> pure $ Right req
VRFailed e -> do
logError "unauthorized client request"
pure $ Left (corrId, entId, NRErr e)
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
send :: Transport c => THandleNTF c 'TServer -> NtfServerClient -> IO ()
@@ -751,10 +750,10 @@ send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do
data VerificationResult = VRVerified NtfRequest | VRFailed ErrorType
verifyNtfTransmission :: NtfPostgresStore -> Maybe (THandleAuth 'TServer, C.CbNonce) -> SignedTransmission ErrorType NtfCmd -> NtfCmd -> IO VerificationResult
verifyNtfTransmission st auth_ (tAuth, authorized, (corrId, entId, _)) = \case
verifyNtfTransmission :: NtfPostgresStore -> Maybe (THandleAuth 'TServer) -> SignedTransmission NtfCmd -> IO VerificationResult
verifyNtfTransmission st thAuth (tAuth, authorized, (corrId, entId, cmd)) = case cmd of
NtfCmd SToken c@(TNEW tkn@(NewNtfTkn _ k _))
| verifyCmdAuthorization auth_ tAuth authorized k ->
| verifyCmdAuthorization thAuth tAuth authorized corrId k ->
result <$> findNtfTokenRegistration st tkn
| otherwise -> pure $ VRFailed AUTH
where
@@ -783,10 +782,10 @@ verifyNtfTransmission st auth_ (tAuth, authorized, (corrId, entId, _)) = \case
subCmd s c = NtfReqCmd SSubscription (NtfSub s) (corrId, entId, c)
verifyToken :: NtfTknRec -> NtfRequest -> VerificationResult
verifyToken NtfTknRec {tknVerifyKey} r
| verifyCmdAuthorization auth_ tAuth authorized tknVerifyKey = VRVerified r
| verifyCmdAuthorization thAuth tAuth authorized corrId tknVerifyKey = VRVerified r
| otherwise = VRFailed AUTH
err = \case -- signature verification for AUTH errors mitigates timing attacks for existence checks
AUTH -> maybe False (dummyVerifyCmd auth_ authorized) tAuth `seq` VRFailed AUTH
AUTH -> dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH
e -> VRFailed e
client :: NtfServerClient -> NtfSubscriber -> NtfPushServer -> M ()

View File

@@ -127,7 +127,7 @@ newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbSt
data NtfSubscriber = NtfSubscriber
{ smpSubscribers :: TMap SMPServer SMPSubscriberVar,
subscriberSeq :: TVar Int,
smpAgent :: SMPClientAgent 'Notifier
smpAgent :: SMPClientAgent 'NotifierService
}
type SMPSubscriberVar = SessionVar SMPSubscriber
@@ -136,7 +136,7 @@ newNtfSubscriber :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO NtfSubscriber
newNtfSubscriber smpAgentCfg random = do
smpSubscribers <- TM.emptyIO
subscriberSeq <- newTVarIO 0
smpAgent <- newSMPClientAgent SNotifier smpAgentCfg random
smpAgent <- newSMPClientAgent SNotifierService smpAgentCfg random
pure NtfSubscriber {smpSubscribers, subscriberSeq, smpAgent}
data SMPSubscriber = SMPSubscriber

View File

@@ -66,8 +66,9 @@ module Simplex.Messaging.Protocol
EncDataBytes (..),
Party (..),
Cmd (..),
DirectParty,
SubscriberParty,
QueueParty,
BatchParty,
ServiceParty,
ASubscriberParty (..),
BrokerMsg (..),
SParty (..),
@@ -80,12 +81,13 @@ module Simplex.Messaging.Protocol
BrokerErrorType (..),
BlockingInfo (..),
BlockingReason (..),
RawTransmission,
Transmission,
TAuthorizations,
TransmissionAuth (..),
SignedTransmission,
SignedTransmissionOrError,
SentRawTransmission,
SignedRawTransmission,
ClientMsgEnvelope (..),
PubHeader (..),
ClientMessage (..),
@@ -153,8 +155,11 @@ module Simplex.Messaging.Protocol
currentSMPClientVersion,
senderCanSecure,
queueReqMode,
subscriberParty,
subscriberServiceRole,
queueParty,
batchParty,
serviceParty,
partyClientRole,
partyServiceRole,
userProtocol,
rcvMessageMeta,
noMsgFlags,
@@ -186,9 +191,11 @@ module Simplex.Messaging.Protocol
TransportBatch (..),
tPut,
tPutLog,
tGet,
tGetServer,
tGetClient,
tParse,
tDecodeParseValidate,
tDecodeServer,
tDecodeClient,
tEncode,
tEncodeBatch1,
batchTransmissions,
@@ -208,7 +215,7 @@ import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.Bifunctor (bimap, first)
import qualified Data.ByteString.Base64 as B64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
@@ -303,22 +310,40 @@ e2eEncMessageLength :: Int
e2eEncMessageLength = 16000 -- 15988 .. 16005
-- | SMP protocol clients
data Party = Recipient | Sender | Notifier | LinkClient | ProxiedClient | ProxyService
data Party
= Creator
| Recipient
| RecipientService
| Sender
| IdleClient
| Notifier
| NotifierService
| LinkClient
| ProxiedClient
| ProxyService
deriving (Show)
-- | Singleton types for SMP protocol clients
data SParty :: Party -> Type where
SCreator :: SParty Creator
SRecipient :: SParty Recipient
SRecipientService :: SParty RecipientService
SSender :: SParty Sender
SIdleClient :: SParty IdleClient
SNotifier :: SParty Notifier
SNotifierService :: SParty NotifierService
SSenderLink :: SParty LinkClient
SProxiedClient :: SParty ProxiedClient
SProxyService :: SParty ProxyService
instance TestEquality SParty where
testEquality SCreator SCreator = Just Refl
testEquality SRecipient SRecipient = Just Refl
testEquality SRecipientService SRecipientService = Just Refl
testEquality SSender SSender = Just Refl
testEquality SIdleClient SIdleClient = Just Refl
testEquality SNotifier SNotifier = Just Refl
testEquality SNotifierService SNotifierService = Just Refl
testEquality SSenderLink SSenderLink = Just Refl
testEquality SProxiedClient SProxiedClient = Just Refl
testEquality SProxyService SProxyService = Just Refl
@@ -328,34 +353,72 @@ deriving instance Show (SParty p)
class PartyI (p :: Party) where sParty :: SParty p
instance PartyI Creator where sParty = SCreator
instance PartyI Recipient where sParty = SRecipient
instance PartyI RecipientService where sParty = SRecipientService
instance PartyI Sender where sParty = SSender
instance PartyI IdleClient where sParty = SIdleClient
instance PartyI Notifier where sParty = SNotifier
instance PartyI NotifierService where sParty = SNotifierService
instance PartyI LinkClient where sParty = SSenderLink
instance PartyI ProxiedClient where sParty = SProxiedClient
instance PartyI ProxyService where sParty = SProxyService
type family DirectParty (p :: Party) :: Constraint where
DirectParty Recipient = ()
DirectParty Sender = ()
DirectParty Notifier = ()
DirectParty LinkClient = ()
DirectParty ProxyService = ()
DirectParty p =
(Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not direct"))
-- command parties that can read queues
type family QueueParty (p :: Party) :: Constraint where
QueueParty Recipient = ()
QueueParty Sender = ()
QueueParty Notifier = ()
QueueParty LinkClient = ()
QueueParty p =
(Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not QueueParty"))
type family SubscriberParty (p :: Party) :: Constraint where
SubscriberParty Recipient = ()
SubscriberParty Notifier = ()
SubscriberParty p =
(Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not subscriber"))
queueParty :: SParty p -> Maybe (Dict (PartyI p, QueueParty p))
queueParty = \case
SRecipient -> Just Dict
SSender -> Just Dict
SSenderLink -> Just Dict
SNotifier -> Just Dict
_ -> Nothing
{-# INLINE queueParty #-}
data ASubscriberParty = forall p. (PartyI p, SubscriberParty p) => ASP (SParty p)
type family BatchParty (p :: Party) :: Constraint where
BatchParty Recipient = ()
BatchParty Notifier = ()
BatchParty p =
(Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not BatchParty"))
batchParty :: SParty p -> Maybe (Dict (PartyI p, BatchParty p))
batchParty = \case
SRecipient -> Just Dict
SNotifier -> Just Dict
_ -> Nothing
{-# INLINE batchParty #-}
-- command parties that can subscribe to individual queues
type family ServiceParty (p :: Party) :: Constraint where
ServiceParty RecipientService = ()
ServiceParty NotifierService = ()
ServiceParty p =
(Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not ServiceParty"))
serviceParty :: SParty p -> Maybe (Dict (PartyI p, ServiceParty p))
serviceParty = \case
SRecipientService -> Just Dict
SNotifierService -> Just Dict
_ -> Nothing
{-# INLINE serviceParty #-}
data ASubscriberParty = forall p. (PartyI p, ServiceParty p) => ASP (SParty p)
deriving instance Show ASubscriberParty
@@ -364,30 +427,37 @@ instance Eq ASubscriberParty where
instance Encoding ASubscriberParty where
smpEncode = \case
ASP SRecipient -> "R"
ASP SNotifier -> "N"
ASP SRecipientService -> "R"
ASP SNotifierService -> "N"
smpP =
A.anyChar >>= \case
'R' -> pure $ ASP SRecipient
'N' -> pure $ ASP SNotifier
'R' -> pure $ ASP SRecipientService
'N' -> pure $ ASP SNotifierService
_ -> fail "bad ASubscriberParty"
instance StrEncoding ASubscriberParty where
strEncode = smpEncode
strP = smpP
subscriberParty :: SParty p -> Maybe (Dict (PartyI p, SubscriberParty p))
subscriberParty = \case
SRecipient -> Just Dict
SNotifier -> Just Dict
_ -> Nothing
{-# INLINE subscriberParty #-}
partyClientRole :: SParty p -> Maybe SMPServiceRole
partyClientRole = \case
SCreator -> Just SRMessaging
SRecipient -> Just SRMessaging
SRecipientService -> Just SRMessaging
SSender -> Just SRMessaging
SIdleClient -> Nothing
SNotifier -> Just SRNotifier
SNotifierService -> Just SRNotifier
SSenderLink -> Just SRMessaging
SProxiedClient -> Just SRMessaging
SProxyService -> Just SRProxy
{-# INLINE partyClientRole #-}
subscriberServiceRole :: SubscriberParty p => SParty p -> SMPServiceRole
subscriberServiceRole = \case
SRecipient -> SRMessaging
SNotifier -> SRNotifier
{-# INLINE subscriberServiceRole #-}
partyServiceRole :: ServiceParty p => SParty p -> SMPServiceRole
partyServiceRole = \case
SRecipientService -> SRMessaging
SNotifierService -> SRNotifier
{-# INLINE partyServiceRole #-}
-- | Type for client command of any participant.
data Cmd = forall p. PartyI p => Cmd (SParty p) (Command p)
@@ -398,7 +468,9 @@ deriving instance Show Cmd
type Transmission c = (CorrId, EntityId, c)
-- | signed parsed transmission, with original raw bytes and parsing error.
type SignedTransmission e c = (Maybe TAuthorizations, Signed, Transmission (Either e c))
type SignedTransmission c = (Maybe TAuthorizations, Signed, Transmission c)
type SignedTransmissionOrError e c = Either (Transmission e) (SignedTransmission c)
type Signed = ByteString
@@ -439,9 +511,6 @@ decodeTAuthBytes s serviceSig
| B.length s == C.cbAuthenticatorSize = Right $ Just (TAAuthenticator (C.CbAuthenticator s), serviceSig)
| otherwise = (\sig -> Just (TASignature sig, serviceSig)) <$> C.decodeSignature s
-- | unparsed sent SMP transmission with signature, without session ID.
type SignedRawTransmission = (Maybe TAuthorizations, CorrId, EntityId, ByteString)
-- | unparsed sent SMP transmission with signature.
type SentRawTransmission = (Maybe TAuthorizations, ByteString)
@@ -466,10 +535,10 @@ data Command (p :: Party) where
-- v6 of SMP servers only support signature algorithm for command authorization.
-- v7 of SMP servers additionally support additional layer of authenticated encryption.
-- RcvPublicAuthKey is defined as C.APublicKey - it can be either signature or DH public keys.
NEW :: NewQueueReq -> Command Recipient
NEW :: NewQueueReq -> Command Creator
SUB :: Command Recipient
-- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command.
SUBS :: Command Recipient
SUBS :: Command RecipientService
KEY :: SndPublicAuthKey -> Command Recipient
RKEY :: NonEmpty RcvPublicAuthKey -> Command Recipient
LSET :: LinkId -> QueueLinkData -> Command Recipient
@@ -486,14 +555,14 @@ data Command (p :: Party) where
-- SEND v1 has to be supported for encoding/decoding
-- SEND :: MsgBody -> Command Sender
SEND :: MsgFlags -> MsgBody -> Command Sender
PING :: Command Sender
PING :: Command IdleClient
-- Client accessing short links
LKEY :: SndPublicAuthKey -> Command LinkClient
LGET :: Command LinkClient
-- SMP notification subscriber commands
NSUB :: Command Notifier
-- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command.
NSUBS :: Command Notifier
NSUBS :: Command NotifierService
PRXY :: SMPServer -> Maybe BasicAuth -> Command ProxiedClient -- request a relay server connection by URI
-- Transmission to proxy:
-- - entity ID: ID of the session with relay returned in PKEY (response to PRXY)
@@ -827,9 +896,9 @@ noMsgFlags = MsgFlags {notification = False}
-- * SMP command tags
data CommandTag (p :: Party) where
NEW_ :: CommandTag Recipient
NEW_ :: CommandTag Creator
SUB_ :: CommandTag Recipient
SUBS_ :: CommandTag Recipient
SUBS_ :: CommandTag RecipientService
KEY_ :: CommandTag Recipient
RKEY_ :: CommandTag Recipient
LSET_ :: CommandTag Recipient
@@ -843,14 +912,14 @@ data CommandTag (p :: Party) where
QUE_ :: CommandTag Recipient
SKEY_ :: CommandTag Sender
SEND_ :: CommandTag Sender
PING_ :: CommandTag Sender
PING_ :: CommandTag IdleClient
LKEY_ :: CommandTag LinkClient
LGET_ :: CommandTag LinkClient
PRXY_ :: CommandTag ProxiedClient
PFWD_ :: CommandTag ProxiedClient
RFWD_ :: CommandTag ProxyService
NSUB_ :: CommandTag Notifier
NSUBS_ :: CommandTag Notifier
NSUBS_ :: CommandTag NotifierService
data CmdTag = forall p. PartyI p => CT (SParty p) (CommandTag p)
@@ -916,9 +985,9 @@ instance PartyI p => Encoding (CommandTag p) where
instance ProtocolMsgTag CmdTag where
decodeTag = \case
"NEW" -> Just $ CT SRecipient NEW_
"NEW" -> Just $ CT SCreator NEW_
"SUB" -> Just $ CT SRecipient SUB_
"SUBS" -> Just $ CT SRecipient SUBS_
"SUBS" -> Just $ CT SRecipientService SUBS_
"KEY" -> Just $ CT SRecipient KEY_
"RKEY" -> Just $ CT SRecipient RKEY_
"LSET" -> Just $ CT SRecipient LSET_
@@ -932,14 +1001,14 @@ instance ProtocolMsgTag CmdTag where
"QUE" -> Just $ CT SRecipient QUE_
"SKEY" -> Just $ CT SSender SKEY_
"SEND" -> Just $ CT SSender SEND_
"PING" -> Just $ CT SSender PING_
"PING" -> Just $ CT SIdleClient PING_
"LKEY" -> Just $ CT SSenderLink LKEY_
"LGET" -> Just $ CT SSenderLink LGET_
"PRXY" -> Just $ CT SProxiedClient PRXY_
"PFWD" -> Just $ CT SProxiedClient PFWD_
"RFWD" -> Just $ CT SProxyService RFWD_
"NSUB" -> Just $ CT SNotifier NSUB_
"NSUBS" -> Just $ CT SNotifier NSUBS_
"NSUBS" -> Just $ CT SNotifierService NSUBS_
_ -> Nothing
instance Encoding CmdTag where
@@ -1564,7 +1633,7 @@ instance Protocol SMPVersion ErrorType BrokerMsg where
Cmd _ NSUB -> True
_ -> False
{-# INLINE useServiceAuth #-}
protocolPing = Cmd SSender PING
protocolPing = Cmd SIdleClient PING
{-# INLINE protocolPing #-}
protocolError = \case
ERR e -> Just e
@@ -1576,7 +1645,7 @@ class ProtocolMsgTag (Tag msg) => ProtocolEncoding v err msg | msg -> err, msg -
encodeProtocol :: Version v -> msg -> ByteString
protocolP :: Version v -> Tag msg -> Parser msg
fromProtocolError :: ProtocolErrorType -> err
checkCredentials :: SignedRawTransmission -> msg -> Either err msg
checkCredentials :: Maybe TAuthorizations -> EntityId -> msg -> Either err msg
instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
type Tag (Command p) = CommandTag p
@@ -1620,7 +1689,7 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
checkCredentials (auth, _, EntityId entId, _) cmd = case cmd of
checkCredentials auth (EntityId entId) cmd = case cmd of
-- NEW must have signature but NOT queue ID
NEW {}
| isNothing auth -> Left $ CMD NO_AUTH
@@ -1663,14 +1732,14 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
{-# INLINE encodeProtocol #-}
protocolP v = \case
CT SRecipient tag ->
Cmd SRecipient <$> case tag of
NEW_
| v >= shortLinksSMPVersion -> NEW <$> new smpP smpP
| v >= sndAuthKeySMPVersion -> NEW <$> new smpP (qReq <$> smpP)
| otherwise -> NEW <$> new auth (pure Nothing)
CT SCreator NEW_ -> Cmd SCreator <$> newCmd
where
newCmd
| v >= shortLinksSMPVersion = new smpP smpP
| v >= sndAuthKeySMPVersion = new smpP (qReq <$> smpP)
| otherwise = new auth (pure Nothing)
where
new p1 p2 = do
new p1 p2 = NEW <$> do
rcvAuthKey <- _smpP
rcvDhKey <- smpP
auth_ <- p1
@@ -1681,8 +1750,9 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
pure NewQueueReq {rcvAuthKey, rcvDhKey, auth_, subMode, queueReqData} -- ntfCreds
auth = optional (A.char 'A' *> smpP)
qReq sndSecure = Just $ if sndSecure then QRMessaging Nothing else QRContact Nothing
CT SRecipient tag ->
Cmd SRecipient <$> case tag of
SUB_ -> pure SUB
SUBS_ -> pure SUBS
KEY_ -> KEY <$> _smpP
RKEY_ -> RKEY <$> _smpP
LSET_ -> LSET <$> _smpP <*> smpP
@@ -1694,11 +1764,12 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
OFF_ -> pure OFF
DEL_ -> pure DEL
QUE_ -> pure QUE
CT SRecipientService SUBS_ -> pure $ Cmd SRecipientService SUBS
CT SSender tag ->
Cmd SSender <$> case tag of
SKEY_ -> SKEY <$> _smpP
SEND_ -> SEND <$> _smpP <*> (unTail <$> _smpP)
PING_ -> pure PING
CT SIdleClient PING_ -> pure $ Cmd SIdleClient PING
CT SProxyService RFWD_ ->
Cmd SProxyService . RFWD . EncFwdTransmission . unTail <$> _smpP
CT SSenderLink tag ->
@@ -1709,15 +1780,13 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
Cmd SProxiedClient <$> case tag of
PFWD_ -> PFWD <$> _smpP <*> smpP <*> (EncTransmission . unTail <$> smpP)
PRXY_ -> PRXY <$> _smpP <*> smpP
CT SNotifier tag ->
pure $ Cmd SNotifier $ case tag of
NSUB_ -> NSUB
NSUBS_ -> NSUBS
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
CT SNotifierService NSUBS_ -> pure $ Cmd SNotifierService NSUBS
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
checkCredentials tAuth entId (Cmd p c) = Cmd p <$> checkCredentials tAuth entId c
{-# INLINE checkCredentials #-}
instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where
@@ -1804,7 +1873,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where
PEBlock -> BLOCK
{-# INLINE fromProtocolError #-}
checkCredentials (_, _, EntityId entId, _) cmd = case cmd of
checkCredentials _ (EntityId entId) cmd = case cmd of
-- IDS response should not have queue ID
IDS _ -> Right cmd
-- ERR response does not always have queue ID
@@ -2077,26 +2146,51 @@ tParse thParams@THandleParams {batch} s
eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b)
eitherList = either (\e -> [Left e])
-- | Receive client and server transmissions (determined by `cmd` type).
tGet :: forall v err cmd c p. (ProtocolEncoding v err cmd, Transport c) => THandle v c p -> IO (NonEmpty (SignedTransmission err cmd))
tGet th@THandle {params} = L.map (tDecodeParseValidate params) <$> tGetParse th
-- | Receive server transmissions
tGetServer :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TServer -> IO (NonEmpty (SignedTransmissionOrError err cmd))
tGetServer = tGet tDecodeServer
{-# INLINE tGetServer #-}
tDecodeParseValidate :: forall v p err cmd. ProtocolEncoding v err cmd => THandleParams v p -> Either TransportError RawTransmission -> SignedTransmission err cmd
tDecodeParseValidate THandleParams {sessionId, thVersion = v, implySessId} = \case
-- | Receive client transmissions
tGetClient :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TClient -> IO (NonEmpty (Transmission (Either err cmd)))
tGetClient = tGet tDecodeClient
{-# INLINE tGetClient #-}
tGet ::
Transport c =>
(THandleParams v p -> Either TransportError RawTransmission -> r) ->
THandle v c p ->
IO (NonEmpty r)
tGet tDecode th@THandle {params} = L.map (tDecode params) <$> tGetParse th
{-# INLINE tGet #-}
tDecodeServer :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v 'TServer -> Either TransportError RawTransmission -> SignedTransmissionOrError err cmd
tDecodeServer THandleParams {sessionId, thVersion = v, implySessId} = \case
Right RawTransmission {authenticator, serviceSig, authorized, sessId, corrId, entityId, command}
| implySessId || sessId == sessionId ->
let decodedTransmission = (,corrId,entityId,command) <$> decodeTAuthBytes authenticator serviceSig
in either (const $ tError corrId) (tParseValidate authorized) decodedTransmission
| otherwise -> (Nothing, "", (corrId, NoEntity, Left $ fromProtocolError @v @err @cmd PESession))
Left _ -> tError ""
| implySessId || sessId == sessionId -> case decodeTAuthBytes authenticator serviceSig of
Right tAuth -> bimap t ((tAuth,authorized,) . t) cmdOrErr
where
cmdOrErr = parseProtocol @v @err @cmd v command >>= checkCredentials tAuth entityId
t :: a -> (CorrId, EntityId, a)
t = (corrId,entityId,)
Left _ -> tError corrId PEBlock
| otherwise -> tError corrId PESession
Left _ -> tError "" PEBlock
where
tError :: CorrId -> SignedTransmission err cmd
tError corrId = (Nothing, "", (corrId, NoEntity, Left $ fromProtocolError @v @err @cmd PEBlock))
tError :: CorrId -> ProtocolErrorType -> SignedTransmissionOrError err cmd
tError corrId err = Left (corrId, NoEntity, fromProtocolError @v @err @cmd err)
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd
tParseValidate signed t@(sig, corrId, entityId, command) =
let cmd = parseProtocol @v @err @cmd v command >>= checkCredentials t
in (sig, signed, (corrId, entityId, cmd))
tDecodeClient :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v 'TClient -> Either TransportError RawTransmission -> Transmission (Either err cmd)
tDecodeClient THandleParams {sessionId, thVersion = v, implySessId} = \case
Right RawTransmission {sessId, corrId, entityId, command}
| implySessId || sessId == sessionId -> (corrId, entityId, cmdOrErr)
| otherwise -> tError corrId PESession
where
cmdOrErr = parseProtocol @v @err @cmd v command >>= checkCredentials Nothing entityId
Left _ -> tError "" PEBlock
where
tError :: CorrId -> ProtocolErrorType -> Transmission (Either err cmd)
tError corrId err = (corrId, NoEntity, Left $ fromProtocolError @v @err @cmd err)
$(J.deriveJSON defaultJSON ''MsgFlags)

View File

@@ -52,12 +52,13 @@ import Control.Monad.Reader
import Control.Monad.Trans.Except
import Control.Monad.STM (retry)
import Crypto.Random (ChaChaDRG)
import Data.Bifunctor (first)
import Data.Bifunctor (first, second)
import Data.ByteString.Base64 (encode)
import qualified Data.ByteString.Builder as BLD
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Constraint (Dict (..))
import Data.Dynamic (toDyn)
import Data.Either (fromRight, partitionEithers)
import Data.Functor (($>))
@@ -1068,48 +1069,53 @@ cancelSub s = case subThread s of
_ -> pure ()
ProhibitSub -> pure ()
type VerifiedTransmissionOrError s = Either (Transmission BrokerMsg) (VerifiedTransmission s)
receive :: forall c s. (Transport c, MsgStoreClass s) => THandleSMP c 'TServer -> s -> Client s -> M s ()
receive h@THandle {params = THandleParams {thAuth, sessionId}} ms Client {rcvQ, sndQ, rcvActiveAt} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive"
sa <- asks serverActive
stats <- asks serverStats
liftIO $ forever $ do
ts <- tGet h
ts <- tGetServer h
unlessM (readTVarIO sa) $ throwIO $ userError "server stopped"
atomically . (writeTVar rcvActiveAt $!) =<< getSystemTime
let service = peerClientService =<< thAuth
(errs, cmds) <- partitionEithers <$> mapM (cmdAction stats service) (L.toList ts)
updateBatchStats stats cmds
write sndQ errs
write rcvQ cmds
let (es, ts') = partitionEithers $ L.toList ts
errs = map (second ERR) es
case ts' of
(_, _, (_, _, Cmd p cmd)) : rest -> do
let service = peerClientService =<< thAuth
(errs', cmds) <- partitionEithers <$> case batchParty p of
Just Dict | not (null rest) && all (sameParty p) ts'-> do
updateBatchStats stats cmd -- even if nothing is verified
let queueId (_, _, (_, qId, _)) = qId
qs <- getQueueRecs ms p $ map queueId ts'
zipWithM (\t -> verified stats t . verifyLoadedQueue service thAuth t) ts' qs
_ -> mapM (\t -> verified stats t =<< verifyTransmission ms service thAuth t) ts'
write rcvQ cmds
write sndQ $ errs ++ errs'
[] -> write sndQ errs
where
updateBatchStats :: ServerStats -> [(Maybe (StoreQueue s, QueueRec), Transmission Cmd)] -> IO ()
sameParty :: SParty p -> SignedTransmission Cmd -> Bool
sameParty p (_, _, (_, _, Cmd p' _)) = isJust $ testEquality p p'
updateBatchStats :: ServerStats -> Command p -> IO ()
updateBatchStats stats = \case
(_, (_, _, (Cmd _ cmd))) : _ -> do
let sel_ = case cmd of
SUB -> Just qSubAllB
DEL -> Just qDeletedAllB
NSUB -> Just ntfSubB
NDEL -> Just ntfDeletedB
_ -> Nothing
mapM_ (\sel -> incStat $ sel stats) sel_
[] -> pure ()
cmdAction :: ServerStats -> Maybe THPeerClientService -> SignedTransmission ErrorType Cmd -> IO (Either (Transmission BrokerMsg) (Maybe (StoreQueue s, QueueRec), Transmission Cmd))
cmdAction stats service (tAuth, authorized, (corrId, entId, cmdOrError)) =
case cmdOrError of
Left e -> pure $ Left (corrId, entId, ERR e)
Right cmd -> verified =<< verifyTransmission ms service ((,C.cbNonce (bs corrId)) <$> thAuth) tAuth authorized entId cmd
where
verified = \case
VRVerified q -> pure $ Right (q, (corrId, entId, cmd))
VRFailed e -> do
case cmd of
Cmd _ SEND {} -> incStat $ msgSentAuth stats
Cmd _ SUB -> incStat $ qSubAuth stats
Cmd _ NSUB -> incStat $ ntfSubAuth stats
Cmd _ GET -> incStat $ msgGetAuth stats
_ -> pure ()
pure $ Left (corrId, entId, ERR e)
SUB -> incStat $ qSubAllB stats
DEL -> incStat $ qDeletedAllB stats
NDEL -> incStat $ ntfDeletedB stats
NSUB -> incStat $ ntfSubB stats
_ -> pure ()
verified :: ServerStats -> SignedTransmission Cmd -> VerificationResult s -> IO (VerifiedTransmissionOrError s)
verified stats (_, _, t@(corrId, entId, Cmd _ command)) = \case
VRVerified q -> pure $ Right (q, t)
VRFailed e -> Left (corrId, entId, ERR e) <$ when (e == AUTH) incAuthStat
where
incAuthStat = case command of
SEND {} -> incStat $ msgSentAuth stats
SUB -> incStat $ qSubAuth stats
NSUB -> incStat $ ntfSubAuth stats
GET -> incStat $ msgGetAuth stats
_ -> pure ()
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
send :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> IO ()
@@ -1169,34 +1175,42 @@ data VerificationResult s = VRVerified (Maybe (StoreQueue s, QueueRec)) | VRFail
-- - the queue or party key do not exist.
-- In all cases, the time of the verification should depend only on the provided authorization type,
-- a dummy key is used to run verification in the last two cases, and failure is returned irrespective of the result.
verifyTransmission :: forall s. MsgStoreClass s => s -> Maybe THPeerClientService -> Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TAuthorizations -> ByteString -> QueueId -> Cmd -> IO (VerificationResult s)
verifyTransmission ms service auth_ tAuth authorized queueId command@(Cmd party cmd)
| verifyServiceSig = case party of
SRecipient | hasRole SRMessaging -> case cmd of
NEW NewQueueReq {rcvAuthKey = k} -> pure $ Nothing `verifiedWith` k
SUB -> verifyQueue SRecipient $ \q -> Just q `verifiedWithKeys` recipientKeys (snd q)
SUBS -> pure verifyServiceCmd
_ -> verifyQueue SRecipient $ \q -> Just q `verifiedWithKeys` recipientKeys (snd q)
SSender | hasRole SRMessaging -> case cmd of
SKEY k -> verifySecure SSender k
-- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command
SEND {} -> verifyQueue SSender $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified (Just q) else VRFailed AUTH
PING -> pure $ VRVerified Nothing
SSenderLink | hasRole SRMessaging -> case cmd of
LKEY k -> verifySecure SSenderLink k
LGET -> verifyQueue SSenderLink $ \q -> if isContactQueue (snd q) then VRVerified (Just q) else VRFailed AUTH
SNotifier | hasRole SRNotifier -> case cmd of
NSUB -> verifyQueue SNotifier $ \q -> maybe dummyVerify (\n -> Just q `verifiedWith` notifierKey n) (notifier $ snd q)
NSUBS -> pure verifyServiceCmd
SProxiedClient | hasRole SRMessaging -> pure $ VRVerified Nothing
SProxyService | hasRole SRProxy -> pure $ VRVerified Nothing
_ -> pure $ VRFailed $ CMD PROHIBITED
| otherwise = pure $ VRFailed SERVICE
verifyTransmission :: forall s. MsgStoreClass s => s -> Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> IO (VerificationResult s)
verifyTransmission ms service thAuth t@(_, _, (_, queueId, Cmd p _)) = case queueParty p of
Just Dict -> verifyLoadedQueue service thAuth t <$> getQueueRec ms p queueId
Nothing -> pure $ verifyQueueTransmission service thAuth t Nothing
verifyLoadedQueue :: Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> Either ErrorType (StoreQueue s, QueueRec) -> VerificationResult s
verifyLoadedQueue service thAuth t@(tAuth, authorized, (corrId, _, _)) = \case
Right q -> verifyQueueTransmission service thAuth t (Just q)
Left AUTH -> dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH
Left e -> VRFailed e
verifyQueueTransmission :: forall s. Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> Maybe (StoreQueue s, QueueRec) -> VerificationResult s
verifyQueueTransmission service thAuth (tAuth, authorized, (corrId, _, command@(Cmd p cmd))) q_
| not checkRole = VRFailed $ CMD PROHIBITED
| not verifyServiceSig = VRFailed SERVICE
| otherwise = vc p cmd
where
hasRole role = case service of
Just THClientService {serviceRole} -> serviceRole == role
Nothing -> True
verify = verifyCmdAuthorization auth_ tAuth authorized'
vc :: SParty p -> Command p -> VerificationResult s -- this pattern match works with ghc8.10.7, flat case sees it as non-exhastive.
vc SCreator (NEW NewQueueReq {rcvAuthKey = k}) = verifiedWith k
vc SRecipient SUB = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q)
vc SRecipient _ = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q)
vc SRecipientService SUBS = verifyServiceCmd
vc SSender (SKEY k) = verifySecure k
-- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command
vc SSender SEND {} = verifyQueue $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified q_ else VRFailed AUTH
vc SIdleClient PING = VRVerified Nothing
vc SSenderLink (LKEY k) = verifySecure k
vc SSenderLink LGET = verifyQueue $ \q -> if isContactQueue (snd q) then VRVerified q_ else VRFailed AUTH
vc SNotifier NSUB = verifyQueue $ \q -> maybe dummyVerify (\n -> verifiedWith $ notifierKey n) (notifier $ snd q)
vc SNotifierService NSUBS = verifyServiceCmd
vc SProxiedClient _ = VRVerified Nothing
vc SProxyService (RFWD _) = VRVerified Nothing
checkRole = case (service, partyClientRole p) of
(Just THClientService {serviceRole}, Just role) -> serviceRole == role
_ -> True
verify = verifyCmdAuthorization thAuth tAuth authorized' corrId
verifyServiceCmd :: VerificationResult s
verifyServiceCmd = case (service, tAuth) of
(Just THClientService {serviceKey = k}, Just (TASignature (C.ASignature C.SEd25519 s), Nothing))
@@ -1214,20 +1228,17 @@ verifyTransmission ms service auth_ tAuth authorized queueId command@(Cmd party
(Just THClientService {serviceCertHash = XV.Fingerprint fp}, Just _) -> fp <> authorized
_ -> authorized
dummyVerify :: VerificationResult s
dummyVerify = verify (dummyAuthKey tAuth) `seq` VRFailed AUTH
verifyQueue :: DirectParty p => SParty p -> ((StoreQueue s, QueueRec) -> VerificationResult s) -> IO (VerificationResult s)
verifyQueue p v = either err v <$> getQueueRec ms p queueId
where
-- this prevents reporting any STORE errors as AUTH errors
err = \case
AUTH -> dummyVerify
e -> VRFailed e
verifySecure :: DirectParty p => SParty p -> SndPublicAuthKey -> IO (VerificationResult s)
verifySecure p k = verifyQueue p $ \q -> if k `allowedKey` snd q then Just q `verifiedWith` k else dummyVerify
verifiedWith :: Maybe (StoreQueue s, QueueRec) -> C.APublicAuthKey -> VerificationResult s
verifiedWith q_ k = if verify k then VRVerified q_ else VRFailed AUTH
verifiedWithKeys :: Maybe (StoreQueue s, QueueRec) -> NonEmpty C.APublicAuthKey -> VerificationResult s
verifiedWithKeys q_ ks = if any verify ks then VRVerified q_ else VRFailed AUTH
dummyVerify = dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH
-- That a specific command requires queue signature verification is determined by `queueParty`,
-- it should be coordinated with the case in this function (`verifyQueueTransmission`)
verifyQueue :: ((StoreQueue s, QueueRec) -> VerificationResult s) -> VerificationResult s
verifyQueue v = maybe (VRFailed INTERNAL) v q_
verifySecure :: SndPublicAuthKey -> VerificationResult s
verifySecure k = verifyQueue $ \q -> if k `allowedKey` snd q then verifiedWith k else dummyVerify
verifiedWith :: C.APublicAuthKey -> VerificationResult s
verifiedWith k = if verify k then VRVerified q_ else VRFailed AUTH
verifiedWithKeys :: NonEmpty C.APublicAuthKey -> VerificationResult s
verifiedWithKeys ks = if any verify ks then VRVerified q_ else VRFailed AUTH
allowedKey k = \case
QueueRec {queueMode = Just QMMessaging, senderKey} -> maybe True (k ==) senderKey
_ -> False
@@ -1243,8 +1254,9 @@ isSecuredMsgQueue QueueRec {queueMode, senderKey} = case queueMode of
Just QMContact -> False
_ -> isJust senderKey
verifyCmdAuthorization :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TAuthorizations -> ByteString -> C.APublicAuthKey -> Bool
verifyCmdAuthorization auth_ tAuth authorized key = maybe False (verify key) tAuth
-- Random correlation ID is used as a nonce in case crypto_box authenticator is used to authorize transmission
verifyCmdAuthorization :: Maybe (THandleAuth 'TServer) -> Maybe TAuthorizations -> ByteString -> CorrId -> C.APublicAuthKey -> Bool
verifyCmdAuthorization thAuth tAuth authorized corrId key = maybe False (verify key) tAuth
where
verify :: C.APublicAuthKey -> TAuthorizations -> Bool
verify (C.APublicAuthKey a k) = \case
@@ -1252,18 +1264,20 @@ verifyCmdAuthorization auth_ tAuth authorized key = maybe False (verify key) tAu
Just Refl -> C.verify' k s authorized
_ -> C.verify' (dummySignKey a') s authorized `seq` False
(TAAuthenticator s, _) -> case a of
C.SX25519 -> verifyCmdAuth auth_ k s authorized
_ -> verifyCmdAuth auth_ dummyKeyX25519 s authorized `seq` False
C.SX25519 -> verifyCmdAuth thAuth k s authorized corrId
_ -> verifyCmdAuth thAuth dummyKeyX25519 s authorized corrId `seq` False
verifyCmdAuth :: Maybe (THandleAuth 'TServer, C.CbNonce) -> C.PublicKeyX25519 -> C.CbAuthenticator -> ByteString -> Bool
verifyCmdAuth auth_ k authenticator authorized = case auth_ of
Just (THAuthServer {serverPrivKey = pk}, nonce) -> C.cbVerify k pk nonce authenticator authorized
verifyCmdAuth :: Maybe (THandleAuth 'TServer) -> C.PublicKeyX25519 -> C.CbAuthenticator -> ByteString -> CorrId -> Bool
verifyCmdAuth thAuth k authenticator authorized (CorrId corrId) = case thAuth of
Just THAuthServer {serverPrivKey = pk} -> C.cbVerify k pk (C.cbNonce corrId) authenticator authorized
Nothing -> False
dummyVerifyCmd :: Maybe (THandleAuth 'TServer, C.CbNonce) -> ByteString -> TAuthorizations -> Bool
dummyVerifyCmd auth_ authorized = \case
(TASignature (C.ASignature a s), _) -> C.verify' (dummySignKey a) s authorized
(TAAuthenticator s, _) -> verifyCmdAuth auth_ dummyKeyX25519 s authorized
dummyVerifyCmd :: Maybe (THandleAuth 'TServer) -> Maybe TAuthorizations -> ByteString -> CorrId -> Maybe Bool
dummyVerifyCmd thAuth tAuth authorized corrId = verify <$> tAuth
where
verify = \case
(TASignature (C.ASignature a s), _) -> C.verify' (dummySignKey a) s authorized
(TAAuthenticator s, _) -> verifyCmdAuth thAuth dummyKeyX25519 s authorized corrId
-- These dummy keys are used with `dummyVerify` function to mitigate timing attacks
-- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes
@@ -1272,13 +1286,6 @@ dummySignKey = \case
C.SEd25519 -> dummyKeyEd25519
C.SEd448 -> dummyKeyEd448
dummyAuthKey :: Maybe TAuthorizations -> C.APublicAuthKey
dummyAuthKey = \case
Just (TASignature (C.ASignature a _), _) -> case a of
C.SEd25519 -> C.APublicAuthKey C.SEd25519 dummyKeyEd25519
C.SEd448 -> C.APublicAuthKey C.SEd448 dummyKeyEd448
_ -> C.APublicAuthKey C.SX25519 dummyKeyX25519
dummyKeyEd25519 :: C.PublicKey 'C.Ed25519
dummyKeyEd25519 = "MCowBQYDK2VwAyEA139Oqs4QgpqbAmB0o7rZf6T19ryl7E65k4AYe0kE3Qs="
@@ -1392,34 +1399,32 @@ client
mkIncProxyStats ps psOwn own sel = do
incStat $ sel ps
when own $ incStat $ sel psOwn
processCommand :: Maybe THPeerClientService -> VersionSMP -> (Maybe (StoreQueue s, QueueRec), Transmission Cmd) -> M s (Maybe (Transmission BrokerMsg))
processCommand :: Maybe THPeerClientService -> VersionSMP -> VerifiedTransmission s -> M s (Maybe (Transmission BrokerMsg))
processCommand service clntVersion (q_, (corrId, entId, cmd)) = case cmd of
Cmd SProxiedClient command -> processProxiedCmd (corrId, entId, command)
Cmd SSender command -> Just <$> case command of
SKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k
SEND flags msgBody -> withQueue_ False $ sendMessage flags msgBody
PING -> pure (corrId, NoEntity, PONG)
Cmd SProxyService (RFWD encBlock) -> Just . (corrId, NoEntity,) <$> processForwardedCommand encBlock
Cmd SIdleClient PING -> pure $ Just (corrId, NoEntity, PONG)
Cmd SProxyService (RFWD encBlock) -> Just . (corrId,NoEntity,) <$> processForwardedCommand encBlock
Cmd SSenderLink command -> Just <$> case command of
LKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k $>> getQueueLink_ q qr
LGET -> withQueue $ \q qr -> checkContact qr $ getQueueLink_ q qr
Cmd SNotifier command -> Just . (corrId,entId,) <$> case command of
NSUB -> case q_ of
Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds
_ -> pure $ ERR INTERNAL
NSUBS -> case service of
Just s -> subscribeServiceNotifications s
Nothing -> pure $ ERR INTERNAL
Cmd SNotifier NSUB -> Just . (corrId,entId,) <$> case q_ of
Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds
_ -> pure $ ERR INTERNAL
Cmd SNotifierService NSUBS -> Just . (corrId,entId,) <$> case service of
Just s -> subscribeServiceNotifications s
Nothing -> pure $ ERR INTERNAL
Cmd SCreator (NEW nqr@NewQueueReq {auth_}) ->
Just <$> ifM allowNew (createQueue nqr) (pure (corrId, entId, ERR AUTH))
where
allowNew = do
ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config
pure $ allowNewQueues && maybe True ((== auth_) . Just) newQueueBasicAuth
Cmd SRecipient command ->
Just <$> case command of
NEW nqr@NewQueueReq {auth_} ->
ifM allowNew (createQueue nqr) (pure (corrId, entId, ERR AUTH))
where
allowNew = do
ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config
pure $ allowNewQueues && maybe True ((== auth_) . Just) newQueueBasicAuth
SUB -> withQueue subscribeQueue
SUBS -> pure $ err (CMD PROHIBITED) -- "TODO [certs rcv]"
GET -> withQueue getMessage
ACK msgId -> withQueue $ acknowledgeMsg msgId
KEY sKey -> withQueue $ \q _ -> either err (corrId,entId,) <$> secureQueue_ q sKey
@@ -1438,6 +1443,7 @@ client
OFF -> maybe (pure $ err INTERNAL) suspendQueue_ q_
DEL -> maybe (pure $ err INTERNAL) delQueueAndMsgs q_
QUE -> withQueue $ \q qr -> (corrId,entId,) <$> getQueueInfo q qr
Cmd SRecipientService SUBS -> pure $ Just $ err (CMD PROHIBITED) -- "TODO [certs rcv]"
where
createQueue :: NewQueueReq -> M s (Transmission BrokerMsg)
createQueue NewQueueReq {rcvAuthKey, rcvDhKey, subMode, queueReqData}
@@ -1492,7 +1498,7 @@ client
| clntIds -> pure $ ERR AUTH -- no retry on collision if sender ID is client-supplied
| otherwise -> tryCreate (n - 1)
Left e -> pure $ ERR e
Right q -> do
Right _q -> do
stats <- asks serverStats
incStat $ qCreated stats
incStat $ qCount stats
@@ -1500,7 +1506,7 @@ client
-- when (isJust ntf) $ incStat $ ntfCreated stats
case subMode of
SMOnlyCreate -> pure ()
SMSubscribe -> void $ subscribeQueue q qr
SMSubscribe -> void $ subscribeNewQueue rcvId qr -- no need to check if message is available, it's a new queue
pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId = fst <$> queueData, serviceId = rcvServiceId} -- , serverNtfCreds = snd <$> ntf
(corrId,entId,) <$> tryCreate (3 :: Int)
@@ -1563,9 +1569,9 @@ client
-- TODO [certs rcv] if serviceId is passed, associate with the service and respond with SOK
subscribeQueue :: StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)
subscribeQueue q qr@QueueRec {rcvServiceId} =
subscribeQueue q qr =
liftIO (TM.lookupIO rId subscriptions) >>= \case
Nothing -> newSub >>= deliver True
Nothing -> subscribeNewQueue rId qr >>= deliver True
Just s@Sub {subThread} -> do
stats <- asks serverStats
case subThread of
@@ -1578,12 +1584,6 @@ client
atomically (tryTakeTMVar $ delivered s) >> deliver False s
where
rId = recipientId q
newSub :: M s Sub
newSub = time "SUB newSub" . atomically $ do
writeTQueue (subQ subscribers) (CSClient rId rcvServiceId Nothing, clientId)
sub <- newSubscription NoSub
TM.insert rId sub subscriptions
pure sub
deliver :: Bool -> Sub -> M s (Transmission BrokerMsg)
deliver inc sub = do
stats <- asks serverStats
@@ -1592,6 +1592,13 @@ client
liftIO $ when (inc && isJust msg_) $ incStat (qSub stats)
liftIO $ deliverMessage "SUB" qr rId sub msg_
subscribeNewQueue :: RecipientId -> QueueRec -> M s Sub
subscribeNewQueue rId QueueRec {rcvServiceId} = time "SUB newSub" . atomically $ do
writeTQueue (subQ subscribers) (CSClient rId rcvServiceId Nothing, clientId)
sub <- newSubscription NoSub
TM.insert rId sub subscriptions
pure sub
-- clients that use GET are not added to server subscribers
getMessage :: StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)
getMessage q qr = time "GET" $ do
@@ -1660,7 +1667,7 @@ client
pure $ SOK $ Just serviceId
| otherwise ->
-- new or updated queue-service association
liftIO (setQueueService (queueStore ms) q SNotifier (Just serviceId)) >>= \case
liftIO (setQueueService (queueStore ms) q SNotifierService (Just serviceId)) >>= \case
Left e -> pure $ ERR e
Right () -> do
hasSub <- atomically $ (<$ newServiceQueueSub) =<< hasServiceSub
@@ -1677,7 +1684,7 @@ client
modifyTVar' (totalServiceSubs ntfSubscribers) (+ 1) -- server count for all services
Nothing -> case ntfServiceId of
Just _ ->
liftIO (setQueueService (queueStore ms) q SNotifier Nothing) >>= \case
liftIO (setQueueService (queueStore ms) q SNotifierService Nothing) >>= \case
Left e -> pure $ ERR e
Right () -> do
-- hasSubscription should never be True in this branch, because queue was associated with service.
@@ -1900,7 +1907,7 @@ client
let clntTHParams = smpTHParamsSetVersion fwdVersion thParams'
-- only allowing single forwarded transactions
t' <- case tParse clntTHParams b of
t :| [] -> pure $ tDecodeParseValidate clntTHParams t
t :| [] -> pure $ tDecodeServer clntTHParams t
_ -> throwE BLOCK
let clntThAuth = Just $ THAuthServer {serverPrivKey, peerClientService = Nothing, sessSecret' = Just clientSecret}
-- process forwarded command
@@ -1925,23 +1932,22 @@ client
incStat $ pMsgFwdsRecv stats
pure r3
where
rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmission ErrorType Cmd -> M s (Either (Transmission BrokerMsg) (Maybe (StoreQueue s, QueueRec), Transmission Cmd))
rejectOrVerify clntThAuth (tAuth, authorized, (corrId', entId', cmdOrError)) =
case cmdOrError of
Left e -> pure $ Left (corrId', entId', ERR e)
Right cmd'
| allowed -> liftIO $ verified <$> verifyTransmission ms Nothing ((,C.cbNonce (bs corrId')) <$> clntThAuth) tAuth authorized entId' cmd'
| otherwise -> pure $ Left (corrId', entId', ERR $ CMD PROHIBITED)
where
allowed = case cmd' of
Cmd SSender SEND {} -> True
Cmd SSender (SKEY _) -> True
Cmd SSenderLink (LKEY _) -> True
Cmd SSenderLink LGET -> True
_ -> False
verified = \case
VRVerified q -> Right (q, (corrId', entId', cmd'))
VRFailed e -> Left (corrId', entId', ERR e)
rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmissionOrError ErrorType Cmd -> M s (VerifiedTransmissionOrError s)
rejectOrVerify clntThAuth = \case
Left (corrId', entId', e) -> pure $ Left (corrId', entId', ERR e)
Right t'@(_, _, t''@(corrId', entId', cmd'))
| allowed -> liftIO $ verified <$> verifyTransmission ms Nothing clntThAuth t'
| otherwise -> pure $ Left (corrId', entId', ERR $ CMD PROHIBITED)
where
allowed = case cmd' of
Cmd SSender SEND {} -> True
Cmd SSender (SKEY _) -> True
Cmd SSenderLink (LKEY _) -> True
Cmd SSenderLink LGET -> True
_ -> False
verified = \case
VRVerified q -> Right (q, t'')
VRFailed e -> Left (corrId', entId', ERR e)
deliverMessage :: T.Text -> QueueRec -> RecipientId -> Sub -> Maybe Message -> IO (Transmission BrokerMsg)
deliverMessage name qr rId s@Sub {subThread} msg_ = time (name <> " deliver") . atomically $

View File

@@ -39,6 +39,7 @@ module Simplex.Messaging.Server.Env.STM
MsgStoreType,
MsgStore (..),
AStoreType (..),
VerifiedTransmission,
newEnv,
mkJournalStoreConfig,
msgStore,
@@ -390,7 +391,7 @@ data Client s = Client
ntfSubscriptions :: TMap NotifierId (),
serviceSubsCount :: TVar Int64, -- only one service can be subscribed, based on its certificate, this is subscription count
ntfServiceSubsCount :: TVar Int64, -- only one service can be subscribed, based on its certificate, this is subscription count
rcvQ :: TBQueue (NonEmpty (Maybe (StoreQueue s, QueueRec), Transmission Cmd)),
rcvQ :: TBQueue (NonEmpty (VerifiedTransmission s)),
sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)),
msgQ :: TBQueue (NonEmpty (Transmission BrokerMsg)),
procThreads :: TVar Int,
@@ -403,6 +404,8 @@ data Client s = Client
sndActiveAt :: TVar SystemTime
}
type VerifiedTransmission s = (Maybe (StoreQueue s, QueueRec), Transmission Cmd)
data ServerSub = ServerSub (TVar SubscriptionThread) | ProhibitSub
data SubscriptionThread = NoSub | SubPending | SubThread (Weak ThreadId)

View File

@@ -324,6 +324,8 @@ instance QueueStoreClass (JournalQueue s) (QStore s) where
{-# INLINE addQueue_ #-}
getQueue_ = withQS getQueue_
{-# INLINE getQueue_ #-}
getQueues_ = withQS getQueues_
{-# INLINE getQueues_ #-}
addQueueLinkData = withQS addQueueLinkData
{-# INLINE addQueueLinkData #-}
getQueueLinkData = withQS getQueueLinkData

View File

@@ -18,6 +18,7 @@
module Simplex.Messaging.Server.MsgStore.Types where
import Control.Concurrent.STM
import Control.Monad
import Control.Monad.Trans.Except
import Data.Functor (($>))
import Data.Int (Int64)
@@ -107,14 +108,23 @@ addQueue :: MsgStoreClass s => s -> RecipientId -> QueueRec -> IO (Either ErrorT
addQueue st = addQueue_ (queueStore st) (mkQueue st True)
{-# INLINE addQueue #-}
getQueue :: (MsgStoreClass s, DirectParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s))
getQueue :: (MsgStoreClass s, QueueParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s))
getQueue st = getQueue_ (queueStore st) (mkQueue st)
{-# INLINE getQueue #-}
getQueueRec :: (MsgStoreClass s, DirectParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s, QueueRec))
getQueueRec st party qId =
getQueue st party qId
$>>= (\q -> maybe (Left AUTH) (Right . (q,)) <$> readTVarIO (queueRec q))
getQueueRec :: (MsgStoreClass s, QueueParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s, QueueRec))
getQueueRec st party qId = getQueue st party qId $>>= readQueueRec
getQueues :: (MsgStoreClass s, BatchParty p) => s -> SParty p -> [QueueId] -> IO [Either ErrorType (StoreQueue s)]
getQueues st = getQueues_ (queueStore st) (mkQueue st)
{-# INLINE getQueues #-}
getQueueRecs :: (MsgStoreClass s, BatchParty p) => s -> SParty p -> [QueueId] -> IO [Either ErrorType (StoreQueue s, QueueRec)]
getQueueRecs st party qIds = getQueues st party qIds >>= mapM (fmap join . mapM readQueueRec)
readQueueRec :: StoreQueueClass q => q -> IO (Either ErrorType (q, QueueRec))
readQueueRec q = maybe (Left AUTH) (Right . (q,)) <$> readTVarIO (queueRec q)
{-# INLINE readQueueRec #-}
getQueueSize :: MsgStoreClass s => s -> StoreQueue s -> ExceptT ErrorType IO Int
getQueueSize st q = withPeekMsgQueue st q "getQueueSize" $ maybe (pure 0) (getQueueSize_ . fst)

View File

@@ -37,18 +37,20 @@ import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Class
import Control.Monad.Trans.Except
import Data.Bifunctor (first)
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as BB
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Lazy as LB
import Data.Bitraversable (bimapM)
import Data.Either (fromRight)
import Data.Either (fromRight, lefts, rights)
import Data.Functor (($>))
import Data.Int (Int64)
import Data.List (foldl', intersperse, partition)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe)
import Data.Maybe (catMaybes, fromMaybe, mapMaybe)
import qualified Data.Set as S
import Data.Text (Text)
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
@@ -62,7 +64,7 @@ import Database.PostgreSQL.Simple.ToField (Action (..), ToField (..))
import Database.PostgreSQL.Simple.Errors (ConstraintViolation (..), constraintViolation)
import Database.PostgreSQL.Simple.SqlQQ (sql)
import GHC.IO (catchAny)
import Simplex.Messaging.Agent.Client (withLockMap)
import Simplex.Messaging.Agent.Client (withLockMap, withLocksMap)
import Simplex.Messaging.Agent.Lock (Lock)
import Simplex.Messaging.Agent.Store.AgentStore ()
import Simplex.Messaging.Agent.Store.Postgres (createDBStore, closeDBStore)
@@ -81,7 +83,7 @@ import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (SMPServiceRole (..))
import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>))
import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>), ($>>=))
import System.Exit (exitFailure)
import System.IO (IOMode (..), hFlush, stdout)
import UnliftIO.STM
@@ -180,18 +182,16 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
-- hasId = anyM [TM.memberIO rId queues, TM.memberIO senderId senders, hasNotifier]
-- hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.memberIO notifierId notifiers) notifier
getQueue_ :: DirectParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueue_ :: QueueParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueue_ st mkQ party qId = case party of
SRecipient -> getRcvQueue qId
SSender -> getSndQueue
SProxyService -> getSndQueue
SSender -> TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue
SSenderLink -> TM.lookupIO qId links >>= maybe (mask loadLinkQueue) getRcvQueue
-- loaded queue is deleted from notifiers map to reduce cache size after queue was subscribed to by ntf server
SNotifier -> TM.lookupIO qId notifiers >>= maybe (mask loadNtfQueue) (getRcvQueue >=> (atomically (TM.delete qId notifiers) $>))
where
PostgresQueueStore {queues, senders, links, notifiers} = st
getRcvQueue rId = TM.lookupIO rId queues >>= maybe (mask loadRcvQueue) (pure . Right)
getSndQueue = TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue
loadRcvQueue = do
(rId, qRec) <- loadQueue " WHERE recipient_id = ?"
liftIO $ cacheQueue rId qRec $ \_ -> pure () -- recipient map already checked, not caching sender ref
@@ -228,6 +228,47 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
TM.insert rId sq queues
pure sq
getQueues_ :: forall p. BatchParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q]
getQueues_ st mkQ party qIds = case party of
SRecipient -> do
qs <- readTVarIO queues
let qs' = map (\qId -> get qs qId qId) qIds
E.uninterruptibleMask_ $ loadQueues qs' " WHERE recipient_id IN ?" cacheRcvQueue
SNotifier -> do
ns <- readTVarIO notifiers
qs <- readTVarIO queues
let qs' = map (\qId -> get ns qId qId >>= get qs qId) qIds
E.uninterruptibleMask_ $ loadQueues qs' " WHERE notifier_id IN ?" $ \(rId, qRec) ->
forM (notifier qRec) $ \NtfCreds {notifierId = nId} -> -- it is always Just with this query
(nId,) <$> maybe (mkQ False rId qRec) pure (M.lookup rId qs)
where
PostgresQueueStore {queues, notifiers} = st
get :: M.Map QueueId a -> QueueId -> QueueId -> Either QueueId a
get m qId = maybe (Left qId) Right . (`M.lookup` m)
loadQueues :: [Either QueueId q] -> Query -> ((RecipientId, QueueRec) -> IO (Maybe (QueueId, q))) -> IO [Either ErrorType q]
loadQueues qs' cond mkCacheQueue = do
let qIds' = lefts qs'
if null qIds'
then pure $ map (first (const INTERNAL)) qs'
else do
qs_ <-
runExceptT $ fmap M.fromList $
withDB' "getQueues_" st (\db -> DB.query db (queueRecQuery <> cond <> " AND deleted_at IS NULL") (Only (In qIds')))
>>= liftIO . fmap catMaybes . mapM (mkCacheQueue . rowToQueueRec)
pure $ map (result qs_) qs'
where
result :: Either ErrorType (M.Map QueueId q) -> Either QueueId q -> Either ErrorType q
result _ (Right q) = Right q
result qs_ (Left qId) = maybe (Left AUTH) Right . M.lookup qId =<< qs_
cacheRcvQueue (rId, qRec) = do
sq <- mkQ True rId qRec
sq' <- withQueueLock sq "getQueue_" $ atomically $
-- checking the cache again for concurrent reads, use previously loaded queue if exists.
TM.lookup rId queues >>= \case
Just sq' -> pure sq'
Nothing -> sq <$ TM.insert rId sq queues
pure $ Just (rId, sq')
getQueueLinkData :: PostgresQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData)
getQueueLinkData st sq lnkId = runExceptT $ do
qr <- ExceptT $ readQueueRecIO $ queueRec sq
@@ -311,7 +352,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
db
[sql|
UPDATE msg_queues
SET notifier_id = ?, notifier_key = ?, rcv_ntf_dh_secret = ?
SET notifier_id = ?, notifier_key = ?, rcv_ntf_dh_secret = ?, ntf_service_id = NULL
WHERE recipient_id = ? AND deleted_at IS NULL
|]
(nId, notifierKey, rcvNtfDhSecret, rId)
@@ -333,7 +374,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
db
[sql|
UPDATE msg_queues
SET notifier_id = NULL, notifier_key = NULL, rcv_ntf_dh_secret = NULL
SET notifier_id = NULL, notifier_key = NULL, rcv_ntf_dh_secret = NULL, ntf_service_id = NULL
WHERE recipient_id = ? AND deleted_at IS NULL
|]
(Only rId)
@@ -402,15 +443,15 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
when new $ withLog "getCreateService" st (`logNewService` sr)
pure serviceId
setQueueService :: (PartyI p, SubscriberParty p) => PostgresQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
setQueueService :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
setQueueService st sq party serviceId = withQueueRec sq "setQueueService" $ \q -> case party of
SRecipient
SRecipientService
| rcvServiceId q == serviceId -> pure ()
| otherwise -> do
assertUpdated $ withDB' "setQueueService" st $ \db ->
DB.execute db "UPDATE msg_queues SET rcv_service_id = ? WHERE recipient_id = ? AND deleted_at IS NULL" (serviceId, rId)
updateQueueRec q {rcvServiceId = serviceId}
SNotifier -> case notifier q of
SNotifierService -> case notifier q of
Nothing -> throwE AUTH
Just nc@NtfCreds {ntfServiceId = prevSrvId}
| prevSrvId == serviceId -> pure ()

View File

@@ -128,17 +128,29 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.member notifierId notifiers) notifier
hasLink = maybe (pure False) (\(lnkId, _) -> TM.member lnkId links) queueData
getQueue_ :: DirectParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueue_ :: QueueParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueue_ st _ party qId =
maybe (Left AUTH) Right <$> case party of
SRecipient -> TM.lookupIO qId queues
SSender -> getSndQueue
SProxyService -> getSndQueue
SSender -> TM.lookupIO qId senders $>>= (`TM.lookupIO` queues)
SNotifier -> TM.lookupIO qId notifiers $>>= (`TM.lookupIO` queues)
SSenderLink -> TM.lookupIO qId links $>>= (`TM.lookupIO` queues)
where
STMQueueStore {queues, senders, notifiers, links} = st
getSndQueue = TM.lookupIO qId senders $>>= (`TM.lookupIO` queues)
getQueues_ :: BatchParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q]
getQueues_ st _ party qIds = case party of
SRecipient -> do
qs <- readTVarIO queues
pure $ map (get qs) qIds
SNotifier -> do
ns <- readTVarIO notifiers
qs <- readTVarIO queues
pure $ map (get qs <=< get ns) qIds
where
STMQueueStore {queues, notifiers} = st
get :: M.Map QueueId a -> QueueId -> Either ErrorType a
get m = maybe (Left AUTH) Right . (`M.lookup` m)
getQueueLinkData :: STMQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData)
getQueueLinkData _ q lnkId = atomically $ readQueueRec (queueRec q) $>>= pure . getData
@@ -292,7 +304,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
serviceNtfQueues <- newTVar S.empty
pure STMService {serviceRec = sr, serviceRcvQueues, serviceNtfQueues}
setQueueService :: (PartyI p, SubscriberParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
setQueueService :: (PartyI p, ServiceParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
setQueueService st sq party serviceId =
atomically (readQueueRec qr $>>= setService)
$>> withLog "setQueueService" st (\sl -> logQueueService sl rId party serviceId)
@@ -301,13 +313,13 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
rId = recipientId sq
setService :: QueueRec -> STM (Either ErrorType ())
setService q@QueueRec {rcvServiceId = prevSrvId} = case party of
SRecipient
SRecipientService
| prevSrvId == serviceId -> pure $ Right ()
| otherwise -> do
updateServiceQueues serviceRcvQueues rId prevSrvId
let !q' = Just q {rcvServiceId = serviceId}
writeTVar qr q' $> Right ()
SNotifier -> case notifier q of
SNotifierService -> case notifier q of
Nothing -> pure $ Left AUTH
Just nc@NtfCreds {notifierId = nId, ntfServiceId = prevNtfSrvId}
| prevNtfSrvId == serviceId -> pure $ Right ()

View File

@@ -31,7 +31,8 @@ class StoreQueueClass q => QueueStoreClass q s where
loadedQueues :: s -> TMap RecipientId q
compactQueues :: s -> IO Int64
addQueue_ :: s -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q)
getQueue_ :: DirectParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueue_ :: QueueParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
getQueues_ :: BatchParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q]
getQueueLinkData :: s -> q -> LinkId -> IO (Either ErrorType QueueLinkData)
addQueueLinkData :: s -> q -> LinkId -> QueueLinkData -> IO (Either ErrorType ())
deleteQueueLinkData :: s -> q -> IO (Either ErrorType ())
@@ -45,7 +46,7 @@ class StoreQueueClass q => QueueStoreClass q s where
updateQueueTime :: s -> q -> RoundedSystemTime -> IO (Either ErrorType QueueRec)
deleteStoreQueue :: s -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q)))
getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId)
setQueueService :: (PartyI p, SubscriberParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
setQueueService :: (PartyI p, ServiceParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
getQueueNtfServices :: s -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)]))
getNtfServiceQueueCount :: s -> ServiceId -> IO (Either ErrorType Int64)

View File

@@ -286,7 +286,7 @@ logUpdateQueueTime s qId t = writeStoreLogRecord s $ UpdateTime qId t
logNewService :: StoreLog 'WriteMode -> ServiceRec -> IO ()
logNewService s = writeStoreLogRecord s . NewService
logQueueService :: (PartyI p, SubscriberParty p) => StoreLog 'WriteMode -> RecipientId -> SParty p -> Maybe ServiceId -> IO ()
logQueueService :: (PartyI p, ServiceParty p) => StoreLog 'WriteMode -> RecipientId -> SParty p -> Maybe ServiceId -> IO ()
logQueueService s rId party = writeStoreLogRecord s . QueueService rId (ASP party)
readWriteStoreLog :: (FilePath -> s -> IO ()) -> (StoreLog 'WriteMode -> s -> IO ()) -> FilePath -> s -> IO (StoreLog 'WriteMode)

View File

@@ -122,7 +122,7 @@ storeLogTests =
},
SLTC
{ name = "create queue, add notifier, register and associate notification service",
saved = [CreateQueue rId qr, AddNotifier rId ntfCreds, NewService sr, QueueService rId (ASP SNotifier) (Just serviceId)],
saved = [CreateQueue rId qr, AddNotifier rId ntfCreds, NewService sr, QueueService rId (ASP SNotifierService) (Just serviceId)],
compacted = [NewService sr, CreateQueue rId qr {notifier = Just ntfCreds {ntfServiceId = Just serviceId}}],
state = M.fromList [(rId, qr {notifier = Just ntfCreds {ntfServiceId = Just serviceId}})]
},

View File

@@ -206,7 +206,7 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
[Right ()] <- tPut h [Right (sig, t')]
pure ()
tGet' h = do
[(Nothing, _, (CorrId corrId, EntityId qId, Right cmd))] <- tGet h
[(CorrId corrId, EntityId qId, Right cmd)] <- tGetClient h
pure (Nothing, corrId, qId, cmd)
ntfTest :: Transport c => TProxy c 'TServer -> (THandleNTF c 'TClient -> IO ()) -> Expectation

View File

@@ -72,18 +72,18 @@ ntfSyntaxTests (ATransport t) = do
Expectation
command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission ErrorType NtfResponse
pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command))
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> Transmission (Either ErrorType NtfResponse)
pattern RespNtf corrId queueId command <- (corrId, queueId, Right command)
deriving instance Eq NtfResponse
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> (Maybe TAuthorizations, ByteString, NtfEntityId, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> (Maybe TAuthorizations, ByteString, NtfEntityId, NtfCommand e) -> IO (Transmission (Either ErrorType NtfResponse))
sendRecvNtf h@THandle {params} (sgn, corrId, qId, cmd) = do
let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
Right () <- tPut1 h (sgn, tToSend)
tGet1 h
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> C.APrivateAuthKey -> (ByteString, NtfEntityId, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> C.APrivateAuthKey -> (ByteString, NtfEntityId, NtfCommand e) -> IO (Transmission (Either ErrorType NtfResponse))
signSendRecvNtf h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
Right () <- tPut1 h (authorize tForAuth, tToSend)

View File

@@ -394,7 +394,7 @@ smpServerTest _ t = runSmpTest (ASType SQSMemory SMSJournal) $ \h -> tPut' h t >
[Right ()] <- tPut h [Right (sig, t')]
pure ()
tGet' h = do
[(Nothing, _, (CorrId corrId, EntityId qId, Right cmd))] <- tGet h
[(CorrId corrId, EntityId qId, Right cmd)] <- tGetClient h
pure (Nothing, corrId, qId, cmd)
smpTest :: (HasCallStack, Transport c) => TProxy c 'TServer -> AStoreType -> (HasCallStack => THandleSMP c 'TClient -> IO ()) -> Expectation

View File

@@ -435,14 +435,14 @@ testNoProxy :: AStoreType -> IO ()
testNoProxy msType = do
withSmpServerConfigOn (transport @TLS) (cfgMS msType) testPort2 $ \_ -> do
testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do
(_, _, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer Nothing)
(_, _, reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer Nothing)
reply `shouldBe` Right (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH)
testProxyAuth :: AStoreType -> IO ()
testProxyAuth msType = do
withSmpServerConfigOn (transport @TLS) proxyCfgAuth testPort $ \_ -> do
testSMPClient_ "127.0.0.1" testPort proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do
(_, _s, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer2 $ Just "wrong")
(_, _, reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer2 $ Just "wrong")
reply `shouldBe` Right (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH)
where
proxyCfgAuth = updateCfg (proxyCfgMS msType) $ \cfg_ -> cfg_ {newQueueBasicAuth = Just "correct"}

View File

@@ -92,10 +92,10 @@ serverTests = do
testInvQueueLinkData
testContactQueueLinkData
pattern Resp :: CorrId -> QueueId -> BrokerMsg -> SignedTransmission ErrorType BrokerMsg
pattern Resp corrId queueId command <- (_, _, (corrId, queueId, Right command))
pattern Resp :: CorrId -> QueueId -> BrokerMsg -> Transmission (Either ErrorType BrokerMsg)
pattern Resp corrId queueId command <- (corrId, queueId, Right command)
pattern New :: RcvPublicAuthKey -> RcvPublicDhKey -> Command 'Recipient
pattern New :: RcvPublicAuthKey -> RcvPublicDhKey -> Command 'Creator
pattern New rPub dhPub = NEW (NewQueueReq rPub dhPub Nothing SMSubscribe (Just (QRMessaging Nothing)))
pattern Ids :: RecipientId -> SenderId -> RcvPublicDhKey -> BrokerMsg
@@ -104,19 +104,19 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh _sndSecure _linkId Nothing)
pattern Msg :: MsgId -> MsgBody -> BrokerMsg
pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body}
sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> (Maybe TAuthorizations, ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> (Maybe TAuthorizations, ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg))
sendRecv h@THandle {params} (sgn, corrId, qId, cmd) = do
let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
Right () <- tPut1 h (sgn, tToSend)
tGet1 h
signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> (ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> (ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg))
signSendRecv h pk = signSendRecv_ h pk Nothing
serviceSignSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
serviceSignSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg))
serviceSignSendRecv h pk = signSendRecv_ h pk . Just
signSendRecv_ :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> Maybe C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
signSendRecv_ :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> Maybe C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg))
signSendRecv_ h@THandle {params} (C.APrivateAuthKey a pk) serviceKey_ (corrId, qId, cmd) = do
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
Right () <- tPut1 h (authorize tForAuth, tToSend)
@@ -139,9 +139,9 @@ tPut1 h t = do
[r] <- tPut h [Right t]
pure r
tGet1 :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TClient -> IO (SignedTransmission err cmd)
tGet1 :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TClient -> IO (Transmission (Either err cmd))
tGet1 h = do
[r] <- liftIO $ tGet h
[r] <- liftIO $ tGetClient h
pure r
(#==) :: (HasCallStack, Eq a, Show a) => (a, a) -> String -> Assertion
@@ -519,7 +519,7 @@ testSwitchSub =
Resp "" rId' DELD <- tGet1 rh2
(rId', rId) #== "connection deleted event delivered to subscribed client"
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case
Nothing -> return ()
Just _ -> error "nothing else is delivered to the 1st TCP connection"
@@ -1017,7 +1017,7 @@ testMessageNotifications =
Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2)
(dec mId2 msg2, Right "hello again") #== "delivered from queue again"
Resp "" _ (NMSG _ _) <- tGet1 nh2
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case
Nothing -> pure ()
Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection"
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL)
@@ -1027,7 +1027,7 @@ testMessageNotifications =
Resp "" _ (Msg mId3 msg3) <- tGet1 rh
(dec mId3 msg3, Right "hello there") #== "delivered from queue again"
Resp "7a" _ OK <- signSendRecv rh rKey ("7a", rId, ACK mId3)
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case
Nothing -> pure ()
Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection"
(nPub'', nKey'') <- atomically $ C.generateAuthKeyPair C.SEd25519 g
@@ -1069,7 +1069,7 @@ testMessageServiceNotifications =
Resp "" serviceId2 (ENDS 1) <- tGet1 nh1
serviceId2 `shouldBe` serviceId
deliverMessage rh rId rKey sh sId sKey nh2 "hello again" dec
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case
Nothing -> pure ()
Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection"
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL)
@@ -1079,7 +1079,7 @@ testMessageServiceNotifications =
Resp "" _ (Msg mId3 msg3) <- tGet1 rh
(dec mId3 msg3, Right "hello there") #== "delivered from queue again"
Resp "7a" _ OK <- signSendRecv rh rKey ("7a", rId, ACK mId3)
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case
Nothing -> pure ()
Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection"
-- new notification credentials
@@ -1133,7 +1133,7 @@ testMsgExpireOnSend =
testSMPClient @c $ \rh -> do
Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB)
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh >>= \case
Nothing -> return ()
Just _ -> error "nothing else should be delivered"
@@ -1153,7 +1153,7 @@ testMsgExpireOnInterval =
signSendRecv rh rKey ("2", rId, SUB) >>= \case
Resp "2" _ OK -> pure ()
r -> unexpected r
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh >>= \case
Nothing -> return ()
Just _ -> error "nothing should be delivered"
@@ -1172,7 +1172,7 @@ testMsgNOTExpireOnInterval =
testSMPClient @c $ \rh -> do
Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB)
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh >>= \case
Nothing -> return ()
Just _ -> error "nothing else should be delivered"