parameterize protocol by error type (#644)

This commit is contained in:
Evgeny Poberezkin
2023-02-17 20:46:01 +00:00
committed by GitHub
parent 2ae3100bed
commit 2ddfb044fc
12 changed files with 216 additions and 176 deletions

View File

@@ -5,6 +5,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
@@ -162,7 +163,7 @@ import UnliftIO (mapConcurrently)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
type ClientVar msg = TMVar (Either AgentErrorType (ProtocolClient msg))
type ClientVar err msg = TMVar (Either AgentErrorType (ProtocolClient err msg))
type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
@@ -319,15 +320,15 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
agentClientStore :: AgentClient -> SQLiteStore
agentClientStore AgentClient {agentEnv = Env {store}} = store
class ProtocolServerClient msg where
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient msg)
clientProtocolError :: ErrorType -> AgentErrorType
class ProtocolServerClient err msg | msg -> err where
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient err msg)
clientProtocolError :: err -> AgentErrorType
instance ProtocolServerClient BrokerMsg where
instance ProtocolServerClient ErrorType BrokerMsg where
getProtocolServerClient = getSMPServerClient
clientProtocolError = SMP
instance ProtocolServerClient NtfResponse where
instance ProtocolServerClient ErrorType NtfResponse where
getProtocolServerClient = getNtfServerClient
clientProtocolError = NTF
@@ -428,7 +429,7 @@ getClientVar tSess clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM
TM.insert tSess var clients
pure var
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (ProtocolClient msg)
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar err msg -> m (ProtocolClient err msg)
waitForProtocolClient c (_, srv, _) clientVar = do
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
@@ -438,18 +439,18 @@ waitForProtocolClient c (_, srv, _) clientVar = do
Nothing -> Left $ BROKER (B.unpack $ strEncode srv) TIMEOUT
newProtocolClient ::
forall msg m.
forall err msg m.
(AgentMonad m, ProtocolTypeI (ProtoType msg)) =>
AgentClient ->
TransportSession msg ->
TMap (TransportSession msg) (ClientVar msg) ->
m (ProtocolClient msg) ->
TMap (TransportSession msg) (ClientVar err msg) ->
m (ProtocolClient err msg) ->
(AgentClient -> TransportSession msg -> m ()) ->
ClientVar msg ->
m (ProtocolClient msg)
ClientVar err msg ->
m (ProtocolClient err msg)
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
where
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
tryConnectClient :: (ProtocolClient err msg -> m a) -> m () -> m a
tryConnectClient successAction retryAction =
tryError connectClient >>= \r -> case r of
Right client -> do
@@ -474,7 +475,7 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconne
withRetryInterval ri $ \loop -> void $ tryConnectClient (const $ reconnectClient c tSess) loop
atomically . removeAsyncAction aId $ asyncClients c
hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent
hostEvent :: forall err msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient err msg -> ACommand 'Agent
hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost' client
getClientConfig :: AgentMonad m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig
@@ -518,7 +519,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} =
where
k = (server, sndId)
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar err msg)) -> IO ()
closeProtocolServerClients c clientsSel =
atomically (swapTVar cs M.empty) >>= mapM_ (forkIO . closeClient)
where
@@ -541,32 +542,32 @@ withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pur
where
newLock = createLock >>= \l -> TM.insert key l locks $> l
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
withClient_ :: forall a m err msg. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient err msg -> m a) -> m a
withClient_ c tSess@(userId, srv, _) statCmd action = do
cl <- getProtocolServerClient c tSess
(action cl <* stat cl "OK") `catchError` logServerError cl
where
stat cl = liftIO . incClientStat c userId cl statCmd
logServerError :: ProtocolClient msg -> AgentErrorType -> m a
logServerError :: ProtocolClient err msg -> AgentErrorType -> m a
logServerError cl e = do
logServer "<--" c srv "" $ strEncode e
stat cl $ strEncode e
throwError e
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> m a) -> m a
withLogClient_ :: (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient err msg -> m a) -> m a
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
logServer "-->" c srv entId cmdStr
res <- withClient_ c tSess cmdStr action
logServer "<--" c srv entId "OK"
return res
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg, ProtocolTypeI (ProtoType msg), Encoding err, Show err) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient err msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withLogClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg, ProtocolTypeI (ProtoType msg), Encoding err, Show err) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient err msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT ProtocolClientError IO a) -> m a
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
withSMPClient c q cmdStr action = do
tSess <- mkSMPTransportSession c q
withLogClient c tSess (queueId q) cmdStr action
@@ -576,16 +577,16 @@ withSMPClient_ c q cmdStr action = do
tSess <- mkSMPTransportSession c q
withLogClient_ c tSess (queueId q) cmdStr action
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT ProtocolClientError IO a) -> m a
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT NtfClientError IO a) -> m a
withNtfClient c srv = withLogClient c (0, srv, Nothing)
liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> HostName -> ExceptT ProtocolClientError IO a -> m a
liftClient :: (AgentMonad m, Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> m a
liftClient protocolError_ = liftError . protocolClientError protocolError_
protocolClientError :: (ErrorType -> AgentErrorType) -> HostName -> ProtocolClientError -> AgentErrorType
protocolClientError :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ProtocolClientError err -> AgentErrorType
protocolClientError protocolError_ host = \case
PCEProtocolError e -> protocolError_ e
PCEResponseError e -> BROKER host $ RESPONSE e
PCEResponseError e -> BROKER host $ RESPONSE $ B.unpack $ smpEncode e
PCEUnexpectedResponse _ -> BROKER host UNEXPECTED
PCEResponseTimeout -> BROKER host TIMEOUT
PCENetworkError -> BROKER host NETWORK
@@ -632,7 +633,7 @@ runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
Left e -> pure (Just $ testErr TSConnect e)
where
addr = B.unpack $ strEncode srv
testErr :: SMPTestStep -> ProtocolClientError -> SMPTestFailure
testErr :: SMPTestStep -> SMPClientError -> SMPTestFailure
testErr step = SMPTestFailure step . protocolClientError SMP addr
mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg)
@@ -682,7 +683,7 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do
}
pure (rq, SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey)
processSubResult :: AgentClient -> RcvQueue -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
processSubResult :: AgentClient -> RcvQueue -> Either SMPClientError () -> IO (Either SMPClientError ())
processSubResult c rq r = do
case r of
Left e ->
@@ -691,7 +692,7 @@ processSubResult c rq r = do
_ -> addSubscription c rq
pure r
temporaryClientError :: ProtocolClientError -> Bool
temporaryClientError :: ProtocolClientError err -> Bool
temporaryClientError = \case
PCENetworkError -> True
PCEResponseTimeout -> True
@@ -732,7 +733,7 @@ subscribeQueues c qs = do
type BatchResponses e = (NonEmpty (RcvQueue, Either e ()))
-- statBatchSize is not used to batch the commands, only for traffic statistics
sendTSessionBatches :: forall m. AgentMonad m => ByteString -> Int -> (SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)) -> AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
sendTSessionBatches :: forall m. AgentMonad m => ByteString -> Int -> (SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError)) -> AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
sendTSessionBatches statCmd statBatchSize action c qs =
concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues)
where
@@ -759,7 +760,7 @@ sendTSessionBatches statCmd statBatchSize action c qs =
let n = (length qs - 1) `div` statBatchSize + 1
in incClientStatN c userId smp n (statCmd <> "S") "OK"
sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateSignKey, SMP.RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)
sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateSignKey, SMP.RecipientId) -> IO (NonEmpty (Either SMPClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError)
sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs)
where
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
@@ -1047,7 +1048,7 @@ incStat AgentClient {agentStats} n k = do
Just v -> modifyTVar' v (+ n)
_ -> newTVar n >>= \v -> TM.insert k v agentStats
incClientStat :: AgentClient -> UserId -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
incClientStat :: AgentClient -> UserId -> ProtocolClient err msg -> ByteString -> ByteString -> IO ()
incClientStat c userId pc = incClientStatN c userId pc 1
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
@@ -1057,7 +1058,7 @@ incServerStat c userId ProtocolServer {host} cmd res = do
where
statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res}
incClientStatN :: AgentClient -> UserId -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN :: AgentClient -> UserId -> ProtocolClient err msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN c userId pc n cmd res = do
atomically $ incStat c n statsKey
where

