parameterize protocol by error type (#644)

This commit is contained in:
Evgeny Poberezkin
2023-02-17 20:46:01 +00:00
committed by GitHub
parent 2ae3100bed
commit 2ddfb044fc
12 changed files with 216 additions and 176 deletions
+35 -34
View File
@@ -5,6 +5,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
@@ -162,7 +163,7 @@ import UnliftIO (mapConcurrently)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
type ClientVar msg = TMVar (Either AgentErrorType (ProtocolClient msg))
type ClientVar err msg = TMVar (Either AgentErrorType (ProtocolClient err msg))
type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
@@ -319,15 +320,15 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
agentClientStore :: AgentClient -> SQLiteStore
agentClientStore AgentClient {agentEnv = Env {store}} = store
class ProtocolServerClient msg where
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient msg)
clientProtocolError :: ErrorType -> AgentErrorType
class ProtocolServerClient err msg | msg -> err where
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient err msg)
clientProtocolError :: err -> AgentErrorType
instance ProtocolServerClient BrokerMsg where
instance ProtocolServerClient ErrorType BrokerMsg where
getProtocolServerClient = getSMPServerClient
clientProtocolError = SMP
instance ProtocolServerClient NtfResponse where
instance ProtocolServerClient ErrorType NtfResponse where
getProtocolServerClient = getNtfServerClient
clientProtocolError = NTF
@@ -428,7 +429,7 @@ getClientVar tSess clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM
TM.insert tSess var clients
pure var
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (ProtocolClient msg)
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar err msg -> m (ProtocolClient err msg)
waitForProtocolClient c (_, srv, _) clientVar = do
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
@@ -438,18 +439,18 @@ waitForProtocolClient c (_, srv, _) clientVar = do
Nothing -> Left $ BROKER (B.unpack $ strEncode srv) TIMEOUT
newProtocolClient ::
forall msg m.
forall err msg m.
(AgentMonad m, ProtocolTypeI (ProtoType msg)) =>
AgentClient ->
TransportSession msg ->
TMap (TransportSession msg) (ClientVar msg) ->
m (ProtocolClient msg) ->
TMap (TransportSession msg) (ClientVar err msg) ->
m (ProtocolClient err msg) ->
(AgentClient -> TransportSession msg -> m ()) ->
ClientVar msg ->
m (ProtocolClient msg)
ClientVar err msg ->
m (ProtocolClient err msg)
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
where
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
tryConnectClient :: (ProtocolClient err msg -> m a) -> m () -> m a
tryConnectClient successAction retryAction =
tryError connectClient >>= \r -> case r of
Right client -> do
@@ -474,7 +475,7 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconne
withRetryInterval ri $ \loop -> void $ tryConnectClient (const $ reconnectClient c tSess) loop
atomically . removeAsyncAction aId $ asyncClients c
hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent
hostEvent :: forall err msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient err msg -> ACommand 'Agent
hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost' client
getClientConfig :: AgentMonad m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig
@@ -518,7 +519,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} =
where
k = (server, sndId)
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar err msg)) -> IO ()
closeProtocolServerClients c clientsSel =
atomically (swapTVar cs M.empty) >>= mapM_ (forkIO . closeClient)
where
@@ -541,32 +542,32 @@ withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pur
where
newLock = createLock >>= \l -> TM.insert key l locks $> l
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
withClient_ :: forall a m err msg. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient err msg -> m a) -> m a
withClient_ c tSess@(userId, srv, _) statCmd action = do
cl <- getProtocolServerClient c tSess
(action cl <* stat cl "OK") `catchError` logServerError cl
where
stat cl = liftIO . incClientStat c userId cl statCmd
logServerError :: ProtocolClient msg -> AgentErrorType -> m a
logServerError :: ProtocolClient err msg -> AgentErrorType -> m a
logServerError cl e = do
logServer "<--" c srv "" $ strEncode e
stat cl $ strEncode e
throwError e
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> m a) -> m a
withLogClient_ :: (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient err msg -> m a) -> m a
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
logServer "-->" c srv entId cmdStr
res <- withClient_ c tSess cmdStr action
logServer "<--" c srv entId "OK"
return res
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg, ProtocolTypeI (ProtoType msg), Encoding err, Show err) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient err msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withLogClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg, ProtocolTypeI (ProtoType msg), Encoding err, Show err) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient err msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT ProtocolClientError IO a) -> m a
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
withSMPClient c q cmdStr action = do
tSess <- mkSMPTransportSession c q
withLogClient c tSess (queueId q) cmdStr action
@@ -576,16 +577,16 @@ withSMPClient_ c q cmdStr action = do
tSess <- mkSMPTransportSession c q
withLogClient_ c tSess (queueId q) cmdStr action
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT ProtocolClientError IO a) -> m a
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT NtfClientError IO a) -> m a
withNtfClient c srv = withLogClient c (0, srv, Nothing)
liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> HostName -> ExceptT ProtocolClientError IO a -> m a
liftClient :: (AgentMonad m, Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> m a
liftClient protocolError_ = liftError . protocolClientError protocolError_
protocolClientError :: (ErrorType -> AgentErrorType) -> HostName -> ProtocolClientError -> AgentErrorType
protocolClientError :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ProtocolClientError err -> AgentErrorType
protocolClientError protocolError_ host = \case
PCEProtocolError e -> protocolError_ e
PCEResponseError e -> BROKER host $ RESPONSE e
PCEResponseError e -> BROKER host $ RESPONSE $ B.unpack $ smpEncode e
PCEUnexpectedResponse _ -> BROKER host UNEXPECTED
PCEResponseTimeout -> BROKER host TIMEOUT
PCENetworkError -> BROKER host NETWORK
@@ -632,7 +633,7 @@ runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
Left e -> pure (Just $ testErr TSConnect e)
where
addr = B.unpack $ strEncode srv
testErr :: SMPTestStep -> ProtocolClientError -> SMPTestFailure
testErr :: SMPTestStep -> SMPClientError -> SMPTestFailure
testErr step = SMPTestFailure step . protocolClientError SMP addr
mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg)
@@ -682,7 +683,7 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do
}
pure (rq, SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey)
processSubResult :: AgentClient -> RcvQueue -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
processSubResult :: AgentClient -> RcvQueue -> Either SMPClientError () -> IO (Either SMPClientError ())
processSubResult c rq r = do
case r of
Left e ->
@@ -691,7 +692,7 @@ processSubResult c rq r = do
_ -> addSubscription c rq
pure r
temporaryClientError :: ProtocolClientError -> Bool
temporaryClientError :: ProtocolClientError err -> Bool
temporaryClientError = \case
PCENetworkError -> True
PCEResponseTimeout -> True
@@ -732,7 +733,7 @@ subscribeQueues c qs = do
type BatchResponses e = (NonEmpty (RcvQueue, Either e ()))
-- statBatchSize is not used to batch the commands, only for traffic statistics
sendTSessionBatches :: forall m. AgentMonad m => ByteString -> Int -> (SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)) -> AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
sendTSessionBatches :: forall m. AgentMonad m => ByteString -> Int -> (SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError)) -> AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
sendTSessionBatches statCmd statBatchSize action c qs =
concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues)
where
@@ -759,7 +760,7 @@ sendTSessionBatches statCmd statBatchSize action c qs =
let n = (length qs - 1) `div` statBatchSize + 1
in incClientStatN c userId smp n (statCmd <> "S") "OK"
sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateSignKey, SMP.RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)
sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateSignKey, SMP.RecipientId) -> IO (NonEmpty (Either SMPClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError)
sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs)
where
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
@@ -1047,7 +1048,7 @@ incStat AgentClient {agentStats} n k = do
Just v -> modifyTVar' v (+ n)
_ -> newTVar n >>= \v -> TM.insert k v agentStats
incClientStat :: AgentClient -> UserId -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
incClientStat :: AgentClient -> UserId -> ProtocolClient err msg -> ByteString -> ByteString -> IO ()
incClientStat c userId pc = incClientStatN c userId pc 1
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
@@ -1057,7 +1058,7 @@ incServerStat c userId ProtocolServer {host} cmd res = do
where
statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res}
incClientStatN :: AgentClient -> UserId -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN :: AgentClient -> UserId -> ProtocolClient err msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN c userId pc n cmd res = do
atomically $ incStat c n statsKey
where
+9 -9
View File
@@ -1121,7 +1121,7 @@ instance ToJSON ConnectionErrorType where
-- | SMP server errors.
data BrokerErrorType
= -- | invalid server response (failed to parse)
RESPONSE {smpErr :: ErrorType}
RESPONSE {smpErr :: String}
| -- | unexpected response
UNEXPECTED
| -- | network error
@@ -1164,27 +1164,27 @@ instance StrEncoding AgentErrorType where
<|> "CONN " *> (CONN <$> parseRead1)
<|> "SMP " *> (SMP <$> strP)
<|> "NTF " *> (NTF <$> strP)
<|> "BROKER " *> (BROKER <$> srvP <* " RESPONSE " <*> (RESPONSE <$> strP))
<|> "BROKER " *> (BROKER <$> srvP <* " TRANSPORT " <*> (TRANSPORT <$> transportErrorP))
<|> "BROKER " *> (BROKER <$> srvP <* A.space <*> parseRead1)
<|> "BROKER " *> (BROKER <$> textP <* " RESPONSE " <*> (RESPONSE <$> textP))
<|> "BROKER " *> (BROKER <$> textP <* " TRANSPORT " <*> (TRANSPORT <$> transportErrorP))
<|> "BROKER " *> (BROKER <$> textP <* A.space <*> parseRead1)
<|> "AGENT QUEUE " *> (AGENT . A_QUEUE <$> parseRead A.takeByteString)
<|> "AGENT " *> (AGENT <$> parseRead1)
<|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString)
where
srvP = T.unpack . safeDecodeUtf8 <$> A.takeTill (== ' ')
textP = T.unpack . safeDecodeUtf8 <$> A.takeTill (== ' ')
strEncode = \case
CMD e -> "CMD " <> bshow e
CONN e -> "CONN " <> bshow e
SMP e -> "SMP " <> strEncode e
NTF e -> "NTF " <> strEncode e
BROKER srv (RESPONSE e) -> "BROKER " <> addr srv <> " RESPONSE " <> strEncode e
BROKER srv (TRANSPORT e) -> "BROKER " <> addr srv <> " TRANSPORT " <> serializeTransportError e
BROKER srv e -> "BROKER " <> addr srv <> " " <> bshow e
BROKER srv (RESPONSE e) -> "BROKER " <> text srv <> " RESPONSE " <> text e
BROKER srv (TRANSPORT e) -> "BROKER " <> text srv <> " TRANSPORT " <> serializeTransportError e
BROKER srv e -> "BROKER " <> text srv <> " " <> bshow e
AGENT (A_QUEUE e) -> "AGENT QUEUE " <> bshow e
AGENT e -> "AGENT " <> bshow e
INTERNAL e -> "INTERNAL " <> bshow e
where
addr = encodeUtf8 . T.pack
text = encodeUtf8 . T.pack
instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU