Files
simplexmq/src/Simplex/Messaging/Agent/Store/SQLite.hs
2021-02-26 18:11:22 +04:00

524 lines
19 KiB
Haskell

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Simplex.Messaging.Agent.Store.SQLite where
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Data.Int (Int64)
import Data.Maybe
import qualified Data.Text as T
import Data.Time
import Database.SQLite.Simple hiding (Connection)
import qualified Database.SQLite.Simple as DB
import Database.SQLite.Simple.FromField
import Database.SQLite.Simple.Internal (Field (..))
import Database.SQLite.Simple.Ok
import Database.SQLite.Simple.QQ (sql)
import Database.SQLite.Simple.ToField
import Network.Socket
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.SQLite.Schema
import Simplex.Messaging.Agent.Store.Types
import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Server.Transmission (RecipientId, SenderId)
import Simplex.Messaging.Util
import Text.Read
import qualified UnliftIO.Exception as E
import UnliftIO.STM
addRcvQueueQuery :: Query
addRcvQueueQuery =
[sql|
INSERT INTO receive_queues
( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode)
VALUES
(:server_id,:rcv_id,:rcv_private_key,:snd_id,:snd_key,:decrypt_key,:verify_key,:status,:ack_mode);
|]
data SQLiteStore = SQLiteStore
{ dbFilename :: String,
conn :: DB.Connection,
serversLock :: TMVar (),
rcvQueuesLock :: TMVar (),
sndQueuesLock :: TMVar (),
connectionsLock :: TMVar (),
messagesLock :: TMVar ()
}
newSQLiteStore :: MonadUnliftIO m => String -> m SQLiteStore
newSQLiteStore dbFilename = do
conn <- liftIO $ DB.open dbFilename
liftIO $ createSchema conn
serversLock <- newTMVarIO ()
rcvQueuesLock <- newTMVarIO ()
sndQueuesLock <- newTMVarIO ()
connectionsLock <- newTMVarIO ()
messagesLock <- newTMVarIO ()
return
SQLiteStore
{ dbFilename,
conn,
serversLock,
rcvQueuesLock,
sndQueuesLock,
connectionsLock,
messagesLock
}
type QueueRowId = Int64
type ConnectionRowId = Int64
fromFieldToReadable :: forall a. (Read a, E.Typeable a) => Field -> Ok a
fromFieldToReadable = \case
f@(Field (SQLText t) _) ->
let str = T.unpack t
in case readMaybe str of
Just x -> Ok x
_ -> returnError ConversionFailed f ("invalid string: " <> str)
f -> returnError ConversionFailed f "expecting SQLText column type"
withLock :: MonadUnliftIO m => SQLiteStore -> (SQLiteStore -> TMVar ()) -> (DB.Connection -> m a) -> m a
withLock st tableLock f = do
let lock = tableLock st
E.bracket_
(atomically $ takeTMVar lock)
(atomically $ putTMVar lock ())
(f $ conn st)
insertWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m Int64
insertWithLock st tableLock queryStr q = do
withLock st tableLock $ \c -> liftIO $ do
DB.execute c queryStr q
DB.lastInsertRowId c
executeWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m ()
executeWithLock st tableLock queryStr q = do
withLock st tableLock $ \c -> liftIO $ do
DB.execute c queryStr q
instance ToRow SMPServer where
toRow SMPServer {host, port, keyHash} = toRow (host, port, keyHash)
instance FromRow SMPServer where
fromRow = SMPServer <$> field <*> field <*> field
upsertServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServer -> m SMPServerId
upsertServer SQLiteStore {conn} srv@SMPServer {host, port} = do
r <- liftIO $ do
DB.execute
conn
[sql|
INSERT INTO servers (host, port, key_hash) VALUES (?, ?, ?)
ON CONFLICT (host, port) DO UPDATE SET
host=excluded.host,
port=excluded.port,
key_hash=excluded.key_hash;
|]
srv
DB.queryNamed
conn
"SELECT server_id FROM servers WHERE host = :host AND port = :port"
[":host" := host, ":port" := port]
case r of
[Only serverId] -> return serverId
_ -> throwError SEInternal
getServer :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> SMPServerId -> m SMPServer
getServer SQLiteStore {conn} serverId = do
r <-
liftIO $
DB.queryNamed
conn
"SELECT host, port, key_hash FROM servers WHERE server_id = :server_id"
[":server_id" := serverId]
case r of
[smpServer] -> return smpServer
_ -> throwError SENotFound
instance ToField AckMode where toField (AckMode mode) = toField $ show mode
instance FromField AckMode where fromField = AckMode <$$> fromFieldToReadable
instance ToField QueueStatus where toField = toField . show
instance FromField QueueStatus where fromField = fromFieldToReadable
instance ToRow ReceiveQueue where
toRow ReceiveQueue {rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode} =
toRow (rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode)
instance FromRow ReceiveQueue where
fromRow = ReceiveQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field <*> field <*> field
-- TODO refactor into a single query with join
getRcvQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m ReceiveQueue
getRcvQueue st@SQLiteStore {conn} queueRowId = do
r <-
liftIO $
DB.queryNamed
conn
[sql|
SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode
FROM receive_queues
WHERE receive_queue_id = :rowId;
|]
[":rowId" := queueRowId]
case r of
[Only serverId :. rcvQueue] ->
(\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId
_ -> throwError SENotFound
getRcvQueueByRecipientId :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> RecipientId -> HostName -> Maybe ServiceName -> m ReceiveQueue
getRcvQueueByRecipientId st@SQLiteStore {conn} rcvId host port = do
r <-
liftIO $
DB.queryNamed
conn
[sql|
SELECT server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode
FROM receive_queues
WHERE rcv_id = :rcvId AND server_id IN (
SELECT server_id
FROM servers
WHERE host = :host AND port = :port
);
|]
[":rcvId" := rcvId, ":host" := host, ":port" := port]
case r of
[Only serverId :. rcvQueue] ->
(\srv -> (rcvQueue {server = srv} :: ReceiveQueue)) <$> getServer st serverId
_ -> throwError SENotFound
-- TODO refactor into a single query with join
getSndQueue :: (MonadUnliftIO m, MonadError StoreError m) => SQLiteStore -> QueueRowId -> m SendQueue
getSndQueue st@SQLiteStore {conn} queueRowId = do
r <-
liftIO $
DB.queryNamed
conn
[sql|
SELECT server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode
FROM send_queues
WHERE send_queue_id = :rowId;
|]
[":rowId" := queueRowId]
case r of
[Only serverId :. sndQueue] ->
(\srv -> (sndQueue {server = srv} :: SendQueue)) <$> getServer st serverId
_ -> throwError SENotFound
insertRcvQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> ReceiveQueue -> m QueueRowId
insertRcvQueue store serverId rcvQueue =
insertWithLock
store
rcvQueuesLock
[sql|
INSERT INTO receive_queues
( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode)
VALUES (?,?,?,?,?,?,?,?,?);
|]
(Only serverId :. rcvQueue)
insertRcvConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
insertRcvConnection store connAlias rcvQueueId =
void $
insertWithLock
store
connectionsLock
"INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,?,NULL);"
(Only connAlias :. Only rcvQueueId)
updateRcvConnectionWithSndQueue :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
updateRcvConnectionWithSndQueue store connAlias sndQueueId =
executeWithLock
store
connectionsLock
[sql|
UPDATE connections
SET send_queue_id = ?
WHERE conn_alias = ?;
|]
(Only sndQueueId :. Only connAlias)
instance ToRow SendQueue where
toRow SendQueue {sndId, sndPrivateKey, encryptKey, signKey, status, ackMode} =
toRow (sndId, sndPrivateKey, encryptKey, signKey, status, ackMode)
instance FromRow SendQueue where
fromRow = SendQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field
insertSndQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> SendQueue -> m QueueRowId
insertSndQueue store serverId sndQueue =
insertWithLock
store
sndQueuesLock
[sql|
INSERT INTO send_queues
( server_id, snd_id, snd_private_key, encrypt_key, sign_key, status, ack_mode)
VALUES (?,?,?,?,?,?,?);
|]
(Only serverId :. sndQueue)
insertSndConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
insertSndConnection store connAlias sndQueueId =
void $
insertWithLock
store
connectionsLock
"INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,NULL,?);"
(Only connAlias :. Only sndQueueId)
updateSndConnectionWithRcvQueue :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ()
updateSndConnectionWithRcvQueue store connAlias rcvQueueId =
executeWithLock
store
connectionsLock
[sql|
UPDATE connections
SET receive_queue_id = ?
WHERE conn_alias = ?;
|]
(Only rcvQueueId :. Only connAlias)
getConnection :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> ConnAlias -> m (Maybe QueueRowId, Maybe QueueRowId)
getConnection SQLiteStore {conn} connAlias = do
r <-
liftIO $
DB.queryNamed
conn
"SELECT receive_queue_id, send_queue_id FROM connections WHERE conn_alias = :conn_alias"
[":conn_alias" := connAlias]
case r of
[queueIds] -> return queueIds
_ -> throwError SEInternal
instance FromRow ConnAlias where
fromRow = field
getConnAliasByRcvQueue :: (MonadError StoreError m, MonadUnliftIO m) => SQLiteStore -> RecipientId -> m ConnAlias
getConnAliasByRcvQueue SQLiteStore {conn} rcvId = do
r <-
liftIO $
DB.queryNamed
conn
[sql|
SELECT c.conn_alias
FROM connections c
JOIN receive_queues rq
ON c.receive_queue_id = rq.receive_queue_id
WHERE rq.rcv_id = :rcvId;
|]
[":rcvId" := rcvId]
case r of
[connAlias] -> return connAlias
_ -> throwError SEInternal
deleteRcvQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m ()
deleteRcvQueue store rcvQueueId = do
executeWithLock
store
rcvQueuesLock
"DELETE FROM receive_queues WHERE receive_queue_id = ?"
(Only rcvQueueId)
deleteSndQueue :: MonadUnliftIO m => SQLiteStore -> QueueRowId -> m ()
deleteSndQueue store sndQueueId = do
executeWithLock
store
sndQueuesLock
"DELETE FROM send_queues WHERE send_queue_id = ?"
(Only sndQueueId)
deleteConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> m ()
deleteConnection store connAlias = do
executeWithLock
store
connectionsLock
"DELETE FROM connections WHERE conn_alias = ?"
(Only connAlias)
updateReceiveQueueStatus :: MonadUnliftIO m => SQLiteStore -> RecipientId -> HostName -> Maybe ServiceName -> QueueStatus -> m ()
updateReceiveQueueStatus store rcvQueueId host port status =
executeWithLock
store
rcvQueuesLock
[sql|
UPDATE receive_queues
SET status = ?
WHERE rcv_id = ?
AND server_id IN (
SELECT server_id
FROM servers
WHERE host = ? AND port = ?
);
|]
(Only status :. Only rcvQueueId :. Only host :. Only port)
updateSendQueueStatus :: MonadUnliftIO m => SQLiteStore -> SenderId -> HostName -> Maybe ServiceName -> QueueStatus -> m ()
updateSendQueueStatus store sndQueueId host port status =
executeWithLock
store
sndQueuesLock
[sql|
UPDATE send_queues
SET status = ?
WHERE snd_id = ?
AND server_id IN (
SELECT server_id
FROM servers
WHERE host = ? AND port = ?
);
|]
(Only status :. Only sndQueueId :. Only host :. Only port)
instance ToField QueueDirection where toField = toField . show
-- TODO add parser and serializer for DeliveryStatus? Pass DeliveryStatus?
insertMsg :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> Message -> m ()
insertMsg store connAlias qDirection agentMsgId msg = do
tstamp <- liftIO getCurrentTime
void $
insertWithLock
store
messagesLock
[sql|
INSERT INTO messages (conn_alias, agent_msg_id, timestamp, message, direction, msg_status)
VALUES (?,?,?,?,?,"MDTransmitted");
|]
(Only connAlias :. Only agentMsgId :. Only tstamp :. Only qDirection :. Only msg)
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
addServer store smpServer = upsertServer store smpServer
createRcvConn :: SQLiteStore -> ConnAlias -> ReceiveQueue -> m ()
createRcvConn st connAlias rcvQueue = do
-- TODO test for duplicate connAlias
srvId <- upsertServer st (server (rcvQueue :: ReceiveQueue))
rcvQId <- insertRcvQueue st srvId rcvQueue
insertRcvConnection st connAlias rcvQId
createSndConn :: SQLiteStore -> ConnAlias -> SendQueue -> m ()
createSndConn st connAlias sndQueue = do
-- TODO test for duplicate connAlias
srvId <- upsertServer st (server (sndQueue :: SendQueue))
sndQ <- insertSndQueue st srvId sndQueue
insertSndConnection st connAlias sndQ
-- TODO refactor ito a single query with join, and parse as `Only connAlias :. rcvQueue :. sndQueue`
getConn :: SQLiteStore -> ConnAlias -> m SomeConn
getConn st connAlias =
getConnection st connAlias >>= \case
(Just rcvQId, Just sndQId) -> do
rcvQ <- getRcvQueue st rcvQId
sndQ <- getSndQueue st sndQId
return $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
(Just rcvQId, _) -> do
rcvQ <- getRcvQueue st rcvQId
return $ SomeConn SCReceive (ReceiveConnection connAlias rcvQ)
(_, Just sndQId) -> do
sndQ <- getSndQueue st sndQId
return $ SomeConn SCSend (SendConnection connAlias sndQ)
_ -> throwError SEBadConn
getReceiveQueue :: SQLiteStore -> SMPServer -> RecipientId -> m (ConnAlias, ReceiveQueue)
getReceiveQueue st SMPServer {host, port} recipientId = do
rcvQueue <- getRcvQueueByRecipientId st recipientId host port
connAlias <- getConnAliasByRcvQueue st recipientId
return (connAlias, rcvQueue)
-- TODO make transactional
addSndQueue :: SQLiteStore -> ConnAlias -> SendQueue -> m ()
addSndQueue st connAlias sndQueue =
getConn st connAlias
>>= \case
SomeConn SCDuplex _ -> throwError (SEBadConnType CDuplex)
SomeConn SCSend _ -> throwError (SEBadConnType CSend)
SomeConn SCReceive _ -> do
srvId <- upsertServer st (server (sndQueue :: SendQueue))
sndQ <- insertSndQueue st srvId sndQueue
updateRcvConnectionWithSndQueue st connAlias sndQ
-- TODO make transactional
addRcvQueue :: SQLiteStore -> ConnAlias -> ReceiveQueue -> m ()
addRcvQueue st connAlias rcvQueue =
getConn st connAlias
>>= \case
SomeConn SCDuplex _ -> throwError (SEBadConnType CDuplex)
SomeConn SCReceive _ -> throwError (SEBadConnType CReceive)
SomeConn SCSend _ -> do
srvId <- upsertServer st (server (rcvQueue :: ReceiveQueue))
rcvQ <- insertRcvQueue st srvId rcvQueue
updateSndConnectionWithRcvQueue st connAlias rcvQ
-- TODO think about design of one-to-one relationships between connections ans send/receive queues
-- - Make wide `connections` table? -> Leads to inability to constrain queue fields on SQL level
-- - Make bi-directional foreign keys deferred on queue side?
-- * Involves populating foreign keys on queues' tables and reworking store
-- * Enables cascade deletes
-- ? See https://sqlite.org/foreignkeys.html#fk_deferred
-- - Keep as is and just wrap in transaction?
deleteConn :: SQLiteStore -> ConnAlias -> m ()
deleteConn st connAlias = do
(rcvQId, sndQId) <- getConnection st connAlias
forM_ rcvQId $ deleteRcvQueue st
forM_ sndQId $ deleteSndQueue st
deleteConnection st connAlias
when (isNothing rcvQId && isNothing sndQId) $ throwError SEBadConn
removeSndAuth :: SQLiteStore -> ConnAlias -> m ()
removeSndAuth _st _connAlias = throwError SENotImplemented
-- TODO throw error if queue doesn't exist
updateRcvQueueStatus :: SQLiteStore -> ReceiveQueue -> QueueStatus -> m ()
updateRcvQueueStatus st ReceiveQueue {rcvId, server = SMPServer {host, port}} status =
updateReceiveQueueStatus st rcvId host port status
-- TODO throw error if queue doesn't exist
updateSndQueueStatus :: SQLiteStore -> SendQueue -> QueueStatus -> m ()
updateSndQueueStatus st SendQueue {sndId, server = SMPServer {host, port}} status =
updateSendQueueStatus st sndId host port status
-- TODO decrease duplication of queue direction checks?
createMsg :: SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> AMessage -> m ()
createMsg st connAlias qDirection agentMsgId msg = do
case qDirection of
RCV -> do
(rcvQId, _) <- getConnection st connAlias
case rcvQId of
Just _ -> insertMsg st connAlias qDirection agentMsgId $ serializeAgentMessage msg
Nothing -> throwError SEBadQueueDirection
SND -> do
(_, sndQId) <- getConnection st connAlias
case sndQId of
Just _ -> insertMsg st connAlias qDirection agentMsgId $ serializeAgentMessage msg
Nothing -> throwError SEBadQueueDirection
getLastMsg :: SQLiteStore -> ConnAlias -> QueueDirection -> m MessageDelivery
getLastMsg _st _connAlias _dir = throwError SENotImplemented
getMsg :: SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> m MessageDelivery
getMsg _st _connAlias _dir _msgId = throwError SENotImplemented
-- TODO missing status parameter?
updateMsgStatus :: SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> m ()
updateMsgStatus _st _connAlias _dir _msgId = throwError SENotImplemented
deleteMsg :: SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> m ()
deleteMsg _st _connAlias _dir _msgId = throwError SENotImplemented