View File

@@ -1121,7 +1121,7 @@ instance ToJSON ConnectionErrorType where
-- | SMP server errors.
data BrokerErrorType
= -- | invalid server response (failed to parse)
RESPONSE {smpErr :: ErrorType}
RESPONSE {smpErr :: String}
| -- | unexpected response
UNEXPECTED
| -- | network error
@@ -1164,27 +1164,27 @@ instance StrEncoding AgentErrorType where
<|> "CONN " *> (CONN <$> parseRead1)
<|> "SMP " *> (SMP <$> strP)
<|> "NTF " *> (NTF <$> strP)
<|> "BROKER " *> (BROKER <$> srvP <* " RESPONSE " <*> (RESPONSE <$> strP))
<|> "BROKER " *> (BROKER <$> srvP <* " TRANSPORT " <*> (TRANSPORT <$> transportErrorP))
<|> "BROKER " *> (BROKER <$> srvP <* A.space <*> parseRead1)
<|> "BROKER " *> (BROKER <$> textP <* " RESPONSE " <*> (RESPONSE <$> textP))
<|> "BROKER " *> (BROKER <$> textP <* " TRANSPORT " <*> (TRANSPORT <$> transportErrorP))
<|> "BROKER " *> (BROKER <$> textP <* A.space <*> parseRead1)
<|> "AGENT QUEUE " *> (AGENT . A_QUEUE <$> parseRead A.takeByteString)
<|> "AGENT " *> (AGENT <$> parseRead1)
<|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString)
where
srvP = T.unpack . safeDecodeUtf8 <$> A.takeTill (== ' ')
textP = T.unpack . safeDecodeUtf8 <$> A.takeTill (== ' ')
strEncode = \case
CMD e -> "CMD " <> bshow e
CONN e -> "CONN " <> bshow e
SMP e -> "SMP " <> strEncode e
NTF e -> "NTF " <> strEncode e
BROKER srv (RESPONSE e) -> "BROKER " <> addr srv <> " RESPONSE " <> strEncode e
BROKER srv (TRANSPORT e) -> "BROKER " <> addr srv <> " TRANSPORT " <> serializeTransportError e
BROKER srv e -> "BROKER " <> addr srv <> " " <> bshow e
BROKER srv (RESPONSE e) -> "BROKER " <> text srv <> " RESPONSE " <> text e
BROKER srv (TRANSPORT e) -> "BROKER " <> text srv <> " TRANSPORT " <> serializeTransportError e
BROKER srv e -> "BROKER " <> text srv <> " " <> bshow e
AGENT (A_QUEUE e) -> "AGENT QUEUE " <> bshow e
AGENT e -> "AGENT " <> bshow e
INTERNAL e -> "INTERNAL " <> bshow e
where
addr = encodeUtf8 . T.pack
text = encodeUtf8 . T.pack
instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU

View File

