diff --git a/apps/smp-agent/Main.hs b/apps/smp-agent/Main.hs index 15fb33599..d2e7ae835 100644 --- a/apps/smp-agent/Main.hs +++ b/apps/smp-agent/Main.hs @@ -15,7 +15,7 @@ cfg :: AgentConfig cfg = AgentConfig { tcpPort = "5224", - smpServers = L.fromList ["localhost:5223#KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8="], + smpServers = L.fromList ["localhost:5223#bU0K+bRg24xWW//lS0umO1Zdw/SXqpJNtm1/RrPLViE="], rsaKeySize = 2048 `div` 8, connIdBytes = 12, tbqSize = 16, diff --git a/migrations/20210529_broadcasts.sql b/migrations/20210529_broadcasts.sql new file mode 100644 index 000000000..3095f0572 --- /dev/null +++ b/migrations/20210529_broadcasts.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS broadcasts ( + broadcast_id BLOB NOT NULL, + PRIMARY KEY (broadcast_id) +) WITHOUT ROWID; + +CREATE TABLE IF NOT EXISTS broadcast_connections ( + broadcast_id BLOB NOT NULL REFERENCES broadcasts (broadcast_id) ON DELETE CASCADE, + conn_alias BLOB NOT NULL REFERENCES connections (conn_alias), + PRIMARY KEY (broadcast_id, conn_alias) +) WITHOUT ROWID; diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 50bdad522..25bc15c91 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -151,31 +151,45 @@ withStore action = do handleInternal e = throwError . SEInternal $ bshow e storeError :: StoreError -> AgentErrorType storeError = \case - SEConnNotFound -> CONN UNKNOWN + SEConnNotFound -> CONN NOT_FOUND SEConnDuplicate -> CONN DUPLICATE + SEBadConnType CRcv -> CONN SIMPLEX + SEBadConnType CSnd -> CONN SIMPLEX + SEBcastNotFound -> BCAST B_NOT_FOUND + SEBcastDuplicate -> BCAST B_DUPLICATE e -> INTERNAL $ show e processCommand :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ATransmission 'Client -> m () -processCommand c@AgentClient {sndQ} st (ATransmission corrId entity cmd) = - case entity of - Conn cId -> case cmd of - NEW -> createNewConnection cId - JOIN smpQueueInfo replyMode -> joinConnection cId smpQueueInfo replyMode - SUB -> subscribeConnection cId - SUBALL -> subscribeAll - SEND msgBody -> sendMessage cId msgBody - OFF -> suspendConnection cId - DEL -> deleteConnection cId - _ -> atomically . writeTBQueue sndQ . ATransmission corrId entity . ERR $ CMD ENTITY +processCommand c st (ATransmission corrId entity cmd) = process c st corrId entity cmd where - createNewConnection :: ByteString -> m () - createNewConnection cId = do + process = case entity of + Conn _ -> processConnCommand + Broadcast _ -> processBroadcastCommand + _ -> unsupportedEntity + +unsupportedEntity :: AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity t -> ACommand 'Client c -> m () +unsupportedEntity c _ corrId entity _ = + atomically . writeTBQueue (sndQ c) . ATransmission corrId entity . ERR $ CMD UNSUPPORTED + +processConnCommand :: + forall c m. (AgentMonad m, EntityCommand 'Conn_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> ACommand 'Client c -> m () +processConnCommand c@AgentClient {sndQ} st corrId conn = \case + NEW -> createNewConnection conn + JOIN smpQueueInfo replyMode -> joinConnection conn smpQueueInfo replyMode + SUB -> subscribeConnection conn + SUBALL -> subscribeAll + SEND msgBody -> sendMessage c st corrId conn msgBody + OFF -> suspendConnection conn + DEL -> deleteConnection conn + where + createNewConnection :: Entity 'Conn_ -> m () + createNewConnection (Conn cId) = do -- TODO create connection alias if not passed -- make connAlias Maybe? srv <- getSMPServer (rq, qInfo) <- newReceiveQueue c srv cId withStore $ createRcvConn st rq - respond (Conn cId) $ INV qInfo + respond conn $ INV qInfo getSMPServer :: m SMPServer getSMPServer = @@ -186,8 +200,8 @@ processCommand c@AgentClient {sndQ} st (ATransmission corrId entity cmd) = i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1) pure $ servers L.!! i - joinConnection :: ByteString -> SMPQueueInfo -> ReplyMode -> m () - joinConnection cId qInfo (ReplyMode replyMode) = do + joinConnection :: Entity 'Conn_ -> SMPQueueInfo -> ReplyMode -> m () + joinConnection (Conn cId) qInfo (ReplyMode replyMode) = do -- TODO create connection alias if not passed -- make connAlias Maybe? (sq, senderKey, verifyKey) <- newSendQueue qInfo cId @@ -195,63 +209,38 @@ processCommand c@AgentClient {sndQ} st (ATransmission corrId entity cmd) = connectToSendQueue c st sq senderKey verifyKey when (replyMode == On) $ createReplyQueue cId sq -- TODO this response is disabled to avoid two responses in terminal client (OK + CON), - -- respond OK + -- respond conn OK - subscribeConnection :: ByteString -> m () - subscribeConnection cId = + subscribeConnection :: Entity 'Conn_ -> m () + subscribeConnection conn'@(Conn cId) = withStore (getConn st cId) >>= \case SomeConn _ (DuplexConnection _ rq _) -> subscribe rq SomeConn _ (RcvConnection _ rq) -> subscribe rq _ -> throwError $ CONN SIMPLEX where - subscribe rq = subscribeQueue c rq cId >> respond (Conn cId) OK + subscribe rq = subscribeQueue c rq cId >> respond conn' OK -- TODO remove - hack for subscribing to all; respond' and parameterization of subscribeConnection are byproduct subscribeAll :: m () - subscribeAll = withStore (getAllConnAliases st) >>= mapM_ subscribeConnection + subscribeAll = withStore (getAllConnAliases st) >>= mapM_ (subscribeConnection . Conn) - sendMessage :: ByteString -> MsgBody -> m () - sendMessage cId msgBody = - withStore (getConn st cId) >>= \case - SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq - SomeConn _ (SndConnection _ sq) -> sendMsg sq - _ -> throwError $ CONN SIMPLEX - where - sendMsg sq = do - internalTs <- liftIO getCurrentTime - (internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq - let msgStr = - serializeSMPMessage - SMPMessage - { senderMsgId = unSndId internalSndId, - senderTimestamp = internalTs, - previousMsgHash, - agentMessage = A_MSG msgBody - } - msgHash = C.sha256Hash msgStr - withStore $ - createSndMsg st sq $ - SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash} - sendAgentMessage c sq msgStr - respond (Conn cId) $ SENT (unId internalId) - - suspendConnection :: ByteString -> m () - suspendConnection cId = + suspendConnection :: Entity 'Conn_ -> m () + suspendConnection (Conn cId) = withStore (getConn st cId) >>= \case SomeConn _ (DuplexConnection _ rq _) -> suspend rq SomeConn _ (RcvConnection _ rq) -> suspend rq _ -> throwError $ CONN SIMPLEX where - suspend rq = suspendQueue c rq >> respond (Conn cId) OK + suspend rq = suspendQueue c rq >> respond conn OK - deleteConnection :: ByteString -> m () - deleteConnection cId = + deleteConnection :: Entity 'Conn_ -> m () + deleteConnection (Conn cId) = withStore (getConn st cId) >>= \case SomeConn _ (DuplexConnection _ rq _) -> delete rq SomeConn _ (RcvConnection _ rq) -> delete rq _ -> delConn where - delConn = withStore (deleteConn st cId) >> respond (Conn cId) OK + delConn = withStore (deleteConn st cId) >> respond conn OK delete rq = do deleteQueue c rq removeSubscription c cId @@ -271,9 +260,54 @@ processCommand c@AgentClient {sndQ} st (ATransmission corrId entity cmd) = agentMessage = REPLY qInfo } - respond :: EntityCommand t c => Entity t -> ACommand 'Agent c -> m () + respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m () respond ent resp = atomically . writeTBQueue sndQ $ ATransmission corrId ent resp +sendMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m () +sendMessage c st corrId (Conn cId) msgBody = + withStore (getConn st cId) >>= \case + SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq + SomeConn _ (SndConnection _ sq) -> sendMsg sq + _ -> throwError $ CONN SIMPLEX + where + sendMsg :: SndQueue -> m () + sendMsg sq = do + internalTs <- liftIO getCurrentTime + (internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq + let msgStr = + serializeSMPMessage + SMPMessage + { senderMsgId = unSndId internalSndId, + senderTimestamp = internalTs, + previousMsgHash, + agentMessage = A_MSG msgBody + } + msgHash = C.sha256Hash msgStr + withStore $ + createSndMsg st sq $ + SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash} + sendAgentMessage c sq msgStr + atomically . writeTBQueue (sndQ c) $ ATransmission corrId (Conn cId) $ SENT (unId internalId) + +processBroadcastCommand :: + forall c m. (AgentMonad m, EntityCommand 'Broadcast_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Broadcast_ -> ACommand 'Client c -> m () +processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case + NEW -> withStore (createBcast st bId) >> ok + ADD (Conn cId) -> withStore (addBcastConn st bId cId) >> ok + REM (Conn cId) -> withStore (removeBcastConn st bId cId) >> ok + LS -> withStore (getBcast st bId) >>= respond bcast . MS . map Conn + SEND msgBody -> withStore (getBcast st bId) >>= mapM_ (sendMsg msgBody) >> respond bcast (SENT 0) + DEL -> withStore (deleteBcast st bId) >> ok + where + sendMsg :: MsgBody -> ConnAlias -> m () + sendMsg msgBody cId = sendMessage c st corrId (Conn cId) msgBody + + ok :: m () + ok = respond bcast OK + + respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m () + respond ent resp = atomically . writeTBQueue (sndQ c) $ ATransmission corrId ent resp + subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m () subscriber c@AgentClient {msgQ} st = forever $ do -- TODO this will only process messages and notifications diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index f7be994ce..07f135440 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -13,6 +13,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} -- | @@ -35,6 +36,7 @@ module Simplex.Messaging.Agent.Protocol EntityCommand, entityCommand, ACommand (..), + ACmdTag (..), AParty (..), APartyCmd (..), SAParty (..), @@ -45,6 +47,7 @@ module Simplex.Messaging.Agent.Protocol AgentErrorType (..), CommandErrorType (..), ConnectionErrorType (..), + BroadcastErrorType (..), BrokerErrorType (..), SMPAgentError (..), ATransmission (..), @@ -73,7 +76,7 @@ module Simplex.Messaging.Agent.Protocol serializeSmpQueueInfo, serializeAgentError, commandP, - entityP, + anEntityP, parseSMPMessage, smpServerP, smpQueueInfoP, @@ -106,6 +109,7 @@ import Data.Time.ISO8601 import Data.Type.Equality import Data.Typeable () import GHC.Generics (Generic) +import GHC.TypeLits (ErrorMessage (..), TypeError) import Generic.Random (genericArbitraryU) import Network.Socket (HostName, ServiceName) import qualified Simplex.Messaging.Crypto as C @@ -158,10 +162,12 @@ instance TestEquality SAParty where data EntityTag = Conn_ | OpenConn_ | Broadcast_ | AGroup_ data Entity :: EntityTag -> Type where - Conn :: ByteString -> Entity Conn_ - OpenConn :: ByteString -> Entity OpenConn_ - BroadCast :: ByteString -> Entity Broadcast_ - AGroup :: ByteString -> Entity AGroup_ + Conn :: {fromConn :: ByteString} -> Entity Conn_ + OpenConn :: {fromOpenConn :: ByteString} -> Entity OpenConn_ + Broadcast :: {fromBroadcast :: ByteString} -> Entity Broadcast_ + AGroup :: {fromAGroup :: ByteString} -> Entity AGroup_ + +deriving instance Eq (Entity t) deriving instance Show (Entity t) @@ -169,7 +175,7 @@ entityId :: Entity t -> ByteString entityId = \case Conn bs -> bs OpenConn bs -> bs - BroadCast bs -> bs + Broadcast bs -> bs AGroup bs -> bs data AnEntity = forall t. AE (Entity t) @@ -200,7 +206,19 @@ type family EntityCommand (t :: EntityTag) (c :: ACmdTag) :: Constraint where EntityCommand Conn_ DEL_ = () EntityCommand Conn_ OK_ = () EntityCommand Conn_ ERR_ = () + EntityCommand Broadcast_ NEW_ = () + EntityCommand Broadcast_ ADD_ = () + EntityCommand Broadcast_ REM_ = () + EntityCommand Broadcast_ LS_ = () + EntityCommand Broadcast_ MS_ = () + EntityCommand Broadcast_ SEND_ = () + EntityCommand Broadcast_ SENT_ = () + EntityCommand Broadcast_ DEL_ = () + EntityCommand Broadcast_ OK_ = () + EntityCommand Broadcast_ ERR_ = () EntityCommand _ ERR_ = () + EntityCommand t c = + (Int ~ Bool, TypeError (Text "Entity " :<>: ShowType t :<>: Text " does not support command " :<>: ShowType c)) entityCommand :: Entity t -> ACommand p c -> Maybe (Dict (EntityCommand t c)) entityCommand = \case @@ -219,6 +237,19 @@ entityCommand = \case DEL -> Just Dict OK -> Just Dict ERR _ -> Just Dict + _ -> Nothing + Broadcast _ -> \case + NEW -> Just Dict + ADD _ -> Just Dict + REM _ -> Just Dict + LS -> Just Dict + MS _ -> Just Dict + SEND _ -> Just Dict + SENT _ -> Just Dict + DEL -> Just Dict + OK -> Just Dict + ERR _ -> Just Dict + _ -> Nothing _ -> \case ERR _ -> Just Dict _ -> Nothing @@ -236,6 +267,10 @@ data ACmdTag | MSG_ | OFF_ | DEL_ + | ADD_ + | REM_ + | LS_ + | MS_ | OK_ | ERR_ @@ -267,6 +302,10 @@ data ACommand (p :: AParty) (c :: ACmdTag) where -- RCVD :: AgentMsgId -> ACommand Agent OFF :: ACommand Client MSG_ DEL :: ACommand Client DEL_ + ADD :: Entity Conn_ -> ACommand Client ADD_ + REM :: Entity Conn_ -> ACommand Client REM_ + LS :: ACommand Client LS_ + MS :: [Entity Conn_] -> ACommand Agent MS_ OK :: ACommand Agent OK_ ERR :: AgentErrorType -> ACommand Agent ERR_ @@ -287,6 +326,10 @@ instance TestEquality (ACommand p) where testEquality c@MSG {} c'@MSG {} = refl c c' testEquality OFF OFF = Just Refl testEquality DEL DEL = Just Refl + testEquality c@ADD {} c'@ADD {} = refl c c' + testEquality c@REM {} c'@REM {} = refl c c' + testEquality c@LS {} c'@LS {} = refl c c' + testEquality c@MS {} c'@MS {} = refl c c' testEquality OK OK = Just Refl testEquality c@ERR {} c'@ERR {} = refl c c' testEquality _ _ = Nothing @@ -477,6 +520,8 @@ data AgentErrorType CMD CommandErrorType | -- | connection errors CONN ConnectionErrorType + | -- | broadcast errors + BCAST BroadcastErrorType | -- | SMP protocol errors forwarded to agent clients SMP ErrorType | -- | SMP server errors @@ -492,7 +537,7 @@ data CommandErrorType = -- | command is prohibited in this context PROHIBITED | -- | command is not supported by this entity - ENTITY + UNSUPPORTED | -- | command syntax is invalid SYNTAX | -- | cannot parse entity @@ -508,13 +553,21 @@ data CommandErrorType -- | Connection error. data ConnectionErrorType = -- | connection alias is not in the database - UNKNOWN + NOT_FOUND | -- | connection alias already exists DUPLICATE | -- | connection is simplex, but operation requires another queue SIMPLEX deriving (Eq, Generic, Read, Show, Exception) +-- | Broadcast error +data BroadcastErrorType + = -- | broadcast ID is not in the database + B_NOT_FOUND + | -- | broadcast ID already exists + B_DUPLICATE + deriving (Eq, Generic, Read, Show, Exception) + -- | SMP server errors. data BrokerErrorType = -- | invalid server response (failed to parse) @@ -547,25 +600,30 @@ instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU +instance Arbitrary BroadcastErrorType where arbitrary = genericArbitraryU + instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU -entityP :: Parser AnEntity -entityP = +anEntityP :: Parser AnEntity +anEntityP = ($) <$> ( "C:" $> AE . Conn <|> "O:" $> AE . OpenConn - <|> "B:" $> AE . BroadCast + <|> "B:" $> AE . Broadcast <|> "G:" $> AE . AGroup ) <*> A.takeTill (== ' ') +entityConnP :: Parser (Entity Conn_) +entityConnP = "C:" *> (Conn <$> A.takeTill (== ' ')) + serializeEntity :: Entity t -> ByteString serializeEntity = \case Conn s -> "C:" <> s OpenConn s -> "O:" <> s - BroadCast s -> "B:" <> s + Broadcast s -> "B:" <> s AGroup s -> "G:" <> s -- | SMP agent command and response parser @@ -582,6 +640,10 @@ commandP = <|> "MSG " *> message <|> "OFF" $> ACmd SClient OFF <|> "DEL" $> ACmd SClient DEL + <|> "ADD " *> addCmd + <|> "REM " *> removeCmd + <|> "LS" $> ACmd SClient LS + <|> "MS " *> membersResp <|> "ERR " *> agentError <|> "CON" $> ACmd SAgent CON <|> "OK" $> ACmd SAgent OK @@ -590,6 +652,9 @@ commandP = joinCmd = ACmd SClient <$> (JOIN <$> smpQueueInfoP <*> replyMode) sendCmd = ACmd SClient . SEND <$> A.takeByteString sentResp = ACmd SAgent . SENT <$> A.decimal + addCmd = ACmd SClient . ADD <$> entityConnP + removeCmd = ACmd SClient . REM <$> entityConnP + membersResp = ACmd SAgent . MS <$> (entityConnP `A.sepBy'` A.char ' ') message = do msgIntegrity <- msgIntegrityP <* A.space recipientMeta <- "R=" *> partyMeta A.decimal @@ -636,6 +701,10 @@ serializeCommand = \case ] OFF -> "OFF" DEL -> "DEL" + ADD c -> "ADD " <> serializeEntity c + REM c -> "REM " <> serializeEntity c + LS -> "LS" + MS cs -> "MS " <> B.intercalate " " (map serializeEntity cs) CON -> "CON" ERR e -> "ERR " <> serializeAgentError e OK -> "OK" @@ -663,6 +732,7 @@ serializeMsgIntegrity = \case agentErrorTypeP :: Parser AgentErrorType agentErrorTypeP = "SMP " *> (SMP <$> SMP.errorTypeP) + <|> "BCAST " *> (BCAST <$> bcastErrorP) <|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> SMP.errorTypeP) <|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP) <|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString) @@ -672,10 +742,19 @@ agentErrorTypeP = serializeAgentError :: AgentErrorType -> ByteString serializeAgentError = \case SMP e -> "SMP " <> SMP.serializeErrorType e + BCAST e -> "BCAST " <> serializeBcastError e BROKER (RESPONSE e) -> "BROKER RESPONSE " <> SMP.serializeErrorType e BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e e -> bshow e +bcastErrorP :: Parser BroadcastErrorType +bcastErrorP = "NOT_FOUND" $> B_NOT_FOUND <|> "DUPLICATE" $> B_DUPLICATE + +serializeBcastError :: BroadcastErrorType -> ByteString +serializeBcastError = \case + B_NOT_FOUND -> "NOT_FOUND" + B_DUPLICATE -> "DUPLICATE" + serializeMsg :: ByteString -> ByteString serializeMsg body = bshow (B.length body) <> "\n" <> body @@ -701,7 +780,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody where tParseLoadBody :: ARawTransmission -> m (ATransmissionOrError p) tParseLoadBody (corrId, entityStr, command) = - case parseAll entityP entityStr of + case parseAll anEntityP entityStr of Left _ -> pure $ ATransmissionOrError @_ @_ @ERR_ corrId (Conn "") $ Left $ CMD BAD_ENTITY Right entity -> do let cmd = parseCommand command >>= fromParty >>= hasEntityId entity @@ -730,7 +809,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody Left e -> err e Right (APartyCmd cmd) -> case entityCommand entity cmd of Just Dict -> ATransmissionOrError corrId entity $ Right cmd - _ -> err $ CMD ENTITY + _ -> err $ CMD UNSUPPORTED where err e = ATransmissionOrError @_ @_ @ERR_ corrId entity $ Left e diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index e31bb548f..6d3dc606f 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -51,6 +51,13 @@ class Monad m => MonadAgentStore s m where getMsg :: s -> ConnAlias -> InternalId -> m Msg + -- Broadcasts + createBcast :: s -> BroadcastId -> m () + addBcastConn :: s -> BroadcastId -> ConnAlias -> m () + removeBcastConn :: s -> BroadcastId -> ConnAlias -> m () + deleteBcast :: s -> BroadcastId -> m () + getBcast :: s -> BroadcastId -> m [ConnAlias] + -- * Queue types -- | A receive queue. SMP queue through which the agent receives messages from a sender. @@ -171,6 +178,10 @@ data SndMsgData = SndMsgData internalHash :: MsgHash } +-- * Broadcast types + +type BroadcastId = ByteString + -- * Message types -- | A message in either direction that is stored by the agent. @@ -283,6 +294,10 @@ data StoreError | -- | Wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa. -- 'upgradeRcvConnToDuplex' and 'upgradeSndConnToDuplex' do not allow duplex connections - they would also return this error. SEBadConnType ConnType + | -- | Broadcast ID not found. + SEBcastNotFound + | -- | Broadcast ID already used. + SEBcastDuplicate | -- | Currently not used. The intention was to pass current expected queue status in methods, -- as we always know what it should be at any stage of the protocol, -- and in case it does not match use this error. diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index a0c91ec15..9dcf7edd1 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -107,12 +107,12 @@ connectSQLiteStore dbFilePath = do |] pure SQLiteStore {dbFilePath, dbConn, dbNew} -checkDuplicate :: (MonadUnliftIO m, MonadError StoreError m) => IO () -> m () -checkDuplicate action = liftIOEither $ first handleError <$> E.try action +checkConstraint :: StoreError -> IO () -> IO (Either StoreError ()) +checkConstraint err action = first handleError <$> E.try action where handleError :: SQLError -> StoreError handleError e - | DB.sqlError e == DB.ErrorConstraint = SEConnDuplicate + | DB.sqlError e == DB.ErrorConstraint = err | otherwise = SEInternal $ bshow e withTransaction :: forall a. DB.Connection -> IO a -> IO a @@ -130,19 +130,21 @@ withTransaction db a = loop 100 100_000 instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where createRcvConn :: SQLiteStore -> RcvQueue -> m () createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} = - checkDuplicate $ - withTransaction dbConn $ do - upsertServer_ dbConn server - insertRcvQueue_ dbConn q - insertRcvConnection_ dbConn q + liftIOEither $ + checkConstraint SEConnDuplicate $ + withTransaction dbConn $ do + upsertServer_ dbConn server + insertRcvQueue_ dbConn q + insertRcvConnection_ dbConn q createSndConn :: SQLiteStore -> SndQueue -> m () createSndConn SQLiteStore {dbConn} q@SndQueue {server} = - checkDuplicate $ - withTransaction dbConn $ do - upsertServer_ dbConn server - insertSndQueue_ dbConn q - insertSndConnection_ dbConn q + liftIOEither $ + checkConstraint SEConnDuplicate $ + withTransaction dbConn $ do + upsertServer_ dbConn server + insertSndQueue_ dbConn q + insertSndConnection_ dbConn q getConn :: SQLiteStore -> ConnAlias -> m SomeConn getConn SQLiteStore {dbConn} connAlias = @@ -182,7 +184,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} = liftIOEither . withTransaction dbConn $ getConn_ dbConn connAlias >>= \case - Right (SomeConn SCRcv (RcvConnection _ _)) -> do + Right (SomeConn _ RcvConnection {}) -> do upsertServer_ dbConn server insertSndQueue_ dbConn sq updateConnWithSndQueue_ dbConn connAlias sq @@ -194,7 +196,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} = liftIOEither . withTransaction dbConn $ getConn_ dbConn connAlias >>= \case - Right (SomeConn SCSnd (SndConnection _ _)) -> do + Right (SomeConn _ SndConnection {}) -> do upsertServer_ dbConn server insertRcvQueue_ dbConn rq updateConnWithRcvQueue_ dbConn connAlias rq @@ -204,7 +206,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m () setRcvQueueStatus SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} status = - -- ? throw error if queue doesn't exist? + -- ? throw error if queue does not exist? liftIO $ DB.executeNamed dbConn @@ -217,7 +219,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto setRcvQueueActive :: SQLiteStore -> RcvQueue -> VerificationKey -> m () setRcvQueueActive SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey = - -- ? throw error if queue doesn't exist? + -- ? throw error if queue does not exist? liftIO $ DB.executeNamed dbConn @@ -235,7 +237,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m () setSndQueueStatus SQLiteStore {dbConn} SndQueue {sndId, server = SMPServer {host, port}} status = - -- ? throw error if queue doesn't exist? + -- ? throw error if queue does not exist? liftIO $ DB.executeNamed dbConn @@ -281,6 +283,48 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg getMsg _st _connAlias _id = throwError SENotImplemented + createBcast :: SQLiteStore -> BroadcastId -> m () + createBcast SQLiteStore {dbConn} bId = + liftIOEither $ + checkConstraint SEBcastDuplicate $ + DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId) + + addBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m () + addBcastConn SQLiteStore {dbConn} bId connAlias = + liftIOEither . checkBroadcast dbConn bId $ + getConn_ dbConn connAlias >>= \case + Left _ -> pure $ Left SEConnNotFound + Right (SomeConn _ RcvConnection {}) -> pure . Left $ SEBadConnType CRcv + Right _ -> + checkConstraint SEConnDuplicate $ + DB.execute + dbConn + "INSERT INTO broadcast_connections (broadcast_id, conn_alias) VALUES (?, ?);" + (bId, connAlias) + + removeBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m () + removeBcastConn SQLiteStore {dbConn} bId connAlias = + liftIOEither . checkBroadcast dbConn bId $ + bcastConnExists_ dbConn bId connAlias >>= \case + False -> pure $ Left SEConnNotFound + _ -> + Right + <$> DB.execute + dbConn + "DELETE FROM broadcast_connections WHERE broadcast_id = ? AND conn_alias = ?;" + (bId, connAlias) + + deleteBcast :: SQLiteStore -> BroadcastId -> m () + deleteBcast SQLiteStore {dbConn} bId = + liftIOEither . checkBroadcast dbConn bId $ + Right <$> DB.execute dbConn "DELETE FROM broadcasts WHERE broadcast_id = ?;" (Only bId) + + getBcast :: SQLiteStore -> BroadcastId -> m [ConnAlias] + getBcast SQLiteStore {dbConn} bId = + liftIOEither . checkBroadcast dbConn bId $ + Right . map fromOnly + <$> DB.query dbConn "SELECT conn_alias FROM broadcast_connections WHERE broadcast_id = ?;" (Only bId) + -- * Auxiliary helpers -- ? replace with ToField? - it's easy to forget to use this @@ -686,3 +730,31 @@ updateHashSnd_ dbConn connAlias SndMsgData {..} = ":conn_alias" := connAlias, ":last_internal_snd_msg_id" := internalSndId ] + +-- * Broadcast helpers + +checkBroadcast :: DB.Connection -> BroadcastId -> IO (Either StoreError a) -> IO (Either StoreError a) +checkBroadcast dbConn bId action = + withTransaction dbConn $ do + ok <- bcastExists_ dbConn bId + if ok then action else pure $ Left SEBcastNotFound + +bcastExists_ :: DB.Connection -> BroadcastId -> IO Bool +bcastExists_ dbConn bId = not . null <$> queryBcast + where + queryBcast :: IO [Only BroadcastId] + queryBcast = DB.query dbConn "SELECT broadcast_id FROM broadcasts WHERE broadcast_id = ?;" (Only bId) + +bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnAlias -> IO Bool +bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn + where + queryBcastConn :: IO [(BroadcastId, ConnAlias)] + queryBcastConn = + DB.query + dbConn + [sql| + SELECT broadcast_id, conn_alias + FROM broadcast_connections + WHERE broadcast_id = ? AND conn_alias = ?; + |] + (bId, connAlias) diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 2e4f7395f..a3d9d184f 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -38,6 +38,9 @@ agentTests (ATransport t) = do smpAgentTest3_1_1 $ testSubscription t it "should send notifications to client when server disconnects" $ smpAgentServerTest $ testSubscrNotification t + describe "Broadcast" do + it "should create broadcast and send messages" $ + smpAgentTest3 $ testBroadcast t type TestTransmission p = (ACorrId, ByteString, APartyCmd p) @@ -138,6 +141,50 @@ testSubscrNotification _ (server, _) client = do killThread server client <# ("", "C:conn1", END) +testBroadcast :: forall c. Transport c => TProxy c -> c -> c -> c -> IO () +testBroadcast _ alice bob tom = do + -- establish connections + (alice, "alice") `connect` (bob, "bob") + (alice, "alice") `connect` (tom, "tom") + -- create and set up broadcast + alice #: ("1", "B:team", "NEW") #> ("1", "B:team", OK) + alice #: ("2", "B:team", "ADD C:bob") #> ("2", "B:team", OK) + alice #: ("3", "B:team", "ADD C:tom") #> ("3", "B:team", OK) + -- commands with errors + alice #: ("e1", "B:team", "NEW") #> ("e1", "B:team", ERR $ BCAST B_DUPLICATE) + alice #: ("e2", "B:group", "ADD C:bob") #> ("e2", "B:group", ERR $ BCAST B_NOT_FOUND) + alice #: ("e3", "B:team", "ADD C:unknown") #> ("e3", "B:team", ERR $ CONN NOT_FOUND) + alice #: ("e4", "B:team", "ADD C:bob") #> ("e4", "B:team", ERR $ CONN DUPLICATE) + -- send message + alice #: ("4", "B:team", "SEND 5\nhello") #> ("4", "C:bob", SENT 1) + alice <# ("4", "C:tom", SENT 1) + alice <# ("4", "B:team", SENT 0) + bob <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False + tom <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False + -- remove one connection + alice #: ("5", "B:team", "REM C:tom") #> ("5", "B:team", OK) + alice #: ("6", "B:team", "SEND 11\nhello again") #> ("6", "C:bob", SENT 2) + alice <# ("6", "B:team", SENT 0) + bob <#= \case ("", "C:alice", Msg "hello again") -> True; _ -> False + tom #:# "nothing delivered to tom" + -- commands with errors + alice #: ("e5", "B:group", "REM C:bob") #> ("e5", "B:group", ERR $ BCAST B_NOT_FOUND) + alice #: ("e6", "B:team", "REM C:unknown") #> ("e6", "B:team", ERR $ CONN NOT_FOUND) + alice #: ("e7", "B:team", "REM C:tom") #> ("e7", "B:team", ERR $ CONN NOT_FOUND) + -- delete broadcast + alice #: ("7", "B:team", "DEL") #> ("7", "B:team", OK) + alice #: ("8", "B:team", "SEND 11\ntry sending") #> ("8", "B:team", ERR $ BCAST B_NOT_FOUND) + -- commands with errors + alice #: ("e8", "B:team", "DEL") #> ("e8", "B:team", ERR $ BCAST B_NOT_FOUND) + alice #: ("e9", "B:group", "DEL") #> ("e9", "B:group", ERR $ BCAST B_NOT_FOUND) + where + connect :: (c, ByteString) -> (c, ByteString) -> IO () + connect (h1, name1) (h2, name2) = do + ("c1", _, Right (Inv qInfo)) <- h1 #: ("c1", "C:" <> name2, "NEW") + let qInfo' = serializeSmpQueueInfo qInfo + h2 #: ("c2", "C:" <> name1, "JOIN " <> qInfo') =#> \case ("", c1, APartyCmd CON) -> c1 == "C:" <> name1; _ -> False + h1 <#= \case ("", c2, APartyCmd CON) -> c2 == "C:" <> name2; _ -> False + samplePublicKey :: ByteString samplePublicKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR"