mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 14:16:00 +00:00
smp server: batch commands (#1559)
* protocol: refactor types and encoding * clean * smp server: batch commands (#1560) * smp server: batch commands verification into one DB transaction * ghc 8.10.7 * flatten transmission tuples * diff * only use batch logic if there is more than one transmission * func * reset NTF service when adding notifier * version * Revert "smp server: use separate database pool for reading queues and creating service records (#1561)" This reverts commit3df2425162. * version * Revert "version" This reverts commitd80a6b74c5.
This commit is contained in:
@@ -204,7 +204,7 @@ sendXFTPTransmission XFTPClient {config, thParams, http2Client} t chunkSpec_ = d
|
||||
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req (Just reqTimeout)
|
||||
when (B.length bodyHead /= xftpBlockSize) $ throwE $ PCEResponseError BLOCK
|
||||
-- TODO validate that the file ID is the same as in the request?
|
||||
(_, _, (_, _fId, respOrErr)) <- liftEither . first PCEResponseError $ xftpDecodeTransmission thParams bodyHead
|
||||
(_, _fId, respOrErr) <-liftEither $ first PCEResponseError $ xftpDecodeTClient thParams bodyHead
|
||||
case respOrErr of
|
||||
Right r -> case protocolError r of
|
||||
Just e -> throwE $ PCEProtocolError e
|
||||
|
||||
@@ -44,8 +44,9 @@ import Simplex.Messaging.Protocol
|
||||
EntityId (..),
|
||||
RecipientId,
|
||||
SenderId,
|
||||
RawTransmission,
|
||||
SentRawTransmission,
|
||||
SignedTransmission,
|
||||
SignedTransmissionOrError,
|
||||
SndPublicAuthKey,
|
||||
Transmission,
|
||||
TransmissionForAuth (..),
|
||||
@@ -53,7 +54,8 @@ import Simplex.Messaging.Protocol
|
||||
encodeTransmission,
|
||||
encodeTransmissionForAuth,
|
||||
messageTagP,
|
||||
tDecodeParseValidate,
|
||||
tDecodeServer,
|
||||
tDecodeClient,
|
||||
tEncodeBatch1,
|
||||
tParse,
|
||||
)
|
||||
@@ -197,7 +199,7 @@ instance FilePartyI p => ProtocolEncoding XFTPVersion XFTPErrorType (FileCommand
|
||||
fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (auth, _, EntityId fileId, _) cmd = case cmd of
|
||||
checkCredentials auth (EntityId fileId) cmd = case cmd of
|
||||
-- FNEW must not have signature and chunk ID
|
||||
FNEW {}
|
||||
| isNothing auth -> Left $ CMD NO_AUTH
|
||||
@@ -231,7 +233,7 @@ instance ProtocolEncoding XFTPVersion XFTPErrorType FileCmd where
|
||||
fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials t (FileCmd p c) = FileCmd p <$> checkCredentials t c
|
||||
checkCredentials tAuth entId (FileCmd p c) = FileCmd p <$> checkCredentials tAuth entId c
|
||||
{-# INLINE checkCredentials #-}
|
||||
|
||||
instance Encoding FileInfo where
|
||||
@@ -310,7 +312,7 @@ instance ProtocolEncoding XFTPVersion XFTPErrorType FileResponse where
|
||||
PEBlock -> BLOCK
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (_, _, EntityId entId, _) cmd = case cmd of
|
||||
checkCredentials _ (EntityId entId) cmd = case cmd of
|
||||
FRSndIds {} -> noEntity
|
||||
-- ERR response does not always have entity ID
|
||||
FRErr _ -> Right cmd
|
||||
@@ -335,25 +337,35 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of
|
||||
Just Refl -> Just c
|
||||
_ -> Nothing
|
||||
|
||||
xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion 'TClient -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString
|
||||
xftpEncodeAuthTransmission thParams@THandleParams {thAuth} pKey (corrId, fId, msg) = do
|
||||
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, fId, msg)
|
||||
xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion XFTPErrorType c => THandleParams XFTPVersion 'TClient -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString
|
||||
xftpEncodeAuthTransmission thParams@THandleParams {thAuth} pKey t@(corrId, _, _) = do
|
||||
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams t
|
||||
xftpEncodeBatch1 . (,tToSend) =<< authTransmission thAuth False (Just pKey) (C.cbNonce $ bs corrId) tForAuth
|
||||
|
||||
xftpEncodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion p -> Transmission c -> Either TransportError ByteString
|
||||
xftpEncodeTransmission thParams (corrId, fId, msg) = do
|
||||
let t = encodeTransmission thParams (corrId, fId, msg)
|
||||
xftpEncodeBatch1 (Nothing, t)
|
||||
xftpEncodeTransmission :: ProtocolEncoding XFTPVersion XFTPErrorType c => THandleParams XFTPVersion p -> Transmission c -> Either TransportError ByteString
|
||||
xftpEncodeTransmission thParams t = xftpEncodeBatch1 (Nothing, encodeTransmission thParams t)
|
||||
|
||||
-- this function uses batch syntax but puts only one transmission in the batch
|
||||
xftpEncodeBatch1 :: SentRawTransmission -> Either TransportError ByteString
|
||||
xftpEncodeBatch1 t = first (const TELargeMsg) $ C.pad (tEncodeBatch1 False t) xftpBlockSize
|
||||
|
||||
xftpDecodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion p -> ByteString -> Either XFTPErrorType (SignedTransmission e c)
|
||||
xftpDecodeTransmission thParams t = do
|
||||
xftpDecodeTServer :: THandleParams XFTPVersion 'TServer -> ByteString -> Either XFTPErrorType (SignedTransmissionOrError XFTPErrorType FileCmd)
|
||||
xftpDecodeTServer = xftpDecodeTransmission tDecodeServer
|
||||
{-# INLINE xftpDecodeTServer #-}
|
||||
|
||||
xftpDecodeTClient :: THandleParams XFTPVersion 'TClient -> ByteString -> Either XFTPErrorType (Transmission (Either XFTPErrorType FileResponse))
|
||||
xftpDecodeTClient = xftpDecodeTransmission tDecodeClient
|
||||
{-# INLINE xftpDecodeTClient #-}
|
||||
|
||||
xftpDecodeTransmission ::
|
||||
(THandleParams XFTPVersion p -> Either TransportError RawTransmission -> r) ->
|
||||
THandleParams XFTPVersion p ->
|
||||
ByteString ->
|
||||
Either XFTPErrorType r
|
||||
xftpDecodeTransmission tDecode thParams t = do
|
||||
t' <- first (const BLOCK) $ C.unPad t
|
||||
case tParse thParams t' of
|
||||
t'' :| [] -> Right $ tDecodeParseValidate thParams t''
|
||||
t'' :| [] -> Right $ tDecode thParams t''
|
||||
_ -> Left BLOCK
|
||||
|
||||
$(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty)
|
||||
|
||||
@@ -53,7 +53,7 @@ import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (CorrId (..), BlockingInfo, EntityId (..), RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TAuthorizations, pattern NoEntity)
|
||||
import Simplex.Messaging.Protocol (BlockingInfo, EntityId (..), RcvPublicAuthKey, RcvPublicDhKey, RecipientId, SignedTransmission, pattern NoEntity)
|
||||
import Simplex.Messaging.Server (dummyVerifyCmd, verifyCmdAuthorization)
|
||||
import Simplex.Messaging.Server.Control (CPClientRole (..))
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
@@ -317,22 +317,20 @@ data ServerFile = ServerFile
|
||||
processRequest :: XFTPTransportRequest -> M ()
|
||||
processRequest XFTPTransportRequest {thParams, reqBody = body@HTTP2Body {bodyHead}, sendResponse}
|
||||
| B.length bodyHead /= xftpBlockSize = sendXFTPResponse ("", NoEntity, FRErr BLOCK) Nothing
|
||||
| otherwise = do
|
||||
case xftpDecodeTransmission thParams bodyHead of
|
||||
Right (sig_, signed, (corrId, fId, cmdOrErr)) ->
|
||||
case cmdOrErr of
|
||||
Right cmd -> do
|
||||
let THandleParams {thAuth} = thParams
|
||||
verifyXFTPTransmission ((,C.cbNonce (bs corrId)) <$> thAuth) sig_ signed fId cmd >>= \case
|
||||
VRVerified req -> uncurry send =<< processXFTPRequest body req
|
||||
VRFailed e -> send (FRErr e) Nothing
|
||||
Left e -> send (FRErr e) Nothing
|
||||
| otherwise =
|
||||
case xftpDecodeTServer thParams bodyHead of
|
||||
Right (Right t@(_, _, (corrId, fId, _))) -> do
|
||||
let THandleParams {thAuth} = thParams
|
||||
verifyXFTPTransmission thAuth t >>= \case
|
||||
VRVerified req -> uncurry send =<< processXFTPRequest body req
|
||||
VRFailed e -> send (FRErr e) Nothing
|
||||
where
|
||||
send resp = sendXFTPResponse (corrId, fId, resp)
|
||||
Right (Left (corrId, fId, e)) -> sendXFTPResponse (corrId, fId, FRErr e) Nothing
|
||||
Left e -> sendXFTPResponse ("", NoEntity, FRErr e) Nothing
|
||||
where
|
||||
sendXFTPResponse (corrId, fId, resp) serverFile_ = do
|
||||
let t_ = xftpEncodeTransmission thParams (corrId, fId, resp)
|
||||
sendXFTPResponse t' serverFile_ = do
|
||||
let t_ = xftpEncodeTransmission thParams t'
|
||||
#ifdef slow_servers
|
||||
randomDelay
|
||||
#endif
|
||||
@@ -361,8 +359,8 @@ randomDelay = do
|
||||
|
||||
data VerificationResult = VRVerified XFTPRequest | VRFailed XFTPErrorType
|
||||
|
||||
verifyXFTPTransmission :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TAuthorizations -> ByteString -> XFTPFileId -> FileCmd -> M VerificationResult
|
||||
verifyXFTPTransmission auth_ tAuth authorized fId cmd =
|
||||
verifyXFTPTransmission :: Maybe (THandleAuth 'TServer) -> SignedTransmission FileCmd -> M VerificationResult
|
||||
verifyXFTPTransmission thAuth (tAuth, authorized, (corrId, fId, cmd)) =
|
||||
case cmd of
|
||||
FileCmd SFSender (FNEW file rcps auth') -> pure $ XFTPReqNew file rcps auth' `verifyWith` sndKey file
|
||||
FileCmd SFRecipient PING -> pure $ VRVerified XFTPReqPing
|
||||
@@ -381,9 +379,9 @@ verifyXFTPTransmission auth_ tAuth authorized fId cmd =
|
||||
EntityBlocked info -> VRFailed $ BLOCKED info
|
||||
EntityOff -> noFileAuth
|
||||
Left _ -> pure noFileAuth
|
||||
noFileAuth = maybe False (dummyVerifyCmd Nothing authorized) tAuth `seq` VRFailed AUTH
|
||||
noFileAuth = dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH
|
||||
-- TODO verify with DH authorization
|
||||
req `verifyWith` k = if verifyCmdAuthorization auth_ tAuth authorized k then VRVerified req else VRFailed AUTH
|
||||
req `verifyWith` k = if verifyCmdAuthorization thAuth tAuth authorized corrId k then VRVerified req else VRFailed AUTH
|
||||
|
||||
processXFTPRequest :: HTTP2Body -> XFTPRequest -> M (FileResponse, Maybe ServerFile)
|
||||
processXFTPRequest HTTP2Body {bodyPart} = \case
|
||||
|
||||
@@ -33,6 +33,7 @@ module Simplex.Messaging.Agent.Client
|
||||
withConnLocks,
|
||||
withInvLock,
|
||||
withLockMap,
|
||||
withLocksMap,
|
||||
getMapLock,
|
||||
ipAddressProtected,
|
||||
closeAgentClient,
|
||||
@@ -1004,16 +1005,16 @@ withInvLock' AgentClient {invLocks} = withLockMap invLocks
|
||||
{-# INLINE withInvLock' #-}
|
||||
|
||||
withConnLocks :: AgentClient -> Set ConnId -> Text -> AM' a -> AM' a
|
||||
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks
|
||||
withConnLocks AgentClient {connLocks} = withLocksMap connLocks
|
||||
{-# INLINE withConnLocks #-}
|
||||
|
||||
withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> Text -> m a -> m a
|
||||
withLockMap = withGetLock . getMapLock
|
||||
{-# INLINE withLockMap #-}
|
||||
|
||||
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> Text -> m a -> m a
|
||||
withLocksMap_ = withGetLocks . getMapLock
|
||||
{-# INLINE withLocksMap_ #-}
|
||||
withLocksMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> Text -> m a -> m a
|
||||
withLocksMap = withGetLocks . getMapLock
|
||||
{-# INLINE withLocksMap #-}
|
||||
|
||||
getMapLock :: Ord k => TMap k Lock -> k -> STM Lock
|
||||
getMapLock locks key = TM.lookup key locks >>= maybe newLock pure
|
||||
|
||||
@@ -183,7 +183,7 @@ data PClient v err msg = PClient
|
||||
clientCorrId :: TVar ChaChaDRG,
|
||||
sentCommands :: TMap CorrId (Request err msg),
|
||||
sndQ :: TBQueue (Maybe (Request err msg), ByteString),
|
||||
rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)),
|
||||
rcvQ :: TBQueue (NonEmpty (Transmission (Either err msg))),
|
||||
msgQ :: Maybe (TBQueue (ServerTransmissionBatch v err msg))
|
||||
}
|
||||
|
||||
@@ -615,7 +615,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
|
||||
|
||||
receive :: Transport c => ProtocolClient v err msg -> THandle v c 'TClient -> IO ()
|
||||
receive ProtocolClient {client_ = PClient {rcvQ, lastReceived, timeoutErrorCount}} h = forever $ do
|
||||
tGet h >>= atomically . writeTBQueue rcvQ
|
||||
tGetClient h >>= atomically . writeTBQueue rcvQ
|
||||
getCurrentTime >>= atomically . writeTVar lastReceived
|
||||
atomically $ writeTVar timeoutErrorCount 0
|
||||
|
||||
@@ -642,14 +642,14 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
|
||||
process :: ProtocolClient v err msg -> IO ()
|
||||
process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= processMsgs c
|
||||
|
||||
processMsgs :: ProtocolClient v err msg -> NonEmpty (SignedTransmission err msg) -> IO ()
|
||||
processMsgs :: ProtocolClient v err msg -> NonEmpty (Transmission (Either err msg)) -> IO ()
|
||||
processMsgs c ts = do
|
||||
ts' <- catMaybes <$> mapM (processMsg c) (L.toList ts)
|
||||
forM_ msgQ $ \q ->
|
||||
mapM_ (atomically . writeTBQueue q . serverTransmission c) (L.nonEmpty ts')
|
||||
|
||||
processMsg :: ProtocolClient v err msg -> SignedTransmission err msg -> IO (Maybe (EntityId, ServerTransmission err msg))
|
||||
processMsg ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr))
|
||||
processMsg :: ProtocolClient v err msg -> Transmission (Either err msg) -> IO (Maybe (EntityId, ServerTransmission err msg))
|
||||
processMsg ProtocolClient {client_ = PClient {sentCommands}} (corrId, entId, respOrErr)
|
||||
| B.null $ bs corrId = sendMsg $ STEvent clientResp
|
||||
| otherwise =
|
||||
TM.lookupIO corrId sentCommands >>= \case
|
||||
@@ -767,7 +767,7 @@ createSMPQueue ::
|
||||
-- Maybe NewNtfCreds ->
|
||||
ExceptT SMPClientError IO QueueIdsKeys
|
||||
createSMPQueue c nonce_ (rKey, rpKey) dhKey auth subMode qrd =
|
||||
sendProtocolCommand_ c nonce_ Nothing (Just rpKey) NoEntity (Cmd SRecipient $ NEW $ NewQueueReq rKey dhKey auth subMode (Just qrd)) >>= \case
|
||||
sendProtocolCommand_ c nonce_ Nothing (Just rpKey) NoEntity (Cmd SCreator $ NEW $ NewQueueReq rKey dhKey auth subMode (Just qrd)) >>= \case
|
||||
IDS qik -> pure qik
|
||||
r -> throwE $ unexpectedResponse r
|
||||
|
||||
@@ -848,7 +848,7 @@ nsubResponse_ = \case
|
||||
r' -> Left $ unexpectedResponse r'
|
||||
{-# INLINE nsubResponse_ #-}
|
||||
|
||||
subscribeService :: forall p. (PartyI p, SubscriberParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO Int64
|
||||
subscribeService :: forall p. (PartyI p, ServiceParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO Int64
|
||||
subscribeService c party = case smpClientService c of
|
||||
Just THClientService {serviceId, serviceKey} -> do
|
||||
liftIO $ enablePings c
|
||||
@@ -858,8 +858,8 @@ subscribeService c party = case smpClientService c of
|
||||
where
|
||||
subCmd :: Command p
|
||||
subCmd = case party of
|
||||
SRecipient -> SUBS
|
||||
SNotifier -> NSUBS
|
||||
SRecipientService -> SUBS
|
||||
SNotifierService -> NSUBS
|
||||
Nothing -> throwE PCEServiceUnavailable
|
||||
|
||||
smpClientService :: SMPClient -> Maybe THClientService
|
||||
@@ -1119,8 +1119,8 @@ proxySMPCommand c@ProtocolClient {thParams = proxyThParams, client_ = PClient {c
|
||||
-- server interaction errors are thrown directly
|
||||
t' <- liftEitherWith PCECryptoError $ C.cbDecrypt cmdSecret (C.reverseNonce nonce) er
|
||||
case tParse serverThParams t' of
|
||||
t'' :| [] -> case tDecodeParseValidate serverThParams t'' of
|
||||
(_auth, _signed, (_c, _e, cmd)) -> case cmd of
|
||||
t'' :| [] -> case tDecodeClient serverThParams t'' of
|
||||
(_, _, cmd) -> case cmd of
|
||||
Right (ERR e) -> throwE $ PCEProtocolError e -- this is the error from the destination relay
|
||||
Right r' -> pure $ Right r'
|
||||
Left e -> throwE $ PCEResponseError e
|
||||
|
||||
@@ -70,9 +70,9 @@ import Simplex.Messaging.Protocol
|
||||
QueueId,
|
||||
SMPServer,
|
||||
SParty (..),
|
||||
SubscriberParty,
|
||||
subscriberParty,
|
||||
subscriberServiceRole
|
||||
ServiceParty,
|
||||
serviceParty,
|
||||
partyServiceRole
|
||||
)
|
||||
import Simplex.Messaging.Session
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
@@ -331,11 +331,11 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s
|
||||
|
||||
reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> ExceptT SMPClientError IO ()
|
||||
reconnectSMPClient ca@SMPClientAgent {agentCfg, agentParty} srv (sSub_, qSubs_) =
|
||||
withSMP ca srv $ \smp -> liftIO $ case subscriberParty agentParty of
|
||||
withSMP ca srv $ \smp -> liftIO $ case serviceParty agentParty of
|
||||
Just Dict -> resubscribe smp
|
||||
Nothing -> pure ()
|
||||
where
|
||||
resubscribe :: (PartyI p, SubscriberParty p) => SMPClient -> IO ()
|
||||
resubscribe :: (PartyI p, ServiceParty p) => SMPClient -> IO ()
|
||||
resubscribe smp = do
|
||||
mapM_ (smpSubscribeService ca smp srv) sSub_
|
||||
forM_ qSubs_ $ \qSubs -> do
|
||||
@@ -394,22 +394,22 @@ withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPE
|
||||
logInfo $ "SMP error (" <> safeDecodeUtf8 (strEncode $ host srv) <> "): " <> tshow e
|
||||
throwE e
|
||||
|
||||
subscribeQueuesNtfs :: SMPClientAgent 'Notifier -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO ()
|
||||
subscribeQueuesNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO ()
|
||||
subscribeQueuesNtfs = subscribeQueues_
|
||||
{-# INLINE subscribeQueuesNtfs #-}
|
||||
|
||||
subscribeQueues_ :: SubscriberParty p => SMPClientAgent p -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO ()
|
||||
subscribeQueues_ :: ServiceParty p => SMPClientAgent p -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO ()
|
||||
subscribeQueues_ ca srv subs = do
|
||||
atomically $ addPendingSubs ca srv $ L.toList subs
|
||||
runExceptT (getSMPServerClient' ca srv) >>= \case
|
||||
Right smp -> smpSubscribeQueues ca smp srv subs
|
||||
Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that
|
||||
|
||||
smpSubscribeQueues :: SubscriberParty p => SMPClientAgent p -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO ()
|
||||
smpSubscribeQueues :: ServiceParty p => SMPClientAgent p -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO ()
|
||||
smpSubscribeQueues ca smp srv subs = do
|
||||
rs <- case agentParty ca of
|
||||
SRecipient -> subscribeSMPQueues smp subs
|
||||
SNotifier -> subscribeSMPQueuesNtfs smp subs
|
||||
SRecipientService -> subscribeSMPQueues smp subs
|
||||
SNotifierService -> subscribeSMPQueuesNtfs smp subs
|
||||
rs' <-
|
||||
atomically $
|
||||
ifM
|
||||
@@ -454,18 +454,18 @@ smpSubscribeQueues ca smp srv subs = do
|
||||
notify_ :: (SMPServer -> NonEmpty a -> SMPClientAgentEvent) -> [a] -> IO ()
|
||||
notify_ evt qs = mapM_ (notify ca . evt srv) $ L.nonEmpty qs
|
||||
|
||||
subscribeServiceNtfs :: SMPClientAgent 'Notifier -> SMPServer -> (ServiceId, Int64) -> IO ()
|
||||
subscribeServiceNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> (ServiceId, Int64) -> IO ()
|
||||
subscribeServiceNtfs = subscribeService_
|
||||
{-# INLINE subscribeServiceNtfs #-}
|
||||
|
||||
subscribeService_ :: (PartyI p, SubscriberParty p) => SMPClientAgent p -> SMPServer -> (ServiceId, Int64) -> IO ()
|
||||
subscribeService_ :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPServer -> (ServiceId, Int64) -> IO ()
|
||||
subscribeService_ ca srv serviceSub = do
|
||||
atomically $ setPendingServiceSub ca srv $ Just serviceSub
|
||||
runExceptT (getSMPServerClient' ca srv) >>= \case
|
||||
Right smp -> smpSubscribeService ca smp srv serviceSub
|
||||
Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that
|
||||
|
||||
smpSubscribeService :: (PartyI p, SubscriberParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> (ServiceId, Int64) -> IO ()
|
||||
smpSubscribeService :: (PartyI p, ServiceParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> (ServiceId, Int64) -> IO ()
|
||||
smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService smp of
|
||||
Just service | serviceAvailable service -> subscribe
|
||||
_ -> notifyUnavailable
|
||||
@@ -490,7 +490,7 @@ smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService
|
||||
setActiveServiceSub ca srv $ Just ((serviceId, n), sessId)
|
||||
setPendingServiceSub ca srv Nothing
|
||||
serviceAvailable THClientService {serviceRole, serviceId = serviceId'} =
|
||||
serviceId == serviceId' && subscriberServiceRole (agentParty ca) == serviceRole
|
||||
serviceId == serviceId' && partyServiceRole (agentParty ca) == serviceRole
|
||||
notifyUnavailable = do
|
||||
atomically $ setPendingServiceSub ca srv Nothing
|
||||
notify ca $ CAServiceUnavailable srv serviceSub -- this will resubscribe all queues directly
|
||||
|
||||
@@ -214,7 +214,7 @@ instance NtfEntityI e => ProtocolEncoding NTFVersion ErrorType (NtfCommand e) wh
|
||||
fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (auth, _, EntityId entityId, _) cmd = case cmd of
|
||||
checkCredentials auth (EntityId entityId) cmd = case cmd of
|
||||
-- TNEW and SNEW must have signature but NOT token/subscription IDs
|
||||
TNEW {} -> sigNoEntity
|
||||
SNEW {} -> sigNoEntity
|
||||
@@ -254,7 +254,7 @@ instance ProtocolEncoding NTFVersion ErrorType NtfCmd where
|
||||
fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c
|
||||
checkCredentials tAuth entId (NtfCmd e c) = NtfCmd e <$> checkCredentials tAuth entId c
|
||||
|
||||
data NtfResponseTag
|
||||
= NRTknId_
|
||||
@@ -334,7 +334,7 @@ instance ProtocolEncoding NTFVersion ErrorType NtfResponse where
|
||||
PEBlock -> BLOCK
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (_, _, EntityId entId, _) cmd = case cmd of
|
||||
checkCredentials _ (EntityId entId) cmd = case cmd of
|
||||
-- IDTKN response must not have queue ID
|
||||
NRTknId {} -> noEntity
|
||||
-- IDSUB response must not have queue ID
|
||||
|
||||
@@ -62,7 +62,7 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore, TokenNtfMessag
|
||||
import Simplex.Messaging.Notifications.Server.Store.Postgres
|
||||
import Simplex.Messaging.Notifications.Server.Store.Types
|
||||
import Simplex.Messaging.Notifications.Transport
|
||||
import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceId, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGet, tPut)
|
||||
import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceId, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGetServer, tPut)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server
|
||||
import Simplex.Messaging.Server.Control (CPClientRole (..))
|
||||
@@ -277,21 +277,21 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
|
||||
apnsPushQLength
|
||||
}
|
||||
where
|
||||
getSMPServiceSubMetrics :: forall sub. SMPClientAgent 'Notifier -> (SMPClientAgent 'Notifier -> TMap SMPServer (TVar (Maybe sub))) -> (sub -> Int64) -> IO NtfSMPSubMetrics
|
||||
getSMPServiceSubMetrics :: forall sub. SMPClientAgent 'NotifierService -> (SMPClientAgent 'NotifierService -> TMap SMPServer (TVar (Maybe sub))) -> (sub -> Int64) -> IO NtfSMPSubMetrics
|
||||
getSMPServiceSubMetrics a sel subQueueCount = getSubMetrics_ a sel countSubs
|
||||
where
|
||||
countSubs :: (NtfSMPSubMetrics, S.Set Text) -> (SMPServer, TVar (Maybe sub)) -> IO (NtfSMPSubMetrics, S.Set Text)
|
||||
countSubs acc (srv, serviceSubs) = subMetricsResult a acc srv . fromIntegral . maybe 0 subQueueCount <$> readTVarIO serviceSubs
|
||||
|
||||
getSMPSubMetrics :: SMPClientAgent 'Notifier -> (SMPClientAgent 'Notifier -> TMap SMPServer (TMap NotifierId a)) -> IO NtfSMPSubMetrics
|
||||
getSMPSubMetrics :: SMPClientAgent 'NotifierService -> (SMPClientAgent 'NotifierService -> TMap SMPServer (TMap NotifierId a)) -> IO NtfSMPSubMetrics
|
||||
getSMPSubMetrics a sel = getSubMetrics_ a sel countSubs
|
||||
where
|
||||
countSubs :: (NtfSMPSubMetrics, S.Set Text) -> (SMPServer, TMap NotifierId a) -> IO (NtfSMPSubMetrics, S.Set Text)
|
||||
countSubs acc (srv, queueSubs) = subMetricsResult a acc srv . M.size <$> readTVarIO queueSubs
|
||||
|
||||
getSubMetrics_ ::
|
||||
SMPClientAgent 'Notifier ->
|
||||
(SMPClientAgent 'Notifier -> TVar (M.Map SMPServer sub')) ->
|
||||
SMPClientAgent 'NotifierService ->
|
||||
(SMPClientAgent 'NotifierService -> TVar (M.Map SMPServer sub')) ->
|
||||
((NtfSMPSubMetrics, S.Set Text) -> (SMPServer, sub') -> IO (NtfSMPSubMetrics, S.Set Text)) ->
|
||||
IO NtfSMPSubMetrics
|
||||
getSubMetrics_ a sel countSubs = do
|
||||
@@ -300,7 +300,7 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
|
||||
(metrics', otherSrvs) <- foldM countSubs (metrics, S.empty) $ M.assocs subs
|
||||
pure (metrics' :: NtfSMPSubMetrics) {otherServers = S.size otherSrvs}
|
||||
|
||||
subMetricsResult :: SMPClientAgent 'Notifier -> (NtfSMPSubMetrics, S.Set Text) -> SMPServer -> Int -> (NtfSMPSubMetrics, S.Set Text)
|
||||
subMetricsResult :: SMPClientAgent 'NotifierService -> (NtfSMPSubMetrics, S.Set Text) -> SMPServer -> Int -> (NtfSMPSubMetrics, S.Set Text)
|
||||
subMetricsResult a acc@(metrics, !otherSrvs) srv@(SMPServer (h :| _) _ _) cnt
|
||||
| isOwnServer a srv =
|
||||
let !ownSrvSubs' = M.alter (Just . maybe cnt (+ cnt)) host ownSrvSubs
|
||||
@@ -314,9 +314,9 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions}
|
||||
NtfSMPSubMetrics {ownSrvSubs, otherSrvSubCount} = metrics
|
||||
host = safeDecodeUtf8 $ strEncode h
|
||||
|
||||
getSMPWorkerMetrics :: SMPClientAgent 'Notifier -> TMap SMPServer a -> IO NtfSMPWorkerMetrics
|
||||
getSMPWorkerMetrics :: SMPClientAgent 'NotifierService -> TMap SMPServer a -> IO NtfSMPWorkerMetrics
|
||||
getSMPWorkerMetrics a v = workerMetrics a . M.keys <$> readTVarIO v
|
||||
workerMetrics :: SMPClientAgent 'Notifier -> [SMPServer] -> NtfSMPWorkerMetrics
|
||||
workerMetrics :: SMPClientAgent 'NotifierService -> [SMPServer] -> NtfSMPWorkerMetrics
|
||||
workerMetrics a srvs = NtfSMPWorkerMetrics {ownServers = reverse ownSrvs, otherServers}
|
||||
where
|
||||
(ownSrvs, otherServers) = foldl' countSrv ([], 0) srvs
|
||||
@@ -455,7 +455,7 @@ resubscribe NtfSubscriber {smpAgent = ca} = do
|
||||
counts <- mapConcurrently (subscribeSrvSubs ca st batchSize) srvs
|
||||
logNote $ "Completed all SMP resubscriptions for " <> tshow (length srvs) <> " servers (" <> tshow (sum counts) <> " subscriptions)"
|
||||
|
||||
subscribeSrvSubs :: SMPClientAgent 'Notifier -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe (ServiceId, Int64)) -> IO Int
|
||||
subscribeSrvSubs :: SMPClientAgent 'NotifierService -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe (ServiceId, Int64)) -> IO Int
|
||||
subscribeSrvSubs ca st batchSize (srv, srvId, service_) = do
|
||||
let srvStr = safeDecodeUtf8 (strEncode $ L.head $ host srv)
|
||||
logNote $ "Starting SMP resubscriptions for " <> srvStr
|
||||
@@ -722,25 +722,24 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte
|
||||
|
||||
receive :: Transport c => NtfPostgresStore -> THandleNTF c 'TServer -> NtfServerClient -> IO ()
|
||||
receive st th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ, rcvActiveAt} = forever $ do
|
||||
ts <- L.toList <$> tGet th
|
||||
ts <- L.toList <$> tGetServer th
|
||||
atomically . (writeTVar rcvActiveAt $!) =<< getSystemTime
|
||||
(errs, cmds) <- partitionEithers <$> mapM cmdAction ts
|
||||
write sndQ errs
|
||||
write rcvQ cmds
|
||||
where
|
||||
cmdAction t@(_, _, (corrId, entId, cmdOrError)) =
|
||||
case cmdOrError of
|
||||
Left e -> do
|
||||
logError $ "invalid client request: " <> tshow e
|
||||
pure $ Left (corrId, entId, NRErr e)
|
||||
Right cmd ->
|
||||
verified =<< verifyNtfTransmission st ((,C.cbNonce (SMP.bs corrId)) <$> thAuth) t cmd
|
||||
where
|
||||
verified = \case
|
||||
VRVerified req -> pure $ Right req
|
||||
VRFailed e -> do
|
||||
logError "unauthorized client request"
|
||||
pure $ Left (corrId, entId, NRErr e)
|
||||
cmdAction = \case
|
||||
Left (corrId, entId, e) -> do
|
||||
logError $ "invalid client request: " <> tshow e
|
||||
pure $ Left (corrId, entId, NRErr e)
|
||||
Right t@(_, _, (corrId, entId, _)) ->
|
||||
verified =<< verifyNtfTransmission st thAuth t
|
||||
where
|
||||
verified = \case
|
||||
VRVerified req -> pure $ Right req
|
||||
VRFailed e -> do
|
||||
logError "unauthorized client request"
|
||||
pure $ Left (corrId, entId, NRErr e)
|
||||
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
|
||||
|
||||
send :: Transport c => THandleNTF c 'TServer -> NtfServerClient -> IO ()
|
||||
@@ -751,10 +750,10 @@ send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do
|
||||
|
||||
data VerificationResult = VRVerified NtfRequest | VRFailed ErrorType
|
||||
|
||||
verifyNtfTransmission :: NtfPostgresStore -> Maybe (THandleAuth 'TServer, C.CbNonce) -> SignedTransmission ErrorType NtfCmd -> NtfCmd -> IO VerificationResult
|
||||
verifyNtfTransmission st auth_ (tAuth, authorized, (corrId, entId, _)) = \case
|
||||
verifyNtfTransmission :: NtfPostgresStore -> Maybe (THandleAuth 'TServer) -> SignedTransmission NtfCmd -> IO VerificationResult
|
||||
verifyNtfTransmission st thAuth (tAuth, authorized, (corrId, entId, cmd)) = case cmd of
|
||||
NtfCmd SToken c@(TNEW tkn@(NewNtfTkn _ k _))
|
||||
| verifyCmdAuthorization auth_ tAuth authorized k ->
|
||||
| verifyCmdAuthorization thAuth tAuth authorized corrId k ->
|
||||
result <$> findNtfTokenRegistration st tkn
|
||||
| otherwise -> pure $ VRFailed AUTH
|
||||
where
|
||||
@@ -783,10 +782,10 @@ verifyNtfTransmission st auth_ (tAuth, authorized, (corrId, entId, _)) = \case
|
||||
subCmd s c = NtfReqCmd SSubscription (NtfSub s) (corrId, entId, c)
|
||||
verifyToken :: NtfTknRec -> NtfRequest -> VerificationResult
|
||||
verifyToken NtfTknRec {tknVerifyKey} r
|
||||
| verifyCmdAuthorization auth_ tAuth authorized tknVerifyKey = VRVerified r
|
||||
| verifyCmdAuthorization thAuth tAuth authorized corrId tknVerifyKey = VRVerified r
|
||||
| otherwise = VRFailed AUTH
|
||||
err = \case -- signature verification for AUTH errors mitigates timing attacks for existence checks
|
||||
AUTH -> maybe False (dummyVerifyCmd auth_ authorized) tAuth `seq` VRFailed AUTH
|
||||
AUTH -> dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH
|
||||
e -> VRFailed e
|
||||
|
||||
client :: NtfServerClient -> NtfSubscriber -> NtfPushServer -> M ()
|
||||
|
||||
@@ -127,7 +127,7 @@ newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbSt
|
||||
data NtfSubscriber = NtfSubscriber
|
||||
{ smpSubscribers :: TMap SMPServer SMPSubscriberVar,
|
||||
subscriberSeq :: TVar Int,
|
||||
smpAgent :: SMPClientAgent 'Notifier
|
||||
smpAgent :: SMPClientAgent 'NotifierService
|
||||
}
|
||||
|
||||
type SMPSubscriberVar = SessionVar SMPSubscriber
|
||||
@@ -136,7 +136,7 @@ newNtfSubscriber :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO NtfSubscriber
|
||||
newNtfSubscriber smpAgentCfg random = do
|
||||
smpSubscribers <- TM.emptyIO
|
||||
subscriberSeq <- newTVarIO 0
|
||||
smpAgent <- newSMPClientAgent SNotifier smpAgentCfg random
|
||||
smpAgent <- newSMPClientAgent SNotifierService smpAgentCfg random
|
||||
pure NtfSubscriber {smpSubscribers, subscriberSeq, smpAgent}
|
||||
|
||||
data SMPSubscriber = SMPSubscriber
|
||||
|
||||
@@ -66,8 +66,9 @@ module Simplex.Messaging.Protocol
|
||||
EncDataBytes (..),
|
||||
Party (..),
|
||||
Cmd (..),
|
||||
DirectParty,
|
||||
SubscriberParty,
|
||||
QueueParty,
|
||||
BatchParty,
|
||||
ServiceParty,
|
||||
ASubscriberParty (..),
|
||||
BrokerMsg (..),
|
||||
SParty (..),
|
||||
@@ -80,12 +81,13 @@ module Simplex.Messaging.Protocol
|
||||
BrokerErrorType (..),
|
||||
BlockingInfo (..),
|
||||
BlockingReason (..),
|
||||
RawTransmission,
|
||||
Transmission,
|
||||
TAuthorizations,
|
||||
TransmissionAuth (..),
|
||||
SignedTransmission,
|
||||
SignedTransmissionOrError,
|
||||
SentRawTransmission,
|
||||
SignedRawTransmission,
|
||||
ClientMsgEnvelope (..),
|
||||
PubHeader (..),
|
||||
ClientMessage (..),
|
||||
@@ -153,8 +155,11 @@ module Simplex.Messaging.Protocol
|
||||
currentSMPClientVersion,
|
||||
senderCanSecure,
|
||||
queueReqMode,
|
||||
subscriberParty,
|
||||
subscriberServiceRole,
|
||||
queueParty,
|
||||
batchParty,
|
||||
serviceParty,
|
||||
partyClientRole,
|
||||
partyServiceRole,
|
||||
userProtocol,
|
||||
rcvMessageMeta,
|
||||
noMsgFlags,
|
||||
@@ -186,9 +191,11 @@ module Simplex.Messaging.Protocol
|
||||
TransportBatch (..),
|
||||
tPut,
|
||||
tPutLog,
|
||||
tGet,
|
||||
tGetServer,
|
||||
tGetClient,
|
||||
tParse,
|
||||
tDecodeParseValidate,
|
||||
tDecodeServer,
|
||||
tDecodeClient,
|
||||
tEncode,
|
||||
tEncodeBatch1,
|
||||
batchTransmissions,
|
||||
@@ -208,7 +215,7 @@ import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import qualified Data.ByteString.Base64 as B64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
@@ -303,22 +310,40 @@ e2eEncMessageLength :: Int
|
||||
e2eEncMessageLength = 16000 -- 15988 .. 16005
|
||||
|
||||
-- | SMP protocol clients
|
||||
data Party = Recipient | Sender | Notifier | LinkClient | ProxiedClient | ProxyService
|
||||
data Party
|
||||
= Creator
|
||||
| Recipient
|
||||
| RecipientService
|
||||
| Sender
|
||||
| IdleClient
|
||||
| Notifier
|
||||
| NotifierService
|
||||
| LinkClient
|
||||
| ProxiedClient
|
||||
| ProxyService
|
||||
deriving (Show)
|
||||
|
||||
-- | Singleton types for SMP protocol clients
|
||||
data SParty :: Party -> Type where
|
||||
SCreator :: SParty Creator
|
||||
SRecipient :: SParty Recipient
|
||||
SRecipientService :: SParty RecipientService
|
||||
SSender :: SParty Sender
|
||||
SIdleClient :: SParty IdleClient
|
||||
SNotifier :: SParty Notifier
|
||||
SNotifierService :: SParty NotifierService
|
||||
SSenderLink :: SParty LinkClient
|
||||
SProxiedClient :: SParty ProxiedClient
|
||||
SProxyService :: SParty ProxyService
|
||||
|
||||
instance TestEquality SParty where
|
||||
testEquality SCreator SCreator = Just Refl
|
||||
testEquality SRecipient SRecipient = Just Refl
|
||||
testEquality SRecipientService SRecipientService = Just Refl
|
||||
testEquality SSender SSender = Just Refl
|
||||
testEquality SIdleClient SIdleClient = Just Refl
|
||||
testEquality SNotifier SNotifier = Just Refl
|
||||
testEquality SNotifierService SNotifierService = Just Refl
|
||||
testEquality SSenderLink SSenderLink = Just Refl
|
||||
testEquality SProxiedClient SProxiedClient = Just Refl
|
||||
testEquality SProxyService SProxyService = Just Refl
|
||||
@@ -328,34 +353,72 @@ deriving instance Show (SParty p)
|
||||
|
||||
class PartyI (p :: Party) where sParty :: SParty p
|
||||
|
||||
instance PartyI Creator where sParty = SCreator
|
||||
|
||||
instance PartyI Recipient where sParty = SRecipient
|
||||
|
||||
instance PartyI RecipientService where sParty = SRecipientService
|
||||
|
||||
instance PartyI Sender where sParty = SSender
|
||||
|
||||
instance PartyI IdleClient where sParty = SIdleClient
|
||||
|
||||
instance PartyI Notifier where sParty = SNotifier
|
||||
|
||||
instance PartyI NotifierService where sParty = SNotifierService
|
||||
|
||||
instance PartyI LinkClient where sParty = SSenderLink
|
||||
|
||||
instance PartyI ProxiedClient where sParty = SProxiedClient
|
||||
|
||||
instance PartyI ProxyService where sParty = SProxyService
|
||||
|
||||
type family DirectParty (p :: Party) :: Constraint where
|
||||
DirectParty Recipient = ()
|
||||
DirectParty Sender = ()
|
||||
DirectParty Notifier = ()
|
||||
DirectParty LinkClient = ()
|
||||
DirectParty ProxyService = ()
|
||||
DirectParty p =
|
||||
(Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not direct"))
|
||||
-- command parties that can read queues
|
||||
type family QueueParty (p :: Party) :: Constraint where
|
||||
QueueParty Recipient = ()
|
||||
QueueParty Sender = ()
|
||||
QueueParty Notifier = ()
|
||||
QueueParty LinkClient = ()
|
||||
QueueParty p =
|
||||
(Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not QueueParty"))
|
||||
|
||||
type family SubscriberParty (p :: Party) :: Constraint where
|
||||
SubscriberParty Recipient = ()
|
||||
SubscriberParty Notifier = ()
|
||||
SubscriberParty p =
|
||||
(Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not subscriber"))
|
||||
queueParty :: SParty p -> Maybe (Dict (PartyI p, QueueParty p))
|
||||
queueParty = \case
|
||||
SRecipient -> Just Dict
|
||||
SSender -> Just Dict
|
||||
SSenderLink -> Just Dict
|
||||
SNotifier -> Just Dict
|
||||
_ -> Nothing
|
||||
{-# INLINE queueParty #-}
|
||||
|
||||
data ASubscriberParty = forall p. (PartyI p, SubscriberParty p) => ASP (SParty p)
|
||||
type family BatchParty (p :: Party) :: Constraint where
|
||||
BatchParty Recipient = ()
|
||||
BatchParty Notifier = ()
|
||||
BatchParty p =
|
||||
(Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not BatchParty"))
|
||||
|
||||
batchParty :: SParty p -> Maybe (Dict (PartyI p, BatchParty p))
|
||||
batchParty = \case
|
||||
SRecipient -> Just Dict
|
||||
SNotifier -> Just Dict
|
||||
_ -> Nothing
|
||||
{-# INLINE batchParty #-}
|
||||
|
||||
-- command parties that can subscribe to individual queues
|
||||
type family ServiceParty (p :: Party) :: Constraint where
|
||||
ServiceParty RecipientService = ()
|
||||
ServiceParty NotifierService = ()
|
||||
ServiceParty p =
|
||||
(Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not ServiceParty"))
|
||||
|
||||
serviceParty :: SParty p -> Maybe (Dict (PartyI p, ServiceParty p))
|
||||
serviceParty = \case
|
||||
SRecipientService -> Just Dict
|
||||
SNotifierService -> Just Dict
|
||||
_ -> Nothing
|
||||
{-# INLINE serviceParty #-}
|
||||
|
||||
data ASubscriberParty = forall p. (PartyI p, ServiceParty p) => ASP (SParty p)
|
||||
|
||||
deriving instance Show ASubscriberParty
|
||||
|
||||
@@ -364,30 +427,37 @@ instance Eq ASubscriberParty where
|
||||
|
||||
instance Encoding ASubscriberParty where
|
||||
smpEncode = \case
|
||||
ASP SRecipient -> "R"
|
||||
ASP SNotifier -> "N"
|
||||
ASP SRecipientService -> "R"
|
||||
ASP SNotifierService -> "N"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'R' -> pure $ ASP SRecipient
|
||||
'N' -> pure $ ASP SNotifier
|
||||
'R' -> pure $ ASP SRecipientService
|
||||
'N' -> pure $ ASP SNotifierService
|
||||
_ -> fail "bad ASubscriberParty"
|
||||
|
||||
instance StrEncoding ASubscriberParty where
|
||||
strEncode = smpEncode
|
||||
strP = smpP
|
||||
|
||||
subscriberParty :: SParty p -> Maybe (Dict (PartyI p, SubscriberParty p))
|
||||
subscriberParty = \case
|
||||
SRecipient -> Just Dict
|
||||
SNotifier -> Just Dict
|
||||
_ -> Nothing
|
||||
{-# INLINE subscriberParty #-}
|
||||
partyClientRole :: SParty p -> Maybe SMPServiceRole
|
||||
partyClientRole = \case
|
||||
SCreator -> Just SRMessaging
|
||||
SRecipient -> Just SRMessaging
|
||||
SRecipientService -> Just SRMessaging
|
||||
SSender -> Just SRMessaging
|
||||
SIdleClient -> Nothing
|
||||
SNotifier -> Just SRNotifier
|
||||
SNotifierService -> Just SRNotifier
|
||||
SSenderLink -> Just SRMessaging
|
||||
SProxiedClient -> Just SRMessaging
|
||||
SProxyService -> Just SRProxy
|
||||
{-# INLINE partyClientRole #-}
|
||||
|
||||
subscriberServiceRole :: SubscriberParty p => SParty p -> SMPServiceRole
|
||||
subscriberServiceRole = \case
|
||||
SRecipient -> SRMessaging
|
||||
SNotifier -> SRNotifier
|
||||
{-# INLINE subscriberServiceRole #-}
|
||||
partyServiceRole :: ServiceParty p => SParty p -> SMPServiceRole
|
||||
partyServiceRole = \case
|
||||
SRecipientService -> SRMessaging
|
||||
SNotifierService -> SRNotifier
|
||||
{-# INLINE partyServiceRole #-}
|
||||
|
||||
-- | Type for client command of any participant.
|
||||
data Cmd = forall p. PartyI p => Cmd (SParty p) (Command p)
|
||||
@@ -398,7 +468,9 @@ deriving instance Show Cmd
|
||||
type Transmission c = (CorrId, EntityId, c)
|
||||
|
||||
-- | signed parsed transmission, with original raw bytes and parsing error.
|
||||
type SignedTransmission e c = (Maybe TAuthorizations, Signed, Transmission (Either e c))
|
||||
type SignedTransmission c = (Maybe TAuthorizations, Signed, Transmission c)
|
||||
|
||||
type SignedTransmissionOrError e c = Either (Transmission e) (SignedTransmission c)
|
||||
|
||||
type Signed = ByteString
|
||||
|
||||
@@ -439,9 +511,6 @@ decodeTAuthBytes s serviceSig
|
||||
| B.length s == C.cbAuthenticatorSize = Right $ Just (TAAuthenticator (C.CbAuthenticator s), serviceSig)
|
||||
| otherwise = (\sig -> Just (TASignature sig, serviceSig)) <$> C.decodeSignature s
|
||||
|
||||
-- | unparsed sent SMP transmission with signature, without session ID.
|
||||
type SignedRawTransmission = (Maybe TAuthorizations, CorrId, EntityId, ByteString)
|
||||
|
||||
-- | unparsed sent SMP transmission with signature.
|
||||
type SentRawTransmission = (Maybe TAuthorizations, ByteString)
|
||||
|
||||
@@ -466,10 +535,10 @@ data Command (p :: Party) where
|
||||
-- v6 of SMP servers only support signature algorithm for command authorization.
|
||||
-- v7 of SMP servers additionally support additional layer of authenticated encryption.
|
||||
-- RcvPublicAuthKey is defined as C.APublicKey - it can be either signature or DH public keys.
|
||||
NEW :: NewQueueReq -> Command Recipient
|
||||
NEW :: NewQueueReq -> Command Creator
|
||||
SUB :: Command Recipient
|
||||
-- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command.
|
||||
SUBS :: Command Recipient
|
||||
SUBS :: Command RecipientService
|
||||
KEY :: SndPublicAuthKey -> Command Recipient
|
||||
RKEY :: NonEmpty RcvPublicAuthKey -> Command Recipient
|
||||
LSET :: LinkId -> QueueLinkData -> Command Recipient
|
||||
@@ -486,14 +555,14 @@ data Command (p :: Party) where
|
||||
-- SEND v1 has to be supported for encoding/decoding
|
||||
-- SEND :: MsgBody -> Command Sender
|
||||
SEND :: MsgFlags -> MsgBody -> Command Sender
|
||||
PING :: Command Sender
|
||||
PING :: Command IdleClient
|
||||
-- Client accessing short links
|
||||
LKEY :: SndPublicAuthKey -> Command LinkClient
|
||||
LGET :: Command LinkClient
|
||||
-- SMP notification subscriber commands
|
||||
NSUB :: Command Notifier
|
||||
-- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command.
|
||||
NSUBS :: Command Notifier
|
||||
NSUBS :: Command NotifierService
|
||||
PRXY :: SMPServer -> Maybe BasicAuth -> Command ProxiedClient -- request a relay server connection by URI
|
||||
-- Transmission to proxy:
|
||||
-- - entity ID: ID of the session with relay returned in PKEY (response to PRXY)
|
||||
@@ -827,9 +896,9 @@ noMsgFlags = MsgFlags {notification = False}
|
||||
-- * SMP command tags
|
||||
|
||||
data CommandTag (p :: Party) where
|
||||
NEW_ :: CommandTag Recipient
|
||||
NEW_ :: CommandTag Creator
|
||||
SUB_ :: CommandTag Recipient
|
||||
SUBS_ :: CommandTag Recipient
|
||||
SUBS_ :: CommandTag RecipientService
|
||||
KEY_ :: CommandTag Recipient
|
||||
RKEY_ :: CommandTag Recipient
|
||||
LSET_ :: CommandTag Recipient
|
||||
@@ -843,14 +912,14 @@ data CommandTag (p :: Party) where
|
||||
QUE_ :: CommandTag Recipient
|
||||
SKEY_ :: CommandTag Sender
|
||||
SEND_ :: CommandTag Sender
|
||||
PING_ :: CommandTag Sender
|
||||
PING_ :: CommandTag IdleClient
|
||||
LKEY_ :: CommandTag LinkClient
|
||||
LGET_ :: CommandTag LinkClient
|
||||
PRXY_ :: CommandTag ProxiedClient
|
||||
PFWD_ :: CommandTag ProxiedClient
|
||||
RFWD_ :: CommandTag ProxyService
|
||||
NSUB_ :: CommandTag Notifier
|
||||
NSUBS_ :: CommandTag Notifier
|
||||
NSUBS_ :: CommandTag NotifierService
|
||||
|
||||
data CmdTag = forall p. PartyI p => CT (SParty p) (CommandTag p)
|
||||
|
||||
@@ -916,9 +985,9 @@ instance PartyI p => Encoding (CommandTag p) where
|
||||
|
||||
instance ProtocolMsgTag CmdTag where
|
||||
decodeTag = \case
|
||||
"NEW" -> Just $ CT SRecipient NEW_
|
||||
"NEW" -> Just $ CT SCreator NEW_
|
||||
"SUB" -> Just $ CT SRecipient SUB_
|
||||
"SUBS" -> Just $ CT SRecipient SUBS_
|
||||
"SUBS" -> Just $ CT SRecipientService SUBS_
|
||||
"KEY" -> Just $ CT SRecipient KEY_
|
||||
"RKEY" -> Just $ CT SRecipient RKEY_
|
||||
"LSET" -> Just $ CT SRecipient LSET_
|
||||
@@ -932,14 +1001,14 @@ instance ProtocolMsgTag CmdTag where
|
||||
"QUE" -> Just $ CT SRecipient QUE_
|
||||
"SKEY" -> Just $ CT SSender SKEY_
|
||||
"SEND" -> Just $ CT SSender SEND_
|
||||
"PING" -> Just $ CT SSender PING_
|
||||
"PING" -> Just $ CT SIdleClient PING_
|
||||
"LKEY" -> Just $ CT SSenderLink LKEY_
|
||||
"LGET" -> Just $ CT SSenderLink LGET_
|
||||
"PRXY" -> Just $ CT SProxiedClient PRXY_
|
||||
"PFWD" -> Just $ CT SProxiedClient PFWD_
|
||||
"RFWD" -> Just $ CT SProxyService RFWD_
|
||||
"NSUB" -> Just $ CT SNotifier NSUB_
|
||||
"NSUBS" -> Just $ CT SNotifier NSUBS_
|
||||
"NSUBS" -> Just $ CT SNotifierService NSUBS_
|
||||
_ -> Nothing
|
||||
|
||||
instance Encoding CmdTag where
|
||||
@@ -1564,7 +1633,7 @@ instance Protocol SMPVersion ErrorType BrokerMsg where
|
||||
Cmd _ NSUB -> True
|
||||
_ -> False
|
||||
{-# INLINE useServiceAuth #-}
|
||||
protocolPing = Cmd SSender PING
|
||||
protocolPing = Cmd SIdleClient PING
|
||||
{-# INLINE protocolPing #-}
|
||||
protocolError = \case
|
||||
ERR e -> Just e
|
||||
@@ -1576,7 +1645,7 @@ class ProtocolMsgTag (Tag msg) => ProtocolEncoding v err msg | msg -> err, msg -
|
||||
encodeProtocol :: Version v -> msg -> ByteString
|
||||
protocolP :: Version v -> Tag msg -> Parser msg
|
||||
fromProtocolError :: ProtocolErrorType -> err
|
||||
checkCredentials :: SignedRawTransmission -> msg -> Either err msg
|
||||
checkCredentials :: Maybe TAuthorizations -> EntityId -> msg -> Either err msg
|
||||
|
||||
instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
|
||||
type Tag (Command p) = CommandTag p
|
||||
@@ -1620,7 +1689,7 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
|
||||
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (auth, _, EntityId entId, _) cmd = case cmd of
|
||||
checkCredentials auth (EntityId entId) cmd = case cmd of
|
||||
-- NEW must have signature but NOT queue ID
|
||||
NEW {}
|
||||
| isNothing auth -> Left $ CMD NO_AUTH
|
||||
@@ -1663,14 +1732,14 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
|
||||
{-# INLINE encodeProtocol #-}
|
||||
|
||||
protocolP v = \case
|
||||
CT SRecipient tag ->
|
||||
Cmd SRecipient <$> case tag of
|
||||
NEW_
|
||||
| v >= shortLinksSMPVersion -> NEW <$> new smpP smpP
|
||||
| v >= sndAuthKeySMPVersion -> NEW <$> new smpP (qReq <$> smpP)
|
||||
| otherwise -> NEW <$> new auth (pure Nothing)
|
||||
CT SCreator NEW_ -> Cmd SCreator <$> newCmd
|
||||
where
|
||||
newCmd
|
||||
| v >= shortLinksSMPVersion = new smpP smpP
|
||||
| v >= sndAuthKeySMPVersion = new smpP (qReq <$> smpP)
|
||||
| otherwise = new auth (pure Nothing)
|
||||
where
|
||||
new p1 p2 = do
|
||||
new p1 p2 = NEW <$> do
|
||||
rcvAuthKey <- _smpP
|
||||
rcvDhKey <- smpP
|
||||
auth_ <- p1
|
||||
@@ -1681,8 +1750,9 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
|
||||
pure NewQueueReq {rcvAuthKey, rcvDhKey, auth_, subMode, queueReqData} -- ntfCreds
|
||||
auth = optional (A.char 'A' *> smpP)
|
||||
qReq sndSecure = Just $ if sndSecure then QRMessaging Nothing else QRContact Nothing
|
||||
CT SRecipient tag ->
|
||||
Cmd SRecipient <$> case tag of
|
||||
SUB_ -> pure SUB
|
||||
SUBS_ -> pure SUBS
|
||||
KEY_ -> KEY <$> _smpP
|
||||
RKEY_ -> RKEY <$> _smpP
|
||||
LSET_ -> LSET <$> _smpP <*> smpP
|
||||
@@ -1694,11 +1764,12 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
|
||||
OFF_ -> pure OFF
|
||||
DEL_ -> pure DEL
|
||||
QUE_ -> pure QUE
|
||||
CT SRecipientService SUBS_ -> pure $ Cmd SRecipientService SUBS
|
||||
CT SSender tag ->
|
||||
Cmd SSender <$> case tag of
|
||||
SKEY_ -> SKEY <$> _smpP
|
||||
SEND_ -> SEND <$> _smpP <*> (unTail <$> _smpP)
|
||||
PING_ -> pure PING
|
||||
CT SIdleClient PING_ -> pure $ Cmd SIdleClient PING
|
||||
CT SProxyService RFWD_ ->
|
||||
Cmd SProxyService . RFWD . EncFwdTransmission . unTail <$> _smpP
|
||||
CT SSenderLink tag ->
|
||||
@@ -1709,15 +1780,13 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where
|
||||
Cmd SProxiedClient <$> case tag of
|
||||
PFWD_ -> PFWD <$> _smpP <*> smpP <*> (EncTransmission . unTail <$> smpP)
|
||||
PRXY_ -> PRXY <$> _smpP <*> smpP
|
||||
CT SNotifier tag ->
|
||||
pure $ Cmd SNotifier $ case tag of
|
||||
NSUB_ -> NSUB
|
||||
NSUBS_ -> NSUBS
|
||||
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
|
||||
CT SNotifierService NSUBS_ -> pure $ Cmd SNotifierService NSUBS
|
||||
|
||||
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
|
||||
checkCredentials tAuth entId (Cmd p c) = Cmd p <$> checkCredentials tAuth entId c
|
||||
{-# INLINE checkCredentials #-}
|
||||
|
||||
instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where
|
||||
@@ -1804,7 +1873,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where
|
||||
PEBlock -> BLOCK
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (_, _, EntityId entId, _) cmd = case cmd of
|
||||
checkCredentials _ (EntityId entId) cmd = case cmd of
|
||||
-- IDS response should not have queue ID
|
||||
IDS _ -> Right cmd
|
||||
-- ERR response does not always have queue ID
|
||||
@@ -2077,26 +2146,51 @@ tParse thParams@THandleParams {batch} s
|
||||
eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b)
|
||||
eitherList = either (\e -> [Left e])
|
||||
|
||||
-- | Receive client and server transmissions (determined by `cmd` type).
|
||||
tGet :: forall v err cmd c p. (ProtocolEncoding v err cmd, Transport c) => THandle v c p -> IO (NonEmpty (SignedTransmission err cmd))
|
||||
tGet th@THandle {params} = L.map (tDecodeParseValidate params) <$> tGetParse th
|
||||
-- | Receive server transmissions
|
||||
tGetServer :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TServer -> IO (NonEmpty (SignedTransmissionOrError err cmd))
|
||||
tGetServer = tGet tDecodeServer
|
||||
{-# INLINE tGetServer #-}
|
||||
|
||||
tDecodeParseValidate :: forall v p err cmd. ProtocolEncoding v err cmd => THandleParams v p -> Either TransportError RawTransmission -> SignedTransmission err cmd
|
||||
tDecodeParseValidate THandleParams {sessionId, thVersion = v, implySessId} = \case
|
||||
-- | Receive client transmissions
|
||||
tGetClient :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TClient -> IO (NonEmpty (Transmission (Either err cmd)))
|
||||
tGetClient = tGet tDecodeClient
|
||||
{-# INLINE tGetClient #-}
|
||||
|
||||
tGet ::
|
||||
Transport c =>
|
||||
(THandleParams v p -> Either TransportError RawTransmission -> r) ->
|
||||
THandle v c p ->
|
||||
IO (NonEmpty r)
|
||||
tGet tDecode th@THandle {params} = L.map (tDecode params) <$> tGetParse th
|
||||
{-# INLINE tGet #-}
|
||||
|
||||
tDecodeServer :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v 'TServer -> Either TransportError RawTransmission -> SignedTransmissionOrError err cmd
|
||||
tDecodeServer THandleParams {sessionId, thVersion = v, implySessId} = \case
|
||||
Right RawTransmission {authenticator, serviceSig, authorized, sessId, corrId, entityId, command}
|
||||
| implySessId || sessId == sessionId ->
|
||||
let decodedTransmission = (,corrId,entityId,command) <$> decodeTAuthBytes authenticator serviceSig
|
||||
in either (const $ tError corrId) (tParseValidate authorized) decodedTransmission
|
||||
| otherwise -> (Nothing, "", (corrId, NoEntity, Left $ fromProtocolError @v @err @cmd PESession))
|
||||
Left _ -> tError ""
|
||||
| implySessId || sessId == sessionId -> case decodeTAuthBytes authenticator serviceSig of
|
||||
Right tAuth -> bimap t ((tAuth,authorized,) . t) cmdOrErr
|
||||
where
|
||||
cmdOrErr = parseProtocol @v @err @cmd v command >>= checkCredentials tAuth entityId
|
||||
t :: a -> (CorrId, EntityId, a)
|
||||
t = (corrId,entityId,)
|
||||
Left _ -> tError corrId PEBlock
|
||||
| otherwise -> tError corrId PESession
|
||||
Left _ -> tError "" PEBlock
|
||||
where
|
||||
tError :: CorrId -> SignedTransmission err cmd
|
||||
tError corrId = (Nothing, "", (corrId, NoEntity, Left $ fromProtocolError @v @err @cmd PEBlock))
|
||||
tError :: CorrId -> ProtocolErrorType -> SignedTransmissionOrError err cmd
|
||||
tError corrId err = Left (corrId, NoEntity, fromProtocolError @v @err @cmd err)
|
||||
|
||||
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd
|
||||
tParseValidate signed t@(sig, corrId, entityId, command) =
|
||||
let cmd = parseProtocol @v @err @cmd v command >>= checkCredentials t
|
||||
in (sig, signed, (corrId, entityId, cmd))
|
||||
tDecodeClient :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v 'TClient -> Either TransportError RawTransmission -> Transmission (Either err cmd)
|
||||
tDecodeClient THandleParams {sessionId, thVersion = v, implySessId} = \case
|
||||
Right RawTransmission {sessId, corrId, entityId, command}
|
||||
| implySessId || sessId == sessionId -> (corrId, entityId, cmdOrErr)
|
||||
| otherwise -> tError corrId PESession
|
||||
where
|
||||
cmdOrErr = parseProtocol @v @err @cmd v command >>= checkCredentials Nothing entityId
|
||||
Left _ -> tError "" PEBlock
|
||||
where
|
||||
tError :: CorrId -> ProtocolErrorType -> Transmission (Either err cmd)
|
||||
tError corrId err = (corrId, NoEntity, Left $ fromProtocolError @v @err @cmd err)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''MsgFlags)
|
||||
|
||||
|
||||
@@ -52,12 +52,13 @@ import Control.Monad.Reader
|
||||
import Control.Monad.Trans.Except
|
||||
import Control.Monad.STM (retry)
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Bifunctor (first, second)
|
||||
import Data.ByteString.Base64 (encode)
|
||||
import qualified Data.ByteString.Builder as BLD
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Constraint (Dict (..))
|
||||
import Data.Dynamic (toDyn)
|
||||
import Data.Either (fromRight, partitionEithers)
|
||||
import Data.Functor (($>))
|
||||
@@ -1068,48 +1069,53 @@ cancelSub s = case subThread s of
|
||||
_ -> pure ()
|
||||
ProhibitSub -> pure ()
|
||||
|
||||
type VerifiedTransmissionOrError s = Either (Transmission BrokerMsg) (VerifiedTransmission s)
|
||||
|
||||
receive :: forall c s. (Transport c, MsgStoreClass s) => THandleSMP c 'TServer -> s -> Client s -> M s ()
|
||||
receive h@THandle {params = THandleParams {thAuth, sessionId}} ms Client {rcvQ, sndQ, rcvActiveAt} = do
|
||||
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive"
|
||||
sa <- asks serverActive
|
||||
stats <- asks serverStats
|
||||
liftIO $ forever $ do
|
||||
ts <- tGet h
|
||||
ts <- tGetServer h
|
||||
unlessM (readTVarIO sa) $ throwIO $ userError "server stopped"
|
||||
atomically . (writeTVar rcvActiveAt $!) =<< getSystemTime
|
||||
let service = peerClientService =<< thAuth
|
||||
(errs, cmds) <- partitionEithers <$> mapM (cmdAction stats service) (L.toList ts)
|
||||
updateBatchStats stats cmds
|
||||
write sndQ errs
|
||||
write rcvQ cmds
|
||||
let (es, ts') = partitionEithers $ L.toList ts
|
||||
errs = map (second ERR) es
|
||||
case ts' of
|
||||
(_, _, (_, _, Cmd p cmd)) : rest -> do
|
||||
let service = peerClientService =<< thAuth
|
||||
(errs', cmds) <- partitionEithers <$> case batchParty p of
|
||||
Just Dict | not (null rest) && all (sameParty p) ts'-> do
|
||||
updateBatchStats stats cmd -- even if nothing is verified
|
||||
let queueId (_, _, (_, qId, _)) = qId
|
||||
qs <- getQueueRecs ms p $ map queueId ts'
|
||||
zipWithM (\t -> verified stats t . verifyLoadedQueue service thAuth t) ts' qs
|
||||
_ -> mapM (\t -> verified stats t =<< verifyTransmission ms service thAuth t) ts'
|
||||
write rcvQ cmds
|
||||
write sndQ $ errs ++ errs'
|
||||
[] -> write sndQ errs
|
||||
where
|
||||
updateBatchStats :: ServerStats -> [(Maybe (StoreQueue s, QueueRec), Transmission Cmd)] -> IO ()
|
||||
sameParty :: SParty p -> SignedTransmission Cmd -> Bool
|
||||
sameParty p (_, _, (_, _, Cmd p' _)) = isJust $ testEquality p p'
|
||||
updateBatchStats :: ServerStats -> Command p -> IO ()
|
||||
updateBatchStats stats = \case
|
||||
(_, (_, _, (Cmd _ cmd))) : _ -> do
|
||||
let sel_ = case cmd of
|
||||
SUB -> Just qSubAllB
|
||||
DEL -> Just qDeletedAllB
|
||||
NSUB -> Just ntfSubB
|
||||
NDEL -> Just ntfDeletedB
|
||||
_ -> Nothing
|
||||
mapM_ (\sel -> incStat $ sel stats) sel_
|
||||
[] -> pure ()
|
||||
cmdAction :: ServerStats -> Maybe THPeerClientService -> SignedTransmission ErrorType Cmd -> IO (Either (Transmission BrokerMsg) (Maybe (StoreQueue s, QueueRec), Transmission Cmd))
|
||||
cmdAction stats service (tAuth, authorized, (corrId, entId, cmdOrError)) =
|
||||
case cmdOrError of
|
||||
Left e -> pure $ Left (corrId, entId, ERR e)
|
||||
Right cmd -> verified =<< verifyTransmission ms service ((,C.cbNonce (bs corrId)) <$> thAuth) tAuth authorized entId cmd
|
||||
where
|
||||
verified = \case
|
||||
VRVerified q -> pure $ Right (q, (corrId, entId, cmd))
|
||||
VRFailed e -> do
|
||||
case cmd of
|
||||
Cmd _ SEND {} -> incStat $ msgSentAuth stats
|
||||
Cmd _ SUB -> incStat $ qSubAuth stats
|
||||
Cmd _ NSUB -> incStat $ ntfSubAuth stats
|
||||
Cmd _ GET -> incStat $ msgGetAuth stats
|
||||
_ -> pure ()
|
||||
pure $ Left (corrId, entId, ERR e)
|
||||
SUB -> incStat $ qSubAllB stats
|
||||
DEL -> incStat $ qDeletedAllB stats
|
||||
NDEL -> incStat $ ntfDeletedB stats
|
||||
NSUB -> incStat $ ntfSubB stats
|
||||
_ -> pure ()
|
||||
verified :: ServerStats -> SignedTransmission Cmd -> VerificationResult s -> IO (VerifiedTransmissionOrError s)
|
||||
verified stats (_, _, t@(corrId, entId, Cmd _ command)) = \case
|
||||
VRVerified q -> pure $ Right (q, t)
|
||||
VRFailed e -> Left (corrId, entId, ERR e) <$ when (e == AUTH) incAuthStat
|
||||
where
|
||||
incAuthStat = case command of
|
||||
SEND {} -> incStat $ msgSentAuth stats
|
||||
SUB -> incStat $ qSubAuth stats
|
||||
NSUB -> incStat $ ntfSubAuth stats
|
||||
GET -> incStat $ msgGetAuth stats
|
||||
_ -> pure ()
|
||||
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
|
||||
|
||||
send :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> IO ()
|
||||
@@ -1169,34 +1175,42 @@ data VerificationResult s = VRVerified (Maybe (StoreQueue s, QueueRec)) | VRFail
|
||||
-- - the queue or party key do not exist.
|
||||
-- In all cases, the time of the verification should depend only on the provided authorization type,
|
||||
-- a dummy key is used to run verification in the last two cases, and failure is returned irrespective of the result.
|
||||
verifyTransmission :: forall s. MsgStoreClass s => s -> Maybe THPeerClientService -> Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TAuthorizations -> ByteString -> QueueId -> Cmd -> IO (VerificationResult s)
|
||||
verifyTransmission ms service auth_ tAuth authorized queueId command@(Cmd party cmd)
|
||||
| verifyServiceSig = case party of
|
||||
SRecipient | hasRole SRMessaging -> case cmd of
|
||||
NEW NewQueueReq {rcvAuthKey = k} -> pure $ Nothing `verifiedWith` k
|
||||
SUB -> verifyQueue SRecipient $ \q -> Just q `verifiedWithKeys` recipientKeys (snd q)
|
||||
SUBS -> pure verifyServiceCmd
|
||||
_ -> verifyQueue SRecipient $ \q -> Just q `verifiedWithKeys` recipientKeys (snd q)
|
||||
SSender | hasRole SRMessaging -> case cmd of
|
||||
SKEY k -> verifySecure SSender k
|
||||
-- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command
|
||||
SEND {} -> verifyQueue SSender $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified (Just q) else VRFailed AUTH
|
||||
PING -> pure $ VRVerified Nothing
|
||||
SSenderLink | hasRole SRMessaging -> case cmd of
|
||||
LKEY k -> verifySecure SSenderLink k
|
||||
LGET -> verifyQueue SSenderLink $ \q -> if isContactQueue (snd q) then VRVerified (Just q) else VRFailed AUTH
|
||||
SNotifier | hasRole SRNotifier -> case cmd of
|
||||
NSUB -> verifyQueue SNotifier $ \q -> maybe dummyVerify (\n -> Just q `verifiedWith` notifierKey n) (notifier $ snd q)
|
||||
NSUBS -> pure verifyServiceCmd
|
||||
SProxiedClient | hasRole SRMessaging -> pure $ VRVerified Nothing
|
||||
SProxyService | hasRole SRProxy -> pure $ VRVerified Nothing
|
||||
_ -> pure $ VRFailed $ CMD PROHIBITED
|
||||
| otherwise = pure $ VRFailed SERVICE
|
||||
verifyTransmission :: forall s. MsgStoreClass s => s -> Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> IO (VerificationResult s)
|
||||
verifyTransmission ms service thAuth t@(_, _, (_, queueId, Cmd p _)) = case queueParty p of
|
||||
Just Dict -> verifyLoadedQueue service thAuth t <$> getQueueRec ms p queueId
|
||||
Nothing -> pure $ verifyQueueTransmission service thAuth t Nothing
|
||||
|
||||
verifyLoadedQueue :: Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> Either ErrorType (StoreQueue s, QueueRec) -> VerificationResult s
|
||||
verifyLoadedQueue service thAuth t@(tAuth, authorized, (corrId, _, _)) = \case
|
||||
Right q -> verifyQueueTransmission service thAuth t (Just q)
|
||||
Left AUTH -> dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH
|
||||
Left e -> VRFailed e
|
||||
|
||||
verifyQueueTransmission :: forall s. Maybe THPeerClientService -> Maybe (THandleAuth 'TServer) -> SignedTransmission Cmd -> Maybe (StoreQueue s, QueueRec) -> VerificationResult s
|
||||
verifyQueueTransmission service thAuth (tAuth, authorized, (corrId, _, command@(Cmd p cmd))) q_
|
||||
| not checkRole = VRFailed $ CMD PROHIBITED
|
||||
| not verifyServiceSig = VRFailed SERVICE
|
||||
| otherwise = vc p cmd
|
||||
where
|
||||
hasRole role = case service of
|
||||
Just THClientService {serviceRole} -> serviceRole == role
|
||||
Nothing -> True
|
||||
verify = verifyCmdAuthorization auth_ tAuth authorized'
|
||||
vc :: SParty p -> Command p -> VerificationResult s -- this pattern match works with ghc8.10.7, flat case sees it as non-exhastive.
|
||||
vc SCreator (NEW NewQueueReq {rcvAuthKey = k}) = verifiedWith k
|
||||
vc SRecipient SUB = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q)
|
||||
vc SRecipient _ = verifyQueue $ \q -> verifiedWithKeys $ recipientKeys (snd q)
|
||||
vc SRecipientService SUBS = verifyServiceCmd
|
||||
vc SSender (SKEY k) = verifySecure k
|
||||
-- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command
|
||||
vc SSender SEND {} = verifyQueue $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified q_ else VRFailed AUTH
|
||||
vc SIdleClient PING = VRVerified Nothing
|
||||
vc SSenderLink (LKEY k) = verifySecure k
|
||||
vc SSenderLink LGET = verifyQueue $ \q -> if isContactQueue (snd q) then VRVerified q_ else VRFailed AUTH
|
||||
vc SNotifier NSUB = verifyQueue $ \q -> maybe dummyVerify (\n -> verifiedWith $ notifierKey n) (notifier $ snd q)
|
||||
vc SNotifierService NSUBS = verifyServiceCmd
|
||||
vc SProxiedClient _ = VRVerified Nothing
|
||||
vc SProxyService (RFWD _) = VRVerified Nothing
|
||||
checkRole = case (service, partyClientRole p) of
|
||||
(Just THClientService {serviceRole}, Just role) -> serviceRole == role
|
||||
_ -> True
|
||||
verify = verifyCmdAuthorization thAuth tAuth authorized' corrId
|
||||
verifyServiceCmd :: VerificationResult s
|
||||
verifyServiceCmd = case (service, tAuth) of
|
||||
(Just THClientService {serviceKey = k}, Just (TASignature (C.ASignature C.SEd25519 s), Nothing))
|
||||
@@ -1214,20 +1228,17 @@ verifyTransmission ms service auth_ tAuth authorized queueId command@(Cmd party
|
||||
(Just THClientService {serviceCertHash = XV.Fingerprint fp}, Just _) -> fp <> authorized
|
||||
_ -> authorized
|
||||
dummyVerify :: VerificationResult s
|
||||
dummyVerify = verify (dummyAuthKey tAuth) `seq` VRFailed AUTH
|
||||
verifyQueue :: DirectParty p => SParty p -> ((StoreQueue s, QueueRec) -> VerificationResult s) -> IO (VerificationResult s)
|
||||
verifyQueue p v = either err v <$> getQueueRec ms p queueId
|
||||
where
|
||||
-- this prevents reporting any STORE errors as AUTH errors
|
||||
err = \case
|
||||
AUTH -> dummyVerify
|
||||
e -> VRFailed e
|
||||
verifySecure :: DirectParty p => SParty p -> SndPublicAuthKey -> IO (VerificationResult s)
|
||||
verifySecure p k = verifyQueue p $ \q -> if k `allowedKey` snd q then Just q `verifiedWith` k else dummyVerify
|
||||
verifiedWith :: Maybe (StoreQueue s, QueueRec) -> C.APublicAuthKey -> VerificationResult s
|
||||
verifiedWith q_ k = if verify k then VRVerified q_ else VRFailed AUTH
|
||||
verifiedWithKeys :: Maybe (StoreQueue s, QueueRec) -> NonEmpty C.APublicAuthKey -> VerificationResult s
|
||||
verifiedWithKeys q_ ks = if any verify ks then VRVerified q_ else VRFailed AUTH
|
||||
dummyVerify = dummyVerifyCmd thAuth tAuth authorized corrId `seq` VRFailed AUTH
|
||||
-- That a specific command requires queue signature verification is determined by `queueParty`,
|
||||
-- it should be coordinated with the case in this function (`verifyQueueTransmission`)
|
||||
verifyQueue :: ((StoreQueue s, QueueRec) -> VerificationResult s) -> VerificationResult s
|
||||
verifyQueue v = maybe (VRFailed INTERNAL) v q_
|
||||
verifySecure :: SndPublicAuthKey -> VerificationResult s
|
||||
verifySecure k = verifyQueue $ \q -> if k `allowedKey` snd q then verifiedWith k else dummyVerify
|
||||
verifiedWith :: C.APublicAuthKey -> VerificationResult s
|
||||
verifiedWith k = if verify k then VRVerified q_ else VRFailed AUTH
|
||||
verifiedWithKeys :: NonEmpty C.APublicAuthKey -> VerificationResult s
|
||||
verifiedWithKeys ks = if any verify ks then VRVerified q_ else VRFailed AUTH
|
||||
allowedKey k = \case
|
||||
QueueRec {queueMode = Just QMMessaging, senderKey} -> maybe True (k ==) senderKey
|
||||
_ -> False
|
||||
@@ -1243,8 +1254,9 @@ isSecuredMsgQueue QueueRec {queueMode, senderKey} = case queueMode of
|
||||
Just QMContact -> False
|
||||
_ -> isJust senderKey
|
||||
|
||||
verifyCmdAuthorization :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TAuthorizations -> ByteString -> C.APublicAuthKey -> Bool
|
||||
verifyCmdAuthorization auth_ tAuth authorized key = maybe False (verify key) tAuth
|
||||
-- Random correlation ID is used as a nonce in case crypto_box authenticator is used to authorize transmission
|
||||
verifyCmdAuthorization :: Maybe (THandleAuth 'TServer) -> Maybe TAuthorizations -> ByteString -> CorrId -> C.APublicAuthKey -> Bool
|
||||
verifyCmdAuthorization thAuth tAuth authorized corrId key = maybe False (verify key) tAuth
|
||||
where
|
||||
verify :: C.APublicAuthKey -> TAuthorizations -> Bool
|
||||
verify (C.APublicAuthKey a k) = \case
|
||||
@@ -1252,18 +1264,20 @@ verifyCmdAuthorization auth_ tAuth authorized key = maybe False (verify key) tAu
|
||||
Just Refl -> C.verify' k s authorized
|
||||
_ -> C.verify' (dummySignKey a') s authorized `seq` False
|
||||
(TAAuthenticator s, _) -> case a of
|
||||
C.SX25519 -> verifyCmdAuth auth_ k s authorized
|
||||
_ -> verifyCmdAuth auth_ dummyKeyX25519 s authorized `seq` False
|
||||
C.SX25519 -> verifyCmdAuth thAuth k s authorized corrId
|
||||
_ -> verifyCmdAuth thAuth dummyKeyX25519 s authorized corrId `seq` False
|
||||
|
||||
verifyCmdAuth :: Maybe (THandleAuth 'TServer, C.CbNonce) -> C.PublicKeyX25519 -> C.CbAuthenticator -> ByteString -> Bool
|
||||
verifyCmdAuth auth_ k authenticator authorized = case auth_ of
|
||||
Just (THAuthServer {serverPrivKey = pk}, nonce) -> C.cbVerify k pk nonce authenticator authorized
|
||||
verifyCmdAuth :: Maybe (THandleAuth 'TServer) -> C.PublicKeyX25519 -> C.CbAuthenticator -> ByteString -> CorrId -> Bool
|
||||
verifyCmdAuth thAuth k authenticator authorized (CorrId corrId) = case thAuth of
|
||||
Just THAuthServer {serverPrivKey = pk} -> C.cbVerify k pk (C.cbNonce corrId) authenticator authorized
|
||||
Nothing -> False
|
||||
|
||||
dummyVerifyCmd :: Maybe (THandleAuth 'TServer, C.CbNonce) -> ByteString -> TAuthorizations -> Bool
|
||||
dummyVerifyCmd auth_ authorized = \case
|
||||
(TASignature (C.ASignature a s), _) -> C.verify' (dummySignKey a) s authorized
|
||||
(TAAuthenticator s, _) -> verifyCmdAuth auth_ dummyKeyX25519 s authorized
|
||||
dummyVerifyCmd :: Maybe (THandleAuth 'TServer) -> Maybe TAuthorizations -> ByteString -> CorrId -> Maybe Bool
|
||||
dummyVerifyCmd thAuth tAuth authorized corrId = verify <$> tAuth
|
||||
where
|
||||
verify = \case
|
||||
(TASignature (C.ASignature a s), _) -> C.verify' (dummySignKey a) s authorized
|
||||
(TAAuthenticator s, _) -> verifyCmdAuth thAuth dummyKeyX25519 s authorized corrId
|
||||
|
||||
-- These dummy keys are used with `dummyVerify` function to mitigate timing attacks
|
||||
-- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes
|
||||
@@ -1272,13 +1286,6 @@ dummySignKey = \case
|
||||
C.SEd25519 -> dummyKeyEd25519
|
||||
C.SEd448 -> dummyKeyEd448
|
||||
|
||||
dummyAuthKey :: Maybe TAuthorizations -> C.APublicAuthKey
|
||||
dummyAuthKey = \case
|
||||
Just (TASignature (C.ASignature a _), _) -> case a of
|
||||
C.SEd25519 -> C.APublicAuthKey C.SEd25519 dummyKeyEd25519
|
||||
C.SEd448 -> C.APublicAuthKey C.SEd448 dummyKeyEd448
|
||||
_ -> C.APublicAuthKey C.SX25519 dummyKeyX25519
|
||||
|
||||
dummyKeyEd25519 :: C.PublicKey 'C.Ed25519
|
||||
dummyKeyEd25519 = "MCowBQYDK2VwAyEA139Oqs4QgpqbAmB0o7rZf6T19ryl7E65k4AYe0kE3Qs="
|
||||
|
||||
@@ -1392,34 +1399,32 @@ client
|
||||
mkIncProxyStats ps psOwn own sel = do
|
||||
incStat $ sel ps
|
||||
when own $ incStat $ sel psOwn
|
||||
processCommand :: Maybe THPeerClientService -> VersionSMP -> (Maybe (StoreQueue s, QueueRec), Transmission Cmd) -> M s (Maybe (Transmission BrokerMsg))
|
||||
processCommand :: Maybe THPeerClientService -> VersionSMP -> VerifiedTransmission s -> M s (Maybe (Transmission BrokerMsg))
|
||||
processCommand service clntVersion (q_, (corrId, entId, cmd)) = case cmd of
|
||||
Cmd SProxiedClient command -> processProxiedCmd (corrId, entId, command)
|
||||
Cmd SSender command -> Just <$> case command of
|
||||
SKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k
|
||||
SEND flags msgBody -> withQueue_ False $ sendMessage flags msgBody
|
||||
PING -> pure (corrId, NoEntity, PONG)
|
||||
Cmd SProxyService (RFWD encBlock) -> Just . (corrId, NoEntity,) <$> processForwardedCommand encBlock
|
||||
Cmd SIdleClient PING -> pure $ Just (corrId, NoEntity, PONG)
|
||||
Cmd SProxyService (RFWD encBlock) -> Just . (corrId,NoEntity,) <$> processForwardedCommand encBlock
|
||||
Cmd SSenderLink command -> Just <$> case command of
|
||||
LKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k $>> getQueueLink_ q qr
|
||||
LGET -> withQueue $ \q qr -> checkContact qr $ getQueueLink_ q qr
|
||||
Cmd SNotifier command -> Just . (corrId,entId,) <$> case command of
|
||||
NSUB -> case q_ of
|
||||
Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds
|
||||
_ -> pure $ ERR INTERNAL
|
||||
NSUBS -> case service of
|
||||
Just s -> subscribeServiceNotifications s
|
||||
Nothing -> pure $ ERR INTERNAL
|
||||
Cmd SNotifier NSUB -> Just . (corrId,entId,) <$> case q_ of
|
||||
Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds
|
||||
_ -> pure $ ERR INTERNAL
|
||||
Cmd SNotifierService NSUBS -> Just . (corrId,entId,) <$> case service of
|
||||
Just s -> subscribeServiceNotifications s
|
||||
Nothing -> pure $ ERR INTERNAL
|
||||
Cmd SCreator (NEW nqr@NewQueueReq {auth_}) ->
|
||||
Just <$> ifM allowNew (createQueue nqr) (pure (corrId, entId, ERR AUTH))
|
||||
where
|
||||
allowNew = do
|
||||
ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config
|
||||
pure $ allowNewQueues && maybe True ((== auth_) . Just) newQueueBasicAuth
|
||||
Cmd SRecipient command ->
|
||||
Just <$> case command of
|
||||
NEW nqr@NewQueueReq {auth_} ->
|
||||
ifM allowNew (createQueue nqr) (pure (corrId, entId, ERR AUTH))
|
||||
where
|
||||
allowNew = do
|
||||
ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config
|
||||
pure $ allowNewQueues && maybe True ((== auth_) . Just) newQueueBasicAuth
|
||||
SUB -> withQueue subscribeQueue
|
||||
SUBS -> pure $ err (CMD PROHIBITED) -- "TODO [certs rcv]"
|
||||
GET -> withQueue getMessage
|
||||
ACK msgId -> withQueue $ acknowledgeMsg msgId
|
||||
KEY sKey -> withQueue $ \q _ -> either err (corrId,entId,) <$> secureQueue_ q sKey
|
||||
@@ -1438,6 +1443,7 @@ client
|
||||
OFF -> maybe (pure $ err INTERNAL) suspendQueue_ q_
|
||||
DEL -> maybe (pure $ err INTERNAL) delQueueAndMsgs q_
|
||||
QUE -> withQueue $ \q qr -> (corrId,entId,) <$> getQueueInfo q qr
|
||||
Cmd SRecipientService SUBS -> pure $ Just $ err (CMD PROHIBITED) -- "TODO [certs rcv]"
|
||||
where
|
||||
createQueue :: NewQueueReq -> M s (Transmission BrokerMsg)
|
||||
createQueue NewQueueReq {rcvAuthKey, rcvDhKey, subMode, queueReqData}
|
||||
@@ -1492,7 +1498,7 @@ client
|
||||
| clntIds -> pure $ ERR AUTH -- no retry on collision if sender ID is client-supplied
|
||||
| otherwise -> tryCreate (n - 1)
|
||||
Left e -> pure $ ERR e
|
||||
Right q -> do
|
||||
Right _q -> do
|
||||
stats <- asks serverStats
|
||||
incStat $ qCreated stats
|
||||
incStat $ qCount stats
|
||||
@@ -1500,7 +1506,7 @@ client
|
||||
-- when (isJust ntf) $ incStat $ ntfCreated stats
|
||||
case subMode of
|
||||
SMOnlyCreate -> pure ()
|
||||
SMSubscribe -> void $ subscribeQueue q qr
|
||||
SMSubscribe -> void $ subscribeNewQueue rcvId qr -- no need to check if message is available, it's a new queue
|
||||
pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId = fst <$> queueData, serviceId = rcvServiceId} -- , serverNtfCreds = snd <$> ntf
|
||||
(corrId,entId,) <$> tryCreate (3 :: Int)
|
||||
|
||||
@@ -1563,9 +1569,9 @@ client
|
||||
|
||||
-- TODO [certs rcv] if serviceId is passed, associate with the service and respond with SOK
|
||||
subscribeQueue :: StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)
|
||||
subscribeQueue q qr@QueueRec {rcvServiceId} =
|
||||
subscribeQueue q qr =
|
||||
liftIO (TM.lookupIO rId subscriptions) >>= \case
|
||||
Nothing -> newSub >>= deliver True
|
||||
Nothing -> subscribeNewQueue rId qr >>= deliver True
|
||||
Just s@Sub {subThread} -> do
|
||||
stats <- asks serverStats
|
||||
case subThread of
|
||||
@@ -1578,12 +1584,6 @@ client
|
||||
atomically (tryTakeTMVar $ delivered s) >> deliver False s
|
||||
where
|
||||
rId = recipientId q
|
||||
newSub :: M s Sub
|
||||
newSub = time "SUB newSub" . atomically $ do
|
||||
writeTQueue (subQ subscribers) (CSClient rId rcvServiceId Nothing, clientId)
|
||||
sub <- newSubscription NoSub
|
||||
TM.insert rId sub subscriptions
|
||||
pure sub
|
||||
deliver :: Bool -> Sub -> M s (Transmission BrokerMsg)
|
||||
deliver inc sub = do
|
||||
stats <- asks serverStats
|
||||
@@ -1592,6 +1592,13 @@ client
|
||||
liftIO $ when (inc && isJust msg_) $ incStat (qSub stats)
|
||||
liftIO $ deliverMessage "SUB" qr rId sub msg_
|
||||
|
||||
subscribeNewQueue :: RecipientId -> QueueRec -> M s Sub
|
||||
subscribeNewQueue rId QueueRec {rcvServiceId} = time "SUB newSub" . atomically $ do
|
||||
writeTQueue (subQ subscribers) (CSClient rId rcvServiceId Nothing, clientId)
|
||||
sub <- newSubscription NoSub
|
||||
TM.insert rId sub subscriptions
|
||||
pure sub
|
||||
|
||||
-- clients that use GET are not added to server subscribers
|
||||
getMessage :: StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg)
|
||||
getMessage q qr = time "GET" $ do
|
||||
@@ -1660,7 +1667,7 @@ client
|
||||
pure $ SOK $ Just serviceId
|
||||
| otherwise ->
|
||||
-- new or updated queue-service association
|
||||
liftIO (setQueueService (queueStore ms) q SNotifier (Just serviceId)) >>= \case
|
||||
liftIO (setQueueService (queueStore ms) q SNotifierService (Just serviceId)) >>= \case
|
||||
Left e -> pure $ ERR e
|
||||
Right () -> do
|
||||
hasSub <- atomically $ (<$ newServiceQueueSub) =<< hasServiceSub
|
||||
@@ -1677,7 +1684,7 @@ client
|
||||
modifyTVar' (totalServiceSubs ntfSubscribers) (+ 1) -- server count for all services
|
||||
Nothing -> case ntfServiceId of
|
||||
Just _ ->
|
||||
liftIO (setQueueService (queueStore ms) q SNotifier Nothing) >>= \case
|
||||
liftIO (setQueueService (queueStore ms) q SNotifierService Nothing) >>= \case
|
||||
Left e -> pure $ ERR e
|
||||
Right () -> do
|
||||
-- hasSubscription should never be True in this branch, because queue was associated with service.
|
||||
@@ -1900,7 +1907,7 @@ client
|
||||
let clntTHParams = smpTHParamsSetVersion fwdVersion thParams'
|
||||
-- only allowing single forwarded transactions
|
||||
t' <- case tParse clntTHParams b of
|
||||
t :| [] -> pure $ tDecodeParseValidate clntTHParams t
|
||||
t :| [] -> pure $ tDecodeServer clntTHParams t
|
||||
_ -> throwE BLOCK
|
||||
let clntThAuth = Just $ THAuthServer {serverPrivKey, peerClientService = Nothing, sessSecret' = Just clientSecret}
|
||||
-- process forwarded command
|
||||
@@ -1925,23 +1932,22 @@ client
|
||||
incStat $ pMsgFwdsRecv stats
|
||||
pure r3
|
||||
where
|
||||
rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmission ErrorType Cmd -> M s (Either (Transmission BrokerMsg) (Maybe (StoreQueue s, QueueRec), Transmission Cmd))
|
||||
rejectOrVerify clntThAuth (tAuth, authorized, (corrId', entId', cmdOrError)) =
|
||||
case cmdOrError of
|
||||
Left e -> pure $ Left (corrId', entId', ERR e)
|
||||
Right cmd'
|
||||
| allowed -> liftIO $ verified <$> verifyTransmission ms Nothing ((,C.cbNonce (bs corrId')) <$> clntThAuth) tAuth authorized entId' cmd'
|
||||
| otherwise -> pure $ Left (corrId', entId', ERR $ CMD PROHIBITED)
|
||||
where
|
||||
allowed = case cmd' of
|
||||
Cmd SSender SEND {} -> True
|
||||
Cmd SSender (SKEY _) -> True
|
||||
Cmd SSenderLink (LKEY _) -> True
|
||||
Cmd SSenderLink LGET -> True
|
||||
_ -> False
|
||||
verified = \case
|
||||
VRVerified q -> Right (q, (corrId', entId', cmd'))
|
||||
VRFailed e -> Left (corrId', entId', ERR e)
|
||||
rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmissionOrError ErrorType Cmd -> M s (VerifiedTransmissionOrError s)
|
||||
rejectOrVerify clntThAuth = \case
|
||||
Left (corrId', entId', e) -> pure $ Left (corrId', entId', ERR e)
|
||||
Right t'@(_, _, t''@(corrId', entId', cmd'))
|
||||
| allowed -> liftIO $ verified <$> verifyTransmission ms Nothing clntThAuth t'
|
||||
| otherwise -> pure $ Left (corrId', entId', ERR $ CMD PROHIBITED)
|
||||
where
|
||||
allowed = case cmd' of
|
||||
Cmd SSender SEND {} -> True
|
||||
Cmd SSender (SKEY _) -> True
|
||||
Cmd SSenderLink (LKEY _) -> True
|
||||
Cmd SSenderLink LGET -> True
|
||||
_ -> False
|
||||
verified = \case
|
||||
VRVerified q -> Right (q, t'')
|
||||
VRFailed e -> Left (corrId', entId', ERR e)
|
||||
|
||||
deliverMessage :: T.Text -> QueueRec -> RecipientId -> Sub -> Maybe Message -> IO (Transmission BrokerMsg)
|
||||
deliverMessage name qr rId s@Sub {subThread} msg_ = time (name <> " deliver") . atomically $
|
||||
|
||||
@@ -39,6 +39,7 @@ module Simplex.Messaging.Server.Env.STM
|
||||
MsgStoreType,
|
||||
MsgStore (..),
|
||||
AStoreType (..),
|
||||
VerifiedTransmission,
|
||||
newEnv,
|
||||
mkJournalStoreConfig,
|
||||
msgStore,
|
||||
@@ -390,7 +391,7 @@ data Client s = Client
|
||||
ntfSubscriptions :: TMap NotifierId (),
|
||||
serviceSubsCount :: TVar Int64, -- only one service can be subscribed, based on its certificate, this is subscription count
|
||||
ntfServiceSubsCount :: TVar Int64, -- only one service can be subscribed, based on its certificate, this is subscription count
|
||||
rcvQ :: TBQueue (NonEmpty (Maybe (StoreQueue s, QueueRec), Transmission Cmd)),
|
||||
rcvQ :: TBQueue (NonEmpty (VerifiedTransmission s)),
|
||||
sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)),
|
||||
msgQ :: TBQueue (NonEmpty (Transmission BrokerMsg)),
|
||||
procThreads :: TVar Int,
|
||||
@@ -403,6 +404,8 @@ data Client s = Client
|
||||
sndActiveAt :: TVar SystemTime
|
||||
}
|
||||
|
||||
type VerifiedTransmission s = (Maybe (StoreQueue s, QueueRec), Transmission Cmd)
|
||||
|
||||
data ServerSub = ServerSub (TVar SubscriptionThread) | ProhibitSub
|
||||
|
||||
data SubscriptionThread = NoSub | SubPending | SubThread (Weak ThreadId)
|
||||
|
||||
@@ -324,6 +324,8 @@ instance QueueStoreClass (JournalQueue s) (QStore s) where
|
||||
{-# INLINE addQueue_ #-}
|
||||
getQueue_ = withQS getQueue_
|
||||
{-# INLINE getQueue_ #-}
|
||||
getQueues_ = withQS getQueues_
|
||||
{-# INLINE getQueues_ #-}
|
||||
addQueueLinkData = withQS addQueueLinkData
|
||||
{-# INLINE addQueueLinkData #-}
|
||||
getQueueLinkData = withQS getQueueLinkData
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
module Simplex.Messaging.Server.MsgStore.Types where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
@@ -107,14 +108,23 @@ addQueue :: MsgStoreClass s => s -> RecipientId -> QueueRec -> IO (Either ErrorT
|
||||
addQueue st = addQueue_ (queueStore st) (mkQueue st True)
|
||||
{-# INLINE addQueue #-}
|
||||
|
||||
getQueue :: (MsgStoreClass s, DirectParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s))
|
||||
getQueue :: (MsgStoreClass s, QueueParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s))
|
||||
getQueue st = getQueue_ (queueStore st) (mkQueue st)
|
||||
{-# INLINE getQueue #-}
|
||||
|
||||
getQueueRec :: (MsgStoreClass s, DirectParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s, QueueRec))
|
||||
getQueueRec st party qId =
|
||||
getQueue st party qId
|
||||
$>>= (\q -> maybe (Left AUTH) (Right . (q,)) <$> readTVarIO (queueRec q))
|
||||
getQueueRec :: (MsgStoreClass s, QueueParty p) => s -> SParty p -> QueueId -> IO (Either ErrorType (StoreQueue s, QueueRec))
|
||||
getQueueRec st party qId = getQueue st party qId $>>= readQueueRec
|
||||
|
||||
getQueues :: (MsgStoreClass s, BatchParty p) => s -> SParty p -> [QueueId] -> IO [Either ErrorType (StoreQueue s)]
|
||||
getQueues st = getQueues_ (queueStore st) (mkQueue st)
|
||||
{-# INLINE getQueues #-}
|
||||
|
||||
getQueueRecs :: (MsgStoreClass s, BatchParty p) => s -> SParty p -> [QueueId] -> IO [Either ErrorType (StoreQueue s, QueueRec)]
|
||||
getQueueRecs st party qIds = getQueues st party qIds >>= mapM (fmap join . mapM readQueueRec)
|
||||
|
||||
readQueueRec :: StoreQueueClass q => q -> IO (Either ErrorType (q, QueueRec))
|
||||
readQueueRec q = maybe (Left AUTH) (Right . (q,)) <$> readTVarIO (queueRec q)
|
||||
{-# INLINE readQueueRec #-}
|
||||
|
||||
getQueueSize :: MsgStoreClass s => s -> StoreQueue s -> ExceptT ErrorType IO Int
|
||||
getQueueSize st q = withPeekMsgQueue st q "getQueueSize" $ maybe (pure 0) (getQueueSize_ . fst)
|
||||
|
||||
@@ -37,18 +37,20 @@ import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Builder (Builder)
|
||||
import qualified Data.ByteString.Builder as BB
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Lazy as LB
|
||||
import Data.Bitraversable (bimapM)
|
||||
import Data.Either (fromRight)
|
||||
import Data.Either (fromRight, lefts, rights)
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.List (foldl', intersperse, partition)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (catMaybes, fromMaybe)
|
||||
import Data.Maybe (catMaybes, fromMaybe, mapMaybe)
|
||||
import qualified Data.Set as S
|
||||
import Data.Text (Text)
|
||||
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
||||
@@ -62,7 +64,7 @@ import Database.PostgreSQL.Simple.ToField (Action (..), ToField (..))
|
||||
import Database.PostgreSQL.Simple.Errors (ConstraintViolation (..), constraintViolation)
|
||||
import Database.PostgreSQL.Simple.SqlQQ (sql)
|
||||
import GHC.IO (catchAny)
|
||||
import Simplex.Messaging.Agent.Client (withLockMap)
|
||||
import Simplex.Messaging.Agent.Client (withLockMap, withLocksMap)
|
||||
import Simplex.Messaging.Agent.Lock (Lock)
|
||||
import Simplex.Messaging.Agent.Store.AgentStore ()
|
||||
import Simplex.Messaging.Agent.Store.Postgres (createDBStore, closeDBStore)
|
||||
@@ -81,7 +83,7 @@ import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (SMPServiceRole (..))
|
||||
import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>))
|
||||
import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>), ($>>=))
|
||||
import System.Exit (exitFailure)
|
||||
import System.IO (IOMode (..), hFlush, stdout)
|
||||
import UnliftIO.STM
|
||||
@@ -180,18 +182,16 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
-- hasId = anyM [TM.memberIO rId queues, TM.memberIO senderId senders, hasNotifier]
|
||||
-- hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.memberIO notifierId notifiers) notifier
|
||||
|
||||
getQueue_ :: DirectParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
|
||||
getQueue_ :: QueueParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
|
||||
getQueue_ st mkQ party qId = case party of
|
||||
SRecipient -> getRcvQueue qId
|
||||
SSender -> getSndQueue
|
||||
SProxyService -> getSndQueue
|
||||
SSender -> TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue
|
||||
SSenderLink -> TM.lookupIO qId links >>= maybe (mask loadLinkQueue) getRcvQueue
|
||||
-- loaded queue is deleted from notifiers map to reduce cache size after queue was subscribed to by ntf server
|
||||
SNotifier -> TM.lookupIO qId notifiers >>= maybe (mask loadNtfQueue) (getRcvQueue >=> (atomically (TM.delete qId notifiers) $>))
|
||||
where
|
||||
PostgresQueueStore {queues, senders, links, notifiers} = st
|
||||
getRcvQueue rId = TM.lookupIO rId queues >>= maybe (mask loadRcvQueue) (pure . Right)
|
||||
getSndQueue = TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue
|
||||
loadRcvQueue = do
|
||||
(rId, qRec) <- loadQueue " WHERE recipient_id = ?"
|
||||
liftIO $ cacheQueue rId qRec $ \_ -> pure () -- recipient map already checked, not caching sender ref
|
||||
@@ -228,6 +228,47 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
TM.insert rId sq queues
|
||||
pure sq
|
||||
|
||||
getQueues_ :: forall p. BatchParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q]
|
||||
getQueues_ st mkQ party qIds = case party of
|
||||
SRecipient -> do
|
||||
qs <- readTVarIO queues
|
||||
let qs' = map (\qId -> get qs qId qId) qIds
|
||||
E.uninterruptibleMask_ $ loadQueues qs' " WHERE recipient_id IN ?" cacheRcvQueue
|
||||
SNotifier -> do
|
||||
ns <- readTVarIO notifiers
|
||||
qs <- readTVarIO queues
|
||||
let qs' = map (\qId -> get ns qId qId >>= get qs qId) qIds
|
||||
E.uninterruptibleMask_ $ loadQueues qs' " WHERE notifier_id IN ?" $ \(rId, qRec) ->
|
||||
forM (notifier qRec) $ \NtfCreds {notifierId = nId} -> -- it is always Just with this query
|
||||
(nId,) <$> maybe (mkQ False rId qRec) pure (M.lookup rId qs)
|
||||
where
|
||||
PostgresQueueStore {queues, notifiers} = st
|
||||
get :: M.Map QueueId a -> QueueId -> QueueId -> Either QueueId a
|
||||
get m qId = maybe (Left qId) Right . (`M.lookup` m)
|
||||
loadQueues :: [Either QueueId q] -> Query -> ((RecipientId, QueueRec) -> IO (Maybe (QueueId, q))) -> IO [Either ErrorType q]
|
||||
loadQueues qs' cond mkCacheQueue = do
|
||||
let qIds' = lefts qs'
|
||||
if null qIds'
|
||||
then pure $ map (first (const INTERNAL)) qs'
|
||||
else do
|
||||
qs_ <-
|
||||
runExceptT $ fmap M.fromList $
|
||||
withDB' "getQueues_" st (\db -> DB.query db (queueRecQuery <> cond <> " AND deleted_at IS NULL") (Only (In qIds')))
|
||||
>>= liftIO . fmap catMaybes . mapM (mkCacheQueue . rowToQueueRec)
|
||||
pure $ map (result qs_) qs'
|
||||
where
|
||||
result :: Either ErrorType (M.Map QueueId q) -> Either QueueId q -> Either ErrorType q
|
||||
result _ (Right q) = Right q
|
||||
result qs_ (Left qId) = maybe (Left AUTH) Right . M.lookup qId =<< qs_
|
||||
cacheRcvQueue (rId, qRec) = do
|
||||
sq <- mkQ True rId qRec
|
||||
sq' <- withQueueLock sq "getQueue_" $ atomically $
|
||||
-- checking the cache again for concurrent reads, use previously loaded queue if exists.
|
||||
TM.lookup rId queues >>= \case
|
||||
Just sq' -> pure sq'
|
||||
Nothing -> sq <$ TM.insert rId sq queues
|
||||
pure $ Just (rId, sq')
|
||||
|
||||
getQueueLinkData :: PostgresQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData)
|
||||
getQueueLinkData st sq lnkId = runExceptT $ do
|
||||
qr <- ExceptT $ readQueueRecIO $ queueRec sq
|
||||
@@ -311,7 +352,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
db
|
||||
[sql|
|
||||
UPDATE msg_queues
|
||||
SET notifier_id = ?, notifier_key = ?, rcv_ntf_dh_secret = ?
|
||||
SET notifier_id = ?, notifier_key = ?, rcv_ntf_dh_secret = ?, ntf_service_id = NULL
|
||||
WHERE recipient_id = ? AND deleted_at IS NULL
|
||||
|]
|
||||
(nId, notifierKey, rcvNtfDhSecret, rId)
|
||||
@@ -333,7 +374,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
db
|
||||
[sql|
|
||||
UPDATE msg_queues
|
||||
SET notifier_id = NULL, notifier_key = NULL, rcv_ntf_dh_secret = NULL
|
||||
SET notifier_id = NULL, notifier_key = NULL, rcv_ntf_dh_secret = NULL, ntf_service_id = NULL
|
||||
WHERE recipient_id = ? AND deleted_at IS NULL
|
||||
|]
|
||||
(Only rId)
|
||||
@@ -402,15 +443,15 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
when new $ withLog "getCreateService" st (`logNewService` sr)
|
||||
pure serviceId
|
||||
|
||||
setQueueService :: (PartyI p, SubscriberParty p) => PostgresQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
|
||||
setQueueService :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
|
||||
setQueueService st sq party serviceId = withQueueRec sq "setQueueService" $ \q -> case party of
|
||||
SRecipient
|
||||
SRecipientService
|
||||
| rcvServiceId q == serviceId -> pure ()
|
||||
| otherwise -> do
|
||||
assertUpdated $ withDB' "setQueueService" st $ \db ->
|
||||
DB.execute db "UPDATE msg_queues SET rcv_service_id = ? WHERE recipient_id = ? AND deleted_at IS NULL" (serviceId, rId)
|
||||
updateQueueRec q {rcvServiceId = serviceId}
|
||||
SNotifier -> case notifier q of
|
||||
SNotifierService -> case notifier q of
|
||||
Nothing -> throwE AUTH
|
||||
Just nc@NtfCreds {ntfServiceId = prevSrvId}
|
||||
| prevSrvId == serviceId -> pure ()
|
||||
|
||||
@@ -128,17 +128,29 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.member notifierId notifiers) notifier
|
||||
hasLink = maybe (pure False) (\(lnkId, _) -> TM.member lnkId links) queueData
|
||||
|
||||
getQueue_ :: DirectParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
|
||||
getQueue_ :: QueueParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
|
||||
getQueue_ st _ party qId =
|
||||
maybe (Left AUTH) Right <$> case party of
|
||||
SRecipient -> TM.lookupIO qId queues
|
||||
SSender -> getSndQueue
|
||||
SProxyService -> getSndQueue
|
||||
SSender -> TM.lookupIO qId senders $>>= (`TM.lookupIO` queues)
|
||||
SNotifier -> TM.lookupIO qId notifiers $>>= (`TM.lookupIO` queues)
|
||||
SSenderLink -> TM.lookupIO qId links $>>= (`TM.lookupIO` queues)
|
||||
where
|
||||
STMQueueStore {queues, senders, notifiers, links} = st
|
||||
getSndQueue = TM.lookupIO qId senders $>>= (`TM.lookupIO` queues)
|
||||
|
||||
getQueues_ :: BatchParty p => STMQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q]
|
||||
getQueues_ st _ party qIds = case party of
|
||||
SRecipient -> do
|
||||
qs <- readTVarIO queues
|
||||
pure $ map (get qs) qIds
|
||||
SNotifier -> do
|
||||
ns <- readTVarIO notifiers
|
||||
qs <- readTVarIO queues
|
||||
pure $ map (get qs <=< get ns) qIds
|
||||
where
|
||||
STMQueueStore {queues, notifiers} = st
|
||||
get :: M.Map QueueId a -> QueueId -> Either ErrorType a
|
||||
get m = maybe (Left AUTH) Right . (`M.lookup` m)
|
||||
|
||||
getQueueLinkData :: STMQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData)
|
||||
getQueueLinkData _ q lnkId = atomically $ readQueueRec (queueRec q) $>>= pure . getData
|
||||
@@ -292,7 +304,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
serviceNtfQueues <- newTVar S.empty
|
||||
pure STMService {serviceRec = sr, serviceRcvQueues, serviceNtfQueues}
|
||||
|
||||
setQueueService :: (PartyI p, SubscriberParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
|
||||
setQueueService :: (PartyI p, ServiceParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
|
||||
setQueueService st sq party serviceId =
|
||||
atomically (readQueueRec qr $>>= setService)
|
||||
$>> withLog "setQueueService" st (\sl -> logQueueService sl rId party serviceId)
|
||||
@@ -301,13 +313,13 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
rId = recipientId sq
|
||||
setService :: QueueRec -> STM (Either ErrorType ())
|
||||
setService q@QueueRec {rcvServiceId = prevSrvId} = case party of
|
||||
SRecipient
|
||||
SRecipientService
|
||||
| prevSrvId == serviceId -> pure $ Right ()
|
||||
| otherwise -> do
|
||||
updateServiceQueues serviceRcvQueues rId prevSrvId
|
||||
let !q' = Just q {rcvServiceId = serviceId}
|
||||
writeTVar qr q' $> Right ()
|
||||
SNotifier -> case notifier q of
|
||||
SNotifierService -> case notifier q of
|
||||
Nothing -> pure $ Left AUTH
|
||||
Just nc@NtfCreds {notifierId = nId, ntfServiceId = prevNtfSrvId}
|
||||
| prevNtfSrvId == serviceId -> pure $ Right ()
|
||||
|
||||
@@ -31,7 +31,8 @@ class StoreQueueClass q => QueueStoreClass q s where
|
||||
loadedQueues :: s -> TMap RecipientId q
|
||||
compactQueues :: s -> IO Int64
|
||||
addQueue_ :: s -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q)
|
||||
getQueue_ :: DirectParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
|
||||
getQueue_ :: QueueParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
|
||||
getQueues_ :: BatchParty p => s -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q]
|
||||
getQueueLinkData :: s -> q -> LinkId -> IO (Either ErrorType QueueLinkData)
|
||||
addQueueLinkData :: s -> q -> LinkId -> QueueLinkData -> IO (Either ErrorType ())
|
||||
deleteQueueLinkData :: s -> q -> IO (Either ErrorType ())
|
||||
@@ -45,7 +46,7 @@ class StoreQueueClass q => QueueStoreClass q s where
|
||||
updateQueueTime :: s -> q -> RoundedSystemTime -> IO (Either ErrorType QueueRec)
|
||||
deleteStoreQueue :: s -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q)))
|
||||
getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId)
|
||||
setQueueService :: (PartyI p, SubscriberParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
|
||||
setQueueService :: (PartyI p, ServiceParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
|
||||
getQueueNtfServices :: s -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)]))
|
||||
getNtfServiceQueueCount :: s -> ServiceId -> IO (Either ErrorType Int64)
|
||||
|
||||
|
||||
@@ -286,7 +286,7 @@ logUpdateQueueTime s qId t = writeStoreLogRecord s $ UpdateTime qId t
|
||||
logNewService :: StoreLog 'WriteMode -> ServiceRec -> IO ()
|
||||
logNewService s = writeStoreLogRecord s . NewService
|
||||
|
||||
logQueueService :: (PartyI p, SubscriberParty p) => StoreLog 'WriteMode -> RecipientId -> SParty p -> Maybe ServiceId -> IO ()
|
||||
logQueueService :: (PartyI p, ServiceParty p) => StoreLog 'WriteMode -> RecipientId -> SParty p -> Maybe ServiceId -> IO ()
|
||||
logQueueService s rId party = writeStoreLogRecord s . QueueService rId (ASP party)
|
||||
|
||||
readWriteStoreLog :: (FilePath -> s -> IO ()) -> (StoreLog 'WriteMode -> s -> IO ()) -> FilePath -> s -> IO (StoreLog 'WriteMode)
|
||||
|
||||
@@ -122,7 +122,7 @@ storeLogTests =
|
||||
},
|
||||
SLTC
|
||||
{ name = "create queue, add notifier, register and associate notification service",
|
||||
saved = [CreateQueue rId qr, AddNotifier rId ntfCreds, NewService sr, QueueService rId (ASP SNotifier) (Just serviceId)],
|
||||
saved = [CreateQueue rId qr, AddNotifier rId ntfCreds, NewService sr, QueueService rId (ASP SNotifierService) (Just serviceId)],
|
||||
compacted = [NewService sr, CreateQueue rId qr {notifier = Just ntfCreds {ntfServiceId = Just serviceId}}],
|
||||
state = M.fromList [(rId, qr {notifier = Just ntfCreds {ntfServiceId = Just serviceId}})]
|
||||
},
|
||||
|
||||
@@ -206,7 +206,7 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
|
||||
[Right ()] <- tPut h [Right (sig, t')]
|
||||
pure ()
|
||||
tGet' h = do
|
||||
[(Nothing, _, (CorrId corrId, EntityId qId, Right cmd))] <- tGet h
|
||||
[(CorrId corrId, EntityId qId, Right cmd)] <- tGetClient h
|
||||
pure (Nothing, corrId, qId, cmd)
|
||||
|
||||
ntfTest :: Transport c => TProxy c 'TServer -> (THandleNTF c 'TClient -> IO ()) -> Expectation
|
||||
|
||||
@@ -72,18 +72,18 @@ ntfSyntaxTests (ATransport t) = do
|
||||
Expectation
|
||||
command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response
|
||||
|
||||
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission ErrorType NtfResponse
|
||||
pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command))
|
||||
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> Transmission (Either ErrorType NtfResponse)
|
||||
pattern RespNtf corrId queueId command <- (corrId, queueId, Right command)
|
||||
|
||||
deriving instance Eq NtfResponse
|
||||
|
||||
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> (Maybe TAuthorizations, ByteString, NtfEntityId, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
|
||||
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> (Maybe TAuthorizations, ByteString, NtfEntityId, NtfCommand e) -> IO (Transmission (Either ErrorType NtfResponse))
|
||||
sendRecvNtf h@THandle {params} (sgn, corrId, qId, cmd) = do
|
||||
let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (sgn, tToSend)
|
||||
tGet1 h
|
||||
|
||||
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> C.APrivateAuthKey -> (ByteString, NtfEntityId, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
|
||||
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> C.APrivateAuthKey -> (ByteString, NtfEntityId, NtfCommand e) -> IO (Transmission (Either ErrorType NtfResponse))
|
||||
signSendRecvNtf h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do
|
||||
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (authorize tForAuth, tToSend)
|
||||
|
||||
@@ -394,7 +394,7 @@ smpServerTest _ t = runSmpTest (ASType SQSMemory SMSJournal) $ \h -> tPut' h t >
|
||||
[Right ()] <- tPut h [Right (sig, t')]
|
||||
pure ()
|
||||
tGet' h = do
|
||||
[(Nothing, _, (CorrId corrId, EntityId qId, Right cmd))] <- tGet h
|
||||
[(CorrId corrId, EntityId qId, Right cmd)] <- tGetClient h
|
||||
pure (Nothing, corrId, qId, cmd)
|
||||
|
||||
smpTest :: (HasCallStack, Transport c) => TProxy c 'TServer -> AStoreType -> (HasCallStack => THandleSMP c 'TClient -> IO ()) -> Expectation
|
||||
|
||||
@@ -435,14 +435,14 @@ testNoProxy :: AStoreType -> IO ()
|
||||
testNoProxy msType = do
|
||||
withSmpServerConfigOn (transport @TLS) (cfgMS msType) testPort2 $ \_ -> do
|
||||
testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do
|
||||
(_, _, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer Nothing)
|
||||
(_, _, reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer Nothing)
|
||||
reply `shouldBe` Right (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH)
|
||||
|
||||
testProxyAuth :: AStoreType -> IO ()
|
||||
testProxyAuth msType = do
|
||||
withSmpServerConfigOn (transport @TLS) proxyCfgAuth testPort $ \_ -> do
|
||||
testSMPClient_ "127.0.0.1" testPort proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do
|
||||
(_, _s, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer2 $ Just "wrong")
|
||||
(_, _, reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PRXY testSMPServer2 $ Just "wrong")
|
||||
reply `shouldBe` Right (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH)
|
||||
where
|
||||
proxyCfgAuth = updateCfg (proxyCfgMS msType) $ \cfg_ -> cfg_ {newQueueBasicAuth = Just "correct"}
|
||||
|
||||
@@ -92,10 +92,10 @@ serverTests = do
|
||||
testInvQueueLinkData
|
||||
testContactQueueLinkData
|
||||
|
||||
pattern Resp :: CorrId -> QueueId -> BrokerMsg -> SignedTransmission ErrorType BrokerMsg
|
||||
pattern Resp corrId queueId command <- (_, _, (corrId, queueId, Right command))
|
||||
pattern Resp :: CorrId -> QueueId -> BrokerMsg -> Transmission (Either ErrorType BrokerMsg)
|
||||
pattern Resp corrId queueId command <- (corrId, queueId, Right command)
|
||||
|
||||
pattern New :: RcvPublicAuthKey -> RcvPublicDhKey -> Command 'Recipient
|
||||
pattern New :: RcvPublicAuthKey -> RcvPublicDhKey -> Command 'Creator
|
||||
pattern New rPub dhPub = NEW (NewQueueReq rPub dhPub Nothing SMSubscribe (Just (QRMessaging Nothing)))
|
||||
|
||||
pattern Ids :: RecipientId -> SenderId -> RcvPublicDhKey -> BrokerMsg
|
||||
@@ -104,19 +104,19 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh _sndSecure _linkId Nothing)
|
||||
pattern Msg :: MsgId -> MsgBody -> BrokerMsg
|
||||
pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body}
|
||||
|
||||
sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> (Maybe TAuthorizations, ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
|
||||
sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> (Maybe TAuthorizations, ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg))
|
||||
sendRecv h@THandle {params} (sgn, corrId, qId, cmd) = do
|
||||
let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (sgn, tToSend)
|
||||
tGet1 h
|
||||
|
||||
signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> (ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
|
||||
signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> (ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg))
|
||||
signSendRecv h pk = signSendRecv_ h pk Nothing
|
||||
|
||||
serviceSignSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
|
||||
serviceSignSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg))
|
||||
serviceSignSendRecv h pk = signSendRecv_ h pk . Just
|
||||
|
||||
signSendRecv_ :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> Maybe C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
|
||||
signSendRecv_ :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> Maybe C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (Transmission (Either ErrorType BrokerMsg))
|
||||
signSendRecv_ h@THandle {params} (C.APrivateAuthKey a pk) serviceKey_ (corrId, qId, cmd) = do
|
||||
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (authorize tForAuth, tToSend)
|
||||
@@ -139,9 +139,9 @@ tPut1 h t = do
|
||||
[r] <- tPut h [Right t]
|
||||
pure r
|
||||
|
||||
tGet1 :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TClient -> IO (SignedTransmission err cmd)
|
||||
tGet1 :: (ProtocolEncoding v err cmd, Transport c) => THandle v c 'TClient -> IO (Transmission (Either err cmd))
|
||||
tGet1 h = do
|
||||
[r] <- liftIO $ tGet h
|
||||
[r] <- liftIO $ tGetClient h
|
||||
pure r
|
||||
|
||||
(#==) :: (HasCallStack, Eq a, Show a) => (a, a) -> String -> Assertion
|
||||
@@ -519,7 +519,7 @@ testSwitchSub =
|
||||
Resp "" rId' DELD <- tGet1 rh2
|
||||
(rId', rId) #== "connection deleted event delivered to subscribed client"
|
||||
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case
|
||||
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else is delivered to the 1st TCP connection"
|
||||
|
||||
@@ -1017,7 +1017,7 @@ testMessageNotifications =
|
||||
Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2)
|
||||
(dec mId2 msg2, Right "hello again") #== "delivered from queue again"
|
||||
Resp "" _ (NMSG _ _) <- tGet1 nh2
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case
|
||||
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection"
|
||||
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL)
|
||||
@@ -1027,7 +1027,7 @@ testMessageNotifications =
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet1 rh
|
||||
(dec mId3 msg3, Right "hello there") #== "delivered from queue again"
|
||||
Resp "7a" _ OK <- signSendRecv rh rKey ("7a", rId, ACK mId3)
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case
|
||||
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection"
|
||||
(nPub'', nKey'') <- atomically $ C.generateAuthKeyPair C.SEd25519 g
|
||||
@@ -1069,7 +1069,7 @@ testMessageServiceNotifications =
|
||||
Resp "" serviceId2 (ENDS 1) <- tGet1 nh1
|
||||
serviceId2 `shouldBe` serviceId
|
||||
deliverMessage rh rId rKey sh sId sKey nh2 "hello again" dec
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case
|
||||
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection"
|
||||
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL)
|
||||
@@ -1079,7 +1079,7 @@ testMessageServiceNotifications =
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet1 rh
|
||||
(dec mId3 msg3, Right "hello there") #== "delivered from queue again"
|
||||
Resp "7a" _ OK <- signSendRecv rh rKey ("7a", rId, ACK mId3)
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case
|
||||
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection"
|
||||
-- new notification credentials
|
||||
@@ -1133,7 +1133,7 @@ testMsgExpireOnSend =
|
||||
testSMPClient @c $ \rh -> do
|
||||
Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB)
|
||||
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
|
||||
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else should be delivered"
|
||||
|
||||
@@ -1153,7 +1153,7 @@ testMsgExpireOnInterval =
|
||||
signSendRecv rh rKey ("2", rId, SUB) >>= \case
|
||||
Resp "2" _ OK -> pure ()
|
||||
r -> unexpected r
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
|
||||
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing should be delivered"
|
||||
|
||||
@@ -1172,7 +1172,7 @@ testMsgNOTExpireOnInterval =
|
||||
testSMPClient @c $ \rh -> do
|
||||
Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB)
|
||||
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
|
||||
1000 `timeout` tGetClient @SMPVersion @ErrorType @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else should be delivered"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user