@@ -55,6 +55,7 @@ module Simplex.Messaging.Client
-- * Supporting types and client configuration
ProtocolClientError (..),
SMPClientError,
ProtocolClientConfig (..),
NetworkConfig (..),
TransportSessionMode (..),
@@ -107,28 +108,28 @@ import System.Timeout (timeout)
-- | 'SMPClient' is a handle used to send commands to a specific SMP server.
--
-- Use 'getSMPClient' to connect to an SMP server and create a client handle.
data ProtocolClient msg = ProtocolClient
data ProtocolClient err msg = ProtocolClient
{ action :: Maybe (Async ()),
sessionId :: SessionId,
sessionTs :: UTCTime,
thVersion :: Version,
client_ :: PClient msg
client_ :: PClient err msg
}
data PClient msg = PClient
data PClient err msg = PClient
{ connected :: TVar Bool,
transportSession :: TransportSession msg,
transportHost :: TransportHost,
tcpTimeout :: Int,
pingErrorCount :: TVar Int,
clientCorrId :: TVar Natural,
sentCommands :: TMap CorrId (Request msg),
sentCommands :: TMap CorrId (Request err msg),
sndQ :: TBQueue (NonEmpty SentRawTransmission),
rcvQ :: TBQueue (NonEmpty (SignedTransmission msg)),
rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)),
msgQ :: Maybe (TBQueue (ServerTransmission msg))
}
type SMPClient = ProtocolClient SMP.BrokerMsg
type SMPClient = ProtocolClient ErrorType SMP.BrokerMsg
-- | Type for client command data
type ClientCommand msg = (Maybe C.APrivateSignKey, QueueId, ProtoCommand msg)
@@ -231,14 +232,14 @@ defaultClientConfig =
smpServerVRange = supportedSMPServerVRange
}
data Request msg = Request
data Request err msg = Request
{ queueId :: QueueId,
responseVar :: TMVar (Response msg)
responseVar :: TMVar (Response err msg)
}
type Response msg = Either ProtocolClientError msg
type Response err msg = Either (ProtocolClientError err) msg
chooseTransportHost :: NetworkConfig -> NonEmpty TransportHost -> Either ProtocolClientError TransportHost
chooseTransportHost :: NetworkConfig -> NonEmpty TransportHost -> Either (ProtocolClientError err) TransportHost
chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts =
firstOrError $ case hostMode of
HMOnionViaSocks -> maybe publicHost (const onionHost) socksProxy
@@ -252,15 +253,15 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts
onionHost = find isOnionHost hosts
publicHost = find (not . isOnionHost) hosts
clientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient msg -> String
clientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient err msg -> String
clientServer = B.unpack . strEncode . snd3 . transportSession . client_
where
snd3 (_, s, _) = s
transportHost' :: ProtocolClient msg -> TransportHost
transportHost' :: ProtocolClient err msg -> TransportHost
transportHost' = transportHost . client_
transportSession' :: ProtocolClient msg -> TransportSession msg
transportSession' :: ProtocolClient err msg -> TransportSession msg
transportSession' = transportSession . client_
type UserId = Int64
@@ -273,7 +274,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId)
--
-- A single queue can be used for multiple 'SMPClient' instances,
-- as 'SMPServerTransmission' includes server information.
getProtocolClient :: forall msg. Protocol msg => TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
getProtocolClient :: forall err msg. Protocol err msg => TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient err msg))
getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
case chooseTransportHost networkConfig (host srv) of
Right useHost ->
@@ -282,7 +283,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
Left e -> pure $ Left e
where
NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig
mkProtocolClient :: TransportHost -> STM (PClient msg)
mkProtocolClient :: TransportHost -> STM (PClient err msg)
mkProtocolClient transportHost = do
connected <- newTVar False
pingErrorCount <- newTVar 0
@@ -304,7 +305,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
msgQ
}
runClient :: (ServiceName, ATransport) -> TransportHost -> PClient msg -> IO (Either ProtocolClientError (ProtocolClient msg))
runClient :: (ServiceName, ATransport) -> TransportHost -> PClient err msg -> IO (Either (ProtocolClientError err) (ProtocolClient err msg))
runClient (port', ATransport t) useHost c = do
cVar <- newEmptyTMVarIO
let tcConfig = transportClientConfig networkConfig
@@ -325,9 +326,9 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
"80" -> ("80", transport @WS)
p -> (p, transport @TLS)
client :: forall c. Transport c => TProxy c -> PClient msg -> TMVar (Either ProtocolClientError (ProtocolClient msg)) -> c -> IO ()
client :: forall c. Transport c => TProxy c -> PClient err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient err msg)) -> c -> IO ()
client _ c cVar h =
runExceptT (protocolClientHandshake @msg h (keyHash srv) smpServerVRange) >>= \case
runExceptT (protocolClientHandshake @err @msg h (keyHash srv) smpServerVRange) >>= \case
Left e -> atomically . putTMVar cVar . Left $ PCETransportError e
Right th@THandle {sessionId, thVersion} -> do
sessionTs <- getCurrentTime
@@ -338,16 +339,16 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
raceAny_ ([send c' th, process c', receive c' th] <> [ping c' | smpPingInterval > 0])
`finally` disconnected c'
send :: Transport c => ProtocolClient msg -> THandle c -> IO ()
send :: Transport c => ProtocolClient err msg -> THandle c -> IO ()
send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
receive :: Transport c => ProtocolClient msg -> THandle c -> IO ()
receive :: Transport c => ProtocolClient err msg -> THandle c -> IO ()
receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
ping :: ProtocolClient msg -> IO ()
ping :: ProtocolClient err msg -> IO ()
ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do
threadDelay smpPingInterval
runExceptT (sendProtocolCommand c Nothing "" protocolPing) >>= \case
runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @err @msg) >>= \case
Left PCEResponseTimeout -> do
cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1)
when (maxCnt == 0 || cnt < maxCnt) $ ping c
@@ -355,10 +356,10 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
where
maxCnt = smpPingCount networkConfig
process :: ProtocolClient msg -> IO ()
process :: ProtocolClient err msg -> IO ()
process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c)
processMsg :: ProtocolClient msg -> SignedTransmission msg -> IO ()
processMsg :: ProtocolClient err msg -> SignedTransmission err msg -> IO ()
processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, qId, respOrErr)) =
if B.null $ bs corrId
then sendMsg respOrErr
@@ -376,7 +377,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
_ -> Right r
else Left . PCEUnexpectedResponse $ bshow respOrErr
where
sendMsg :: Either ErrorType msg -> IO ()
sendMsg :: Either err msg -> IO ()
sendMsg = \case
Right msg -> atomically $ mapM_ (`writeTBQueue` serverTransmission c qId msg) msgQ
Left e -> putStrLn $ "SMP client error: " <> show e
@@ -385,17 +386,17 @@ proxyUsername :: TransportSession msg -> ByteString
proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_
-- | Disconnects client from the server and terminates client threads.
closeProtocolClient :: ProtocolClient msg -> IO ()
closeProtocolClient :: ProtocolClient err msg -> IO ()
closeProtocolClient = mapM_ uninterruptibleCancel . action
-- | SMP client error type.
data ProtocolClientError
data ProtocolClientError err
= -- | Correctly parsed SMP server ERR response.
-- This error is forwarded to the agent client as `ERR SMP err`.
PCEProtocolError ErrorType
PCEProtocolError err
| -- | Invalid server response that failed to parse.
-- Forwarded to the agent client as `ERR BROKER RESPONSE`.
PCEResponseError ErrorType
PCEResponseError err
| -- | Different response from what is expected to a certain SMP command,
-- e.g. server should respond `IDS` or `ERR` to `NEW` command,
-- other responses would result in this error.
@@ -418,6 +419,8 @@ data ProtocolClientError
PCEIOError IOException
deriving (Eq, Show, Exception)
type SMPClientError = ProtocolClientError ErrorType
-- | Create a new SMP queue.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#create-queue-command
@@ -427,7 +430,7 @@ createSMPQueue ::
RcvPublicVerifyKey ->
RcvPublicDhKey ->
Maybe BasicAuth ->
ExceptT ProtocolClientError IO QueueIdsKeys
ExceptT SMPClientError IO QueueIdsKeys
createSMPQueue c rpKey rKey dhKey auth =
sendSMPCommand c (Just rpKey) "" (NEW rKey dhKey auth) >>= \case
IDS qik -> pure qik
@@ -436,7 +439,7 @@ createSMPQueue c rpKey rKey dhKey auth =
-- | Subscribe to the SMP queue.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue
subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO ()
subscribeSMPQueue c rpKey rId =
sendSMPCommand c (Just rpKey) rId SUB >>= \case
OK -> return ()
@@ -444,7 +447,7 @@ subscribeSMPQueue c rpKey rId =
r -> throwE . PCEUnexpectedResponse $ bshow r
-- | Subscribe to multiple SMP queues batching commands if supported.
subscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
subscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM response . L.zip qs
where
cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient SUB)) qs
@@ -457,14 +460,14 @@ subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM response . L.zip qs
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO ()
writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c)
serverTransmission :: ProtocolClient msg -> RecipientId -> msg -> ServerTransmission msg
serverTransmission :: ProtocolClient err msg -> RecipientId -> msg -> ServerTransmission msg
serverTransmission ProtocolClient {thVersion, sessionId, client_ = PClient {transportSession}} entityId message =
(transportSession, thVersion, sessionId, entityId, message)
-- | Get message from SMP queue. The server returns ERR PROHIBITED if a client uses SUB and GET via the same transport connection for the same queue
--
-- https://github.covm/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#receive-a-message-from-the-queue
getSMPMessage :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO (Maybe RcvMessage)
getSMPMessage :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO (Maybe RcvMessage)
getSMPMessage c rpKey rId =
sendSMPCommand c (Just rpKey) rId GET >>= \case
OK -> pure Nothing
@@ -474,26 +477,26 @@ getSMPMessage c rpKey rId =
-- | Subscribe to the SMP queue notifications.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT ProtocolClientError IO ()
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT SMPClientError IO ()
subscribeSMPQueueNotifications = okSMPCommand NSUB
-- | Secure the SMP queue by adding a sender public key.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command
secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT ProtocolClientError IO ()
secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT SMPClientError IO ()
secureSMPQueue c rpKey rId senderKey = okSMPCommand (KEY senderKey) c rpKey rId
-- | Enable notifications for the queue for push notifications server.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#enable-notifications-command
enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> ExceptT ProtocolClientError IO (NotifierId, RcvNtfPublicDhKey)
enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> ExceptT SMPClientError IO (NotifierId, RcvNtfPublicDhKey)
enableSMPQueueNotifications c rpKey rId notifierKey rcvNtfPublicDhKey =
sendSMPCommand c (Just rpKey) rId (NKEY notifierKey rcvNtfPublicDhKey) >>= \case
NID nId rcvNtfSrvPublicDhKey -> pure (nId, rcvNtfSrvPublicDhKey)
r -> throwE . PCEUnexpectedResponse $ bshow r
-- | Enable notifications for the multiple queues for push notifications server.
enableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId, NtfPublicVerifyKey, RcvNtfPublicDhKey) -> IO (NonEmpty (Either ProtocolClientError (NotifierId, RcvNtfPublicDhKey)))
enableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId, NtfPublicVerifyKey, RcvNtfPublicDhKey) -> IO (NonEmpty (Either SMPClientError (NotifierId, RcvNtfPublicDhKey)))
enableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
where
cs = L.map (\(rpKey, rId, notifierKey, rcvNtfPublicDhKey) -> (Just rpKey, rId, Cmd SRecipient $ NKEY notifierKey rcvNtfPublicDhKey)) qs
@@ -505,17 +508,17 @@ enableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
-- | Disable notifications for the queue for push notifications server.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#disable-notifications-command
disableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
disableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO ()
disableSMPQueueNotifications = okSMPCommand NDEL
-- | Disable notifications for multiple queues for push notifications server.
disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
disableSMPQueuesNtfs = okSMPCommands NDEL
-- | Send SMP message.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message
sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgFlags -> MsgBody -> ExceptT ProtocolClientError IO ()
sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgFlags -> MsgBody -> ExceptT SMPClientError IO ()
sendSMPMessage c spKey sId flags msg =
sendSMPCommand c spKey sId (SEND flags msg) >>= \case
OK -> pure ()
@@ -524,7 +527,7 @@ sendSMPMessage c spKey sId flags msg =
-- | Acknowledge message delivery (server deletes the message).
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#acknowledge-message-delivery
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT ProtocolClientError IO ()
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT SMPClientError IO ()
ackSMPMessage c rpKey rId msgId =
sendSMPCommand c (Just rpKey) rId (ACK msgId) >>= \case
OK -> return ()
@@ -535,26 +538,26 @@ ackSMPMessage c rpKey rId msgId =
-- The existing messages from the queue will still be delivered.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue
suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
suspendSMPQueue = okSMPCommand OFF
-- | Irreversibly delete SMP queue and all messages in it.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO ()
deleteSMPQueue = okSMPCommand DEL
-- | Delete multiple SMP queues batching commands if supported.
deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
deleteSMPQueues = okSMPCommands DEL
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
okSMPCommand cmd c pKey qId =
sendSMPCommand c (Just pKey) qId cmd >>= \case
OK -> return ()
r -> throwE . PCEUnexpectedResponse $ bshow r
okSMPCommands :: PartyI p => Command p -> SMPClient -> NonEmpty (C.APrivateSignKey, QueueId) -> IO (NonEmpty (Either ProtocolClientError ()))
okSMPCommands :: PartyI p => Command p -> SMPClient -> NonEmpty (C.APrivateSignKey, QueueId) -> IO (NonEmpty (Either SMPClientError ()))
okSMPCommands cmd c qs = L.map response <$> sendProtocolCommands c cs
where
aCmd = Cmd sParty cmd
@@ -565,11 +568,11 @@ okSMPCommands cmd c qs = L.map response <$> sendProtocolCommands c cs
Left e -> Left e
-- | Send SMP command
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg
sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
-- | Send multiple commands with batching and collect responses
sendProtocolCommands :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either ProtocolClientError msg))
sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either (ProtocolClientError err) msg))
sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ}} cs = do
ts <- mapM (runExceptT . mkTransmission c) cs
mapM_ (atomically . writeTBQueue sndQ . L.map fst) . L.nonEmpty . rights $ L.toList ts
@@ -578,22 +581,22 @@ sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ}} cs = do
Left e -> pure $ Left e
-- | Send Protocol command
sendProtocolCommand :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT ProtocolClientError IO msg
sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}} pKey qId cmd = do
(t, r) <- mkTransmission c (pKey, qId, cmd)
ExceptT $ sendRecv t r
where
-- two separate "atomically" needed to avoid blocking
sendRecv :: SentRawTransmission -> TMVar (Response msg) -> IO (Response msg)
sendRecv :: SentRawTransmission -> TMVar (Response err msg) -> IO (Response err msg)
sendRecv t r = atomically (writeTBQueue sndQ [t]) >> withTimeout c (atomically $ takeTMVar r)
withTimeout :: ProtocolClient msg -> IO (Either ProtocolClientError msg) -> IO (Either ProtocolClientError msg)
withTimeout :: ProtocolClient err msg -> IO (Either (ProtocolClientError err) msg) -> IO (Either (ProtocolClientError err) msg)
withTimeout ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} a =
timeout tcpTimeout a >>= \case
Just r -> atomically (writeTVar pingErrorCount 0) >> pure r
_ -> pure $ Left PCEResponseTimeout
mkTransmission :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> ClientCommand msg -> ExceptT ProtocolClientError IO (SentRawTransmission, TMVar (Response msg))
mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> ExceptT (ProtocolClientError err) IO (SentRawTransmission, TMVar (Response err msg))
mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCorrId, sentCommands}} (pKey, qId, cmd) = do
corrId <- liftIO $ atomically getNextCorrId
let t = signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd)
@@ -606,7 +609,7 @@ mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCo
pure . CorrId $ bshow i
signTransmission :: ByteString -> SentRawTransmission
signTransmission t = ((`C.sign` t) <$> pKey, t)
mkRequest :: CorrId -> STM (TMVar (Response msg))
mkRequest :: CorrId -> STM (TMVar (Response err msg))
mkRequest corrId = do
r <- newEmptyTMVar
TM.insert corrId (Request qId r) sentCommands

