mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-29 14:30:22 +00:00
ntf server implementation, updated ntf protocol, ntf client based on refactored protocol client, bare-bones SMP agent to manage ntf connections (to connect to ntf server) (#338)
* process ntf server commands * when subscription is re-created and it was ENDed, resubscribe to SMP * SMPClientAgent draft * SMPClientAgent: remove double tracking of subscriptions * subscriber frame * PING error now throws error to restart SMPClient for more reliable re-connection (#342) * increase TCP timeout to 5 sec * add pragmas and vacuum db (#343) * vacuum in each connection to enable auto-vacuum (#344) * update protocol, token verification * refactor SMPClient to ProtocoClient, to use with notification server protocol * notification server client, managing notification clients in the agent * stub for push payload Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
4e1184d9eb
commit
d31958855f
30
protocol/diagrams/notifications/register-token.mmd
Normal file
30
protocol/diagrams/notifications/register-token.mmd
Normal file
@@ -0,0 +1,30 @@
|
||||
sequenceDiagram
|
||||
participant M as mobile app
|
||||
participant C as chat core
|
||||
participant A as agent
|
||||
participant P as push server
|
||||
participant APN as APN
|
||||
|
||||
note over M, APN: get device token
|
||||
M ->> APN: registerForRemoteNotifications()
|
||||
APN ->> M: device token
|
||||
|
||||
note over M, P: register device token with push server
|
||||
M ->> C: /_ntf register <token>
|
||||
C ->> A: registerNtfToken(<token>)
|
||||
A ->> P: TNEW
|
||||
P ->> A: ID (tokenId)
|
||||
A ->> C: registered
|
||||
C ->> M: registered
|
||||
|
||||
note over M, APN: verify device token
|
||||
P ->> APN: E2E encrypted code<br>in background<br>notification
|
||||
APN ->> M: deliver background notification with e2ee verification token
|
||||
M ->> C: /_ntf verify <e2ee code>
|
||||
C ->> A: verifyNtfToken(<e2ee code>)
|
||||
A ->> P: TVFY code
|
||||
P ->> A: OK / ERR
|
||||
A ->> C: verified
|
||||
C ->> M: verified
|
||||
|
||||
note over M, APN: now token ID can be used
|
||||
@@ -43,10 +43,12 @@ library
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220301_snd_queue_keys
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220322_notifications
|
||||
Simplex.Messaging.Client
|
||||
Simplex.Messaging.Client.Agent
|
||||
Simplex.Messaging.Crypto
|
||||
Simplex.Messaging.Crypto.Ratchet
|
||||
Simplex.Messaging.Encoding
|
||||
Simplex.Messaging.Encoding.String
|
||||
Simplex.Messaging.Notifications.Client
|
||||
Simplex.Messaging.Notifications.Protocol
|
||||
Simplex.Messaging.Notifications.Server
|
||||
Simplex.Messaging.Notifications.Server.Env
|
||||
|
||||
@@ -76,12 +76,13 @@ import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore)
|
||||
import Simplex.Messaging.Client (SMPServerTransmission)
|
||||
import Simplex.Messaging.Client (ServerTransmission)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken)
|
||||
import Simplex.Messaging.Parsers (parse)
|
||||
import Simplex.Messaging.Protocol (MsgBody)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, MsgBody)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (bshow, liftError, tryError, unlessM)
|
||||
@@ -149,6 +150,10 @@ deleteConnection c = withAgentEnv c . deleteConnection' c
|
||||
setSMPServers :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServer -> m ()
|
||||
setSMPServers c = withAgentEnv c . setSMPServers' c
|
||||
|
||||
-- | Register device notifications token
|
||||
registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m ()
|
||||
registerNtfToken c = withAgentEnv c . registerNtfToken' c
|
||||
|
||||
withAgentEnv :: AgentClient -> ReaderT Env m a -> m a
|
||||
withAgentEnv c = (`runReaderT` agentEnv c)
|
||||
|
||||
@@ -490,10 +495,13 @@ deleteConnection' c connId =
|
||||
withStore (`deleteConn` connId)
|
||||
|
||||
-- | Change servers to be used for creating new queues, in Reader monad
|
||||
setSMPServers' :: forall m. AgentMonad m => AgentClient -> NonEmpty SMPServer -> m ()
|
||||
setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m ()
|
||||
setSMPServers' c servers = do
|
||||
atomically $ writeTVar (smpServers c) servers
|
||||
|
||||
registerNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m ()
|
||||
registerNtfToken' c token = pure ()
|
||||
|
||||
getSMPServer :: AgentMonad m => AgentClient -> m SMPServer
|
||||
getSMPServer c = do
|
||||
smpServers <- readTVarIO $ smpServers c
|
||||
@@ -511,7 +519,7 @@ subscriber c@AgentClient {msgQ} = forever $ do
|
||||
Left e -> liftIO $ print e
|
||||
Right _ -> return ()
|
||||
|
||||
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SMPServerTransmission -> m ()
|
||||
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m ()
|
||||
processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do
|
||||
withStore (\st -> getRcvConn st srv rId) >>= \case
|
||||
SomeConn SCDuplex (DuplexConnection cData rq _) -> processSMP SCDuplex cData rq
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
{-# LANGUAGE TypeSynonymInstances #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Client
|
||||
( AgentClient (..),
|
||||
@@ -57,9 +56,12 @@ import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
|
||||
import Simplex.Messaging.Notifications.Client (NtfClient)
|
||||
import Simplex.Messaging.Notifications.Protocol (NtfResponse)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
@@ -67,18 +69,22 @@ import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError, when
|
||||
import Simplex.Messaging.Version
|
||||
import System.Timeout (timeout)
|
||||
import UnliftIO (async, forConcurrently_)
|
||||
import UnliftIO.Exception (Exception, IOException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
type ClientVar msg = TMVar (Either AgentErrorType (ProtocolClient msg))
|
||||
|
||||
type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
|
||||
|
||||
type NtfClientVar = TMVar (Either AgentErrorType NtfClient)
|
||||
|
||||
data AgentClient = AgentClient
|
||||
{ rcvQ :: TBQueue (ATransmission 'Client),
|
||||
subQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue SMPServerTransmission,
|
||||
msgQ :: TBQueue (ServerTransmission BrokerMsg),
|
||||
smpServers :: TVar (NonEmpty SMPServer),
|
||||
smpClients :: TMap SMPServer SMPClientVar,
|
||||
ntfClients :: TMap ProtocolServer NtfClientVar,
|
||||
subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
|
||||
pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
|
||||
subscrConns :: TMap ConnId SMPServer,
|
||||
@@ -101,6 +107,7 @@ newAgentClient agentEnv = do
|
||||
msgQ <- newTBQueue qSize
|
||||
smpServers <- newTVar $ initialSMPServers (config agentEnv)
|
||||
smpClients <- TM.empty
|
||||
ntfClients <- TM.empty
|
||||
subscrSrvrs <- TM.empty
|
||||
pendingSubscrSrvrs <- TM.empty
|
||||
subscrConns <- TM.empty
|
||||
@@ -111,80 +118,30 @@ newAgentClient agentEnv = do
|
||||
asyncClients <- newTVar []
|
||||
clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1)
|
||||
lock <- newTMVar ()
|
||||
return AgentClient {rcvQ, subQ, msgQ, smpServers, smpClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock}
|
||||
return AgentClient {rcvQ, subQ, msgQ, smpServers, smpClients, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock}
|
||||
|
||||
-- | Agent monad with MonadReader Env and MonadError AgentErrorType
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
|
||||
newtype InternalException e = InternalException {unInternalException :: e}
|
||||
deriving (Eq, Show)
|
||||
class ProtocolServerClient msg where
|
||||
getProtocolServerClient :: AgentMonad m => AgentClient -> ProtocolServer -> m (ProtocolClient msg)
|
||||
|
||||
instance Exception e => Exception (InternalException e)
|
||||
instance ProtocolServerClient BrokerMsg where getProtocolServerClient = getSMPServerClient
|
||||
|
||||
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
|
||||
withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b
|
||||
withRunInIO exceptToIO =
|
||||
withExceptT unInternalException . ExceptT . E.try $
|
||||
withRunInIO $ \run ->
|
||||
exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT)
|
||||
instance ProtocolServerClient NtfResponse where getProtocolServerClient = getNtfServerClient
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
||||
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
atomically getClientVar >>= either newSMPClient waitForSMPClient
|
||||
atomically (getClientVar srv smpClients)
|
||||
>>= either
|
||||
(newProtocolClient c srv smpClients connectClient reconnectClient)
|
||||
(waitForProtocolClient smpCfg)
|
||||
where
|
||||
getClientVar :: STM (Either SMPClientVar SMPClientVar)
|
||||
getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients
|
||||
|
||||
newClientVar :: STM SMPClientVar
|
||||
newClientVar = do
|
||||
smpVar <- newEmptyTMVar
|
||||
TM.insert srv smpVar smpClients
|
||||
pure smpVar
|
||||
|
||||
waitForSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient
|
||||
waitForSMPClient smpVar = do
|
||||
SMPClientConfig {tcpTimeout} <- asks $ smpCfg . config
|
||||
smpClient_ <- liftIO $ tcpTimeout `timeout` atomically (readTMVar smpVar)
|
||||
liftEither $ case smpClient_ of
|
||||
Just (Right smpClient) -> Right smpClient
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left $ BROKER TIMEOUT
|
||||
|
||||
newSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient
|
||||
newSMPClient smpVar = tryConnectClient pure tryConnectAsync
|
||||
where
|
||||
tryConnectClient :: (SMPClient -> m a) -> m () -> m a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryError connectClient >>= \r -> case r of
|
||||
Right smp -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
atomically $ putTMVar smpVar r
|
||||
successAction smp
|
||||
Left e -> do
|
||||
if e == BROKER NETWORK || e == BROKER TIMEOUT
|
||||
then retryAction
|
||||
else atomically $ do
|
||||
putTMVar smpVar (Left e)
|
||||
TM.delete srv smpClients
|
||||
throwError e
|
||||
tryConnectAsync :: m ()
|
||||
tryConnectAsync = do
|
||||
a <- async connectAsync
|
||||
atomically $ modifyTVar' (asyncClients c) (a :)
|
||||
connectAsync :: m ()
|
||||
connectAsync = do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
connectClient :: m SMPClient
|
||||
connectClient = do
|
||||
cfg <- asks $ smpCfg . config
|
||||
u <- askUnliftIO
|
||||
liftEitherError smpClientError (getSMPClient srv cfg msgQ $ clientDisconnected u)
|
||||
`E.catch` internalError
|
||||
where
|
||||
internalError :: IOException -> m SMPClient
|
||||
internalError = throwError . INTERNAL . show
|
||||
liftEitherError protocolClientError (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
|
||||
|
||||
clientDisconnected :: UnliftIO m -> IO ()
|
||||
clientDisconnected u = do
|
||||
@@ -194,13 +151,14 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
cVar_ <- TM.lookupDelete srv $ subscrSrvrs c
|
||||
forM cVar_ $ \cVar -> do
|
||||
cs <- readTVar cVar
|
||||
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
|
||||
addPendingSubs cVar cs
|
||||
pure cs
|
||||
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
|
||||
where
|
||||
updateSubs cVar = do
|
||||
cs <- readTVar cVar
|
||||
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
|
||||
addPendingSubs cVar cs
|
||||
pure cs
|
||||
|
||||
addPendingSubs cVar cs = do
|
||||
let ps = pendingSubscrSrvrs c
|
||||
TM.lookup srv ps >>= \case
|
||||
@@ -225,30 +183,100 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
|
||||
reconnectClient :: m ()
|
||||
reconnectClient =
|
||||
withAgentLock c . withSMP c srv $ \smp -> do
|
||||
withAgentLock c . withClient c srv $ \smp -> do
|
||||
cs <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSubscrSrvrs c)
|
||||
forConcurrently_ (maybe [] M.toList cs) $ \sub@(connId, _) ->
|
||||
whenM (atomically $ isNothing <$> TM.lookup connId (subscrConns c)) $
|
||||
subscribe_ smp sub `catchError` handleError connId
|
||||
where
|
||||
subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT SMPClientError IO ()
|
||||
subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT ProtocolClientError IO ()
|
||||
subscribe_ smp (connId, rq@RcvQueue {rcvPrivateKey, rcvId}) = do
|
||||
subscribeSMPQueue smp rcvPrivateKey rcvId
|
||||
addSubscription c rq connId
|
||||
liftIO $ notifySub UP connId
|
||||
|
||||
handleError :: ConnId -> SMPClientError -> ExceptT SMPClientError IO ()
|
||||
handleError :: ConnId -> ProtocolClientError -> ExceptT ProtocolClientError IO ()
|
||||
handleError connId = \case
|
||||
e@SMPResponseTimeout -> throwError e
|
||||
e@SMPNetworkError -> throwError e
|
||||
e@PCEResponseTimeout -> throwError e
|
||||
e@PCENetworkError -> throwError e
|
||||
e -> do
|
||||
liftIO $ notifySub (ERR $ smpClientError e) connId
|
||||
liftIO $ notifySub (ERR $ protocolClientError e) connId
|
||||
atomically $ removePendingSubscription c srv connId
|
||||
|
||||
notifySub :: ACommand 'Agent -> ConnId -> IO ()
|
||||
notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
|
||||
|
||||
closeAgentClient :: MonadUnliftIO m => AgentClient -> m ()
|
||||
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> ProtocolServer -> m NtfClient
|
||||
getNtfServerClient c@AgentClient {ntfClients} srv =
|
||||
atomically (getClientVar srv ntfClients)
|
||||
>>= either
|
||||
(newProtocolClient c srv ntfClients connectClient $ pure ())
|
||||
(waitForProtocolClient ntfCfg)
|
||||
where
|
||||
connectClient :: m NtfClient
|
||||
connectClient = do
|
||||
cfg <- asks $ ntfCfg . config
|
||||
liftEitherError protocolClientError (getProtocolClient srv cfg Nothing clientDisconnected)
|
||||
|
||||
clientDisconnected :: IO ()
|
||||
clientDisconnected = do
|
||||
atomically $ TM.delete srv ntfClients
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
getClientVar :: forall a. ProtocolServer -> TMap ProtocolServer (TMVar a) -> STM (Either (TMVar a) (TMVar a))
|
||||
getClientVar srv clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv clients
|
||||
where
|
||||
newClientVar :: STM (TMVar a)
|
||||
newClientVar = do
|
||||
var <- newEmptyTMVar
|
||||
TM.insert srv var clients
|
||||
pure var
|
||||
|
||||
waitForProtocolClient :: AgentMonad m => (AgentConfig -> ProtocolClientConfig) -> ClientVar msg -> m (ProtocolClient msg)
|
||||
waitForProtocolClient clientConfig clientVar = do
|
||||
ProtocolClientConfig {tcpTimeout} <- asks $ clientConfig . config
|
||||
client_ <- liftIO $ tcpTimeout `timeout` atomically (readTMVar clientVar)
|
||||
liftEither $ case client_ of
|
||||
Just (Right smpClient) -> Right smpClient
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left $ BROKER TIMEOUT
|
||||
|
||||
newProtocolClient ::
|
||||
forall msg m.
|
||||
AgentMonad m =>
|
||||
AgentClient ->
|
||||
ProtocolServer ->
|
||||
TMap ProtocolServer (ClientVar msg) ->
|
||||
m (ProtocolClient msg) ->
|
||||
m () ->
|
||||
ClientVar msg ->
|
||||
m (ProtocolClient msg)
|
||||
newProtocolClient c srv clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
|
||||
where
|
||||
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryError connectClient >>= \r -> case r of
|
||||
Right smp -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
atomically $ putTMVar clientVar r
|
||||
successAction smp
|
||||
Left e -> do
|
||||
if e == BROKER NETWORK || e == BROKER TIMEOUT
|
||||
then retryAction
|
||||
else atomically $ do
|
||||
putTMVar clientVar (Left e)
|
||||
TM.delete srv clients
|
||||
throwError e
|
||||
tryConnectAsync :: m ()
|
||||
tryConnectAsync = do
|
||||
a <- async connectAsync
|
||||
atomically $ modifyTVar' (asyncClients c) (a :)
|
||||
connectAsync :: m ()
|
||||
connectAsync = do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
closeAgentClient :: MonadIO m => AgentClient -> m ()
|
||||
closeAgentClient c = liftIO $ do
|
||||
closeSMPServerClients c
|
||||
cancelActions $ reconnections c
|
||||
@@ -260,7 +288,7 @@ closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeCli
|
||||
where
|
||||
closeClient smpVar =
|
||||
atomically (readTMVar smpVar) >>= \case
|
||||
Right smp -> closeSMPClient smp `E.catch` \(_ :: E.SomeException) -> pure ()
|
||||
Right smp -> closeProtocolClient smp `E.catch` \(_ :: E.SomeException) -> pure ()
|
||||
_ -> pure ()
|
||||
|
||||
cancelActions :: Foldable f => TVar (f (Async ())) -> IO ()
|
||||
@@ -272,40 +300,40 @@ withAgentLock AgentClient {lock} =
|
||||
(void . atomically $ takeTMVar lock)
|
||||
(atomically $ putTMVar lock ())
|
||||
|
||||
withSMP_ :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> m a) -> m a
|
||||
withSMP_ c srv action =
|
||||
(getSMPServerClient c srv >>= action) `catchError` logServerError
|
||||
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> (ProtocolClient msg -> m a) -> m a
|
||||
withClient_ c srv action = (getProtocolServerClient c srv >>= action) `catchError` logServerError
|
||||
where
|
||||
logServerError :: AgentErrorType -> m a
|
||||
logServerError e = do
|
||||
logServer "<--" c srv "" $ bshow e
|
||||
throwError e
|
||||
|
||||
withLogSMP_ :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> m a) -> m a
|
||||
withLogSMP_ c srv qId cmdStr action = do
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> QueueId -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withLogClient_ c srv qId cmdStr action = do
|
||||
logServer "-->" c srv qId cmdStr
|
||||
res <- withSMP_ c srv action
|
||||
res <- withClient_ c srv action
|
||||
logServer "<--" c srv qId "OK"
|
||||
return res
|
||||
|
||||
withSMP :: AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withSMP c srv action = withSMP_ c srv $ liftSMP . action
|
||||
withClient :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withClient c srv action = withClient_ c srv $ liftClient . action
|
||||
|
||||
withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withLogSMP c srv qId cmdStr action = withLogSMP_ c srv qId cmdStr $ liftSMP . action
|
||||
withLogClient :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ liftClient . action
|
||||
|
||||
liftSMP :: AgentMonad m => ExceptT SMPClientError IO a -> m a
|
||||
liftSMP = liftError smpClientError
|
||||
liftClient :: AgentMonad m => ExceptT ProtocolClientError IO a -> m a
|
||||
liftClient = liftError protocolClientError
|
||||
|
||||
smpClientError :: SMPClientError -> AgentErrorType
|
||||
smpClientError = \case
|
||||
SMPServerError e -> SMP e
|
||||
SMPResponseError e -> BROKER $ RESPONSE e
|
||||
SMPUnexpectedResponse -> BROKER UNEXPECTED
|
||||
SMPResponseTimeout -> BROKER TIMEOUT
|
||||
SMPNetworkError -> BROKER NETWORK
|
||||
SMPTransportError e -> BROKER $ TRANSPORT e
|
||||
e -> INTERNAL $ show e
|
||||
protocolClientError :: ProtocolClientError -> AgentErrorType
|
||||
protocolClientError = \case
|
||||
PCEProtocolError e -> SMP e
|
||||
PCEResponseError e -> BROKER $ RESPONSE e
|
||||
PCEUnexpectedResponse -> BROKER UNEXPECTED
|
||||
PCEResponseTimeout -> BROKER TIMEOUT
|
||||
PCENetworkError -> BROKER NETWORK
|
||||
PCETransportError e -> BROKER $ TRANSPORT e
|
||||
e@PCESignatureError {} -> INTERNAL $ show e
|
||||
e@PCEIOError {} -> INTERNAL $ show e
|
||||
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> m (RcvQueue, SMPQueueUri)
|
||||
newRcvQueue c srv =
|
||||
@@ -324,7 +352,7 @@ newRcvQueue_ a c srv = do
|
||||
(e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair'
|
||||
logServer "-->" c srv "" "NEW"
|
||||
QIK {rcvId, sndId, rcvPublicDhKey} <-
|
||||
withSMP c srv $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey
|
||||
withClient c srv $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey
|
||||
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
|
||||
let rq =
|
||||
RcvQueue
|
||||
@@ -342,15 +370,15 @@ newRcvQueue_ a c srv = do
|
||||
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
|
||||
atomically $ addPendingSubscription c rq connId
|
||||
withLogSMP c server rcvId "SUB" $ \smp -> do
|
||||
withLogClient c server rcvId "SUB" $ \smp -> do
|
||||
liftIO (runExceptT $ subscribeSMPQueue smp rcvPrivateKey rcvId) >>= \case
|
||||
Left e -> do
|
||||
atomically . when (e /= SMPNetworkError && e /= SMPResponseTimeout) $
|
||||
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription c server connId
|
||||
throwError e
|
||||
Right _ -> addSubscription c rq connId
|
||||
|
||||
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
addSubscription c rq@RcvQueue {server} connId = atomically $ do
|
||||
TM.insert connId server $ subscrConns c
|
||||
addSubs_ (subscrSrvrs c) rq connId
|
||||
@@ -365,7 +393,7 @@ addSubs_ ss rq@RcvQueue {server} connId =
|
||||
Just m -> TM.insert connId rq m
|
||||
_ -> TM.singleton connId rq >>= \m -> TM.insert server m ss
|
||||
|
||||
removeSubscription :: MonadUnliftIO m => AgentClient -> ConnId -> m ()
|
||||
removeSubscription :: MonadIO m => AgentClient -> ConnId -> m ()
|
||||
removeSubscription c@AgentClient {subscrConns} connId = atomically $ do
|
||||
server_ <- TM.lookupDelete connId subscrConns
|
||||
mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_
|
||||
@@ -377,12 +405,12 @@ removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> S
|
||||
removeSubs_ ss server connId =
|
||||
TM.lookup server ss >>= mapM_ (TM.delete connId)
|
||||
|
||||
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
|
||||
logServer :: MonadIO m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
|
||||
logServer dir AgentClient {clientId} srv qId cmdStr =
|
||||
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr]
|
||||
|
||||
showServer :: SMPServer -> ByteString
|
||||
showServer SMPServer {host, port} =
|
||||
showServer ProtocolServer {host, port} =
|
||||
B.pack $ host <> if null port then "" else ':' : port
|
||||
|
||||
logSecret :: ByteString -> ByteString
|
||||
@@ -390,17 +418,17 @@ logSecret bs = encode $ B.take 3 bs
|
||||
|
||||
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
||||
sendConfirmation c sq@SndQueue {server, sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
|
||||
withLogSMP_ c server sndId "SEND <CONF>" $ \smp -> do
|
||||
withLogClient_ c server sndId "SEND <CONF>" $ \smp -> do
|
||||
let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation
|
||||
msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg
|
||||
liftSMP $ sendSMPMessage smp Nothing sndId msg
|
||||
liftClient $ sendSMPMessage smp Nothing sndId msg
|
||||
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
|
||||
|
||||
sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) connReq connInfo =
|
||||
withLogSMP_ c smpServer senderId "SEND <INV>" $ \smp -> do
|
||||
withLogClient_ c smpServer senderId "SEND <INV>" $ \smp -> do
|
||||
msg <- mkInvitation
|
||||
liftSMP $ sendSMPMessage smp Nothing senderId msg
|
||||
liftClient $ sendSMPMessage smp Nothing senderId msg
|
||||
where
|
||||
mkInvitation :: m ByteString
|
||||
-- this is only encrypted with per-queue E2E, not with double ratchet
|
||||
@@ -411,30 +439,30 @@ sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) co
|
||||
|
||||
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicVerifyKey -> m ()
|
||||
secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey =
|
||||
withLogSMP c server rcvId "KEY <key>" $ \smp ->
|
||||
withLogClient c server rcvId "KEY <key>" $ \smp ->
|
||||
secureSMPQueue smp rcvPrivateKey rcvId senderKey
|
||||
|
||||
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
sendAck c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogSMP c server rcvId "ACK" $ \smp ->
|
||||
withLogClient c server rcvId "ACK" $ \smp ->
|
||||
ackSMPMessage smp rcvPrivateKey rcvId
|
||||
|
||||
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogSMP c server rcvId "OFF" $ \smp ->
|
||||
withLogClient c server rcvId "OFF" $ \smp ->
|
||||
suspendSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogSMP c server rcvId "DEL" $ \smp ->
|
||||
withLogClient c server rcvId "DEL" $ \smp ->
|
||||
deleteSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
sendAgentMessage :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
||||
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} agentMsg =
|
||||
withLogSMP_ c server sndId "SEND <MSG>" $ \smp -> do
|
||||
withLogClient_ c server sndId "SEND <MSG>" $ \smp -> do
|
||||
let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg
|
||||
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
|
||||
liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg
|
||||
liftClient $ sendSMPMessage smp (Just sndPrivateKey) sndId msg
|
||||
|
||||
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
|
||||
agentCbEncrypt SndQueue {e2eDhSecret} e2ePubKey msg = do
|
||||
|
||||
@@ -23,6 +23,7 @@ import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent (SMPClientAgentConfig, defaultSMPClientAgentConfig)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import System.Random (StdGen, newStdGen)
|
||||
import UnliftIO.STM
|
||||
@@ -36,7 +37,8 @@ data AgentConfig = AgentConfig
|
||||
dbFile :: FilePath,
|
||||
dbPoolSize :: Int,
|
||||
yesToMigrations :: Bool,
|
||||
smpCfg :: SMPClientConfig,
|
||||
smpCfg :: ProtocolClientConfig,
|
||||
ntfCfg :: ProtocolClientConfig,
|
||||
reconnectInterval :: RetryInterval,
|
||||
helloTimeout :: NominalDiffTime,
|
||||
caCertificateFile :: FilePath,
|
||||
@@ -55,7 +57,8 @@ defaultAgentConfig =
|
||||
dbFile = "smp-agent.db",
|
||||
dbPoolSize = 4,
|
||||
yesToMigrations = False,
|
||||
smpCfg = smpDefaultConfig,
|
||||
smpCfg = defaultClientConfig,
|
||||
ntfCfg = defaultClientConfig,
|
||||
reconnectInterval =
|
||||
RetryInterval
|
||||
{ initialInterval = second,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
@@ -49,7 +50,8 @@ module Simplex.Messaging.Agent.Protocol
|
||||
AgentMessageType (..),
|
||||
APrivHeader (..),
|
||||
AMessage (..),
|
||||
SMPServer (..),
|
||||
SMPServer,
|
||||
pattern SMPServer,
|
||||
SrvLoc (..),
|
||||
SMPQueueUri (..),
|
||||
SMPQueueInfo (..),
|
||||
@@ -131,9 +133,10 @@ import Simplex.Messaging.Protocol
|
||||
( ErrorType,
|
||||
MsgBody,
|
||||
MsgId,
|
||||
SMPServer (..),
|
||||
SMPServer,
|
||||
SndPublicVerifyKey,
|
||||
SrvLoc (..),
|
||||
pattern SMPServer,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTransportError, transportErrorP)
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
@@ -59,7 +60,7 @@ import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), Skipp
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (blobFieldParser)
|
||||
import Simplex.Messaging.Protocol (MsgBody)
|
||||
import Simplex.Messaging.Protocol (MsgBody, ProtocolServer (..))
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Util (bshow, liftIOEither)
|
||||
import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
|
||||
@@ -127,9 +128,26 @@ connectSQLiteStore dbFilePath poolSize = do
|
||||
connectDB :: FilePath -> IO DB.Connection
|
||||
connectDB path = do
|
||||
dbConn <- DB.open path
|
||||
DB.execute_ dbConn "PRAGMA foreign_keys = ON; PRAGMA journal_mode = WAL;"
|
||||
DB.execute_ dbConn "PRAGMA foreign_keys = ON;"
|
||||
-- DB.execute_ dbConn "PRAGMA trusted_schema = OFF;"
|
||||
DB.execute_ dbConn "PRAGMA secure_delete = ON;"
|
||||
DB.execute_ dbConn "PRAGMA auto_vacuum = FULL;"
|
||||
DB.execute_ dbConn "VACUUM;"
|
||||
-- _printPragmas dbConn path
|
||||
pure dbConn
|
||||
|
||||
_printPragmas :: DB.Connection -> FilePath -> IO ()
|
||||
_printPragmas db path = do
|
||||
foreign_keys <- DB.query_ db "PRAGMA foreign_keys;" :: IO [[Int]]
|
||||
print $ path <> " foreign_keys: " <> show foreign_keys
|
||||
-- when run via sqlite-simple query for trusted_schema seems to return empty list
|
||||
trusted_schema <- DB.query_ db "PRAGMA trusted_schema;" :: IO [[Int]]
|
||||
print $ path <> " trusted_schema: " <> show trusted_schema
|
||||
secure_delete <- DB.query_ db "PRAGMA secure_delete;" :: IO [[Int]]
|
||||
print $ path <> " secure_delete: " <> show secure_delete
|
||||
auto_vacuum <- DB.query_ db "PRAGMA auto_vacuum;" :: IO [[Int]]
|
||||
print $ path <> " auto_vacuum: " <> show auto_vacuum
|
||||
|
||||
checkConstraint :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a)
|
||||
checkConstraint err action = action `E.catch` (pure . Left . handleSQLError err)
|
||||
|
||||
@@ -190,7 +208,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
getConn_ db connId
|
||||
|
||||
getRcvConn :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m SomeConn
|
||||
getRcvConn st SMPServer {host, port} rcvId =
|
||||
getRcvConn st ProtocolServer {host, port} rcvId =
|
||||
liftIOEither . withTransaction st $ \db ->
|
||||
DB.queryNamed
|
||||
db
|
||||
@@ -235,7 +253,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m ()
|
||||
setRcvQueueStatus st RcvQueue {rcvId, server = SMPServer {host, port}} status =
|
||||
setRcvQueueStatus st RcvQueue {rcvId, server = ProtocolServer {host, port}} status =
|
||||
-- ? throw error if queue does not exist?
|
||||
liftIO . withTransaction st $ \db ->
|
||||
DB.executeNamed
|
||||
@@ -248,7 +266,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
[":status" := status, ":host" := host, ":port" := port, ":rcv_id" := rcvId]
|
||||
|
||||
setRcvQueueConfirmedE2E :: SQLiteStore -> RcvQueue -> C.DhSecretX25519 -> m ()
|
||||
setRcvQueueConfirmedE2E st RcvQueue {rcvId, server = SMPServer {host, port}} e2eDhSecret =
|
||||
setRcvQueueConfirmedE2E st RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret =
|
||||
liftIO . withTransaction st $ \db ->
|
||||
DB.executeNamed
|
||||
db
|
||||
@@ -266,7 +284,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
]
|
||||
|
||||
setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m ()
|
||||
setSndQueueStatus st SndQueue {sndId, server = SMPServer {host, port}} status =
|
||||
setSndQueueStatus st SndQueue {sndId, server = ProtocolServer {host, port}} status =
|
||||
-- ? throw error if queue does not exist?
|
||||
liftIO . withTransaction st $ \db ->
|
||||
DB.executeNamed
|
||||
@@ -640,7 +658,7 @@ instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f,
|
||||
-- * Server upsert helper
|
||||
|
||||
upsertServer_ :: DB.Connection -> SMPServer -> IO ()
|
||||
upsertServer_ dbConn SMPServer {host, port, keyHash} = do
|
||||
upsertServer_ dbConn ProtocolServer {host, port, keyHash} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
@@ -23,9 +24,10 @@
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
|
||||
module Simplex.Messaging.Client
|
||||
( -- * Connect (disconnect) client to (from) SMP server
|
||||
ProtocolClient,
|
||||
SMPClient,
|
||||
getSMPClient,
|
||||
closeSMPClient,
|
||||
getProtocolClient,
|
||||
closeProtocolClient,
|
||||
|
||||
-- * SMP protocol command functions
|
||||
createSMPQueue,
|
||||
@@ -37,13 +39,13 @@ module Simplex.Messaging.Client
|
||||
ackSMPMessage,
|
||||
suspendSMPQueue,
|
||||
deleteSMPQueue,
|
||||
sendSMPCommand,
|
||||
sendProtocolCommand,
|
||||
|
||||
-- * Supporting types and client configuration
|
||||
SMPClientError (..),
|
||||
SMPClientConfig (..),
|
||||
smpDefaultConfig,
|
||||
SMPServerTransmission,
|
||||
ProtocolClientError (..),
|
||||
ProtocolClientConfig (..),
|
||||
defaultClientConfig,
|
||||
ServerTransmission,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -60,7 +62,7 @@ import Data.Maybe (fromMaybe)
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport (..), THandle (..), TLS, TProxy, Transport (..), TransportError)
|
||||
@@ -72,31 +74,30 @@ import System.Timeout (timeout)
|
||||
|
||||
-- | 'SMPClient' is a handle used to send commands to a specific SMP server.
|
||||
--
|
||||
-- The only exported selector is blockSize that is negotiated
|
||||
-- with the server during the TCP transport handshake.
|
||||
--
|
||||
-- Use 'getSMPClient' to connect to an SMP server and create a client handle.
|
||||
data SMPClient = SMPClient
|
||||
data ProtocolClient msg = ProtocolClient
|
||||
{ action :: Async (),
|
||||
connected :: TVar Bool,
|
||||
sessionId :: ByteString,
|
||||
smpServer :: SMPServer,
|
||||
protocolServer :: ProtocolServer,
|
||||
tcpTimeout :: Int,
|
||||
clientCorrId :: TVar Natural,
|
||||
sentCommands :: TMap CorrId Request,
|
||||
sentCommands :: TMap CorrId (Request msg),
|
||||
sndQ :: TBQueue SentRawTransmission,
|
||||
rcvQ :: TBQueue (SignedTransmission BrokerMsg),
|
||||
msgQ :: TBQueue SMPServerTransmission
|
||||
rcvQ :: TBQueue (SignedTransmission msg),
|
||||
msgQ :: Maybe (TBQueue (ServerTransmission msg))
|
||||
}
|
||||
|
||||
-- | Type synonym for transmission from some SPM server queue.
|
||||
type SMPServerTransmission = (SMPServer, RecipientId, BrokerMsg)
|
||||
type SMPClient = ProtocolClient SMP.BrokerMsg
|
||||
|
||||
-- | SMP client configuration.
|
||||
data SMPClientConfig = SMPClientConfig
|
||||
-- | Type synonym for transmission from some SPM server queue.
|
||||
type ServerTransmission msg = (ProtocolServer, QueueId, msg)
|
||||
|
||||
-- | protocol client configuration.
|
||||
data ProtocolClientConfig = ProtocolClientConfig
|
||||
{ -- | size of TBQueue to use for server commands and responses
|
||||
qSize :: Natural,
|
||||
-- | default SMP server port if port is not specified in SMPServer
|
||||
-- | default server port if port is not specified in ProtocolServer
|
||||
defaultTransport :: (ServiceName, ATransport),
|
||||
-- | timeout of TCP commands (microseconds)
|
||||
tcpTimeout :: Int,
|
||||
@@ -106,34 +107,35 @@ data SMPClientConfig = SMPClientConfig
|
||||
smpPing :: Int
|
||||
}
|
||||
|
||||
-- | Default SMP client configuration.
|
||||
smpDefaultConfig :: SMPClientConfig
|
||||
smpDefaultConfig =
|
||||
SMPClientConfig
|
||||
-- | Default protocol client configuration.
|
||||
defaultClientConfig :: ProtocolClientConfig
|
||||
defaultClientConfig =
|
||||
ProtocolClientConfig
|
||||
{ qSize = 64,
|
||||
defaultTransport = ("5223", transport @TLS),
|
||||
tcpTimeout = 4_000_000,
|
||||
tcpTimeout = 5_000_000,
|
||||
tcpKeepAlive = Just defaultKeepAliveOpts,
|
||||
smpPing = 600_000_000 -- 10min
|
||||
smpPing = 300_000_000 -- 5 min
|
||||
}
|
||||
|
||||
data Request = Request
|
||||
data Request msg = Request
|
||||
{ queueId :: QueueId,
|
||||
responseVar :: TMVar Response
|
||||
responseVar :: TMVar (Response msg)
|
||||
}
|
||||
|
||||
type Response = Either SMPClientError BrokerMsg
|
||||
type Response msg = Either ProtocolClientError msg
|
||||
|
||||
-- | Connects to 'SMPServer' using passed client configuration
|
||||
-- | Connects to 'ProtocolServer' using passed client configuration
|
||||
-- and queue for messages and notifications.
|
||||
--
|
||||
-- A single queue can be used for multiple 'SMPClient' instances,
|
||||
-- as 'SMPServerTransmission' includes server information.
|
||||
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO (Either SMPClientError SMPClient)
|
||||
getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smpPing} msgQ disconnected =
|
||||
atomically mkSMPClient >>= runClient useTransport
|
||||
getProtocolClient :: forall msg. Protocol msg => ProtocolServer -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tcpKeepAlive, smpPing} msgQ disconnected =
|
||||
(atomically mkSMPClient >>= runClient useTransport)
|
||||
`catch` \(e :: IOException) -> pure . Left $ PCEIOError e
|
||||
where
|
||||
mkSMPClient :: STM SMPClient
|
||||
mkSMPClient :: STM (ProtocolClient msg)
|
||||
mkSMPClient = do
|
||||
connected <- newTVar False
|
||||
clientCorrId <- newTVar 0
|
||||
@@ -141,11 +143,11 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smp
|
||||
sndQ <- newTBQueue qSize
|
||||
rcvQ <- newTBQueue qSize
|
||||
return
|
||||
SMPClient
|
||||
ProtocolClient
|
||||
{ action = undefined,
|
||||
sessionId = undefined,
|
||||
connected,
|
||||
smpServer,
|
||||
protocolServer,
|
||||
tcpTimeout,
|
||||
clientCorrId,
|
||||
sentCommands,
|
||||
@@ -154,50 +156,51 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smp
|
||||
msgQ
|
||||
}
|
||||
|
||||
runClient :: (ServiceName, ATransport) -> SMPClient -> IO (Either SMPClientError SMPClient)
|
||||
runClient :: (ServiceName, ATransport) -> ProtocolClient msg -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
runClient (port', ATransport t) c = do
|
||||
thVar <- newEmptyTMVarIO
|
||||
action <-
|
||||
async $
|
||||
runTransportClient (host smpServer) port' (keyHash smpServer) tcpKeepAlive (client t c thVar)
|
||||
`finally` atomically (putTMVar thVar $ Left SMPNetworkError)
|
||||
runTransportClient (host protocolServer) port' (keyHash protocolServer) tcpKeepAlive (client t c thVar)
|
||||
`finally` atomically (putTMVar thVar $ Left PCENetworkError)
|
||||
th_ <- tcpTimeout `timeout` atomically (takeTMVar thVar)
|
||||
pure $ case th_ of
|
||||
Just (Right THandle {sessionId}) -> Right c {action, sessionId}
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left SMPNetworkError
|
||||
Nothing -> Left PCENetworkError
|
||||
|
||||
useTransport :: (ServiceName, ATransport)
|
||||
useTransport = case port smpServer of
|
||||
useTransport = case port protocolServer of
|
||||
"" -> defaultTransport cfg
|
||||
"80" -> ("80", transport @WS)
|
||||
p -> (p, transport @TLS)
|
||||
|
||||
client :: forall c. Transport c => TProxy c -> SMPClient -> TMVar (Either SMPClientError (THandle c)) -> c -> IO ()
|
||||
client :: forall c. Transport c => TProxy c -> ProtocolClient msg -> TMVar (Either ProtocolClientError (THandle c)) -> c -> IO ()
|
||||
client _ c thVar h =
|
||||
runExceptT (smpClientHandshake h $ keyHash smpServer) >>= \case
|
||||
Left e -> atomically . putTMVar thVar . Left $ SMPTransportError e
|
||||
runExceptT (smpClientHandshake h $ keyHash protocolServer) >>= \case
|
||||
Left e -> atomically . putTMVar thVar . Left $ PCETransportError e
|
||||
Right th@THandle {sessionId} -> do
|
||||
atomically $ do
|
||||
writeTVar (connected c) True
|
||||
putTMVar thVar $ Right th
|
||||
let c' = c {sessionId} :: SMPClient
|
||||
let c' = c {sessionId} :: ProtocolClient msg
|
||||
-- TODO remove ping if 0 is passed (or Nothing?)
|
||||
raceAny_ [send c' th, process c', receive c' th, ping c']
|
||||
`finally` disconnected
|
||||
|
||||
send :: Transport c => SMPClient -> THandle c -> IO ()
|
||||
send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
send :: Transport c => ProtocolClient msg -> THandle c -> IO ()
|
||||
send ProtocolClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
|
||||
receive :: Transport c => SMPClient -> THandle c -> IO ()
|
||||
receive SMPClient {rcvQ} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
|
||||
receive :: Transport c => ProtocolClient msg -> THandle c -> IO ()
|
||||
receive ProtocolClient {rcvQ} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
|
||||
|
||||
ping :: SMPClient -> IO ()
|
||||
ping :: ProtocolClient msg -> IO ()
|
||||
ping c = forever $ do
|
||||
threadDelay smpPing
|
||||
runExceptT $ sendSMPCommand c Nothing "" PING
|
||||
void . either throwIO pure =<< runExceptT (sendProtocolCommand c Nothing "" protocolPing)
|
||||
|
||||
process :: SMPClient -> IO ()
|
||||
process SMPClient {rcvQ, sentCommands} = forever $ do
|
||||
process :: ProtocolClient msg -> IO ()
|
||||
process ProtocolClient {rcvQ, sentCommands} = forever $ do
|
||||
(_, _, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ
|
||||
if B.null $ bs corrId
|
||||
then sendMsg qId respOrErr
|
||||
@@ -209,45 +212,48 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smp
|
||||
putTMVar responseVar $
|
||||
if queueId == qId
|
||||
then case respOrErr of
|
||||
Left e -> Left $ SMPResponseError e
|
||||
Right (ERR e) -> Left $ SMPServerError e
|
||||
Right r -> Right r
|
||||
else Left SMPUnexpectedResponse
|
||||
Left e -> Left $ PCEResponseError e
|
||||
Right r -> case protocolError r of
|
||||
Just e -> Left $ PCEProtocolError e
|
||||
_ -> Right r
|
||||
else Left PCEUnexpectedResponse
|
||||
|
||||
sendMsg :: QueueId -> Either ErrorType BrokerMsg -> IO ()
|
||||
sendMsg :: QueueId -> Either ErrorType msg -> IO ()
|
||||
sendMsg qId = \case
|
||||
Right cmd -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd)
|
||||
Right cmd -> atomically $ mapM_ (`writeTBQueue` (protocolServer, qId, cmd)) msgQ
|
||||
-- TODO send everything else to errQ and log in agent
|
||||
_ -> return ()
|
||||
|
||||
-- | Disconnects SMP client from the server and terminates client threads.
|
||||
closeSMPClient :: SMPClient -> IO ()
|
||||
closeSMPClient = uninterruptibleCancel . action
|
||||
-- | Disconnects client from the server and terminates client threads.
|
||||
closeProtocolClient :: ProtocolClient msg -> IO ()
|
||||
closeProtocolClient = uninterruptibleCancel . action
|
||||
|
||||
-- | SMP client error type.
|
||||
data SMPClientError
|
||||
data ProtocolClientError
|
||||
= -- | Correctly parsed SMP server ERR response.
|
||||
-- This error is forwarded to the agent client as `ERR SMP err`.
|
||||
SMPServerError ErrorType
|
||||
PCEProtocolError ErrorType
|
||||
| -- | Invalid server response that failed to parse.
|
||||
-- Forwarded to the agent client as `ERR BROKER RESPONSE`.
|
||||
SMPResponseError ErrorType
|
||||
PCEResponseError ErrorType
|
||||
| -- | Different response from what is expected to a certain SMP command,
|
||||
-- e.g. server should respond `IDS` or `ERR` to `NEW` command,
|
||||
-- other responses would result in this error.
|
||||
-- Forwarded to the agent client as `ERR BROKER UNEXPECTED`.
|
||||
SMPUnexpectedResponse
|
||||
PCEUnexpectedResponse
|
||||
| -- | Used for TCP connection and command response timeouts.
|
||||
-- Forwarded to the agent client as `ERR BROKER TIMEOUT`.
|
||||
SMPResponseTimeout
|
||||
PCEResponseTimeout
|
||||
| -- | Failure to establish TCP connection.
|
||||
-- Forwarded to the agent client as `ERR BROKER NETWORK`.
|
||||
SMPNetworkError
|
||||
PCENetworkError
|
||||
| -- | TCP transport handshake or some other transport error.
|
||||
-- Forwarded to the agent client as `ERR BROKER TRANSPORT e`.
|
||||
SMPTransportError TransportError
|
||||
PCETransportError TransportError
|
||||
| -- | Error when cryptographically "signing" the command.
|
||||
SMPSignatureError C.CryptoError
|
||||
PCESignatureError C.CryptoError
|
||||
| -- | IO Error
|
||||
PCEIOError IOException
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
-- | Create a new SMP queue.
|
||||
@@ -258,92 +264,95 @@ createSMPQueue ::
|
||||
RcvPrivateSignKey ->
|
||||
RcvPublicVerifyKey ->
|
||||
RcvPublicDhKey ->
|
||||
ExceptT SMPClientError IO QueueIdsKeys
|
||||
ExceptT ProtocolClientError IO QueueIdsKeys
|
||||
createSMPQueue c rpKey rKey dhKey =
|
||||
sendSMPCommand c (Just rpKey) "" (NEW rKey dhKey) >>= \case
|
||||
IDS qik -> pure qik
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
-- | Subscribe to the SMP queue.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue
|
||||
subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO ()
|
||||
subscribeSMPQueue c@SMPClient {smpServer, msgQ} rpKey rId =
|
||||
subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
|
||||
subscribeSMPQueue c@ProtocolClient {protocolServer, msgQ} rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId SUB >>= \case
|
||||
OK -> return ()
|
||||
cmd@MSG {} ->
|
||||
lift . atomically $ writeTBQueue msgQ (smpServer, rId, cmd)
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
lift . atomically $ mapM_ (`writeTBQueue` (protocolServer, rId, cmd)) msgQ
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
-- | Subscribe to the SMP queue notifications.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications
|
||||
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT SMPClientError IO ()
|
||||
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT ProtocolClientError IO ()
|
||||
subscribeSMPQueueNotifications = okSMPCommand NSUB
|
||||
|
||||
-- | Secure the SMP queue by adding a sender public key.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command
|
||||
secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT SMPClientError IO ()
|
||||
secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT ProtocolClientError IO ()
|
||||
secureSMPQueue c rpKey rId senderKey = okSMPCommand (KEY senderKey) c rpKey rId
|
||||
|
||||
-- | Enable notifications for the queue for push notifications server.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#enable-notifications-command
|
||||
enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> ExceptT SMPClientError IO NotifierId
|
||||
enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> ExceptT ProtocolClientError IO NotifierId
|
||||
enableSMPQueueNotifications c rpKey rId notifierKey =
|
||||
sendSMPCommand c (Just rpKey) rId (NKEY notifierKey) >>= \case
|
||||
NID nId -> pure nId
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
-- | Send SMP message.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message
|
||||
sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgBody -> ExceptT SMPClientError IO ()
|
||||
sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgBody -> ExceptT ProtocolClientError IO ()
|
||||
sendSMPMessage c spKey sId msg =
|
||||
sendSMPCommand c spKey sId (SEND msg) >>= \case
|
||||
OK -> pure ()
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
-- | Acknowledge message delivery (server deletes the message).
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#acknowledge-message-delivery
|
||||
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
ackSMPMessage c@SMPClient {smpServer, msgQ} rpKey rId =
|
||||
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
ackSMPMessage c@ProtocolClient {protocolServer, msgQ} rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId ACK >>= \case
|
||||
OK -> return ()
|
||||
cmd@MSG {} ->
|
||||
lift . atomically $ writeTBQueue msgQ (smpServer, rId, cmd)
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
lift . atomically $ mapM_ (`writeTBQueue` (protocolServer, rId, cmd)) msgQ
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
-- | Irreversibly suspend SMP queue.
|
||||
-- The existing messages from the queue will still be delivered.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue
|
||||
suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
suspendSMPQueue = okSMPCommand OFF
|
||||
|
||||
-- | Irreversibly delete SMP queue and all messages in it.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue
|
||||
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
deleteSMPQueue = okSMPCommand DEL
|
||||
|
||||
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
okSMPCommand cmd c pKey qId =
|
||||
sendSMPCommand c (Just pKey) qId cmd >>= \case
|
||||
OK -> return ()
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
-- | Send SMP command
|
||||
-- TODO sign all requests (SEND of SMP confirmation would be signed with the same key that is passed to the recipient)
|
||||
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg
|
||||
sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, sessionId, tcpTimeout} pKey qId cmd = do
|
||||
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg
|
||||
sendSMPCommand c pKey qId = sendProtocolCommand c pKey qId . Cmd sParty
|
||||
|
||||
-- | Send Protocol command
|
||||
sendProtocolCommand :: forall msg. ProtocolEncoding (ProtocolCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtocolCommand msg -> ExceptT ProtocolClientError IO msg
|
||||
sendProtocolCommand ProtocolClient {sndQ, sentCommands, clientCorrId, sessionId, tcpTimeout} pKey qId cmd = do
|
||||
corrId <- lift_ getNextCorrId
|
||||
t <- signTransmission $ encodeTransmission sessionId (corrId, qId, cmd)
|
||||
ExceptT $ sendRecv corrId t
|
||||
where
|
||||
lift_ :: STM a -> ExceptT SMPClientError IO a
|
||||
lift_ :: STM a -> ExceptT ProtocolClientError IO a
|
||||
lift_ action = ExceptT $ Right <$> atomically action
|
||||
|
||||
getNextCorrId :: STM CorrId
|
||||
@@ -351,20 +360,20 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, sessionId, tcpTimeou
|
||||
i <- stateTVar clientCorrId $ \i -> (i, i + 1)
|
||||
pure . CorrId $ bshow i
|
||||
|
||||
signTransmission :: ByteString -> ExceptT SMPClientError IO SentRawTransmission
|
||||
signTransmission :: ByteString -> ExceptT ProtocolClientError IO SentRawTransmission
|
||||
signTransmission t = case pKey of
|
||||
Nothing -> return (Nothing, t)
|
||||
Just pk -> do
|
||||
sig <- liftError SMPSignatureError $ C.sign pk t
|
||||
sig <- liftError PCESignatureError $ C.sign pk t
|
||||
return (Just sig, t)
|
||||
|
||||
-- two separate "atomically" needed to avoid blocking
|
||||
sendRecv :: CorrId -> SentRawTransmission -> IO Response
|
||||
sendRecv :: CorrId -> SentRawTransmission -> IO (Response msg)
|
||||
sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar
|
||||
where
|
||||
withTimeout a = fromMaybe (Left SMPResponseTimeout) <$> timeout tcpTimeout a
|
||||
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
|
||||
|
||||
send :: CorrId -> SentRawTransmission -> STM (TMVar Response)
|
||||
send :: CorrId -> SentRawTransmission -> STM (TMVar (Response msg))
|
||||
send corrId t = do
|
||||
r <- newEmptyTMVar
|
||||
TM.insert corrId (Request qId r) sentCommands
|
||||
|
||||
301
src/Simplex/Messaging/Client/Agent.hs
Normal file
301
src/Simplex/Messaging/Client/Agent.hs
Normal file
@@ -0,0 +1,301 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
module Simplex.Messaging.Client.Agent where
|
||||
|
||||
import Control.Concurrent (forkIO)
|
||||
import Control.Concurrent.Async (Async, uninterruptibleCancel)
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Set (Set)
|
||||
import Data.Text.Encoding
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, SMPServer)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (tryE, whenM)
|
||||
import System.Timeout (timeout)
|
||||
import UnliftIO (async, forConcurrently_)
|
||||
import UnliftIO.Exception (Exception)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
type SMPClientVar = TMVar (Either ProtocolClientError SMPClient)
|
||||
|
||||
data SMPClientAgentEvent
|
||||
= CAConnected SMPServer
|
||||
| CADisconnected SMPServer (Set SMPSub)
|
||||
| CAReconnected SMPServer
|
||||
| CAResubscribed SMPServer SMPSub
|
||||
| CASubError SMPServer SMPSub ProtocolClientError
|
||||
|
||||
data SMPSubParty = SPRecipient | SPNotifier
|
||||
deriving (Eq, Ord)
|
||||
|
||||
type SMPSub = (SMPSubParty, QueueId)
|
||||
|
||||
-- type SMPServerSub = (SMPServer, SMPSub)
|
||||
|
||||
data SMPClientAgentConfig = SMPClientAgentConfig
|
||||
{ smpCfg :: ProtocolClientConfig,
|
||||
reconnectInterval :: RetryInterval,
|
||||
msgQSize :: Natural,
|
||||
agentQSize :: Natural
|
||||
}
|
||||
|
||||
defaultSMPClientAgentConfig :: SMPClientAgentConfig
|
||||
defaultSMPClientAgentConfig =
|
||||
SMPClientAgentConfig
|
||||
{ smpCfg = defaultClientConfig,
|
||||
reconnectInterval =
|
||||
RetryInterval
|
||||
{ initialInterval = second,
|
||||
increaseAfter = 10 * second,
|
||||
maxInterval = 10 * second
|
||||
},
|
||||
msgQSize = 64,
|
||||
agentQSize = 64
|
||||
}
|
||||
where
|
||||
second = 1000000
|
||||
|
||||
data SMPClientAgent = SMPClientAgent
|
||||
{ agentCfg :: SMPClientAgentConfig,
|
||||
msgQ :: TBQueue (ServerTransmission BrokerMsg),
|
||||
agentQ :: TBQueue SMPClientAgentEvent,
|
||||
smpClients :: TMap SMPServer SMPClientVar,
|
||||
srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateSignKey),
|
||||
pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateSignKey),
|
||||
reconnections :: TVar [Async ()],
|
||||
asyncClients :: TVar [Async ()]
|
||||
}
|
||||
|
||||
newtype InternalException e = InternalException {unInternalException :: e}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Exception e => Exception (InternalException e)
|
||||
|
||||
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
|
||||
withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b
|
||||
withRunInIO exceptToIO =
|
||||
withExceptT unInternalException . ExceptT . E.try $
|
||||
withRunInIO $ \run ->
|
||||
exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT)
|
||||
|
||||
newSMPClientAgent :: SMPClientAgentConfig -> STM SMPClientAgent
|
||||
newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} = do
|
||||
msgQ <- newTBQueue msgQSize
|
||||
agentQ <- newTBQueue agentQSize
|
||||
smpClients <- TM.empty
|
||||
srvSubs <- TM.empty
|
||||
pendingSrvSubs <- TM.empty
|
||||
reconnections <- newTVar []
|
||||
asyncClients <- newTVar []
|
||||
pure SMPClientAgent {agentCfg, msgQ, agentQ, smpClients, srvSubs, pendingSrvSubs, reconnections, asyncClients}
|
||||
|
||||
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT ProtocolClientError IO SMPClient
|
||||
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
atomically getClientVar >>= either newSMPClient waitForSMPClient
|
||||
where
|
||||
getClientVar :: STM (Either SMPClientVar SMPClientVar)
|
||||
getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients
|
||||
|
||||
newClientVar :: STM SMPClientVar
|
||||
newClientVar = do
|
||||
smpVar <- newEmptyTMVar
|
||||
TM.insert srv smpVar smpClients
|
||||
pure smpVar
|
||||
|
||||
waitForSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient
|
||||
waitForSMPClient smpVar = do
|
||||
let ProtocolClientConfig {tcpTimeout} = smpCfg agentCfg
|
||||
smpClient_ <- liftIO $ tcpTimeout `timeout` atomically (readTMVar smpVar)
|
||||
liftEither $ case smpClient_ of
|
||||
Just (Right smpClient) -> Right smpClient
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left PCEResponseTimeout
|
||||
|
||||
newSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient
|
||||
newSMPClient smpVar = tryConnectClient pure tryConnectAsync
|
||||
where
|
||||
tryConnectClient :: (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO () -> ExceptT ProtocolClientError IO a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryE connectClient >>= \r -> case r of
|
||||
Right smp -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
atomically $ putTMVar smpVar r
|
||||
successAction smp
|
||||
Left e -> do
|
||||
if e == PCENetworkError || e == PCEResponseTimeout
|
||||
then retryAction
|
||||
else atomically $ do
|
||||
putTMVar smpVar (Left e)
|
||||
TM.delete srv smpClients
|
||||
throwE e
|
||||
tryConnectAsync :: ExceptT ProtocolClientError IO ()
|
||||
tryConnectAsync = do
|
||||
a <- async connectAsync
|
||||
atomically $ modifyTVar' (asyncClients ca) (a :)
|
||||
connectAsync :: ExceptT ProtocolClientError IO ()
|
||||
connectAsync =
|
||||
withRetryInterval (reconnectInterval agentCfg) $ \loop ->
|
||||
void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
connectClient :: ExceptT ProtocolClientError IO SMPClient
|
||||
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected
|
||||
|
||||
clientDisconnected :: IO ()
|
||||
clientDisconnected = do
|
||||
removeClientAndSubs >>= (`forM_` serverDown)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
removeClientAndSubs :: IO (Maybe (Map SMPSub C.APrivateSignKey))
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs
|
||||
where
|
||||
updateSubs sVar = do
|
||||
ss <- readTVar sVar
|
||||
addPendingSubs sVar ss
|
||||
pure ss
|
||||
|
||||
addPendingSubs sVar ss = do
|
||||
let ps = pendingSrvSubs ca
|
||||
TM.lookup srv ps >>= \case
|
||||
Just v -> TM.union ss v
|
||||
_ -> TM.insert srv sVar ps
|
||||
|
||||
serverDown :: Map SMPSub C.APrivateSignKey -> IO ()
|
||||
serverDown ss = unless (M.null ss) . void . runExceptT $ do
|
||||
notify . CADisconnected srv $ M.keysSet ss
|
||||
reconnectServer
|
||||
|
||||
reconnectServer :: ExceptT ProtocolClientError IO ()
|
||||
reconnectServer = do
|
||||
a <- async tryReconnectClient
|
||||
atomically $ modifyTVar' (reconnections ca) (a :)
|
||||
|
||||
tryReconnectClient :: ExceptT ProtocolClientError IO ()
|
||||
tryReconnectClient = do
|
||||
withRetryInterval (reconnectInterval agentCfg) $ \loop ->
|
||||
reconnectClient `catchE` const loop
|
||||
|
||||
reconnectClient :: ExceptT ProtocolClientError IO ()
|
||||
reconnectClient = do
|
||||
withSMP ca srv $ \smp -> do
|
||||
notify $ CAReconnected srv
|
||||
cs <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSrvSubs ca)
|
||||
forConcurrently_ (maybe [] M.assocs cs) $ \sub@(s, _) ->
|
||||
whenM (atomically $ hasSub (srvSubs ca) srv s) $
|
||||
subscribe_ smp sub `catchE` handleError s
|
||||
where
|
||||
subscribe_ :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
|
||||
subscribe_ smp sub@(s, _) = do
|
||||
smpSubscribe smp sub
|
||||
atomically $ addSubscription ca srv sub
|
||||
notify $ CAResubscribed srv s
|
||||
|
||||
handleError :: SMPSub -> ProtocolClientError -> ExceptT ProtocolClientError IO ()
|
||||
handleError s = \case
|
||||
e@PCEResponseTimeout -> throwE e
|
||||
e@PCENetworkError -> throwE e
|
||||
e -> do
|
||||
notify $ CASubError srv s e
|
||||
atomically $ removePendingSubscription ca srv s
|
||||
|
||||
notify :: SMPClientAgentEvent -> ExceptT ProtocolClientError IO ()
|
||||
notify evt = atomically $ writeTBQueue (agentQ ca) evt
|
||||
|
||||
closeSMPClientAgent :: MonadUnliftIO m => SMPClientAgent -> m ()
|
||||
closeSMPClientAgent c = liftIO $ do
|
||||
closeSMPServerClients c
|
||||
cancelActions $ reconnections c
|
||||
cancelActions $ asyncClients c
|
||||
|
||||
closeSMPServerClients :: SMPClientAgent -> IO ()
|
||||
closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient)
|
||||
where
|
||||
closeClient smpVar =
|
||||
atomically (readTMVar smpVar) >>= \case
|
||||
Right smp -> closeProtocolClient smp `E.catch` \(_ :: E.SomeException) -> pure ()
|
||||
_ -> pure ()
|
||||
|
||||
cancelActions :: Foldable f => TVar (f (Async ())) -> IO ()
|
||||
cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel
|
||||
|
||||
withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO a
|
||||
withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPError
|
||||
where
|
||||
logSMPError :: ProtocolClientError -> ExceptT ProtocolClientError IO a
|
||||
logSMPError e = do
|
||||
liftIO $ putStrLn $ "SMP error (" <> show srv <> "): " <> show e
|
||||
throwE e
|
||||
|
||||
subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
|
||||
subscribeQueue ca srv sub = do
|
||||
atomically $ addPendingSubscription ca srv sub
|
||||
withSMP ca srv $ \smp -> subscribe_ smp `catchE` handleError
|
||||
where
|
||||
subscribe_ smp = do
|
||||
smpSubscribe smp sub
|
||||
atomically $ addSubscription ca srv sub
|
||||
|
||||
handleError e = do
|
||||
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription ca srv $ fst sub
|
||||
throwE e
|
||||
|
||||
showServer :: SMPServer -> ByteString
|
||||
showServer ProtocolServer {host, port} =
|
||||
B.pack $ host <> if null port then "" else ':' : port
|
||||
|
||||
smpSubscribe :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
|
||||
smpSubscribe smp ((party, queueId), privKey) = subscribe_ smp privKey queueId
|
||||
where
|
||||
subscribe_ = case party of
|
||||
SPRecipient -> subscribeSMPQueue
|
||||
SPNotifier -> subscribeSMPQueueNotifications
|
||||
|
||||
addSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> STM ()
|
||||
addSubscription ca srv sub = do
|
||||
addSub_ (srvSubs ca) srv sub
|
||||
removePendingSubscription ca srv $ fst sub
|
||||
|
||||
addPendingSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> STM ()
|
||||
addPendingSubscription = addSub_ . pendingSrvSubs
|
||||
|
||||
addSub_ :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> (SMPSub, C.APrivateSignKey) -> STM ()
|
||||
addSub_ subs srv (s, key) =
|
||||
TM.lookup srv subs >>= \case
|
||||
Just m -> TM.insert s key m
|
||||
_ -> TM.singleton s key >>= \v -> TM.insert srv v subs
|
||||
|
||||
removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
|
||||
removeSubscription = removeSub_ . srvSubs
|
||||
|
||||
removePendingSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
|
||||
removePendingSubscription = removeSub_ . pendingSrvSubs
|
||||
|
||||
removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM ()
|
||||
removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s)
|
||||
|
||||
getSubKey :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM (Maybe C.APrivateSignKey)
|
||||
getSubKey subs srv s = fmap join . mapM (TM.lookup s) =<< TM.lookup srv subs
|
||||
|
||||
hasSub :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM Bool
|
||||
hasSub subs srv s = maybe (pure False) (TM.member s) =<< TM.lookup srv subs
|
||||
55
src/Simplex/Messaging/Notifications/Client.hs
Normal file
55
src/Simplex/Messaging/Notifications/Client.hs
Normal file
@@ -0,0 +1,55 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Client where
|
||||
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Word (Word16)
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
|
||||
type NtfClient = ProtocolClient NtfResponse
|
||||
|
||||
registerNtfToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519)
|
||||
registerNtfToken c pKey newTkn =
|
||||
sendNtfCommand c (Just pKey) "" (TNEW newTkn) >>= \case
|
||||
NRId tknId dhKey -> pure (tknId, dhKey)
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
verifyNtfToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegistrationCode -> ExceptT ProtocolClientError IO ()
|
||||
verifyNtfToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId
|
||||
|
||||
deleteNtfToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO ()
|
||||
deleteNtfToken = okNtfCommand TDEL
|
||||
|
||||
enableNtfCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO ()
|
||||
enableNtfCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId
|
||||
|
||||
createNtfSubsciption :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO (NtfSubscriptionId, C.PublicKeyX25519)
|
||||
createNtfSubsciption c pKey newSub =
|
||||
sendNtfCommand c (Just pKey) "" (SNEW newSub) >>= \case
|
||||
NRId tknId dhKey -> pure (tknId, dhKey)
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
checkNtfSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus
|
||||
checkNtfSubscription c pKey subId =
|
||||
sendNtfCommand c (Just pKey) subId SCHK >>= \case
|
||||
NRStat stat -> pure stat
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
deleteNfgSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO ()
|
||||
deleteNfgSubscription = okNtfCommand SDEL
|
||||
|
||||
-- | Send notification server command
|
||||
sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT ProtocolClientError IO NtfResponse
|
||||
sendNtfCommand c pKey entId = sendProtocolCommand c pKey entId . NtfCmd sNtfEntity
|
||||
|
||||
okNtfCommand :: NtfEntityI e => NtfCommand e -> NtfClient -> C.APrivateSignKey -> NtfEntityId -> ExceptT ProtocolClientError IO ()
|
||||
okNtfCommand cmd c pKey entId =
|
||||
sendNtfCommand c (Just pKey) entId cmd >>= \case
|
||||
NROk -> return ()
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
@@ -1,6 +1,11 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Protocol where
|
||||
@@ -8,154 +13,333 @@ module Simplex.Messaging.Notifications.Protocol where
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Kind
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Type.Equality
|
||||
import Data.Word (Word16)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..))
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
data NtfCommandTag
|
||||
= NCCreate_
|
||||
| NCCheck_
|
||||
| NCToken_
|
||||
| NCDelete_
|
||||
data NtfEntity = Token | Subscription
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding NtfCommandTag where
|
||||
data SNtfEntity :: NtfEntity -> Type where
|
||||
SToken :: SNtfEntity 'Token
|
||||
SSubscription :: SNtfEntity 'Subscription
|
||||
|
||||
instance TestEquality SNtfEntity where
|
||||
testEquality SToken SToken = Just Refl
|
||||
testEquality SSubscription SSubscription = Just Refl
|
||||
testEquality _ _ = Nothing
|
||||
|
||||
deriving instance Show (SNtfEntity e)
|
||||
|
||||
class NtfEntityI (e :: NtfEntity) where sNtfEntity :: SNtfEntity e
|
||||
|
||||
instance NtfEntityI 'Token where sNtfEntity = SToken
|
||||
|
||||
instance NtfEntityI 'Subscription where sNtfEntity = SSubscription
|
||||
|
||||
data NtfCommandTag (e :: NtfEntity) where
|
||||
TNEW_ :: NtfCommandTag 'Token
|
||||
TVFY_ :: NtfCommandTag 'Token
|
||||
TDEL_ :: NtfCommandTag 'Token
|
||||
TCRN_ :: NtfCommandTag 'Token
|
||||
SNEW_ :: NtfCommandTag 'Subscription
|
||||
SCHK_ :: NtfCommandTag 'Subscription
|
||||
SDEL_ :: NtfCommandTag 'Subscription
|
||||
PING_ :: NtfCommandTag 'Subscription
|
||||
|
||||
deriving instance Show (NtfCommandTag e)
|
||||
|
||||
data NtfCmdTag = forall e. NtfEntityI e => NCT (SNtfEntity e) (NtfCommandTag e)
|
||||
|
||||
instance NtfEntityI e => Encoding (NtfCommandTag e) where
|
||||
smpEncode = \case
|
||||
NCCreate_ -> "CREATE"
|
||||
NCCheck_ -> "CHECK"
|
||||
NCToken_ -> "TOKEN"
|
||||
NCDelete_ -> "DELETE"
|
||||
TNEW_ -> "TNEW"
|
||||
TVFY_ -> "TVFY"
|
||||
TDEL_ -> "TDEL"
|
||||
TCRN_ -> "TCRN"
|
||||
SNEW_ -> "SNEW"
|
||||
SCHK_ -> "SCHK"
|
||||
SDEL_ -> "SDEL"
|
||||
PING_ -> "PING"
|
||||
smpP = messageTagP
|
||||
|
||||
instance ProtocolMsgTag NtfCommandTag where
|
||||
instance Encoding NtfCmdTag where
|
||||
smpEncode (NCT _ t) = smpEncode t
|
||||
smpP = messageTagP
|
||||
|
||||
instance ProtocolMsgTag NtfCmdTag where
|
||||
decodeTag = \case
|
||||
"CREATE" -> Just NCCreate_
|
||||
"CHECK" -> Just NCCheck_
|
||||
"TOKEN" -> Just NCToken_
|
||||
"DELETE" -> Just NCDelete_
|
||||
"TNEW" -> Just $ NCT SToken TNEW_
|
||||
"TVFY" -> Just $ NCT SToken TVFY_
|
||||
"TDEL" -> Just $ NCT SToken TDEL_
|
||||
"TCRN" -> Just $ NCT SToken TCRN_
|
||||
"SNEW" -> Just $ NCT SSubscription SNEW_
|
||||
"SCHK" -> Just $ NCT SSubscription SCHK_
|
||||
"SDEL" -> Just $ NCT SSubscription SDEL_
|
||||
"PING" -> Just $ NCT SSubscription PING_
|
||||
_ -> Nothing
|
||||
|
||||
data NtfCommand
|
||||
= NCCreate DeviceToken SMPQueueNtfUri C.APublicVerifyKey C.PublicKeyX25519
|
||||
| NCCheck
|
||||
| NCToken DeviceToken
|
||||
| NCDelete
|
||||
instance NtfEntityI e => ProtocolMsgTag (NtfCommandTag e) where
|
||||
decodeTag s = decodeTag s >>= (\(NCT _ t) -> checkEntity' t)
|
||||
|
||||
instance Protocol NtfCommand where
|
||||
type Tag NtfCommand = NtfCommandTag
|
||||
type NtfRegistrationCode = ByteString
|
||||
|
||||
data NewNtfEntity (e :: NtfEntity) where
|
||||
NewNtfTkn :: DeviceToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> NewNtfEntity 'Token
|
||||
NewNtfSub :: NtfTokenId -> SMPQueueNtf -> NewNtfEntity 'Subscription
|
||||
|
||||
data ANewNtfEntity = forall e. NtfEntityI e => ANE (SNtfEntity e) (NewNtfEntity e)
|
||||
|
||||
instance NtfEntityI e => Encoding (NewNtfEntity e) where
|
||||
smpEncode = \case
|
||||
NewNtfTkn tkn verifyKey dhPubKey -> smpEncode ('T', tkn, verifyKey, dhPubKey)
|
||||
NewNtfSub tknId smpQueue -> smpEncode ('S', tknId, smpQueue)
|
||||
smpP = (\(ANE _ c) -> checkEntity c) <$?> smpP
|
||||
|
||||
instance Encoding ANewNtfEntity where
|
||||
smpEncode (ANE _ e) = smpEncode e
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'T' -> ANE SToken <$> (NewNtfTkn <$> smpP <*> smpP <*> smpP)
|
||||
'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP)
|
||||
_ -> fail "bad ANewNtfEntity"
|
||||
|
||||
instance Protocol NtfResponse where
|
||||
type ProtocolCommand NtfResponse = NtfCmd
|
||||
protocolPing = NtfCmd SSubscription PING
|
||||
protocolError = \case
|
||||
NRErr e -> Just e
|
||||
_ -> Nothing
|
||||
|
||||
data NtfCommand (e :: NtfEntity) where
|
||||
-- | register new device token for notifications
|
||||
TNEW :: NewNtfEntity 'Token -> NtfCommand 'Token
|
||||
-- | verify token - uses e2e encrypted random string sent to the device via PN to confirm that the device has the token
|
||||
TVFY :: NtfRegistrationCode -> NtfCommand 'Token
|
||||
-- | delete token - all subscriptions will be removed and no more notifications will be sent
|
||||
TDEL :: NtfCommand 'Token
|
||||
-- | enable periodic background notification to fetch the new messages - interval is in minutes, minimum is 20, 0 to disable
|
||||
TCRN :: Word16 -> NtfCommand 'Token
|
||||
-- | create SMP subscription
|
||||
SNEW :: NewNtfEntity 'Subscription -> NtfCommand 'Subscription
|
||||
-- | check SMP subscription status (response is STAT)
|
||||
SCHK :: NtfCommand 'Subscription
|
||||
-- | delete SMP subscription
|
||||
SDEL :: NtfCommand 'Subscription
|
||||
-- | keep-alive command
|
||||
PING :: NtfCommand 'Subscription
|
||||
|
||||
data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e)
|
||||
|
||||
instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
|
||||
type Tag (NtfCommand e) = NtfCommandTag e
|
||||
encodeProtocol = \case
|
||||
NCCreate token smpQueue verifyKey dhKey -> e (NCCreate_, ' ', token, smpQueue, verifyKey, dhKey)
|
||||
NCCheck -> e NCCheck_
|
||||
NCToken token -> e (NCToken_, ' ', token)
|
||||
NCDelete -> e NCDelete_
|
||||
TNEW newTkn -> e (TNEW_, ' ', newTkn)
|
||||
TVFY code -> e (TVFY_, ' ', code)
|
||||
TDEL -> e TDEL_
|
||||
TCRN int -> e (TCRN_, ' ', int)
|
||||
SNEW newSub -> e (SNEW_, ' ', newSub)
|
||||
SCHK -> e SCHK_
|
||||
SDEL -> e SDEL_
|
||||
PING -> e PING_
|
||||
where
|
||||
e :: Encoding a => a -> ByteString
|
||||
e = smpEncode
|
||||
|
||||
protocolP = \case
|
||||
NCCreate_ -> NCCreate <$> _smpP <*> smpP <*> smpP <*> smpP
|
||||
NCCheck_ -> pure NCCheck
|
||||
NCToken_ -> NCToken <$> _smpP
|
||||
NCDelete_ -> pure NCDelete
|
||||
protocolP tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP (NCT (sNtfEntity @e) tag)
|
||||
|
||||
checkCredentials (sig, _, subId, _) cmd = case cmd of
|
||||
-- CREATE must have signature but NOT subscription ID
|
||||
NCCreate {}
|
||||
| isNothing sig -> Left $ CMD NO_AUTH
|
||||
| not (B.null subId) -> Left $ CMD HAS_AUTH
|
||||
| otherwise -> Right cmd
|
||||
-- other client commands must have both signature and subscription ID
|
||||
checkCredentials (sig, _, entityId, _) cmd = case cmd of
|
||||
-- TNEW and SNEW must have signature but NOT token/subscription IDs
|
||||
TNEW {} -> sigNoEntity
|
||||
SNEW {} -> sigNoEntity
|
||||
PING
|
||||
| isNothing sig && B.null entityId -> Right cmd
|
||||
| otherwise -> Left $ CMD HAS_AUTH
|
||||
-- other client commands must have both signature and entity ID
|
||||
_
|
||||
| isNothing sig || B.null subId -> Left $ CMD NO_AUTH
|
||||
| isNothing sig || B.null entityId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
where
|
||||
sigNoEntity
|
||||
| isNothing sig = Left $ CMD NO_AUTH
|
||||
| not (B.null entityId) = Left $ CMD HAS_AUTH
|
||||
| otherwise = Right cmd
|
||||
|
||||
instance ProtocolEncoding NtfCmd where
|
||||
type Tag NtfCmd = NtfCmdTag
|
||||
encodeProtocol (NtfCmd _ c) = encodeProtocol c
|
||||
|
||||
protocolP = \case
|
||||
NCT SToken tag ->
|
||||
NtfCmd SToken <$> case tag of
|
||||
TNEW_ -> TNEW <$> _smpP
|
||||
TVFY_ -> TVFY <$> _smpP
|
||||
TDEL_ -> pure TDEL
|
||||
TCRN_ -> TCRN <$> _smpP
|
||||
NCT SSubscription tag ->
|
||||
NtfCmd SSubscription <$> case tag of
|
||||
SNEW_ -> SNEW <$> _smpP
|
||||
SCHK_ -> pure SCHK
|
||||
SDEL_ -> pure SDEL
|
||||
PING_ -> pure PING
|
||||
|
||||
checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c
|
||||
|
||||
data NtfResponseTag
|
||||
= NRSubId_
|
||||
= NRId_
|
||||
| NROk_
|
||||
| NRErr_
|
||||
| NRStat_
|
||||
| NRPong_
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding NtfResponseTag where
|
||||
smpEncode = \case
|
||||
NRSubId_ -> "ID"
|
||||
NRId_ -> "ID"
|
||||
NROk_ -> "OK"
|
||||
NRErr_ -> "ERR"
|
||||
NRStat_ -> "STAT"
|
||||
NRPong_ -> "PONG"
|
||||
smpP = messageTagP
|
||||
|
||||
instance ProtocolMsgTag NtfResponseTag where
|
||||
decodeTag = \case
|
||||
"ID" -> Just NRSubId_
|
||||
"ID" -> Just NRId_
|
||||
"OK" -> Just NROk_
|
||||
"ERR" -> Just NRErr_
|
||||
"STAT" -> Just NRStat_
|
||||
"PONG" -> Just NRPong_
|
||||
_ -> Nothing
|
||||
|
||||
data NtfResponse
|
||||
= NRSubId C.PublicKeyX25519
|
||||
= NRId NtfEntityId C.PublicKeyX25519
|
||||
| NROk
|
||||
| NRErr ErrorType
|
||||
| NRStat NtfStatus
|
||||
| NRStat NtfSubStatus
|
||||
| NRPong
|
||||
|
||||
instance Protocol NtfResponse where
|
||||
instance ProtocolEncoding NtfResponse where
|
||||
type Tag NtfResponse = NtfResponseTag
|
||||
encodeProtocol = \case
|
||||
NRSubId dhKey -> e (NRSubId_, ' ', dhKey)
|
||||
NRId entId dhKey -> e (NRId_, ' ', entId, dhKey)
|
||||
NROk -> e NROk_
|
||||
NRErr err -> e (NRErr_, ' ', err)
|
||||
NRStat stat -> e (NRStat_, ' ', stat)
|
||||
NRPong -> e NRPong_
|
||||
where
|
||||
e :: Encoding a => a -> ByteString
|
||||
e = smpEncode
|
||||
|
||||
protocolP = \case
|
||||
NRSubId_ -> NRSubId <$> _smpP
|
||||
NRId_ -> NRId <$> _smpP <*> smpP
|
||||
NROk_ -> pure NROk
|
||||
NRErr_ -> NRErr <$> _smpP
|
||||
NRStat_ -> NRStat <$> _smpP
|
||||
NRPong_ -> pure NRPong
|
||||
|
||||
checkCredentials (_, _, subId, _) cmd = case cmd of
|
||||
-- ERR response does not always have subscription ID
|
||||
checkCredentials (_, _, entId, _) cmd = case cmd of
|
||||
-- ID response must not have queue ID
|
||||
NRId {} -> noEntity
|
||||
-- ERR response does not always have entity ID
|
||||
NRErr _ -> Right cmd
|
||||
-- other server responses must have subscription ID
|
||||
-- PONG response must not have queue ID
|
||||
NRPong -> noEntity
|
||||
-- other server responses must have entity ID
|
||||
_
|
||||
| B.null subId -> Left $ CMD NO_ENTITY
|
||||
| B.null entId -> Left $ CMD NO_ENTITY
|
||||
| otherwise -> Right cmd
|
||||
where
|
||||
noEntity
|
||||
| B.null entId = Right cmd
|
||||
| otherwise = Left $ CMD HAS_AUTH
|
||||
|
||||
data SMPQueueNtfUri = SMPQueueNtfUri
|
||||
{ smpServer :: SMPServer,
|
||||
data SMPQueueNtf = SMPQueueNtf
|
||||
{ smpServer :: ProtocolServer,
|
||||
notifierId :: NotifierId,
|
||||
notifierKey :: NtfPrivateSignKey
|
||||
}
|
||||
|
||||
instance Encoding SMPQueueNtfUri where
|
||||
smpEncode SMPQueueNtfUri {smpServer, notifierId, notifierKey} = smpEncode (smpServer, notifierId, notifierKey)
|
||||
instance Encoding SMPQueueNtf where
|
||||
smpEncode SMPQueueNtf {smpServer, notifierId, notifierKey} = smpEncode (smpServer, notifierId, notifierKey)
|
||||
smpP = do
|
||||
(smpServer, notifierId, notifierKey) <- smpP
|
||||
pure $ SMPQueueNtfUri smpServer notifierId notifierKey
|
||||
pure $ SMPQueueNtf smpServer notifierId notifierKey
|
||||
|
||||
newtype DeviceToken = DeviceToken ByteString
|
||||
data PushPlatform = PPApple
|
||||
|
||||
instance Encoding PushPlatform where
|
||||
smpEncode = \case
|
||||
PPApple -> "A"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'A' -> pure PPApple
|
||||
_ -> fail "bad PushPlatform"
|
||||
|
||||
data DeviceToken = DeviceToken PushPlatform ByteString
|
||||
|
||||
instance Encoding DeviceToken where
|
||||
smpEncode (DeviceToken t) = smpEncode t
|
||||
smpP = DeviceToken <$> smpP
|
||||
smpEncode (DeviceToken p t) = smpEncode (p, t)
|
||||
smpP = DeviceToken <$> smpP <*> smpP
|
||||
|
||||
type NtfSubsciptionId = ByteString
|
||||
type NtfEntityId = ByteString
|
||||
|
||||
data NtfStatus = NSPending | NSActive | NSEnd | NSSMPAuth
|
||||
type NtfSubscriptionId = NtfEntityId
|
||||
|
||||
instance Encoding NtfStatus where
|
||||
type NtfTokenId = NtfEntityId
|
||||
|
||||
data NtfSubStatus
|
||||
= -- | state after SNEW
|
||||
NSNew
|
||||
| -- | pending connection/subscription to SMP server
|
||||
NSPending
|
||||
| -- | connected and subscribed to SMP server
|
||||
NSActive
|
||||
| -- | NEND received (we currently do not support it)
|
||||
NSEnd
|
||||
| -- | SMP AUTH error
|
||||
NSSMPAuth
|
||||
deriving (Eq)
|
||||
|
||||
instance Encoding NtfSubStatus where
|
||||
smpEncode = \case
|
||||
NSPending -> "PENDING"
|
||||
NSNew -> "NEW"
|
||||
NSPending -> "PENDING" -- e.g. after SMP server disconnect/timeout while ntf server is retrying to connect
|
||||
NSActive -> "ACTIVE"
|
||||
NSEnd -> "END"
|
||||
NSSMPAuth -> "SMP_AUTH"
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"NEW" -> pure NSNew
|
||||
"PENDING" -> pure NSPending
|
||||
"ACTIVE" -> pure NSActive
|
||||
"END" -> pure NSEnd
|
||||
"SMP_AUTH" -> pure NSSMPAuth
|
||||
_ -> fail "bad NtfError"
|
||||
|
||||
data NtfTknStatus
|
||||
= -- | state after registration (TNEW)
|
||||
NTNew
|
||||
| -- | if initial notification or verification failed (push provider error)
|
||||
NTInvalid
|
||||
| -- | if initial notification succeeded
|
||||
NTConfirmed
|
||||
| -- | after successful verification (TVFY)
|
||||
NTActive
|
||||
| -- | after it is no longer valid (push provider error)
|
||||
NTExpired
|
||||
deriving (Eq)
|
||||
|
||||
checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e)
|
||||
checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of
|
||||
Just Refl -> Right c
|
||||
Nothing -> Left "bad command party"
|
||||
|
||||
checkEntity' :: forall t p p'. (NtfEntityI p, NtfEntityI p') => t p' -> Maybe (t p)
|
||||
checkEntity' c = case testEquality (sNtfEntity @p) (sNtfEntity @p') of
|
||||
Just Refl -> Just c
|
||||
_ -> Nothing
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server where
|
||||
|
||||
@@ -12,13 +15,16 @@ import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random (MonadRandom)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Functor (($>))
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Client.Agent
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Env
|
||||
import Simplex.Messaging.Notifications.Server.Subscriptions
|
||||
import Simplex.Messaging.Notifications.Transport
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), Transmission, encodeTransmission, tGet, tPut)
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), SignedTransmission, Transmission, encodeTransmission, tGet, tPut)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server
|
||||
import Simplex.Messaging.Transport (ATransport (..), THandle (..), TProxy, Transport)
|
||||
import Simplex.Messaging.Transport.Server (runTransportServer)
|
||||
@@ -37,7 +43,10 @@ runNtfServerBlocking started cfg@NtfServerConfig {transports} = do
|
||||
runReaderT ntfServer env
|
||||
where
|
||||
ntfServer :: (MonadUnliftIO m', MonadReader NtfEnv m') => m' ()
|
||||
ntfServer = raceAny_ (map runServer transports)
|
||||
ntfServer = do
|
||||
s <- asks subscriber
|
||||
ps <- asks pushServer
|
||||
raceAny_ (ntfSubscriber s : ntfPush ps : map runServer transports)
|
||||
|
||||
runServer :: (MonadUnliftIO m', MonadReader NtfEnv m') => (ServiceName, ATransport) -> m' ()
|
||||
runServer (tcpPort, ATransport t) = do
|
||||
@@ -51,11 +60,61 @@ runNtfServerBlocking started cfg@NtfServerConfig {transports} = do
|
||||
Right th -> runNtfClientTransport th
|
||||
Left _ -> pure ()
|
||||
|
||||
ntfSubscriber :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfSubscriber -> m ()
|
||||
ntfSubscriber NtfSubscriber {subQ, smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = do
|
||||
raceAny_ [subscribe, receiveSMP, receiveAgent]
|
||||
where
|
||||
subscribe :: m ()
|
||||
subscribe = forever $ do
|
||||
atomically (readTBQueue subQ) >>= \case
|
||||
NtfSub NtfSubData {smpQueue} -> do
|
||||
let SMPQueueNtf {smpServer, notifierId, notifierKey} = smpQueue
|
||||
liftIO (runExceptT $ subscribeQueue ca smpServer ((SPNotifier, notifierId), notifierKey)) >>= \case
|
||||
Right _ -> pure () -- update subscription status
|
||||
Left e -> pure ()
|
||||
|
||||
receiveSMP :: m ()
|
||||
receiveSMP = forever $ do
|
||||
(srv, ntfId, msg) <- atomically $ readTBQueue msgQ
|
||||
case msg of
|
||||
SMP.NMSG -> do
|
||||
-- check when the last NMSG was received from this queue
|
||||
-- update timestamp
|
||||
-- check what was the last hidden notification was sent (and whether to this queue)
|
||||
-- decide whether it should be sent as hidden or visible
|
||||
-- construct and possibly encrypt notification
|
||||
-- send it
|
||||
pure ()
|
||||
_ -> pure ()
|
||||
pure ()
|
||||
|
||||
receiveAgent =
|
||||
forever $
|
||||
atomically (readTBQueue agentQ) >>= \case
|
||||
CAConnected _ -> pure ()
|
||||
CADisconnected srv subs -> do
|
||||
-- update subscription statuses
|
||||
pure ()
|
||||
CAReconnected _ -> pure ()
|
||||
CAResubscribed srv sub -> do
|
||||
-- update subscription status
|
||||
pure ()
|
||||
CASubError srv sub err -> do
|
||||
-- update subscription status
|
||||
pure ()
|
||||
|
||||
ntfPush :: (MonadUnliftIO m, MonadReader NtfEnv m) => NtfPushServer -> m ()
|
||||
ntfPush NtfPushServer {pushQ} = forever $ do
|
||||
atomically (readTBQueue pushQ) >>= \case
|
||||
(NtfTknData {}, Notification {}) -> pure ()
|
||||
|
||||
runNtfClientTransport :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> m ()
|
||||
runNtfClientTransport th@THandle {sessionId} = do
|
||||
q <- asks $ tbqSize . config
|
||||
c <- atomically $ newNtfServerClient q sessionId
|
||||
raceAny_ [send th c, client c, receive th c]
|
||||
qSize <- asks $ clientQSize . config
|
||||
c <- atomically $ newNtfServerClient qSize sessionId
|
||||
s <- asks subscriber
|
||||
ps <- asks pushServer
|
||||
raceAny_ [send th c, client c s ps, receive th c]
|
||||
`finally` clientDisconnected c
|
||||
|
||||
clientDisconnected :: MonadUnliftIO m => NtfServerClient -> m ()
|
||||
@@ -63,14 +122,13 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte
|
||||
|
||||
receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m ()
|
||||
receive th NtfServerClient {rcvQ, sndQ} = forever $ do
|
||||
(sig, signed, (corrId, queueId, cmdOrError)) <- tGet th
|
||||
t@(sig, signed, (corrId, subId, cmdOrError)) <- tGet th
|
||||
case cmdOrError of
|
||||
Left e -> write sndQ (corrId, queueId, NRErr e)
|
||||
Right cmd -> do
|
||||
verified <- verifyTransmission sig signed queueId cmd
|
||||
if verified
|
||||
then write rcvQ (corrId, queueId, cmd)
|
||||
else write sndQ (corrId, queueId, NRErr AUTH)
|
||||
Left e -> write sndQ (corrId, subId, NRErr e)
|
||||
Right cmd ->
|
||||
verifyNtfTransmission t cmd >>= \case
|
||||
VRVerified req -> write rcvQ req
|
||||
VRFailed -> write sndQ (corrId, subId, NRErr AUTH)
|
||||
where
|
||||
write q t = atomically $ writeTBQueue q t
|
||||
|
||||
@@ -79,33 +137,121 @@ send h NtfServerClient {sndQ, sessionId} = forever $ do
|
||||
t <- atomically $ readTBQueue sndQ
|
||||
liftIO $ tPut h (Nothing, encodeTransmission sessionId t)
|
||||
|
||||
verifyTransmission ::
|
||||
forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => Maybe C.ASignature -> ByteString -> NtfSubsciptionId -> NtfCommand -> m Bool
|
||||
verifyTransmission sig_ signed subId cmd = do
|
||||
case cmd of
|
||||
NCCreate _ _ k _ -> pure $ verifyCmdSignature sig_ signed k
|
||||
_ -> do
|
||||
st <- asks store
|
||||
verifySubCmd <$> atomically (getNtfSubscription st subId)
|
||||
where
|
||||
verifySubCmd = \case
|
||||
Right sub -> verifyCmdSignature sig_ signed $ subVerifyKey sub
|
||||
Left _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` False
|
||||
data VerificationResult = VRVerified NtfRequest | VRFailed
|
||||
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> m ()
|
||||
client NtfServerClient {rcvQ, sndQ} =
|
||||
verifyNtfTransmission ::
|
||||
forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => SignedTransmission NtfCmd -> NtfCmd -> m VerificationResult
|
||||
verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
|
||||
case cmd of
|
||||
NtfCmd SToken (TNEW n@(NewNtfTkn _ k _)) ->
|
||||
-- TODO check that token is not already in store
|
||||
pure $
|
||||
if verifyCmdSignature sig_ signed k
|
||||
then VRVerified (NtfReqNew corrId (ANE SToken n))
|
||||
else VRFailed
|
||||
NtfCmd SToken c -> do
|
||||
st <- asks store
|
||||
atomically (getNtfToken st entId) >>= \case
|
||||
Just r@(NtfTkn NtfTknData {tknVerifyKey}) ->
|
||||
pure $
|
||||
if verifyCmdSignature sig_ signed tknVerifyKey
|
||||
then VRVerified (NtfReqCmd SToken r (corrId, entId, c))
|
||||
else VRFailed
|
||||
_ -> pure VRFailed -- TODO dummy verification
|
||||
_ -> pure VRFailed
|
||||
|
||||
-- do
|
||||
-- st <- asks store
|
||||
-- case cmd of
|
||||
-- NCSubCreate tokenId smpQueue -> verifyCreateCmd verifyKey newSub <$> atomically (getNtfSubViaSMPQueue st smpQueue)
|
||||
-- _ -> verifySubCmd <$> atomically (getNtfSub st subId)
|
||||
-- where
|
||||
-- verifyCreateCmd k newSub sub_
|
||||
-- | verifyCmdSignature sig_ signed k = case sub_ of
|
||||
-- Just sub -> if k == subVerifyKey sub then VRCommand sub else VRFail
|
||||
-- _ -> VRCreate newSub
|
||||
-- | otherwise = VRFail
|
||||
-- verifySubCmd = \case
|
||||
-- Just sub -> if verifyCmdSignature sig_ signed $ subVerifyKey sub then VRCommand sub else VRFail
|
||||
-- _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFail
|
||||
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> NtfSubscriber -> NtfPushServer -> m ()
|
||||
client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ} NtfPushServer {pushQ} =
|
||||
forever $
|
||||
atomically (readTBQueue rcvQ)
|
||||
>>= processCommand
|
||||
>>= atomically . writeTBQueue sndQ
|
||||
where
|
||||
processCommand :: Transmission NtfCommand -> m (Transmission NtfResponse)
|
||||
processCommand (corrId, subId, cmd) = case cmd of
|
||||
NCCreate _token _smpQueue _verifyKey _dhKey -> do
|
||||
pure (corrId, subId, NROk)
|
||||
NCCheck -> do
|
||||
pure (corrId, subId, NROk)
|
||||
NCToken _token -> do
|
||||
pure (corrId, subId, NROk)
|
||||
NCDelete -> do
|
||||
pure (corrId, subId, NROk)
|
||||
processCommand :: NtfRequest -> m (Transmission NtfResponse)
|
||||
processCommand = \case
|
||||
NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn _ _ dhPubKey)) -> do
|
||||
st <- asks store
|
||||
(srvDhPubKey, srvDrivDhKey) <- liftIO C.generateKeyPair'
|
||||
let dhSecret = C.dh' dhPubKey srvDrivDhKey
|
||||
tknId <- getId
|
||||
atomically $ do
|
||||
tkn <- mkNtfTknData newTkn dhSecret
|
||||
addNtfToken st tknId tkn
|
||||
writeTBQueue pushQ (tkn, Notification)
|
||||
-- pure (corrId, sId, NRSubId pubDhKey)
|
||||
pure (corrId, "", NRId tknId srvDhPubKey)
|
||||
NtfReqCmd SToken tkn (corrId, tknId, cmd) ->
|
||||
(corrId,tknId,) <$> case cmd of
|
||||
TNEW newTkn -> pure NROk -- TODO when duplicate token sent
|
||||
TVFY code -> pure NROk
|
||||
TDEL -> pure NROk
|
||||
TCRN int -> pure NROk
|
||||
NtfReqNew corrId (ANE SSubscription newSub) -> pure (corrId, "", NROk)
|
||||
NtfReqCmd SSubscription sub (corrId, subId, cmd) ->
|
||||
(corrId,subId,) <$> case cmd of
|
||||
SNEW newSub -> pure NROk
|
||||
SCHK -> pure NROk
|
||||
SDEL -> pure NROk
|
||||
PING -> pure NRPong
|
||||
getId :: m NtfEntityId
|
||||
getId = do
|
||||
n <- asks $ subIdBytes . config
|
||||
gVar <- asks idsDrg
|
||||
atomically (randomBytes n gVar)
|
||||
|
||||
-- NReqCreate corrId tokenId smpQueue -> pure (corrId, "", NROk)
|
||||
-- do
|
||||
-- st <- asks store
|
||||
-- (pubDhKey, privDhKey) <- liftIO C.generateKeyPair'
|
||||
-- let dhSecret = C.dh' dhPubKey privDhKey
|
||||
-- sub <- atomically $ mkNtfSubsciption smpQueue token verifyKey dhSecret
|
||||
-- addSubRetry 3 st sub >>= \case
|
||||
-- Nothing -> pure (corrId, "", NRErr INTERNAL)
|
||||
-- Just sId -> do
|
||||
-- atomically $ writeTBQueue subQ sub
|
||||
-- pure (corrId, sId, NRSubId pubDhKey)
|
||||
-- where
|
||||
-- addSubRetry :: Int -> NtfSubscriptionsStore -> NtfSubsciption -> m (Maybe NtfSubsciptionId)
|
||||
-- addSubRetry 0 _ _ = pure Nothing
|
||||
-- addSubRetry n st sub = do
|
||||
-- sId <- getId
|
||||
-- -- create QueueRec record with these ids and keys
|
||||
-- atomically (addNtfSub st sId sub) >>= \case
|
||||
-- Nothing -> addSubRetry (n - 1) st sub
|
||||
-- _ -> pure $ Just sId
|
||||
-- getId :: m NtfSubsciptionId
|
||||
-- getId = do
|
||||
-- n <- asks $ subIdBytes . config
|
||||
-- gVar <- asks idsDrg
|
||||
-- atomically (randomBytes n gVar)
|
||||
-- NReqCommand sub@NtfSubsciption {tokenId, subStatus} (corrId, subId, cmd) ->
|
||||
-- (corrId,subId,) <$> case cmd of
|
||||
-- NCSubCreate tokenId smpQueue -> pure NROk
|
||||
-- do
|
||||
-- st <- asks store
|
||||
-- (pubDhKey, privDhKey) <- liftIO C.generateKeyPair'
|
||||
-- let dhSecret = C.dh' (dhPubKey newSub) privDhKey
|
||||
-- atomically (updateNtfSub st sub newSub dhSecret) >>= \case
|
||||
-- Nothing -> pure $ NRErr INTERNAL
|
||||
-- _ -> atomically $ do
|
||||
-- whenM ((== NSEnd) <$> readTVar status) $ writeTBQueue subQ sub
|
||||
-- pure $ NRSubId pubDhKey
|
||||
-- NCSubCheck -> NRStat <$> readTVarIO subStatus
|
||||
-- NCSubDelete -> do
|
||||
-- st <- asks store
|
||||
-- atomically (deleteNtfSub st subId) $> NROk
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.Env where
|
||||
@@ -6,26 +9,27 @@ module Simplex.Messaging.Notifications.Server.Env where
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.X509.Validation (Fingerprint (..))
|
||||
import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Subscriptions
|
||||
import Simplex.Messaging.Protocol (Transmission)
|
||||
import Simplex.Messaging.Protocol (CorrId, Transmission)
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
|
||||
import UnliftIO.STM
|
||||
|
||||
data NtfServerConfig = NtfServerConfig
|
||||
{ transports :: [(ServiceName, ATransport)],
|
||||
subscriptionIdBytes :: Int,
|
||||
tbqSize :: Natural,
|
||||
smpCfg :: SMPClientConfig,
|
||||
subIdBytes :: Int,
|
||||
clientQSize :: Natural,
|
||||
subQSize :: Natural,
|
||||
pushQSize :: Natural,
|
||||
smpAgentCfg :: SMPClientAgentConfig,
|
||||
reconnectInterval :: RetryInterval,
|
||||
-- CA certificate private key is not needed for initialization
|
||||
caCertificateFile :: FilePath,
|
||||
@@ -33,26 +37,55 @@ data NtfServerConfig = NtfServerConfig
|
||||
certificateFile :: FilePath
|
||||
}
|
||||
|
||||
data Notification = Notification
|
||||
|
||||
data NtfEnv = NtfEnv
|
||||
{ config :: NtfServerConfig,
|
||||
serverIdentity :: C.KeyHash,
|
||||
store :: NtfSubscriptions,
|
||||
subscriber :: NtfSubscriber,
|
||||
pushServer :: NtfPushServer,
|
||||
store :: NtfStore,
|
||||
idsDrg :: TVar ChaChaDRG,
|
||||
serverIdentity :: C.KeyHash,
|
||||
tlsServerParams :: T.ServerParams,
|
||||
serverIdentity :: C.KeyHash
|
||||
}
|
||||
|
||||
newNtfServerEnv :: (MonadUnliftIO m, MonadRandom m) => NtfServerConfig -> m NtfEnv
|
||||
newNtfServerEnv config@NtfServerConfig {caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
idsDrg <- newTVarIO =<< drgNew
|
||||
store <- newTVarIO M.empty
|
||||
store <- atomically newNtfStore
|
||||
subscriber <- atomically $ newNtfSubscriber subQSize smpAgentCfg
|
||||
pushServer <- atomically $ newNtfPushServer pushQSize
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile
|
||||
let serverIdentity = C.KeyHash fp
|
||||
pure NtfEnv {config, store, idsDrg, tlsServerParams, serverIdentity}
|
||||
pure NtfEnv {config, subscriber, pushServer, store, idsDrg, tlsServerParams, serverIdentity = C.KeyHash fp}
|
||||
|
||||
data NtfSubscriber = NtfSubscriber
|
||||
{ subQ :: TBQueue (NtfEntityRec 'Subscription),
|
||||
smpAgent :: SMPClientAgent
|
||||
}
|
||||
|
||||
newNtfSubscriber :: Natural -> SMPClientAgentConfig -> STM NtfSubscriber
|
||||
newNtfSubscriber qSize smpAgentCfg = do
|
||||
smpAgent <- newSMPClientAgent smpAgentCfg
|
||||
subQ <- newTBQueue qSize
|
||||
pure NtfSubscriber {smpAgent, subQ}
|
||||
|
||||
newtype NtfPushServer = NtfPushServer
|
||||
{ pushQ :: TBQueue (NtfTknData, Notification)
|
||||
}
|
||||
|
||||
newNtfPushServer :: Natural -> STM NtfPushServer
|
||||
newNtfPushServer qSize = do
|
||||
pushQ <- newTBQueue qSize
|
||||
pure NtfPushServer {pushQ}
|
||||
|
||||
data NtfRequest
|
||||
= NtfReqNew CorrId ANewNtfEntity
|
||||
| forall e. NtfEntityI e => NtfReqCmd (SNtfEntity e) (NtfEntityRec e) (Transmission (NtfCommand e))
|
||||
|
||||
data NtfServerClient = NtfServerClient
|
||||
{ rcvQ :: TBQueue (Transmission NtfCommand),
|
||||
{ rcvQ :: TBQueue NtfRequest,
|
||||
sndQ :: TBQueue (Transmission NtfResponse),
|
||||
sessionId :: ByteString,
|
||||
connected :: TVar Bool
|
||||
|
||||
11
src/Simplex/Messaging/Notifications/Server/Push.hs
Normal file
11
src/Simplex/Messaging/Notifications/Server/Push.hs
Normal file
@@ -0,0 +1,11 @@
|
||||
module Simplex.Messaging.Notifications.Server.Push where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Simplex.Messaging.Protocol (NotifierId, SMPServer)
|
||||
|
||||
data NtfPushPayload = NPVerification ByteString | NPNotification SMPServer NotifierId | NPPing
|
||||
|
||||
class PushProvider p where
|
||||
newPushProvider :: STM p
|
||||
requestBody :: p -> NtfPushPayload -> ByteString -- ?
|
||||
@@ -1,25 +1,109 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.Subscriptions where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad
|
||||
import Crypto.PubKey.Curve25519 (dhSecret)
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Set (Set)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), NotifierId, NtfPrivateSignKey, SMPServer)
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), NotifierId, NtfPrivateSignKey, ProtocolServer)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util ((<$$>))
|
||||
|
||||
type NtfSubscriptionsData = Map NtfSubsciptionId NtfSubsciptionRec
|
||||
|
||||
type NtfSubscriptions = TVar NtfSubscriptionsData
|
||||
|
||||
data NtfSubsciptionRec = NtfSubsciptionRec
|
||||
{ smpServer :: SMPServer,
|
||||
notifierId :: NotifierId,
|
||||
notifierKey :: NtfPrivateSignKey,
|
||||
token :: DeviceToken,
|
||||
status :: TVar NtfStatus,
|
||||
subVerifyKey :: C.APublicVerifyKey,
|
||||
subDHSecret :: C.DhSecretX25519
|
||||
data NtfStore = NtfStore
|
||||
{ tokens :: TMap NtfTokenId NtfTknData,
|
||||
tokenIds :: TMap DeviceToken NtfTokenId
|
||||
}
|
||||
|
||||
getNtfSubscription :: NtfSubscriptions -> NtfSubsciptionId -> STM (Either ErrorType NtfSubsciptionRec)
|
||||
getNtfSubscription st subId = maybe (Left AUTH) Right . M.lookup subId <$> readTVar st
|
||||
newNtfStore :: STM NtfStore
|
||||
newNtfStore = do
|
||||
tokens <- TM.empty
|
||||
tokenIds <- TM.empty
|
||||
pure NtfStore {tokens, tokenIds}
|
||||
|
||||
data NtfTknData = NtfTknData
|
||||
{ token :: DeviceToken,
|
||||
tknStatus :: TVar NtfTknStatus,
|
||||
tknVerifyKey :: C.APublicVerifyKey,
|
||||
tknDhSecret :: C.DhSecretX25519
|
||||
}
|
||||
|
||||
mkNtfTknData :: NewNtfEntity 'Token -> C.DhSecretX25519 -> STM NtfTknData
|
||||
mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhSecret = do
|
||||
tknStatus <- newTVar NTNew
|
||||
pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhSecret}
|
||||
|
||||
data NtfSubscriptionsStore = NtfSubscriptionsStore
|
||||
|
||||
-- { subscriptions :: TMap NtfSubsciptionId NtfSubsciption,
|
||||
-- activeSubscriptions :: TMap (SMPServer, NotifierId) NtfSubsciptionId
|
||||
-- }
|
||||
-- do
|
||||
-- subscriptions <- newTVar M.empty
|
||||
-- activeSubscriptions <- newTVar M.empty
|
||||
-- pure NtfSubscriptionsStore {subscriptions, activeSubscriptions}
|
||||
|
||||
data NtfSubData = NtfSubData
|
||||
{ smpQueue :: SMPQueueNtf,
|
||||
tokenId :: NtfTokenId,
|
||||
subStatus :: TVar NtfSubStatus
|
||||
}
|
||||
|
||||
data NtfEntityRec (e :: NtfEntity) where
|
||||
NtfTkn :: NtfTknData -> NtfEntityRec 'Token
|
||||
NtfSub :: NtfSubData -> NtfEntityRec 'Subscription
|
||||
|
||||
data ANtfEntityRec = forall e. NtfEntityI e => NER (SNtfEntity e) (NtfEntityRec e)
|
||||
|
||||
getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe (NtfEntityRec 'Token))
|
||||
getNtfToken st tknId = NtfTkn <$$> TM.lookup tknId (tokens st)
|
||||
|
||||
addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM ()
|
||||
addNtfToken st tknId tkn = pure ()
|
||||
|
||||
-- getNtfRec :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e))
|
||||
-- getNtfRec st ent entId = case ent of
|
||||
-- SToken -> NtfTkn <$$> TM.lookup entId (tokens st)
|
||||
-- SSubscription -> pure Nothing
|
||||
|
||||
-- getNtfVerifyKey :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e, C.APublicVerifyKey))
|
||||
-- getNtfVerifyKey st ent entId =
|
||||
-- getNtfRec st ent entId >>= \case
|
||||
-- Just r@(NtfTkn NtfTknData {tknVerifyKey}) -> pure $ Just (r, tknVerifyKey)
|
||||
-- Just r@(NtfSub NtfSubData {tokenId}) ->
|
||||
-- getNtfRec st SToken tokenId >>= \case
|
||||
-- Just (NtfTkn NtfTknData {tknVerifyKey}) -> pure $ Just (r, tknVerifyKey)
|
||||
-- _ -> pure Nothing
|
||||
-- _ -> pure Nothing
|
||||
|
||||
-- mkNtfSubsciption :: SMPQueueNtf -> NtfTokenId -> STM NtfSubsciption
|
||||
-- mkNtfSubsciption smpQueue tokenId = do
|
||||
-- subStatus <- newTVar NSNew
|
||||
-- pure NtfSubsciption {smpQueue, tokenId, subStatus}
|
||||
|
||||
-- getNtfSub :: NtfSubscriptionsStore -> NtfSubsciptionId -> STM (Maybe NtfSubsciption)
|
||||
-- getNtfSub st subId = pure Nothing -- maybe (pure $ Left AUTH) (fmap Right . readTVar) . M.lookup subId . subscriptions =<< readTVar st
|
||||
|
||||
-- getNtfSubViaSMPQueue :: NtfSubscriptionsStore -> SMPQueueNtf -> STM (Maybe NtfSubsciption)
|
||||
-- getNtfSubViaSMPQueue st smpQueue = pure Nothing
|
||||
|
||||
-- -- replace keeping status
|
||||
-- updateNtfSub :: NtfSubscriptionsStore -> NtfSubsciption -> SMPQueueNtf -> NtfTokenId -> C.DhSecretX25519 -> STM (Maybe ())
|
||||
-- updateNtfSub st sub smpQueue tokenId dhSecret = pure Nothing
|
||||
|
||||
-- addNtfSub :: NtfSubscriptionsStore -> NtfSubsciptionId -> NtfSubsciption -> STM (Maybe ())
|
||||
-- addNtfSub st subId sub = pure Nothing
|
||||
|
||||
-- deleteNtfSub :: NtfSubscriptionsStore -> NtfSubsciptionId -> STM ()
|
||||
-- deleteNtfSub st subId = pure ()
|
||||
|
||||
@@ -1,23 +1,25 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeFamilyDependencies #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
-- |
|
||||
-- Module : Simplex.Messaging.Protocol
|
||||
-- Module : Simplex.Messaging.ProtocolEncoding
|
||||
-- Copyright : (c) simplex.chat
|
||||
-- License : AGPL-3
|
||||
--
|
||||
@@ -37,7 +39,7 @@ module Simplex.Messaging.Protocol
|
||||
e2eEncMessageLength,
|
||||
|
||||
-- * SMP protocol types
|
||||
Protocol (..),
|
||||
ProtocolEncoding (..),
|
||||
Command (..),
|
||||
Party (..),
|
||||
Cmd (..),
|
||||
@@ -55,7 +57,10 @@ module Simplex.Messaging.Protocol
|
||||
PubHeader (..),
|
||||
ClientMessage (..),
|
||||
PrivHeader (..),
|
||||
SMPServer (..),
|
||||
Protocol (..),
|
||||
ProtocolServer (..),
|
||||
SMPServer,
|
||||
pattern SMPServer,
|
||||
SrvLoc (..),
|
||||
CorrId (..),
|
||||
QueueId,
|
||||
@@ -163,7 +168,7 @@ data Cmd = forall p. PartyI p => Cmd (SParty p) (Command p)
|
||||
deriving instance Show Cmd
|
||||
|
||||
-- | Parsed SMP transmission without signature, size and session ID.
|
||||
type Transmission c = (CorrId, QueueId, c)
|
||||
type Transmission c = (CorrId, EntityId, c)
|
||||
|
||||
-- | signed parsed transmission, with original raw bytes and parsing error.
|
||||
type SignedTransmission c = (Maybe C.ASignature, Signed, Transmission (Either ErrorType c))
|
||||
@@ -196,7 +201,9 @@ type SenderId = QueueId
|
||||
type NotifierId = QueueId
|
||||
|
||||
-- | SMP queue ID on the server.
|
||||
type QueueId = ByteString
|
||||
type QueueId = EntityId
|
||||
|
||||
type EntityId = ByteString
|
||||
|
||||
-- | Parameterized type for SMP protocol commands from all clients.
|
||||
data Command (p :: Party) where
|
||||
@@ -266,7 +273,7 @@ class ProtocolMsgTag t where
|
||||
|
||||
messageTagP :: ProtocolMsgTag t => Parser t
|
||||
messageTagP =
|
||||
maybe (fail "bad command") pure . decodeTag
|
||||
maybe (fail "bad message") pure . decodeTag
|
||||
=<< (A.takeTill (== ' ') <* optional A.space)
|
||||
|
||||
instance PartyI p => Encoding (CommandTag p) where
|
||||
@@ -374,34 +381,39 @@ instance Encoding ClientMessage where
|
||||
smpEncode (ClientMessage h msg) = smpEncode h <> msg
|
||||
smpP = ClientMessage <$> smpP <*> A.takeByteString
|
||||
|
||||
type SMPServer = ProtocolServer
|
||||
|
||||
pattern SMPServer :: HostName -> ServiceName -> C.KeyHash -> ProtocolServer
|
||||
pattern SMPServer host port keyHash = ProtocolServer host port keyHash
|
||||
|
||||
-- | SMP server location and transport key digest (hash).
|
||||
data SMPServer = SMPServer
|
||||
data ProtocolServer = ProtocolServer
|
||||
{ host :: HostName,
|
||||
port :: ServiceName,
|
||||
keyHash :: C.KeyHash
|
||||
}
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance IsString SMPServer where
|
||||
instance IsString ProtocolServer where
|
||||
fromString = parseString strDecode
|
||||
|
||||
instance Encoding SMPServer where
|
||||
smpEncode SMPServer {host, port, keyHash} =
|
||||
instance Encoding ProtocolServer where
|
||||
smpEncode ProtocolServer {host, port, keyHash} =
|
||||
smpEncode (host, port, keyHash)
|
||||
smpP = do
|
||||
(host, port, keyHash) <- smpP
|
||||
pure SMPServer {host, port, keyHash}
|
||||
pure ProtocolServer {host, port, keyHash}
|
||||
|
||||
instance StrEncoding SMPServer where
|
||||
strEncode SMPServer {host, port, keyHash} =
|
||||
instance StrEncoding ProtocolServer where
|
||||
strEncode ProtocolServer {host, port, keyHash} =
|
||||
"smp://" <> strEncode keyHash <> "@" <> strEncode (SrvLoc host port)
|
||||
strP = do
|
||||
_ <- "smp://"
|
||||
keyHash <- strP <* A.char '@'
|
||||
SrvLoc host port <- strP
|
||||
pure SMPServer {host, port, keyHash}
|
||||
pure ProtocolServer {host, port, keyHash}
|
||||
|
||||
instance ToJSON SMPServer where
|
||||
instance ToJSON ProtocolServer where
|
||||
toJSON = strToJSON
|
||||
toEncoding = strToJEncoding
|
||||
|
||||
@@ -540,13 +552,25 @@ transmissionP = do
|
||||
command <- A.takeByteString
|
||||
pure RawTransmission {signature, signed, sessId, corrId, entityId, command}
|
||||
|
||||
class Protocol msg where
|
||||
class (ProtocolEncoding msg, ProtocolEncoding (ProtocolCommand msg)) => Protocol msg where
|
||||
type ProtocolCommand msg = cmd | cmd -> msg
|
||||
protocolPing :: ProtocolCommand msg
|
||||
protocolError :: msg -> Maybe ErrorType
|
||||
|
||||
instance Protocol BrokerMsg where
|
||||
type ProtocolCommand BrokerMsg = Cmd
|
||||
protocolPing = Cmd SSender PING
|
||||
protocolError = \case
|
||||
ERR e -> Just e
|
||||
_ -> Nothing
|
||||
|
||||
class ProtocolMsgTag (Tag msg) => ProtocolEncoding msg where
|
||||
type Tag msg
|
||||
encodeProtocol :: msg -> ByteString
|
||||
protocolP :: Tag msg -> Parser msg
|
||||
checkCredentials :: SignedRawTransmission -> msg -> Either ErrorType msg
|
||||
|
||||
instance PartyI p => Protocol (Command p) where
|
||||
instance PartyI p => ProtocolEncoding (Command p) where
|
||||
type Tag (Command p) = CommandTag p
|
||||
encodeProtocol = \case
|
||||
NEW rKey dhKey -> e (NEW_, ' ', rKey, dhKey)
|
||||
@@ -584,7 +608,7 @@ instance PartyI p => Protocol (Command p) where
|
||||
| isNothing sig || B.null queueId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
|
||||
instance Protocol Cmd where
|
||||
instance ProtocolEncoding Cmd where
|
||||
type Tag Cmd = CmdTag
|
||||
encodeProtocol (Cmd _ c) = encodeProtocol c
|
||||
|
||||
@@ -606,7 +630,7 @@ instance Protocol Cmd where
|
||||
|
||||
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
|
||||
|
||||
instance Protocol BrokerMsg where
|
||||
instance ProtocolEncoding BrokerMsg where
|
||||
type Tag BrokerMsg = BrokerMsgTag
|
||||
encodeProtocol = \case
|
||||
IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh)
|
||||
@@ -632,7 +656,7 @@ instance Protocol BrokerMsg where
|
||||
PONG_ -> pure PONG
|
||||
|
||||
checkCredentials (_, _, queueId, _) cmd = case cmd of
|
||||
-- IDS response must not have queue ID
|
||||
-- IDS response should not have queue ID
|
||||
IDS _ -> Right cmd
|
||||
-- ERR response does not always have queue ID
|
||||
ERR _ -> Right cmd
|
||||
@@ -649,7 +673,7 @@ _smpP :: Encoding a => Parser a
|
||||
_smpP = A.space *> smpP
|
||||
|
||||
-- | Parse SMP protocol commands and broker messages
|
||||
parseProtocol :: (Protocol msg, ProtocolMsgTag (Tag msg)) => ByteString -> Either ErrorType msg
|
||||
parseProtocol :: ProtocolEncoding msg => ByteString -> Either ErrorType msg
|
||||
parseProtocol s =
|
||||
let (tag, params) = B.break (== ' ') s
|
||||
in case decodeTag tag of
|
||||
@@ -712,7 +736,7 @@ instance Encoding CommandError where
|
||||
tPut :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ())
|
||||
tPut th (sig, t) = tPutBlock th $ smpEncode (C.signatureBytes sig) <> t
|
||||
|
||||
encodeTransmission :: Protocol c => ByteString -> Transmission c -> ByteString
|
||||
encodeTransmission :: ProtocolEncoding c => ByteString -> Transmission c -> ByteString
|
||||
encodeTransmission sessionId (CorrId corrId, queueId, command) =
|
||||
smpEncode (sessionId, corrId, queueId) <> encodeProtocol command
|
||||
|
||||
@@ -721,11 +745,7 @@ tGetParse :: Transport c => THandle c -> IO (Either TransportError RawTransmissi
|
||||
tGetParse th = (parse transmissionP TEBadBlock =<<) <$> tGetBlock th
|
||||
|
||||
-- | Receive client and server transmissions (determined by `cmd` type).
|
||||
tGet ::
|
||||
forall cmd c m.
|
||||
(Protocol cmd, ProtocolMsgTag (Tag cmd), Transport c, MonadIO m) =>
|
||||
THandle c ->
|
||||
m (SignedTransmission cmd)
|
||||
tGet :: forall cmd c m. (ProtocolEncoding cmd, Transport c, MonadIO m) => THandle c -> m (SignedTransmission cmd)
|
||||
tGet th@THandle {sessionId} = liftIO (tGetParse th) >>= decodeParseValidate
|
||||
where
|
||||
decodeParseValidate :: Either TransportError RawTransmission -> m (SignedTransmission cmd)
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
-- and optional append only log of SMP queue records.
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
|
||||
module Simplex.Messaging.Server (runSMPServer, runSMPServerBlocking, verifyCmdSignature, dummyVerifyCmd) where
|
||||
module Simplex.Messaging.Server (runSMPServer, runSMPServerBlocking, verifyCmdSignature, dummyVerifyCmd, randomBytes) where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
|
||||
@@ -11,6 +11,7 @@ module Simplex.Messaging.TMap
|
||||
adjust,
|
||||
update,
|
||||
alter,
|
||||
alterF,
|
||||
union,
|
||||
)
|
||||
where
|
||||
@@ -65,6 +66,12 @@ alter :: Ord k => (Maybe a -> Maybe a) -> k -> TMap k a -> STM ()
|
||||
alter f k m = modifyTVar' m $ M.alter f k
|
||||
{-# INLINE alter #-}
|
||||
|
||||
alterF :: Ord k => (Maybe a -> STM (Maybe a)) -> k -> TMap k a -> STM ()
|
||||
alterF f k m = do
|
||||
mv <- M.alterF f k =<< readTVar m
|
||||
writeTVar m $! mv
|
||||
{-# INLINE alterF #-}
|
||||
|
||||
union :: Ord k => Map k a -> TMap k a -> STM ()
|
||||
union m' m = modifyTVar' m $ M.union m'
|
||||
{-# INLINE union #-}
|
||||
|
||||
@@ -52,7 +52,7 @@ foreign import capi "netinet/tcp.h value TCP_KEEPINTVL" _TCP_KEEPINTVL :: CInt
|
||||
|
||||
foreign import capi "netinet/tcp.h value TCP_KEEPCNT" _TCP_KEEPCNT :: CInt
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
setSocketKeepAlive :: Socket -> KeepAliveOpts -> IO ()
|
||||
setSocketKeepAlive sock KeepAliveOpts {keepCnt, keepIdle, keepIntvl} = do
|
||||
|
||||
@@ -11,7 +11,7 @@ import Simplex.Messaging.Agent.Protocol
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (smpClientVRange)
|
||||
import Simplex.Messaging.Protocol (ProtocolServer (..), smpClientVRange)
|
||||
import Simplex.Messaging.Version
|
||||
import Test.Hspec
|
||||
|
||||
@@ -20,7 +20,7 @@ uri = "smp.simplex.im"
|
||||
|
||||
srv :: SMPServer
|
||||
srv =
|
||||
SMPServer
|
||||
ProtocolServer
|
||||
{ host = "smp.simplex.im",
|
||||
port = "5223",
|
||||
keyHash = C.KeyHash "\215m\248\251"
|
||||
|
||||
@@ -24,7 +24,7 @@ import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Server (runSMPAgentBlocking)
|
||||
import Simplex.Messaging.Client (SMPClientConfig (..), smpDefaultConfig)
|
||||
import Simplex.Messaging.Client (ProtocolClientConfig (..), defaultClientConfig)
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Simplex.Messaging.Transport.KeepAlive
|
||||
@@ -162,7 +162,7 @@ cfg =
|
||||
tbqSize = 1,
|
||||
dbFile = testDB,
|
||||
smpCfg =
|
||||
smpDefaultConfig
|
||||
defaultClientConfig
|
||||
{ qSize = 1,
|
||||
defaultTransport = (testPort, transport @TLS),
|
||||
tcpTimeout = 500_000
|
||||
|
||||
Reference in New Issue
Block a user