broadcast commands (#154)

* broadcast commands (WIP)

* broadcasts: store and commands implementation

* test broadcast

* broadcast test

* rename migration, handle SEBadConnType errors

* query semicolons
This commit is contained in:
Evgeny Poberezkin
2021-06-01 18:11:16 +01:00
committed by GitHub
parent 84ce001598
commit bc780343df
7 changed files with 344 additions and 87 deletions
+1 -1
View File
@@ -15,7 +15,7 @@ cfg :: AgentConfig
cfg =
AgentConfig
{ tcpPort = "5224",
smpServers = L.fromList ["localhost:5223#KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8="],
smpServers = L.fromList ["localhost:5223#bU0K+bRg24xWW//lS0umO1Zdw/SXqpJNtm1/RrPLViE="],
rsaKeySize = 2048 `div` 8,
connIdBytes = 12,
tbqSize = 16,
+10
View File
@@ -0,0 +1,10 @@
CREATE TABLE IF NOT EXISTS broadcasts (
broadcast_id BLOB NOT NULL,
PRIMARY KEY (broadcast_id)
) WITHOUT ROWID;
CREATE TABLE IF NOT EXISTS broadcast_connections (
broadcast_id BLOB NOT NULL REFERENCES broadcasts (broadcast_id) ON DELETE CASCADE,
conn_alias BLOB NOT NULL REFERENCES connections (conn_alias),
PRIMARY KEY (broadcast_id, conn_alias)
) WITHOUT ROWID;
+88 -54
View File
@@ -151,31 +151,45 @@ withStore action = do
handleInternal e = throwError . SEInternal $ bshow e
storeError :: StoreError -> AgentErrorType
storeError = \case
SEConnNotFound -> CONN UNKNOWN
SEConnNotFound -> CONN NOT_FOUND
SEConnDuplicate -> CONN DUPLICATE
SEBadConnType CRcv -> CONN SIMPLEX
SEBadConnType CSnd -> CONN SIMPLEX
SEBcastNotFound -> BCAST B_NOT_FOUND
SEBcastDuplicate -> BCAST B_DUPLICATE
e -> INTERNAL $ show e
processCommand :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ATransmission 'Client -> m ()
processCommand c@AgentClient {sndQ} st (ATransmission corrId entity cmd) =
case entity of
Conn cId -> case cmd of
NEW -> createNewConnection cId
JOIN smpQueueInfo replyMode -> joinConnection cId smpQueueInfo replyMode
SUB -> subscribeConnection cId
SUBALL -> subscribeAll
SEND msgBody -> sendMessage cId msgBody
OFF -> suspendConnection cId
DEL -> deleteConnection cId
_ -> atomically . writeTBQueue sndQ . ATransmission corrId entity . ERR $ CMD ENTITY
processCommand c st (ATransmission corrId entity cmd) = process c st corrId entity cmd
where
createNewConnection :: ByteString -> m ()
createNewConnection cId = do
process = case entity of
Conn _ -> processConnCommand
Broadcast _ -> processBroadcastCommand
_ -> unsupportedEntity
unsupportedEntity :: AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity t -> ACommand 'Client c -> m ()
unsupportedEntity c _ corrId entity _ =
atomically . writeTBQueue (sndQ c) . ATransmission corrId entity . ERR $ CMD UNSUPPORTED
processConnCommand ::
forall c m. (AgentMonad m, EntityCommand 'Conn_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> ACommand 'Client c -> m ()
processConnCommand c@AgentClient {sndQ} st corrId conn = \case
NEW -> createNewConnection conn
JOIN smpQueueInfo replyMode -> joinConnection conn smpQueueInfo replyMode
SUB -> subscribeConnection conn
SUBALL -> subscribeAll
SEND msgBody -> sendMessage c st corrId conn msgBody
OFF -> suspendConnection conn
DEL -> deleteConnection conn
where
createNewConnection :: Entity 'Conn_ -> m ()
createNewConnection (Conn cId) = do
-- TODO create connection alias if not passed
-- make connAlias Maybe?
srv <- getSMPServer
(rq, qInfo) <- newReceiveQueue c srv cId
withStore $ createRcvConn st rq
respond (Conn cId) $ INV qInfo
respond conn $ INV qInfo
getSMPServer :: m SMPServer
getSMPServer =
@@ -186,8 +200,8 @@ processCommand c@AgentClient {sndQ} st (ATransmission corrId entity cmd) =
i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1)
pure $ servers L.!! i
joinConnection :: ByteString -> SMPQueueInfo -> ReplyMode -> m ()
joinConnection cId qInfo (ReplyMode replyMode) = do
joinConnection :: Entity 'Conn_ -> SMPQueueInfo -> ReplyMode -> m ()
joinConnection (Conn cId) qInfo (ReplyMode replyMode) = do
-- TODO create connection alias if not passed
-- make connAlias Maybe?
(sq, senderKey, verifyKey) <- newSendQueue qInfo cId
@@ -195,63 +209,38 @@ processCommand c@AgentClient {sndQ} st (ATransmission corrId entity cmd) =
connectToSendQueue c st sq senderKey verifyKey
when (replyMode == On) $ createReplyQueue cId sq
-- TODO this response is disabled to avoid two responses in terminal client (OK + CON),
-- respond OK
-- respond conn OK
subscribeConnection :: ByteString -> m ()
subscribeConnection cId =
subscribeConnection :: Entity 'Conn_ -> m ()
subscribeConnection conn'@(Conn cId) =
withStore (getConn st cId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> subscribe rq
SomeConn _ (RcvConnection _ rq) -> subscribe rq
_ -> throwError $ CONN SIMPLEX
where
subscribe rq = subscribeQueue c rq cId >> respond (Conn cId) OK
subscribe rq = subscribeQueue c rq cId >> respond conn' OK
-- TODO remove - hack for subscribing to all; respond' and parameterization of subscribeConnection are byproduct
subscribeAll :: m ()
subscribeAll = withStore (getAllConnAliases st) >>= mapM_ subscribeConnection
subscribeAll = withStore (getAllConnAliases st) >>= mapM_ (subscribeConnection . Conn)
sendMessage :: ByteString -> MsgBody -> m ()
sendMessage cId msgBody =
withStore (getConn st cId) >>= \case
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq
SomeConn _ (SndConnection _ sq) -> sendMsg sq
_ -> throwError $ CONN SIMPLEX
where
sendMsg sq = do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq
let msgStr =
serializeSMPMessage
SMPMessage
{ senderMsgId = unSndId internalSndId,
senderTimestamp = internalTs,
previousMsgHash,
agentMessage = A_MSG msgBody
}
msgHash = C.sha256Hash msgStr
withStore $
createSndMsg st sq $
SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash}
sendAgentMessage c sq msgStr
respond (Conn cId) $ SENT (unId internalId)
suspendConnection :: ByteString -> m ()
suspendConnection cId =
suspendConnection :: Entity 'Conn_ -> m ()
suspendConnection (Conn cId) =
withStore (getConn st cId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> suspend rq
SomeConn _ (RcvConnection _ rq) -> suspend rq
_ -> throwError $ CONN SIMPLEX
where
suspend rq = suspendQueue c rq >> respond (Conn cId) OK
suspend rq = suspendQueue c rq >> respond conn OK
deleteConnection :: ByteString -> m ()
deleteConnection cId =
deleteConnection :: Entity 'Conn_ -> m ()
deleteConnection (Conn cId) =
withStore (getConn st cId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> delete rq
SomeConn _ (RcvConnection _ rq) -> delete rq
_ -> delConn
where
delConn = withStore (deleteConn st cId) >> respond (Conn cId) OK
delConn = withStore (deleteConn st cId) >> respond conn OK
delete rq = do
deleteQueue c rq
removeSubscription c cId
@@ -271,9 +260,54 @@ processCommand c@AgentClient {sndQ} st (ATransmission corrId entity cmd) =
agentMessage = REPLY qInfo
}
respond :: EntityCommand t c => Entity t -> ACommand 'Agent c -> m ()
respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m ()
respond ent resp = atomically . writeTBQueue sndQ $ ATransmission corrId ent resp
sendMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m ()
sendMessage c st corrId (Conn cId) msgBody =
withStore (getConn st cId) >>= \case
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq
SomeConn _ (SndConnection _ sq) -> sendMsg sq
_ -> throwError $ CONN SIMPLEX
where
sendMsg :: SndQueue -> m ()
sendMsg sq = do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq
let msgStr =
serializeSMPMessage
SMPMessage
{ senderMsgId = unSndId internalSndId,
senderTimestamp = internalTs,
previousMsgHash,
agentMessage = A_MSG msgBody
}
msgHash = C.sha256Hash msgStr
withStore $
createSndMsg st sq $
SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash}
sendAgentMessage c sq msgStr
atomically . writeTBQueue (sndQ c) $ ATransmission corrId (Conn cId) $ SENT (unId internalId)
processBroadcastCommand ::
forall c m. (AgentMonad m, EntityCommand 'Broadcast_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Broadcast_ -> ACommand 'Client c -> m ()
processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case
NEW -> withStore (createBcast st bId) >> ok
ADD (Conn cId) -> withStore (addBcastConn st bId cId) >> ok
REM (Conn cId) -> withStore (removeBcastConn st bId cId) >> ok
LS -> withStore (getBcast st bId) >>= respond bcast . MS . map Conn
SEND msgBody -> withStore (getBcast st bId) >>= mapM_ (sendMsg msgBody) >> respond bcast (SENT 0)
DEL -> withStore (deleteBcast st bId) >> ok
where
sendMsg :: MsgBody -> ConnAlias -> m ()
sendMsg msgBody cId = sendMessage c st corrId (Conn cId) msgBody
ok :: m ()
ok = respond bcast OK
respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m ()
respond ent resp = atomically . writeTBQueue (sndQ c) $ ATransmission corrId ent resp
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
subscriber c@AgentClient {msgQ} st = forever $ do
-- TODO this will only process messages and notifications
+93 -14
View File
@@ -13,6 +13,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
-- |
@@ -35,6 +36,7 @@ module Simplex.Messaging.Agent.Protocol
EntityCommand,
entityCommand,
ACommand (..),
ACmdTag (..),
AParty (..),
APartyCmd (..),
SAParty (..),
@@ -45,6 +47,7 @@ module Simplex.Messaging.Agent.Protocol
AgentErrorType (..),
CommandErrorType (..),
ConnectionErrorType (..),
BroadcastErrorType (..),
BrokerErrorType (..),
SMPAgentError (..),
ATransmission (..),
@@ -73,7 +76,7 @@ module Simplex.Messaging.Agent.Protocol
serializeSmpQueueInfo,
serializeAgentError,
commandP,
entityP,
anEntityP,
parseSMPMessage,
smpServerP,
smpQueueInfoP,
@@ -106,6 +109,7 @@ import Data.Time.ISO8601
import Data.Type.Equality
import Data.Typeable ()
import GHC.Generics (Generic)
import GHC.TypeLits (ErrorMessage (..), TypeError)
import Generic.Random (genericArbitraryU)
import Network.Socket (HostName, ServiceName)
import qualified Simplex.Messaging.Crypto as C
@@ -158,10 +162,12 @@ instance TestEquality SAParty where
data EntityTag = Conn_ | OpenConn_ | Broadcast_ | AGroup_
data Entity :: EntityTag -> Type where
Conn :: ByteString -> Entity Conn_
OpenConn :: ByteString -> Entity OpenConn_
BroadCast :: ByteString -> Entity Broadcast_
AGroup :: ByteString -> Entity AGroup_
Conn :: {fromConn :: ByteString} -> Entity Conn_
OpenConn :: {fromOpenConn :: ByteString} -> Entity OpenConn_
Broadcast :: {fromBroadcast :: ByteString} -> Entity Broadcast_
AGroup :: {fromAGroup :: ByteString} -> Entity AGroup_
deriving instance Eq (Entity t)
deriving instance Show (Entity t)
@@ -169,7 +175,7 @@ entityId :: Entity t -> ByteString
entityId = \case
Conn bs -> bs
OpenConn bs -> bs
BroadCast bs -> bs
Broadcast bs -> bs
AGroup bs -> bs
data AnEntity = forall t. AE (Entity t)
@@ -200,7 +206,19 @@ type family EntityCommand (t :: EntityTag) (c :: ACmdTag) :: Constraint where
EntityCommand Conn_ DEL_ = ()
EntityCommand Conn_ OK_ = ()
EntityCommand Conn_ ERR_ = ()
EntityCommand Broadcast_ NEW_ = ()
EntityCommand Broadcast_ ADD_ = ()
EntityCommand Broadcast_ REM_ = ()
EntityCommand Broadcast_ LS_ = ()
EntityCommand Broadcast_ MS_ = ()
EntityCommand Broadcast_ SEND_ = ()
EntityCommand Broadcast_ SENT_ = ()
EntityCommand Broadcast_ DEL_ = ()
EntityCommand Broadcast_ OK_ = ()
EntityCommand Broadcast_ ERR_ = ()
EntityCommand _ ERR_ = ()
EntityCommand t c =
(Int ~ Bool, TypeError (Text "Entity " :<>: ShowType t :<>: Text " does not support command " :<>: ShowType c))
entityCommand :: Entity t -> ACommand p c -> Maybe (Dict (EntityCommand t c))
entityCommand = \case
@@ -219,6 +237,19 @@ entityCommand = \case
DEL -> Just Dict
OK -> Just Dict
ERR _ -> Just Dict
_ -> Nothing
Broadcast _ -> \case
NEW -> Just Dict
ADD _ -> Just Dict
REM _ -> Just Dict
LS -> Just Dict
MS _ -> Just Dict
SEND _ -> Just Dict
SENT _ -> Just Dict
DEL -> Just Dict
OK -> Just Dict
ERR _ -> Just Dict
_ -> Nothing
_ -> \case
ERR _ -> Just Dict
_ -> Nothing
@@ -236,6 +267,10 @@ data ACmdTag
| MSG_
| OFF_
| DEL_
| ADD_
| REM_
| LS_
| MS_
| OK_
| ERR_
@@ -267,6 +302,10 @@ data ACommand (p :: AParty) (c :: ACmdTag) where
-- RCVD :: AgentMsgId -> ACommand Agent
OFF :: ACommand Client MSG_
DEL :: ACommand Client DEL_
ADD :: Entity Conn_ -> ACommand Client ADD_
REM :: Entity Conn_ -> ACommand Client REM_
LS :: ACommand Client LS_
MS :: [Entity Conn_] -> ACommand Agent MS_
OK :: ACommand Agent OK_
ERR :: AgentErrorType -> ACommand Agent ERR_
@@ -287,6 +326,10 @@ instance TestEquality (ACommand p) where
testEquality c@MSG {} c'@MSG {} = refl c c'
testEquality OFF OFF = Just Refl
testEquality DEL DEL = Just Refl
testEquality c@ADD {} c'@ADD {} = refl c c'
testEquality c@REM {} c'@REM {} = refl c c'
testEquality c@LS {} c'@LS {} = refl c c'
testEquality c@MS {} c'@MS {} = refl c c'
testEquality OK OK = Just Refl
testEquality c@ERR {} c'@ERR {} = refl c c'
testEquality _ _ = Nothing
@@ -477,6 +520,8 @@ data AgentErrorType
CMD CommandErrorType
| -- | connection errors
CONN ConnectionErrorType
| -- | broadcast errors
BCAST BroadcastErrorType
| -- | SMP protocol errors forwarded to agent clients
SMP ErrorType
| -- | SMP server errors
@@ -492,7 +537,7 @@ data CommandErrorType
= -- | command is prohibited in this context
PROHIBITED
| -- | command is not supported by this entity
ENTITY
UNSUPPORTED
| -- | command syntax is invalid
SYNTAX
| -- | cannot parse entity
@@ -508,13 +553,21 @@ data CommandErrorType
-- | Connection error.
data ConnectionErrorType
= -- | connection alias is not in the database
UNKNOWN
NOT_FOUND
| -- | connection alias already exists
DUPLICATE
| -- | connection is simplex, but operation requires another queue
SIMPLEX
deriving (Eq, Generic, Read, Show, Exception)
-- | Broadcast error
data BroadcastErrorType
= -- | broadcast ID is not in the database
B_NOT_FOUND
| -- | broadcast ID already exists
B_DUPLICATE
deriving (Eq, Generic, Read, Show, Exception)
-- | SMP server errors.
data BrokerErrorType
= -- | invalid server response (failed to parse)
@@ -547,25 +600,30 @@ instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU
instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU
instance Arbitrary BroadcastErrorType where arbitrary = genericArbitraryU
instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU
instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU
entityP :: Parser AnEntity
entityP =
anEntityP :: Parser AnEntity
anEntityP =
($)
<$> ( "C:" $> AE . Conn
<|> "O:" $> AE . OpenConn
<|> "B:" $> AE . BroadCast
<|> "B:" $> AE . Broadcast
<|> "G:" $> AE . AGroup
)
<*> A.takeTill (== ' ')
entityConnP :: Parser (Entity Conn_)
entityConnP = "C:" *> (Conn <$> A.takeTill (== ' '))
serializeEntity :: Entity t -> ByteString
serializeEntity = \case
Conn s -> "C:" <> s
OpenConn s -> "O:" <> s
BroadCast s -> "B:" <> s
Broadcast s -> "B:" <> s
AGroup s -> "G:" <> s
-- | SMP agent command and response parser
@@ -582,6 +640,10 @@ commandP =
<|> "MSG " *> message
<|> "OFF" $> ACmd SClient OFF
<|> "DEL" $> ACmd SClient DEL
<|> "ADD " *> addCmd
<|> "REM " *> removeCmd
<|> "LS" $> ACmd SClient LS
<|> "MS " *> membersResp
<|> "ERR " *> agentError
<|> "CON" $> ACmd SAgent CON
<|> "OK" $> ACmd SAgent OK
@@ -590,6 +652,9 @@ commandP =
joinCmd = ACmd SClient <$> (JOIN <$> smpQueueInfoP <*> replyMode)
sendCmd = ACmd SClient . SEND <$> A.takeByteString
sentResp = ACmd SAgent . SENT <$> A.decimal
addCmd = ACmd SClient . ADD <$> entityConnP
removeCmd = ACmd SClient . REM <$> entityConnP
membersResp = ACmd SAgent . MS <$> (entityConnP `A.sepBy'` A.char ' ')
message = do
msgIntegrity <- msgIntegrityP <* A.space
recipientMeta <- "R=" *> partyMeta A.decimal
@@ -636,6 +701,10 @@ serializeCommand = \case
]
OFF -> "OFF"
DEL -> "DEL"
ADD c -> "ADD " <> serializeEntity c
REM c -> "REM " <> serializeEntity c
LS -> "LS"
MS cs -> "MS " <> B.intercalate " " (map serializeEntity cs)
CON -> "CON"
ERR e -> "ERR " <> serializeAgentError e
OK -> "OK"
@@ -663,6 +732,7 @@ serializeMsgIntegrity = \case
agentErrorTypeP :: Parser AgentErrorType
agentErrorTypeP =
"SMP " *> (SMP <$> SMP.errorTypeP)
<|> "BCAST " *> (BCAST <$> bcastErrorP)
<|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> SMP.errorTypeP)
<|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP)
<|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString)
@@ -672,10 +742,19 @@ agentErrorTypeP =
serializeAgentError :: AgentErrorType -> ByteString
serializeAgentError = \case
SMP e -> "SMP " <> SMP.serializeErrorType e
BCAST e -> "BCAST " <> serializeBcastError e
BROKER (RESPONSE e) -> "BROKER RESPONSE " <> SMP.serializeErrorType e
BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e
e -> bshow e
bcastErrorP :: Parser BroadcastErrorType
bcastErrorP = "NOT_FOUND" $> B_NOT_FOUND <|> "DUPLICATE" $> B_DUPLICATE
serializeBcastError :: BroadcastErrorType -> ByteString
serializeBcastError = \case
B_NOT_FOUND -> "NOT_FOUND"
B_DUPLICATE -> "DUPLICATE"
serializeMsg :: ByteString -> ByteString
serializeMsg body = bshow (B.length body) <> "\n" <> body
@@ -701,7 +780,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
where
tParseLoadBody :: ARawTransmission -> m (ATransmissionOrError p)
tParseLoadBody (corrId, entityStr, command) =
case parseAll entityP entityStr of
case parseAll anEntityP entityStr of
Left _ -> pure $ ATransmissionOrError @_ @_ @ERR_ corrId (Conn "") $ Left $ CMD BAD_ENTITY
Right entity -> do
let cmd = parseCommand command >>= fromParty >>= hasEntityId entity
@@ -730,7 +809,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
Left e -> err e
Right (APartyCmd cmd) -> case entityCommand entity cmd of
Just Dict -> ATransmissionOrError corrId entity $ Right cmd
_ -> err $ CMD ENTITY
_ -> err $ CMD UNSUPPORTED
where
err e = ATransmissionOrError @_ @_ @ERR_ corrId entity $ Left e
+15
View File
@@ -51,6 +51,13 @@ class Monad m => MonadAgentStore s m where
getMsg :: s -> ConnAlias -> InternalId -> m Msg
-- Broadcasts
createBcast :: s -> BroadcastId -> m ()
addBcastConn :: s -> BroadcastId -> ConnAlias -> m ()
removeBcastConn :: s -> BroadcastId -> ConnAlias -> m ()
deleteBcast :: s -> BroadcastId -> m ()
getBcast :: s -> BroadcastId -> m [ConnAlias]
-- * Queue types
-- | A receive queue. SMP queue through which the agent receives messages from a sender.
@@ -171,6 +178,10 @@ data SndMsgData = SndMsgData
internalHash :: MsgHash
}
-- * Broadcast types
type BroadcastId = ByteString
-- * Message types
-- | A message in either direction that is stored by the agent.
@@ -283,6 +294,10 @@ data StoreError
| -- | Wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa.
-- 'upgradeRcvConnToDuplex' and 'upgradeSndConnToDuplex' do not allow duplex connections - they would also return this error.
SEBadConnType ConnType
| -- | Broadcast ID not found.
SEBcastNotFound
| -- | Broadcast ID already used.
SEBcastDuplicate
| -- | Currently not used. The intention was to pass current expected queue status in methods,
-- as we always know what it should be at any stage of the protocol,
-- and in case it does not match use this error.
+90 -18
View File
@@ -107,12 +107,12 @@ connectSQLiteStore dbFilePath = do
|]
pure SQLiteStore {dbFilePath, dbConn, dbNew}
checkDuplicate :: (MonadUnliftIO m, MonadError StoreError m) => IO () -> m ()
checkDuplicate action = liftIOEither $ first handleError <$> E.try action
checkConstraint :: StoreError -> IO () -> IO (Either StoreError ())
checkConstraint err action = first handleError <$> E.try action
where
handleError :: SQLError -> StoreError
handleError e
| DB.sqlError e == DB.ErrorConstraint = SEConnDuplicate
| DB.sqlError e == DB.ErrorConstraint = err
| otherwise = SEInternal $ bshow e
withTransaction :: forall a. DB.Connection -> IO a -> IO a
@@ -130,19 +130,21 @@ withTransaction db a = loop 100 100_000
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
createRcvConn :: SQLiteStore -> RcvQueue -> m ()
createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} =
checkDuplicate $
withTransaction dbConn $ do
upsertServer_ dbConn server
insertRcvQueue_ dbConn q
insertRcvConnection_ dbConn q
liftIOEither $
checkConstraint SEConnDuplicate $
withTransaction dbConn $ do
upsertServer_ dbConn server
insertRcvQueue_ dbConn q
insertRcvConnection_ dbConn q
createSndConn :: SQLiteStore -> SndQueue -> m ()
createSndConn SQLiteStore {dbConn} q@SndQueue {server} =
checkDuplicate $
withTransaction dbConn $ do
upsertServer_ dbConn server
insertSndQueue_ dbConn q
insertSndConnection_ dbConn q
liftIOEither $
checkConstraint SEConnDuplicate $
withTransaction dbConn $ do
upsertServer_ dbConn server
insertSndQueue_ dbConn q
insertSndConnection_ dbConn q
getConn :: SQLiteStore -> ConnAlias -> m SomeConn
getConn SQLiteStore {dbConn} connAlias =
@@ -182,7 +184,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} =
liftIOEither . withTransaction dbConn $
getConn_ dbConn connAlias >>= \case
Right (SomeConn SCRcv (RcvConnection _ _)) -> do
Right (SomeConn _ RcvConnection {}) -> do
upsertServer_ dbConn server
insertSndQueue_ dbConn sq
updateConnWithSndQueue_ dbConn connAlias sq
@@ -194,7 +196,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} =
liftIOEither . withTransaction dbConn $
getConn_ dbConn connAlias >>= \case
Right (SomeConn SCSnd (SndConnection _ _)) -> do
Right (SomeConn _ SndConnection {}) -> do
upsertServer_ dbConn server
insertRcvQueue_ dbConn rq
updateConnWithRcvQueue_ dbConn connAlias rq
@@ -204,7 +206,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m ()
setRcvQueueStatus SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} status =
-- ? throw error if queue doesn't exist?
-- ? throw error if queue does not exist?
liftIO $
DB.executeNamed
dbConn
@@ -217,7 +219,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
setRcvQueueActive :: SQLiteStore -> RcvQueue -> VerificationKey -> m ()
setRcvQueueActive SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey =
-- ? throw error if queue doesn't exist?
-- ? throw error if queue does not exist?
liftIO $
DB.executeNamed
dbConn
@@ -235,7 +237,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m ()
setSndQueueStatus SQLiteStore {dbConn} SndQueue {sndId, server = SMPServer {host, port}} status =
-- ? throw error if queue doesn't exist?
-- ? throw error if queue does not exist?
liftIO $
DB.executeNamed
dbConn
@@ -281,6 +283,48 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg
getMsg _st _connAlias _id = throwError SENotImplemented
createBcast :: SQLiteStore -> BroadcastId -> m ()
createBcast SQLiteStore {dbConn} bId =
liftIOEither $
checkConstraint SEBcastDuplicate $
DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId)
addBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m ()
addBcastConn SQLiteStore {dbConn} bId connAlias =
liftIOEither . checkBroadcast dbConn bId $
getConn_ dbConn connAlias >>= \case
Left _ -> pure $ Left SEConnNotFound
Right (SomeConn _ RcvConnection {}) -> pure . Left $ SEBadConnType CRcv
Right _ ->
checkConstraint SEConnDuplicate $
DB.execute
dbConn
"INSERT INTO broadcast_connections (broadcast_id, conn_alias) VALUES (?, ?);"
(bId, connAlias)
removeBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m ()
removeBcastConn SQLiteStore {dbConn} bId connAlias =
liftIOEither . checkBroadcast dbConn bId $
bcastConnExists_ dbConn bId connAlias >>= \case
False -> pure $ Left SEConnNotFound
_ ->
Right
<$> DB.execute
dbConn
"DELETE FROM broadcast_connections WHERE broadcast_id = ? AND conn_alias = ?;"
(bId, connAlias)
deleteBcast :: SQLiteStore -> BroadcastId -> m ()
deleteBcast SQLiteStore {dbConn} bId =
liftIOEither . checkBroadcast dbConn bId $
Right <$> DB.execute dbConn "DELETE FROM broadcasts WHERE broadcast_id = ?;" (Only bId)
getBcast :: SQLiteStore -> BroadcastId -> m [ConnAlias]
getBcast SQLiteStore {dbConn} bId =
liftIOEither . checkBroadcast dbConn bId $
Right . map fromOnly
<$> DB.query dbConn "SELECT conn_alias FROM broadcast_connections WHERE broadcast_id = ?;" (Only bId)
-- * Auxiliary helpers
-- ? replace with ToField? - it's easy to forget to use this
@@ -686,3 +730,31 @@ updateHashSnd_ dbConn connAlias SndMsgData {..} =
":conn_alias" := connAlias,
":last_internal_snd_msg_id" := internalSndId
]
-- * Broadcast helpers
checkBroadcast :: DB.Connection -> BroadcastId -> IO (Either StoreError a) -> IO (Either StoreError a)
checkBroadcast dbConn bId action =
withTransaction dbConn $ do
ok <- bcastExists_ dbConn bId
if ok then action else pure $ Left SEBcastNotFound
bcastExists_ :: DB.Connection -> BroadcastId -> IO Bool
bcastExists_ dbConn bId = not . null <$> queryBcast
where
queryBcast :: IO [Only BroadcastId]
queryBcast = DB.query dbConn "SELECT broadcast_id FROM broadcasts WHERE broadcast_id = ?;" (Only bId)
bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnAlias -> IO Bool
bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn
where
queryBcastConn :: IO [(BroadcastId, ConnAlias)]
queryBcastConn =
DB.query
dbConn
[sql|
SELECT broadcast_id, conn_alias
FROM broadcast_connections
WHERE broadcast_id = ? AND conn_alias = ?;
|]
(bId, connAlias)
+47
View File
@@ -38,6 +38,9 @@ agentTests (ATransport t) = do
smpAgentTest3_1_1 $ testSubscription t
it "should send notifications to client when server disconnects" $
smpAgentServerTest $ testSubscrNotification t
describe "Broadcast" do
it "should create broadcast and send messages" $
smpAgentTest3 $ testBroadcast t
type TestTransmission p = (ACorrId, ByteString, APartyCmd p)
@@ -138,6 +141,50 @@ testSubscrNotification _ (server, _) client = do
killThread server
client <# ("", "C:conn1", END)
testBroadcast :: forall c. Transport c => TProxy c -> c -> c -> c -> IO ()
testBroadcast _ alice bob tom = do
-- establish connections
(alice, "alice") `connect` (bob, "bob")
(alice, "alice") `connect` (tom, "tom")
-- create and set up broadcast
alice #: ("1", "B:team", "NEW") #> ("1", "B:team", OK)
alice #: ("2", "B:team", "ADD C:bob") #> ("2", "B:team", OK)
alice #: ("3", "B:team", "ADD C:tom") #> ("3", "B:team", OK)
-- commands with errors
alice #: ("e1", "B:team", "NEW") #> ("e1", "B:team", ERR $ BCAST B_DUPLICATE)
alice #: ("e2", "B:group", "ADD C:bob") #> ("e2", "B:group", ERR $ BCAST B_NOT_FOUND)
alice #: ("e3", "B:team", "ADD C:unknown") #> ("e3", "B:team", ERR $ CONN NOT_FOUND)
alice #: ("e4", "B:team", "ADD C:bob") #> ("e4", "B:team", ERR $ CONN DUPLICATE)
-- send message
alice #: ("4", "B:team", "SEND 5\nhello") #> ("4", "C:bob", SENT 1)
alice <# ("4", "C:tom", SENT 1)
alice <# ("4", "B:team", SENT 0)
bob <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False
tom <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False
-- remove one connection
alice #: ("5", "B:team", "REM C:tom") #> ("5", "B:team", OK)
alice #: ("6", "B:team", "SEND 11\nhello again") #> ("6", "C:bob", SENT 2)
alice <# ("6", "B:team", SENT 0)
bob <#= \case ("", "C:alice", Msg "hello again") -> True; _ -> False
tom #:# "nothing delivered to tom"
-- commands with errors
alice #: ("e5", "B:group", "REM C:bob") #> ("e5", "B:group", ERR $ BCAST B_NOT_FOUND)
alice #: ("e6", "B:team", "REM C:unknown") #> ("e6", "B:team", ERR $ CONN NOT_FOUND)
alice #: ("e7", "B:team", "REM C:tom") #> ("e7", "B:team", ERR $ CONN NOT_FOUND)
-- delete broadcast
alice #: ("7", "B:team", "DEL") #> ("7", "B:team", OK)
alice #: ("8", "B:team", "SEND 11\ntry sending") #> ("8", "B:team", ERR $ BCAST B_NOT_FOUND)
-- commands with errors
alice #: ("e8", "B:team", "DEL") #> ("e8", "B:team", ERR $ BCAST B_NOT_FOUND)
alice #: ("e9", "B:group", "DEL") #> ("e9", "B:group", ERR $ BCAST B_NOT_FOUND)
where
connect :: (c, ByteString) -> (c, ByteString) -> IO ()
connect (h1, name1) (h2, name2) = do
("c1", _, Right (Inv qInfo)) <- h1 #: ("c1", "C:" <> name2, "NEW")
let qInfo' = serializeSmpQueueInfo qInfo
h2 #: ("c2", "C:" <> name1, "JOIN " <> qInfo') =#> \case ("", c1, APartyCmd CON) -> c1 == "C:" <> name1; _ -> False
h1 <#= \case ("", c2, APartyCmd CON) -> c2 == "C:" <> name2; _ -> False
samplePublicKey :: ByteString
samplePublicKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR"