mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-04-25 22:52:15 +00:00
Merge branch 'master' into xftp
This commit is contained in:
@@ -71,6 +71,7 @@ library
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230217_server_key_hash
|
||||
Simplex.Messaging.Agent.TAsyncs
|
||||
Simplex.Messaging.Agent.TRcvQueues
|
||||
Simplex.Messaging.Client
|
||||
|
||||
@@ -757,7 +757,7 @@ sendMessage' c connId msgFlags msg = withConnLock c connId "sendMessage" $ do
|
||||
enqueueCommand :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> m ()
|
||||
enqueueCommand c corrId connId server aCommand = do
|
||||
resumeSrvCmds c server
|
||||
commandId <- withStore' c $ \db -> createCommand db corrId connId server aCommand
|
||||
commandId <- withStore c $ \db -> createCommand db corrId connId server aCommand
|
||||
queuePendingCommands c server [commandId]
|
||||
|
||||
resumeSrvCmds :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m ()
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
@@ -162,7 +163,7 @@ import UnliftIO (mapConcurrently)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
type ClientVar msg = TMVar (Either AgentErrorType (ProtocolClient msg))
|
||||
type ClientVar err msg = TMVar (Either AgentErrorType (ProtocolClient err msg))
|
||||
|
||||
type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
|
||||
|
||||
@@ -319,15 +320,15 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
|
||||
agentClientStore :: AgentClient -> SQLiteStore
|
||||
agentClientStore AgentClient {agentEnv = Env {store}} = store
|
||||
|
||||
class ProtocolServerClient msg where
|
||||
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient msg)
|
||||
clientProtocolError :: ErrorType -> AgentErrorType
|
||||
class ProtocolServerClient err msg | msg -> err where
|
||||
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient err msg)
|
||||
clientProtocolError :: err -> AgentErrorType
|
||||
|
||||
instance ProtocolServerClient BrokerMsg where
|
||||
instance ProtocolServerClient ErrorType BrokerMsg where
|
||||
getProtocolServerClient = getSMPServerClient
|
||||
clientProtocolError = SMP
|
||||
|
||||
instance ProtocolServerClient NtfResponse where
|
||||
instance ProtocolServerClient ErrorType NtfResponse where
|
||||
getProtocolServerClient = getNtfServerClient
|
||||
clientProtocolError = NTF
|
||||
|
||||
@@ -428,7 +429,7 @@ getClientVar tSess clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM
|
||||
TM.insert tSess var clients
|
||||
pure var
|
||||
|
||||
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (ProtocolClient msg)
|
||||
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar err msg -> m (ProtocolClient err msg)
|
||||
waitForProtocolClient c (_, srv, _) clientVar = do
|
||||
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
|
||||
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
|
||||
@@ -438,18 +439,18 @@ waitForProtocolClient c (_, srv, _) clientVar = do
|
||||
Nothing -> Left $ BROKER (B.unpack $ strEncode srv) TIMEOUT
|
||||
|
||||
newProtocolClient ::
|
||||
forall msg m.
|
||||
forall err msg m.
|
||||
(AgentMonad m, ProtocolTypeI (ProtoType msg)) =>
|
||||
AgentClient ->
|
||||
TransportSession msg ->
|
||||
TMap (TransportSession msg) (ClientVar msg) ->
|
||||
m (ProtocolClient msg) ->
|
||||
TMap (TransportSession msg) (ClientVar err msg) ->
|
||||
m (ProtocolClient err msg) ->
|
||||
(AgentClient -> TransportSession msg -> m ()) ->
|
||||
ClientVar msg ->
|
||||
m (ProtocolClient msg)
|
||||
ClientVar err msg ->
|
||||
m (ProtocolClient err msg)
|
||||
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
|
||||
where
|
||||
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
|
||||
tryConnectClient :: (ProtocolClient err msg -> m a) -> m () -> m a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryError connectClient >>= \r -> case r of
|
||||
Right client -> do
|
||||
@@ -474,7 +475,7 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconne
|
||||
withRetryInterval ri $ \loop -> void $ tryConnectClient (const $ reconnectClient c tSess) loop
|
||||
atomically . removeAsyncAction aId $ asyncClients c
|
||||
|
||||
hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent
|
||||
hostEvent :: forall err msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient err msg -> ACommand 'Agent
|
||||
hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost' client
|
||||
|
||||
getClientConfig :: AgentMonad m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig
|
||||
@@ -518,7 +519,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} =
|
||||
where
|
||||
k = (server, sndId)
|
||||
|
||||
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
|
||||
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar err msg)) -> IO ()
|
||||
closeProtocolServerClients c clientsSel =
|
||||
atomically (swapTVar cs M.empty) >>= mapM_ (forkIO . closeClient)
|
||||
where
|
||||
@@ -541,32 +542,32 @@ withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pur
|
||||
where
|
||||
newLock = createLock >>= \l -> TM.insert key l locks $> l
|
||||
|
||||
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withClient_ :: forall a m err msg. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient err msg -> m a) -> m a
|
||||
withClient_ c tSess@(userId, srv, _) statCmd action = do
|
||||
cl <- getProtocolServerClient c tSess
|
||||
(action cl <* stat cl "OK") `catchError` logServerError cl
|
||||
where
|
||||
stat cl = liftIO . incClientStat c userId cl statCmd
|
||||
logServerError :: ProtocolClient msg -> AgentErrorType -> m a
|
||||
logServerError :: ProtocolClient err msg -> AgentErrorType -> m a
|
||||
logServerError cl e = do
|
||||
logServer "<--" c srv "" $ strEncode e
|
||||
stat cl $ strEncode e
|
||||
throwError e
|
||||
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient err msg -> m a) -> m a
|
||||
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
|
||||
logServer "-->" c srv entId cmdStr
|
||||
res <- withClient_ c tSess cmdStr action
|
||||
logServer "<--" c srv entId "OK"
|
||||
return res
|
||||
|
||||
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
|
||||
withClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg, ProtocolTypeI (ProtoType msg), Encoding err, Show err) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient err msg -> ExceptT (ProtocolClientError err) IO a) -> m a
|
||||
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
|
||||
|
||||
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
|
||||
withLogClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg, ProtocolTypeI (ProtoType msg), Encoding err, Show err) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient err msg -> ExceptT (ProtocolClientError err) IO a) -> m a
|
||||
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
|
||||
|
||||
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withSMPClient c q cmdStr action = do
|
||||
tSess <- mkSMPTransportSession c q
|
||||
withLogClient c tSess (queueId q) cmdStr action
|
||||
@@ -576,16 +577,16 @@ withSMPClient_ c q cmdStr action = do
|
||||
tSess <- mkSMPTransportSession c q
|
||||
withLogClient_ c tSess (queueId q) cmdStr action
|
||||
|
||||
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT NtfClientError IO a) -> m a
|
||||
withNtfClient c srv = withLogClient c (0, srv, Nothing)
|
||||
|
||||
liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> HostName -> ExceptT ProtocolClientError IO a -> m a
|
||||
liftClient :: (AgentMonad m, Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> m a
|
||||
liftClient protocolError_ = liftError . protocolClientError protocolError_
|
||||
|
||||
protocolClientError :: (ErrorType -> AgentErrorType) -> HostName -> ProtocolClientError -> AgentErrorType
|
||||
protocolClientError :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ProtocolClientError err -> AgentErrorType
|
||||
protocolClientError protocolError_ host = \case
|
||||
PCEProtocolError e -> protocolError_ e
|
||||
PCEResponseError e -> BROKER host $ RESPONSE e
|
||||
PCEResponseError e -> BROKER host $ RESPONSE $ B.unpack $ smpEncode e
|
||||
PCEUnexpectedResponse _ -> BROKER host UNEXPECTED
|
||||
PCEResponseTimeout -> BROKER host TIMEOUT
|
||||
PCENetworkError -> BROKER host NETWORK
|
||||
@@ -632,7 +633,7 @@ runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
Left e -> pure (Just $ testErr TSConnect e)
|
||||
where
|
||||
addr = B.unpack $ strEncode srv
|
||||
testErr :: SMPTestStep -> ProtocolClientError -> SMPTestFailure
|
||||
testErr :: SMPTestStep -> SMPClientError -> SMPTestFailure
|
||||
testErr step = SMPTestFailure step . protocolClientError SMP addr
|
||||
|
||||
mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg)
|
||||
@@ -682,7 +683,7 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do
|
||||
}
|
||||
pure (rq, SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey)
|
||||
|
||||
processSubResult :: AgentClient -> RcvQueue -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
|
||||
processSubResult :: AgentClient -> RcvQueue -> Either SMPClientError () -> IO (Either SMPClientError ())
|
||||
processSubResult c rq r = do
|
||||
case r of
|
||||
Left e ->
|
||||
@@ -725,7 +726,7 @@ subscribeQueues c qs = do
|
||||
type BatchResponses e = (NonEmpty (RcvQueue, Either e ()))
|
||||
|
||||
-- statBatchSize is not used to batch the commands, only for traffic statistics
|
||||
sendTSessionBatches :: forall m. AgentMonad m => ByteString -> Int -> (SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)) -> AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
sendTSessionBatches :: forall m. AgentMonad m => ByteString -> Int -> (SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError)) -> AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
sendTSessionBatches statCmd statBatchSize action c qs =
|
||||
concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues)
|
||||
where
|
||||
@@ -752,7 +753,7 @@ sendTSessionBatches statCmd statBatchSize action c qs =
|
||||
let n = (length qs - 1) `div` statBatchSize + 1
|
||||
in incClientStatN c userId smp n (statCmd <> "S") "OK"
|
||||
|
||||
sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateSignKey, SMP.RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)
|
||||
sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateSignKey, SMP.RecipientId) -> IO (NonEmpty (Either SMPClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError)
|
||||
sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs)
|
||||
where
|
||||
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
|
||||
@@ -1040,7 +1041,7 @@ incStat AgentClient {agentStats} n k = do
|
||||
Just v -> modifyTVar' v (+ n)
|
||||
_ -> newTVar n >>= \v -> TM.insert k v agentStats
|
||||
|
||||
incClientStat :: AgentClient -> UserId -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
|
||||
incClientStat :: AgentClient -> UserId -> ProtocolClient err msg -> ByteString -> ByteString -> IO ()
|
||||
incClientStat c userId pc = incClientStatN c userId pc 1
|
||||
|
||||
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
|
||||
@@ -1050,7 +1051,7 @@ incServerStat c userId ProtocolServer {host} cmd res = do
|
||||
where
|
||||
statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res}
|
||||
|
||||
incClientStatN :: AgentClient -> UserId -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
|
||||
incClientStatN :: AgentClient -> UserId -> ProtocolClient err msg -> Int -> ByteString -> ByteString -> IO ()
|
||||
incClientStatN c userId pc n cmd res = do
|
||||
atomically $ incStat c n statsKey
|
||||
where
|
||||
|
||||
@@ -82,11 +82,11 @@ processNtfSub c (connId, cmd) = do
|
||||
case clientNtfCreds of
|
||||
Just ClientNtfCreds {notifierId} -> do
|
||||
let newSub = newNtfSubscription connId smpServer (Just notifierId) ntfServer NASKey
|
||||
withStore' c $ \db -> createNtfSubscription db newSub $ NtfSubNTFAction NSACreate
|
||||
withStore c $ \db -> createNtfSubscription db newSub $ NtfSubNTFAction NSACreate
|
||||
addNtfNTFWorker ntfServer
|
||||
Nothing -> do
|
||||
let newSub = newNtfSubscription connId smpServer Nothing ntfServer NASNew
|
||||
withStore' c $ \db -> createNtfSubscription db newSub $ NtfSubSMPAction NSASmpKey
|
||||
withStore c $ \db -> createNtfSubscription db newSub $ NtfSubSMPAction NSASmpKey
|
||||
addNtfSMPWorker smpServer
|
||||
(Just (sub@NtfSubscription {ntfSubStatus, ntfServer = subNtfServer, smpServer = smpServer', ntfQueueId}, action_)) -> do
|
||||
case (clientNtfCreds, ntfQueueId) of
|
||||
|
||||
@@ -1121,7 +1121,7 @@ instance ToJSON ConnectionErrorType where
|
||||
-- | SMP server errors.
|
||||
data BrokerErrorType
|
||||
= -- | invalid server response (failed to parse)
|
||||
RESPONSE {smpErr :: ErrorType}
|
||||
RESPONSE {smpErr :: String}
|
||||
| -- | unexpected response
|
||||
UNEXPECTED
|
||||
| -- | network error
|
||||
@@ -1164,27 +1164,27 @@ instance StrEncoding AgentErrorType where
|
||||
<|> "CONN " *> (CONN <$> parseRead1)
|
||||
<|> "SMP " *> (SMP <$> strP)
|
||||
<|> "NTF " *> (NTF <$> strP)
|
||||
<|> "BROKER " *> (BROKER <$> srvP <* " RESPONSE " <*> (RESPONSE <$> strP))
|
||||
<|> "BROKER " *> (BROKER <$> srvP <* " TRANSPORT " <*> (TRANSPORT <$> transportErrorP))
|
||||
<|> "BROKER " *> (BROKER <$> srvP <* A.space <*> parseRead1)
|
||||
<|> "BROKER " *> (BROKER <$> textP <* " RESPONSE " <*> (RESPONSE <$> textP))
|
||||
<|> "BROKER " *> (BROKER <$> textP <* " TRANSPORT " <*> (TRANSPORT <$> transportErrorP))
|
||||
<|> "BROKER " *> (BROKER <$> textP <* A.space <*> parseRead1)
|
||||
<|> "AGENT QUEUE " *> (AGENT . A_QUEUE <$> parseRead A.takeByteString)
|
||||
<|> "AGENT " *> (AGENT <$> parseRead1)
|
||||
<|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString)
|
||||
where
|
||||
srvP = T.unpack . safeDecodeUtf8 <$> A.takeTill (== ' ')
|
||||
textP = T.unpack . safeDecodeUtf8 <$> A.takeTill (== ' ')
|
||||
strEncode = \case
|
||||
CMD e -> "CMD " <> bshow e
|
||||
CONN e -> "CONN " <> bshow e
|
||||
SMP e -> "SMP " <> strEncode e
|
||||
NTF e -> "NTF " <> strEncode e
|
||||
BROKER srv (RESPONSE e) -> "BROKER " <> addr srv <> " RESPONSE " <> strEncode e
|
||||
BROKER srv (TRANSPORT e) -> "BROKER " <> addr srv <> " TRANSPORT " <> serializeTransportError e
|
||||
BROKER srv e -> "BROKER " <> addr srv <> " " <> bshow e
|
||||
BROKER srv (RESPONSE e) -> "BROKER " <> text srv <> " RESPONSE " <> text e
|
||||
BROKER srv (TRANSPORT e) -> "BROKER " <> text srv <> " TRANSPORT " <> serializeTransportError e
|
||||
BROKER srv e -> "BROKER " <> text srv <> " " <> bshow e
|
||||
AGENT (A_QUEUE e) -> "AGENT QUEUE " <> bshow e
|
||||
AGENT e -> "AGENT " <> bshow e
|
||||
INTERNAL e -> "INTERNAL " <> bshow e
|
||||
where
|
||||
addr = encodeUtf8 . T.pack
|
||||
text = encodeUtf8 . T.pack
|
||||
|
||||
instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
|
||||
@@ -510,6 +510,8 @@ data StoreError
|
||||
SEUserNotFound
|
||||
| -- | Connection not found (or both queues absent).
|
||||
SEConnNotFound
|
||||
| -- | Server not found.
|
||||
SEServerNotFound
|
||||
| -- | Connection already used.
|
||||
SEConnDuplicate
|
||||
| -- | Wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa.
|
||||
|
||||
@@ -400,16 +400,16 @@ updateNewConnSnd db connId sq =
|
||||
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
|
||||
createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
upsertServer_ db server
|
||||
serverKeyHash_ <- createServer_ db server
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
void $ insertRcvQueue_ db connId q
|
||||
void $ insertRcvQueue_ db connId q serverKeyHash_
|
||||
|
||||
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
|
||||
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
upsertServer_ db server
|
||||
serverKeyHash_ <- createServer_ db server
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
void $ insertSndQueue_ db connId q
|
||||
void $ insertSndQueue_ db connId q serverKeyHash_
|
||||
|
||||
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
|
||||
getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
|
||||
@@ -447,8 +447,8 @@ addConnRcvQueue db connId rq =
|
||||
|
||||
addConnRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
|
||||
addConnRcvQueue_ db connId rq@RcvQueue {server} = do
|
||||
upsertServer_ db server
|
||||
insertRcvQueue_ db connId rq
|
||||
serverKeyHash_ <- createServer_ db server
|
||||
insertRcvQueue_ db connId rq serverKeyHash_
|
||||
|
||||
addConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64)
|
||||
addConnSndQueue db connId sq =
|
||||
@@ -459,8 +459,8 @@ addConnSndQueue db connId sq =
|
||||
|
||||
addConnSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
|
||||
addConnSndQueue_ db connId sq@SndQueue {server} = do
|
||||
upsertServer_ db server
|
||||
insertSndQueue_ db connId sq
|
||||
serverKeyHash_ <- createServer_ db server
|
||||
insertSndQueue_ db connId sq serverKeyHash_
|
||||
|
||||
setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO ()
|
||||
setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status =
|
||||
@@ -858,18 +858,21 @@ updateRatchet db connId rc skipped = do
|
||||
forM_ (M.assocs mks) $ \(msgN, mk) ->
|
||||
DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk)
|
||||
|
||||
createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> IO AsyncCmdId
|
||||
createCommand db corrId connId srv cmd = do
|
||||
DB.execute
|
||||
db
|
||||
"INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command) VALUES (?,?,?,?,?,?)"
|
||||
(host_, port_, corrId, connId, agentCommandTag cmd, cmd)
|
||||
insertedRowId db
|
||||
createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> IO (Either StoreError AsyncCmdId)
|
||||
createCommand db corrId connId srv_ cmd = runExceptT $ do
|
||||
(host_, port_, serverKeyHash_) <- serverFields
|
||||
liftIO $ do
|
||||
DB.execute
|
||||
db
|
||||
"INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command, server_key_hash) VALUES (?,?,?,?,?,?,?)"
|
||||
(host_, port_, corrId, connId, agentCommandTag cmd, cmd, serverKeyHash_)
|
||||
insertedRowId db
|
||||
where
|
||||
(host_, port_) =
|
||||
case srv of
|
||||
Just (SMPServer host port _) -> (Just host, Just port)
|
||||
_ -> (Nothing, Nothing)
|
||||
serverFields :: ExceptT StoreError IO (Maybe (NonEmpty TransportHost), Maybe ServiceName, Maybe C.KeyHash)
|
||||
serverFields = case srv_ of
|
||||
Just srv@(SMPServer host port _) ->
|
||||
(Just host,Just port,) <$> ExceptT (getServerKeyHash_ db srv)
|
||||
Nothing -> pure (Nothing, Nothing, Nothing)
|
||||
|
||||
insertedRowId :: DB.Connection -> IO Int64
|
||||
insertedRowId db = fromOnly . head <$> DB.query_ db "SELECT last_insert_rowid()"
|
||||
@@ -880,7 +883,7 @@ getPendingCommands db connId = do
|
||||
<$> DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT c.host, c.port, s.key_hash, c.command_id
|
||||
SELECT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash), c.command_id
|
||||
FROM commands c
|
||||
LEFT JOIN servers s ON s.host = c.host AND s.port = c.port
|
||||
WHERE conn_id = ?
|
||||
@@ -1003,7 +1006,7 @@ getNtfSubscription db connId =
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT s.host, s.port, s.key_hash, ns.ntf_host, ns.ntf_port, ns.ntf_key_hash,
|
||||
SELECT s.host, s.port, COALESCE(nsb.smp_server_key_hash, s.key_hash), ns.ntf_host, ns.ntf_port, ns.ntf_key_hash,
|
||||
nsb.smp_ntf_id, nsb.ntf_sub_id, nsb.ntf_sub_status, nsb.ntf_sub_action, nsb.ntf_sub_smp_action, nsb.ntf_sub_action_ts
|
||||
FROM ntf_subscriptions nsb
|
||||
JOIN servers s ON s.host = nsb.smp_host AND s.port = nsb.smp_port
|
||||
@@ -1021,21 +1024,23 @@ getNtfSubscription db connId =
|
||||
_ -> Nothing
|
||||
in (NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}, action)
|
||||
|
||||
createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO ()
|
||||
createNtfSubscription db ntfSubscription action = do
|
||||
let NtfSubscription {connId, smpServer = (SMPServer host port _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription
|
||||
createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO (Either StoreError ())
|
||||
createNtfSubscription db ntfSubscription action = runExceptT $ do
|
||||
let NtfSubscription {connId, smpServer = smpServer@(SMPServer host port _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription
|
||||
smpServerKeyHash_ <- ExceptT $ getServerKeyHash_ db smpServer
|
||||
actionTs <- liftIO getCurrentTime
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO ntf_subscriptions
|
||||
(conn_id, smp_host, smp_port, smp_ntf_id, ntf_host, ntf_port, ntf_sub_id,
|
||||
ntf_sub_status, ntf_sub_action, ntf_sub_smp_action, ntf_sub_action_ts)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
||||
|]
|
||||
( (connId, host, port, ntfQueueId, ntfHost, ntfPort, ntfSubId)
|
||||
:. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs)
|
||||
)
|
||||
liftIO $
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO ntf_subscriptions
|
||||
(conn_id, smp_host, smp_port, smp_ntf_id, ntf_host, ntf_port, ntf_sub_id,
|
||||
ntf_sub_status, ntf_sub_action, ntf_sub_smp_action, ntf_sub_action_ts, smp_server_key_hash)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
|]
|
||||
( (connId, host, port, ntfQueueId, ntfHost, ntfPort, ntfSubId)
|
||||
:. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, smpServerKeyHash_)
|
||||
)
|
||||
where
|
||||
(ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action
|
||||
|
||||
@@ -1136,7 +1141,7 @@ getNextNtfSubNTFAction db ntfServer@(NtfServer ntfHost ntfPort _) = do
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT ns.conn_id, s.host, s.port, s.key_hash,
|
||||
SELECT ns.conn_id, s.host, s.port, COALESCE(ns.smp_server_key_hash, s.key_hash),
|
||||
ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_action
|
||||
FROM ntf_subscriptions ns
|
||||
JOIN servers s ON s.host = ns.smp_host AND s.port = ns.smp_port
|
||||
@@ -1321,20 +1326,25 @@ instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f,
|
||||
|
||||
{- ORMOLU_ENABLE -}
|
||||
|
||||
-- * Server upsert helper
|
||||
-- * Server helper
|
||||
|
||||
upsertServer_ :: DB.Connection -> SMPServer -> IO ()
|
||||
upsertServer_ dbConn ProtocolServer {host, port, keyHash} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO servers (host, port, key_hash) VALUES (:host,:port,:key_hash)
|
||||
ON CONFLICT (host, port) DO UPDATE SET
|
||||
host=excluded.host,
|
||||
port=excluded.port,
|
||||
key_hash=excluded.key_hash;
|
||||
|]
|
||||
[":host" := host, ":port" := port, ":key_hash" := keyHash]
|
||||
-- | Creates a new server, if it doesn't exist, and returns the passed key hash if it is different from stored.
|
||||
createServer_ :: DB.Connection -> SMPServer -> IO (Maybe C.KeyHash)
|
||||
createServer_ db newSrv@ProtocolServer {host, port, keyHash} =
|
||||
getServerKeyHash_ db newSrv >>= \case
|
||||
Right keyHash_ -> pure keyHash_
|
||||
Left _ -> insertNewServer_ $> Nothing
|
||||
where
|
||||
insertNewServer_ =
|
||||
DB.execute db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?)" (host, port, keyHash)
|
||||
|
||||
-- | Returns the stored server key hash if it is different from the passed one, or the error if the server does not exist.
|
||||
getServerKeyHash_ :: DB.Connection -> SMPServer -> IO (Either StoreError (Maybe C.KeyHash))
|
||||
getServerKeyHash_ db ProtocolServer {host, port, keyHash} = do
|
||||
firstRow useKeyHash SEServerNotFound $
|
||||
DB.query db "SELECT key_hash FROM servers WHERE host = ? AND port = ?" (host, port)
|
||||
where
|
||||
useKeyHash (Only keyHash') = if keyHash /= keyHash' then Just keyHash else Nothing
|
||||
|
||||
upsertNtfServer_ :: DB.Connection -> NtfServer -> IO ()
|
||||
upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
|
||||
@@ -1351,30 +1361,30 @@ upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
|
||||
|
||||
-- * createRcvConn helpers
|
||||
|
||||
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
|
||||
insertRcvQueue_ db connId' RcvQueue {..} = do
|
||||
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> Maybe C.KeyHash -> IO Int64
|
||||
insertRcvQueue_ db connId' RcvQueue {..} serverKeyHash_ = do
|
||||
qId <- newQueueId_ <$> DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? ORDER BY rcv_queue_id DESC LIMIT 1" (Only connId')
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO rcv_queues
|
||||
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, status, qId, primary, dbReplaceQueueId, smpClientVersion))
|
||||
((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_))
|
||||
pure qId
|
||||
|
||||
-- * createSndConn helpers
|
||||
|
||||
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
|
||||
insertSndQueue_ db connId' SndQueue {..} = do
|
||||
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> Maybe C.KeyHash -> IO Int64
|
||||
insertSndQueue_ db connId' SndQueue {..} serverKeyHash_ = do
|
||||
qId <- newQueueId_ <$> DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? ORDER BY snd_queue_id DESC LIMIT 1" (Only connId')
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO snd_queues
|
||||
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
((host server, port server, sndId, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, dbReplaceQueueId, smpClientVersion))
|
||||
((host server, port server, sndId, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_))
|
||||
pure qId
|
||||
|
||||
newQueueId_ :: [Only Int64] -> Int64
|
||||
@@ -1443,7 +1453,7 @@ getRcvQueuesByConnId_ db connId =
|
||||
rcvQueueQuery :: Query
|
||||
rcvQueueQuery =
|
||||
[sql|
|
||||
SELECT c.user_id, s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
||||
SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
||||
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status,
|
||||
q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.smp_client_version, q.delete_errors,
|
||||
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
|
||||
@@ -1477,7 +1487,7 @@ getSndQueuesByConnId_ dbConn connId =
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT c.user_id, s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
|
||||
SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
|
||||
FROM snd_queues q
|
||||
JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
JOIN connections c ON q.conn_id = c.conn_id
|
||||
|
||||
@@ -40,6 +40,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queu
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230217_server_key_hash
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
|
||||
@@ -59,7 +60,8 @@ schemaMigrations =
|
||||
("m20220915_connection_queues", m20220915_connection_queues),
|
||||
("m20230110_users", m20230110_users),
|
||||
("m20230117_fkey_indexes", m20230117_fkey_indexes),
|
||||
("m20230120_delete_errors", m20230120_delete_errors)
|
||||
("m20230120_delete_errors", m20230120_delete_errors),
|
||||
("m20230217_server_key_hash", m20230217_server_key_hash)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230217_server_key_hash where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
-- server_key_hash is not null for records whose entities refer to a server
|
||||
-- that was previously saved with the same host and port but different key hash
|
||||
m20230217_server_key_hash :: Query
|
||||
m20230217_server_key_hash =
|
||||
[sql|
|
||||
ALTER TABLE rcv_queues ADD COLUMN server_key_hash BLOB;
|
||||
|
||||
ALTER TABLE snd_queues ADD COLUMN server_key_hash BLOB;
|
||||
|
||||
ALTER TABLE ntf_subscriptions ADD COLUMN smp_server_key_hash BLOB;
|
||||
|
||||
ALTER TABLE commands ADD COLUMN server_key_hash BLOB;
|
||||
|]
|
||||
@@ -48,6 +48,7 @@ CREATE TABLE rcv_queues(
|
||||
rcv_primary INTEGER CHECK(rcv_primary NOT NULL),
|
||||
replace_rcv_queue_id INTEGER NULL,
|
||||
delete_errors INTEGER DEFAULT 0 CHECK(delete_errors NOT NULL),
|
||||
server_key_hash BLOB,
|
||||
PRIMARY KEY(host, port, rcv_id),
|
||||
FOREIGN KEY(host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
@@ -68,6 +69,7 @@ CREATE TABLE snd_queues(
|
||||
snd_queue_id INTEGER CHECK(snd_queue_id NOT NULL),
|
||||
snd_primary INTEGER CHECK(snd_primary NOT NULL),
|
||||
replace_snd_queue_id INTEGER NULL,
|
||||
server_key_hash BLOB,
|
||||
PRIMARY KEY(host, port, snd_id),
|
||||
FOREIGN KEY(host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
@@ -199,6 +201,7 @@ CREATE TABLE ntf_subscriptions(
|
||||
updated_by_supervisor INTEGER NOT NULL DEFAULT 0, -- to be checked on updates by workers to not overwrite supervisor command(state still should be updated)
|
||||
created_at TEXT NOT NULL DEFAULT(datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT(datetime('now')),
|
||||
smp_server_key_hash BLOB,
|
||||
PRIMARY KEY(conn_id),
|
||||
FOREIGN KEY(smp_host, smp_port) REFERENCES servers(host, port)
|
||||
ON DELETE SET NULL ON UPDATE CASCADE,
|
||||
@@ -214,6 +217,7 @@ CREATE TABLE commands(
|
||||
command_tag BLOB NOT NULL,
|
||||
command BLOB NOT NULL,
|
||||
agent_version INTEGER NOT NULL DEFAULT 1,
|
||||
server_key_hash BLOB,
|
||||
FOREIGN KEY(host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
@@ -55,6 +55,7 @@ module Simplex.Messaging.Client
|
||||
|
||||
-- * Supporting types and client configuration
|
||||
ProtocolClientError (..),
|
||||
SMPClientError,
|
||||
ProtocolClientConfig (..),
|
||||
NetworkConfig (..),
|
||||
TransportSessionMode (..),
|
||||
@@ -108,28 +109,28 @@ import System.Timeout (timeout)
|
||||
-- | 'SMPClient' is a handle used to send commands to a specific SMP server.
|
||||
--
|
||||
-- Use 'getSMPClient' to connect to an SMP server and create a client handle.
|
||||
data ProtocolClient msg = ProtocolClient
|
||||
data ProtocolClient err msg = ProtocolClient
|
||||
{ action :: Maybe (Async ()),
|
||||
sessionId :: SessionId,
|
||||
sessionTs :: UTCTime,
|
||||
thVersion :: Version,
|
||||
client_ :: PClient msg
|
||||
client_ :: PClient err msg
|
||||
}
|
||||
|
||||
data PClient msg = PClient
|
||||
data PClient err msg = PClient
|
||||
{ connected :: TVar Bool,
|
||||
transportSession :: TransportSession msg,
|
||||
transportHost :: TransportHost,
|
||||
tcpTimeout :: Int,
|
||||
pingErrorCount :: TVar Int,
|
||||
clientCorrId :: TVar Natural,
|
||||
sentCommands :: TMap CorrId (Request msg),
|
||||
sentCommands :: TMap CorrId (Request err msg),
|
||||
sndQ :: TBQueue (NonEmpty SentRawTransmission),
|
||||
rcvQ :: TBQueue (NonEmpty (SignedTransmission msg)),
|
||||
rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)),
|
||||
msgQ :: Maybe (TBQueue (ServerTransmission msg))
|
||||
}
|
||||
|
||||
type SMPClient = ProtocolClient SMP.BrokerMsg
|
||||
type SMPClient = ProtocolClient ErrorType SMP.BrokerMsg
|
||||
|
||||
-- | Type for client command data
|
||||
type ClientCommand msg = (Maybe C.APrivateSignKey, QueueId, ProtoCommand msg)
|
||||
@@ -232,14 +233,14 @@ defaultClientConfig =
|
||||
smpServerVRange = supportedSMPServerVRange
|
||||
}
|
||||
|
||||
data Request msg = Request
|
||||
data Request err msg = Request
|
||||
{ queueId :: QueueId,
|
||||
responseVar :: TMVar (Response msg)
|
||||
responseVar :: TMVar (Response err msg)
|
||||
}
|
||||
|
||||
type Response msg = Either ProtocolClientError msg
|
||||
type Response err msg = Either (ProtocolClientError err) msg
|
||||
|
||||
chooseTransportHost :: NetworkConfig -> NonEmpty TransportHost -> Either ProtocolClientError TransportHost
|
||||
chooseTransportHost :: NetworkConfig -> NonEmpty TransportHost -> Either (ProtocolClientError err) TransportHost
|
||||
chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts =
|
||||
firstOrError $ case hostMode of
|
||||
HMOnionViaSocks -> maybe publicHost (const onionHost) socksProxy
|
||||
@@ -253,15 +254,15 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts
|
||||
onionHost = find isOnionHost hosts
|
||||
publicHost = find (not . isOnionHost) hosts
|
||||
|
||||
clientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient msg -> String
|
||||
clientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient err msg -> String
|
||||
clientServer = B.unpack . strEncode . snd3 . transportSession . client_
|
||||
where
|
||||
snd3 (_, s, _) = s
|
||||
|
||||
transportHost' :: ProtocolClient msg -> TransportHost
|
||||
transportHost' :: ProtocolClient err msg -> TransportHost
|
||||
transportHost' = transportHost . client_
|
||||
|
||||
transportSession' :: ProtocolClient msg -> TransportSession msg
|
||||
transportSession' :: ProtocolClient err msg -> TransportSession msg
|
||||
transportSession' = transportSession . client_
|
||||
|
||||
type UserId = Int64
|
||||
@@ -274,7 +275,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId)
|
||||
--
|
||||
-- A single queue can be used for multiple 'SMPClient' instances,
|
||||
-- as 'SMPServerTransmission' includes server information.
|
||||
getProtocolClient :: forall msg. Protocol msg => TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
getProtocolClient :: forall err msg. Protocol err msg => TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient err msg))
|
||||
getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
|
||||
case chooseTransportHost networkConfig (host srv) of
|
||||
Right useHost ->
|
||||
@@ -283,7 +284,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
|
||||
Left e -> pure $ Left e
|
||||
where
|
||||
NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig
|
||||
mkProtocolClient :: TransportHost -> STM (PClient msg)
|
||||
mkProtocolClient :: TransportHost -> STM (PClient err msg)
|
||||
mkProtocolClient transportHost = do
|
||||
connected <- newTVar False
|
||||
pingErrorCount <- newTVar 0
|
||||
@@ -305,7 +306,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
|
||||
msgQ
|
||||
}
|
||||
|
||||
runClient :: (ServiceName, ATransport) -> TransportHost -> PClient msg -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
runClient :: (ServiceName, ATransport) -> TransportHost -> PClient err msg -> IO (Either (ProtocolClientError err) (ProtocolClient err msg))
|
||||
runClient (port', ATransport t) useHost c = do
|
||||
cVar <- newEmptyTMVarIO
|
||||
let tcConfig = transportClientConfig networkConfig
|
||||
@@ -326,9 +327,9 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
|
||||
"80" -> ("80", transport @WS)
|
||||
p -> (p, transport @TLS)
|
||||
|
||||
client :: forall c. Transport c => TProxy c -> PClient msg -> TMVar (Either ProtocolClientError (ProtocolClient msg)) -> c -> IO ()
|
||||
client :: forall c. Transport c => TProxy c -> PClient err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient err msg)) -> c -> IO ()
|
||||
client _ c cVar h =
|
||||
runExceptT (protocolClientHandshake @msg h (keyHash srv) smpServerVRange) >>= \case
|
||||
runExceptT (protocolClientHandshake @err @msg h (keyHash srv) smpServerVRange) >>= \case
|
||||
Left e -> atomically . putTMVar cVar . Left $ PCETransportError e
|
||||
Right th@THandle {sessionId, thVersion} -> do
|
||||
sessionTs <- getCurrentTime
|
||||
@@ -339,16 +340,16 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
|
||||
raceAny_ ([send c' th, process c', receive c' th] <> [ping c' | smpPingInterval > 0])
|
||||
`finally` disconnected c'
|
||||
|
||||
send :: Transport c => ProtocolClient msg -> THandle c -> IO ()
|
||||
send :: Transport c => ProtocolClient err msg -> THandle c -> IO ()
|
||||
send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
|
||||
receive :: Transport c => ProtocolClient msg -> THandle c -> IO ()
|
||||
receive :: Transport c => ProtocolClient err msg -> THandle c -> IO ()
|
||||
receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
|
||||
|
||||
ping :: ProtocolClient msg -> IO ()
|
||||
ping :: ProtocolClient err msg -> IO ()
|
||||
ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do
|
||||
threadDelay smpPingInterval
|
||||
runExceptT (sendProtocolCommand c Nothing "" protocolPing) >>= \case
|
||||
runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @err @msg) >>= \case
|
||||
Left PCEResponseTimeout -> do
|
||||
cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1)
|
||||
when (maxCnt == 0 || cnt < maxCnt) $ ping c
|
||||
@@ -356,10 +357,10 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
|
||||
where
|
||||
maxCnt = smpPingCount networkConfig
|
||||
|
||||
process :: ProtocolClient msg -> IO ()
|
||||
process :: ProtocolClient err msg -> IO ()
|
||||
process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c)
|
||||
|
||||
processMsg :: ProtocolClient msg -> SignedTransmission msg -> IO ()
|
||||
processMsg :: ProtocolClient err msg -> SignedTransmission err msg -> IO ()
|
||||
processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, qId, respOrErr)) =
|
||||
if B.null $ bs corrId
|
||||
then sendMsg respOrErr
|
||||
@@ -377,7 +378,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
|
||||
_ -> Right r
|
||||
else Left . PCEUnexpectedResponse $ bshow respOrErr
|
||||
where
|
||||
sendMsg :: Either ErrorType msg -> IO ()
|
||||
sendMsg :: Either err msg -> IO ()
|
||||
sendMsg = \case
|
||||
Right msg -> atomically $ mapM_ (`writeTBQueue` serverTransmission c qId msg) msgQ
|
||||
Left e -> putStrLn $ "SMP client error: " <> show e
|
||||
@@ -386,17 +387,17 @@ proxyUsername :: TransportSession msg -> ByteString
|
||||
proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_
|
||||
|
||||
-- | Disconnects client from the server and terminates client threads.
|
||||
closeProtocolClient :: ProtocolClient msg -> IO ()
|
||||
closeProtocolClient :: ProtocolClient err msg -> IO ()
|
||||
closeProtocolClient = mapM_ uninterruptibleCancel . action
|
||||
|
||||
-- | SMP client error type.
|
||||
data ProtocolClientError
|
||||
data ProtocolClientError err
|
||||
= -- | Correctly parsed SMP server ERR response.
|
||||
-- This error is forwarded to the agent client as `ERR SMP err`.
|
||||
PCEProtocolError ErrorType
|
||||
PCEProtocolError err
|
||||
| -- | Invalid server response that failed to parse.
|
||||
-- Forwarded to the agent client as `ERR BROKER RESPONSE`.
|
||||
PCEResponseError ErrorType
|
||||
PCEResponseError err
|
||||
| -- | Different response from what is expected to a certain SMP command,
|
||||
-- e.g. server should respond `IDS` or `ERR` to `NEW` command,
|
||||
-- other responses would result in this error.
|
||||
@@ -419,7 +420,9 @@ data ProtocolClientError
|
||||
PCEIOError IOException
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
temporaryClientError :: ProtocolClientError -> Bool
|
||||
type SMPClientError = ProtocolClientError ErrorType
|
||||
|
||||
temporaryClientError :: ProtocolClientError err -> Bool
|
||||
temporaryClientError = \case
|
||||
PCENetworkError -> True
|
||||
PCEResponseTimeout -> True
|
||||
@@ -435,7 +438,7 @@ createSMPQueue ::
|
||||
RcvPublicVerifyKey ->
|
||||
RcvPublicDhKey ->
|
||||
Maybe BasicAuth ->
|
||||
ExceptT ProtocolClientError IO QueueIdsKeys
|
||||
ExceptT SMPClientError IO QueueIdsKeys
|
||||
createSMPQueue c rpKey rKey dhKey auth =
|
||||
sendSMPCommand c (Just rpKey) "" (NEW rKey dhKey auth) >>= \case
|
||||
IDS qik -> pure qik
|
||||
@@ -444,7 +447,7 @@ createSMPQueue c rpKey rKey dhKey auth =
|
||||
-- | Subscribe to the SMP queue.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue
|
||||
subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
|
||||
subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO ()
|
||||
subscribeSMPQueue c rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId SUB >>= \case
|
||||
OK -> return ()
|
||||
@@ -452,7 +455,7 @@ subscribeSMPQueue c rpKey rId =
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Subscribe to multiple SMP queues batching commands if supported.
|
||||
subscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
|
||||
subscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
|
||||
subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM response . L.zip qs
|
||||
where
|
||||
cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient SUB)) qs
|
||||
@@ -465,14 +468,14 @@ subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM response . L.zip qs
|
||||
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO ()
|
||||
writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c)
|
||||
|
||||
serverTransmission :: ProtocolClient msg -> RecipientId -> msg -> ServerTransmission msg
|
||||
serverTransmission :: ProtocolClient err msg -> RecipientId -> msg -> ServerTransmission msg
|
||||
serverTransmission ProtocolClient {thVersion, sessionId, client_ = PClient {transportSession}} entityId message =
|
||||
(transportSession, thVersion, sessionId, entityId, message)
|
||||
|
||||
-- | Get message from SMP queue. The server returns ERR PROHIBITED if a client uses SUB and GET via the same transport connection for the same queue
|
||||
--
|
||||
-- https://github.covm/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#receive-a-message-from-the-queue
|
||||
getSMPMessage :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO (Maybe RcvMessage)
|
||||
getSMPMessage :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO (Maybe RcvMessage)
|
||||
getSMPMessage c rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId GET >>= \case
|
||||
OK -> pure Nothing
|
||||
@@ -482,26 +485,26 @@ getSMPMessage c rpKey rId =
|
||||
-- | Subscribe to the SMP queue notifications.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications
|
||||
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT ProtocolClientError IO ()
|
||||
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT SMPClientError IO ()
|
||||
subscribeSMPQueueNotifications = okSMPCommand NSUB
|
||||
|
||||
-- | Secure the SMP queue by adding a sender public key.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command
|
||||
secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT ProtocolClientError IO ()
|
||||
secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT SMPClientError IO ()
|
||||
secureSMPQueue c rpKey rId senderKey = okSMPCommand (KEY senderKey) c rpKey rId
|
||||
|
||||
-- | Enable notifications for the queue for push notifications server.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#enable-notifications-command
|
||||
enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> ExceptT ProtocolClientError IO (NotifierId, RcvNtfPublicDhKey)
|
||||
enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> ExceptT SMPClientError IO (NotifierId, RcvNtfPublicDhKey)
|
||||
enableSMPQueueNotifications c rpKey rId notifierKey rcvNtfPublicDhKey =
|
||||
sendSMPCommand c (Just rpKey) rId (NKEY notifierKey rcvNtfPublicDhKey) >>= \case
|
||||
NID nId rcvNtfSrvPublicDhKey -> pure (nId, rcvNtfSrvPublicDhKey)
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Enable notifications for the multiple queues for push notifications server.
|
||||
enableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId, NtfPublicVerifyKey, RcvNtfPublicDhKey) -> IO (NonEmpty (Either ProtocolClientError (NotifierId, RcvNtfPublicDhKey)))
|
||||
enableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId, NtfPublicVerifyKey, RcvNtfPublicDhKey) -> IO (NonEmpty (Either SMPClientError (NotifierId, RcvNtfPublicDhKey)))
|
||||
enableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
|
||||
where
|
||||
cs = L.map (\(rpKey, rId, notifierKey, rcvNtfPublicDhKey) -> (Just rpKey, rId, Cmd SRecipient $ NKEY notifierKey rcvNtfPublicDhKey)) qs
|
||||
@@ -513,17 +516,17 @@ enableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
|
||||
-- | Disable notifications for the queue for push notifications server.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#disable-notifications-command
|
||||
disableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
|
||||
disableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO ()
|
||||
disableSMPQueueNotifications = okSMPCommand NDEL
|
||||
|
||||
-- | Disable notifications for multiple queues for push notifications server.
|
||||
disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
|
||||
disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
|
||||
disableSMPQueuesNtfs = okSMPCommands NDEL
|
||||
|
||||
-- | Send SMP message.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message
|
||||
sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgFlags -> MsgBody -> ExceptT ProtocolClientError IO ()
|
||||
sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgFlags -> MsgBody -> ExceptT SMPClientError IO ()
|
||||
sendSMPMessage c spKey sId flags msg =
|
||||
sendSMPCommand c spKey sId (SEND flags msg) >>= \case
|
||||
OK -> pure ()
|
||||
@@ -532,7 +535,7 @@ sendSMPMessage c spKey sId flags msg =
|
||||
-- | Acknowledge message delivery (server deletes the message).
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#acknowledge-message-delivery
|
||||
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT ProtocolClientError IO ()
|
||||
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT SMPClientError IO ()
|
||||
ackSMPMessage c rpKey rId msgId =
|
||||
sendSMPCommand c (Just rpKey) rId (ACK msgId) >>= \case
|
||||
OK -> return ()
|
||||
@@ -543,26 +546,26 @@ ackSMPMessage c rpKey rId msgId =
|
||||
-- The existing messages from the queue will still be delivered.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue
|
||||
suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
suspendSMPQueue = okSMPCommand OFF
|
||||
|
||||
-- | Irreversibly delete SMP queue and all messages in it.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue
|
||||
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
|
||||
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO ()
|
||||
deleteSMPQueue = okSMPCommand DEL
|
||||
|
||||
-- | Delete multiple SMP queues batching commands if supported.
|
||||
deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
|
||||
deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
|
||||
deleteSMPQueues = okSMPCommands DEL
|
||||
|
||||
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
okSMPCommand cmd c pKey qId =
|
||||
sendSMPCommand c (Just pKey) qId cmd >>= \case
|
||||
OK -> return ()
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
okSMPCommands :: PartyI p => Command p -> SMPClient -> NonEmpty (C.APrivateSignKey, QueueId) -> IO (NonEmpty (Either ProtocolClientError ()))
|
||||
okSMPCommands :: PartyI p => Command p -> SMPClient -> NonEmpty (C.APrivateSignKey, QueueId) -> IO (NonEmpty (Either SMPClientError ()))
|
||||
okSMPCommands cmd c qs = L.map response <$> sendProtocolCommands c cs
|
||||
where
|
||||
aCmd = Cmd sParty cmd
|
||||
@@ -573,11 +576,11 @@ okSMPCommands cmd c qs = L.map response <$> sendProtocolCommands c cs
|
||||
Left e -> Left e
|
||||
|
||||
-- | Send SMP command
|
||||
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg
|
||||
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg
|
||||
sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
|
||||
|
||||
-- | Send multiple commands with batching and collect responses
|
||||
sendProtocolCommands :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either ProtocolClientError msg))
|
||||
sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either (ProtocolClientError err) msg))
|
||||
sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ}} cs = do
|
||||
ts <- mapM (runExceptT . mkTransmission c) cs
|
||||
mapM_ (atomically . writeTBQueue sndQ . L.map fst) . L.nonEmpty . rights $ L.toList ts
|
||||
@@ -586,22 +589,22 @@ sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ}} cs = do
|
||||
Left e -> pure $ Left e
|
||||
|
||||
-- | Send Protocol command
|
||||
sendProtocolCommand :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT ProtocolClientError IO msg
|
||||
sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg
|
||||
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}} pKey qId cmd = do
|
||||
(t, r) <- mkTransmission c (pKey, qId, cmd)
|
||||
ExceptT $ sendRecv t r
|
||||
where
|
||||
-- two separate "atomically" needed to avoid blocking
|
||||
sendRecv :: SentRawTransmission -> TMVar (Response msg) -> IO (Response msg)
|
||||
sendRecv :: SentRawTransmission -> TMVar (Response err msg) -> IO (Response err msg)
|
||||
sendRecv t r = atomically (writeTBQueue sndQ [t]) >> withTimeout c (atomically $ takeTMVar r)
|
||||
|
||||
withTimeout :: ProtocolClient msg -> IO (Either ProtocolClientError msg) -> IO (Either ProtocolClientError msg)
|
||||
withTimeout :: ProtocolClient err msg -> IO (Either (ProtocolClientError err) msg) -> IO (Either (ProtocolClientError err) msg)
|
||||
withTimeout ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} a =
|
||||
timeout tcpTimeout a >>= \case
|
||||
Just r -> atomically (writeTVar pingErrorCount 0) >> pure r
|
||||
_ -> pure $ Left PCEResponseTimeout
|
||||
|
||||
mkTransmission :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> ClientCommand msg -> ExceptT ProtocolClientError IO (SentRawTransmission, TMVar (Response msg))
|
||||
mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> ExceptT (ProtocolClientError err) IO (SentRawTransmission, TMVar (Response err msg))
|
||||
mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCorrId, sentCommands}} (pKey, qId, cmd) = do
|
||||
corrId <- liftIO $ atomically getNextCorrId
|
||||
let t = signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd)
|
||||
@@ -614,7 +617,7 @@ mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCo
|
||||
pure . CorrId $ bshow i
|
||||
signTransmission :: ByteString -> SentRawTransmission
|
||||
signTransmission t = ((`C.sign` t) <$> pKey, t)
|
||||
mkRequest :: CorrId -> STM (TMVar (Response msg))
|
||||
mkRequest :: CorrId -> STM (TMVar (Response err msg))
|
||||
mkRequest corrId = do
|
||||
r <- newEmptyTMVar
|
||||
TM.insert corrId (Request qId r) sentCommands
|
||||
|
||||
@@ -38,14 +38,14 @@ import UnliftIO.Exception (Exception)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
type SMPClientVar = TMVar (Either ProtocolClientError SMPClient)
|
||||
type SMPClientVar = TMVar (Either SMPClientError SMPClient)
|
||||
|
||||
data SMPClientAgentEvent
|
||||
= CAConnected SMPServer
|
||||
| CADisconnected SMPServer (Set SMPSub)
|
||||
| CAReconnected SMPServer
|
||||
| CAResubscribed SMPServer SMPSub
|
||||
| CASubError SMPServer SMPSub ProtocolClientError
|
||||
| CASubError SMPServer SMPSub SMPClientError
|
||||
|
||||
data SMPSubParty = SPRecipient | SPNotifier
|
||||
deriving (Eq, Ord, Show)
|
||||
@@ -111,7 +111,7 @@ newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} = do
|
||||
asyncClients <- newTVar []
|
||||
pure SMPClientAgent {agentCfg, msgQ, agentQ, smpClients, srvSubs, pendingSrvSubs, reconnections, asyncClients}
|
||||
|
||||
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT ProtocolClientError IO SMPClient
|
||||
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT SMPClientError IO SMPClient
|
||||
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
atomically getClientVar >>= either newSMPClient waitForSMPClient
|
||||
where
|
||||
@@ -124,7 +124,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
TM.insert srv smpVar smpClients
|
||||
pure smpVar
|
||||
|
||||
waitForSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient
|
||||
waitForSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
|
||||
waitForSMPClient smpVar = do
|
||||
let ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg
|
||||
smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar smpVar)
|
||||
@@ -133,10 +133,10 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left PCEResponseTimeout
|
||||
|
||||
newSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient
|
||||
newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
|
||||
newSMPClient smpVar = tryConnectClient pure tryConnectAsync
|
||||
where
|
||||
tryConnectClient :: (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO () -> ExceptT ProtocolClientError IO a
|
||||
tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryE connectClient >>= \r -> case r of
|
||||
Right smp -> do
|
||||
@@ -150,16 +150,16 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
putTMVar smpVar (Left e)
|
||||
TM.delete srv smpClients
|
||||
throwE e
|
||||
tryConnectAsync :: ExceptT ProtocolClientError IO ()
|
||||
tryConnectAsync :: ExceptT SMPClientError IO ()
|
||||
tryConnectAsync = do
|
||||
a <- async connectAsync
|
||||
atomically $ modifyTVar' (asyncClients ca) (a :)
|
||||
connectAsync :: ExceptT ProtocolClientError IO ()
|
||||
connectAsync :: ExceptT SMPClientError IO ()
|
||||
connectAsync =
|
||||
withRetryInterval (reconnectInterval agentCfg) $ \loop ->
|
||||
void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
connectClient :: ExceptT ProtocolClientError IO SMPClient
|
||||
connectClient :: ExceptT SMPClientError IO SMPClient
|
||||
connectClient = ExceptT $ getProtocolClient (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) clientDisconnected
|
||||
|
||||
clientDisconnected :: SMPClient -> IO ()
|
||||
@@ -188,17 +188,17 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
notify . CADisconnected srv $ M.keysSet ss
|
||||
reconnectServer
|
||||
|
||||
reconnectServer :: ExceptT ProtocolClientError IO ()
|
||||
reconnectServer :: ExceptT SMPClientError IO ()
|
||||
reconnectServer = do
|
||||
a <- async tryReconnectClient
|
||||
atomically $ modifyTVar' (reconnections ca) (a :)
|
||||
|
||||
tryReconnectClient :: ExceptT ProtocolClientError IO ()
|
||||
tryReconnectClient :: ExceptT SMPClientError IO ()
|
||||
tryReconnectClient = do
|
||||
withRetryInterval (reconnectInterval agentCfg) $ \loop ->
|
||||
reconnectClient `catchE` const loop
|
||||
|
||||
reconnectClient :: ExceptT ProtocolClientError IO ()
|
||||
reconnectClient :: ExceptT SMPClientError IO ()
|
||||
reconnectClient = do
|
||||
withSMP ca srv $ \smp -> do
|
||||
notify $ CAReconnected srv
|
||||
@@ -207,13 +207,13 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
unlessM (atomically $ hasSub (srvSubs ca) srv s) $
|
||||
subscribe_ smp sub `catchE` handleError s
|
||||
where
|
||||
subscribe_ :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
|
||||
subscribe_ :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT SMPClientError IO ()
|
||||
subscribe_ smp sub@(s, _) = do
|
||||
smpSubscribe smp sub
|
||||
atomically $ addSubscription ca srv sub
|
||||
notify $ CAResubscribed srv s
|
||||
|
||||
handleError :: SMPSub -> ProtocolClientError -> ExceptT ProtocolClientError IO ()
|
||||
handleError :: SMPSub -> SMPClientError -> ExceptT SMPClientError IO ()
|
||||
handleError s = \case
|
||||
e@PCEResponseTimeout -> throwE e
|
||||
e@PCENetworkError -> throwE e
|
||||
@@ -221,7 +221,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
notify $ CASubError srv s e
|
||||
atomically $ removePendingSubscription ca srv s
|
||||
|
||||
notify :: SMPClientAgentEvent -> ExceptT ProtocolClientError IO ()
|
||||
notify :: SMPClientAgentEvent -> ExceptT SMPClientError IO ()
|
||||
notify evt = atomically $ writeTBQueue (agentQ ca) evt
|
||||
|
||||
closeSMPClientAgent :: MonadUnliftIO m => SMPClientAgent -> m ()
|
||||
@@ -241,15 +241,15 @@ closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeCli
|
||||
cancelActions :: Foldable f => TVar (f (Async ())) -> IO ()
|
||||
cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel
|
||||
|
||||
withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO a
|
||||
withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO a
|
||||
withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPError
|
||||
where
|
||||
logSMPError :: ProtocolClientError -> ExceptT ProtocolClientError IO a
|
||||
logSMPError :: SMPClientError -> ExceptT SMPClientError IO a
|
||||
logSMPError e = do
|
||||
liftIO $ putStrLn $ "SMP error (" <> show srv <> "): " <> show e
|
||||
throwE e
|
||||
|
||||
subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
|
||||
subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> ExceptT SMPClientError IO ()
|
||||
subscribeQueue ca srv sub = do
|
||||
atomically $ addPendingSubscription ca srv sub
|
||||
withSMP ca srv $ \smp -> subscribe_ smp `catchE` handleError
|
||||
@@ -267,7 +267,7 @@ showServer :: SMPServer -> ByteString
|
||||
showServer ProtocolServer {host, port} =
|
||||
strEncode host <> B.pack (if null port then "" else ':' : port)
|
||||
|
||||
smpSubscribe :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
|
||||
smpSubscribe :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT SMPClientError IO ()
|
||||
smpSubscribe smp ((party, queueId), privKey) = subscribe_ smp privKey queueId
|
||||
where
|
||||
subscribe_ = case party of
|
||||
|
||||
@@ -10,54 +10,57 @@ import Data.Word (Word16)
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Protocol (ErrorType)
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
|
||||
type NtfClient = ProtocolClient NtfResponse
|
||||
type NtfClient = ProtocolClient ErrorType NtfResponse
|
||||
|
||||
ntfRegisterToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519)
|
||||
type NtfClientError = ProtocolClientError ErrorType
|
||||
|
||||
ntfRegisterToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT NtfClientError IO (NtfTokenId, C.PublicKeyX25519)
|
||||
ntfRegisterToken c pKey newTkn =
|
||||
sendNtfCommand c (Just pKey) "" (TNEW newTkn) >>= \case
|
||||
NRTknId tknId dhKey -> pure (tknId, dhKey)
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT ProtocolClientError IO ()
|
||||
ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT NtfClientError IO ()
|
||||
ntfVerifyToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId
|
||||
|
||||
ntfCheckToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO NtfTknStatus
|
||||
ntfCheckToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT NtfClientError IO NtfTknStatus
|
||||
ntfCheckToken c pKey tknId =
|
||||
sendNtfCommand c (Just pKey) tknId TCHK >>= \case
|
||||
NRTkn stat -> pure stat
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
ntfReplaceToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> DeviceToken -> ExceptT ProtocolClientError IO ()
|
||||
ntfReplaceToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> DeviceToken -> ExceptT NtfClientError IO ()
|
||||
ntfReplaceToken c pKey tknId token = okNtfCommand (TRPL token) c pKey tknId
|
||||
|
||||
ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO ()
|
||||
ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT NtfClientError IO ()
|
||||
ntfDeleteToken = okNtfCommand TDEL
|
||||
|
||||
ntfEnableCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO ()
|
||||
ntfEnableCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT NtfClientError IO ()
|
||||
ntfEnableCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId
|
||||
|
||||
ntfCreateSubscription :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO NtfSubscriptionId
|
||||
ntfCreateSubscription :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT NtfClientError IO NtfSubscriptionId
|
||||
ntfCreateSubscription c pKey newSub =
|
||||
sendNtfCommand c (Just pKey) "" (SNEW newSub) >>= \case
|
||||
NRSubId subId -> pure subId
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus
|
||||
ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT NtfClientError IO NtfSubStatus
|
||||
ntfCheckSubscription c pKey subId =
|
||||
sendNtfCommand c (Just pKey) subId SCHK >>= \case
|
||||
NRSub stat -> pure stat
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO ()
|
||||
ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT NtfClientError IO ()
|
||||
ntfDeleteSubscription = okNtfCommand SDEL
|
||||
|
||||
-- | Send notification server command
|
||||
sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT ProtocolClientError IO NtfResponse
|
||||
sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT NtfClientError IO NtfResponse
|
||||
sendNtfCommand c pKey entId cmd = sendProtocolCommand c pKey entId (NtfCmd sNtfEntity cmd)
|
||||
|
||||
okNtfCommand :: NtfEntityI e => NtfCommand e -> NtfClient -> C.APrivateSignKey -> NtfEntityId -> ExceptT ProtocolClientError IO ()
|
||||
okNtfCommand :: NtfEntityI e => NtfCommand e -> NtfClient -> C.APrivateSignKey -> NtfEntityId -> ExceptT NtfClientError IO ()
|
||||
okNtfCommand cmd c pKey entId =
|
||||
sendNtfCommand c (Just pKey) entId cmd >>= \case
|
||||
NROk -> return ()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
@@ -146,7 +147,7 @@ instance Encoding ANewNtfEntity where
|
||||
'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP <*> smpP)
|
||||
_ -> fail "bad ANewNtfEntity"
|
||||
|
||||
instance Protocol NtfResponse where
|
||||
instance Protocol ErrorType NtfResponse where
|
||||
type ProtoCommand NtfResponse = NtfCmd
|
||||
type ProtoType NtfResponse = 'PNTF
|
||||
protocolClientHandshake = ntfClientHandshake
|
||||
@@ -183,7 +184,7 @@ data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e)
|
||||
|
||||
deriving instance Show NtfCmd
|
||||
|
||||
instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
|
||||
instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
|
||||
type Tag (NtfCommand e) = NtfCommandTag e
|
||||
encodeProtocol _v = \case
|
||||
TNEW newTkn -> e (TNEW_, ' ', newTkn)
|
||||
@@ -202,6 +203,9 @@ instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
|
||||
|
||||
protocolP _v tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP _v (NCT (sNtfEntity @e) tag)
|
||||
|
||||
fromProtocolError = fromProtocolError @ErrorType @NtfResponse
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (sig, _, entityId, _) cmd = case cmd of
|
||||
-- TNEW and SNEW must have signature but NOT token/subscription IDs
|
||||
TNEW {} -> sigNoEntity
|
||||
@@ -219,7 +223,7 @@ instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
|
||||
| not (B.null entityId) = Left $ CMD HAS_AUTH
|
||||
| otherwise = Right cmd
|
||||
|
||||
instance ProtocolEncoding NtfCmd where
|
||||
instance ProtocolEncoding ErrorType NtfCmd where
|
||||
type Tag NtfCmd = NtfCmdTag
|
||||
encodeProtocol _v (NtfCmd _ c) = encodeProtocol _v c
|
||||
|
||||
@@ -239,6 +243,9 @@ instance ProtocolEncoding NtfCmd where
|
||||
SDEL_ -> pure SDEL
|
||||
PING_ -> pure PING
|
||||
|
||||
fromProtocolError = fromProtocolError @ErrorType @NtfResponse
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c
|
||||
|
||||
data NtfResponseTag
|
||||
@@ -283,7 +290,7 @@ data NtfResponse
|
||||
| NRPong
|
||||
deriving (Show)
|
||||
|
||||
instance ProtocolEncoding NtfResponse where
|
||||
instance ProtocolEncoding ErrorType NtfResponse where
|
||||
type Tag NtfResponse = NtfResponseTag
|
||||
encodeProtocol _v = \case
|
||||
NRTknId entId dhKey -> e (NRTknId_, ' ', entId, dhKey)
|
||||
@@ -306,6 +313,13 @@ instance ProtocolEncoding NtfResponse where
|
||||
NRSub_ -> NRSub <$> _smpP
|
||||
NRPong_ -> pure NRPong
|
||||
|
||||
fromProtocolError = \case
|
||||
PECmdSyntax -> CMD SYNTAX
|
||||
PECmdUnknown -> CMD UNKNOWN
|
||||
PESession -> SESSION
|
||||
PEBlock -> BLOCK
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (_, _, entId, _) cmd = case cmd of
|
||||
-- IDTKN response must not have queue ID
|
||||
NRTknId {} -> noEntity
|
||||
|
||||
@@ -27,7 +27,7 @@ import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
|
||||
import Data.Time.Clock.System (getSystemTime)
|
||||
import Data.Time.Format.ISO8601 (iso8601Show)
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Client (ProtocolClientError (..))
|
||||
import Simplex.Messaging.Client (ProtocolClientError (..), SMPClientError)
|
||||
import Simplex.Messaging.Client.Agent
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
@@ -227,7 +227,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
|
||||
where
|
||||
showServer' = decodeLatin1 . strEncode . host
|
||||
|
||||
handleSubError :: SMPQueueNtf -> ProtocolClientError -> M ()
|
||||
handleSubError :: SMPQueueNtf -> SMPClientError -> M ()
|
||||
handleSubError smpQueue = \case
|
||||
PCEProtocolError AUTH -> updateSubStatus smpQueue NSAuth
|
||||
PCEProtocolError e -> updateErr "SMP error " e
|
||||
@@ -343,7 +343,7 @@ send h@THandle {thVersion = v} NtfServerClient {sndQ, sessionId, activeAt} = for
|
||||
|
||||
data VerificationResult = VRVerified NtfRequest | VRFailed
|
||||
|
||||
verifyNtfTransmission :: SignedTransmission NtfCmd -> NtfCmd -> M VerificationResult
|
||||
verifyNtfTransmission :: SignedTransmission ErrorType NtfCmd -> NtfCmd -> M VerificationResult
|
||||
verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
|
||||
st <- asks store
|
||||
case cmd of
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
@@ -19,11 +20,10 @@
|
||||
{-# LANGUAGE TypeFamilyDependencies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# HLINT ignore "Use newtype instead of data" #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
{-# HLINT ignore "Use newtype instead of data" #-}
|
||||
|
||||
-- |
|
||||
-- Module : Simplex.Messaging.ProtocolEncoding
|
||||
-- Copyright : (c) simplex.chat
|
||||
@@ -52,6 +52,7 @@ module Simplex.Messaging.Protocol
|
||||
SParty (..),
|
||||
PartyI (..),
|
||||
QueueIdsKeys (..),
|
||||
ProtocolErrorType (..),
|
||||
ErrorType (..),
|
||||
CommandError (..),
|
||||
Transmission,
|
||||
@@ -226,7 +227,7 @@ deriving instance Show Cmd
|
||||
type Transmission c = (CorrId, EntityId, c)
|
||||
|
||||
-- | signed parsed transmission, with original raw bytes and parsing error.
|
||||
type SignedTransmission c = (Maybe C.ASignature, Signed, Transmission (Either ErrorType c))
|
||||
type SignedTransmission e c = (Maybe C.ASignature, Signed, Transmission (Either e c))
|
||||
|
||||
type Signed = ByteString
|
||||
|
||||
@@ -894,6 +895,8 @@ type MsgId = ByteString
|
||||
-- | SMP message body.
|
||||
type MsgBody = ByteString
|
||||
|
||||
data ProtocolErrorType = PECmdSyntax | PECmdUnknown | PESession | PEBlock
|
||||
|
||||
-- | Type for protocol errors.
|
||||
data ErrorType
|
||||
= -- | incorrect block format, encoding or signature size
|
||||
@@ -964,16 +967,16 @@ transmissionP = do
|
||||
command <- A.takeByteString
|
||||
pure RawTransmission {signature, signed, sessId, corrId, entityId, command}
|
||||
|
||||
class (ProtocolEncoding msg, ProtocolEncoding (ProtoCommand msg), Show msg) => Protocol msg where
|
||||
class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where
|
||||
type ProtoCommand msg = cmd | cmd -> msg
|
||||
type ProtoType msg = (sch :: ProtocolType) | sch -> msg
|
||||
protocolClientHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
protocolPing :: ProtoCommand msg
|
||||
protocolError :: msg -> Maybe ErrorType
|
||||
protocolError :: msg -> Maybe err
|
||||
|
||||
type ProtoServer msg = ProtocolServer (ProtoType msg)
|
||||
|
||||
instance Protocol BrokerMsg where
|
||||
instance Protocol ErrorType BrokerMsg where
|
||||
type ProtoCommand BrokerMsg = Cmd
|
||||
type ProtoType BrokerMsg = 'PSMP
|
||||
protocolClientHandshake = smpClientHandshake
|
||||
@@ -982,13 +985,14 @@ instance Protocol BrokerMsg where
|
||||
ERR e -> Just e
|
||||
_ -> Nothing
|
||||
|
||||
class ProtocolMsgTag (Tag msg) => ProtocolEncoding msg where
|
||||
class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where
|
||||
type Tag msg
|
||||
encodeProtocol :: Version -> msg -> ByteString
|
||||
protocolP :: Version -> Tag msg -> Parser msg
|
||||
checkCredentials :: SignedRawTransmission -> msg -> Either ErrorType msg
|
||||
fromProtocolError :: ProtocolErrorType -> err
|
||||
checkCredentials :: SignedRawTransmission -> msg -> Either err msg
|
||||
|
||||
instance PartyI p => ProtocolEncoding (Command p) where
|
||||
instance PartyI p => ProtocolEncoding ErrorType (Command p) where
|
||||
type Tag (Command p) = CommandTag p
|
||||
encodeProtocol v = \case
|
||||
NEW rKey dhKey auth_ -> case auth_ of
|
||||
@@ -1019,6 +1023,9 @@ instance PartyI p => ProtocolEncoding (Command p) where
|
||||
|
||||
protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag)
|
||||
|
||||
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (sig, _, queueId, _) cmd = case cmd of
|
||||
-- NEW must have signature but NOT queue ID
|
||||
NEW {}
|
||||
@@ -1038,7 +1045,7 @@ instance PartyI p => ProtocolEncoding (Command p) where
|
||||
| isNothing sig || B.null queueId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
|
||||
instance ProtocolEncoding Cmd where
|
||||
instance ProtocolEncoding ErrorType Cmd where
|
||||
type Tag Cmd = CmdTag
|
||||
encodeProtocol v (Cmd _ c) = encodeProtocol v c
|
||||
|
||||
@@ -1068,9 +1075,12 @@ instance ProtocolEncoding Cmd where
|
||||
PING_ -> pure PING
|
||||
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
|
||||
|
||||
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
|
||||
|
||||
instance ProtocolEncoding BrokerMsg where
|
||||
instance ProtocolEncoding ErrorType BrokerMsg where
|
||||
type Tag BrokerMsg = BrokerMsgTag
|
||||
encodeProtocol v = \case
|
||||
IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh)
|
||||
@@ -1105,6 +1115,13 @@ instance ProtocolEncoding BrokerMsg where
|
||||
ERR_ -> ERR <$> _smpP
|
||||
PONG_ -> pure PONG
|
||||
|
||||
fromProtocolError = \case
|
||||
PECmdSyntax -> CMD SYNTAX
|
||||
PECmdUnknown -> CMD UNKNOWN
|
||||
PESession -> SESSION
|
||||
PEBlock -> BLOCK
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (_, _, queueId, _) cmd = case cmd of
|
||||
-- IDS response should not have queue ID
|
||||
IDS _ -> Right cmd
|
||||
@@ -1123,12 +1140,12 @@ _smpP :: Encoding a => Parser a
|
||||
_smpP = A.space *> smpP
|
||||
|
||||
-- | Parse SMP protocol commands and broker messages
|
||||
parseProtocol :: ProtocolEncoding msg => Version -> ByteString -> Either ErrorType msg
|
||||
parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg
|
||||
parseProtocol v s =
|
||||
let (tag, params) = B.break (== ' ') s
|
||||
in case decodeTag tag of
|
||||
Just cmd -> parse (protocolP v cmd) (CMD SYNTAX) params
|
||||
Nothing -> Left $ CMD UNKNOWN
|
||||
Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params
|
||||
Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown
|
||||
|
||||
checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p)
|
||||
checkParty c = case testEquality (sParty @p) (sParty @p') of
|
||||
@@ -1223,7 +1240,7 @@ tEncode (sig, t) = smpEncode (C.signatureBytes sig) <> t
|
||||
tEncodeBatch :: Int -> ByteString -> ByteString
|
||||
tEncodeBatch n s = lenEncode n `B.cons` s
|
||||
|
||||
encodeTransmission :: ProtocolEncoding c => Version -> ByteString -> Transmission c -> ByteString
|
||||
encodeTransmission :: ProtocolEncoding e c => Version -> ByteString -> Transmission c -> ByteString
|
||||
encodeTransmission v sessionId (CorrId corrId, queueId, command) =
|
||||
smpEncode (sessionId, corrId, queueId) <> encodeProtocol v command
|
||||
|
||||
@@ -1243,22 +1260,22 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b
|
||||
eitherList = either (\e -> [Left e])
|
||||
|
||||
-- | Receive client and server transmissions (determined by `cmd` type).
|
||||
tGet :: forall cmd c. (ProtocolEncoding cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission cmd))
|
||||
tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd))
|
||||
tGet th@THandle {sessionId, thVersion = v} = L.map (tDecodeParseValidate sessionId v) <$> tGetParse th
|
||||
|
||||
tDecodeParseValidate :: forall cmd. ProtocolEncoding cmd => SessionId -> Version -> Either TransportError RawTransmission -> SignedTransmission cmd
|
||||
tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => SessionId -> Version -> Either TransportError RawTransmission -> SignedTransmission err cmd
|
||||
tDecodeParseValidate sessionId v = \case
|
||||
Right RawTransmission {signature, signed, sessId, corrId, entityId, command}
|
||||
| sessId == sessionId ->
|
||||
let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature
|
||||
in either (const $ tError corrId) (tParseValidate signed) decodedTransmission
|
||||
| otherwise -> (Nothing, "", (CorrId corrId, "", Left SESSION))
|
||||
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession))
|
||||
Left _ -> tError ""
|
||||
where
|
||||
tError :: ByteString -> SignedTransmission cmd
|
||||
tError corrId = (Nothing, "", (CorrId corrId, "", Left BLOCK))
|
||||
tError :: ByteString -> SignedTransmission err cmd
|
||||
tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock))
|
||||
|
||||
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission cmd
|
||||
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd
|
||||
tParseValidate signed t@(sig, corrId, entityId, command) =
|
||||
let cmd = parseProtocol v command >>= checkCredentials t
|
||||
let cmd = parseProtocol @err @cmd v command >>= checkCredentials t
|
||||
in (sig, signed, (CorrId corrId, entityId, cmd))
|
||||
|
||||
@@ -266,7 +266,7 @@ receive th Client {rcvQ, sndQ, activeAt} = forever $ do
|
||||
write sndQ $ fst as
|
||||
write rcvQ $ snd as
|
||||
where
|
||||
cmdAction :: SignedTransmission Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd))
|
||||
cmdAction :: SignedTransmission ErrorType Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd))
|
||||
cmdAction (sig, signed, (corrId, queueId, cmdOrError)) =
|
||||
case cmdOrError of
|
||||
Left e -> pure $ Left (corrId, queueId, ERR e)
|
||||
|
||||
@@ -25,6 +25,7 @@ module AgentTests.FunctionalAPITests
|
||||
)
|
||||
where
|
||||
|
||||
import AgentTests.ConnectionRequestTests (connReqData, queueAddr, testE2ERatchetParams)
|
||||
import Control.Concurrent (killThread, threadDelay)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except (ExceptT, runExceptT, throwError)
|
||||
@@ -45,8 +46,9 @@ import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Store (UserId)
|
||||
import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (TSMEntity, TSMUser), defaultClientConfig)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..))
|
||||
import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..), supportedSMPClientVRange)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server.Env.STM (ServerConfig (..))
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
@@ -96,8 +98,10 @@ runRight action =
|
||||
|
||||
functionalAPITests :: ATransport -> Spec
|
||||
functionalAPITests t = do
|
||||
describe "Establishing duplex connection" $
|
||||
describe "Establishing duplex connection" $ do
|
||||
testMatrix2 t runAgentClientTest
|
||||
it "should connect when server with multiple identities is stored" $
|
||||
withSmpServer t testServerMultipleIdentities
|
||||
describe "Establishing duplex connection v2, different Ratchet versions" $
|
||||
testRatchetMatrix2 t runAgentClientTest
|
||||
describe "Establish duplex connection via contact address" $
|
||||
@@ -1010,6 +1014,40 @@ testTwoUsers = do
|
||||
hasClients :: HasCallStack => AgentClient -> Int -> ExceptT AgentErrorType IO ()
|
||||
hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n
|
||||
|
||||
testServerMultipleIdentities :: HasCallStack => IO ()
|
||||
testServerMultipleIdentities = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
runRight_ $ do
|
||||
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
get bob ##> ("", aliceId, CON)
|
||||
exchangeGreetings alice bobId bob aliceId
|
||||
-- this saves queue with second server identity
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
bob' <- liftIO $ getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
subscribeConnection bob' aliceId
|
||||
exchangeGreetingsMsgId 6 alice bobId bob' aliceId
|
||||
where
|
||||
secondIdentityCReq :: ConnectionRequestUri 'CMInvitation
|
||||
secondIdentityCReq =
|
||||
CRInvitationUri
|
||||
connReqData
|
||||
{ crSmpQueues =
|
||||
[ SMPQueueUri
|
||||
supportedSMPClientVRange
|
||||
queueAddr
|
||||
{ smpServer = SMPServer "localhost" "5001" (C.KeyHash "\215m\248\251")
|
||||
}
|
||||
]
|
||||
}
|
||||
testE2ERatchetParams
|
||||
|
||||
exchangeGreetings :: HasCallStack => AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
|
||||
exchangeGreetings = exchangeGreetingsMsgId 4
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ module CoreTests.ProtocolErrorTests where
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Simplex.Messaging.Agent.Protocol (AgentErrorType (..))
|
||||
import Simplex.Messaging.Agent.Protocol (AgentErrorType (..), BrokerErrorType (..))
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Test.Hspec
|
||||
@@ -17,12 +17,14 @@ protocolErrorTests :: Spec
|
||||
protocolErrorTests = modifyMaxSuccess (const 1000) $ do
|
||||
describe "errors parsing / serializing" $ do
|
||||
it "should parse SMP protocol errors" . property $ \(err :: AgentErrorType) ->
|
||||
errServerHasSpaces err
|
||||
errHasSpaces err
|
||||
|| parseAll strP (strEncode err) == Right err
|
||||
it "should parse SMP agent errors" . property $ \(err :: AgentErrorType) ->
|
||||
errServerHasSpaces err
|
||||
errHasSpaces err
|
||||
|| parseAll strP (strEncode err) == Right err
|
||||
where
|
||||
errServerHasSpaces = \case
|
||||
BROKER srv _ -> ' ' `B.elem` encodeUtf8 (T.pack srv)
|
||||
errHasSpaces = \case
|
||||
BROKER srv (RESPONSE e) -> hasSpaces srv || hasSpaces e
|
||||
BROKER srv _ -> hasSpaces srv
|
||||
_ -> False
|
||||
hasSpaces s = ' ' `B.elem` encodeUtf8 (T.pack s)
|
||||
|
||||
@@ -64,16 +64,16 @@ ntfSyntaxTests (ATransport t) = do
|
||||
Expectation
|
||||
command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response
|
||||
|
||||
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission NtfResponse
|
||||
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission ErrorType NtfResponse
|
||||
pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command))
|
||||
|
||||
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse)
|
||||
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
|
||||
sendRecvNtf h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (sgn, t)
|
||||
tGet1 h
|
||||
|
||||
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse)
|
||||
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
|
||||
signSendRecvNtf h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (Just $ C.sign pk t, t)
|
||||
|
||||
@@ -63,7 +63,7 @@ serverTests t@(ATransport t') = do
|
||||
testMsgExpireOnInterval t'
|
||||
testMsgNOTExpireOnInterval t'
|
||||
|
||||
pattern Resp :: CorrId -> QueueId -> BrokerMsg -> SignedTransmission BrokerMsg
|
||||
pattern Resp :: CorrId -> QueueId -> BrokerMsg -> SignedTransmission ErrorType BrokerMsg
|
||||
pattern Resp corrId queueId command <- (_, _, (corrId, queueId, Right command))
|
||||
|
||||
pattern Ids :: RecipientId -> SenderId -> RcvPublicDhKey -> BrokerMsg
|
||||
@@ -72,13 +72,13 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh)
|
||||
pattern Msg :: MsgId -> MsgBody -> BrokerMsg
|
||||
pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body}
|
||||
|
||||
sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg)
|
||||
sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
|
||||
sendRecv h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (sgn, t)
|
||||
tGet1 h
|
||||
|
||||
signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg)
|
||||
signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
|
||||
signSendRecv h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (Just $ C.sign pk t, t)
|
||||
@@ -89,7 +89,7 @@ tPut1 h t = do
|
||||
[r] <- tPut h [t]
|
||||
pure r
|
||||
|
||||
tGet1 :: (ProtocolEncoding cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission cmd)
|
||||
tGet1 :: (ProtocolEncoding err cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission err cmd)
|
||||
tGet1 h = do
|
||||
[r] <- liftIO $ tGet h
|
||||
pure r
|
||||
@@ -428,7 +428,7 @@ testSwitchSub (ATransport t) =
|
||||
Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, ACK mId3)
|
||||
(ok3, OK) #== "accepts ACK from the 2nd TCP connection"
|
||||
|
||||
1000 `timeout` tGet @BrokerMsg rh1 >>= \case
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg rh1 >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else is delivered to the 1st TCP connection"
|
||||
|
||||
@@ -869,14 +869,14 @@ testMessageNotifications (ATransport t) =
|
||||
Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2)
|
||||
(dec mId2 msg2, Right "hello again") #== "delivered from queue again"
|
||||
Resp "" _ (NMSG _ _) <- tGet1 nh2
|
||||
1000 `timeout` tGet @BrokerMsg nh1 >>= \case
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg nh1 >>= \case
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection"
|
||||
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL)
|
||||
Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there")
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet1 rh
|
||||
(dec mId3 msg3, Right "hello there") #== "delivered from queue again"
|
||||
1000 `timeout` tGet @BrokerMsg nh2 >>= \case
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg nh2 >>= \case
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection"
|
||||
|
||||
@@ -895,7 +895,7 @@ testMsgExpireOnSend t =
|
||||
testSMPClient @c $ \rh -> do
|
||||
Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB)
|
||||
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
|
||||
1000 `timeout` tGet @BrokerMsg rh >>= \case
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else should be delivered"
|
||||
|
||||
@@ -911,7 +911,7 @@ testMsgExpireOnInterval t =
|
||||
threadDelay 2500000
|
||||
testSMPClient @c $ \rh -> do
|
||||
Resp "2" _ OK <- signSendRecv rh rKey ("2", rId, SUB)
|
||||
1000 `timeout` tGet @BrokerMsg rh >>= \case
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing should be delivered"
|
||||
|
||||
@@ -929,7 +929,7 @@ testMsgNOTExpireOnInterval t =
|
||||
testSMPClient @c $ \rh -> do
|
||||
Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB)
|
||||
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
|
||||
1000 `timeout` tGet @BrokerMsg rh >>= \case
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else should be delivered"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user