From 650971fa02de206ca99adcea71e0e0d4b5af9265 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sun, 3 Jan 2021 18:05:50 +0000 Subject: [PATCH] method to insert connection (#3) --- package.yaml | 1 + src/Simplex/Messaging/Agent/Store.hs | 4 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 140 ++++++++++++++---- .../Messaging/Agent/Store/SQLite/Schema.hs | 26 ++-- src/Simplex/Messaging/Agent/Transmission.hs | 5 +- src/Simplex/Messaging/Server/Transmission.hs | 1 + src/Simplex/Messaging/Transport.hs | 6 +- src/Simplex/Messaging/Util.hs | 5 + 8 files changed, 140 insertions(+), 48 deletions(-) diff --git a/package.yaml b/package.yaml index 2b6ae0b3c..7b907791a 100644 --- a/package.yaml +++ b/package.yaml @@ -24,6 +24,7 @@ dependencies: - sqlite-simple == 0.4.* - stm - template-haskell == 2.15.* + - text == 1.2.* - time == 1.9.* - unliftio == 0.2.* - unliftio-core == 0.1.* diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 67bfb5fa8..db268d1bf 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -73,8 +73,8 @@ type SMPServerId = Int64 class Monad m => MonadAgentStore s m where addServer :: s -> SMPServer -> m (Either StoreError SMPServerId) - createRcvConn :: s -> Maybe ConnAlias -> ReceiveQueue -> m (Either StoreError (Connection CReceive)) - createSndConn :: s -> Maybe ConnAlias -> SendQueue -> m (Either StoreError (Connection CSend)) + createRcvConn :: s -> ConnAlias -> ReceiveQueue -> m (Either StoreError (Connection CReceive)) + createSndConn :: s -> ConnAlias -> SendQueue -> m (Either StoreError (Connection CSend)) getConn :: s -> ConnAlias -> m (Either StoreError SomeConn) deleteConn :: s -> ConnAlias -> m (Either StoreError ()) addSndQueue :: s -> ConnAlias -> SendQueue -> m (Either StoreError (Connection CDuplex)) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 94eeb3a54..0e62f66ce 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -1,39 +1,50 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module Simplex.Messaging.Agent.Store.SQLite where import Control.Monad.IO.Unlift import Data.Int (Int64) -import Database.SQLite.Simple (NamedParam (..)) +import qualified Data.Text as T +import Database.SQLite.Simple hiding (Connection) import qualified Database.SQLite.Simple as DB +import Database.SQLite.Simple.FromField +import Database.SQLite.Simple.Internal (Field (..)) +import Database.SQLite.Simple.Ok +import Database.SQLite.Simple.ToField import Multiline (s) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite.Schema import Simplex.Messaging.Agent.Transmission +import Simplex.Messaging.Server.Transmission (PublicKey, QueueId) +import Simplex.Messaging.Util +import Text.Read import qualified UnliftIO.Exception as E import UnliftIO.STM -addServerQuery :: DB.Query -addServerQuery = +addRcvQueueQuery :: Query +addRcvQueueQuery = [s| - INSERT INTO servers (host_address, port, key_hash) - VALUES (:host_address, :port, :key_hash) - ON CONFLICT(host_address, port) DO UPDATE SET - host_address=excluded.host_address, - port=excluded.port, - key_hash=excluded.key_hash; + INSERT INTO receive_queues + ( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode) + VALUES + (:server_id,:rcv_id,:rcv_private_key,:snd_id,:snd_key,:decrypt_key,:verify_key,:status,:ack_mode); |] data SQLiteStore = SQLiteStore { conn :: DB.Connection, serversLock :: TMVar (), - recipientQueuesLock :: TMVar (), - senderQueuesLock :: TMVar (), + rcvQueuesLock :: TMVar (), + sndQueuesLock :: TMVar (), connectionsLock :: TMVar (), messagesLock :: TMVar () } @@ -43,20 +54,46 @@ newSQLiteStore dbFile = do conn <- liftIO $ DB.open dbFile liftIO $ createSchema conn serversLock <- newTMVarIO () - recipientQueuesLock <- newTMVarIO () - senderQueuesLock <- newTMVarIO () + rcvQueuesLock <- newTMVarIO () + sndQueuesLock <- newTMVarIO () connectionsLock <- newTMVarIO () messagesLock <- newTMVarIO () return SQLiteStore { conn, serversLock, - recipientQueuesLock, - senderQueuesLock, + rcvQueuesLock, + sndQueuesLock, connectionsLock, messagesLock } +-- data ReceiveQueueRec = ReceiveQueueRec +-- { rowId :: Maybe Int64, +-- serverId :: Int64, +-- rcvId :: QueueId, +-- rcvPrivateKey :: PrivateKey, +-- sndId :: Maybe QueueId, +-- sndKey :: Maybe PublicKey, +-- decryptKey :: PrivateKey, +-- verifyKey :: Maybe PublicKey, +-- status :: QueueStatus, +-- ackMode :: AckMode +-- } + +type QueueRowId = Int64 + +type ConnectionRowId = Int64 + +fromFieldToReadable :: forall a. (Read a, E.Typeable a) => Field -> Ok a +fromFieldToReadable = \case + f@(Field (SQLText t) _) -> + let s = T.unpack t + in case readMaybe s of + Just x -> Ok x + _ -> returnError ConversionFailed f ("invalid string: " ++ s) + f -> returnError ConversionFailed f "expecting SQLText column type" + withLock :: MonadUnliftIO m => SQLiteStore -> (SQLiteStore -> TMVar ()) -> (DB.Connection -> m a) -> m a withLock st tableLock f = do let lock = tableLock st @@ -65,21 +102,72 @@ withLock st tableLock f = do (atomically $ putTMVar lock ()) (f $ conn st) -insertWithLock :: MonadUnliftIO m => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> [DB.NamedParam] -> m Int64 -insertWithLock st tableLock q qParams = do +insertWithLock :: (MonadUnliftIO m, ToRow q) => SQLiteStore -> (SQLiteStore -> TMVar ()) -> DB.Query -> q -> m Int64 +insertWithLock st tableLock queryStr q = do withLock st tableLock $ \c -> liftIO $ do - DB.executeNamed c q qParams + DB.execute c queryStr q DB.lastInsertRowId c -instance MonadUnliftIO m => MonadAgentStore SQLiteStore m where - addServer :: SQLiteStore -> SMPServer -> m (Either StoreError SMPServerId) - addServer st SMPServer {host, port, keyHash} = - Right <$> insertWithLock st serversLock addServerQuery [":host_address" := host, ":port" := port, ":key_hash" := keyHash] +instance ToRow SMPServer where + toRow SMPServer {host, port, keyHash} = toRow (host, port, keyHash) --- createRcvConn :: DB.Connection -> Maybe ConnAlias -> ReceiveQueue -> m (Either StoreError (Connection CReceive)) --- createRcvConn conn connAlias q = do --- id <- query conn "INSERT ..." --- query conn "INSERT ..." +instance FromRow SMPServer where + fromRow = SMPServer <$> field <*> field <*> field + +upsertServer :: MonadUnliftIO m => SQLiteStore -> SMPServer -> m SMPServerId +upsertServer store = + insertWithLock + store + serversLock + "INSERT INTO servers (host, port, key_hash) VALUES (?, ?, ?)" + +instance ToField AckMode where toField (AckMode mode) = toField $ show mode + +instance FromField AckMode where fromField = AckMode <$$> fromFieldToReadable + +instance ToField QueueStatus where toField = toField . show + +instance FromField QueueStatus where fromField = fromFieldToReadable + +instance ToRow ReceiveQueue where + toRow ReceiveQueue {rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode} = + toRow (rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status, ackMode) + +instance FromRow ReceiveQueue where + fromRow = ReceiveQueue undefined <$> field <*> field <*> field <*> field <*> field <*> field <*> field <*> field + +insertRcvQueue :: MonadUnliftIO m => SQLiteStore -> SMPServerId -> ReceiveQueue -> m QueueRowId +insertRcvQueue store serverId rcvQueue = + insertWithLock + store + rcvQueuesLock + [s| + INSERT INTO receive_queues + ( server_id, rcv_id, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status, ack_mode) + VALUES (?,?,?,?,?,?,?,?,?); + |] + (Only serverId :. rcvQueue) + +insertRcvConnection :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueRowId -> m ConnectionRowId +insertRcvConnection store connAlias rcvQueueId = + insertWithLock + store + connectionsLock + "INSERT INTO connections (conn_alias, receive_queue_id, send_queue_id) VALUES (?,?,NULL);" + (Only connAlias :. Only rcvQueueId) + +instance MonadUnliftIO m => MonadAgentStore SQLiteStore m where + addServer store smpServer = Right <$> upsertServer store smpServer + + createRcvConn :: SQLiteStore -> ConnAlias -> ReceiveQueue -> m (Either StoreError (Connection CReceive)) + createRcvConn st connAlias rcvQueue = do + serverId <- upsertServer st $ server (rcvQueue :: ReceiveQueue) + qId <- insertRcvQueue st serverId rcvQueue -- TODO test for duplicate connAlias + insertRcvConnection st connAlias qId + return $ Right (ReceiveConnection connAlias rcvQueue) + +-- id <- query conn "INSERT ..." +-- query conn "INSERT ..." -- sqlite queries to create server, queue and connection diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs index 47f0e255d..9d9a1b0ed 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs @@ -12,19 +12,19 @@ servers = [s| CREATE TABLE IF NOT EXISTS servers ( server_id INTEGER PRIMARY KEY, - host_address TEXT NOT NULL, + host TEXT NOT NULL, port INT NOT NULL, key_hash BLOB, - UNIQUE (host_address, port) + UNIQUE (host, port) ) |] -- TODO unique constraints on (server_id, rcv_id) and (server_id, snd_id) -recipientQueues :: Query -recipientQueues = +receiveQueues :: Query +receiveQueues = [s| - CREATE TABLE IF NOT EXISTS recipient_queues - ( recipient_queue_id INTEGER PRIMARY KEY, + CREATE TABLE IF NOT EXISTS receive_queues + ( receive_queue_id INTEGER PRIMARY KEY, server_id INTEGER REFERENCES servers(server_id) NOT NULL, rcv_id BLOB NOT NULL, rcv_private_key BLOB NOT NULL, @@ -39,11 +39,11 @@ recipientQueues = ) |] -senderQueues :: Query -senderQueues = +sendQueues :: Query +sendQueues = [s| - CREATE TABLE IF NOT EXISTS sender_queues - ( sender_queue_id INTEGER PRIMARY KEY, + CREATE TABLE IF NOT EXISTS send_queues + ( send_queue_id INTEGER PRIMARY KEY, server_id INTEGER REFERENCES servers(server_id) NOT NULL, snd_id BLOB NOT NULL, snd_private_key BLOB NOT NULL, @@ -60,8 +60,8 @@ connections = CREATE TABLE IF NOT EXISTS connections ( connection_id INTEGER PRIMARY KEY, conn_alias TEXT UNIQUE, - recipient_queue_id INTEGER REFERENCES recipient_queues(recipient_queue_id), - sender_queue_id INTEGER REFERENCES sender_queues(sender_queue_id) + receive_queue_id INTEGER REFERENCES recipient_queues(receive_queue_id), + send_queue_id INTEGER REFERENCES sender_queues(send_queue_id) ) |] @@ -81,4 +81,4 @@ messages = createSchema :: Connection -> IO () createSchema conn = - mapM_ (execute_ conn) [servers, recipientQueues, senderQueues, connections, messages] + mapM_ (execute_ conn) [servers, receiveQueues, sendQueues, connections, messages] diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 22176de75..81fb754c0 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -35,6 +35,7 @@ import Simplex.Messaging.Server.Transmission ) import qualified Simplex.Messaging.Server.Transmission as SMP import Simplex.Messaging.Transport +import Simplex.Messaging.Util import System.IO import Text.Read import UnliftIO.Exception @@ -108,7 +109,7 @@ type ConnAlias = ByteString type OtherPartyId = Encoded -data Mode = On | Off deriving (Show) +data Mode = On | Off deriving (Show, Read) newtype AckMode = AckMode Mode deriving (Show) @@ -124,7 +125,7 @@ type VerificationKey = PublicKey data QueueDirection = SND | RCV deriving (Show) data QueueStatus = New | Confirmed | Secured | Active | Disabled - deriving (Show) + deriving (Show, Read) type AgentMsgId = Int diff --git a/src/Simplex/Messaging/Server/Transmission.hs b/src/Simplex/Messaging/Server/Transmission.hs index 1222b08c8..d5d0ce0f3 100644 --- a/src/Simplex/Messaging/Server/Transmission.hs +++ b/src/Simplex/Messaging/Server/Transmission.hs @@ -22,6 +22,7 @@ import Data.String import Data.Time.Clock import Data.Time.ISO8601 import Simplex.Messaging.Transport +import Simplex.Messaging.Util import System.IO import Text.Read diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index b4ef2ce09..58640d054 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -13,6 +13,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import GHC.IO.Exception (IOErrorType (..)) import Network.Socket +import Simplex.Messaging.Util import System.IO import System.IO.Error import UnliftIO.Concurrent @@ -89,8 +90,3 @@ getLn h = B.pack <$> liftIO (hGetLine h) getBytes :: MonadIO m => Handle -> Int -> m ByteString getBytes h = liftIO . B.hGet h - -infixl 4 <$$> - -(<$$>) :: (Functor f, Functor g) => (a -> b) -> f (g a) -> f (g b) -(<$$>) = fmap . fmap diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 4e2e54789..164984949 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -9,3 +9,8 @@ raceAny_ = r [] where r as (m : ms) = withAsync m $ \a -> r (a : as) ms r as [] = void $ waitAnyCancel as + +infixl 4 <$$> + +(<$$>) :: (Functor f, Functor g) => (a -> b) -> f (g a) -> f (g b) +(<$$>) = fmap . fmap