diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 7b5fba6bf..86758b945 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -22,11 +22,11 @@ import Simplex.Messaging.Agent.ServerClient (ServerClient (..), newServerClient) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Server (randomBytes) -import Simplex.Messaging.Server.Transmission (Cmd (..), CorrId (..), SParty (..)) +import Simplex.Messaging.Server.Transmission (Cmd (..), CorrId (..), PublicKey, RecipientId, SParty (..), SenderId) import qualified Simplex.Messaging.Server.Transmission as SMP import Simplex.Messaging.Transport import UnliftIO.Async -import UnliftIO.Exception (Exception, SomeException) +import UnliftIO.Exception (SomeException) import qualified UnliftIO.Exception as E import UnliftIO.IO import UnliftIO.STM @@ -61,8 +61,8 @@ send h AgentClient {sndQ} = forever $ atomically (readTBQueue sndQ) >>= tPut h client :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () client c@AgentClient {rcvQ, sndQ} = forever $ do - t@(corrId, cAlias, cmd) <- atomically $ readTBQueue rcvQ - runExceptT (processCommand c t cmd) >>= \case + t@(corrId, cAlias, _) <- atomically $ readTBQueue rcvQ + runExceptT (processCommand c t) >>= \case Left e -> atomically $ writeTBQueue sndQ (corrId, cAlias, ERR e) Right _ -> return () @@ -75,14 +75,20 @@ withStore action = Left _ -> throwError INTERNAL Right c -> return c -processCommand :: forall m. (MonadUnliftIO m, MonadReader Env m, MonadError ErrorType m) => AgentClient -> ATransmission 'Client -> ACommand 'Client -> m () -processCommand AgentClient {respQ, servers, commands} t = \case - NEW smpServer _ -> do - srv <- getSMPServer smpServer - smpT <- mkSmpNEW - atomically $ writeTBQueue (smpSndQ srv) smpT - return () - _ -> throwError PROHIBITED +processCommand :: + forall m. + (MonadUnliftIO m, MonadReader Env m, MonadError ErrorType m) => + AgentClient -> + ATransmission 'Client -> + m () +processCommand AgentClient {respQ, servers, commands} t@(_, connAlias, cmd) = + case cmd of + NEW smpServer _ -> do + srv <- getSMPServer smpServer + smpT <- mkSmpNEW smpServer + atomically $ writeTBQueue (smpSndQ srv) smpT + return () + _ -> throwError PROHIBITED where replyError :: ErrorType -> SomeException -> m a replyError err e = do @@ -90,60 +96,89 @@ processCommand AgentClient {respQ, servers, commands} t = \case throwError err getSMPServer :: SMPServer -> m ServerClient - getSMPServer s@SMPServer {host, port} = do + getSMPServer SMPServer {host, port} = do defPort <- asks $ smpTcpPort . config let p = fromMaybe defPort port atomically (M.lookup (host, p) <$> readTVar servers) - >>= maybe (newSMPServer s host p) return + >>= maybe (newSMPServer host p) return - newSMPServer :: SMPServer -> HostName -> ServiceName -> m ServerClient - newSMPServer s host port = do + newSMPServer :: HostName -> ServiceName -> m ServerClient + newSMPServer host port = do cfg <- asks $ smpConfig . config - store <- asks db - _serverId <- withStore (addServer store s) `E.catch` replyError INTERNAL + -- store <- asks db + -- _serverId <- withStore (addServer store s) `E.catch` replyError INTERNAL srv <- newServerClient cfg respQ host port `E.catch` replyError (BROKER smpErrTCPConnection) atomically . modifyTVar servers $ M.insert (host, port) srv return srv - mkSmpNEW :: m SMP.Transmission - mkSmpNEW = do + mkSmpNEW :: SMPServer -> m SMP.Transmission + mkSmpNEW smpServer = do g <- asks idsDrg smpCorrId <- atomically $ CorrId <$> randomBytes 4 g recipientKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair - let recipientPrivateKey = recipientKey + let rcvPrivateKey = recipientKey toSMP = ("", (smpCorrId, "", Cmd SRecipient $ SMP.NEW recipientKey)) req = Request { fromClient = t, toSMP, - state = NEWRequestState {recipientKey, recipientPrivateKey} + state = NEWRequestState {connAlias, smpServer, rcvPrivateKey} } atomically . modifyTVar commands $ M.insert smpCorrId req -- TODO check ID collision return toSMP processSmp :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () -processSmp AgentClient {respQ, sndQ, commands} = forever $ do +processSmp c@AgentClient {respQ, sndQ, commands} = forever $ do (_, (smpCorrId, qId, cmdOrErr)) <- atomically $ readTBQueue respQ liftIO $ putStrLn "received from server" -- TODO remove liftIO $ print (smpCorrId, qId, cmdOrErr) req <- atomically $ M.lookup smpCorrId <$> readTVar commands case req of -- TODO empty correlation ID is ok - it can be a message Nothing -> atomically $ writeTBQueue sndQ ("", "", ERR $ BROKER smpErrCorrelationId) - Just r -> processResponse r cmdOrErr - where - processResponse :: Request -> Either SMP.ErrorType SMP.Cmd -> m () - processResponse Request {fromClient = (corrId, cAlias, cmd), toSMP = (_, (_, _, smpCmd)), state} cmdOrErr = do - case cmdOrErr of - Left e -> respond $ ERR (SMP e) - Right resp -> case resp of - Cmd SBroker (SMP.IDS recipientId senderId) -> case smpCmd of - Cmd SRecipient (SMP.NEW _) -> case (cmd, state) of - (NEW _ _, NEWRequestState {recipientKey, recipientPrivateKey}) -> do - -- TODO all good - process response - respond $ ERR UNKNOWN - _ -> respond $ ERR INTERNAL - _ -> respond $ ERR (BROKER smpUnexpectedResponse) - _ -> respond $ ERR UNSUPPORTED - where - respond :: ACommand 'Agent -> m () - respond c = atomically $ writeTBQueue sndQ (corrId, cAlias, c) + Just r@Request {fromClient = (corrId, cAlias, _)} -> + runExceptT (processResponse c r cmdOrErr) >>= \case + Left e -> atomically $ writeTBQueue sndQ (corrId, cAlias, ERR e) + Right _ -> return () + +processResponse :: + forall m. + (MonadUnliftIO m, MonadReader Env m, MonadError ErrorType m) => + AgentClient -> + Request -> + Either SMP.ErrorType SMP.Cmd -> + m () +processResponse + AgentClient {sndQ} + Request {fromClient = (corrId, cAlias, cmd), toSMP = (_, (_, _, smpCmd)), state} + cmdOrErr = do + case cmdOrErr of + Left e -> throwError $ SMP e + Right resp -> case resp of + Cmd SBroker (SMP.IDS recipientId senderId) -> case smpCmd of + Cmd SRecipient (SMP.NEW _) -> case (cmd, state) of + (NEW _ _, NEWRequestState {connAlias, smpServer, rcvPrivateKey}) -> do + -- TODO all good - process response + g <- asks idsDrg + st <- asks db + encryptKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair + let decryptKey = encryptKey + withStore $ + createRcvConn st connAlias $ + ReceiveQueue + { server = smpServer, + rcvId = recipientId, + rcvPrivateKey, + sndId = Just senderId, + sndKey = Nothing, + decryptKey, + verifyKey = Nothing, + status = New, + ackMode = AckMode On + } + respond . INV $ SMPQueueInfo smpServer senderId encryptKey + _ -> throwError INTERNAL + _ -> throwError $ BROKER smpUnexpectedResponse + _ -> throwError UNSUPPORTED + where + respond :: ACommand 'Agent -> m () + respond c = atomically $ writeTBQueue sndQ (corrId, cAlias, c) diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index e8b1d4717..646b5461a 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -50,8 +50,9 @@ data Request = Request } data RequestState = NEWRequestState - { recipientKey :: PublicKey, - recipientPrivateKey :: PrivateKey + { connAlias :: ConnAlias, + smpServer :: SMPServer, + rcvPrivateKey :: PrivateKey } newAgentClient :: Natural -> STM AgentClient diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 3ea272ac8..6399be206 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -117,12 +117,12 @@ upsertServer SQLiteStore {conn} srv@SMPServer {host, port} = do DB.execute conn [s| - INSERT INTO servers (host, port, key_hash) VALUES (?, ?, ?) - ON CONFLICT (host, port) DO UPDATE SET - host=excluded.host, - port=excluded.port, - key_hash=excluded.key_hash; - |] + INSERT INTO servers (host, port, key_hash) VALUES (?, ?, ?) + ON CONFLICT (host, port) DO UPDATE SET + host=excluded.host, + port=excluded.port, + key_hash=excluded.key_hash; + |] srv DB.queryNamed conn @@ -323,7 +323,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto -- TODO refactor ito a single query with join, and parse as `Only connAlias :. rcvQueue :. sndQueue` getConn :: SQLiteStore -> ConnAlias -> m SomeConn - getConn st connAlias = do + getConn st connAlias = getConnection st connAlias >>= \case (Just rcvQId, Just sndQId) -> do rcvQ <- getRcvQueue st rcvQId diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 222a89aa0..03511086a 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -2,6 +2,7 @@ {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -29,7 +30,7 @@ import Simplex.Messaging.Server.Transmission Encoded, MsgBody, PublicKey, - QueueId, + SenderId, errBadParameters, errMessageBody, ) @@ -115,7 +116,7 @@ newtype AckMode = AckMode Mode deriving (Eq, Show) newtype SubMode = SubMode Mode deriving (Show) -data SMPQueueInfo = SMPQueueInfo SMPServer QueueId EncryptionKey +data SMPQueueInfo = SMPQueueInfo SMPServer SenderId EncryptionKey deriving (Show) type EncryptionKey = PublicKey @@ -214,7 +215,12 @@ parseCommand command = case B.words command of errInv = Left $ SYNTAX errBadInvitation serializeCommand :: ACommand p -> ByteString -serializeCommand = B.pack . show +serializeCommand = \case + INV (SMPQueueInfo srv qId ek) -> "INV smp::" <> server srv <> "::" <> encode qId <> "::" <> encode ek + c -> B.pack $ show c + where + server :: SMPServer -> ByteString + server SMPServer {host, port, keyHash} = B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack) keyHash tPutRaw :: MonadIO m => Handle -> ARawTransmission -> m () tPutRaw h (corrId, connAlias, command) = do