diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index f85e5c95a..a60608bef 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -92,6 +92,7 @@ data DeliveryStatus type SMPServerId = Int64 +-- TODO rework types - decouple Transmission types from Store? Convert on the agent instead? class Monad m => MonadAgentStore s m where addServer :: s -> SMPServer -> m SMPServerId createRcvConn :: s -> ConnAlias -> ReceiveQueue -> m () @@ -102,7 +103,7 @@ class Monad m => MonadAgentStore s m where addRcvQueue :: s -> ConnAlias -> ReceiveQueue -> m () removeSndAuth :: s -> ConnAlias -> m () updateQueueStatus :: s -> ConnAlias -> QueueDirection -> QueueStatus -> m () - createMsg :: s -> ConnAlias -> QueueDirection -> AMessage -> m MessageDelivery + createMsg :: s -> ConnAlias -> QueueDirection -> AgentMsgId -> AMessage -> m () getLastMsg :: s -> ConnAlias -> QueueDirection -> m MessageDelivery getMsg :: s -> ConnAlias -> QueueDirection -> AgentMsgId -> m MessageDelivery updateMsgStatus :: s -> ConnAlias -> QueueDirection -> AgentMsgId -> m () diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 468eb4b26..8c608447e 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -22,6 +22,7 @@ import Control.Monad.IO.Unlift import Data.Int (Int64) import Data.Maybe import qualified Data.Text as T +import Data.Time import Database.SQLite.Simple hiding (Connection) import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.FromField @@ -330,6 +331,22 @@ updateSndQueueStatus store sndQueueId status = |] (Only status :. Only sndQueueId) +instance ToField QueueDirection where toField = toField . show + +-- TODO add parser and serializer for DeliveryStatus? Pass DeliveryStatus? +insertMsg :: MonadUnliftIO m => SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> Message -> m () +insertMsg store connAlias qDirection agentMsgId msg = do + tstamp <- liftIO getCurrentTime + void $ + insertWithLock + store + messagesLock + [s| + INSERT INTO messages (conn_alias, agent_msg_id, timestamp, message, direction, msg_status) + VALUES (?,?,?,?,?,"MDTransmitted"); + |] + (Only connAlias :. Only agentMsgId :. Only tstamp :. Only qDirection :. Only msg) + instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where addServer store smpServer = upsertServer store smpServer @@ -412,15 +429,27 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto (rcvQId, _) <- getConnection st connAlias case rcvQId of Just qId -> updateRcvQueueStatus st qId status - Nothing -> throwError SEBadConn + Nothing -> throwError SEBadQueueDirection SND -> do (_, sndQId) <- getConnection st connAlias case sndQId of Just qId -> updateSndQueueStatus st qId status - Nothing -> throwError SEBadConn + Nothing -> throwError SEBadQueueDirection - createMsg :: SQLiteStore -> ConnAlias -> QueueDirection -> AMessage -> m MessageDelivery - createMsg _st _connAlias _dir _msg = throwError SEInternal + -- TODO decrease duplication of queue direction checks? + createMsg :: SQLiteStore -> ConnAlias -> QueueDirection -> AgentMsgId -> AMessage -> m () + createMsg st connAlias qDirection agentMsgId msg = do + case qDirection of + RCV -> do + (rcvQId, _) <- getConnection st connAlias + case rcvQId of + Just _ -> insertMsg st connAlias qDirection agentMsgId $ serializeMsg msg + Nothing -> throwError SEBadQueueDirection + SND -> do + (_, sndQId) <- getConnection st connAlias + case sndQId of + Just _ -> insertMsg st connAlias qDirection agentMsgId $ serializeMsg msg + Nothing -> throwError SEBadQueueDirection getLastMsg :: SQLiteStore -> ConnAlias -> QueueDirection -> m MessageDelivery getLastMsg _st _connAlias _dir = throwError SEInternal diff --git a/src/Simplex/Messaging/Agent/Store/Types.hs b/src/Simplex/Messaging/Agent/Store/Types.hs index c4d4ae9af..30695d29d 100644 --- a/src/Simplex/Messaging/Agent/Store/Types.hs +++ b/src/Simplex/Messaging/Agent/Store/Types.hs @@ -12,4 +12,5 @@ data StoreError | SEBadConn | SEBadConnType ConnType | SEBadQueueStatus + | SEBadQueueDirection deriving (Eq, Show, Exception) diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 8cb95ffe4..bd272bb5b 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -14,7 +14,6 @@ module Simplex.Messaging.Agent.Transmission where import Control.Monad import Control.Monad.IO.Class --- import Numeric.Natural import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -91,13 +90,64 @@ data ACommand (p :: AParty) where deriving instance Show (ACommand p) +type Message = ByteString + data AMessage where HELLO :: VerificationKey -> AckMode -> AMessage REPLY :: SMPQueueInfo -> AMessage A_MSG :: MsgBody -> AMessage --- A_ACK :: AgentMsgId -> AckStatus -> AMessage --- A_DEL :: AMessage +parseMessage :: Message -> Either ErrorType AMessage +parseMessage msg = case B.words msg of + ["HELLO", key, ackMode] -> HELLO key <$> parseAckMode ackMode + ["REPLY", qInfo] -> REPLY <$> parseSmpQueueInfo qInfo + ["A_MSG", msgBody] -> Right $ A_MSG msgBody + _ -> Left UNKNOWN + +parseSmpQueueInfo :: ByteString -> Either ErrorType SMPQueueInfo +parseSmpQueueInfo qInfo = case splitOn "::" $ B.unpack qInfo of + ["smp", srv, qId, ek] -> liftM3 SMPQueueInfo (parseSmpServer $ B.pack srv) (parseDec64 qId) (parseDec64 ek) + _ -> Left $ SYNTAX errBadInvitation + +parseSmpServer :: ByteString -> Either ErrorType SMPServer +parseSmpServer srv = + let (s, kf) = span (/= '#') $ B.unpack srv + (h, p) = span (/= ':') s + in SMPServer h (parseSrvPart p) <$> traverse parseDec64 (parseSrvPart kf) + +parseDec64 :: String -> Either ErrorType ByteString +parseDec64 s = case decode $ B.pack s of + Left _ -> Left $ SYNTAX errBadEncoding + Right b -> Right b + +parseSrvPart :: String -> Maybe String +parseSrvPart s = if length s > 1 then Just $ tail s else Nothing + +parseAckMode :: ByteString -> Either ErrorType AckMode +parseAckMode am = case B.split '=' am of + ["ACK", mode] -> AckMode <$> getMode mode + _ -> errParams + +getMode :: ByteString -> Either ErrorType Mode +getMode mode = case mode of + "ON" -> Right On + "OFF" -> Right Off + _ -> errParams + +errParams :: Either ErrorType a +errParams = Left $ SYNTAX errBadParameters + +serializeMsg :: AMessage -> Message +serializeMsg = \case + HELLO _verKey _ackMode -> "HELLO" -- TODO + REPLY qInfo -> "REPLY" <> serializeSmpQueueInfo qInfo + A_MSG msgBody -> "A_MSG" <> msgBody -- ? whitespaces missing + +serializeSmpQueueInfo :: SMPQueueInfo -> ByteString +serializeSmpQueueInfo (SMPQueueInfo srv qId ek) = "smp::" <> serializeServer srv <> "::" <> encode qId <> "::" <> encode ek + +serializeServer :: SMPServer -> ByteString +serializeServer SMPServer {host, port, keyHash} = B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack) keyHash data SMPServer = SMPServer { host :: HostName, @@ -179,8 +229,8 @@ smpUnexpectedResponse = 3 parseCommand :: ByteString -> Either ErrorType ACmd parseCommand command = case B.words command of ["NEW", srv] -> newConn srv -- . Right $ AckMode On - -- ["NEW", srv, am] -> newConn srv $ ackMode am - ["INV", qInfo] -> ACmd SAgent . INV <$> smpQueueInfo qInfo + -- ["NEW", srv, am] -> newConn srv $ parseAckMode am + ["INV", qInfo] -> ACmd SAgent . INV <$> parseSmpQueueInfo qInfo "JOIN" : qInfo : ws -> joinConn qInfo ws ["CON"] -> Right . ACmd SAgent $ CON "NEW" : _ -> errParams @@ -190,72 +240,33 @@ parseCommand command = case B.words command of _ -> Left UNKNOWN where newConn :: ByteString -> Either ErrorType ACmd - newConn srv = ACmd SClient . NEW <$> smpServer srv + newConn srv = ACmd SClient . NEW <$> parseSmpServer srv joinConn :: ByteString -> [ByteString] -> Either ErrorType ACmd joinConn qInfo ws = do - q <- smpQueueInfo qInfo + q <- parseSmpQueueInfo qInfo case ws of [] -> let SMPQueueInfo srv _ _ = q in joinCmd q $ ReplyOn srv ["NO_REPLY"] -> joinCmd q ReplyOff [srv] -> do - s <- smpServer srv + s <- parseSmpServer srv joinCmd q $ ReplyOn s _ -> errParams where joinCmd q r = return $ ACmd SClient $ JOIN q r - smpServer :: ByteString -> Either ErrorType SMPServer - smpServer srv = - let (s, kf) = span (/= '#') $ B.unpack srv - (h, p) = span (/= ':') s - in SMPServer h (srvPart p) <$> traverse dec64 (srvPart kf) - - smpQueueInfo :: ByteString -> Either ErrorType SMPQueueInfo - smpQueueInfo qInfo = case splitOn "::" $ B.unpack qInfo of - ["smp", srv, qId, ek] -> liftM3 SMPQueueInfo (smpServer $ B.pack srv) (dec64 qId) (dec64 ek) - _ -> Left $ SYNTAX errBadInvitation - - dec64 :: String -> Either ErrorType ByteString - dec64 s = case decode $ B.pack s of - Left _ -> Left $ SYNTAX errBadEncoding - Right b -> Right b - - srvPart :: String -> Maybe String - srvPart s = if length s > 1 then Just $ tail s else Nothing - - -- ackMode :: ByteString -> Either ErrorType AckMode - -- ackMode am = case B.split '=' am of - -- ["ACK", mode] -> AckMode <$> getMode mode - -- _ -> errParams - - -- getMode :: ByteString -> Either ErrorType Mode - -- getMode mode = case mode of - -- "ON" -> Right On - -- "OFF" -> Right Off - -- _ -> errParams - - errParams :: Either ErrorType a - errParams = Left $ SYNTAX errBadParameters - serializeCommand :: ACommand p -> ByteString serializeCommand = \case - NEW srv -> "NEW " <> server srv - INV qInfo -> "INV " <> smpQueueInfo qInfo + NEW srv -> "NEW " <> serializeServer srv + INV qInfo -> "INV " <> serializeSmpQueueInfo qInfo JOIN qInfo rMode -> - "JOIN " <> smpQueueInfo qInfo <> " " + "JOIN " <> serializeSmpQueueInfo qInfo <> " " <> case rMode of ReplyOff -> "NO_REPLY" - ReplyOn srv -> server srv + ReplyOn srv -> serializeServer srv CON -> "CON" ERR e -> "ERR " <> B.pack (show e) c -> B.pack $ show c - where - server :: SMPServer -> ByteString - server SMPServer {host, port, keyHash} = B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack) keyHash - - smpQueueInfo :: SMPQueueInfo -> ByteString - smpQueueInfo (SMPQueueInfo srv qId ek) = "smp::" <> server srv <> "::" <> encode qId <> "::" <> encode ek tPutRaw :: MonadIO m => Handle -> ARawTransmission -> m () tPutRaw h (corrId, connAlias, command) = do diff --git a/tests/AgentTests/SQLite.hs b/tests/AgentTests/SQLite.hs index 28bc315ca..22cbd22dd 100644 --- a/tests/AgentTests/SQLite.hs +++ b/tests/AgentTests/SQLite.hs @@ -44,6 +44,13 @@ storeTests = withStore do describe "Duplex connection" testUpdateQueueStatusConnDuplex describe "Bad queue direction - SND" testUpdateQueueStatusBadDirectionSnd describe "Bad queue direction - RCV" testUpdateQueueStatusBadDirectionRcv + describe "createMsg" do + describe "A_MSG in RCV direction" testCreateMsgRcv + describe "A_MSG in SND direction" testCreateMsgSnd + describe "HELLO message" testCreateMsgHello + describe "REPLY message" testCreateMsgReply + describe "Bad queue direction - SND" testCreateMsgBadDirectionSnd + describe "Bad queue direction - RCV" testCreateMsgBadDirectionRcv testCreateRcvConn :: SpecWith SQLiteStore testCreateRcvConn = do @@ -391,7 +398,7 @@ testUpdateQueueStatusBadDirectionSnd = do getConn store "conn1" `returnsResult` SomeConn SCReceive (ReceiveConnection "conn1" rcvQueue) updateQueueStatus store "conn1" SND Confirmed - `throwsError` SEBadConn + `throwsError` SEBadQueueDirection getConn store "conn1" `returnsResult` SomeConn SCReceive (ReceiveConnection "conn1" rcvQueue) @@ -413,6 +420,143 @@ testUpdateQueueStatusBadDirectionRcv = do getConn store "conn1" `returnsResult` SomeConn SCSend (SendConnection "conn1" sndQueue) updateQueueStatus store "conn1" RCV Confirmed - `throwsError` SEBadConn + `throwsError` SEBadQueueDirection getConn store "conn1" `returnsResult` SomeConn SCSend (SendConnection "conn1" sndQueue) + +testCreateMsgRcv :: SpecWith SQLiteStore +testCreateMsgRcv = do + it "should create a message in RCV direction" $ \store -> do + let rcvQueue = + ReceiveQueue + { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + rcvId = "1234", + rcvPrivateKey = "abcd", + sndId = Just "2345", + sndKey = Nothing, + decryptKey = "dcba", + verifyKey = Nothing, + status = New, + ackMode = AckMode On + } + createRcvConn store "conn1" rcvQueue + `returnsResult` () + let msg = A_MSG "hello" + let msgId = 1 + -- TODO getMsg to check message + createMsg store "conn1" RCV msgId msg + `returnsResult` () + +testCreateMsgSnd :: SpecWith SQLiteStore +testCreateMsgSnd = do + it "should create a message in SND direction" $ \store -> do + let sndQueue = + SendQueue + { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + sndId = "1234", + sndPrivateKey = "abcd", + encryptKey = "dcba", + signKey = "edcb", + status = New, + ackMode = AckMode On + } + createSndConn store "conn1" sndQueue + `returnsResult` () + let msg = A_MSG "hi" + let msgId = 1 + -- TODO getMsg to check message + createMsg store "conn1" SND msgId msg + `returnsResult` () + +testCreateMsgHello :: SpecWith SQLiteStore +testCreateMsgHello = do + it "should create a HELLO message" $ \store -> do + let rcvQueue = + ReceiveQueue + { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + rcvId = "1234", + rcvPrivateKey = "abcd", + sndId = Just "2345", + sndKey = Nothing, + decryptKey = "dcba", + verifyKey = Nothing, + status = New, + ackMode = AckMode On + } + createRcvConn store "conn1" rcvQueue + `returnsResult` () + let verificationKey = "abcd" + let am = AckMode On + let msg = HELLO verificationKey am + let msgId = 1 + -- TODO getMsg to check message + createMsg store "conn1" RCV msgId msg + `returnsResult` () + +testCreateMsgReply :: SpecWith SQLiteStore +testCreateMsgReply = do + it "should create a REPLY message" $ \store -> do + let rcvQueue = + ReceiveQueue + { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + rcvId = "1234", + rcvPrivateKey = "abcd", + sndId = Just "2345", + sndKey = Nothing, + decryptKey = "dcba", + verifyKey = Nothing, + status = New, + ackMode = AckMode On + } + createRcvConn store "conn1" rcvQueue + `returnsResult` () + let smpServer = SMPServer "smp.simplex.im" (Just "5223") (Just "1234") + let senderId = "sender1" + let encryptionKey = "abcd" + let msg = REPLY $ SMPQueueInfo smpServer senderId encryptionKey + let msgId = 1 + -- TODO getMsg to check message + createMsg store "conn1" RCV msgId msg + `returnsResult` () + +testCreateMsgBadDirectionSnd :: SpecWith SQLiteStore +testCreateMsgBadDirectionSnd = do + it "should throw error on attempt to create a message in ineligible SND direction" $ \store -> do + let rcvQueue = + ReceiveQueue + { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + rcvId = "1234", + rcvPrivateKey = "abcd", + sndId = Just "2345", + sndKey = Nothing, + decryptKey = "dcba", + verifyKey = Nothing, + status = New, + ackMode = AckMode On + } + createRcvConn store "conn1" rcvQueue + `returnsResult` () + let msg = A_MSG "hello" + let msgId = 1 + createMsg store "conn1" SND msgId msg + `throwsError` SEBadQueueDirection + +testCreateMsgBadDirectionRcv :: SpecWith SQLiteStore +testCreateMsgBadDirectionRcv = do + it "should throw error on attempt to create a message in ineligible RCV direction" $ \store -> do + let sndQueue = + SendQueue + { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + sndId = "1234", + sndPrivateKey = "abcd", + encryptKey = "dcba", + signKey = "edcb", + status = New, + ackMode = AckMode On + } + createSndConn store "conn1" sndQueue + `returnsResult` () + let msg = A_MSG "hello" + let msgId = 1 + createMsg store "conn1" RCV msgId msg + `throwsError` SEBadQueueDirection