From f50da16d0a25c422daaa635aa81ea4803a264d61 Mon Sep 17 00:00:00 2001 From: Efim Poberezkin Date: Sat, 23 Jan 2021 17:06:01 +0400 Subject: [PATCH] reorganize Protocol and Agent Store (#25) * chore: move members from Server/Transmission.hs to Protocol.hs * chore: revert qualified SMP import for server and client * chore: fix corrId call * chore: move common types to Common.hs * chore: decompose SQLite.hs * chore: rename Agent/Transmission.hs ErrorType -> AgentErrorType * chore: move Protocol ErrorType -> Common SMPErrorType * chore: rename Common -> Types * chore: revert SMPErrorType -> ErrorType --- src/Simplex/Messaging/Agent.hs | 12 +- src/Simplex/Messaging/Agent/Client.hs | 19 +- src/Simplex/Messaging/Agent/Env/SQLite.hs | 1 + src/Simplex/Messaging/Agent/Store.hs | 2 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 361 +---------------- .../Messaging/Agent/Store/SQLite/Types.hs | 29 ++ .../Messaging/Agent/Store/SQLite/Util.hs | 369 ++++++++++++++++++ src/Simplex/Messaging/Agent/Transmission.hs | 39 +- src/Simplex/Messaging/Client.hs | 3 +- .../{Server/Transmission.hs => Protocol.hs} | 52 +-- src/Simplex/Messaging/Server.hs | 3 +- src/Simplex/Messaging/Server/Env/STM.hs | 3 +- src/Simplex/Messaging/Server/MsgStore.hs | 2 +- src/Simplex/Messaging/Server/MsgStore/STM.hs | 2 +- src/Simplex/Messaging/Server/QueueStore.hs | 3 +- .../Messaging/Server/QueueStore/STM.hs | 3 +- src/Simplex/Messaging/Types.hs | 60 +++ tests/AgentTests/SQLite.hs | 1 + tests/SMPAgentClient.hs | 2 +- tests/SMPClient.hs | 2 +- tests/ServerTests.hs | 3 +- 21 files changed, 512 insertions(+), 459 deletions(-) create mode 100644 src/Simplex/Messaging/Agent/Store/SQLite/Types.hs create mode 100644 src/Simplex/Messaging/Agent/Store/SQLite/Util.hs rename src/Simplex/Messaging/{Server/Transmission.hs => Protocol.hs} (89%) create mode 100644 src/Simplex/Messaging/Types.hs diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index dc4451009..c536d89b3 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -21,13 +21,13 @@ import Data.Text.Encoding import Simplex.Messaging.Agent.Client import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Store -import Simplex.Messaging.Agent.Store.SQLite +import Simplex.Messaging.Agent.Store.SQLite.Types import Simplex.Messaging.Agent.Store.Types import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client (SMPServerTransmission) +import Simplex.Messaging.Types (CorrId (..), MsgBody, PrivateKey, SenderKey) +import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server (randomBytes) -import Simplex.Messaging.Server.Transmission (CorrId (..), PrivateKey) -import qualified Simplex.Messaging.Server.Transmission as SMP import Simplex.Messaging.Transport import UnliftIO.Async import UnliftIO.Exception (SomeException) @@ -129,7 +129,7 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) = ReplyOff -> return () respond CON - sendMessage :: SMP.MsgBody -> m () + sendMessage :: MsgBody -> m () sendMessage msgBody = withStore (`getConn` connAlias) >>= \case SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq @@ -222,7 +222,7 @@ processSMPTransmission c@AgentClient {sndQ} (srv, rId, cmd) = do notify :: ConnAlias -> ACommand 'Agent -> m () notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg) -connectToSendQueue :: AgentMonad m => AgentClient -> SendQueue -> SMP.SenderKey -> m () +connectToSendQueue :: AgentMonad m => AgentClient -> SendQueue -> SenderKey -> m () connectToSendQueue c sq senderKey = do sendConfirmation c sq senderKey withStore $ \st -> updateSndQueueStatus st sq Confirmed @@ -233,7 +233,7 @@ decryptMessage :: MonadUnliftIO m => PrivateKey -> ByteString -> m ByteString decryptMessage _decryptKey = return newSendQueue :: - (MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SendQueue, SMP.SenderKey) + (MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SendQueue, SenderKey) newSendQueue (SMPQueueInfo smpServer senderId encryptKey) = do g <- asks idsDrg senderKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 8f223a5bc..ad73f3d6b 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -41,8 +41,7 @@ import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client import Simplex.Messaging.Server (randomBytes) -import Simplex.Messaging.Server.Transmission (PrivateKey, PublicKey, SenderKey) -import qualified Simplex.Messaging.Server.Transmission as SMP +import Simplex.Messaging.Types (ErrorType (AUTH), MsgBody, PrivateKey, PublicKey, QueueId, SenderKey) import UnliftIO.Concurrent import UnliftIO.Exception (SomeException) import qualified UnliftIO.Exception as E @@ -66,7 +65,7 @@ newAgentClient cc qSize = do writeTVar cc clientId return AgentClient {rcvQ, sndQ, msgQ, smpClients, clientId} -type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError ErrorType m) +type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient getSMPServerClient AgentClient {smpClients, msgQ} srv = @@ -80,7 +79,7 @@ getSMPServerClient AgentClient {smpClients, msgQ} srv = atomically . modifyTVar smpClients $ M.insert srv c return c - throwErr :: ErrorType -> SomeException -> m a + throwErr :: AgentErrorType -> SomeException -> m a throwErr err e = do liftIO . putStrLn $ "Exception: " ++ show e -- TODO remove throwError err @@ -94,18 +93,18 @@ withSMP c srv action = liftIO (first smpClientError <$> runExceptT (action smp)) >>= liftEither - smpClientError :: SMPClientError -> ErrorType + smpClientError :: SMPClientError -> AgentErrorType smpClientError = \case SMPServerError e -> SMP e -- TODO handle other errors _ -> INTERNAL - logServerError :: ErrorType -> m a + logServerError :: AgentErrorType -> m a logServerError e = do logServer "<--" c srv "" $ (B.pack . show) e throwError e -withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> SMP.QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a +withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a withLogSMP c srv qId cmdStr action = do logServer "-->" c srv qId cmdStr res <- withSMP c srv action @@ -136,7 +135,7 @@ newReceiveQueue c srv = do } return (rcvQueue, SMPQueueInfo srv sId encryptKey) -logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> SMP.QueueId -> ByteString -> m () +logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} SMPServer {host, port} qId cmdStr = logInfo . decodeUtf8 $ B.unwords ["A", "(" <> (B.pack . show) clientId <> ")", dir, server, ":", logSecret qId, cmdStr] where @@ -152,7 +151,7 @@ sendConfirmation c SendQueue {server, sndId} senderKey = do withLogSMP c server sndId "SEND " $ \smp -> sendSMPMessage smp "" sndId msg where - mkConfirmation :: m SMP.MsgBody + mkConfirmation :: m MsgBody mkConfirmation = do let msg = serializeSMPMessage $ SMPConfirmation senderKey -- TODO encryption @@ -172,7 +171,7 @@ sendHello c SendQueue {server, sndId, sndPrivateKey, encryptKey} = do send 0 _ _ = throwE SMPResponseTimeout -- TODO different error send retry msg smp = sendSMPMessage smp sndPrivateKey sndId msg `catchE` \case - SMPServerError SMP.AUTH -> do + SMPServerError AUTH -> do threadDelay 100000 send (retry - 1) msg smp e -> throwE e diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 0847b38a8..fb75716a2 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -10,6 +10,7 @@ import Crypto.Random import Network.Socket import Numeric.Natural import Simplex.Messaging.Agent.Store.SQLite +import Simplex.Messaging.Agent.Store.SQLite.Types import Simplex.Messaging.Client import UnliftIO.STM diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 9e12a7e5f..acfa45341 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -15,7 +15,7 @@ import Data.Time.Clock (UTCTime) import Data.Type.Equality import Simplex.Messaging.Agent.Store.Types import Simplex.Messaging.Agent.Transmission -import Simplex.Messaging.Server.Transmission (PrivateKey, PublicKey, RecipientId, SenderId) +import Simplex.Messaging.Types data ReceiveQueue = ReceiveQueue { server :: SMPServer, diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 7f9ebd24b..aa90a99e9 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -8,9 +8,7 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -19,47 +17,17 @@ module Simplex.Messaging.Agent.Store.SQLite where import Control.Monad import Control.Monad.Except import Control.Monad.IO.Unlift -import Data.Int (Int64) import Data.Maybe -import qualified Data.Text as T -import Data.Time -import Database.SQLite.Simple hiding (Connection) import qualified Database.SQLite.Simple as DB -import Database.SQLite.Simple.FromField -import Database.SQLite.Simple.Internal (Field (..)) -import Database.SQLite.Simple.Ok -import Database.SQLite.Simple.QQ (sql) -import Database.SQLite.Simple.ToField -import Network.Socket import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite.Schema +import Simplex.Messaging.Agent.Store.SQLite.Types +import Simplex.Messaging.Agent.Store.SQLite.Util import Simplex.Messaging.Agent.Store.Types import Simplex.Messaging.Agent.Transmission -import Simplex.Messaging.Server.Transmission (RecipientId, SenderId) -import Simplex.Messaging.Util -import Text.Read -import qualified UnliftIO.Exception as E +import Simplex.Messaging.Types import UnliftIO.STM -addRcvQueueQuery :: Query -addRcvQueueQuery = - [sql| - INSERT INTO receive_queues - ( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode) - VALUES - (:server_id,:rcv_id,:rcv_private_key,:snd_id,:snd_key,:decrypt_key,:verify_key,:status,:ack_mode); - |] - -data SQLiteStore = SQLiteStore - { dbFilename :: String, - conn :: DB.Connection, - serversLock :: TMVar (), - rcvQueuesLock :: TMVar (), - sndQueuesLock :: TMVar (), - connectionsLock :: TMVar (), - messagesLock :: TMVar () - } - newSQLiteStore :: MonadUnliftIO m => String -> m SQLiteStore newSQLiteStore dbFilename = do conn <- liftIO $ DB.open dbFilename @@ -80,329 +48,6 @@ newSQLiteStore dbFilename = do messagesLock } -type QueueRowId = Int64 - -type ConnectionRowId = Int64 - -fromFieldToReadable :: forall a. (Read a, E.Typeable a) => Field -> Ok a -fromFieldToReadable = \case - f@(Field (SQLText t) _) -> - let str = T.unpack t - in case readMaybe str of - Just x -> Ok x - _ -> returnError ConversionFailed f ("invalid string: " <> str) - f -> returnError ConversionFailed f "expecting SQLText column type" - -withLock :: MonadUnliftIO m => SQLiteStore -> (SQLiteStore -> TMVar ()) -> (DB.Connection -> m a) -> m a -withLock st tableLock f = do - let lock = tableLock st - E.bracket_ - (atomically $ takeTMVar lock) - (atomically $ putTMVar lock ()) - (f $ conn st) - -insertWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m Int64 -insertWithLock st tableLock queryStr q = do - withLock st tableLock $ \c -> liftIO $ do - DB.execute c queryStr q - DB.lastInsertRowId c - -executeWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m () -executeWithLock st tableLock queryStr q = do - withLock st tableLock $ \c -> liftIO $ do - DB.execute c queryStr q - -instance ToRow SMPServer where - toRow SMPServer {host, port, keyHash} = toRow (host, port, keyHash) - -instance FromRow SMPServer where - fromRow = SMPServer <$> field <*> field <*> field - -upsertServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServer -> m SMPServerId -upsertServer SQLiteStore {conn} srv@SMPServer {host, port} = do - r <- liftIO $ do - DB.execute - conn - [sql| - INSERT INTO servers (host, port, key_hash) VALUES (?, ?, ?) - ON CONFLICT (host, port) DO UPDATE SET - host=excluded.host, - port=excluded.port, - key_hash=excluded.key_hash; - |] - srv - DB.queryNamed - conn - "SELECT server_id FROM servers WHERE host = :host AND port = :port" - [":host" := host, ":port" := port] - case r of - [Only serverId] -> return serverId - _ -> throwError SEInternal - -getServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServerId -> m SMPServer -getServer SQLiteStore {conn} serverId = do - r <- - liftIO $ - DB.queryNamed - conn - "SELECT host, port, key_hash FROM servers WHERE server_id = :server_id" - [":server_id" := serverId] - case r of - [smpServer] -> return smpServer - _ -> throwError SENotFound - -instance ToField AckMode where toField (AckMode mode) = toField $ show mode - -instance FromField AckMode where fromField = AckMode <$$> fromFieldToReadable - -instance ToField QueueStatus where toField = toField . show - -instance FromField QueueStatus where fromField = fromFieldToReadable - -instance ToRow ReceiveQueue where - toRow ReceiveQueue {rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode} = - toRow (rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode) - -instance FromRow ReceiveQueue where - fromRow = ReceiveQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field <*> field <*> field - --- TODO refactor into a single query with join -getRcvQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m ReceiveQueue -getRcvQueue st@SQLiteStore {conn} queueRowId = do - r <- - liftIO $ - DB.queryNamed - conn - [sql| - SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode - FROM receive_queues - WHERE receive_queue_id = :rowId; - |] - [":rowId" := queueRowId] - case r of - [Only serverId :. rcvQueue] -> - (\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId - _ -> throwError SENotFound - -getRcvQueueByRecipientId :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> RecipientId -> HostName -> Maybe ServiceName -> m ReceiveQueue -getRcvQueueByRecipientId st@SQLiteStore {conn} rcvId host port = do - r <- - liftIO $ - DB.queryNamed - conn - [sql| - SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode - FROM receive_queues - WHERE rcv_id = :rcvId AND server_id IN ( - SELECT server_id - FROM servers - WHERE host = :host AND port = :port - ); - |] - [":rcvId" := rcvId, ":host" := host, ":port" := port] - case r of - [Only serverId :. rcvQueue] -> - (\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId - _ -> throwError SENotFound - --- TODO refactor into a single query with join -getSndQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m SendQueue -getSndQueue st@SQLiteStore {conn} queueRowId = do - r <- - liftIO $ - DB.queryNamed - conn - [sql| - SELECT server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode - FROM send_queues - WHERE send_queue_id = :rowId; - |] - [":rowId" := queueRowId] - case r of - [Only serverId :. sndQueue] -> - (\srv -> (sndQueue {server = srv} :: SendQueue)) <$> getServer st serverId - _ -> throwError SENotFound - -insertRcvQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> ReceiveQueue -> m QueueRowId -insertRcvQueue store serverId rcvQueue = - insertWithLock - store - rcvQueuesLock - [sql| - INSERT INTO receive_queues - ( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode) - VALUES (?,?,?,?,?,?,?,?,?); - |] - (Only serverId :. rcvQueue) - -insertRcvConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m () -insertRcvConnection store connAlias rcvQueueId = - void $ - insertWithLock - store - connectionsLock - "INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,?,NULL);" - (Only connAlias :. Only rcvQueueId) - -updateRcvConnectionWithSndQueue :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m () -updateRcvConnectionWithSndQueue store connAlias sndQueueId = - executeWithLock - store - connectionsLock - [sql| - UPDATE connections - SET send_queue_id = ? - WHERE conn_alias = ?; - |] - (Only sndQueueId :. Only connAlias) - -instance ToRow SendQueue where - toRow SendQueue {sndId, sndPrivateKey, encryptKey, signKey, status, ackMode} = - toRow (sndId, sndPrivateKey, encryptKey, signKey, status, ackMode) - -instance FromRow SendQueue where - fromRow = SendQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field - -insertSndQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> SendQueue -> m QueueRowId -insertSndQueue store serverId sndQueue = - insertWithLock - store - sndQueuesLock - [sql| - INSERT INTO send_queues - ( server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode) - VALUES (?,?,?,?,?,?,?); - |] - (Only serverId :. sndQueue) - -insertSndConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m () -insertSndConnection store connAlias sndQueueId = - void $ - insertWithLock - store - connectionsLock - "INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,NULL,?);" - (Only connAlias :. Only sndQueueId) - -updateSndConnectionWithRcvQueue :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m () -updateSndConnectionWithRcvQueue store connAlias rcvQueueId = - executeWithLock - store - connectionsLock - [sql| - UPDATE connections - SET receive_queue_id = ? - WHERE conn_alias = ?; - |] - (Only rcvQueueId :. Only connAlias) - -getConnection :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> ConnAlias -> m (Maybe QueueRowId, Maybe QueueRowId) -getConnection SQLiteStore {conn} connAlias = do - r <- - liftIO $ - DB.queryNamed - conn - "SELECT receive_queue_id, send_queue_id FROM connections WHERE conn_alias = :conn_alias" - [":conn_alias" := connAlias] - case r of - [queueIds] -> return queueIds - _ -> throwError SEInternal - -instance FromRow ConnAlias where - fromRow = field - -getConnAliasByRcvQueue :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> RecipientId -> m ConnAlias -getConnAliasByRcvQueue SQLiteStore {conn} rcvId = do - r <- - liftIO $ - DB.queryNamed - conn - [sql| - SELECT c.conn_alias - FROM connections c - JOIN receive_queues rq - ON c.receive_queue_id = rq.receive_queue_id - WHERE rq.rcv_id = :rcvId; - |] - [":rcvId" := rcvId] - case r of - [connAlias] -> return connAlias - _ -> throwError SEInternal - -deleteRcvQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m () -deleteRcvQueue store rcvQueueId = do - executeWithLock - store - rcvQueuesLock - "DELETE FROM receive_queues WHERE receive_queue_id = ?" - (Only rcvQueueId) - -deleteSndQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m () -deleteSndQueue store sndQueueId = do - executeWithLock - store - sndQueuesLock - "DELETE FROM send_queues WHERE send_queue_id = ?" - (Only sndQueueId) - -deleteConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> m () -deleteConnection store connAlias = do - executeWithLock - store - connectionsLock - "DELETE FROM connections WHERE conn_alias = ?" - (Only connAlias) - -updateReceiveQueueStatus :: MonadUnliftIO m => SQLiteStore -> RecipientId -> HostName -> Maybe ServiceName -> QueueStatus -> m () -updateReceiveQueueStatus store rcvQueueId host port status = - executeWithLock - store - rcvQueuesLock - [sql| - UPDATE receive_queues - SET status = ? - WHERE rcv_id = ? - AND server_id IN ( - SELECT server_id - FROM servers - WHERE host = ? AND port = ? - ); - |] - (Only status :. Only rcvQueueId :. Only host :. Only port) - -updateSendQueueStatus :: MonadUnliftIO m => SQLiteStore -> SenderId -> HostName -> Maybe ServiceName -> QueueStatus -> m () -updateSendQueueStatus store sndQueueId host port status = - executeWithLock - store - sndQueuesLock - [sql| - UPDATE send_queues - SET status = ? - WHERE snd_id = ? - AND server_id IN ( - SELECT server_id - FROM servers - WHERE host = ? AND port = ? - ); - |] - (Only status :. Only sndQueueId :. Only host :. Only port) - -instance ToField QueueDirection where toField = toField . show - --- TODO add parser and serializer for DeliveryStatus? Pass DeliveryStatus? -insertMsg :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> Message -> m () -insertMsg store connAlias qDirection agentMsgId msg = do - tstamp <- liftIO getCurrentTime - void $ - insertWithLock - store - messagesLock - [sql| - INSERT INTO messages (conn_alias, agent_msg_id, timestamp, message, direction, msg_status) - VALUES (?,?,?,?,?,"MDTransmitted"); - |] - (Only connAlias :. Only agentMsgId :. Only tstamp :. Only qDirection :. Only msg) - instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where addServer store smpServer = upsertServer store smpServer diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Types.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Types.hs new file mode 100644 index 000000000..2fa9cc0ac --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Types.hs @@ -0,0 +1,29 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module Simplex.Messaging.Agent.Store.SQLite.Types where + +import Data.Int (Int64) +import qualified Database.SQLite.Simple as DB +import UnliftIO.STM + +data SQLiteStore = SQLiteStore + { dbFilename :: String, + conn :: DB.Connection, + serversLock :: TMVar (), + rcvQueuesLock :: TMVar (), + sndQueuesLock :: TMVar (), + connectionsLock :: TMVar (), + messagesLock :: TMVar () + } + +type QueueRowId = Int64 + +type ConnectionRowId = Int64 diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs new file mode 100644 index 000000000..eca0b3672 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs @@ -0,0 +1,369 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module Simplex.Messaging.Agent.Store.SQLite.Util where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Unlift +import Data.Int (Int64) +import qualified Data.Text as T +import Data.Time +import Database.SQLite.Simple hiding (Connection) +import qualified Database.SQLite.Simple as DB +import Database.SQLite.Simple.FromField +import Database.SQLite.Simple.Internal (Field (..)) +import Database.SQLite.Simple.Ok +import Database.SQLite.Simple.QQ (sql) +import Database.SQLite.Simple.ToField +import Network.Socket +import Simplex.Messaging.Agent.Store +import Simplex.Messaging.Agent.Store.SQLite.Types +import Simplex.Messaging.Agent.Store.Types +import Simplex.Messaging.Agent.Transmission +import Simplex.Messaging.Types +import Simplex.Messaging.Util +import Text.Read +import qualified UnliftIO.Exception as E +import UnliftIO.STM + +addRcvQueueQuery :: Query +addRcvQueueQuery = + [sql| + INSERT INTO receive_queues + ( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode) + VALUES + (:server_id,:rcv_id,:rcv_private_key,:snd_id,:snd_key,:decrypt_key,:verify_key,:status,:ack_mode); + |] + +fromFieldToReadable :: forall a. (Read a, E.Typeable a) => Field -> Ok a +fromFieldToReadable = \case + f@(Field (SQLText t) _) -> + let str = T.unpack t + in case readMaybe str of + Just x -> Ok x + _ -> returnError ConversionFailed f ("invalid string: " <> str) + f -> returnError ConversionFailed f "expecting SQLText column type" + +withLock :: MonadUnliftIO m => SQLiteStore -> (SQLiteStore -> TMVar ()) -> (DB.Connection -> m a) -> m a +withLock st tableLock f = do + let lock = tableLock st + E.bracket_ + (atomically $ takeTMVar lock) + (atomically $ putTMVar lock ()) + (f $ conn st) + +insertWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m Int64 +insertWithLock st tableLock queryStr q = do + withLock st tableLock $ \c -> liftIO $ do + DB.execute c queryStr q + DB.lastInsertRowId c + +executeWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m () +executeWithLock st tableLock queryStr q = do + withLock st tableLock $ \c -> liftIO $ do + DB.execute c queryStr q + +instance ToRow SMPServer where + toRow SMPServer {host, port, keyHash} = toRow (host, port, keyHash) + +instance FromRow SMPServer where + fromRow = SMPServer <$> field <*> field <*> field + +upsertServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServer -> m SMPServerId +upsertServer SQLiteStore {conn} srv@SMPServer {host, port} = do + r <- liftIO $ do + DB.execute + conn + [sql| + INSERT INTO servers (host, port, key_hash) VALUES (?, ?, ?) + ON CONFLICT (host, port) DO UPDATE SET + host=excluded.host, + port=excluded.port, + key_hash=excluded.key_hash; + |] + srv + DB.queryNamed + conn + "SELECT server_id FROM servers WHERE host = :host AND port = :port" + [":host" := host, ":port" := port] + case r of + [Only serverId] -> return serverId + _ -> throwError SEInternal + +getServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServerId -> m SMPServer +getServer SQLiteStore {conn} serverId = do + r <- + liftIO $ + DB.queryNamed + conn + "SELECT host, port, key_hash FROM servers WHERE server_id = :server_id" + [":server_id" := serverId] + case r of + [smpServer] -> return smpServer + _ -> throwError SENotFound + +instance ToField AckMode where toField (AckMode mode) = toField $ show mode + +instance FromField AckMode where fromField = AckMode <$$> fromFieldToReadable + +instance ToField QueueStatus where toField = toField . show + +instance FromField QueueStatus where fromField = fromFieldToReadable + +instance ToRow ReceiveQueue where + toRow ReceiveQueue {rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode} = + toRow (rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode) + +instance FromRow ReceiveQueue where + fromRow = ReceiveQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field <*> field <*> field + +-- TODO refactor into a single query with join +getRcvQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m ReceiveQueue +getRcvQueue st@SQLiteStore {conn} queueRowId = do + r <- + liftIO $ + DB.queryNamed + conn + [sql| + SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode + FROM receive_queues + WHERE receive_queue_id = :rowId; + |] + [":rowId" := queueRowId] + case r of + [Only serverId :. rcvQueue] -> + (\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId + _ -> throwError SENotFound + +getRcvQueueByRecipientId :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> RecipientId -> HostName -> Maybe ServiceName -> m ReceiveQueue +getRcvQueueByRecipientId st@SQLiteStore {conn} rcvId host port = do + r <- + liftIO $ + DB.queryNamed + conn + [sql| + SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode + FROM receive_queues + WHERE rcv_id = :rcvId AND server_id IN ( + SELECT server_id + FROM servers + WHERE host = :host AND port = :port + ); + |] + [":rcvId" := rcvId, ":host" := host, ":port" := port] + case r of + [Only serverId :. rcvQueue] -> + (\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId + _ -> throwError SENotFound + +-- TODO refactor into a single query with join +getSndQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m SendQueue +getSndQueue st@SQLiteStore {conn} queueRowId = do + r <- + liftIO $ + DB.queryNamed + conn + [sql| + SELECT server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode + FROM send_queues + WHERE send_queue_id = :rowId; + |] + [":rowId" := queueRowId] + case r of + [Only serverId :. sndQueue] -> + (\srv -> (sndQueue {server = srv} :: SendQueue)) <$> getServer st serverId + _ -> throwError SENotFound + +insertRcvQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> ReceiveQueue -> m QueueRowId +insertRcvQueue store serverId rcvQueue = + insertWithLock + store + rcvQueuesLock + [sql| + INSERT INTO receive_queues + ( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode) + VALUES (?,?,?,?,?,?,?,?,?); + |] + (Only serverId :. rcvQueue) + +insertRcvConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m () +insertRcvConnection store connAlias rcvQueueId = + void $ + insertWithLock + store + connectionsLock + "INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,?,NULL);" + (Only connAlias :. Only rcvQueueId) + +updateRcvConnectionWithSndQueue :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m () +updateRcvConnectionWithSndQueue store connAlias sndQueueId = + executeWithLock + store + connectionsLock + [sql| + UPDATE connections + SET send_queue_id = ? + WHERE conn_alias = ?; + |] + (Only sndQueueId :. Only connAlias) + +instance ToRow SendQueue where + toRow SendQueue {sndId, sndPrivateKey, encryptKey, signKey, status, ackMode} = + toRow (sndId, sndPrivateKey, encryptKey, signKey, status, ackMode) + +instance FromRow SendQueue where + fromRow = SendQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field + +insertSndQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> SendQueue -> m QueueRowId +insertSndQueue store serverId sndQueue = + insertWithLock + store + sndQueuesLock + [sql| + INSERT INTO send_queues + ( server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode) + VALUES (?,?,?,?,?,?,?); + |] + (Only serverId :. sndQueue) + +insertSndConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m () +insertSndConnection store connAlias sndQueueId = + void $ + insertWithLock + store + connectionsLock + "INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,NULL,?);" + (Only connAlias :. Only sndQueueId) + +updateSndConnectionWithRcvQueue :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m () +updateSndConnectionWithRcvQueue store connAlias rcvQueueId = + executeWithLock + store + connectionsLock + [sql| + UPDATE connections + SET receive_queue_id = ? + WHERE conn_alias = ?; + |] + (Only rcvQueueId :. Only connAlias) + +getConnection :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> ConnAlias -> m (Maybe QueueRowId, Maybe QueueRowId) +getConnection SQLiteStore {conn} connAlias = do + r <- + liftIO $ + DB.queryNamed + conn + "SELECT receive_queue_id, send_queue_id FROM connections WHERE conn_alias = :conn_alias" + [":conn_alias" := connAlias] + case r of + [queueIds] -> return queueIds + _ -> throwError SEInternal + +instance FromRow ConnAlias where + fromRow = field + +getConnAliasByRcvQueue :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> RecipientId -> m ConnAlias +getConnAliasByRcvQueue SQLiteStore {conn} rcvId = do + r <- + liftIO $ + DB.queryNamed + conn + [sql| + SELECT c.conn_alias + FROM connections c + JOIN receive_queues rq + ON c.receive_queue_id = rq.receive_queue_id + WHERE rq.rcv_id = :rcvId; + |] + [":rcvId" := rcvId] + case r of + [connAlias] -> return connAlias + _ -> throwError SEInternal + +deleteRcvQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m () +deleteRcvQueue store rcvQueueId = do + executeWithLock + store + rcvQueuesLock + "DELETE FROM receive_queues WHERE receive_queue_id = ?" + (Only rcvQueueId) + +deleteSndQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m () +deleteSndQueue store sndQueueId = do + executeWithLock + store + sndQueuesLock + "DELETE FROM send_queues WHERE send_queue_id = ?" + (Only sndQueueId) + +deleteConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> m () +deleteConnection store connAlias = do + executeWithLock + store + connectionsLock + "DELETE FROM connections WHERE conn_alias = ?" + (Only connAlias) + +updateReceiveQueueStatus :: MonadUnliftIO m => SQLiteStore -> RecipientId -> HostName -> Maybe ServiceName -> QueueStatus -> m () +updateReceiveQueueStatus store rcvQueueId host port status = + executeWithLock + store + rcvQueuesLock + [sql| + UPDATE receive_queues + SET status = ? + WHERE rcv_id = ? + AND server_id IN ( + SELECT server_id + FROM servers + WHERE host = ? AND port = ? + ); + |] + (Only status :. Only rcvQueueId :. Only host :. Only port) + +updateSendQueueStatus :: MonadUnliftIO m => SQLiteStore -> SenderId -> HostName -> Maybe ServiceName -> QueueStatus -> m () +updateSendQueueStatus store sndQueueId host port status = + executeWithLock + store + sndQueuesLock + [sql| + UPDATE send_queues + SET status = ? + WHERE snd_id = ? + AND server_id IN ( + SELECT server_id + FROM servers + WHERE host = ? AND port = ? + ); + |] + (Only status :. Only sndQueueId :. Only host :. Only port) + +instance ToField QueueDirection where toField = toField . show + +-- TODO add parser and serializer for DeliveryStatus? Pass DeliveryStatus? +insertMsg :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> Message -> m () +insertMsg store connAlias qDirection agentMsgId msg = do + tstamp <- liftIO getCurrentTime + void $ + insertWithLock + store + messagesLock + [sql| + INSERT INTO messages (conn_alias, agent_msg_id, timestamp, message, direction, msg_status) + VALUES (?,?,?,?,?,"MDTransmitted"); + |] + (Only connAlias :. Only agentMsgId :. Only tstamp :. Only qDirection :. Only msg) diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index bb42d5584..195b04345 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -30,17 +30,8 @@ import Data.Typeable () import Network.Socket import Numeric.Natural import Simplex.Messaging.Agent.Store.Types -import Simplex.Messaging.Server.Transmission - ( CorrId (..), - Encoded, - MsgBody, - PublicKey, - SenderId, - errBadParameters, - errMessageBody, - ) -import qualified Simplex.Messaging.Server.Transmission as SMP import Simplex.Messaging.Transport +import Simplex.Messaging.Types (CorrId (..), Encoded, ErrorType, MsgBody, PublicKey, SenderId, errBadParameters, errMessageBody) import Simplex.Messaging.Util import System.IO import Text.Read @@ -50,7 +41,7 @@ type ARawTransmission = (ByteString, ByteString, ByteString) type ATransmission p = (CorrId, ConnAlias, ACommand p) -type ATransmissionOrError p = (CorrId, ConnAlias, Either ErrorType (ACommand p)) +type ATransmissionOrError p = (CorrId, ConnAlias, Either AgentErrorType (ACommand p)) data AParty = Agent | Client deriving (Eq, Show) @@ -93,7 +84,7 @@ data ACommand (p :: AParty) where -- OFF :: ACommand Client -- DEL :: ACommand Client OK :: ACommand Agent - ERR :: ErrorType -> ACommand Agent + ERR :: AgentErrorType -> ACommand Agent deriving instance Show (ACommand p) @@ -115,7 +106,7 @@ data AMessage where A_MSG :: MsgBody -> AMessage deriving (Show) -parseSMPMessage :: ByteString -> Either ErrorType SMPMessage +parseSMPMessage :: ByteString -> Either AgentErrorType SMPMessage parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ SYNTAX errBadMessage where smpMessageP :: Parser SMPMessage @@ -179,13 +170,13 @@ base64P = do pad <- A.takeWhile (== '=') either fail pure $ decode (str <> pad) -parseAgentMessage :: ByteString -> Either ErrorType AMessage +parseAgentMessage :: ByteString -> Either AgentErrorType AMessage parseAgentMessage = parse agentMessageP $ SYNTAX errBadMessage parse :: Parser a -> e -> (ByteString -> Either e a) parse parser err = first (const err) . A.parseOnly (parser <* A.endOfInput) -errParams :: Either ErrorType a +errParams :: Either AgentErrorType a errParams = Left $ SYNTAX errBadParameters serializeAgentMessage :: AMessage -> ByteString @@ -241,12 +232,12 @@ data MsgStatus = MsgOk | MsgError MsgErrorType data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash deriving (Show) -data ErrorType +data AgentErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | BROKER Natural - | SMP SMP.ErrorType + | SMP ErrorType | SIZE | STORE StoreError | INTERNAL -- etc. TODO SYNTAX Natural @@ -315,7 +306,7 @@ parseCommandP = <|> "NO_ID " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal) <|> "HASH" $> MsgBadHash -parseCommand :: ByteString -> Either ErrorType ACmd +parseCommand :: ByteString -> Either AgentErrorType ACmd parseCommand = parse parseCommandP $ SYNTAX errBadCommand serializeCommand :: ACommand p -> ByteString @@ -365,8 +356,8 @@ tGetRaw h = do return (corrId, connAlias, command) tPut :: MonadIO m => Handle -> ATransmission p -> m () -tPut h (corrId, connAlias, command) = - liftIO $ tPutRaw h (bs corrId, connAlias, serializeCommand command) +tPut h (CorrId corrId, connAlias, command) = + liftIO $ tPutRaw h (corrId, connAlias, serializeCommand command) -- | get client and agent transmissions tGet :: forall m p. MonadIO m => SAParty p -> Handle -> m (ATransmissionOrError p) @@ -378,12 +369,12 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody fullCmd <- either (return . Left) cmdWithMsgBody cmd return (CorrId corrId, connAlias, fullCmd) - fromParty :: ACmd -> Either ErrorType (ACommand p) + fromParty :: ACmd -> Either AgentErrorType (ACommand p) fromParty (ACmd (p :: p1) cmd) = case testEquality party p of Just Refl -> Right cmd _ -> Left PROHIBITED - tConnAlias :: ARawTransmission -> ACommand p -> Either ErrorType (ACommand p) + tConnAlias :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p) tConnAlias (_, connAlias, _) cmd = case cmd of -- NEW and JOIN have optional connAlias NEW _ -> Right cmd @@ -395,14 +386,14 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody | B.null connAlias -> Left $ SYNTAX errNoConnAlias | otherwise -> Right cmd - cmdWithMsgBody :: ACommand p -> m (Either ErrorType (ACommand p)) + cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p)) cmdWithMsgBody = \case SEND body -> SEND <$$> getMsgBody body MSG agentMsgId srvTS agentTS status body -> MSG agentMsgId srvTS agentTS status <$$> getMsgBody body cmd -> return $ Right cmd -- TODO refactor with server - getMsgBody :: MsgBody -> m (Either ErrorType MsgBody) + getMsgBody :: MsgBody -> m (Either AgentErrorType MsgBody) getMsgBody msgBody = case B.unpack msgBody of ':' : body -> return . Right $ B.pack body diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index d8ef4885f..ff885a500 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -37,8 +37,9 @@ import Data.Maybe import Network.Socket (ServiceName) import Numeric.Natural import Simplex.Messaging.Agent.Transmission (SMPServer (..)) -import Simplex.Messaging.Server.Transmission +import Simplex.Messaging.Protocol import Simplex.Messaging.Transport +import Simplex.Messaging.Types import Simplex.Messaging.Util import System.IO diff --git a/src/Simplex/Messaging/Server/Transmission.hs b/src/Simplex/Messaging/Protocol.hs similarity index 89% rename from src/Simplex/Messaging/Server/Transmission.hs rename to src/Simplex/Messaging/Protocol.hs index c37cb92ea..391ac3d40 100644 --- a/src/Simplex/Messaging/Server/Transmission.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -9,7 +9,7 @@ {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} -module Simplex.Messaging.Server.Transmission where +module Simplex.Messaging.Protocol where import Control.Monad import Control.Monad.IO.Class @@ -18,9 +18,9 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Char (ord) import Data.Kind -import Data.String import Data.Time.Clock import Data.Time.ISO8601 +import Simplex.Messaging.Types import Simplex.Messaging.Transport import Simplex.Messaging.Util import System.IO @@ -143,54 +143,6 @@ serializeCommand = \case where serializeMsg msgBody = " " <> B.pack (show $ B.length msgBody) <> "\n" <> msgBody -type Encoded = ByteString - --- newtype to avoid accidentally changing order of transmission parts -newtype CorrId = CorrId {bs :: ByteString} deriving (Eq, Ord, Show) - -instance IsString CorrId where - fromString = CorrId . fromString - -type PublicKey = Encoded - -type PrivateKey = Encoded - -type Signature = Encoded - -type RecipientKey = PublicKey - -type SenderKey = PublicKey - -type RecipientId = QueueId - -type SenderId = QueueId - -type QueueId = Encoded - -type MsgId = Encoded - -type MsgBody = ByteString - -data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq) - -errBadTransmission :: Int -errBadTransmission = 1 - -errBadParameters :: Int -errBadParameters = 2 - -errNoCredentials :: Int -errNoCredentials = 3 - -errHasCredentials :: Int -errHasCredentials = 4 - -errNoQueueId :: Int -errNoQueueId = 5 - -errMessageBody :: Int -errMessageBody = 6 - tPutRaw :: Handle -> RawTransmission -> IO () tPutRaw h (signature, corrId, queueId, command) = do putLn h signature diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 371a2d10c..61dae8643 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -23,13 +23,14 @@ import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import qualified Data.Map.Strict as M import Data.Time.Clock +import Simplex.Messaging.Protocol import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Server.MsgStore import Simplex.Messaging.Server.MsgStore.STM (MsgQueue) import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.STM (QueueStore) -import Simplex.Messaging.Server.Transmission import Simplex.Messaging.Transport +import Simplex.Messaging.Types import Simplex.Messaging.Util import UnliftIO.Async import UnliftIO.Concurrent diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 37324d65d..182542678 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -10,9 +10,10 @@ import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Network.Socket (ServiceName) import Numeric.Natural +import Simplex.Messaging.Types +import Simplex.Messaging.Protocol import Simplex.Messaging.Server.MsgStore.STM import Simplex.Messaging.Server.QueueStore.STM -import Simplex.Messaging.Server.Transmission import UnliftIO.STM data ServerConfig = ServerConfig diff --git a/src/Simplex/Messaging/Server/MsgStore.hs b/src/Simplex/Messaging/Server/MsgStore.hs index 1d0d66daa..bf2e2d1d6 100644 --- a/src/Simplex/Messaging/Server/MsgStore.hs +++ b/src/Simplex/Messaging/Server/MsgStore.hs @@ -3,7 +3,7 @@ module Simplex.Messaging.Server.MsgStore where import Data.Time.Clock -import Simplex.Messaging.Server.Transmission +import Simplex.Messaging.Types data Message = Message { msgId :: Encoded, diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 3e124dc88..1c32c7755 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -8,8 +8,8 @@ module Simplex.Messaging.Server.MsgStore.STM where import Data.Map.Strict (Map) import qualified Data.Map.Strict as M +import Simplex.Messaging.Types import Simplex.Messaging.Server.MsgStore -import Simplex.Messaging.Server.Transmission import UnliftIO.STM newtype MsgQueue = MsgQueue {msgQueue :: TQueue Message} diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index 6ec83f2e9..15357d135 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -5,7 +5,8 @@ module Simplex.Messaging.Server.QueueStore where -import Simplex.Messaging.Server.Transmission +import Simplex.Messaging.Protocol +import Simplex.Messaging.Types data QueueRec = QueueRec { recipientId :: QueueId, diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index 28aab19df..be5b8d39e 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -13,8 +13,9 @@ module Simplex.Messaging.Server.QueueStore.STM where import Data.Map.Strict (Map) import qualified Data.Map.Strict as M +import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore -import Simplex.Messaging.Server.Transmission +import Simplex.Messaging.Types import UnliftIO.STM data QueueStoreData = QueueStoreData diff --git a/src/Simplex/Messaging/Types.hs b/src/Simplex/Messaging/Types.hs new file mode 100644 index 000000000..829992537 --- /dev/null +++ b/src/Simplex/Messaging/Types.hs @@ -0,0 +1,60 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} + +module Simplex.Messaging.Types where + +import Data.ByteString.Char8 (ByteString) +import Data.String + +type Encoded = ByteString + +-- newtype to avoid accidentally changing order of transmission parts +newtype CorrId = CorrId {bs :: ByteString} deriving (Eq, Ord, Show) + +instance IsString CorrId where + fromString = CorrId . fromString + +type PublicKey = Encoded + +type PrivateKey = Encoded + +type Signature = Encoded + +type RecipientKey = PublicKey + +type SenderKey = PublicKey + +type RecipientId = QueueId + +type SenderId = QueueId + +type QueueId = Encoded + +type MsgId = Encoded + +type MsgBody = ByteString + +data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq) + +errBadTransmission :: Int +errBadTransmission = 1 + +errBadParameters :: Int +errBadParameters = 2 + +errNoCredentials :: Int +errNoCredentials = 3 + +errHasCredentials :: Int +errHasCredentials = 4 + +errNoQueueId :: Int +errNoQueueId = 5 + +errMessageBody :: Int +errMessageBody = 6 diff --git a/tests/AgentTests/SQLite.hs b/tests/AgentTests/SQLite.hs index deb3d3f2f..8b4d68e3a 100644 --- a/tests/AgentTests/SQLite.hs +++ b/tests/AgentTests/SQLite.hs @@ -9,6 +9,7 @@ import Data.Word (Word32) import qualified Database.SQLite.Simple as DB import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite +import Simplex.Messaging.Agent.Store.SQLite.Types import Simplex.Messaging.Agent.Store.Types import Simplex.Messaging.Agent.Transmission import System.Random diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 7399c8695..dbb7d8d28 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -101,7 +101,7 @@ withSmpAgent = withSmpAgentOn (agentTestPort, testDB) testSMPAgentClientOn :: MonadUnliftIO m => ServiceName -> (Handle -> m a) -> m a testSMPAgentClientOn port' client = do - threadDelay 50_000 -- TODO hack: thread delay for SMP agent to start + threadDelay 100_000 -- TODO hack: thread delay for SMP agent to start runTCPClient agentTestHost port' $ \h -> do line <- liftIO $ getLn h if line == "Welcome to SMP v0.2.0 agent" diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 1386014ea..d28ec2030 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -7,9 +7,9 @@ module SMPClient where import Control.Monad.IO.Unlift import Crypto.Random import Network.Socket +import Simplex.Messaging.Protocol import Simplex.Messaging.Server import Simplex.Messaging.Server.Env.STM -import Simplex.Messaging.Server.Transmission import Simplex.Messaging.Transport import Test.Hspec import UnliftIO.Concurrent diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 415f2562f..b0f6f6d3d 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -11,7 +11,8 @@ import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import SMPClient -import Simplex.Messaging.Server.Transmission +import Simplex.Messaging.Types +import Simplex.Messaging.Protocol import System.IO (Handle) import System.Timeout import Test.HUnit