From 2ddfb044fcb7bf880b6082cd9e29befe57aba247 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Fri, 17 Feb 2023 20:46:01 +0000 Subject: [PATCH] parameterize protocol by error type (#644) --- src/Simplex/Messaging/Agent/Client.hs | 69 +++++------ src/Simplex/Messaging/Agent/Protocol.hs | 18 +-- src/Simplex/Messaging/Client.hs | 111 +++++++++--------- src/Simplex/Messaging/Client/Agent.hs | 38 +++--- src/Simplex/Messaging/Notifications/Client.hs | 27 +++-- .../Messaging/Notifications/Protocol.hs | 22 +++- src/Simplex/Messaging/Notifications/Server.hs | 6 +- src/Simplex/Messaging/Protocol.hs | 61 ++++++---- src/Simplex/Messaging/Server.hs | 2 +- tests/CoreTests/ProtocolErrorTests.hs | 12 +- tests/NtfServerTests.hs | 6 +- tests/ServerTests.hs | 20 ++-- 12 files changed, 216 insertions(+), 176 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 9b68c4172..d79daa694 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -5,6 +5,7 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -162,7 +163,7 @@ import UnliftIO (mapConcurrently) import qualified UnliftIO.Exception as E import UnliftIO.STM -type ClientVar msg = TMVar (Either AgentErrorType (ProtocolClient msg)) +type ClientVar err msg = TMVar (Either AgentErrorType (ProtocolClient err msg)) type SMPClientVar = TMVar (Either AgentErrorType SMPClient) @@ -319,15 +320,15 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do agentClientStore :: AgentClient -> SQLiteStore agentClientStore AgentClient {agentEnv = Env {store}} = store -class ProtocolServerClient msg where - getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient msg) - clientProtocolError :: ErrorType -> AgentErrorType +class ProtocolServerClient err msg | msg -> err where + getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient err msg) + clientProtocolError :: err -> AgentErrorType -instance ProtocolServerClient BrokerMsg where +instance ProtocolServerClient ErrorType BrokerMsg where getProtocolServerClient = getSMPServerClient clientProtocolError = SMP -instance ProtocolServerClient NtfResponse where +instance ProtocolServerClient ErrorType NtfResponse where getProtocolServerClient = getNtfServerClient clientProtocolError = NTF @@ -428,7 +429,7 @@ getClientVar tSess clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM TM.insert tSess var clients pure var -waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (ProtocolClient msg) +waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar err msg -> m (ProtocolClient err msg) waitForProtocolClient c (_, srv, _) clientVar = do NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar) @@ -438,18 +439,18 @@ waitForProtocolClient c (_, srv, _) clientVar = do Nothing -> Left $ BROKER (B.unpack $ strEncode srv) TIMEOUT newProtocolClient :: - forall msg m. + forall err msg m. (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> - TMap (TransportSession msg) (ClientVar msg) -> - m (ProtocolClient msg) -> + TMap (TransportSession msg) (ClientVar err msg) -> + m (ProtocolClient err msg) -> (AgentClient -> TransportSession msg -> m ()) -> - ClientVar msg -> - m (ProtocolClient msg) + ClientVar err msg -> + m (ProtocolClient err msg) newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync where - tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a + tryConnectClient :: (ProtocolClient err msg -> m a) -> m () -> m a tryConnectClient successAction retryAction = tryError connectClient >>= \r -> case r of Right client -> do @@ -474,7 +475,7 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconne withRetryInterval ri $ \loop -> void $ tryConnectClient (const $ reconnectClient c tSess) loop atomically . removeAsyncAction aId $ asyncClients c -hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent +hostEvent :: forall err msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient err msg -> ACommand 'Agent hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost' client getClientConfig :: AgentMonad m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig @@ -518,7 +519,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} = where k = (server, sndId) -closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () +closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar err msg)) -> IO () closeProtocolServerClients c clientsSel = atomically (swapTVar cs M.empty) >>= mapM_ (forkIO . closeClient) where @@ -541,32 +542,32 @@ withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pur where newLock = createLock >>= \l -> TM.insert key l locks $> l -withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> m a) -> m a +withClient_ :: forall a m err msg. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient err msg -> m a) -> m a withClient_ c tSess@(userId, srv, _) statCmd action = do cl <- getProtocolServerClient c tSess (action cl <* stat cl "OK") `catchError` logServerError cl where stat cl = liftIO . incClientStat c userId cl statCmd - logServerError :: ProtocolClient msg -> AgentErrorType -> m a + logServerError :: ProtocolClient err msg -> AgentErrorType -> m a logServerError cl e = do logServer "<--" c srv "" $ strEncode e stat cl $ strEncode e throwError e -withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> m a) -> m a +withLogClient_ :: (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient err msg -> m a) -> m a withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do logServer "-->" c srv entId cmdStr res <- withClient_ c tSess cmdStr action logServer "<--" c srv entId "OK" return res -withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a -withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client +withClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg, ProtocolTypeI (ProtoType msg), Encoding err, Show err) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient err msg -> ExceptT (ProtocolClientError err) IO a) -> m a +withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client -withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a -withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client +withLogClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg, ProtocolTypeI (ProtoType msg), Encoding err, Show err) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient err msg -> ExceptT (ProtocolClientError err) IO a) -> m a +withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client -withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT ProtocolClientError IO a) -> m a +withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a withSMPClient c q cmdStr action = do tSess <- mkSMPTransportSession c q withLogClient c tSess (queueId q) cmdStr action @@ -576,16 +577,16 @@ withSMPClient_ c q cmdStr action = do tSess <- mkSMPTransportSession c q withLogClient_ c tSess (queueId q) cmdStr action -withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT ProtocolClientError IO a) -> m a +withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT NtfClientError IO a) -> m a withNtfClient c srv = withLogClient c (0, srv, Nothing) -liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> HostName -> ExceptT ProtocolClientError IO a -> m a +liftClient :: (AgentMonad m, Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> m a liftClient protocolError_ = liftError . protocolClientError protocolError_ -protocolClientError :: (ErrorType -> AgentErrorType) -> HostName -> ProtocolClientError -> AgentErrorType +protocolClientError :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ProtocolClientError err -> AgentErrorType protocolClientError protocolError_ host = \case PCEProtocolError e -> protocolError_ e - PCEResponseError e -> BROKER host $ RESPONSE e + PCEResponseError e -> BROKER host $ RESPONSE $ B.unpack $ smpEncode e PCEUnexpectedResponse _ -> BROKER host UNEXPECTED PCEResponseTimeout -> BROKER host TIMEOUT PCENetworkError -> BROKER host NETWORK @@ -632,7 +633,7 @@ runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do Left e -> pure (Just $ testErr TSConnect e) where addr = B.unpack $ strEncode srv - testErr :: SMPTestStep -> ProtocolClientError -> SMPTestFailure + testErr :: SMPTestStep -> SMPClientError -> SMPTestFailure testErr step = SMPTestFailure step . protocolClientError SMP addr mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg) @@ -682,7 +683,7 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do } pure (rq, SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey) -processSubResult :: AgentClient -> RcvQueue -> Either ProtocolClientError () -> IO (Either ProtocolClientError ()) +processSubResult :: AgentClient -> RcvQueue -> Either SMPClientError () -> IO (Either SMPClientError ()) processSubResult c rq r = do case r of Left e -> @@ -691,7 +692,7 @@ processSubResult c rq r = do _ -> addSubscription c rq pure r -temporaryClientError :: ProtocolClientError -> Bool +temporaryClientError :: ProtocolClientError err -> Bool temporaryClientError = \case PCENetworkError -> True PCEResponseTimeout -> True @@ -732,7 +733,7 @@ subscribeQueues c qs = do type BatchResponses e = (NonEmpty (RcvQueue, Either e ())) -- statBatchSize is not used to batch the commands, only for traffic statistics -sendTSessionBatches :: forall m. AgentMonad m => ByteString -> Int -> (SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)) -> AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())] +sendTSessionBatches :: forall m. AgentMonad m => ByteString -> Int -> (SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError)) -> AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())] sendTSessionBatches statCmd statBatchSize action c qs = concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues) where @@ -759,7 +760,7 @@ sendTSessionBatches statCmd statBatchSize action c qs = let n = (length qs - 1) `div` statBatchSize + 1 in incClientStatN c userId smp n (statCmd <> "S") "OK" -sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateSignKey, SMP.RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError) +sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateSignKey, SMP.RecipientId) -> IO (NonEmpty (Either SMPClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError) sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs) where queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) @@ -1047,7 +1048,7 @@ incStat AgentClient {agentStats} n k = do Just v -> modifyTVar' v (+ n) _ -> newTVar n >>= \v -> TM.insert k v agentStats -incClientStat :: AgentClient -> UserId -> ProtocolClient msg -> ByteString -> ByteString -> IO () +incClientStat :: AgentClient -> UserId -> ProtocolClient err msg -> ByteString -> ByteString -> IO () incClientStat c userId pc = incClientStatN c userId pc 1 incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO () @@ -1057,7 +1058,7 @@ incServerStat c userId ProtocolServer {host} cmd res = do where statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res} -incClientStatN :: AgentClient -> UserId -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO () +incClientStatN :: AgentClient -> UserId -> ProtocolClient err msg -> Int -> ByteString -> ByteString -> IO () incClientStatN c userId pc n cmd res = do atomically $ incStat c n statsKey where diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 39bae0a24..c08f86fd7 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -1121,7 +1121,7 @@ instance ToJSON ConnectionErrorType where -- | SMP server errors. data BrokerErrorType = -- | invalid server response (failed to parse) - RESPONSE {smpErr :: ErrorType} + RESPONSE {smpErr :: String} | -- | unexpected response UNEXPECTED | -- | network error @@ -1164,27 +1164,27 @@ instance StrEncoding AgentErrorType where <|> "CONN " *> (CONN <$> parseRead1) <|> "SMP " *> (SMP <$> strP) <|> "NTF " *> (NTF <$> strP) - <|> "BROKER " *> (BROKER <$> srvP <* " RESPONSE " <*> (RESPONSE <$> strP)) - <|> "BROKER " *> (BROKER <$> srvP <* " TRANSPORT " <*> (TRANSPORT <$> transportErrorP)) - <|> "BROKER " *> (BROKER <$> srvP <* A.space <*> parseRead1) + <|> "BROKER " *> (BROKER <$> textP <* " RESPONSE " <*> (RESPONSE <$> textP)) + <|> "BROKER " *> (BROKER <$> textP <* " TRANSPORT " <*> (TRANSPORT <$> transportErrorP)) + <|> "BROKER " *> (BROKER <$> textP <* A.space <*> parseRead1) <|> "AGENT QUEUE " *> (AGENT . A_QUEUE <$> parseRead A.takeByteString) <|> "AGENT " *> (AGENT <$> parseRead1) <|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString) where - srvP = T.unpack . safeDecodeUtf8 <$> A.takeTill (== ' ') + textP = T.unpack . safeDecodeUtf8 <$> A.takeTill (== ' ') strEncode = \case CMD e -> "CMD " <> bshow e CONN e -> "CONN " <> bshow e SMP e -> "SMP " <> strEncode e NTF e -> "NTF " <> strEncode e - BROKER srv (RESPONSE e) -> "BROKER " <> addr srv <> " RESPONSE " <> strEncode e - BROKER srv (TRANSPORT e) -> "BROKER " <> addr srv <> " TRANSPORT " <> serializeTransportError e - BROKER srv e -> "BROKER " <> addr srv <> " " <> bshow e + BROKER srv (RESPONSE e) -> "BROKER " <> text srv <> " RESPONSE " <> text e + BROKER srv (TRANSPORT e) -> "BROKER " <> text srv <> " TRANSPORT " <> serializeTransportError e + BROKER srv e -> "BROKER " <> text srv <> " " <> bshow e AGENT (A_QUEUE e) -> "AGENT QUEUE " <> bshow e AGENT e -> "AGENT " <> bshow e INTERNAL e -> "INTERNAL " <> bshow e where - addr = encodeUtf8 . T.pack + text = encodeUtf8 . T.pack instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 1af672873..b0f13a35f 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -55,6 +55,7 @@ module Simplex.Messaging.Client -- * Supporting types and client configuration ProtocolClientError (..), + SMPClientError, ProtocolClientConfig (..), NetworkConfig (..), TransportSessionMode (..), @@ -107,28 +108,28 @@ import System.Timeout (timeout) -- | 'SMPClient' is a handle used to send commands to a specific SMP server. -- -- Use 'getSMPClient' to connect to an SMP server and create a client handle. -data ProtocolClient msg = ProtocolClient +data ProtocolClient err msg = ProtocolClient { action :: Maybe (Async ()), sessionId :: SessionId, sessionTs :: UTCTime, thVersion :: Version, - client_ :: PClient msg + client_ :: PClient err msg } -data PClient msg = PClient +data PClient err msg = PClient { connected :: TVar Bool, transportSession :: TransportSession msg, transportHost :: TransportHost, tcpTimeout :: Int, pingErrorCount :: TVar Int, clientCorrId :: TVar Natural, - sentCommands :: TMap CorrId (Request msg), + sentCommands :: TMap CorrId (Request err msg), sndQ :: TBQueue (NonEmpty SentRawTransmission), - rcvQ :: TBQueue (NonEmpty (SignedTransmission msg)), + rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)), msgQ :: Maybe (TBQueue (ServerTransmission msg)) } -type SMPClient = ProtocolClient SMP.BrokerMsg +type SMPClient = ProtocolClient ErrorType SMP.BrokerMsg -- | Type for client command data type ClientCommand msg = (Maybe C.APrivateSignKey, QueueId, ProtoCommand msg) @@ -231,14 +232,14 @@ defaultClientConfig = smpServerVRange = supportedSMPServerVRange } -data Request msg = Request +data Request err msg = Request { queueId :: QueueId, - responseVar :: TMVar (Response msg) + responseVar :: TMVar (Response err msg) } -type Response msg = Either ProtocolClientError msg +type Response err msg = Either (ProtocolClientError err) msg -chooseTransportHost :: NetworkConfig -> NonEmpty TransportHost -> Either ProtocolClientError TransportHost +chooseTransportHost :: NetworkConfig -> NonEmpty TransportHost -> Either (ProtocolClientError err) TransportHost chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts = firstOrError $ case hostMode of HMOnionViaSocks -> maybe publicHost (const onionHost) socksProxy @@ -252,15 +253,15 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts onionHost = find isOnionHost hosts publicHost = find (not . isOnionHost) hosts -clientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient msg -> String +clientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient err msg -> String clientServer = B.unpack . strEncode . snd3 . transportSession . client_ where snd3 (_, s, _) = s -transportHost' :: ProtocolClient msg -> TransportHost +transportHost' :: ProtocolClient err msg -> TransportHost transportHost' = transportHost . client_ -transportSession' :: ProtocolClient msg -> TransportSession msg +transportSession' :: ProtocolClient err msg -> TransportSession msg transportSession' = transportSession . client_ type UserId = Int64 @@ -273,7 +274,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId) -- -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. -getProtocolClient :: forall msg. Protocol msg => TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg)) +getProtocolClient :: forall err msg. Protocol err msg => TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient err msg)) getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do case chooseTransportHost networkConfig (host srv) of Right useHost -> @@ -282,7 +283,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, Left e -> pure $ Left e where NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig - mkProtocolClient :: TransportHost -> STM (PClient msg) + mkProtocolClient :: TransportHost -> STM (PClient err msg) mkProtocolClient transportHost = do connected <- newTVar False pingErrorCount <- newTVar 0 @@ -304,7 +305,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, msgQ } - runClient :: (ServiceName, ATransport) -> TransportHost -> PClient msg -> IO (Either ProtocolClientError (ProtocolClient msg)) + runClient :: (ServiceName, ATransport) -> TransportHost -> PClient err msg -> IO (Either (ProtocolClientError err) (ProtocolClient err msg)) runClient (port', ATransport t) useHost c = do cVar <- newEmptyTMVarIO let tcConfig = transportClientConfig networkConfig @@ -325,9 +326,9 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, "80" -> ("80", transport @WS) p -> (p, transport @TLS) - client :: forall c. Transport c => TProxy c -> PClient msg -> TMVar (Either ProtocolClientError (ProtocolClient msg)) -> c -> IO () + client :: forall c. Transport c => TProxy c -> PClient err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient err msg)) -> c -> IO () client _ c cVar h = - runExceptT (protocolClientHandshake @msg h (keyHash srv) smpServerVRange) >>= \case + runExceptT (protocolClientHandshake @err @msg h (keyHash srv) smpServerVRange) >>= \case Left e -> atomically . putTMVar cVar . Left $ PCETransportError e Right th@THandle {sessionId, thVersion} -> do sessionTs <- getCurrentTime @@ -338,16 +339,16 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, raceAny_ ([send c' th, process c', receive c' th] <> [ping c' | smpPingInterval > 0]) `finally` disconnected c' - send :: Transport c => ProtocolClient msg -> THandle c -> IO () + send :: Transport c => ProtocolClient err msg -> THandle c -> IO () send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= tPut h - receive :: Transport c => ProtocolClient msg -> THandle c -> IO () + receive :: Transport c => ProtocolClient err msg -> THandle c -> IO () receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ - ping :: ProtocolClient msg -> IO () + ping :: ProtocolClient err msg -> IO () ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do threadDelay smpPingInterval - runExceptT (sendProtocolCommand c Nothing "" protocolPing) >>= \case + runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @err @msg) >>= \case Left PCEResponseTimeout -> do cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1) when (maxCnt == 0 || cnt < maxCnt) $ ping c @@ -355,10 +356,10 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, where maxCnt = smpPingCount networkConfig - process :: ProtocolClient msg -> IO () + process :: ProtocolClient err msg -> IO () process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c) - processMsg :: ProtocolClient msg -> SignedTransmission msg -> IO () + processMsg :: ProtocolClient err msg -> SignedTransmission err msg -> IO () processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, qId, respOrErr)) = if B.null $ bs corrId then sendMsg respOrErr @@ -376,7 +377,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, _ -> Right r else Left . PCEUnexpectedResponse $ bshow respOrErr where - sendMsg :: Either ErrorType msg -> IO () + sendMsg :: Either err msg -> IO () sendMsg = \case Right msg -> atomically $ mapM_ (`writeTBQueue` serverTransmission c qId msg) msgQ Left e -> putStrLn $ "SMP client error: " <> show e @@ -385,17 +386,17 @@ proxyUsername :: TransportSession msg -> ByteString proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_ -- | Disconnects client from the server and terminates client threads. -closeProtocolClient :: ProtocolClient msg -> IO () +closeProtocolClient :: ProtocolClient err msg -> IO () closeProtocolClient = mapM_ uninterruptibleCancel . action -- | SMP client error type. -data ProtocolClientError +data ProtocolClientError err = -- | Correctly parsed SMP server ERR response. -- This error is forwarded to the agent client as `ERR SMP err`. - PCEProtocolError ErrorType + PCEProtocolError err | -- | Invalid server response that failed to parse. -- Forwarded to the agent client as `ERR BROKER RESPONSE`. - PCEResponseError ErrorType + PCEResponseError err | -- | Different response from what is expected to a certain SMP command, -- e.g. server should respond `IDS` or `ERR` to `NEW` command, -- other responses would result in this error. @@ -418,6 +419,8 @@ data ProtocolClientError PCEIOError IOException deriving (Eq, Show, Exception) +type SMPClientError = ProtocolClientError ErrorType + -- | Create a new SMP queue. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#create-queue-command @@ -427,7 +430,7 @@ createSMPQueue :: RcvPublicVerifyKey -> RcvPublicDhKey -> Maybe BasicAuth -> - ExceptT ProtocolClientError IO QueueIdsKeys + ExceptT SMPClientError IO QueueIdsKeys createSMPQueue c rpKey rKey dhKey auth = sendSMPCommand c (Just rpKey) "" (NEW rKey dhKey auth) >>= \case IDS qik -> pure qik @@ -436,7 +439,7 @@ createSMPQueue c rpKey rKey dhKey auth = -- | Subscribe to the SMP queue. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue -subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO () +subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO () subscribeSMPQueue c rpKey rId = sendSMPCommand c (Just rpKey) rId SUB >>= \case OK -> return () @@ -444,7 +447,7 @@ subscribeSMPQueue c rpKey rId = r -> throwE . PCEUnexpectedResponse $ bshow r -- | Subscribe to multiple SMP queues batching commands if supported. -subscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ())) +subscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ())) subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM response . L.zip qs where cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient SUB)) qs @@ -457,14 +460,14 @@ subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM response . L.zip qs writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO () writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c) -serverTransmission :: ProtocolClient msg -> RecipientId -> msg -> ServerTransmission msg +serverTransmission :: ProtocolClient err msg -> RecipientId -> msg -> ServerTransmission msg serverTransmission ProtocolClient {thVersion, sessionId, client_ = PClient {transportSession}} entityId message = (transportSession, thVersion, sessionId, entityId, message) -- | Get message from SMP queue. The server returns ERR PROHIBITED if a client uses SUB and GET via the same transport connection for the same queue -- -- https://github.covm/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#receive-a-message-from-the-queue -getSMPMessage :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO (Maybe RcvMessage) +getSMPMessage :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO (Maybe RcvMessage) getSMPMessage c rpKey rId = sendSMPCommand c (Just rpKey) rId GET >>= \case OK -> pure Nothing @@ -474,26 +477,26 @@ getSMPMessage c rpKey rId = -- | Subscribe to the SMP queue notifications. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications -subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT ProtocolClientError IO () +subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT SMPClientError IO () subscribeSMPQueueNotifications = okSMPCommand NSUB -- | Secure the SMP queue by adding a sender public key. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command -secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT ProtocolClientError IO () +secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT SMPClientError IO () secureSMPQueue c rpKey rId senderKey = okSMPCommand (KEY senderKey) c rpKey rId -- | Enable notifications for the queue for push notifications server. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#enable-notifications-command -enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> ExceptT ProtocolClientError IO (NotifierId, RcvNtfPublicDhKey) +enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> ExceptT SMPClientError IO (NotifierId, RcvNtfPublicDhKey) enableSMPQueueNotifications c rpKey rId notifierKey rcvNtfPublicDhKey = sendSMPCommand c (Just rpKey) rId (NKEY notifierKey rcvNtfPublicDhKey) >>= \case NID nId rcvNtfSrvPublicDhKey -> pure (nId, rcvNtfSrvPublicDhKey) r -> throwE . PCEUnexpectedResponse $ bshow r -- | Enable notifications for the multiple queues for push notifications server. -enableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId, NtfPublicVerifyKey, RcvNtfPublicDhKey) -> IO (NonEmpty (Either ProtocolClientError (NotifierId, RcvNtfPublicDhKey))) +enableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId, NtfPublicVerifyKey, RcvNtfPublicDhKey) -> IO (NonEmpty (Either SMPClientError (NotifierId, RcvNtfPublicDhKey))) enableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs where cs = L.map (\(rpKey, rId, notifierKey, rcvNtfPublicDhKey) -> (Just rpKey, rId, Cmd SRecipient $ NKEY notifierKey rcvNtfPublicDhKey)) qs @@ -505,17 +508,17 @@ enableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs -- | Disable notifications for the queue for push notifications server. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#disable-notifications-command -disableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO () +disableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO () disableSMPQueueNotifications = okSMPCommand NDEL -- | Disable notifications for multiple queues for push notifications server. -disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ())) +disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ())) disableSMPQueuesNtfs = okSMPCommands NDEL -- | Send SMP message. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message -sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgFlags -> MsgBody -> ExceptT ProtocolClientError IO () +sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgFlags -> MsgBody -> ExceptT SMPClientError IO () sendSMPMessage c spKey sId flags msg = sendSMPCommand c spKey sId (SEND flags msg) >>= \case OK -> pure () @@ -524,7 +527,7 @@ sendSMPMessage c spKey sId flags msg = -- | Acknowledge message delivery (server deletes the message). -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#acknowledge-message-delivery -ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT ProtocolClientError IO () +ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT SMPClientError IO () ackSMPMessage c rpKey rId msgId = sendSMPCommand c (Just rpKey) rId (ACK msgId) >>= \case OK -> return () @@ -535,26 +538,26 @@ ackSMPMessage c rpKey rId msgId = -- The existing messages from the queue will still be delivered. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue -suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO () +suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO () suspendSMPQueue = okSMPCommand OFF -- | Irreversibly delete SMP queue and all messages in it. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue -deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO () +deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO () deleteSMPQueue = okSMPCommand DEL -- | Delete multiple SMP queues batching commands if supported. -deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ())) +deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ())) deleteSMPQueues = okSMPCommands DEL -okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO () +okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT SMPClientError IO () okSMPCommand cmd c pKey qId = sendSMPCommand c (Just pKey) qId cmd >>= \case OK -> return () r -> throwE . PCEUnexpectedResponse $ bshow r -okSMPCommands :: PartyI p => Command p -> SMPClient -> NonEmpty (C.APrivateSignKey, QueueId) -> IO (NonEmpty (Either ProtocolClientError ())) +okSMPCommands :: PartyI p => Command p -> SMPClient -> NonEmpty (C.APrivateSignKey, QueueId) -> IO (NonEmpty (Either SMPClientError ())) okSMPCommands cmd c qs = L.map response <$> sendProtocolCommands c cs where aCmd = Cmd sParty cmd @@ -565,11 +568,11 @@ okSMPCommands cmd c qs = L.map response <$> sendProtocolCommands c cs Left e -> Left e -- | Send SMP command -sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg +sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd) -- | Send multiple commands with batching and collect responses -sendProtocolCommands :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either ProtocolClientError msg)) +sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either (ProtocolClientError err) msg)) sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ}} cs = do ts <- mapM (runExceptT . mkTransmission c) cs mapM_ (atomically . writeTBQueue sndQ . L.map fst) . L.nonEmpty . rights $ L.toList ts @@ -578,22 +581,22 @@ sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ}} cs = do Left e -> pure $ Left e -- | Send Protocol command -sendProtocolCommand :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT ProtocolClientError IO msg +sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}} pKey qId cmd = do (t, r) <- mkTransmission c (pKey, qId, cmd) ExceptT $ sendRecv t r where -- two separate "atomically" needed to avoid blocking - sendRecv :: SentRawTransmission -> TMVar (Response msg) -> IO (Response msg) + sendRecv :: SentRawTransmission -> TMVar (Response err msg) -> IO (Response err msg) sendRecv t r = atomically (writeTBQueue sndQ [t]) >> withTimeout c (atomically $ takeTMVar r) -withTimeout :: ProtocolClient msg -> IO (Either ProtocolClientError msg) -> IO (Either ProtocolClientError msg) +withTimeout :: ProtocolClient err msg -> IO (Either (ProtocolClientError err) msg) -> IO (Either (ProtocolClientError err) msg) withTimeout ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} a = timeout tcpTimeout a >>= \case Just r -> atomically (writeTVar pingErrorCount 0) >> pure r _ -> pure $ Left PCEResponseTimeout -mkTransmission :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> ClientCommand msg -> ExceptT ProtocolClientError IO (SentRawTransmission, TMVar (Response msg)) +mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> ExceptT (ProtocolClientError err) IO (SentRawTransmission, TMVar (Response err msg)) mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCorrId, sentCommands}} (pKey, qId, cmd) = do corrId <- liftIO $ atomically getNextCorrId let t = signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd) @@ -606,7 +609,7 @@ mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCo pure . CorrId $ bshow i signTransmission :: ByteString -> SentRawTransmission signTransmission t = ((`C.sign` t) <$> pKey, t) - mkRequest :: CorrId -> STM (TMVar (Response msg)) + mkRequest :: CorrId -> STM (TMVar (Response err msg)) mkRequest corrId = do r <- newEmptyTMVar TM.insert corrId (Request qId r) sentCommands diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index feefacecd..2d9562aea 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -38,14 +38,14 @@ import UnliftIO.Exception (Exception) import qualified UnliftIO.Exception as E import UnliftIO.STM -type SMPClientVar = TMVar (Either ProtocolClientError SMPClient) +type SMPClientVar = TMVar (Either SMPClientError SMPClient) data SMPClientAgentEvent = CAConnected SMPServer | CADisconnected SMPServer (Set SMPSub) | CAReconnected SMPServer | CAResubscribed SMPServer SMPSub - | CASubError SMPServer SMPSub ProtocolClientError + | CASubError SMPServer SMPSub SMPClientError data SMPSubParty = SPRecipient | SPNotifier deriving (Eq, Ord, Show) @@ -111,7 +111,7 @@ newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} = do asyncClients <- newTVar [] pure SMPClientAgent {agentCfg, msgQ, agentQ, smpClients, srvSubs, pendingSrvSubs, reconnections, asyncClients} -getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT ProtocolClientError IO SMPClient +getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT SMPClientError IO SMPClient getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = atomically getClientVar >>= either newSMPClient waitForSMPClient where @@ -124,7 +124,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = TM.insert srv smpVar smpClients pure smpVar - waitForSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient + waitForSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient waitForSMPClient smpVar = do let ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar smpVar) @@ -133,10 +133,10 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = Just (Left e) -> Left e Nothing -> Left PCEResponseTimeout - newSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient + newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient newSMPClient smpVar = tryConnectClient pure tryConnectAsync where - tryConnectClient :: (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO () -> ExceptT ProtocolClientError IO a + tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a tryConnectClient successAction retryAction = tryE connectClient >>= \r -> case r of Right smp -> do @@ -150,16 +150,16 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = putTMVar smpVar (Left e) TM.delete srv smpClients throwE e - tryConnectAsync :: ExceptT ProtocolClientError IO () + tryConnectAsync :: ExceptT SMPClientError IO () tryConnectAsync = do a <- async connectAsync atomically $ modifyTVar' (asyncClients ca) (a :) - connectAsync :: ExceptT ProtocolClientError IO () + connectAsync :: ExceptT SMPClientError IO () connectAsync = withRetryInterval (reconnectInterval agentCfg) $ \loop -> void $ tryConnectClient (const reconnectClient) loop - connectClient :: ExceptT ProtocolClientError IO SMPClient + connectClient :: ExceptT SMPClientError IO SMPClient connectClient = ExceptT $ getProtocolClient (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) clientDisconnected clientDisconnected :: SMPClient -> IO () @@ -188,17 +188,17 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = notify . CADisconnected srv $ M.keysSet ss reconnectServer - reconnectServer :: ExceptT ProtocolClientError IO () + reconnectServer :: ExceptT SMPClientError IO () reconnectServer = do a <- async tryReconnectClient atomically $ modifyTVar' (reconnections ca) (a :) - tryReconnectClient :: ExceptT ProtocolClientError IO () + tryReconnectClient :: ExceptT SMPClientError IO () tryReconnectClient = do withRetryInterval (reconnectInterval agentCfg) $ \loop -> reconnectClient `catchE` const loop - reconnectClient :: ExceptT ProtocolClientError IO () + reconnectClient :: ExceptT SMPClientError IO () reconnectClient = do withSMP ca srv $ \smp -> do notify $ CAReconnected srv @@ -207,13 +207,13 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = unlessM (atomically $ hasSub (srvSubs ca) srv s) $ subscribe_ smp sub `catchE` handleError s where - subscribe_ :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO () + subscribe_ :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT SMPClientError IO () subscribe_ smp sub@(s, _) = do smpSubscribe smp sub atomically $ addSubscription ca srv sub notify $ CAResubscribed srv s - handleError :: SMPSub -> ProtocolClientError -> ExceptT ProtocolClientError IO () + handleError :: SMPSub -> SMPClientError -> ExceptT SMPClientError IO () handleError s = \case e@PCEResponseTimeout -> throwE e e@PCENetworkError -> throwE e @@ -221,7 +221,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = notify $ CASubError srv s e atomically $ removePendingSubscription ca srv s - notify :: SMPClientAgentEvent -> ExceptT ProtocolClientError IO () + notify :: SMPClientAgentEvent -> ExceptT SMPClientError IO () notify evt = atomically $ writeTBQueue (agentQ ca) evt closeSMPClientAgent :: MonadUnliftIO m => SMPClientAgent -> m () @@ -241,15 +241,15 @@ closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeCli cancelActions :: Foldable f => TVar (f (Async ())) -> IO () cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel -withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO a +withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO a withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPError where - logSMPError :: ProtocolClientError -> ExceptT ProtocolClientError IO a + logSMPError :: SMPClientError -> ExceptT SMPClientError IO a logSMPError e = do liftIO $ putStrLn $ "SMP error (" <> show srv <> "): " <> show e throwE e -subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO () +subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> ExceptT SMPClientError IO () subscribeQueue ca srv sub = do atomically $ addPendingSubscription ca srv sub withSMP ca srv $ \smp -> subscribe_ smp `catchE` handleError @@ -267,7 +267,7 @@ showServer :: SMPServer -> ByteString showServer ProtocolServer {host, port} = strEncode host <> B.pack (if null port then "" else ':' : port) -smpSubscribe :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO () +smpSubscribe :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT SMPClientError IO () smpSubscribe smp ((party, queueId), privKey) = subscribe_ smp privKey queueId where subscribe_ = case party of diff --git a/src/Simplex/Messaging/Notifications/Client.hs b/src/Simplex/Messaging/Notifications/Client.hs index 8f690e955..dfd84c909 100644 --- a/src/Simplex/Messaging/Notifications/Client.hs +++ b/src/Simplex/Messaging/Notifications/Client.hs @@ -10,54 +10,57 @@ import Data.Word (Word16) import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Protocol (ErrorType) import Simplex.Messaging.Util (bshow) -type NtfClient = ProtocolClient NtfResponse +type NtfClient = ProtocolClient ErrorType NtfResponse -ntfRegisterToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519) +type NtfClientError = ProtocolClientError ErrorType + +ntfRegisterToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT NtfClientError IO (NtfTokenId, C.PublicKeyX25519) ntfRegisterToken c pKey newTkn = sendNtfCommand c (Just pKey) "" (TNEW newTkn) >>= \case NRTknId tknId dhKey -> pure (tknId, dhKey) r -> throwE . PCEUnexpectedResponse $ bshow r -ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT ProtocolClientError IO () +ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT NtfClientError IO () ntfVerifyToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId -ntfCheckToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO NtfTknStatus +ntfCheckToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT NtfClientError IO NtfTknStatus ntfCheckToken c pKey tknId = sendNtfCommand c (Just pKey) tknId TCHK >>= \case NRTkn stat -> pure stat r -> throwE . PCEUnexpectedResponse $ bshow r -ntfReplaceToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> DeviceToken -> ExceptT ProtocolClientError IO () +ntfReplaceToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> DeviceToken -> ExceptT NtfClientError IO () ntfReplaceToken c pKey tknId token = okNtfCommand (TRPL token) c pKey tknId -ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO () +ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT NtfClientError IO () ntfDeleteToken = okNtfCommand TDEL -ntfEnableCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO () +ntfEnableCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT NtfClientError IO () ntfEnableCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId -ntfCreateSubscription :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO NtfSubscriptionId +ntfCreateSubscription :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT NtfClientError IO NtfSubscriptionId ntfCreateSubscription c pKey newSub = sendNtfCommand c (Just pKey) "" (SNEW newSub) >>= \case NRSubId subId -> pure subId r -> throwE . PCEUnexpectedResponse $ bshow r -ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus +ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT NtfClientError IO NtfSubStatus ntfCheckSubscription c pKey subId = sendNtfCommand c (Just pKey) subId SCHK >>= \case NRSub stat -> pure stat r -> throwE . PCEUnexpectedResponse $ bshow r -ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO () +ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT NtfClientError IO () ntfDeleteSubscription = okNtfCommand SDEL -- | Send notification server command -sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT ProtocolClientError IO NtfResponse +sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT NtfClientError IO NtfResponse sendNtfCommand c pKey entId cmd = sendProtocolCommand c pKey entId (NtfCmd sNtfEntity cmd) -okNtfCommand :: NtfEntityI e => NtfCommand e -> NtfClient -> C.APrivateSignKey -> NtfEntityId -> ExceptT ProtocolClientError IO () +okNtfCommand :: NtfEntityI e => NtfCommand e -> NtfClient -> C.APrivateSignKey -> NtfEntityId -> ExceptT NtfClientError IO () okNtfCommand cmd c pKey entId = sendNtfCommand c (Just pKey) entId cmd >>= \case NROk -> return () diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 4e40e2fe1..90e32bc38 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -146,7 +147,7 @@ instance Encoding ANewNtfEntity where 'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP <*> smpP) _ -> fail "bad ANewNtfEntity" -instance Protocol NtfResponse where +instance Protocol ErrorType NtfResponse where type ProtoCommand NtfResponse = NtfCmd type ProtoType NtfResponse = 'PNTF protocolClientHandshake = ntfClientHandshake @@ -183,7 +184,7 @@ data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e) deriving instance Show NtfCmd -instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where +instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where type Tag (NtfCommand e) = NtfCommandTag e encodeProtocol _v = \case TNEW newTkn -> e (TNEW_, ' ', newTkn) @@ -202,6 +203,9 @@ instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where protocolP _v tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP _v (NCT (sNtfEntity @e) tag) + fromProtocolError = fromProtocolError @ErrorType @NtfResponse + {-# INLINE fromProtocolError #-} + checkCredentials (sig, _, entityId, _) cmd = case cmd of -- TNEW and SNEW must have signature but NOT token/subscription IDs TNEW {} -> sigNoEntity @@ -219,7 +223,7 @@ instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where | not (B.null entityId) = Left $ CMD HAS_AUTH | otherwise = Right cmd -instance ProtocolEncoding NtfCmd where +instance ProtocolEncoding ErrorType NtfCmd where type Tag NtfCmd = NtfCmdTag encodeProtocol _v (NtfCmd _ c) = encodeProtocol _v c @@ -239,6 +243,9 @@ instance ProtocolEncoding NtfCmd where SDEL_ -> pure SDEL PING_ -> pure PING + fromProtocolError = fromProtocolError @ErrorType @NtfResponse + {-# INLINE fromProtocolError #-} + checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c data NtfResponseTag @@ -283,7 +290,7 @@ data NtfResponse | NRPong deriving (Show) -instance ProtocolEncoding NtfResponse where +instance ProtocolEncoding ErrorType NtfResponse where type Tag NtfResponse = NtfResponseTag encodeProtocol _v = \case NRTknId entId dhKey -> e (NRTknId_, ' ', entId, dhKey) @@ -306,6 +313,13 @@ instance ProtocolEncoding NtfResponse where NRSub_ -> NRSub <$> _smpP NRPong_ -> pure NRPong + fromProtocolError = \case + PECmdSyntax -> CMD SYNTAX + PECmdUnknown -> CMD UNKNOWN + PESession -> SESSION + PEBlock -> BLOCK + {-# INLINE fromProtocolError #-} + checkCredentials (_, _, entId, _) cmd = case cmd of -- IDTKN response must not have queue ID NRTknId {} -> noEntity diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index bed8ea433..bbc7e3d43 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -27,7 +27,7 @@ import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime) import Data.Time.Clock.System (getSystemTime) import Data.Time.Format.ISO8601 (iso8601Show) import Network.Socket (ServiceName) -import Simplex.Messaging.Client (ProtocolClientError (..)) +import Simplex.Messaging.Client (ProtocolClientError (..), SMPClientError) import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String @@ -227,7 +227,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge where showServer' = decodeLatin1 . strEncode . host - handleSubError :: SMPQueueNtf -> ProtocolClientError -> M () + handleSubError :: SMPQueueNtf -> SMPClientError -> M () handleSubError smpQueue = \case PCEProtocolError AUTH -> updateSubStatus smpQueue NSAuth PCEProtocolError e -> updateErr "SMP error " e @@ -343,7 +343,7 @@ send h@THandle {thVersion = v} NtfServerClient {sndQ, sessionId, activeAt} = for data VerificationResult = VRVerified NtfRequest | VRFailed -verifyNtfTransmission :: SignedTransmission NtfCmd -> NtfCmd -> M VerificationResult +verifyNtfTransmission :: SignedTransmission ErrorType NtfCmd -> NtfCmd -> M VerificationResult verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do st <- asks store case cmd of diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 9056297db..238fc971a 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -4,6 +4,7 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -19,11 +20,10 @@ {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# HLINT ignore "Use newtype instead of data" #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} -{-# HLINT ignore "Use newtype instead of data" #-} - -- | -- Module : Simplex.Messaging.ProtocolEncoding -- Copyright : (c) simplex.chat @@ -52,6 +52,7 @@ module Simplex.Messaging.Protocol SParty (..), PartyI (..), QueueIdsKeys (..), + ProtocolErrorType (..), ErrorType (..), CommandError (..), Transmission, @@ -224,7 +225,7 @@ deriving instance Show Cmd type Transmission c = (CorrId, EntityId, c) -- | signed parsed transmission, with original raw bytes and parsing error. -type SignedTransmission c = (Maybe C.ASignature, Signed, Transmission (Either ErrorType c)) +type SignedTransmission e c = (Maybe C.ASignature, Signed, Transmission (Either e c)) type Signed = ByteString @@ -874,6 +875,8 @@ type MsgId = ByteString -- | SMP message body. type MsgBody = ByteString +data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock + -- | Type for protocol errors. data ErrorType = -- | incorrect block format, encoding or signature size @@ -944,16 +947,16 @@ transmissionP = do command <- A.takeByteString pure RawTransmission {signature, signed, sessId, corrId, entityId, command} -class (ProtocolEncoding msg, ProtocolEncoding (ProtoCommand msg), Show msg) => Protocol msg where +class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where type ProtoCommand msg = cmd | cmd -> msg type ProtoType msg = (sch :: ProtocolType) | sch -> msg protocolClientHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) protocolPing :: ProtoCommand msg - protocolError :: msg -> Maybe ErrorType + protocolError :: msg -> Maybe err type ProtoServer msg = ProtocolServer (ProtoType msg) -instance Protocol BrokerMsg where +instance Protocol ErrorType BrokerMsg where type ProtoCommand BrokerMsg = Cmd type ProtoType BrokerMsg = 'PSMP protocolClientHandshake = smpClientHandshake @@ -962,13 +965,14 @@ instance Protocol BrokerMsg where ERR e -> Just e _ -> Nothing -class ProtocolMsgTag (Tag msg) => ProtocolEncoding msg where +class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where type Tag msg encodeProtocol :: Version -> msg -> ByteString protocolP :: Version -> Tag msg -> Parser msg - checkCredentials :: SignedRawTransmission -> msg -> Either ErrorType msg + fromProtocolError :: ProtocolErrorType -> err + checkCredentials :: SignedRawTransmission -> msg -> Either err msg -instance PartyI p => ProtocolEncoding (Command p) where +instance PartyI p => ProtocolEncoding ErrorType (Command p) where type Tag (Command p) = CommandTag p encodeProtocol v = \case NEW rKey dhKey auth_ -> case auth_ of @@ -999,6 +1003,9 @@ instance PartyI p => ProtocolEncoding (Command p) where protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag) + fromProtocolError = fromProtocolError @ErrorType @BrokerMsg + {-# INLINE fromProtocolError #-} + checkCredentials (sig, _, queueId, _) cmd = case cmd of -- NEW must have signature but NOT queue ID NEW {} @@ -1018,7 +1025,7 @@ instance PartyI p => ProtocolEncoding (Command p) where | isNothing sig || B.null queueId -> Left $ CMD NO_AUTH | otherwise -> Right cmd -instance ProtocolEncoding Cmd where +instance ProtocolEncoding ErrorType Cmd where type Tag Cmd = CmdTag encodeProtocol v (Cmd _ c) = encodeProtocol v c @@ -1048,9 +1055,12 @@ instance ProtocolEncoding Cmd where PING_ -> pure PING CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB + fromProtocolError = fromProtocolError @ErrorType @BrokerMsg + {-# INLINE fromProtocolError #-} + checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c -instance ProtocolEncoding BrokerMsg where +instance ProtocolEncoding ErrorType BrokerMsg where type Tag BrokerMsg = BrokerMsgTag encodeProtocol v = \case IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh) @@ -1085,6 +1095,13 @@ instance ProtocolEncoding BrokerMsg where ERR_ -> ERR <$> _smpP PONG_ -> pure PONG + fromProtocolError = \case + PECmdSyntax -> CMD SYNTAX + PECmdUnknown -> CMD UNKNOWN + PESession -> SESSION + PEBlock -> BLOCK + {-# INLINE fromProtocolError #-} + checkCredentials (_, _, queueId, _) cmd = case cmd of -- IDS response should not have queue ID IDS _ -> Right cmd @@ -1103,12 +1120,12 @@ _smpP :: Encoding a => Parser a _smpP = A.space *> smpP -- | Parse SMP protocol commands and broker messages -parseProtocol :: ProtocolEncoding msg => Version -> ByteString -> Either ErrorType msg +parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg parseProtocol v s = let (tag, params) = B.break (== ' ') s in case decodeTag tag of - Just cmd -> parse (protocolP v cmd) (CMD SYNTAX) params - Nothing -> Left $ CMD UNKNOWN + Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params + Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p) checkParty c = case testEquality (sParty @p) (sParty @p') of @@ -1203,7 +1220,7 @@ tEncode (sig, t) = smpEncode (C.signatureBytes sig) <> t tEncodeBatch :: Int -> ByteString -> ByteString tEncodeBatch n s = lenEncode n `B.cons` s -encodeTransmission :: ProtocolEncoding c => Version -> ByteString -> Transmission c -> ByteString +encodeTransmission :: ProtocolEncoding e c => Version -> ByteString -> Transmission c -> ByteString encodeTransmission v sessionId (CorrId corrId, queueId, command) = smpEncode (sessionId, corrId, queueId) <> encodeProtocol v command @@ -1223,22 +1240,22 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b eitherList = either (\e -> [Left e]) -- | Receive client and server transmissions (determined by `cmd` type). -tGet :: forall cmd c. (ProtocolEncoding cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission cmd)) +tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd)) tGet th@THandle {sessionId, thVersion = v} = L.map (tDecodeParseValidate sessionId v) <$> tGetParse th -tDecodeParseValidate :: forall cmd. ProtocolEncoding cmd => SessionId -> Version -> Either TransportError RawTransmission -> SignedTransmission cmd +tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => SessionId -> Version -> Either TransportError RawTransmission -> SignedTransmission err cmd tDecodeParseValidate sessionId v = \case Right RawTransmission {signature, signed, sessId, corrId, entityId, command} | sessId == sessionId -> let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature in either (const $ tError corrId) (tParseValidate signed) decodedTransmission - | otherwise -> (Nothing, "", (CorrId corrId, "", Left SESSION)) + | otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession)) Left _ -> tError "" where - tError :: ByteString -> SignedTransmission cmd - tError corrId = (Nothing, "", (CorrId corrId, "", Left BLOCK)) + tError :: ByteString -> SignedTransmission err cmd + tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock)) - tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission cmd + tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd tParseValidate signed t@(sig, corrId, entityId, command) = - let cmd = parseProtocol v command >>= checkCredentials t + let cmd = parseProtocol @err @cmd v command >>= checkCredentials t in (sig, signed, (CorrId corrId, entityId, cmd)) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index e66da37be..bd8c5330c 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -266,7 +266,7 @@ receive th Client {rcvQ, sndQ, activeAt} = forever $ do write sndQ $ fst as write rcvQ $ snd as where - cmdAction :: SignedTransmission Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd)) + cmdAction :: SignedTransmission ErrorType Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd)) cmdAction (sig, signed, (corrId, queueId, cmdOrError)) = case cmdOrError of Left e -> pure $ Left (corrId, queueId, ERR e) diff --git a/tests/CoreTests/ProtocolErrorTests.hs b/tests/CoreTests/ProtocolErrorTests.hs index 30964be2e..39a00eb88 100644 --- a/tests/CoreTests/ProtocolErrorTests.hs +++ b/tests/CoreTests/ProtocolErrorTests.hs @@ -6,7 +6,7 @@ module CoreTests.ProtocolErrorTests where import qualified Data.ByteString.Char8 as B import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) -import Simplex.Messaging.Agent.Protocol (AgentErrorType (..)) +import Simplex.Messaging.Agent.Protocol (AgentErrorType (..), BrokerErrorType (..)) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll) import Test.Hspec @@ -17,12 +17,14 @@ protocolErrorTests :: Spec protocolErrorTests = modifyMaxSuccess (const 1000) $ do describe "errors parsing / serializing" $ do it "should parse SMP protocol errors" . property $ \(err :: AgentErrorType) -> - errServerHasSpaces err + errHasSpaces err || parseAll strP (strEncode err) == Right err it "should parse SMP agent errors" . property $ \(err :: AgentErrorType) -> - errServerHasSpaces err + errHasSpaces err || parseAll strP (strEncode err) == Right err where - errServerHasSpaces = \case - BROKER srv _ -> ' ' `B.elem` encodeUtf8 (T.pack srv) + errHasSpaces = \case + BROKER srv (RESPONSE e) -> hasSpaces srv || hasSpaces e + BROKER srv _ -> hasSpaces srv _ -> False + hasSpaces s = ' ' `B.elem` encodeUtf8 (T.pack s) diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index 7601652d1..fab7fde51 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -64,16 +64,16 @@ ntfSyntaxTests (ATransport t) = do Expectation command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response -pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission NtfResponse +pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission ErrorType NtfResponse pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command)) -sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse) +sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) sendRecvNtf h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, t) tGet1 h -signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse) +signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) signSendRecvNtf h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd) Right () <- tPut1 h (Just $ C.sign pk t, t) diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 884e88a29..09cb2854a 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -63,7 +63,7 @@ serverTests t@(ATransport t') = do testMsgExpireOnInterval t' testMsgNOTExpireOnInterval t' -pattern Resp :: CorrId -> QueueId -> BrokerMsg -> SignedTransmission BrokerMsg +pattern Resp :: CorrId -> QueueId -> BrokerMsg -> SignedTransmission ErrorType BrokerMsg pattern Resp corrId queueId command <- (_, _, (corrId, queueId, Right command)) pattern Ids :: RecipientId -> SenderId -> RcvPublicDhKey -> BrokerMsg @@ -72,13 +72,13 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh) pattern Msg :: MsgId -> MsgBody -> BrokerMsg pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg) +sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) sendRecv h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, t) tGet1 h -signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg) +signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) signSendRecv h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd) Right () <- tPut1 h (Just $ C.sign pk t, t) @@ -89,7 +89,7 @@ tPut1 h t = do [r] <- tPut h [t] pure r -tGet1 :: (ProtocolEncoding cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission cmd) +tGet1 :: (ProtocolEncoding err cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission err cmd) tGet1 h = do [r] <- liftIO $ tGet h pure r @@ -428,7 +428,7 @@ testSwitchSub (ATransport t) = Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, ACK mId3) (ok3, OK) #== "accepts ACK from the 2nd TCP connection" - 1000 `timeout` tGet @BrokerMsg rh1 >>= \case + 1000 `timeout` tGet @ErrorType @BrokerMsg rh1 >>= \case Nothing -> return () Just _ -> error "nothing else is delivered to the 1st TCP connection" @@ -869,14 +869,14 @@ testMessageNotifications (ATransport t) = Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2) (dec mId2 msg2, Right "hello again") #== "delivered from queue again" Resp "" _ (NMSG _ _) <- tGet1 nh2 - 1000 `timeout` tGet @BrokerMsg nh1 >>= \case + 1000 `timeout` tGet @ErrorType @BrokerMsg nh1 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection" Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL) Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there") Resp "" _ (Msg mId3 msg3) <- tGet1 rh (dec mId3 msg3, Right "hello there") #== "delivered from queue again" - 1000 `timeout` tGet @BrokerMsg nh2 >>= \case + 1000 `timeout` tGet @ErrorType @BrokerMsg nh2 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection" @@ -895,7 +895,7 @@ testMsgExpireOnSend t = testSMPClient @c $ \rh -> do Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB) (dec mId msg, Right "hello (should NOT expire)") #== "delivered" - 1000 `timeout` tGet @BrokerMsg rh >>= \case + 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing else should be delivered" @@ -911,7 +911,7 @@ testMsgExpireOnInterval t = threadDelay 2500000 testSMPClient @c $ \rh -> do Resp "2" _ OK <- signSendRecv rh rKey ("2", rId, SUB) - 1000 `timeout` tGet @BrokerMsg rh >>= \case + 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing should be delivered" @@ -929,7 +929,7 @@ testMsgNOTExpireOnInterval t = testSMPClient @c $ \rh -> do Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB) (dec mId msg, Right "hello (should NOT expire)") #== "delivered" - 1000 `timeout` tGet @BrokerMsg rh >>= \case + 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing else should be delivered"