View File

@@ -38,14 +38,14 @@ import UnliftIO.Exception (Exception)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
type SMPClientVar = TMVar (Either ProtocolClientError SMPClient)
type SMPClientVar = TMVar (Either SMPClientError SMPClient)
data SMPClientAgentEvent
= CAConnected SMPServer
| CADisconnected SMPServer (Set SMPSub)
| CAReconnected SMPServer
| CAResubscribed SMPServer SMPSub
| CASubError SMPServer SMPSub ProtocolClientError
| CASubError SMPServer SMPSub SMPClientError
data SMPSubParty = SPRecipient | SPNotifier
deriving (Eq, Ord, Show)
@@ -111,7 +111,7 @@ newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} = do
asyncClients <- newTVar []
pure SMPClientAgent {agentCfg, msgQ, agentQ, smpClients, srvSubs, pendingSrvSubs, reconnections, asyncClients}
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT ProtocolClientError IO SMPClient
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT SMPClientError IO SMPClient
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
atomically getClientVar >>= either newSMPClient waitForSMPClient
where
@@ -124,7 +124,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
TM.insert srv smpVar smpClients
pure smpVar
waitForSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient
waitForSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
waitForSMPClient smpVar = do
let ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg
smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar smpVar)
@@ -133,10 +133,10 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
Just (Left e) -> Left e
Nothing -> Left PCEResponseTimeout
newSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient
newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
newSMPClient smpVar = tryConnectClient pure tryConnectAsync
where
tryConnectClient :: (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO () -> ExceptT ProtocolClientError IO a
tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a
tryConnectClient successAction retryAction =
tryE connectClient >>= \r -> case r of
Right smp -> do
@@ -150,16 +150,16 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
putTMVar smpVar (Left e)
TM.delete srv smpClients
throwE e
tryConnectAsync :: ExceptT ProtocolClientError IO ()
tryConnectAsync :: ExceptT SMPClientError IO ()
tryConnectAsync = do
a <- async connectAsync
atomically $ modifyTVar' (asyncClients ca) (a :)
connectAsync :: ExceptT ProtocolClientError IO ()
connectAsync :: ExceptT SMPClientError IO ()
connectAsync =
withRetryInterval (reconnectInterval agentCfg) $ \loop ->
void $ tryConnectClient (const reconnectClient) loop
connectClient :: ExceptT ProtocolClientError IO SMPClient
connectClient :: ExceptT SMPClientError IO SMPClient
connectClient = ExceptT $ getProtocolClient (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) clientDisconnected
clientDisconnected :: SMPClient -> IO ()
@@ -188,17 +188,17 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
notify . CADisconnected srv $ M.keysSet ss
reconnectServer
reconnectServer :: ExceptT ProtocolClientError IO ()
reconnectServer :: ExceptT SMPClientError IO ()
reconnectServer = do
a <- async tryReconnectClient
atomically $ modifyTVar' (reconnections ca) (a :)
tryReconnectClient :: ExceptT ProtocolClientError IO ()
tryReconnectClient :: ExceptT SMPClientError IO ()
tryReconnectClient = do
withRetryInterval (reconnectInterval agentCfg) $ \loop ->
reconnectClient `catchE` const loop
reconnectClient :: ExceptT ProtocolClientError IO ()
reconnectClient :: ExceptT SMPClientError IO ()
reconnectClient = do
withSMP ca srv $ \smp -> do
notify $ CAReconnected srv
@@ -207,13 +207,13 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
unlessM (atomically $ hasSub (srvSubs ca) srv s) $
subscribe_ smp sub `catchE` handleError s
where
subscribe_ :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
subscribe_ :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT SMPClientError IO ()
subscribe_ smp sub@(s, _) = do
smpSubscribe smp sub
atomically $ addSubscription ca srv sub
notify $ CAResubscribed srv s
handleError :: SMPSub -> ProtocolClientError -> ExceptT ProtocolClientError IO ()
handleError :: SMPSub -> SMPClientError -> ExceptT SMPClientError IO ()
handleError s = \case
e@PCEResponseTimeout -> throwE e
e@PCENetworkError -> throwE e
@@ -221,7 +221,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
notify $ CASubError srv s e
atomically $ removePendingSubscription ca srv s
notify :: SMPClientAgentEvent -> ExceptT ProtocolClientError IO ()
notify :: SMPClientAgentEvent -> ExceptT SMPClientError IO ()
notify evt = atomically $ writeTBQueue (agentQ ca) evt
closeSMPClientAgent :: MonadUnliftIO m => SMPClientAgent -> m ()
@@ -241,15 +241,15 @@ closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeCli
cancelActions :: Foldable f => TVar (f (Async ())) -> IO ()
cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel
withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO a
withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO a
withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPError
where
logSMPError :: ProtocolClientError -> ExceptT ProtocolClientError IO a
logSMPError :: SMPClientError -> ExceptT SMPClientError IO a
logSMPError e = do
liftIO $ putStrLn $ "SMP error (" <> show srv <> "): " <> show e
throwE e
subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> ExceptT SMPClientError IO ()
subscribeQueue ca srv sub = do
atomically $ addPendingSubscription ca srv sub
withSMP ca srv $ \smp -> subscribe_ smp `catchE` handleError
@@ -267,7 +267,7 @@ showServer :: SMPServer -> ByteString
showServer ProtocolServer {host, port} =
strEncode host <> B.pack (if null port then "" else ':' : port)
smpSubscribe :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
smpSubscribe :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT SMPClientError IO ()
smpSubscribe smp ((party, queueId), privKey) = subscribe_ smp privKey queueId
where
subscribe_ = case party of

View File

@@ -10,54 +10,57 @@ import Data.Word (Word16)
import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Protocol (ErrorType)
import Simplex.Messaging.Util (bshow)
type NtfClient = ProtocolClient NtfResponse
type NtfClient = ProtocolClient ErrorType NtfResponse
ntfRegisterToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519)
type NtfClientError = ProtocolClientError ErrorType
ntfRegisterToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT NtfClientError IO (NtfTokenId, C.PublicKeyX25519)
ntfRegisterToken c pKey newTkn =
sendNtfCommand c (Just pKey) "" (TNEW newTkn) >>= \case
NRTknId tknId dhKey -> pure (tknId, dhKey)
r -> throwE . PCEUnexpectedResponse $ bshow r
ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT ProtocolClientError IO ()
ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT NtfClientError IO ()
ntfVerifyToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId
ntfCheckToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO NtfTknStatus
ntfCheckToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT NtfClientError IO NtfTknStatus
ntfCheckToken c pKey tknId =
sendNtfCommand c (Just pKey) tknId TCHK >>= \case
NRTkn stat -> pure stat
r -> throwE . PCEUnexpectedResponse $ bshow r
ntfReplaceToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> DeviceToken -> ExceptT ProtocolClientError IO ()
ntfReplaceToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> DeviceToken -> ExceptT NtfClientError IO ()
ntfReplaceToken c pKey tknId token = okNtfCommand (TRPL token) c pKey tknId
ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO ()
ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT NtfClientError IO ()
ntfDeleteToken = okNtfCommand TDEL
ntfEnableCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO ()
ntfEnableCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT NtfClientError IO ()
ntfEnableCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId
ntfCreateSubscription :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO NtfSubscriptionId
ntfCreateSubscription :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT NtfClientError IO NtfSubscriptionId
ntfCreateSubscription c pKey newSub =
sendNtfCommand c (Just pKey) "" (SNEW newSub) >>= \case
NRSubId subId -> pure subId
r -> throwE . PCEUnexpectedResponse $ bshow r
ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus
ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT NtfClientError IO NtfSubStatus
ntfCheckSubscription c pKey subId =
sendNtfCommand c (Just pKey) subId SCHK >>= \case
NRSub stat -> pure stat
r -> throwE . PCEUnexpectedResponse $ bshow r
ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO ()
ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT NtfClientError IO ()
ntfDeleteSubscription = okNtfCommand SDEL
-- | Send notification server command
sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT ProtocolClientError IO NtfResponse
sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT NtfClientError IO NtfResponse
sendNtfCommand c pKey entId cmd = sendProtocolCommand c pKey entId (NtfCmd sNtfEntity cmd)
okNtfCommand :: NtfEntityI e => NtfCommand e -> NtfClient -> C.APrivateSignKey -> NtfEntityId -> ExceptT ProtocolClientError IO ()
okNtfCommand :: NtfEntityI e => NtfCommand e -> NtfClient -> C.APrivateSignKey -> NtfEntityId -> ExceptT NtfClientError IO ()
okNtfCommand cmd c pKey entId =
sendNtfCommand c (Just pKey) entId cmd >>= \case
NROk -> return ()

