mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 22:55:50 +00:00
reorganize Protocol and Agent Store (#25)
* chore: move members from Server/Transmission.hs to Protocol.hs * chore: revert qualified SMP import for server and client * chore: fix corrId call * chore: move common types to Common.hs * chore: decompose SQLite.hs * chore: rename Agent/Transmission.hs ErrorType -> AgentErrorType * chore: move Protocol ErrorType -> Common SMPErrorType * chore: rename Common -> Types * chore: revert SMPErrorType -> ErrorType
This commit is contained in:
@@ -21,13 +21,13 @@ import Data.Text.Encoding
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Types
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Client (SMPServerTransmission)
|
||||
import Simplex.Messaging.Types (CorrId (..), MsgBody, PrivateKey, SenderKey)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server (randomBytes)
|
||||
import Simplex.Messaging.Server.Transmission (CorrId (..), PrivateKey)
|
||||
import qualified Simplex.Messaging.Server.Transmission as SMP
|
||||
import Simplex.Messaging.Transport
|
||||
import UnliftIO.Async
|
||||
import UnliftIO.Exception (SomeException)
|
||||
@@ -129,7 +129,7 @@ processCommand c@AgentClient {sndQ} (corrId, connAlias, cmd) =
|
||||
ReplyOff -> return ()
|
||||
respond CON
|
||||
|
||||
sendMessage :: SMP.MsgBody -> m ()
|
||||
sendMessage :: MsgBody -> m ()
|
||||
sendMessage msgBody =
|
||||
withStore (`getConn` connAlias) >>= \case
|
||||
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq
|
||||
@@ -222,7 +222,7 @@ processSMPTransmission c@AgentClient {sndQ} (srv, rId, cmd) = do
|
||||
notify :: ConnAlias -> ACommand 'Agent -> m ()
|
||||
notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg)
|
||||
|
||||
connectToSendQueue :: AgentMonad m => AgentClient -> SendQueue -> SMP.SenderKey -> m ()
|
||||
connectToSendQueue :: AgentMonad m => AgentClient -> SendQueue -> SenderKey -> m ()
|
||||
connectToSendQueue c sq senderKey = do
|
||||
sendConfirmation c sq senderKey
|
||||
withStore $ \st -> updateSndQueueStatus st sq Confirmed
|
||||
@@ -233,7 +233,7 @@ decryptMessage :: MonadUnliftIO m => PrivateKey -> ByteString -> m ByteString
|
||||
decryptMessage _decryptKey = return
|
||||
|
||||
newSendQueue ::
|
||||
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SendQueue, SMP.SenderKey)
|
||||
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SendQueue, SenderKey)
|
||||
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) = do
|
||||
g <- asks idsDrg
|
||||
senderKey <- atomically $ randomBytes 16 g -- TODO replace with cryptographic key pair
|
||||
|
||||
@@ -41,8 +41,7 @@ import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Server (randomBytes)
|
||||
import Simplex.Messaging.Server.Transmission (PrivateKey, PublicKey, SenderKey)
|
||||
import qualified Simplex.Messaging.Server.Transmission as SMP
|
||||
import Simplex.Messaging.Types (ErrorType (AUTH), MsgBody, PrivateKey, PublicKey, QueueId, SenderKey)
|
||||
import UnliftIO.Concurrent
|
||||
import UnliftIO.Exception (SomeException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
@@ -66,7 +65,7 @@ newAgentClient cc qSize = do
|
||||
writeTVar cc clientId
|
||||
return AgentClient {rcvQ, sndQ, msgQ, smpClients, clientId}
|
||||
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError ErrorType m)
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
||||
getSMPServerClient AgentClient {smpClients, msgQ} srv =
|
||||
@@ -80,7 +79,7 @@ getSMPServerClient AgentClient {smpClients, msgQ} srv =
|
||||
atomically . modifyTVar smpClients $ M.insert srv c
|
||||
return c
|
||||
|
||||
throwErr :: ErrorType -> SomeException -> m a
|
||||
throwErr :: AgentErrorType -> SomeException -> m a
|
||||
throwErr err e = do
|
||||
liftIO . putStrLn $ "Exception: " ++ show e -- TODO remove
|
||||
throwError err
|
||||
@@ -94,18 +93,18 @@ withSMP c srv action =
|
||||
liftIO (first smpClientError <$> runExceptT (action smp))
|
||||
>>= liftEither
|
||||
|
||||
smpClientError :: SMPClientError -> ErrorType
|
||||
smpClientError :: SMPClientError -> AgentErrorType
|
||||
smpClientError = \case
|
||||
SMPServerError e -> SMP e
|
||||
-- TODO handle other errors
|
||||
_ -> INTERNAL
|
||||
|
||||
logServerError :: ErrorType -> m a
|
||||
logServerError :: AgentErrorType -> m a
|
||||
logServerError e = do
|
||||
logServer "<--" c srv "" $ (B.pack . show) e
|
||||
throwError e
|
||||
|
||||
withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> SMP.QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withLogSMP c srv qId cmdStr action = do
|
||||
logServer "-->" c srv qId cmdStr
|
||||
res <- withSMP c srv action
|
||||
@@ -136,7 +135,7 @@ newReceiveQueue c srv = do
|
||||
}
|
||||
return (rcvQueue, SMPQueueInfo srv sId encryptKey)
|
||||
|
||||
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> SMP.QueueId -> ByteString -> m ()
|
||||
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
|
||||
logServer dir AgentClient {clientId} SMPServer {host, port} qId cmdStr =
|
||||
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> (B.pack . show) clientId <> ")", dir, server, ":", logSecret qId, cmdStr]
|
||||
where
|
||||
@@ -152,7 +151,7 @@ sendConfirmation c SendQueue {server, sndId} senderKey = do
|
||||
withLogSMP c server sndId "SEND <KEY>" $ \smp ->
|
||||
sendSMPMessage smp "" sndId msg
|
||||
where
|
||||
mkConfirmation :: m SMP.MsgBody
|
||||
mkConfirmation :: m MsgBody
|
||||
mkConfirmation = do
|
||||
let msg = serializeSMPMessage $ SMPConfirmation senderKey
|
||||
-- TODO encryption
|
||||
@@ -172,7 +171,7 @@ sendHello c SendQueue {server, sndId, sndPrivateKey, encryptKey} = do
|
||||
send 0 _ _ = throwE SMPResponseTimeout -- TODO different error
|
||||
send retry msg smp =
|
||||
sendSMPMessage smp sndPrivateKey sndId msg `catchE` \case
|
||||
SMPServerError SMP.AUTH -> do
|
||||
SMPServerError AUTH -> do
|
||||
threadDelay 100000
|
||||
send (retry - 1) msg smp
|
||||
e -> throwE e
|
||||
|
||||
@@ -10,6 +10,7 @@ import Crypto.Random
|
||||
import Network.Socket
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Types
|
||||
import Simplex.Messaging.Client
|
||||
import UnliftIO.STM
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import Data.Time.Clock (UTCTime)
|
||||
import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Server.Transmission (PrivateKey, PublicKey, RecipientId, SenderId)
|
||||
import Simplex.Messaging.Types
|
||||
|
||||
data ReceiveQueue = ReceiveQueue
|
||||
{ server :: SMPServer,
|
||||
|
||||
@@ -8,9 +8,7 @@
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
@@ -19,47 +17,17 @@ module Simplex.Messaging.Agent.Store.SQLite where
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.Int (Int64)
|
||||
import Data.Maybe
|
||||
import qualified Data.Text as T
|
||||
import Data.Time
|
||||
import Database.SQLite.Simple hiding (Connection)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Database.SQLite.Simple.FromField
|
||||
import Database.SQLite.Simple.Internal (Field (..))
|
||||
import Database.SQLite.Simple.Ok
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
import Database.SQLite.Simple.ToField
|
||||
import Network.Socket
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Schema
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Types
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Util
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Server.Transmission (RecipientId, SenderId)
|
||||
import Simplex.Messaging.Util
|
||||
import Text.Read
|
||||
import qualified UnliftIO.Exception as E
|
||||
import Simplex.Messaging.Types
|
||||
import UnliftIO.STM
|
||||
|
||||
addRcvQueueQuery :: Query
|
||||
addRcvQueueQuery =
|
||||
[sql|
|
||||
INSERT INTO receive_queues
|
||||
( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode)
|
||||
VALUES
|
||||
(:server_id,:rcv_id,:rcv_private_key,:snd_id,:snd_key,:decrypt_key,:verify_key,:status,:ack_mode);
|
||||
|]
|
||||
|
||||
data SQLiteStore = SQLiteStore
|
||||
{ dbFilename :: String,
|
||||
conn :: DB.Connection,
|
||||
serversLock :: TMVar (),
|
||||
rcvQueuesLock :: TMVar (),
|
||||
sndQueuesLock :: TMVar (),
|
||||
connectionsLock :: TMVar (),
|
||||
messagesLock :: TMVar ()
|
||||
}
|
||||
|
||||
newSQLiteStore :: MonadUnliftIO m => String -> m SQLiteStore
|
||||
newSQLiteStore dbFilename = do
|
||||
conn <- liftIO $ DB.open dbFilename
|
||||
@@ -80,329 +48,6 @@ newSQLiteStore dbFilename = do
|
||||
messagesLock
|
||||
}
|
||||
|
||||
type QueueRowId = Int64
|
||||
|
||||
type ConnectionRowId = Int64
|
||||
|
||||
fromFieldToReadable :: forall a. (Read a, E.Typeable a) => Field -> Ok a
|
||||
fromFieldToReadable = \case
|
||||
f@(Field (SQLText t) _) ->
|
||||
let str = T.unpack t
|
||||
in case readMaybe str of
|
||||
Just x -> Ok x
|
||||
_ -> returnError ConversionFailed f ("invalid string: " <> str)
|
||||
f -> returnError ConversionFailed f "expecting SQLText column type"
|
||||
|
||||
withLock :: MonadUnliftIO m => SQLiteStore -> (SQLiteStore -> TMVar ()) -> (DB.Connection -> m a) -> m a
|
||||
withLock st tableLock f = do
|
||||
let lock = tableLock st
|
||||
E.bracket_
|
||||
(atomically $ takeTMVar lock)
|
||||
(atomically $ putTMVar lock ())
|
||||
(f $ conn st)
|
||||
|
||||
insertWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m Int64
|
||||
insertWithLock st tableLock queryStr q = do
|
||||
withLock st tableLock $ \c -> liftIO $ do
|
||||
DB.execute c queryStr q
|
||||
DB.lastInsertRowId c
|
||||
|
||||
executeWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m ()
|
||||
executeWithLock st tableLock queryStr q = do
|
||||
withLock st tableLock $ \c -> liftIO $ do
|
||||
DB.execute c queryStr q
|
||||
|
||||
instance ToRow SMPServer where
|
||||
toRow SMPServer {host, port, keyHash} = toRow (host, port, keyHash)
|
||||
|
||||
instance FromRow SMPServer where
|
||||
fromRow = SMPServer <$> field <*> field <*> field
|
||||
|
||||
upsertServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServer -> m SMPServerId
|
||||
upsertServer SQLiteStore {conn} srv@SMPServer {host, port} = do
|
||||
r <- liftIO $ do
|
||||
DB.execute
|
||||
conn
|
||||
[sql|
|
||||
INSERT INTO servers (host, port, key_hash) VALUES (?, ?, ?)
|
||||
ON CONFLICT (host, port) DO UPDATE SET
|
||||
host=excluded.host,
|
||||
port=excluded.port,
|
||||
key_hash=excluded.key_hash;
|
||||
|]
|
||||
srv
|
||||
DB.queryNamed
|
||||
conn
|
||||
"SELECT server_id FROM servers WHERE host = :host AND port = :port"
|
||||
[":host" := host, ":port" := port]
|
||||
case r of
|
||||
[Only serverId] -> return serverId
|
||||
_ -> throwError SEInternal
|
||||
|
||||
getServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServerId -> m SMPServer
|
||||
getServer SQLiteStore {conn} serverId = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
"SELECT host, port, key_hash FROM servers WHERE server_id = :server_id"
|
||||
[":server_id" := serverId]
|
||||
case r of
|
||||
[smpServer] -> return smpServer
|
||||
_ -> throwError SENotFound
|
||||
|
||||
instance ToField AckMode where toField (AckMode mode) = toField $ show mode
|
||||
|
||||
instance FromField AckMode where fromField = AckMode <$$> fromFieldToReadable
|
||||
|
||||
instance ToField QueueStatus where toField = toField . show
|
||||
|
||||
instance FromField QueueStatus where fromField = fromFieldToReadable
|
||||
|
||||
instance ToRow ReceiveQueue where
|
||||
toRow ReceiveQueue {rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode} =
|
||||
toRow (rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode)
|
||||
|
||||
instance FromRow ReceiveQueue where
|
||||
fromRow = ReceiveQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field <*> field <*> field
|
||||
|
||||
-- TODO refactor into a single query with join
|
||||
getRcvQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m ReceiveQueue
|
||||
getRcvQueue st@SQLiteStore {conn} queueRowId = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
[sql|
|
||||
SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode
|
||||
FROM receive_queues
|
||||
WHERE receive_queue_id = :rowId;
|
||||
|]
|
||||
[":rowId" := queueRowId]
|
||||
case r of
|
||||
[Only serverId :. rcvQueue] ->
|
||||
(\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId
|
||||
_ -> throwError SENotFound
|
||||
|
||||
getRcvQueueByRecipientId :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> RecipientId -> HostName -> Maybe ServiceName -> m ReceiveQueue
|
||||
getRcvQueueByRecipientId st@SQLiteStore {conn} rcvId host port = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
[sql|
|
||||
SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode
|
||||
FROM receive_queues
|
||||
WHERE rcv_id = :rcvId AND server_id IN (
|
||||
SELECT server_id
|
||||
FROM servers
|
||||
WHERE host = :host AND port = :port
|
||||
);
|
||||
|]
|
||||
[":rcvId" := rcvId, ":host" := host, ":port" := port]
|
||||
case r of
|
||||
[Only serverId :. rcvQueue] ->
|
||||
(\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId
|
||||
_ -> throwError SENotFound
|
||||
|
||||
-- TODO refactor into a single query with join
|
||||
getSndQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m SendQueue
|
||||
getSndQueue st@SQLiteStore {conn} queueRowId = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
[sql|
|
||||
SELECT server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode
|
||||
FROM send_queues
|
||||
WHERE send_queue_id = :rowId;
|
||||
|]
|
||||
[":rowId" := queueRowId]
|
||||
case r of
|
||||
[Only serverId :. sndQueue] ->
|
||||
(\srv -> (sndQueue {server = srv} :: SendQueue)) <$> getServer st serverId
|
||||
_ -> throwError SENotFound
|
||||
|
||||
insertRcvQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> ReceiveQueue -> m QueueRowId
|
||||
insertRcvQueue store serverId rcvQueue =
|
||||
insertWithLock
|
||||
store
|
||||
rcvQueuesLock
|
||||
[sql|
|
||||
INSERT INTO receive_queues
|
||||
( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode)
|
||||
VALUES (?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
(Only serverId :. rcvQueue)
|
||||
|
||||
insertRcvConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
|
||||
insertRcvConnection store connAlias rcvQueueId =
|
||||
void $
|
||||
insertWithLock
|
||||
store
|
||||
connectionsLock
|
||||
"INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,?,NULL);"
|
||||
(Only connAlias :. Only rcvQueueId)
|
||||
|
||||
updateRcvConnectionWithSndQueue :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
|
||||
updateRcvConnectionWithSndQueue store connAlias sndQueueId =
|
||||
executeWithLock
|
||||
store
|
||||
connectionsLock
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET send_queue_id = ?
|
||||
WHERE conn_alias = ?;
|
||||
|]
|
||||
(Only sndQueueId :. Only connAlias)
|
||||
|
||||
instance ToRow SendQueue where
|
||||
toRow SendQueue {sndId, sndPrivateKey, encryptKey, signKey, status, ackMode} =
|
||||
toRow (sndId, sndPrivateKey, encryptKey, signKey, status, ackMode)
|
||||
|
||||
instance FromRow SendQueue where
|
||||
fromRow = SendQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field
|
||||
|
||||
insertSndQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> SendQueue -> m QueueRowId
|
||||
insertSndQueue store serverId sndQueue =
|
||||
insertWithLock
|
||||
store
|
||||
sndQueuesLock
|
||||
[sql|
|
||||
INSERT INTO send_queues
|
||||
( server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode)
|
||||
VALUES (?,?,?,?,?,?,?);
|
||||
|]
|
||||
(Only serverId :. sndQueue)
|
||||
|
||||
insertSndConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
|
||||
insertSndConnection store connAlias sndQueueId =
|
||||
void $
|
||||
insertWithLock
|
||||
store
|
||||
connectionsLock
|
||||
"INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,NULL,?);"
|
||||
(Only connAlias :. Only sndQueueId)
|
||||
|
||||
updateSndConnectionWithRcvQueue :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
|
||||
updateSndConnectionWithRcvQueue store connAlias rcvQueueId =
|
||||
executeWithLock
|
||||
store
|
||||
connectionsLock
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET receive_queue_id = ?
|
||||
WHERE conn_alias = ?;
|
||||
|]
|
||||
(Only rcvQueueId :. Only connAlias)
|
||||
|
||||
getConnection :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> ConnAlias -> m (Maybe QueueRowId, Maybe QueueRowId)
|
||||
getConnection SQLiteStore {conn} connAlias = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
"SELECT receive_queue_id, send_queue_id FROM connections WHERE conn_alias = :conn_alias"
|
||||
[":conn_alias" := connAlias]
|
||||
case r of
|
||||
[queueIds] -> return queueIds
|
||||
_ -> throwError SEInternal
|
||||
|
||||
instance FromRow ConnAlias where
|
||||
fromRow = field
|
||||
|
||||
getConnAliasByRcvQueue :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> RecipientId -> m ConnAlias
|
||||
getConnAliasByRcvQueue SQLiteStore {conn} rcvId = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
[sql|
|
||||
SELECT c.conn_alias
|
||||
FROM connections c
|
||||
JOIN receive_queues rq
|
||||
ON c.receive_queue_id = rq.receive_queue_id
|
||||
WHERE rq.rcv_id = :rcvId;
|
||||
|]
|
||||
[":rcvId" := rcvId]
|
||||
case r of
|
||||
[connAlias] -> return connAlias
|
||||
_ -> throwError SEInternal
|
||||
|
||||
deleteRcvQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m ()
|
||||
deleteRcvQueue store rcvQueueId = do
|
||||
executeWithLock
|
||||
store
|
||||
rcvQueuesLock
|
||||
"DELETE FROM receive_queues WHERE receive_queue_id = ?"
|
||||
(Only rcvQueueId)
|
||||
|
||||
deleteSndQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m ()
|
||||
deleteSndQueue store sndQueueId = do
|
||||
executeWithLock
|
||||
store
|
||||
sndQueuesLock
|
||||
"DELETE FROM send_queues WHERE send_queue_id = ?"
|
||||
(Only sndQueueId)
|
||||
|
||||
deleteConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> m ()
|
||||
deleteConnection store connAlias = do
|
||||
executeWithLock
|
||||
store
|
||||
connectionsLock
|
||||
"DELETE FROM connections WHERE conn_alias = ?"
|
||||
(Only connAlias)
|
||||
|
||||
updateReceiveQueueStatus :: MonadUnliftIO m => SQLiteStore -> RecipientId -> HostName -> Maybe ServiceName -> QueueStatus -> m ()
|
||||
updateReceiveQueueStatus store rcvQueueId host port status =
|
||||
executeWithLock
|
||||
store
|
||||
rcvQueuesLock
|
||||
[sql|
|
||||
UPDATE receive_queues
|
||||
SET status = ?
|
||||
WHERE rcv_id = ?
|
||||
AND server_id IN (
|
||||
SELECT server_id
|
||||
FROM servers
|
||||
WHERE host = ? AND port = ?
|
||||
);
|
||||
|]
|
||||
(Only status :. Only rcvQueueId :. Only host :. Only port)
|
||||
|
||||
updateSendQueueStatus :: MonadUnliftIO m => SQLiteStore -> SenderId -> HostName -> Maybe ServiceName -> QueueStatus -> m ()
|
||||
updateSendQueueStatus store sndQueueId host port status =
|
||||
executeWithLock
|
||||
store
|
||||
sndQueuesLock
|
||||
[sql|
|
||||
UPDATE send_queues
|
||||
SET status = ?
|
||||
WHERE snd_id = ?
|
||||
AND server_id IN (
|
||||
SELECT server_id
|
||||
FROM servers
|
||||
WHERE host = ? AND port = ?
|
||||
);
|
||||
|]
|
||||
(Only status :. Only sndQueueId :. Only host :. Only port)
|
||||
|
||||
instance ToField QueueDirection where toField = toField . show
|
||||
|
||||
-- TODO add parser and serializer for DeliveryStatus? Pass DeliveryStatus?
|
||||
insertMsg :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> Message -> m ()
|
||||
insertMsg store connAlias qDirection agentMsgId msg = do
|
||||
tstamp <- liftIO getCurrentTime
|
||||
void $
|
||||
insertWithLock
|
||||
store
|
||||
messagesLock
|
||||
[sql|
|
||||
INSERT INTO messages (conn_alias, agent_msg_id, timestamp, message, direction, msg_status)
|
||||
VALUES (?,?,?,?,?,"MDTransmitted");
|
||||
|]
|
||||
(Only connAlias :. Only agentMsgId :. Only tstamp :. Only qDirection :. Only msg)
|
||||
|
||||
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
|
||||
addServer store smpServer = upsertServer store smpServer
|
||||
|
||||
|
||||
29
src/Simplex/Messaging/Agent/Store/SQLite/Types.hs
Normal file
29
src/Simplex/Messaging/Agent/Store/SQLite/Types.hs
Normal file
@@ -0,0 +1,29 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Types where
|
||||
|
||||
import Data.Int (Int64)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import UnliftIO.STM
|
||||
|
||||
data SQLiteStore = SQLiteStore
|
||||
{ dbFilename :: String,
|
||||
conn :: DB.Connection,
|
||||
serversLock :: TMVar (),
|
||||
rcvQueuesLock :: TMVar (),
|
||||
sndQueuesLock :: TMVar (),
|
||||
connectionsLock :: TMVar (),
|
||||
messagesLock :: TMVar ()
|
||||
}
|
||||
|
||||
type QueueRowId = Int64
|
||||
|
||||
type ConnectionRowId = Int64
|
||||
369
src/Simplex/Messaging/Agent/Store/SQLite/Util.hs
Normal file
369
src/Simplex/Messaging/Agent/Store/SQLite/Util.hs
Normal file
@@ -0,0 +1,369 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Util where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.Int (Int64)
|
||||
import qualified Data.Text as T
|
||||
import Data.Time
|
||||
import Database.SQLite.Simple hiding (Connection)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Database.SQLite.Simple.FromField
|
||||
import Database.SQLite.Simple.Internal (Field (..))
|
||||
import Database.SQLite.Simple.Ok
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
import Database.SQLite.Simple.ToField
|
||||
import Network.Socket
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Types
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Util
|
||||
import Text.Read
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
addRcvQueueQuery :: Query
|
||||
addRcvQueueQuery =
|
||||
[sql|
|
||||
INSERT INTO receive_queues
|
||||
( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode)
|
||||
VALUES
|
||||
(:server_id,:rcv_id,:rcv_private_key,:snd_id,:snd_key,:decrypt_key,:verify_key,:status,:ack_mode);
|
||||
|]
|
||||
|
||||
fromFieldToReadable :: forall a. (Read a, E.Typeable a) => Field -> Ok a
|
||||
fromFieldToReadable = \case
|
||||
f@(Field (SQLText t) _) ->
|
||||
let str = T.unpack t
|
||||
in case readMaybe str of
|
||||
Just x -> Ok x
|
||||
_ -> returnError ConversionFailed f ("invalid string: " <> str)
|
||||
f -> returnError ConversionFailed f "expecting SQLText column type"
|
||||
|
||||
withLock :: MonadUnliftIO m => SQLiteStore -> (SQLiteStore -> TMVar ()) -> (DB.Connection -> m a) -> m a
|
||||
withLock st tableLock f = do
|
||||
let lock = tableLock st
|
||||
E.bracket_
|
||||
(atomically $ takeTMVar lock)
|
||||
(atomically $ putTMVar lock ())
|
||||
(f $ conn st)
|
||||
|
||||
insertWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m Int64
|
||||
insertWithLock st tableLock queryStr q = do
|
||||
withLock st tableLock $ \c -> liftIO $ do
|
||||
DB.execute c queryStr q
|
||||
DB.lastInsertRowId c
|
||||
|
||||
executeWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m ()
|
||||
executeWithLock st tableLock queryStr q = do
|
||||
withLock st tableLock $ \c -> liftIO $ do
|
||||
DB.execute c queryStr q
|
||||
|
||||
instance ToRow SMPServer where
|
||||
toRow SMPServer {host, port, keyHash} = toRow (host, port, keyHash)
|
||||
|
||||
instance FromRow SMPServer where
|
||||
fromRow = SMPServer <$> field <*> field <*> field
|
||||
|
||||
upsertServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServer -> m SMPServerId
|
||||
upsertServer SQLiteStore {conn} srv@SMPServer {host, port} = do
|
||||
r <- liftIO $ do
|
||||
DB.execute
|
||||
conn
|
||||
[sql|
|
||||
INSERT INTO servers (host, port, key_hash) VALUES (?, ?, ?)
|
||||
ON CONFLICT (host, port) DO UPDATE SET
|
||||
host=excluded.host,
|
||||
port=excluded.port,
|
||||
key_hash=excluded.key_hash;
|
||||
|]
|
||||
srv
|
||||
DB.queryNamed
|
||||
conn
|
||||
"SELECT server_id FROM servers WHERE host = :host AND port = :port"
|
||||
[":host" := host, ":port" := port]
|
||||
case r of
|
||||
[Only serverId] -> return serverId
|
||||
_ -> throwError SEInternal
|
||||
|
||||
getServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServerId -> m SMPServer
|
||||
getServer SQLiteStore {conn} serverId = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
"SELECT host, port, key_hash FROM servers WHERE server_id = :server_id"
|
||||
[":server_id" := serverId]
|
||||
case r of
|
||||
[smpServer] -> return smpServer
|
||||
_ -> throwError SENotFound
|
||||
|
||||
instance ToField AckMode where toField (AckMode mode) = toField $ show mode
|
||||
|
||||
instance FromField AckMode where fromField = AckMode <$$> fromFieldToReadable
|
||||
|
||||
instance ToField QueueStatus where toField = toField . show
|
||||
|
||||
instance FromField QueueStatus where fromField = fromFieldToReadable
|
||||
|
||||
instance ToRow ReceiveQueue where
|
||||
toRow ReceiveQueue {rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode} =
|
||||
toRow (rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode)
|
||||
|
||||
instance FromRow ReceiveQueue where
|
||||
fromRow = ReceiveQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field <*> field <*> field
|
||||
|
||||
-- TODO refactor into a single query with join
|
||||
getRcvQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m ReceiveQueue
|
||||
getRcvQueue st@SQLiteStore {conn} queueRowId = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
[sql|
|
||||
SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode
|
||||
FROM receive_queues
|
||||
WHERE receive_queue_id = :rowId;
|
||||
|]
|
||||
[":rowId" := queueRowId]
|
||||
case r of
|
||||
[Only serverId :. rcvQueue] ->
|
||||
(\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId
|
||||
_ -> throwError SENotFound
|
||||
|
||||
getRcvQueueByRecipientId :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> RecipientId -> HostName -> Maybe ServiceName -> m ReceiveQueue
|
||||
getRcvQueueByRecipientId st@SQLiteStore {conn} rcvId host port = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
[sql|
|
||||
SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode
|
||||
FROM receive_queues
|
||||
WHERE rcv_id = :rcvId AND server_id IN (
|
||||
SELECT server_id
|
||||
FROM servers
|
||||
WHERE host = :host AND port = :port
|
||||
);
|
||||
|]
|
||||
[":rcvId" := rcvId, ":host" := host, ":port" := port]
|
||||
case r of
|
||||
[Only serverId :. rcvQueue] ->
|
||||
(\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId
|
||||
_ -> throwError SENotFound
|
||||
|
||||
-- TODO refactor into a single query with join
|
||||
getSndQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m SendQueue
|
||||
getSndQueue st@SQLiteStore {conn} queueRowId = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
[sql|
|
||||
SELECT server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode
|
||||
FROM send_queues
|
||||
WHERE send_queue_id = :rowId;
|
||||
|]
|
||||
[":rowId" := queueRowId]
|
||||
case r of
|
||||
[Only serverId :. sndQueue] ->
|
||||
(\srv -> (sndQueue {server = srv} :: SendQueue)) <$> getServer st serverId
|
||||
_ -> throwError SENotFound
|
||||
|
||||
insertRcvQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> ReceiveQueue -> m QueueRowId
|
||||
insertRcvQueue store serverId rcvQueue =
|
||||
insertWithLock
|
||||
store
|
||||
rcvQueuesLock
|
||||
[sql|
|
||||
INSERT INTO receive_queues
|
||||
( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode)
|
||||
VALUES (?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
(Only serverId :. rcvQueue)
|
||||
|
||||
insertRcvConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
|
||||
insertRcvConnection store connAlias rcvQueueId =
|
||||
void $
|
||||
insertWithLock
|
||||
store
|
||||
connectionsLock
|
||||
"INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,?,NULL);"
|
||||
(Only connAlias :. Only rcvQueueId)
|
||||
|
||||
updateRcvConnectionWithSndQueue :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
|
||||
updateRcvConnectionWithSndQueue store connAlias sndQueueId =
|
||||
executeWithLock
|
||||
store
|
||||
connectionsLock
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET send_queue_id = ?
|
||||
WHERE conn_alias = ?;
|
||||
|]
|
||||
(Only sndQueueId :. Only connAlias)
|
||||
|
||||
instance ToRow SendQueue where
|
||||
toRow SendQueue {sndId, sndPrivateKey, encryptKey, signKey, status, ackMode} =
|
||||
toRow (sndId, sndPrivateKey, encryptKey, signKey, status, ackMode)
|
||||
|
||||
instance FromRow SendQueue where
|
||||
fromRow = SendQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field
|
||||
|
||||
insertSndQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> SendQueue -> m QueueRowId
|
||||
insertSndQueue store serverId sndQueue =
|
||||
insertWithLock
|
||||
store
|
||||
sndQueuesLock
|
||||
[sql|
|
||||
INSERT INTO send_queues
|
||||
( server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode)
|
||||
VALUES (?,?,?,?,?,?,?);
|
||||
|]
|
||||
(Only serverId :. sndQueue)
|
||||
|
||||
insertSndConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
|
||||
insertSndConnection store connAlias sndQueueId =
|
||||
void $
|
||||
insertWithLock
|
||||
store
|
||||
connectionsLock
|
||||
"INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,NULL,?);"
|
||||
(Only connAlias :. Only sndQueueId)
|
||||
|
||||
updateSndConnectionWithRcvQueue :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
|
||||
updateSndConnectionWithRcvQueue store connAlias rcvQueueId =
|
||||
executeWithLock
|
||||
store
|
||||
connectionsLock
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET receive_queue_id = ?
|
||||
WHERE conn_alias = ?;
|
||||
|]
|
||||
(Only rcvQueueId :. Only connAlias)
|
||||
|
||||
getConnection :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> ConnAlias -> m (Maybe QueueRowId, Maybe QueueRowId)
|
||||
getConnection SQLiteStore {conn} connAlias = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
"SELECT receive_queue_id, send_queue_id FROM connections WHERE conn_alias = :conn_alias"
|
||||
[":conn_alias" := connAlias]
|
||||
case r of
|
||||
[queueIds] -> return queueIds
|
||||
_ -> throwError SEInternal
|
||||
|
||||
instance FromRow ConnAlias where
|
||||
fromRow = field
|
||||
|
||||
getConnAliasByRcvQueue :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> RecipientId -> m ConnAlias
|
||||
getConnAliasByRcvQueue SQLiteStore {conn} rcvId = do
|
||||
r <-
|
||||
liftIO $
|
||||
DB.queryNamed
|
||||
conn
|
||||
[sql|
|
||||
SELECT c.conn_alias
|
||||
FROM connections c
|
||||
JOIN receive_queues rq
|
||||
ON c.receive_queue_id = rq.receive_queue_id
|
||||
WHERE rq.rcv_id = :rcvId;
|
||||
|]
|
||||
[":rcvId" := rcvId]
|
||||
case r of
|
||||
[connAlias] -> return connAlias
|
||||
_ -> throwError SEInternal
|
||||
|
||||
deleteRcvQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m ()
|
||||
deleteRcvQueue store rcvQueueId = do
|
||||
executeWithLock
|
||||
store
|
||||
rcvQueuesLock
|
||||
"DELETE FROM receive_queues WHERE receive_queue_id = ?"
|
||||
(Only rcvQueueId)
|
||||
|
||||
deleteSndQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m ()
|
||||
deleteSndQueue store sndQueueId = do
|
||||
executeWithLock
|
||||
store
|
||||
sndQueuesLock
|
||||
"DELETE FROM send_queues WHERE send_queue_id = ?"
|
||||
(Only sndQueueId)
|
||||
|
||||
deleteConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> m ()
|
||||
deleteConnection store connAlias = do
|
||||
executeWithLock
|
||||
store
|
||||
connectionsLock
|
||||
"DELETE FROM connections WHERE conn_alias = ?"
|
||||
(Only connAlias)
|
||||
|
||||
updateReceiveQueueStatus :: MonadUnliftIO m => SQLiteStore -> RecipientId -> HostName -> Maybe ServiceName -> QueueStatus -> m ()
|
||||
updateReceiveQueueStatus store rcvQueueId host port status =
|
||||
executeWithLock
|
||||
store
|
||||
rcvQueuesLock
|
||||
[sql|
|
||||
UPDATE receive_queues
|
||||
SET status = ?
|
||||
WHERE rcv_id = ?
|
||||
AND server_id IN (
|
||||
SELECT server_id
|
||||
FROM servers
|
||||
WHERE host = ? AND port = ?
|
||||
);
|
||||
|]
|
||||
(Only status :. Only rcvQueueId :. Only host :. Only port)
|
||||
|
||||
updateSendQueueStatus :: MonadUnliftIO m => SQLiteStore -> SenderId -> HostName -> Maybe ServiceName -> QueueStatus -> m ()
|
||||
updateSendQueueStatus store sndQueueId host port status =
|
||||
executeWithLock
|
||||
store
|
||||
sndQueuesLock
|
||||
[sql|
|
||||
UPDATE send_queues
|
||||
SET status = ?
|
||||
WHERE snd_id = ?
|
||||
AND server_id IN (
|
||||
SELECT server_id
|
||||
FROM servers
|
||||
WHERE host = ? AND port = ?
|
||||
);
|
||||
|]
|
||||
(Only status :. Only sndQueueId :. Only host :. Only port)
|
||||
|
||||
instance ToField QueueDirection where toField = toField . show
|
||||
|
||||
-- TODO add parser and serializer for DeliveryStatus? Pass DeliveryStatus?
|
||||
insertMsg :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> Message -> m ()
|
||||
insertMsg store connAlias qDirection agentMsgId msg = do
|
||||
tstamp <- liftIO getCurrentTime
|
||||
void $
|
||||
insertWithLock
|
||||
store
|
||||
messagesLock
|
||||
[sql|
|
||||
INSERT INTO messages (conn_alias, agent_msg_id, timestamp, message, direction, msg_status)
|
||||
VALUES (?,?,?,?,?,"MDTransmitted");
|
||||
|]
|
||||
(Only connAlias :. Only agentMsgId :. Only tstamp :. Only qDirection :. Only msg)
|
||||
@@ -30,17 +30,8 @@ import Data.Typeable ()
|
||||
import Network.Socket
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import Simplex.Messaging.Server.Transmission
|
||||
( CorrId (..),
|
||||
Encoded,
|
||||
MsgBody,
|
||||
PublicKey,
|
||||
SenderId,
|
||||
errBadParameters,
|
||||
errMessageBody,
|
||||
)
|
||||
import qualified Simplex.Messaging.Server.Transmission as SMP
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Types (CorrId (..), Encoded, ErrorType, MsgBody, PublicKey, SenderId, errBadParameters, errMessageBody)
|
||||
import Simplex.Messaging.Util
|
||||
import System.IO
|
||||
import Text.Read
|
||||
@@ -50,7 +41,7 @@ type ARawTransmission = (ByteString, ByteString, ByteString)
|
||||
|
||||
type ATransmission p = (CorrId, ConnAlias, ACommand p)
|
||||
|
||||
type ATransmissionOrError p = (CorrId, ConnAlias, Either ErrorType (ACommand p))
|
||||
type ATransmissionOrError p = (CorrId, ConnAlias, Either AgentErrorType (ACommand p))
|
||||
|
||||
data AParty = Agent | Client
|
||||
deriving (Eq, Show)
|
||||
@@ -93,7 +84,7 @@ data ACommand (p :: AParty) where
|
||||
-- OFF :: ACommand Client
|
||||
-- DEL :: ACommand Client
|
||||
OK :: ACommand Agent
|
||||
ERR :: ErrorType -> ACommand Agent
|
||||
ERR :: AgentErrorType -> ACommand Agent
|
||||
|
||||
deriving instance Show (ACommand p)
|
||||
|
||||
@@ -115,7 +106,7 @@ data AMessage where
|
||||
A_MSG :: MsgBody -> AMessage
|
||||
deriving (Show)
|
||||
|
||||
parseSMPMessage :: ByteString -> Either ErrorType SMPMessage
|
||||
parseSMPMessage :: ByteString -> Either AgentErrorType SMPMessage
|
||||
parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ SYNTAX errBadMessage
|
||||
where
|
||||
smpMessageP :: Parser SMPMessage
|
||||
@@ -179,13 +170,13 @@ base64P = do
|
||||
pad <- A.takeWhile (== '=')
|
||||
either fail pure $ decode (str <> pad)
|
||||
|
||||
parseAgentMessage :: ByteString -> Either ErrorType AMessage
|
||||
parseAgentMessage :: ByteString -> Either AgentErrorType AMessage
|
||||
parseAgentMessage = parse agentMessageP $ SYNTAX errBadMessage
|
||||
|
||||
parse :: Parser a -> e -> (ByteString -> Either e a)
|
||||
parse parser err = first (const err) . A.parseOnly (parser <* A.endOfInput)
|
||||
|
||||
errParams :: Either ErrorType a
|
||||
errParams :: Either AgentErrorType a
|
||||
errParams = Left $ SYNTAX errBadParameters
|
||||
|
||||
serializeAgentMessage :: AMessage -> ByteString
|
||||
@@ -241,12 +232,12 @@ data MsgStatus = MsgOk | MsgError MsgErrorType
|
||||
data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash
|
||||
deriving (Show)
|
||||
|
||||
data ErrorType
|
||||
data AgentErrorType
|
||||
= UNKNOWN
|
||||
| PROHIBITED
|
||||
| SYNTAX Int
|
||||
| BROKER Natural
|
||||
| SMP SMP.ErrorType
|
||||
| SMP ErrorType
|
||||
| SIZE
|
||||
| STORE StoreError
|
||||
| INTERNAL -- etc. TODO SYNTAX Natural
|
||||
@@ -315,7 +306,7 @@ parseCommandP =
|
||||
<|> "NO_ID " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal)
|
||||
<|> "HASH" $> MsgBadHash
|
||||
|
||||
parseCommand :: ByteString -> Either ErrorType ACmd
|
||||
parseCommand :: ByteString -> Either AgentErrorType ACmd
|
||||
parseCommand = parse parseCommandP $ SYNTAX errBadCommand
|
||||
|
||||
serializeCommand :: ACommand p -> ByteString
|
||||
@@ -365,8 +356,8 @@ tGetRaw h = do
|
||||
return (corrId, connAlias, command)
|
||||
|
||||
tPut :: MonadIO m => Handle -> ATransmission p -> m ()
|
||||
tPut h (corrId, connAlias, command) =
|
||||
liftIO $ tPutRaw h (bs corrId, connAlias, serializeCommand command)
|
||||
tPut h (CorrId corrId, connAlias, command) =
|
||||
liftIO $ tPutRaw h (corrId, connAlias, serializeCommand command)
|
||||
|
||||
-- | get client and agent transmissions
|
||||
tGet :: forall m p. MonadIO m => SAParty p -> Handle -> m (ATransmissionOrError p)
|
||||
@@ -378,12 +369,12 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
fullCmd <- either (return . Left) cmdWithMsgBody cmd
|
||||
return (CorrId corrId, connAlias, fullCmd)
|
||||
|
||||
fromParty :: ACmd -> Either ErrorType (ACommand p)
|
||||
fromParty :: ACmd -> Either AgentErrorType (ACommand p)
|
||||
fromParty (ACmd (p :: p1) cmd) = case testEquality party p of
|
||||
Just Refl -> Right cmd
|
||||
_ -> Left PROHIBITED
|
||||
|
||||
tConnAlias :: ARawTransmission -> ACommand p -> Either ErrorType (ACommand p)
|
||||
tConnAlias :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p)
|
||||
tConnAlias (_, connAlias, _) cmd = case cmd of
|
||||
-- NEW and JOIN have optional connAlias
|
||||
NEW _ -> Right cmd
|
||||
@@ -395,14 +386,14 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
| B.null connAlias -> Left $ SYNTAX errNoConnAlias
|
||||
| otherwise -> Right cmd
|
||||
|
||||
cmdWithMsgBody :: ACommand p -> m (Either ErrorType (ACommand p))
|
||||
cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p))
|
||||
cmdWithMsgBody = \case
|
||||
SEND body -> SEND <$$> getMsgBody body
|
||||
MSG agentMsgId srvTS agentTS status body -> MSG agentMsgId srvTS agentTS status <$$> getMsgBody body
|
||||
cmd -> return $ Right cmd
|
||||
|
||||
-- TODO refactor with server
|
||||
getMsgBody :: MsgBody -> m (Either ErrorType MsgBody)
|
||||
getMsgBody :: MsgBody -> m (Either AgentErrorType MsgBody)
|
||||
getMsgBody msgBody =
|
||||
case B.unpack msgBody of
|
||||
':' : body -> return . Right $ B.pack body
|
||||
|
||||
@@ -37,8 +37,9 @@ import Data.Maybe
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.Transmission (SMPServer (..))
|
||||
import Simplex.Messaging.Server.Transmission
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Util
|
||||
import System.IO
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
module Simplex.Messaging.Server.Transmission where
|
||||
module Simplex.Messaging.Protocol where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Class
|
||||
@@ -18,9 +18,9 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (ord)
|
||||
import Data.Kind
|
||||
import Data.String
|
||||
import Data.Time.Clock
|
||||
import Data.Time.ISO8601
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Util
|
||||
import System.IO
|
||||
@@ -143,54 +143,6 @@ serializeCommand = \case
|
||||
where
|
||||
serializeMsg msgBody = " " <> B.pack (show $ B.length msgBody) <> "\n" <> msgBody
|
||||
|
||||
type Encoded = ByteString
|
||||
|
||||
-- newtype to avoid accidentally changing order of transmission parts
|
||||
newtype CorrId = CorrId {bs :: ByteString} deriving (Eq, Ord, Show)
|
||||
|
||||
instance IsString CorrId where
|
||||
fromString = CorrId . fromString
|
||||
|
||||
type PublicKey = Encoded
|
||||
|
||||
type PrivateKey = Encoded
|
||||
|
||||
type Signature = Encoded
|
||||
|
||||
type RecipientKey = PublicKey
|
||||
|
||||
type SenderKey = PublicKey
|
||||
|
||||
type RecipientId = QueueId
|
||||
|
||||
type SenderId = QueueId
|
||||
|
||||
type QueueId = Encoded
|
||||
|
||||
type MsgId = Encoded
|
||||
|
||||
type MsgBody = ByteString
|
||||
|
||||
data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq)
|
||||
|
||||
errBadTransmission :: Int
|
||||
errBadTransmission = 1
|
||||
|
||||
errBadParameters :: Int
|
||||
errBadParameters = 2
|
||||
|
||||
errNoCredentials :: Int
|
||||
errNoCredentials = 3
|
||||
|
||||
errHasCredentials :: Int
|
||||
errHasCredentials = 4
|
||||
|
||||
errNoQueueId :: Int
|
||||
errNoQueueId = 5
|
||||
|
||||
errMessageBody :: Int
|
||||
errMessageBody = 6
|
||||
|
||||
tPutRaw :: Handle -> RawTransmission -> IO ()
|
||||
tPutRaw h (signature, corrId, queueId, command) = do
|
||||
putLn h signature
|
||||
@@ -23,13 +23,14 @@ import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Time.Clock
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.Env.STM
|
||||
import Simplex.Messaging.Server.MsgStore
|
||||
import Simplex.Messaging.Server.MsgStore.STM (MsgQueue)
|
||||
import Simplex.Messaging.Server.QueueStore
|
||||
import Simplex.Messaging.Server.QueueStore.STM (QueueStore)
|
||||
import Simplex.Messaging.Server.Transmission
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Util
|
||||
import UnliftIO.Async
|
||||
import UnliftIO.Concurrent
|
||||
|
||||
@@ -10,9 +10,10 @@ import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.MsgStore.STM
|
||||
import Simplex.Messaging.Server.QueueStore.STM
|
||||
import Simplex.Messaging.Server.Transmission
|
||||
import UnliftIO.STM
|
||||
|
||||
data ServerConfig = ServerConfig
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
module Simplex.Messaging.Server.MsgStore where
|
||||
|
||||
import Data.Time.Clock
|
||||
import Simplex.Messaging.Server.Transmission
|
||||
import Simplex.Messaging.Types
|
||||
|
||||
data Message = Message
|
||||
{ msgId :: Encoded,
|
||||
|
||||
@@ -8,8 +8,8 @@ module Simplex.Messaging.Server.MsgStore.STM where
|
||||
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Server.MsgStore
|
||||
import Simplex.Messaging.Server.Transmission
|
||||
import UnliftIO.STM
|
||||
|
||||
newtype MsgQueue = MsgQueue {msgQueue :: TQueue Message}
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
|
||||
module Simplex.Messaging.Server.QueueStore where
|
||||
|
||||
import Simplex.Messaging.Server.Transmission
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Types
|
||||
|
||||
data QueueRec = QueueRec
|
||||
{ recipientId :: QueueId,
|
||||
|
||||
@@ -13,8 +13,9 @@ module Simplex.Messaging.Server.QueueStore.STM where
|
||||
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.QueueStore
|
||||
import Simplex.Messaging.Server.Transmission
|
||||
import Simplex.Messaging.Types
|
||||
import UnliftIO.STM
|
||||
|
||||
data QueueStoreData = QueueStoreData
|
||||
|
||||
60
src/Simplex/Messaging/Types.hs
Normal file
60
src/Simplex/Messaging/Types.hs
Normal file
@@ -0,0 +1,60 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
module Simplex.Messaging.Types where
|
||||
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.String
|
||||
|
||||
type Encoded = ByteString
|
||||
|
||||
-- newtype to avoid accidentally changing order of transmission parts
|
||||
newtype CorrId = CorrId {bs :: ByteString} deriving (Eq, Ord, Show)
|
||||
|
||||
instance IsString CorrId where
|
||||
fromString = CorrId . fromString
|
||||
|
||||
type PublicKey = Encoded
|
||||
|
||||
type PrivateKey = Encoded
|
||||
|
||||
type Signature = Encoded
|
||||
|
||||
type RecipientKey = PublicKey
|
||||
|
||||
type SenderKey = PublicKey
|
||||
|
||||
type RecipientId = QueueId
|
||||
|
||||
type SenderId = QueueId
|
||||
|
||||
type QueueId = Encoded
|
||||
|
||||
type MsgId = Encoded
|
||||
|
||||
type MsgBody = ByteString
|
||||
|
||||
data ErrorType = UNKNOWN | PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq)
|
||||
|
||||
errBadTransmission :: Int
|
||||
errBadTransmission = 1
|
||||
|
||||
errBadParameters :: Int
|
||||
errBadParameters = 2
|
||||
|
||||
errNoCredentials :: Int
|
||||
errNoCredentials = 3
|
||||
|
||||
errHasCredentials :: Int
|
||||
errHasCredentials = 4
|
||||
|
||||
errNoQueueId :: Int
|
||||
errNoQueueId = 5
|
||||
|
||||
errMessageBody :: Int
|
||||
errMessageBody = 6
|
||||
@@ -9,6 +9,7 @@ import Data.Word (Word32)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Types
|
||||
import Simplex.Messaging.Agent.Store.Types
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import System.Random
|
||||
|
||||
@@ -101,7 +101,7 @@ withSmpAgent = withSmpAgentOn (agentTestPort, testDB)
|
||||
|
||||
testSMPAgentClientOn :: MonadUnliftIO m => ServiceName -> (Handle -> m a) -> m a
|
||||
testSMPAgentClientOn port' client = do
|
||||
threadDelay 50_000 -- TODO hack: thread delay for SMP agent to start
|
||||
threadDelay 100_000 -- TODO hack: thread delay for SMP agent to start
|
||||
runTCPClient agentTestHost port' $ \h -> do
|
||||
line <- liftIO $ getLn h
|
||||
if line == "Welcome to SMP v0.2.0 agent"
|
||||
|
||||
@@ -7,9 +7,9 @@ module SMPClient where
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Network.Socket
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server
|
||||
import Simplex.Messaging.Server.Env.STM
|
||||
import Simplex.Messaging.Server.Transmission
|
||||
import Simplex.Messaging.Transport
|
||||
import Test.Hspec
|
||||
import UnliftIO.Concurrent
|
||||
|
||||
@@ -11,7 +11,8 @@ import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import SMPClient
|
||||
import Simplex.Messaging.Server.Transmission
|
||||
import Simplex.Messaging.Types
|
||||
import Simplex.Messaging.Protocol
|
||||
import System.IO (Handle)
|
||||
import System.Timeout
|
||||
import Test.HUnit
|
||||
|
||||
Reference in New Issue
Block a user