View File

@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -146,7 +147,7 @@ instance Encoding ANewNtfEntity where
'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP <*> smpP)
_ -> fail "bad ANewNtfEntity"
instance Protocol NtfResponse where
instance Protocol ErrorType NtfResponse where
type ProtoCommand NtfResponse = NtfCmd
type ProtoType NtfResponse = 'PNTF
protocolClientHandshake = ntfClientHandshake
@@ -183,7 +184,7 @@ data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e)
deriving instance Show NtfCmd
instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
type Tag (NtfCommand e) = NtfCommandTag e
encodeProtocol _v = \case
TNEW newTkn -> e (TNEW_, ' ', newTkn)
@@ -202,6 +203,9 @@ instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
protocolP _v tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP _v (NCT (sNtfEntity @e) tag)
fromProtocolError = fromProtocolError @ErrorType @NtfResponse
{-# INLINE fromProtocolError #-}
checkCredentials (sig, _, entityId, _) cmd = case cmd of
-- TNEW and SNEW must have signature but NOT token/subscription IDs
TNEW {} -> sigNoEntity
@@ -219,7 +223,7 @@ instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
| not (B.null entityId) = Left $ CMD HAS_AUTH
| otherwise = Right cmd
instance ProtocolEncoding NtfCmd where
instance ProtocolEncoding ErrorType NtfCmd where
type Tag NtfCmd = NtfCmdTag
encodeProtocol _v (NtfCmd _ c) = encodeProtocol _v c
@@ -239,6 +243,9 @@ instance ProtocolEncoding NtfCmd where
SDEL_ -> pure SDEL
PING_ -> pure PING
fromProtocolError = fromProtocolError @ErrorType @NtfResponse
{-# INLINE fromProtocolError #-}
checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c
data NtfResponseTag
@@ -283,7 +290,7 @@ data NtfResponse
| NRPong
deriving (Show)
instance ProtocolEncoding NtfResponse where
instance ProtocolEncoding ErrorType NtfResponse where
type Tag NtfResponse = NtfResponseTag
encodeProtocol _v = \case
NRTknId entId dhKey -> e (NRTknId_, ' ', entId, dhKey)
@@ -306,6 +313,13 @@ instance ProtocolEncoding NtfResponse where
NRSub_ -> NRSub <$> _smpP
NRPong_ -> pure NRPong
fromProtocolError = \case
PECmdSyntax -> CMD SYNTAX
PECmdUnknown -> CMD UNKNOWN
PESession -> SESSION
PEBlock -> BLOCK
{-# INLINE fromProtocolError #-}
checkCredentials (_, _, entId, _) cmd = case cmd of
-- IDTKN response must not have queue ID
NRTknId {} -> noEntity

View File

@@ -27,7 +27,7 @@ import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
import Data.Time.Clock.System (getSystemTime)
import Data.Time.Format.ISO8601 (iso8601Show)
import Network.Socket (ServiceName)
import Simplex.Messaging.Client (ProtocolClientError (..))
import Simplex.Messaging.Client (ProtocolClientError (..), SMPClientError)
import Simplex.Messaging.Client.Agent
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
@@ -227,7 +227,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
where
showServer' = decodeLatin1 . strEncode . host
handleSubError :: SMPQueueNtf -> ProtocolClientError -> M ()
handleSubError :: SMPQueueNtf -> SMPClientError -> M ()
handleSubError smpQueue = \case
PCEProtocolError AUTH -> updateSubStatus smpQueue NSAuth
PCEProtocolError e -> updateErr "SMP error " e
@@ -343,7 +343,7 @@ send h@THandle {thVersion = v} NtfServerClient {sndQ, sessionId, activeAt} = for
data VerificationResult = VRVerified NtfRequest | VRFailed
verifyNtfTransmission :: SignedTransmission NtfCmd -> NtfCmd -> M VerificationResult
verifyNtfTransmission :: SignedTransmission ErrorType NtfCmd -> NtfCmd -> M VerificationResult
verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
st <- asks store
case cmd of

View File

@@ -4,6 +4,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
@@ -19,11 +20,10 @@
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# HLINT ignore "Use newtype instead of data" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
{-# HLINT ignore "Use newtype instead of data" #-}
-- |
-- Module : Simplex.Messaging.ProtocolEncoding
-- Copyright : (c) simplex.chat
@@ -52,6 +52,7 @@ module Simplex.Messaging.Protocol
SParty (..),
PartyI (..),
QueueIdsKeys (..),
ProtocolErrorType (..),
ErrorType (..),
CommandError (..),
Transmission,
@@ -224,7 +225,7 @@ deriving instance Show Cmd
type Transmission c = (CorrId, EntityId, c)
-- | signed parsed transmission, with original raw bytes and parsing error.
type SignedTransmission c = (Maybe C.ASignature, Signed, Transmission (Either ErrorType c))
type SignedTransmission e c = (Maybe C.ASignature, Signed, Transmission (Either e c))
type Signed = ByteString
@@ -874,6 +875,8 @@ type MsgId = ByteString
-- | SMP message body.
type MsgBody = ByteString
data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock
-- | Type for protocol errors.
data ErrorType
= -- | incorrect block format, encoding or signature size
@@ -944,16 +947,16 @@ transmissionP = do
command <- A.takeByteString
pure RawTransmission {signature, signed, sessId, corrId, entityId, command}
class (ProtocolEncoding msg, ProtocolEncoding (ProtoCommand msg), Show msg) => Protocol msg where
class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where
type ProtoCommand msg = cmd | cmd -> msg
type ProtoType msg = (sch :: ProtocolType) | sch -> msg
protocolClientHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
protocolPing :: ProtoCommand msg
protocolError :: msg -> Maybe ErrorType
protocolError :: msg -> Maybe err
type ProtoServer msg = ProtocolServer (ProtoType msg)
instance Protocol BrokerMsg where
instance Protocol ErrorType BrokerMsg where
type ProtoCommand BrokerMsg = Cmd
type ProtoType BrokerMsg = 'PSMP
protocolClientHandshake = smpClientHandshake
@@ -962,13 +965,14 @@ instance Protocol BrokerMsg where
ERR e -> Just e
_ -> Nothing
class ProtocolMsgTag (Tag msg) => ProtocolEncoding msg where
class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where
type Tag msg
encodeProtocol :: Version -> msg -> ByteString
protocolP :: Version -> Tag msg -> Parser msg
checkCredentials :: SignedRawTransmission -> msg -> Either ErrorType msg
fromProtocolError :: ProtocolErrorType -> err
checkCredentials :: SignedRawTransmission -> msg -> Either err msg
instance PartyI p => ProtocolEncoding (Command p) where
instance PartyI p => ProtocolEncoding ErrorType (Command p) where
type Tag (Command p) = CommandTag p
encodeProtocol v = \case
NEW rKey dhKey auth_ -> case auth_ of
@@ -999,6 +1003,9 @@ instance PartyI p => ProtocolEncoding (Command p) where
protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag)
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
checkCredentials (sig, _, queueId, _) cmd = case cmd of
-- NEW must have signature but NOT queue ID
NEW {}
@@ -1018,7 +1025,7 @@ instance PartyI p => ProtocolEncoding (Command p) where
| isNothing sig || B.null queueId -> Left $ CMD NO_AUTH
| otherwise -> Right cmd
instance ProtocolEncoding Cmd where
instance ProtocolEncoding ErrorType Cmd where
type Tag Cmd = CmdTag
encodeProtocol v (Cmd _ c) = encodeProtocol v c
@@ -1048,9 +1055,12 @@ instance ProtocolEncoding Cmd where
PING_ -> pure PING
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
instance ProtocolEncoding BrokerMsg where
instance ProtocolEncoding ErrorType BrokerMsg where
type Tag BrokerMsg = BrokerMsgTag
encodeProtocol v = \case
IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh)
@@ -1085,6 +1095,13 @@ instance ProtocolEncoding BrokerMsg where
ERR_ -> ERR <$> _smpP
PONG_ -> pure PONG
fromProtocolError = \case
PECmdSyntax -> CMD SYNTAX
PECmdUnknown -> CMD UNKNOWN
PESession -> SESSION
PEBlock -> BLOCK
{-# INLINE fromProtocolError #-}
checkCredentials (_, _, queueId, _) cmd = case cmd of
-- IDS response should not have queue ID
IDS _ -> Right cmd
@@ -1103,12 +1120,12 @@ _smpP :: Encoding a => Parser a
_smpP = A.space *> smpP
-- | Parse SMP protocol commands and broker messages
parseProtocol :: ProtocolEncoding msg => Version -> ByteString -> Either ErrorType msg
parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg
parseProtocol v s =
let (tag, params) = B.break (== ' ') s
in case decodeTag tag of
Just cmd -> parse (protocolP v cmd) (CMD SYNTAX) params
Nothing -> Left $ CMD UNKNOWN
Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params
Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown
checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p)
checkParty c = case testEquality (sParty @p) (sParty @p') of
@@ -1203,7 +1220,7 @@ tEncode (sig, t) = smpEncode (C.signatureBytes sig) <> t
tEncodeBatch :: Int -> ByteString -> ByteString
tEncodeBatch n s = lenEncode n `B.cons` s
encodeTransmission :: ProtocolEncoding c => Version -> ByteString -> Transmission c -> ByteString
encodeTransmission :: ProtocolEncoding e c => Version -> ByteString -> Transmission c -> ByteString
encodeTransmission v sessionId (CorrId corrId, queueId, command) =
smpEncode (sessionId, corrId, queueId) <> encodeProtocol v command
@@ -1223,22 +1240,22 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b
eitherList = either (\e -> [Left e])
-- | Receive client and server transmissions (determined by `cmd` type).
tGet :: forall cmd c. (ProtocolEncoding cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission cmd))
tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd))
tGet th@THandle {sessionId, thVersion = v} = L.map (tDecodeParseValidate sessionId v) <$> tGetParse th
tDecodeParseValidate :: forall cmd. ProtocolEncoding cmd => SessionId -> Version -> Either TransportError RawTransmission -> SignedTransmission cmd
tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => SessionId -> Version -> Either TransportError RawTransmission -> SignedTransmission err cmd
tDecodeParseValidate sessionId v = \case
Right RawTransmission {signature, signed, sessId, corrId, entityId, command}
| sessId == sessionId ->
let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature
in either (const $ tError corrId) (tParseValidate signed) decodedTransmission
| otherwise -> (Nothing, "", (CorrId corrId, "", Left SESSION))
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession))
Left _ -> tError ""
where
tError :: ByteString -> SignedTransmission cmd
tError corrId = (Nothing, "", (CorrId corrId, "", Left BLOCK))
tError :: ByteString -> SignedTransmission err cmd
tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock))
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission cmd
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd
tParseValidate signed t@(sig, corrId, entityId, command) =
let cmd = parseProtocol v command >>= checkCredentials t
let cmd = parseProtocol @err @cmd v command >>= checkCredentials t
in (sig, signed, (CorrId corrId, entityId, cmd))

View File

@@ -266,7 +266,7 @@ receive th Client {rcvQ, sndQ, activeAt} = forever $ do
write sndQ $ fst as
write rcvQ $ snd as
where
cmdAction :: SignedTransmission Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd))
cmdAction :: SignedTransmission ErrorType Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd))
cmdAction (sig, signed, (corrId, queueId, cmdOrError)) =
case cmdOrError of
Left e -> pure $ Left (corrId, queueId, ERR e)

View File

@@ -6,7 +6,7 @@ module CoreTests.ProtocolErrorTests where
import qualified Data.ByteString.Char8 as B
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Simplex.Messaging.Agent.Protocol (AgentErrorType (..))
import Simplex.Messaging.Agent.Protocol (AgentErrorType (..), BrokerErrorType (..))
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (parseAll)
import Test.Hspec
@@ -17,12 +17,14 @@ protocolErrorTests :: Spec
protocolErrorTests = modifyMaxSuccess (const 1000) $ do
describe "errors parsing / serializing" $ do
it "should parse SMP protocol errors" . property $ \(err :: AgentErrorType) ->
errServerHasSpaces err
errHasSpaces err
|| parseAll strP (strEncode err) == Right err
it "should parse SMP agent errors" . property $ \(err :: AgentErrorType) ->
errServerHasSpaces err
errHasSpaces err
|| parseAll strP (strEncode err) == Right err
where
errServerHasSpaces = \case
BROKER srv _ -> ' ' `B.elem` encodeUtf8 (T.pack srv)
errHasSpaces = \case
BROKER srv (RESPONSE e) -> hasSpaces srv || hasSpaces e
BROKER srv _ -> hasSpaces srv
_ -> False
hasSpaces s = ' ' `B.elem` encodeUtf8 (T.pack s)

View File

@@ -64,16 +64,16 @@ ntfSyntaxTests (ATransport t) = do
Expectation
command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission NtfResponse
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission ErrorType NtfResponse
pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command))
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse)
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
sendRecvNtf h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
Right () <- tPut1 h (sgn, t)
tGet1 h
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse)
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
signSendRecvNtf h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
Right () <- tPut1 h (Just $ C.sign pk t, t)

View File

@@ -63,7 +63,7 @@ serverTests t@(ATransport t') = do
testMsgExpireOnInterval t'
testMsgNOTExpireOnInterval t'
pattern Resp :: CorrId -> QueueId -> BrokerMsg -> SignedTransmission BrokerMsg
pattern Resp :: CorrId -> QueueId -> BrokerMsg -> SignedTransmission ErrorType BrokerMsg
pattern Resp corrId queueId command <- (_, _, (corrId, queueId, Right command))
pattern Ids :: RecipientId -> SenderId -> RcvPublicDhKey -> BrokerMsg
@@ -72,13 +72,13 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh)
pattern Msg :: MsgId -> MsgBody -> BrokerMsg
pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body}
sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg)
sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
sendRecv h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
Right () <- tPut1 h (sgn, t)
tGet1 h
signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg)
signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
signSendRecv h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
Right () <- tPut1 h (Just $ C.sign pk t, t)
@@ -89,7 +89,7 @@ tPut1 h t = do
[r] <- tPut h [t]
pure r
tGet1 :: (ProtocolEncoding cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission cmd)
tGet1 :: (ProtocolEncoding err cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission err cmd)
tGet1 h = do
[r] <- liftIO $ tGet h
pure r
@@ -428,7 +428,7 @@ testSwitchSub (ATransport t) =
Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, ACK mId3)
(ok3, OK) #== "accepts ACK from the 2nd TCP connection"
1000 `timeout` tGet @BrokerMsg rh1 >>= \case
1000 `timeout` tGet @ErrorType @BrokerMsg rh1 >>= \case
Nothing -> return ()
Just _ -> error "nothing else is delivered to the 1st TCP connection"
@@ -869,14 +869,14 @@ testMessageNotifications (ATransport t) =
Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2)
(dec mId2 msg2, Right "hello again") #== "delivered from queue again"
Resp "" _ (NMSG _ _) <- tGet1 nh2
1000 `timeout` tGet @BrokerMsg nh1 >>= \case
1000 `timeout` tGet @ErrorType @BrokerMsg nh1 >>= \case
Nothing -> pure ()
Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection"
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL)
Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there")
Resp "" _ (Msg mId3 msg3) <- tGet1 rh
(dec mId3 msg3, Right "hello there") #== "delivered from queue again"
1000 `timeout` tGet @BrokerMsg nh2 >>= \case
1000 `timeout` tGet @ErrorType @BrokerMsg nh2 >>= \case
Nothing -> pure ()
Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection"
@@ -895,7 +895,7 @@ testMsgExpireOnSend t =
testSMPClient @c $ \rh -> do
Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB)
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
1000 `timeout` tGet @BrokerMsg rh >>= \case
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
Nothing -> return ()
Just _ -> error "nothing else should be delivered"
@@ -911,7 +911,7 @@ testMsgExpireOnInterval t =
threadDelay 2500000
testSMPClient @c $ \rh -> do
Resp "2" _ OK <- signSendRecv rh rKey ("2", rId, SUB)
1000 `timeout` tGet @BrokerMsg rh >>= \case
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
Nothing -> return ()
Just _ -> error "nothing should be delivered"
@@ -929,7 +929,7 @@ testMsgNOTExpireOnInterval t =
testSMPClient @c $ \rh -> do
Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB)
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
1000 `timeout` tGet @BrokerMsg rh >>= \case
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
Nothing -> return ()
Just _ -> error "nothing else should be delivered"