diff --git a/migrations/20210529_broadcasts.sql b/migrations/20210529_broadcasts.sql index 3095f0572..91cba0eb6 100644 --- a/migrations/20210529_broadcasts.sql +++ b/migrations/20210529_broadcasts.sql @@ -1,9 +1,8 @@ -CREATE TABLE IF NOT EXISTS broadcasts ( - broadcast_id BLOB NOT NULL, - PRIMARY KEY (broadcast_id) +CREATE TABLE broadcasts ( + broadcast_id BLOB NOT NULL PRIMARY KEY ) WITHOUT ROWID; -CREATE TABLE IF NOT EXISTS broadcast_connections ( +CREATE TABLE broadcast_connections ( broadcast_id BLOB NOT NULL REFERENCES broadcasts (broadcast_id) ON DELETE CASCADE, conn_alias BLOB NOT NULL REFERENCES connections (conn_alias), PRIMARY KEY (broadcast_id, conn_alias) diff --git a/migrations/20210602_introductions.sql b/migrations/20210602_introductions.sql new file mode 100644 index 000000000..d382b2961 --- /dev/null +++ b/migrations/20210602_introductions.sql @@ -0,0 +1,27 @@ +CREATE TABLE conn_intros ( + intro_id BLOB NOT NULL PRIMARY KEY, + to_conn BLOB NOT NULL REFERENCES connections (conn_alias) ON DELETE CASCADE, + to_info BLOB, -- info about "to" connection sent to "re" connection + to_status TEXT NOT NULL DEFAULT '', -- '', INV, CON + re_conn BLOB NOT NULL REFERENCES connections (conn_alias) ON DELETE CASCADE, + re_info BLOB NOT NULL, -- info about "re" connection sent to "to" connection + re_status TEXT NOT NULL DEFAULT '', -- '', INV, CON + queue_info BLOB +) WITHOUT ROWID; + +CREATE TABLE conn_invitations ( + inv_id BLOB NOT NULL PRIMARY KEY, + via_conn BLOB REFERENCES connections (conn_alias) ON DELETE SET NULL, + external_intro_id BLOB NOT NULL, + conn_info BLOB, -- info about another connection + queue_info BLOB, -- NULL if it's an initial introduction + conn_id BLOB REFERENCES connections (conn_alias) -- created connection + ON DELETE CASCADE + DEFERRABLE INITIALLY DEFERRED, + status TEXT DEFAULT '' -- '', 'ACPT', 'CON' +) WITHOUT ROWID; + +ALTER TABLE connections + ADD via_inv BLOB REFERENCES conn_invitations (inv_id) ON DELETE RESTRICT; +ALTER TABLE connections + ADD conn_level INTEGER DEFAULT 0; diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 25bc15c91..e29ea41b0 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -61,7 +61,7 @@ import UnliftIO.STM -- | Runs an SMP agent as a TCP service using passed configuration. -- -- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs -runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m () +runSMPAgent :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m () runSMPAgent t cfg = do started <- newEmptyTMVarIO runSMPAgentBlocking t started cfg @@ -70,10 +70,10 @@ runSMPAgent t cfg = do -- -- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True) -- and when it is disconnected from the TCP socket once the server thread is killed (False). -runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m () +runSMPAgentBlocking :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m () runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort} = runReaderT (smpAgent t) =<< newSMPAgentEnv cfg where - smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' () + smpAgent :: forall c m'. (Transport c, MonadFail m', MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' () smpAgent _ = runTransportServer started tcpPort $ \(h :: c) -> do liftIO $ putLn h "Welcome to SMP v0.3.2 agent" c <- getSMPAgentClient @@ -97,7 +97,7 @@ logConnection c connected = in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"] -- | Runs an SMP agent instance that receives commands and sends responses via 'TBQueue's. -runSMPAgentClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () +runSMPAgentClient :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () runSMPAgentClient c = do db <- asks $ dbFile . config s1 <- liftIO $ connectSQLiteStore db @@ -128,7 +128,7 @@ logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> logClient AgentClient {clientId} dir (ATransmission corrId entity cmd) = do logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, serializeEntity entity, B.takeWhile (/= ' ') $ serializeCommand cmd] -client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m () +client :: forall m. (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m () client c@AgentClient {rcvQ, sndQ} st = forever loop where loop :: m () @@ -147,6 +147,8 @@ withStore action = do Right c -> return c Left e -> throwError $ storeError e where + -- TODO when parsing exception happens in store, the agent hangs; + -- changing SQLError to SomeException does not help handleInternal :: (MonadError StoreError m') => SQLError -> m' a handleInternal e = throwError . SEInternal $ bshow e storeError :: StoreError -> AgentErrorType @@ -173,23 +175,28 @@ unsupportedEntity c _ corrId entity _ = processConnCommand :: forall c m. (AgentMonad m, EntityCommand 'Conn_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> ACommand 'Client c -> m () -processConnCommand c@AgentClient {sndQ} st corrId conn = \case - NEW -> createNewConnection conn - JOIN smpQueueInfo replyMode -> joinConnection conn smpQueueInfo replyMode +processConnCommand c@AgentClient {sndQ} st corrId conn@(Conn connId) = \case + NEW -> createNewConnection Nothing 0 >>= uncurry respond + JOIN smpQueueInfo replyMode -> joinConnection smpQueueInfo replyMode Nothing 0 >> pure () -- >>= (`respond` OK) + INTRO reEntity reInfo -> makeIntroduction reEntity reInfo + ACPT inv eInfo -> acceptInvitation inv eInfo SUB -> subscribeConnection conn SUBALL -> subscribeAll - SEND msgBody -> sendMessage c st corrId conn msgBody + SEND msgBody -> sendClientMessage c st corrId conn msgBody OFF -> suspendConnection conn DEL -> deleteConnection conn where - createNewConnection :: Entity 'Conn_ -> m () - createNewConnection (Conn cId) = do + createNewConnection :: Maybe InvitationId -> Int -> m (Entity 'Conn_, ACommand 'Agent 'INV_) + createNewConnection viaInv connLevel = do -- TODO create connection alias if not passed - -- make connAlias Maybe? + -- make connId Maybe? srv <- getSMPServer - (rq, qInfo) <- newReceiveQueue c srv cId - withStore $ createRcvConn st rq - respond conn $ INV qInfo + (rq, qInfo) <- newReceiveQueue c srv + g <- asks idsDrg + let cData = ConnData {connId, viaInv, connLevel} + connId' <- withStore $ createRcvConn st g cData rq + addSubscription c rq connId' + pure (Conn connId', INV qInfo) getSMPServer :: m SMPServer getSMPServer = @@ -200,16 +207,47 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1) pure $ servers L.!! i - joinConnection :: Entity 'Conn_ -> SMPQueueInfo -> ReplyMode -> m () - joinConnection (Conn cId) qInfo (ReplyMode replyMode) = do - -- TODO create connection alias if not passed - -- make connAlias Maybe? - (sq, senderKey, verifyKey) <- newSendQueue qInfo cId - withStore $ createSndConn st sq + joinConnection :: SMPQueueInfo -> ReplyMode -> Maybe InvitationId -> Int -> m (Entity 'Conn_) + joinConnection qInfo (ReplyMode replyMode) viaInv connLevel = do + (sq, senderKey, verifyKey) <- newSendQueue qInfo + g <- asks idsDrg + let cData = ConnData {connId, viaInv, connLevel} + connId' <- withStore $ createSndConn st g cData sq connectToSendQueue c st sq senderKey verifyKey - when (replyMode == On) $ createReplyQueue cId sq - -- TODO this response is disabled to avoid two responses in terminal client (OK + CON), - -- respond conn OK + when (replyMode == On) $ createReplyQueue connId' sq + pure $ Conn connId' + + makeIntroduction :: IntroEntity -> EntityInfo -> m () + makeIntroduction (IE reEntity) reInfo = case reEntity of + Conn reConn -> + withStore ((,) <$> getConn st connId <*> getConn st reConn) >>= \case + (SomeConn _ (DuplexConnection _ _ sq), SomeConn _ DuplexConnection {}) -> do + g <- asks idsDrg + introId <- withStore $ createIntro st g NewIntroduction {toConn = connId, reConn, reInfo} + sendControlMessage c sq $ A_INTRO (IE (Conn introId)) reInfo + respond conn OK + _ -> throwError $ CONN SIMPLEX + _ -> throwError $ CMD UNSUPPORTED + + acceptInvitation :: IntroEntity -> EntityInfo -> m () + acceptInvitation (IE invEntity) eInfo = case invEntity of + Conn invId -> do + withStore (getInvitation st invId) >>= \case + Invitation {viaConn, qInfo, externalIntroId, status = InvNew} -> + withStore (getConn st viaConn) >>= \case + SomeConn _ (DuplexConnection ConnData {connLevel} _ sq) -> case qInfo of + Nothing -> do + (conn', INV qInfo') <- createNewConnection (Just invId) (connLevel + 1) + withStore $ addInvitationConn st invId $ fromConn conn' + sendControlMessage c sq $ A_INV (Conn externalIntroId) qInfo' eInfo + respond conn' OK + Just qInfo' -> do + conn' <- joinConnection qInfo' (ReplyMode On) (Just invId) (connLevel + 1) + withStore $ addInvitationConn st invId $ fromConn conn' + respond conn' OK + _ -> throwError $ CONN SIMPLEX + _ -> throwError $ CMD PROHIBITED + _ -> throwError $ CMD UNSUPPORTED subscribeConnection :: Entity 'Conn_ -> m () subscribeConnection conn'@(Conn cId) = @@ -222,7 +260,7 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case -- TODO remove - hack for subscribing to all; respond' and parameterization of subscribeConnection are byproduct subscribeAll :: m () - subscribeAll = withStore (getAllConnAliases st) >>= mapM_ (subscribeConnection . Conn) + subscribeAll = withStore (getAllConnIds st) >>= mapM_ (subscribeConnection . Conn) suspendConnection :: Entity 'Conn_ -> m () suspendConnection (Conn cId) = @@ -246,25 +284,30 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case removeSubscription c cId delConn - createReplyQueue :: ByteString -> SndQueue -> m () + createReplyQueue :: ConnId -> SndQueue -> m () createReplyQueue cId sq = do srv <- getSMPServer - (rq, qInfo) <- newReceiveQueue c srv cId + (rq, qInfo) <- newReceiveQueue c srv + addSubscription c rq cId withStore $ upgradeSndConnToDuplex st cId rq - senderTimestamp <- liftIO getCurrentTime - sendAgentMessage c sq . serializeSMPMessage $ - SMPMessage - { senderMsgId = 0, - senderTimestamp, - previousMsgHash = "", - agentMessage = REPLY qInfo - } + sendControlMessage c sq $ REPLY qInfo respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m () respond ent resp = atomically . writeTBQueue sndQ $ ATransmission corrId ent resp -sendMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m () -sendMessage c st corrId (Conn cId) msgBody = +sendControlMessage :: AgentMonad m => AgentClient -> SndQueue -> AMessage -> m () +sendControlMessage c sq agentMessage = do + senderTimestamp <- liftIO getCurrentTime + sendAgentMessage c sq . serializeSMPMessage $ + SMPMessage + { senderMsgId = 0, + senderTimestamp, + previousMsgHash = "", + agentMessage + } + +sendClientMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m () +sendClientMessage c st corrId (Conn cId) msgBody = withStore (getConn st cId) >>= \case SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq SomeConn _ (SndConnection _ sq) -> sendMsg sq @@ -273,7 +316,7 @@ sendMessage c st corrId (Conn cId) msgBody = sendMsg :: SndQueue -> m () sendMsg sq = do internalTs <- liftIO getCurrentTime - (internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq + (internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st cId let msgStr = serializeSMPMessage SMPMessage @@ -284,7 +327,7 @@ sendMessage c st corrId (Conn cId) msgBody = } msgHash = C.sha256Hash msgStr withStore $ - createSndMsg st sq $ + createSndMsg st cId $ SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash} sendAgentMessage c sq msgStr atomically . writeTBQueue (sndQ c) $ ATransmission corrId (Conn cId) $ SENT (unId internalId) @@ -292,15 +335,21 @@ sendMessage c st corrId (Conn cId) msgBody = processBroadcastCommand :: forall c m. (AgentMonad m, EntityCommand 'Broadcast_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Broadcast_ -> ACommand 'Client c -> m () processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case - NEW -> withStore (createBcast st bId) >> ok + NEW -> createNewBroadcast ADD (Conn cId) -> withStore (addBcastConn st bId cId) >> ok REM (Conn cId) -> withStore (removeBcastConn st bId cId) >> ok LS -> withStore (getBcast st bId) >>= respond bcast . MS . map Conn SEND msgBody -> withStore (getBcast st bId) >>= mapM_ (sendMsg msgBody) >> respond bcast (SENT 0) DEL -> withStore (deleteBcast st bId) >> ok where - sendMsg :: MsgBody -> ConnAlias -> m () - sendMsg msgBody cId = sendMessage c st corrId (Conn cId) msgBody + createNewBroadcast :: m () + createNewBroadcast = do + g <- asks idsDrg + bId' <- withStore $ createBcast st g bId + respond (Broadcast bId') OK + + sendMsg :: MsgBody -> ConnId -> m () + sendMsg msgBody cId = sendClientMessage c st corrId (Conn cId) msgBody ok :: m () ok = respond bcast OK @@ -308,7 +357,7 @@ processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m () respond ent resp = atomically . writeTBQueue (sndQ c) $ ATransmission corrId ent resp -subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m () +subscriber :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m () subscriber c@AgentClient {msgQ} st = forever $ do -- TODO this will only process messages and notifications t <- atomically $ readTBQueue msgQ @@ -319,29 +368,33 @@ subscriber c@AgentClient {msgQ} st = forever $ do processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m () processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do withStore (getRcvConn st srv rId) >>= \case - SomeConn SCDuplex (DuplexConnection _ rq _) -> processSMP SCDuplex rq - SomeConn SCRcv (RcvConnection _ rq) -> processSMP SCRcv rq + SomeConn SCDuplex (DuplexConnection cData rq _) -> processSMP SCDuplex cData rq + SomeConn SCRcv (RcvConnection cData rq) -> processSMP SCRcv cData rq _ -> atomically . writeTBQueue sndQ $ ATransmission "" (Conn "") (ERR $ CONN SIMPLEX) where - processSMP :: SConnType c -> RcvQueue -> m () - processSMP cType rq@RcvQueue {connAlias, status} = + processSMP :: SConnType c -> ConnData -> RcvQueue -> m () + processSMP cType ConnData {connId} rq@RcvQueue {status} = case cmd of SMP.MSG srvMsgId srvTs msgBody -> do -- TODO deduplicate with previously received msg <- decryptAndVerify rq msgBody let msgHash = C.sha256Hash msg - agentMsg <- liftEither $ parseSMPMessage msg - case agentMsg of - SMPConfirmation senderKey -> smpConfirmation senderKey - SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} -> + case parseSMPMessage msg of + Left e -> notify $ ERR e + Right (SMPConfirmation senderKey) -> smpConfirmation senderKey + Right SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} -> case agentMessage of HELLO verifyKey _ -> helloMsg verifyKey msgBody REPLY qInfo -> replyMsg qInfo A_MSG body -> agentClientMsg previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash + A_INTRO entity eInfo -> introMsg entity eInfo + A_INV conn qInfo cInfo -> invMsg conn qInfo cInfo + A_REQ conn qInfo cInfo -> reqMsg conn qInfo cInfo + A_CON conn -> conMsg conn sendAck c rq return () SMP.END -> do - removeSubscription c connAlias + removeSubscription c connId logServer "<--" c srv rId "END" notify END _ -> do @@ -349,7 +402,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do notify . ERR $ BROKER UNEXPECTED where notify :: EntityCommand 'Conn_ c => ACommand 'Agent c -> m () - notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connAlias) msg + notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connId) msg prohibited :: m () prohibited = notify . ERR $ AGENT A_PROHIBITED @@ -376,7 +429,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do void $ verifyMessage (Just verifyKey) msgBody withStore $ setRcvQueueActive st rq verifyKey case cType of - SCDuplex -> notify CON + SCDuplex -> connected _ -> pure () replyMsg :: SMPQueueInfo -> m () @@ -384,22 +437,87 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do logServer "<--" c srv rId "MSG " case cType of SCRcv -> do - (sq, senderKey, verifyKey) <- newSendQueue qInfo connAlias - withStore $ upgradeRcvConnToDuplex st connAlias sq + (sq, senderKey, verifyKey) <- newSendQueue qInfo + withStore $ upgradeRcvConnToDuplex st connId sq connectToSendQueue c st sq senderKey verifyKey - notify CON + connected _ -> prohibited + connected :: m () + connected = do + withStore (getConnInvitation st connId) >>= \case + Just (Invitation {invId, externalIntroId}, DuplexConnection _ _ sq) -> do + withStore $ setInvitationStatus st invId InvCon + sendControlMessage c sq $ A_CON (Conn externalIntroId) + _ -> pure () + notify CON + + introMsg :: IntroEntity -> EntityInfo -> m () + introMsg (IE entity) entityInfo = do + logServer "<--" c srv rId "MSG " + case (cType, entity) of + (SCDuplex, intro@Conn {}) -> createInv intro Nothing entityInfo + _ -> prohibited + + invMsg :: Entity 'Conn_ -> SMPQueueInfo -> EntityInfo -> m () + invMsg (Conn introId) qInfo toInfo = do + logServer "<--" c srv rId "MSG " + case cType of + SCDuplex -> + withStore (getIntro st introId) >>= \case + Introduction {toConn, toStatus = IntroNew, reConn, reStatus = IntroNew} + | toConn /= connId -> prohibited + | otherwise -> + withStore (addIntroInvitation st introId toInfo qInfo >> getConn st reConn) >>= \case + SomeConn _ (DuplexConnection _ _ sq) -> do + sendControlMessage c sq $ A_REQ (Conn introId) qInfo toInfo + withStore $ setIntroReStatus st introId IntroInv + _ -> prohibited + _ -> prohibited + _ -> prohibited + + reqMsg :: Entity 'Conn_ -> SMPQueueInfo -> EntityInfo -> m () + reqMsg intro qInfo connInfo = do + logServer "<--" c srv rId "MSG " + case cType of + SCDuplex -> createInv intro (Just qInfo) connInfo + _ -> prohibited + + createInv :: Entity 'Conn_ -> Maybe SMPQueueInfo -> EntityInfo -> m () + createInv (Conn externalIntroId) qInfo entityInfo = do + g <- asks idsDrg + let newInv = NewInvitation {viaConn = connId, externalIntroId, entityInfo, qInfo} + invId <- withStore $ createInvitation st g newInv + notify $ REQ (IE (Conn invId)) entityInfo + + conMsg :: Entity 'Conn_ -> m () + conMsg (Conn introId) = do + logServer "<--" c srv rId "MSG " + withStore (getIntro st introId) >>= \case + Introduction {toConn, toStatus, reConn, reStatus} + | toConn == connId && toStatus == IntroInv -> do + withStore $ setIntroToStatus st introId IntroCon + sendConMsg toConn reConn reStatus + | reConn == connId && reStatus == IntroInv -> do + withStore $ setIntroReStatus st introId IntroCon + sendConMsg toConn reConn toStatus + | otherwise -> prohibited + where + sendConMsg :: ConnId -> ConnId -> IntroStatus -> m () + sendConMsg toConn reConn IntroCon = + atomically . writeTBQueue sndQ $ ATransmission "" (Conn toConn) $ ICON $ IE (Conn reConn) + sendConMsg _ _ _ = pure () + agentClientMsg :: PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m () agentClientMsg receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do logServer "<--" c srv rId "MSG " case status of Active -> do internalTs <- liftIO getCurrentTime - (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq + (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st connId let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash receivedPrevMsgHash withStore $ - createRcvMsg st rq $ + createRcvMsg st connId $ RcvMsgData { internalId, internalRcvId, @@ -438,8 +556,8 @@ connectToSendQueue c st sq senderKey verifyKey = do withStore $ setSndQueueStatus st sq Active newSendQueue :: - (MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey) -newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do + (MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SndQueue, SenderPublicKey, VerificationKey) +newSendQueue (SMPQueueInfo smpServer senderId encryptKey) = do size <- asks $ rsaKeySize . config (senderKey, sndPrivateKey) <- liftIO $ C.generateKeyPair size (verifyKey, signKey) <- liftIO $ C.generateKeyPair size @@ -447,7 +565,6 @@ newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do SndQueue { server = smpServer, sndId = senderId, - connAlias, sndPrivateKey, encryptKey, signKey, diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 2c6bea6f1..f440a6df9 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -15,6 +15,7 @@ module Simplex.Messaging.Agent.Client closeSMPServerClients, newReceiveQueue, subscribeQueue, + addSubscription, sendConfirmation, sendHello, secureQueue, @@ -61,8 +62,8 @@ data AgentClient = AgentClient sndQ :: TBQueue (ATransmission 'Agent), msgQ :: TBQueue SMPServerTransmission, smpClients :: TVar (Map SMPServer SMPClient), - subscrSrvrs :: TVar (Map SMPServer (Set ConnAlias)), - subscrConns :: TVar (Map ConnAlias SMPServer), + subscrSrvrs :: TVar (Map SMPServer (Set ConnId)), + subscrConns :: TVar (Map ConnId SMPServer), clientId :: Int } @@ -78,7 +79,7 @@ newAgentClient cc AgentConfig {tbqSize} = do writeTVar cc clientId return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId} -type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) +type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m, MonadFail m) getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient getSMPServerClient c@AgentClient {smpClients, msgQ} srv = @@ -106,7 +107,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = removeSubs >>= mapM_ (mapM_ notifySub) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv - removeSubs :: IO (Maybe (Set ConnAlias)) + removeSubs :: IO (Maybe (Set ConnId)) removeSubs = atomically $ do modifyTVar smpClients $ M.delete srv cs <- M.lookup srv <$> readTVar (subscrSrvrs c) @@ -117,7 +118,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = deleteKeys :: Ord k => Set k -> Map k a -> Map k a deleteKeys ks m = S.foldr' M.delete m ks - notifySub :: ConnAlias -> IO () + notifySub :: ConnId -> IO () notifySub connAlias = atomically . writeTBQueue (sndQ c) $ ATransmission "" (Conn connAlias) END closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m () @@ -158,8 +159,8 @@ smpClientError = \case SMPTransportError e -> BROKER $ TRANSPORT e e -> INTERNAL $ show e -newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo) -newReceiveQueue c srv connAlias = do +newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> m (RcvQueue, SMPQueueInfo) +newReceiveQueue c srv = do size <- asks $ rsaKeySize . config (recipientKey, rcvPrivateKey) <- liftIO $ C.generateKeyPair size logServer "-->" c srv "" "NEW" @@ -170,7 +171,6 @@ newReceiveQueue c srv connAlias = do RcvQueue { server = srv, rcvId, - connAlias, rcvPrivateKey, sndId = Just sId, sndKey = Nothing, @@ -178,25 +178,24 @@ newReceiveQueue c srv connAlias = do verifyKey = Nothing, status = New } - addSubscription c rq connAlias return (rq, SMPQueueInfo srv sId encryptKey) -subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnAlias -> m () +subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m () subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connAlias = do withLogSMP c server rcvId "SUB" $ \smp -> subscribeSMPQueue smp rcvPrivateKey rcvId addSubscription c rq connAlias -addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnAlias -> m () +addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m () addSubscription c RcvQueue {server} connAlias = atomically $ do modifyTVar (subscrConns c) $ M.insert connAlias server modifyTVar (subscrSrvrs c) $ M.alter (Just . addSub) server where - addSub :: Maybe (Set ConnAlias) -> Set ConnAlias + addSub :: Maybe (Set ConnId) -> Set ConnId addSub (Just cs) = S.insert connAlias cs addSub _ = S.singleton connAlias -removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m () +removeSubscription :: AgentMonad m => AgentClient -> ConnId -> m () removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically $ do cs <- readTVar subscrConns writeTVar subscrConns $ M.delete connAlias cs @@ -204,7 +203,7 @@ removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically (modifyTVar subscrSrvrs . M.alter (>>= delSub)) (M.lookup connAlias cs) where - delSub :: Set ConnAlias -> Maybe (Set ConnAlias) + delSub :: Set ConnId -> Maybe (Set ConnId) delSub cs = let cs' = S.delete connAlias cs in if S.null cs' then Nothing else Just cs' diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 07f135440..16f579504 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -33,6 +33,8 @@ module Simplex.Messaging.Agent.Protocol Entity (..), EntityTag (..), AnEntity (..), + IntroEntity (..), + EntityInfo, EntityCommand, entityCommand, ACommand (..), @@ -53,7 +55,7 @@ module Simplex.Messaging.Agent.Protocol ATransmission (..), ATransmissionOrError (..), ARawTransmission, - ConnAlias, + ConnId, ReplyMode (..), AckMode (..), OnOff (..), @@ -171,6 +173,13 @@ deriving instance Eq (Entity t) deriving instance Show (Entity t) +instance TestEquality Entity where + testEquality (Conn c) (Conn c') = refl c c' + testEquality (OpenConn c) (OpenConn c') = refl c c' + testEquality (Broadcast c) (Broadcast c') = refl c c' + testEquality (AGroup c) (AGroup c') = refl c c' + testEquality _ _ = Nothing + entityId :: Entity t -> ByteString entityId = \case Conn bs -> bs @@ -195,7 +204,11 @@ type family EntityCommand (t :: EntityTag) (c :: ACmdTag) :: Constraint where EntityCommand Conn_ NEW_ = () EntityCommand Conn_ INV_ = () EntityCommand Conn_ JOIN_ = () + EntityCommand Conn_ INTRO_ = () + EntityCommand Conn_ REQ_ = () + EntityCommand Conn_ ACPT_ = () EntityCommand Conn_ CON_ = () + EntityCommand Conn_ ICON_ = () EntityCommand Conn_ SUB_ = () EntityCommand Conn_ SUBALL_ = () EntityCommand Conn_ END_ = () @@ -226,7 +239,11 @@ entityCommand = \case NEW -> Just Dict INV _ -> Just Dict JOIN {} -> Just Dict + INTRO {} -> Just Dict + REQ {} -> Just Dict + ACPT {} -> Just Dict CON -> Just Dict + ICON {} -> Just Dict SUB -> Just Dict SUBALL -> Just Dict END -> Just Dict @@ -258,7 +275,11 @@ data ACmdTag = NEW_ | INV_ | JOIN_ + | INTRO_ + | REQ_ + | ACPT_ | CON_ + | ICON_ | SUB_ | SUBALL_ | END_ @@ -274,15 +295,31 @@ data ACmdTag | OK_ | ERR_ +type family Introduction (t :: EntityTag) :: Constraint where + Introduction Conn_ = () + Introduction OpenConn_ = () + Introduction AGroup_ = () + Introduction t = (Int ~ Bool, TypeError (Text "Entity " :<>: ShowType t :<>: Text " cannot be INTRO'd to")) + +data IntroEntity = forall t. Introduction t => IE (Entity t) + +instance Eq IntroEntity where + IE e1 == IE e2 = isJust $ testEquality e1 e2 + +deriving instance Show IntroEntity + +type EntityInfo = ByteString + -- | Parameterized type for SMP agent protocol commands and responses from all participants. data ACommand (p :: AParty) (c :: ACmdTag) where NEW :: ACommand Client NEW_ -- response INV INV :: SMPQueueInfo -> ACommand Agent INV_ JOIN :: SMPQueueInfo -> ReplyMode -> ACommand Client JOIN_ -- response OK + INTRO :: IntroEntity -> EntityInfo -> ACommand Client INTRO_ + REQ :: IntroEntity -> EntityInfo -> ACommand Agent INTRO_ + ACPT :: IntroEntity -> EntityInfo -> ACommand Client ACPT_ CON :: ACommand Agent CON_ -- notification that connection is established - -- TODO currently it automatically allows whoever sends the confirmation - -- CONF :: OtherPartyId -> ACommand Agent - -- LET :: OtherPartyId -> ACommand Client + ICON :: IntroEntity -> ACommand Agent ICON_ SUB :: ACommand Client SUB_ SUBALL :: ACommand Client SUBALL_ -- TODO should be moved to chat protocol - hack for subscribing to all END :: ACommand Agent END_ @@ -318,6 +355,7 @@ instance TestEquality (ACommand p) where testEquality c@INV {} c'@INV {} = refl c c' testEquality c@JOIN {} c'@JOIN {} = refl c c' testEquality CON CON = Just Refl + testEquality c@ICON {} c'@ICON {} = refl c c' testEquality SUB SUB = Just Refl testEquality SUBALL SUBALL = Just Refl testEquality END END = Just Refl @@ -334,7 +372,7 @@ instance TestEquality (ACommand p) where testEquality c@ERR {} c'@ERR {} = refl c c' testEquality _ _ = Nothing -refl :: Eq (f a) => f a -> f a -> Maybe (a :~: a) +refl :: Eq a => a -> a -> Maybe (t :~: t) refl x x' = if x == x' then Just Refl else Nothing -- | SMP message formats. @@ -366,6 +404,14 @@ data AMessage where REPLY :: SMPQueueInfo -> AMessage -- | agent envelope for the client message A_MSG :: MsgBody -> AMessage + -- | agent message for introduction + A_INTRO :: IntroEntity -> EntityInfo -> AMessage + -- | agent envelope for the sent invitation + A_INV :: Entity Conn_ -> SMPQueueInfo -> EntityInfo -> AMessage + -- | agent envelope for the forwarded invitation + A_REQ :: Entity Conn_ -> SMPQueueInfo -> EntityInfo -> AMessage + -- | agent message for intro/group request + A_CON :: Entity Conn_ -> AMessage deriving (Show) -- | Parse SMP message. @@ -408,12 +454,22 @@ agentMessageP = "HELLO " *> hello <|> "REPLY " *> reply <|> "MSG " *> a_msg + <|> "INTRO " *> a_intro + <|> "INV " *> a_inv + <|> "REQ " *> a_req + <|> "CON " *> a_con where hello = HELLO <$> C.pubKeyP <*> ackMode reply = REPLY <$> smpQueueInfoP - a_msg = do + a_msg = A_MSG <$> binaryBody + a_intro = A_INTRO <$> introEntityP <* A.space <*> binaryBody + a_inv = invP A_INV + a_req = invP A_REQ + a_con = A_CON <$> connEntityP + invP f = f <$> connEntityP <* A.space <*> smpQueueInfoP <* A.space <*> binaryBody + binaryBody = do size :: Int <- A.decimal <* A.endOfLine - A_MSG <$> A.take size <* A.endOfLine + A.take size <* A.endOfLine ackMode = AckMode <$> (" NO_ACK" $> Off <|> pure On) -- | SMP queue information parser. @@ -434,6 +490,13 @@ serializeAgentMessage = \case HELLO verifyKey ackMode -> "HELLO " <> C.serializePubKey verifyKey <> if ackMode == AckMode Off then " NO_ACK" else "" REPLY qInfo -> "REPLY " <> serializeSmpQueueInfo qInfo A_MSG body -> "MSG " <> serializeMsg body <> "\n" + A_INTRO (IE entity) eInfo -> "INTRO " <> serializeIntro entity eInfo <> "\n" + A_INV conn qInfo eInfo -> "INV " <> serializeInv conn qInfo eInfo + A_REQ conn qInfo eInfo -> "REQ " <> serializeInv conn qInfo eInfo + A_CON conn -> "CON " <> serializeEntity conn + where + serializeInv conn qInfo eInfo = + B.intercalate " " [serializeEntity conn, serializeSmpQueueInfo qInfo, serializeMsg eInfo] <> "\n" -- | Serialize SMP queue information that is sent out-of-band. serializeSmpQueueInfo :: SMPQueueInfo -> ByteString @@ -457,7 +520,7 @@ instance IsString SMPServer where fromString = parseString . parseAll $ smpServerP -- | SMP agent connection alias. -type ConnAlias = ByteString +type ConnId = ByteString -- | Connection modes. data OnOff = On | Off deriving (Eq, Show, Read) @@ -614,10 +677,19 @@ anEntityP = <|> "B:" $> AE . Broadcast <|> "G:" $> AE . AGroup ) - <*> A.takeTill (== ' ') + <*> A.takeTill wordEnd -entityConnP :: Parser (Entity Conn_) -entityConnP = "C:" *> (Conn <$> A.takeTill (== ' ')) +connEntityP :: Parser (Entity Conn_) +connEntityP = "C:" *> (Conn <$> A.takeTill wordEnd) + +introEntityP :: Parser IntroEntity +introEntityP = + ($) + <$> ( "C:" $> IE . Conn + <|> "O:" $> IE . OpenConn + <|> "G:" $> IE . AGroup + ) + <*> A.takeTill wordEnd serializeEntity :: Entity t -> ByteString serializeEntity = \case @@ -632,6 +704,9 @@ commandP = "NEW" $> ACmd SClient NEW <|> "INV " *> invResp <|> "JOIN " *> joinCmd + <|> "INTRO " *> introCmd + <|> "REQ " *> reqCmd + <|> "ACPT " *> acptCmd <|> "SUB" $> ACmd SClient SUB <|> "SUBALL" $> ACmd SClient SUBALL -- TODO remove - hack for subscribing to all <|> "END" $> ACmd SAgent END @@ -645,16 +720,21 @@ commandP = <|> "LS" $> ACmd SClient LS <|> "MS " *> membersResp <|> "ERR " *> agentError + <|> "ICON " *> iconMsg <|> "CON" $> ACmd SAgent CON <|> "OK" $> ACmd SAgent OK where invResp = ACmd SAgent . INV <$> smpQueueInfoP joinCmd = ACmd SClient <$> (JOIN <$> smpQueueInfoP <*> replyMode) + introCmd = ACmd SClient <$> introP INTRO + reqCmd = ACmd SAgent <$> introP REQ + acptCmd = ACmd SClient <$> introP ACPT sendCmd = ACmd SClient . SEND <$> A.takeByteString sentResp = ACmd SAgent . SENT <$> A.decimal - addCmd = ACmd SClient . ADD <$> entityConnP - removeCmd = ACmd SClient . REM <$> entityConnP - membersResp = ACmd SAgent . MS <$> (entityConnP `A.sepBy'` A.char ' ') + addCmd = ACmd SClient . ADD <$> connEntityP + removeCmd = ACmd SClient . REM <$> connEntityP + membersResp = ACmd SAgent . MS <$> (connEntityP `A.sepBy'` A.char ' ') + iconMsg = ACmd SAgent . ICON <$> introEntityP message = do msgIntegrity <- msgIntegrityP <* A.space recipientMeta <- "R=" *> partyMeta A.decimal @@ -662,6 +742,7 @@ commandP = senderMeta <- "S=" *> partyMeta A.decimal msgBody <- A.takeByteString return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody} + introP f = f <$> introEntityP <* A.space <*> A.takeByteString replyMode = ReplyMode <$> (" NO_REPLY" $> Off <|> pure On) partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space agentError = ACmd SAgent . ERR <$> agentErrorTypeP @@ -685,6 +766,9 @@ serializeCommand = \case NEW -> "NEW" INV qInfo -> "INV " <> serializeSmpQueueInfo qInfo JOIN qInfo rMode -> "JOIN " <> serializeSmpQueueInfo qInfo <> replyMode rMode + INTRO (IE entity) eInfo -> "INTRO " <> serializeIntro entity eInfo + REQ (IE entity) eInfo -> "REQ " <> serializeIntro entity eInfo + ACPT (IE entity) eInfo -> "ACPT " <> serializeIntro entity eInfo SUB -> "SUB" SUBALL -> "SUBALL" -- TODO remove - hack for subscribing to all END -> "END" @@ -706,6 +790,7 @@ serializeCommand = \case LS -> "LS" MS cs -> "MS " <> B.intercalate " " (map serializeEntity cs) CON -> "CON" + ICON (IE entity) -> "ICON " <> serializeEntity entity ERR e -> "ERR " <> serializeAgentError e OK -> "OK" where @@ -716,6 +801,9 @@ serializeCommand = \case showTs :: UTCTime -> ByteString showTs = B.pack . formatISO8601Millis +serializeIntro :: Entity t -> ByteString -> ByteString +serializeIntro entity eInfo = serializeEntity entity <> " " <> serializeMsg eInfo + -- | Serialize message integrity validation result. serializeMsgIntegrity :: MsgIntegrity -> ByteString serializeMsgIntegrity = \case @@ -794,9 +882,10 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody hasEntityId :: AnEntity -> APartyCmd p -> Either AgentErrorType (APartyCmd p) hasEntityId (AE entity) (APartyCmd cmd) = APartyCmd <$> case cmd of - -- NEW and JOIN have optional entity + -- NEW, JOIN and ACPT have optional entity NEW -> Right cmd - JOIN _ _ -> Right cmd + JOIN {} -> Right cmd + ACPT {} -> Right cmd -- ERROR response does not always have entity ERR _ -> Right cmd -- other responses must have entity @@ -818,6 +907,9 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody APartyCmd <$$> case cmd of SEND body -> SEND <$$> getMsgBody body MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body + INTRO entity eInfo -> INTRO entity <$$> getMsgBody eInfo + REQ entity eInfo -> REQ entity <$$> getMsgBody eInfo + ACPT entity eInfo -> ACPT entity <$$> getMsgBody eInfo _ -> pure $ Right cmd -- TODO refactor with server diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 6d3dc606f..d1e71a3f6 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -3,16 +3,21 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE StandaloneDeriving #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} module Simplex.Messaging.Agent.Store where +import Control.Concurrent.STM (TVar) import Control.Exception (Exception) +import Crypto.Random (ChaChaDRG) import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) import Data.Kind (Type) +import Data.Text (Text) import Data.Time (UTCTime) import Data.Type.Equality import Simplex.Messaging.Agent.Protocol @@ -30,33 +35,45 @@ import qualified Simplex.Messaging.Protocol as SMP -- | Store class type. Defines store access methods for implementations. class Monad m => MonadAgentStore s m where -- Queue and Connection management - createRcvConn :: s -> RcvQueue -> m () - createSndConn :: s -> SndQueue -> m () - getConn :: s -> ConnAlias -> m SomeConn - getAllConnAliases :: s -> m [ConnAlias] -- TODO remove - hack for subscribing to all + createRcvConn :: s -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId + createSndConn :: s -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId + getConn :: s -> ConnId -> m SomeConn + getAllConnIds :: s -> m [ConnId] -- TODO remove - hack for subscribing to all getRcvConn :: s -> SMPServer -> SMP.RecipientId -> m SomeConn - deleteConn :: s -> ConnAlias -> m () - upgradeRcvConnToDuplex :: s -> ConnAlias -> SndQueue -> m () - upgradeSndConnToDuplex :: s -> ConnAlias -> RcvQueue -> m () + deleteConn :: s -> ConnId -> m () + upgradeRcvConnToDuplex :: s -> ConnId -> SndQueue -> m () + upgradeSndConnToDuplex :: s -> ConnId -> RcvQueue -> m () setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m () setRcvQueueActive :: s -> RcvQueue -> VerificationKey -> m () setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m () -- Msg management - updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) - createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m () + updateRcvIds :: s -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) + createRcvMsg :: s -> ConnId -> RcvMsgData -> m () - updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash) - createSndMsg :: s -> SndQueue -> SndMsgData -> m () + updateSndIds :: s -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash) + createSndMsg :: s -> ConnId -> SndMsgData -> m () - getMsg :: s -> ConnAlias -> InternalId -> m Msg + getMsg :: s -> ConnId -> InternalId -> m Msg -- Broadcasts - createBcast :: s -> BroadcastId -> m () - addBcastConn :: s -> BroadcastId -> ConnAlias -> m () - removeBcastConn :: s -> BroadcastId -> ConnAlias -> m () + createBcast :: s -> TVar ChaChaDRG -> BroadcastId -> m BroadcastId + addBcastConn :: s -> BroadcastId -> ConnId -> m () + removeBcastConn :: s -> BroadcastId -> ConnId -> m () deleteBcast :: s -> BroadcastId -> m () - getBcast :: s -> BroadcastId -> m [ConnAlias] + getBcast :: s -> BroadcastId -> m [ConnId] + + -- Introductions + createIntro :: s -> TVar ChaChaDRG -> NewIntroduction -> m IntroId + getIntro :: s -> IntroId -> m Introduction + addIntroInvitation :: s -> IntroId -> EntityInfo -> SMPQueueInfo -> m () + setIntroToStatus :: s -> IntroId -> IntroStatus -> m () + setIntroReStatus :: s -> IntroId -> IntroStatus -> m () + createInvitation :: s -> TVar ChaChaDRG -> NewInvitation -> m InvitationId + getInvitation :: s -> InvitationId -> m Invitation + addInvitationConn :: s -> InvitationId -> ConnId -> m () + getConnInvitation :: s -> ConnId -> m (Maybe (Invitation, Connection CDuplex)) + setInvitationStatus :: s -> InvitationId -> InvitationStatus -> m () -- * Queue types @@ -64,7 +81,6 @@ class Monad m => MonadAgentStore s m where data RcvQueue = RcvQueue { server :: SMPServer, rcvId :: SMP.RecipientId, - connAlias :: ConnAlias, rcvPrivateKey :: RecipientPrivateKey, sndId :: Maybe SMP.SenderId, sndKey :: Maybe SenderPublicKey, @@ -78,7 +94,6 @@ data RcvQueue = RcvQueue data SndQueue = SndQueue { server :: SMPServer, sndId :: SMP.SenderId, - connAlias :: ConnAlias, sndPrivateKey :: SenderPrivateKey, encryptKey :: EncryptionKey, signKey :: SignatureKey, @@ -102,9 +117,9 @@ data ConnType = CRcv | CSnd | CDuplex deriving (Eq, Show) -- - DuplexConnection is a connection that has both receive and send queues set up, -- typically created by upgrading a receive or a send connection with a missing queue. data Connection (d :: ConnType) where - RcvConnection :: ConnAlias -> RcvQueue -> Connection CRcv - SndConnection :: ConnAlias -> SndQueue -> Connection CSnd - DuplexConnection :: ConnAlias -> RcvQueue -> SndQueue -> Connection CDuplex + RcvConnection :: ConnData -> RcvQueue -> Connection CRcv + SndConnection :: ConnData -> SndQueue -> Connection CSnd + DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex deriving instance Eq (Connection d) @@ -141,6 +156,9 @@ instance Eq SomeConn where deriving instance Show SomeConn +data ConnData = ConnData {connId :: ConnId, viaInv :: Maybe InvitationId, connLevel :: Int} + deriving (Eq, Show) + -- * Message integrity validation types type MsgHash = ByteString @@ -263,7 +281,7 @@ type DeliveredTs = UTCTime -- | Base message data independent of direction. data MsgBase = MsgBase - { connAlias :: ConnAlias, + { connAlias :: ConnId, -- | Monotonically increasing id of a message per connection, internal to the agent. -- Internal Id preserves ordering between both received and sent messages, and is needed -- to track the order of the conversation (which can be different for the sender / receiver) @@ -281,12 +299,87 @@ newtype InternalId = InternalId {unId :: Int64} deriving (Eq, Show) type InternalTs = UTCTime +-- * Introduction types + +data NewIntroduction = NewIntroduction + { toConn :: ConnId, + reConn :: ConnId, + reInfo :: ByteString + } + +data Introduction = Introduction + { introId :: IntroId, + toConn :: ConnId, + toInfo :: Maybe ByteString, + toStatus :: IntroStatus, + reConn :: ConnId, + reInfo :: ByteString, + reStatus :: IntroStatus, + qInfo :: Maybe SMPQueueInfo + } + +data IntroStatus = IntroNew | IntroInv | IntroCon + deriving (Eq) + +serializeIntroStatus :: IntroStatus -> Text +serializeIntroStatus = \case + IntroNew -> "" + IntroInv -> "INV" + IntroCon -> "CON" + +introStatusT :: Text -> Maybe IntroStatus +introStatusT = \case + "" -> Just IntroNew + "INV" -> Just IntroInv + "CON" -> Just IntroCon + _ -> Nothing + +type IntroId = ByteString + +data NewInvitation = NewInvitation + { viaConn :: ConnId, + externalIntroId :: IntroId, + entityInfo :: EntityInfo, + qInfo :: Maybe SMPQueueInfo + } + +data Invitation = Invitation + { invId :: InvitationId, + viaConn :: ConnId, + externalIntroId :: IntroId, + entityInfo :: EntityInfo, + qInfo :: Maybe SMPQueueInfo, + connId :: Maybe ConnId, + status :: InvitationStatus + } + deriving (Show) + +data InvitationStatus = InvNew | InvAcpt | InvCon + deriving (Eq, Show) + +serializeInvStatus :: InvitationStatus -> Text +serializeInvStatus = \case + InvNew -> "" + InvAcpt -> "ACPT" + InvCon -> "CON" + +invStatusT :: Text -> Maybe InvitationStatus +invStatusT = \case + "" -> Just InvNew + "ACPT" -> Just InvAcpt + "CON" -> Just InvCon + _ -> Nothing + +type InvitationId = ByteString + -- * Store errors -- | Agent store error. data StoreError = -- | IO exceptions in store actions. SEInternal ByteString + | -- | failed to generate unique random ID + SEUniqueID | -- | Connection alias not found (or both queues absent). SEConnNotFound | -- | Connection alias already used. @@ -298,6 +391,10 @@ data StoreError SEBcastNotFound | -- | Broadcast ID already used. SEBcastDuplicate + | -- | Introduction ID not found. + SEIntroNotFound + | -- | Invitation ID not found. + SEInvitationNotFound | -- | Currently not used. The intention was to pass current expected queue status in methods, -- as we always know what it should be at any stage of the protocol, -- and in case it does not match use this error. diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 9dcf7edd1..d619026c4 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -11,6 +12,7 @@ {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -22,14 +24,19 @@ module Simplex.Messaging.Agent.Store.SQLite where import Control.Concurrent (threadDelay) -import Control.Monad (unless, when) +import Control.Concurrent.STM (TVar, atomically, stateTVar) +import Control.Monad (join, unless, when) import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO)) import Control.Monad.IO.Unlift (MonadUnliftIO) +import Crypto.Random (ChaChaDRG, randomBytesGenerate) import Data.Bifunctor (first) +import Data.ByteString (ByteString) +import Data.ByteString.Base64 (encode) import Data.Char (toLower) +import Data.Functor (($>)) import Data.List (find) import Data.Maybe (fromMaybe) -import Data.Text (isPrefixOf) +import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8) import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLData (..), SQLError, field) @@ -66,8 +73,8 @@ createSQLiteStore dbFilePath = do let dbDir = takeDirectory dbFilePath createDirectoryIfMissing False dbDir store <- connectSQLiteStore dbFilePath - compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]] - let threadsafeOption = find (isPrefixOf "THREADSAFE=") (concat compileOptions) + compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[Text]] + let threadsafeOption = find (T.isPrefixOf "THREADSAFE=") (concat compileOptions) case threadsafeOption of Just "THREADSAFE=0" -> confirmOrExit "SQLite compiled with non-threadsafe code." Nothing -> putStrLn "Warning: SQLite THREADSAFE compile option not found" @@ -107,13 +114,16 @@ connectSQLiteStore dbFilePath = do |] pure SQLiteStore {dbFilePath, dbConn, dbNew} -checkConstraint :: StoreError -> IO () -> IO (Either StoreError ()) -checkConstraint err action = first handleError <$> E.try action - where - handleError :: SQLError -> StoreError - handleError e - | DB.sqlError e == DB.ErrorConstraint = err - | otherwise = SEInternal $ bshow e +checkConstraint :: StoreError -> IO a -> IO (Either StoreError a) +checkConstraint err action = first (handleSQLError err) <$> E.try action + +checkConstraint' :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a) +checkConstraint' err action = action `E.catch` (pure . Left . handleSQLError err) + +handleSQLError :: StoreError -> SQLError -> StoreError +handleSQLError err e + | DB.sqlError e == DB.ErrorConstraint = err + | otherwise = SEInternal $ bshow e withTransaction :: forall a. DB.Connection -> IO a -> IO a withTransaction db a = loop 100 100_000 @@ -128,33 +138,43 @@ withTransaction db a = loop 100 100_000 else E.throwIO e instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where - createRcvConn :: SQLiteStore -> RcvQueue -> m () - createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} = - liftIOEither $ - checkConstraint SEConnDuplicate $ - withTransaction dbConn $ do - upsertServer_ dbConn server - insertRcvQueue_ dbConn q - insertRcvConnection_ dbConn q + createRcvConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId + createRcvConn SQLiteStore {dbConn} gVar cData q@RcvQueue {server} = + -- TODO if schema has to be restarted, this function can be refactored + -- to create connection first using createWithRandomId + liftIOEither . checkConstraint' SEConnDuplicate . withTransaction dbConn $ + getConnId_ dbConn gVar cData >>= traverse create + where + create :: ConnId -> IO ConnId + create connId = do + upsertServer_ dbConn server + insertRcvQueue_ dbConn connId q + insertRcvConnection_ dbConn cData {connId} q + pure connId - createSndConn :: SQLiteStore -> SndQueue -> m () - createSndConn SQLiteStore {dbConn} q@SndQueue {server} = - liftIOEither $ - checkConstraint SEConnDuplicate $ - withTransaction dbConn $ do - upsertServer_ dbConn server - insertSndQueue_ dbConn q - insertSndConnection_ dbConn q + createSndConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId + createSndConn SQLiteStore {dbConn} gVar cData q@SndQueue {server} = + -- TODO if schema has to be restarted, this function can be refactored + -- to create connection first using createWithRandomId + liftIOEither . checkConstraint' SEConnDuplicate . withTransaction dbConn $ + getConnId_ dbConn gVar cData >>= traverse create + where + create :: ConnId -> IO ConnId + create connId = do + upsertServer_ dbConn server + insertSndQueue_ dbConn connId q + insertSndConnection_ dbConn cData {connId} q + pure connId - getConn :: SQLiteStore -> ConnAlias -> m SomeConn - getConn SQLiteStore {dbConn} connAlias = + getConn :: SQLiteStore -> ConnId -> m SomeConn + getConn SQLiteStore {dbConn} connId = liftIOEither . withTransaction dbConn $ - getConn_ dbConn connAlias + getConn_ dbConn connId - getAllConnAliases :: SQLiteStore -> m [ConnAlias] - getAllConnAliases SQLiteStore {dbConn} = + getAllConnIds :: SQLiteStore -> m [ConnId] + getAllConnIds SQLiteStore {dbConn} = liftIO $ do - r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]] + r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnId]] return (concat r) getRcvConn :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m SomeConn @@ -169,37 +189,37 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto |] [":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] >>= \case - [Only connAlias] -> getConn_ dbConn connAlias + [Only connId] -> getConn_ dbConn connId _ -> pure $ Left SEConnNotFound - deleteConn :: SQLiteStore -> ConnAlias -> m () - deleteConn SQLiteStore {dbConn} connAlias = + deleteConn :: SQLiteStore -> ConnId -> m () + deleteConn SQLiteStore {dbConn} connId = liftIO $ DB.executeNamed dbConn "DELETE FROM connections WHERE conn_alias = :conn_alias;" - [":conn_alias" := connAlias] + [":conn_alias" := connId] - upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m () - upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} = + upgradeRcvConnToDuplex :: SQLiteStore -> ConnId -> SndQueue -> m () + upgradeRcvConnToDuplex SQLiteStore {dbConn} connId sq@SndQueue {server} = liftIOEither . withTransaction dbConn $ - getConn_ dbConn connAlias >>= \case + getConn_ dbConn connId >>= \case Right (SomeConn _ RcvConnection {}) -> do upsertServer_ dbConn server - insertSndQueue_ dbConn sq - updateConnWithSndQueue_ dbConn connAlias sq + insertSndQueue_ dbConn connId sq + updateConnWithSndQueue_ dbConn connId sq pure $ Right () Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c _ -> pure $ Left SEConnNotFound - upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m () - upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} = + upgradeSndConnToDuplex :: SQLiteStore -> ConnId -> RcvQueue -> m () + upgradeSndConnToDuplex SQLiteStore {dbConn} connId rq@RcvQueue {server} = liftIOEither . withTransaction dbConn $ - getConn_ dbConn connAlias >>= \case + getConn_ dbConn connId >>= \case Right (SomeConn _ SndConnection {}) -> do upsertServer_ dbConn server - insertRcvQueue_ dbConn rq - updateConnWithRcvQueue_ dbConn connAlias rq + insertRcvQueue_ dbConn connId rq + updateConnWithRcvQueue_ dbConn connId rq pure $ Right () Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c _ -> pure $ Left SEConnNotFound @@ -248,83 +268,233 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto |] [":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId] - updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) - updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} = + updateRcvIds :: SQLiteStore -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) + updateRcvIds SQLiteStore {dbConn} connId = liftIO . withTransaction dbConn $ do - (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias + (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connId let internalId = InternalId $ unId lastInternalId + 1 internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1 - updateLastIdsRcv_ dbConn connAlias internalId internalRcvId + updateLastIdsRcv_ dbConn connId internalId internalRcvId pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash) - createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m () - createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData = + createRcvMsg :: SQLiteStore -> ConnId -> RcvMsgData -> m () + createRcvMsg SQLiteStore {dbConn} connId rcvMsgData = liftIO . withTransaction dbConn $ do - insertRcvMsgBase_ dbConn connAlias rcvMsgData - insertRcvMsgDetails_ dbConn connAlias rcvMsgData - updateHashRcv_ dbConn connAlias rcvMsgData + insertRcvMsgBase_ dbConn connId rcvMsgData + insertRcvMsgDetails_ dbConn connId rcvMsgData + updateHashRcv_ dbConn connId rcvMsgData - updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash) - updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} = + updateSndIds :: SQLiteStore -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash) + updateSndIds SQLiteStore {dbConn} connId = liftIO . withTransaction dbConn $ do - (lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias + (lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connId let internalId = InternalId $ unId lastInternalId + 1 internalSndId = InternalSndId $ unSndId lastInternalSndId + 1 - updateLastIdsSnd_ dbConn connAlias internalId internalSndId + updateLastIdsSnd_ dbConn connId internalId internalSndId pure (internalId, internalSndId, prevSndHash) - createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m () - createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData = + createSndMsg :: SQLiteStore -> ConnId -> SndMsgData -> m () + createSndMsg SQLiteStore {dbConn} connId sndMsgData = liftIO . withTransaction dbConn $ do - insertSndMsgBase_ dbConn connAlias sndMsgData - insertSndMsgDetails_ dbConn connAlias sndMsgData - updateHashSnd_ dbConn connAlias sndMsgData + insertSndMsgBase_ dbConn connId sndMsgData + insertSndMsgDetails_ dbConn connId sndMsgData + updateHashSnd_ dbConn connId sndMsgData - getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg + getMsg :: SQLiteStore -> ConnId -> InternalId -> m Msg getMsg _st _connAlias _id = throwError SENotImplemented - createBcast :: SQLiteStore -> BroadcastId -> m () - createBcast SQLiteStore {dbConn} bId = - liftIOEither $ - checkConstraint SEBcastDuplicate $ - DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId) + createBcast :: SQLiteStore -> TVar ChaChaDRG -> BroadcastId -> m BroadcastId + createBcast SQLiteStore {dbConn} gVar bcastId = liftIOEither $ case bcastId of + "" -> createWithRandomId gVar create + bId -> checkConstraint SEBcastDuplicate $ create bId $> bId + where + create bId = DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId) - addBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m () - addBcastConn SQLiteStore {dbConn} bId connAlias = + addBcastConn :: SQLiteStore -> BroadcastId -> ConnId -> m () + addBcastConn SQLiteStore {dbConn} bId connId = liftIOEither . checkBroadcast dbConn bId $ - getConn_ dbConn connAlias >>= \case + getConn_ dbConn connId >>= \case Left _ -> pure $ Left SEConnNotFound Right (SomeConn _ RcvConnection {}) -> pure . Left $ SEBadConnType CRcv Right _ -> checkConstraint SEConnDuplicate $ DB.execute dbConn - "INSERT INTO broadcast_connections (broadcast_id, conn_alias) VALUES (?, ?);" - (bId, connAlias) + [sql| + INSERT INTO broadcast_connections + (broadcast_id, conn_alias) VALUES (?, ?); + |] + (bId, connId) - removeBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m () - removeBcastConn SQLiteStore {dbConn} bId connAlias = + removeBcastConn :: SQLiteStore -> BroadcastId -> ConnId -> m () + removeBcastConn SQLiteStore {dbConn} bId connId = liftIOEither . checkBroadcast dbConn bId $ - bcastConnExists_ dbConn bId connAlias >>= \case + bcastConnExists_ dbConn bId connId >>= \case False -> pure $ Left SEConnNotFound _ -> Right <$> DB.execute dbConn - "DELETE FROM broadcast_connections WHERE broadcast_id = ? AND conn_alias = ?;" - (bId, connAlias) + [sql| + DELETE FROM broadcast_connections + WHERE broadcast_id = ? AND conn_alias = ?; + |] + (bId, connId) deleteBcast :: SQLiteStore -> BroadcastId -> m () deleteBcast SQLiteStore {dbConn} bId = liftIOEither . checkBroadcast dbConn bId $ Right <$> DB.execute dbConn "DELETE FROM broadcasts WHERE broadcast_id = ?;" (Only bId) - getBcast :: SQLiteStore -> BroadcastId -> m [ConnAlias] + getBcast :: SQLiteStore -> BroadcastId -> m [ConnId] getBcast SQLiteStore {dbConn} bId = liftIOEither . checkBroadcast dbConn bId $ Right . map fromOnly <$> DB.query dbConn "SELECT conn_alias FROM broadcast_connections WHERE broadcast_id = ?;" (Only bId) + createIntro :: SQLiteStore -> TVar ChaChaDRG -> NewIntroduction -> m IntroId + createIntro SQLiteStore {dbConn} gVar NewIntroduction {toConn, reConn, reInfo} = + liftIOEither . createWithRandomId gVar $ \introId -> + DB.execute + dbConn + [sql| + INSERT INTO conn_intros + (intro_id, to_conn, re_conn, re_info) VALUES (?, ?, ?, ?); + |] + (introId, toConn, reConn, reInfo) + + getIntro :: SQLiteStore -> IntroId -> m Introduction + getIntro SQLiteStore {dbConn} introId = + liftIOEither $ + intro + <$> DB.query + dbConn + [sql| + SELECT to_conn, to_info, to_status, re_conn, re_info, re_status, queue_info + FROM conn_intros + WHERE intro_id = ?; + |] + (Only introId) + where + intro [(toConn, toInfo, toStatus, reConn, reInfo, reStatus, qInfo)] = + Right $ Introduction {introId, toConn, toInfo, toStatus, reConn, reInfo, reStatus, qInfo} + intro _ = Left SEIntroNotFound + + addIntroInvitation :: SQLiteStore -> IntroId -> EntityInfo -> SMPQueueInfo -> m () + addIntroInvitation SQLiteStore {dbConn} introId toInfo qInfo = + liftIO $ + DB.executeNamed + dbConn + [sql| + UPDATE conn_intros + SET to_info = :to_info, + queue_info = :queue_info, + to_status = :to_status + WHERE intro_id = :intro_id; + |] + [ ":to_info" := toInfo, + ":queue_info" := Just qInfo, + ":to_status" := IntroInv, + ":intro_id" := introId + ] + + setIntroToStatus :: SQLiteStore -> IntroId -> IntroStatus -> m () + setIntroToStatus SQLiteStore {dbConn} introId toStatus = + liftIO $ + DB.execute + dbConn + [sql| + UPDATE conn_intros + SET to_status = ? + WHERE intro_id = ?; + |] + (toStatus, introId) + + setIntroReStatus :: SQLiteStore -> IntroId -> IntroStatus -> m () + setIntroReStatus SQLiteStore {dbConn} introId reStatus = + liftIO $ + DB.execute + dbConn + [sql| + UPDATE conn_intros + SET re_status = ? + WHERE intro_id = ?; + |] + (reStatus, introId) + + createInvitation :: SQLiteStore -> TVar ChaChaDRG -> NewInvitation -> m InvitationId + createInvitation SQLiteStore {dbConn} gVar NewInvitation {viaConn, externalIntroId, entityInfo, qInfo} = + liftIOEither . createWithRandomId gVar $ \invId -> + DB.execute + dbConn + [sql| + INSERT INTO conn_invitations + (inv_id, via_conn, external_intro_id, conn_info, queue_info) VALUES (?, ?, ?, ?, ?); + |] + (invId, viaConn, externalIntroId, entityInfo, qInfo) + + getInvitation :: SQLiteStore -> InvitationId -> m Invitation + getInvitation SQLiteStore {dbConn} invId = + liftIOEither $ + invitation + <$> DB.query + dbConn + [sql| + SELECT via_conn, external_intro_id, conn_info, queue_info, conn_id, status + FROM conn_invitations + WHERE inv_id = ?; + |] + (Only invId) + where + invitation [(viaConn, externalIntroId, entityInfo, qInfo, connId, status)] = + Right $ Invitation {invId, viaConn, externalIntroId, entityInfo, qInfo, connId, status} + invitation _ = Left SEInvitationNotFound + + addInvitationConn :: SQLiteStore -> InvitationId -> ConnId -> m () + addInvitationConn SQLiteStore {dbConn} invId connId = + liftIO $ + DB.executeNamed + dbConn + [sql| + UPDATE conn_invitations + SET conn_id = :conn_id, status = :status + WHERE inv_id = :inv_id; + |] + [":conn_id" := connId, ":status" := InvAcpt, ":inv_id" := invId] + + getConnInvitation :: SQLiteStore -> ConnId -> m (Maybe (Invitation, Connection 'CDuplex)) + getConnInvitation SQLiteStore {dbConn} cId = + liftIO . withTransaction dbConn $ + DB.query + dbConn + [sql| + SELECT inv_id, via_conn, external_intro_id, conn_info, queue_info, status + FROM conn_invitations + WHERE conn_id = ?; + |] + (Only cId) + >>= fmap join . traverse getViaConn . invitation + where + invitation [(invId, viaConn, externalIntroId, entityInfo, qInfo, status)] = + Just $ Invitation {invId, viaConn, externalIntroId, entityInfo, qInfo, connId = Just cId, status} + invitation _ = Nothing + getViaConn :: Invitation -> IO (Maybe (Invitation, Connection 'CDuplex)) + getViaConn inv@Invitation {viaConn} = fmap (inv,) . duplexConn <$> getConn_ dbConn viaConn + duplexConn :: Either StoreError SomeConn -> Maybe (Connection 'CDuplex) + duplexConn (Right (SomeConn SCDuplex conn)) = Just conn + duplexConn _ = Nothing + + setInvitationStatus :: SQLiteStore -> InvitationId -> InvitationStatus -> m () + setInvitationStatus SQLiteStore {dbConn} invId status = + liftIO $ + DB.execute + dbConn + [sql| + UPDATE conn_invitations + SET status = ? WHERE inv_id = ?; + |] + (status, invId) + -- * Auxiliary helpers -- ? replace with ToField? - it's easy to forget to use this @@ -337,7 +507,7 @@ deserializePort_ port = Just port instance ToField QueueStatus where toField = toField . show -instance FromField QueueStatus where fromField = fromFieldToReadable_ +instance FromField QueueStatus where fromField = fromTextField_ $ readMaybe . T.unpack instance ToField InternalRcvId where toField (InternalRcvId x) = toField x @@ -359,13 +529,24 @@ instance ToField MsgIntegrity where toField = toField . serializeMsgIntegrity instance FromField MsgIntegrity where fromField = blobFieldParser msgIntegrityP -fromFieldToReadable_ :: forall a. (Read a, E.Typeable a) => Field -> Ok a -fromFieldToReadable_ = \case +instance ToField IntroStatus where toField = toField . serializeIntroStatus + +instance FromField IntroStatus where fromField = fromTextField_ introStatusT + +instance ToField InvitationStatus where toField = toField . serializeInvStatus + +instance FromField InvitationStatus where fromField = fromTextField_ invStatusT + +instance ToField SMPQueueInfo where toField = toField . serializeSmpQueueInfo + +instance FromField SMPQueueInfo where fromField = blobFieldParser smpQueueInfoP + +fromTextField_ :: (E.Typeable a) => (Text -> Maybe a) -> Field -> Ok a +fromTextField_ fromText = \case f@(Field (SQLText t) _) -> - let str = T.unpack t - in case readMaybe str of - Just x -> Ok x - _ -> returnError ConversionFailed f ("invalid string: " <> str) + case fromText t of + Just x -> Ok x + _ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t) f -> returnError ConversionFailed f "expecting SQLText column type" {- ORMOLU_DISABLE -} @@ -397,8 +578,8 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do -- * createRcvConn helpers -insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO () -insertRcvQueue_ dbConn RcvQueue {..} = do +insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO () +insertRcvQueue_ dbConn connId RcvQueue {..} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn @@ -411,7 +592,7 @@ insertRcvQueue_ dbConn RcvQueue {..} = do [ ":host" := host server, ":port" := port_, ":rcv_id" := rcvId, - ":conn_alias" := connAlias, + ":conn_alias" := connId, ":rcv_private_key" := rcvPrivateKey, ":snd_id" := sndId, ":snd_key" := sndKey, @@ -420,26 +601,29 @@ insertRcvQueue_ dbConn RcvQueue {..} = do ":status" := status ] -insertRcvConnection_ :: DB.Connection -> RcvQueue -> IO () -insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do +insertRcvConnection_ :: DB.Connection -> ConnData -> RcvQueue -> IO () +insertRcvConnection_ dbConn ConnData {connId, viaInv, connLevel} RcvQueue {server, rcvId} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn [sql| INSERT INTO connections - ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, - last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, - last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) + ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, via_inv, conn_level, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) VALUES - (:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, - 0, 0, 0, 0, x'', x''); + (:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, :via_inv,:conn_level, 0, 0, 0, 0, x'', x''); |] - [":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId] + [ ":conn_alias" := connId, + ":rcv_host" := host server, + ":rcv_port" := port_, + ":rcv_id" := rcvId, + ":via_inv" := viaInv, + ":conn_level" := connLevel + ] -- * createSndConn helpers -insertSndQueue_ :: DB.Connection -> SndQueue -> IO () -insertSndQueue_ dbConn SndQueue {..} = do +insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO () +insertSndQueue_ dbConn connId SndQueue {..} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn @@ -452,85 +636,96 @@ insertSndQueue_ dbConn SndQueue {..} = do [ ":host" := host server, ":port" := port_, ":snd_id" := sndId, - ":conn_alias" := connAlias, + ":conn_alias" := connId, ":snd_private_key" := sndPrivateKey, ":encrypt_key" := encryptKey, ":sign_key" := signKey, ":status" := status ] -insertSndConnection_ :: DB.Connection -> SndQueue -> IO () -insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do +insertSndConnection_ :: DB.Connection -> ConnData -> SndQueue -> IO () +insertSndConnection_ dbConn ConnData {connId, viaInv, connLevel} SndQueue {server, sndId} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn [sql| INSERT INTO connections - ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, - last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, - last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) + ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, via_inv, conn_level, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) VALUES - (:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id, - 0, 0, 0, 0, x'', x''); + (:conn_alias, NULL, NULL, NULL, :snd_host,:snd_port,:snd_id,:via_inv,:conn_level, 0, 0, 0, 0, x'', x''); |] - [":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId] + [ ":conn_alias" := connId, + ":snd_host" := host server, + ":snd_port" := port_, + ":snd_id" := sndId, + ":via_inv" := viaInv, + ":conn_level" := connLevel + ] -- * getConn helpers -getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn) -getConn_ dbConn connAlias = do - rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias - sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias - pure $ case (rQ, sQ) of - (Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ) - (Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ) - (Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ) - _ -> Left SEConnNotFound +getConn_ :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) +getConn_ dbConn connId = + getConnData_ dbConn connId >>= \case + Nothing -> pure $ Left SEConnNotFound + Just connData -> do + rQ <- getRcvQueueByConnAlias_ dbConn connId + sQ <- getSndQueueByConnAlias_ dbConn connId + pure $ case (rQ, sQ) of + (Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ) + (Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ) + (Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connData sndQ) + _ -> Left SEConnNotFound -retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue) -retrieveRcvQueueByConnAlias_ dbConn connAlias = do - r <- - DB.queryNamed +getConnData_ :: DB.Connection -> ConnId -> IO (Maybe ConnData) +getConnData_ dbConn connId = + connData + <$> DB.query dbConn "SELECT via_inv, conn_level FROM connections WHERE conn_alias = ?;" (Only connId) + where + connData [(viaInv, connLevel)] = Just ConnData {connId, viaInv, connLevel} + connData _ = Nothing + +getRcvQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue) +getRcvQueueByConnAlias_ dbConn connId = + rcvQueue + <$> DB.query dbConn [sql| - SELECT - s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key, + SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key, q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status FROM rcv_queues q INNER JOIN servers s ON q.host = s.host AND q.port = s.port - WHERE q.conn_alias = :conn_alias; + WHERE q.conn_alias = ?; |] - [":conn_alias" := connAlias] - case r of - [(keyHash, host, port, rcvId, cAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do + (Only connId) + where + rcvQueue [(keyHash, host, port, rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] = let srv = SMPServer host (deserializePort_ port) keyHash - return . Just $ RcvQueue srv rcvId cAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status - _ -> return Nothing + in Just $ RcvQueue srv rcvId rcvPrivateKey sndId sndKey decryptKey verifyKey status + rcvQueue _ = Nothing -retrieveSndQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe SndQueue) -retrieveSndQueueByConnAlias_ dbConn connAlias = do - r <- - DB.queryNamed +getSndQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue) +getSndQueueByConnAlias_ dbConn connId = + sndQueue + <$> DB.query dbConn [sql| - SELECT - s.key_hash, q.host, q.port, q.snd_id, q.conn_alias, - q.snd_private_key, q.encrypt_key, q.sign_key, q.status + SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_private_key, q.encrypt_key, q.sign_key, q.status FROM snd_queues q INNER JOIN servers s ON q.host = s.host AND q.port = s.port - WHERE q.conn_alias = :conn_alias; + WHERE q.conn_alias = ?; |] - [":conn_alias" := connAlias] - case r of - [(keyHash, host, port, sndId, cAlias, sndPrivateKey, encryptKey, signKey, status)] -> do + (Only connId) + where + sndQueue [(keyHash, host, port, sndId, sndPrivateKey, encryptKey, signKey, status)] = let srv = SMPServer host (deserializePort_ port) keyHash - return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status - _ -> return Nothing + in Just $ SndQueue srv sndId sndPrivateKey encryptKey signKey status + sndQueue _ = Nothing -- * upgradeRcvConnToDuplex helpers -updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO () -updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do +updateConnWithSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO () +updateConnWithSndQueue_ dbConn connId SndQueue {server, sndId} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn @@ -539,12 +734,12 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do SET snd_host = :snd_host, snd_port = :snd_port, snd_id = :snd_id WHERE conn_alias = :conn_alias; |] - [":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connAlias] + [":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connId] -- * upgradeSndConnToDuplex helpers -updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO () -updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do +updateConnWithRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO () +updateConnWithRcvQueue_ dbConn connId RcvQueue {server, rcvId} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn @@ -553,12 +748,12 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do SET rcv_host = :rcv_host, rcv_port = :rcv_port, rcv_id = :rcv_id WHERE conn_alias = :conn_alias; |] - [":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias] + [":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connId] -- * updateRcvIds helpers -retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) -retrieveLastIdsAndHashRcv_ dbConn connAlias = do +retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) +retrieveLastIdsAndHashRcv_ dbConn connId = do [(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <- DB.queryNamed dbConn @@ -567,11 +762,11 @@ retrieveLastIdsAndHashRcv_ dbConn connAlias = do FROM connections WHERE conn_alias = :conn_alias; |] - [":conn_alias" := connAlias] + [":conn_alias" := connId] return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) -updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO () -updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId = +updateLastIdsRcv_ :: DB.Connection -> ConnId -> InternalId -> InternalRcvId -> IO () +updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId = DB.executeNamed dbConn [sql| @@ -582,13 +777,13 @@ updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId = |] [ ":last_internal_msg_id" := newInternalId, ":last_internal_rcv_msg_id" := newInternalRcvId, - ":conn_alias" := connAlias + ":conn_alias" := connId ] -- * createRcvMsg helpers -insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () -insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do +insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () +insertRcvMsgBase_ dbConn connId RcvMsgData {..} = do DB.executeNamed dbConn [sql| @@ -597,15 +792,15 @@ insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do VALUES (:conn_alias,:internal_id,:internal_ts,:internal_rcv_id, NULL,:body); |] - [ ":conn_alias" := connAlias, + [ ":conn_alias" := connId, ":internal_id" := internalId, ":internal_ts" := internalTs, ":internal_rcv_id" := internalRcvId, ":body" := decodeUtf8 msgBody ] -insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () -insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} = +insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () +insertRcvMsgDetails_ dbConn connId RcvMsgData {..} = DB.executeNamed dbConn [sql| @@ -618,7 +813,7 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} = :broker_id,:broker_ts,:rcv_status, NULL, NULL, :internal_hash,:external_prev_snd_hash,:integrity); |] - [ ":conn_alias" := connAlias, + [ ":conn_alias" := connId, ":internal_rcv_id" := internalRcvId, ":internal_id" := internalId, ":external_snd_id" := fst senderMeta, @@ -631,8 +826,8 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} = ":integrity" := msgIntegrity ] -updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () -updateHashRcv_ dbConn connAlias RcvMsgData {..} = +updateHashRcv_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () +updateHashRcv_ dbConn connId RcvMsgData {..} = DB.executeNamed dbConn -- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved @@ -645,14 +840,14 @@ updateHashRcv_ dbConn connAlias RcvMsgData {..} = |] [ ":last_external_snd_msg_id" := fst senderMeta, ":last_rcv_msg_hash" := internalHash, - ":conn_alias" := connAlias, + ":conn_alias" := connId, ":last_internal_rcv_msg_id" := internalRcvId ] -- * updateSndIds helpers -retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash) -retrieveLastIdsAndHashSnd_ dbConn connAlias = do +retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash) +retrieveLastIdsAndHashSnd_ dbConn connId = do [(lastInternalId, lastInternalSndId, lastSndHash)] <- DB.queryNamed dbConn @@ -661,11 +856,11 @@ retrieveLastIdsAndHashSnd_ dbConn connAlias = do FROM connections WHERE conn_alias = :conn_alias; |] - [":conn_alias" := connAlias] + [":conn_alias" := connId] return (lastInternalId, lastInternalSndId, lastSndHash) -updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO () -updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId = +updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO () +updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId = DB.executeNamed dbConn [sql| @@ -676,13 +871,13 @@ updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId = |] [ ":last_internal_msg_id" := newInternalId, ":last_internal_snd_msg_id" := newInternalSndId, - ":conn_alias" := connAlias + ":conn_alias" := connId ] -- * createSndMsg helpers -insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () -insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do +insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO () +insertSndMsgBase_ dbConn connId SndMsgData {..} = do DB.executeNamed dbConn [sql| @@ -691,15 +886,15 @@ insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do VALUES (:conn_alias,:internal_id,:internal_ts, NULL,:internal_snd_id,:body); |] - [ ":conn_alias" := connAlias, + [ ":conn_alias" := connId, ":internal_id" := internalId, ":internal_ts" := internalTs, ":internal_snd_id" := internalSndId, ":body" := decodeUtf8 msgBody ] -insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () -insertSndMsgDetails_ dbConn connAlias SndMsgData {..} = +insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO () +insertSndMsgDetails_ dbConn connId SndMsgData {..} = DB.executeNamed dbConn [sql| @@ -708,15 +903,15 @@ insertSndMsgDetails_ dbConn connAlias SndMsgData {..} = VALUES (:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash); |] - [ ":conn_alias" := connAlias, + [ ":conn_alias" := connId, ":internal_snd_id" := internalSndId, ":internal_id" := internalId, ":snd_status" := Created, ":internal_hash" := internalHash ] -updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () -updateHashSnd_ dbConn connAlias SndMsgData {..} = +updateHashSnd_ :: DB.Connection -> ConnId -> SndMsgData -> IO () +updateHashSnd_ dbConn connId SndMsgData {..} = DB.executeNamed dbConn -- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved @@ -727,7 +922,7 @@ updateHashSnd_ dbConn connAlias SndMsgData {..} = AND last_internal_snd_msg_id = :last_internal_snd_msg_id; |] [ ":last_snd_msg_hash" := internalHash, - ":conn_alias" := connAlias, + ":conn_alias" := connId, ":last_internal_snd_msg_id" := internalSndId ] @@ -745,10 +940,10 @@ bcastExists_ dbConn bId = not . null <$> queryBcast queryBcast :: IO [Only BroadcastId] queryBcast = DB.query dbConn "SELECT broadcast_id FROM broadcasts WHERE broadcast_id = ?;" (Only bId) -bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnAlias -> IO Bool -bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn +bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnId -> IO Bool +bcastConnExists_ dbConn bId connId = not . null <$> queryBcastConn where - queryBcastConn :: IO [(BroadcastId, ConnAlias)] + queryBcastConn :: IO [(BroadcastId, ConnId)] queryBcastConn = DB.query dbConn @@ -757,4 +952,37 @@ bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn FROM broadcast_connections WHERE broadcast_id = ? AND conn_alias = ?; |] - (bId, connAlias) + (bId, connId) + +-- create record with a random ID + +getConnId_ :: DB.Connection -> TVar ChaChaDRG -> ConnData -> IO (Either StoreError ConnId) +getConnId_ dbConn gVar ConnData {connId = ""} = getUniqueRandomId gVar $ getConnData_ dbConn +getConnId_ _ _ ConnData {connId} = pure $ Right connId + +getUniqueRandomId :: TVar ChaChaDRG -> (ByteString -> IO (Maybe a)) -> IO (Either StoreError ByteString) +getUniqueRandomId gVar get = tryGet 3 + where + tryGet :: Int -> IO (Either StoreError ByteString) + tryGet 0 = pure $ Left SEUniqueID + tryGet n = do + id' <- randomId gVar 12 + get id' >>= \case + Nothing -> pure $ Right id' + Just _ -> tryGet (n - 1) + +createWithRandomId :: TVar ChaChaDRG -> (ByteString -> IO ()) -> IO (Either StoreError ByteString) +createWithRandomId gVar create = tryCreate 3 + where + tryCreate :: Int -> IO (Either StoreError ByteString) + tryCreate 0 = pure $ Left SEUniqueID + tryCreate n = do + id' <- randomId gVar 12 + E.try (create id') >>= \case + Right _ -> pure $ Right id' + Left e + | DB.sqlError e == DB.ErrorConstraint -> tryCreate (n - 1) + | otherwise -> pure . Left . SEInternal $ bshow e + +randomId :: TVar ChaChaDRG -> Int -> IO ByteString +randomId gVar n = encode <$> (atomically . stateTVar gVar $ randomBytesGenerate n) diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 2b9522e3e..5e7741c29 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -30,7 +30,7 @@ base64StringP = do pure $ str <> pad tsISO8601P :: Parser UTCTime -tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ') +tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd parse :: Parser a -> e -> (ByteString -> Either e a) parse parser err = first (const err) . parseAll parser @@ -42,14 +42,17 @@ parseRead :: Read a => Parser ByteString -> Parser a parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack) parseRead1 :: Read a => Parser a -parseRead1 = parseRead $ A.takeTill (== ' ') +parseRead1 = parseRead $ A.takeTill wordEnd parseRead2 :: Read a => Parser a parseRead2 = parseRead $ do - w1 <- A.takeTill (== ' ') <* A.char ' ' - w2 <- A.takeTill (== ' ') + w1 <- A.takeTill wordEnd <* A.char ' ' + w2 <- A.takeTill wordEnd pure $ w1 <> " " <> w2 +wordEnd :: Char -> Bool +wordEnd c = c == ' ' || c == '\n' + parseString :: (ByteString -> Either String a) -> (String -> a) parseString p = either error id . p . B.pack diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index a3d9d184f..57c7ad760 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -5,6 +5,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PostfixOperators #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} @@ -17,6 +18,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import SMPAgentClient import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Agent.Store (InvitationId) import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) import Simplex.Messaging.Transport (ATransport (..), TProxy (..), Transport (..)) import System.Timeout @@ -29,10 +31,16 @@ agentTests (ATransport t) = do describe "Establishing duplex connection" do it "should connect via one server and one agent" $ smpAgentTest2_1_1 $ testDuplexConnection t + it "should connect via one server and one agent (random IDs)" $ + smpAgentTest2_1_1 $ testDuplexConnRandomIds t it "should connect via one server and 2 agents" $ smpAgentTest2_2_1 $ testDuplexConnection t + it "should connect via one server and 2 agents (random IDs)" $ + smpAgentTest2_2_1 $ testDuplexConnRandomIds t it "should connect via 2 servers and 2 agents" $ smpAgentTest2_2_2 $ testDuplexConnection t + it "should connect via 2 servers and 2 agents (random IDs)" $ + smpAgentTest2_2_2 $ testDuplexConnRandomIds t describe "Connection subscriptions" do it "should connect via one server and one agent" $ smpAgentTest3_1_1 $ testSubscription t @@ -41,6 +49,13 @@ agentTests (ATransport t) = do describe "Broadcast" do it "should create broadcast and send messages" $ smpAgentTest3 $ testBroadcast t + it "should create broadcast and send messages (random IDs)" $ + smpAgentTest3 $ testBroadcastRandomIds t + describe "Introduction" do + it "should send and accept introduction" $ + smpAgentTest3 $ testIntroduction t + it "should send and accept introduction (random IDs)" $ + smpAgentTest3 $ testIntroductionRandomIds t type TestTransmission p = (ACorrId, ByteString, APartyCmd p) @@ -54,9 +69,13 @@ testTE (ATransmissionOrError corrId entity cmdOrErr) = Right cmd -> Right $ APartyCmd cmd Left e -> Left e +-- | receive message to handle `h` +(<#:) :: Transport c => c -> IO (TestTransmissionOrError 'Agent) +(<#:) h = testTE <$> tGet SAgent h + -- | send transmission `t` to handle `h` and get response (#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (TestTransmissionOrError 'Agent) -h #: t = tPutRaw h t >> testTE <$> tGet SAgent h +h #: t = tPutRaw h t >> (h <#:) -- | action and expected response -- `h #:t #> r` is the test that sends `t` to `h` and validates that the response is `r` @@ -75,11 +94,11 @@ correctTransmission (corrId, cAlias, cmdOrErr) = case cmdOrErr of -- | receive message to handle `h` and validate that it is the expected one (<#) :: Transport c => c -> TestTransmission' 'Agent c' -> Expectation -h <# (corrId, cAlias, cmd) = tGet SAgent h >>= (`shouldBe` (corrId, cAlias, Right (APartyCmd cmd))) . testTE +h <# (corrId, cAlias, cmd) = (h <#:) >>= (`shouldBe` (corrId, cAlias, Right (APartyCmd cmd))) -- | receive message to handle `h` and validate it using predicate `p` (<#=) :: Transport c => c -> (TestTransmission 'Agent -> Bool) -> Expectation -h <#= p = tGet SAgent h >>= (`shouldSatisfy` p . correctTransmission . testTE) +h <#= p = (h <#:) >>= (`shouldSatisfy` p . correctTransmission) -- | test that nothing is delivered to handle `h` during 10ms (#:#) :: Transport c => c -> String -> Expectation @@ -90,53 +109,75 @@ h #:# err = tryGet `shouldReturn` () Just _ -> error err _ -> return () -pattern Msg :: MsgBody -> APartyCmd 'Agent -pattern Msg msgBody <- APartyCmd MSG {msgBody, msgIntegrity = MsgOk} - pattern Sent :: AgentMsgId -> APartyCmd 'Agent pattern Sent msgId <- APartyCmd (SENT msgId) -pattern Inv :: SMPQueueInfo -> APartyCmd 'Agent -pattern Inv invitation <- APartyCmd (INV invitation) +pattern Msg :: MsgBody -> APartyCmd 'Agent +pattern Msg msgBody <- APartyCmd MSG {msgBody, msgIntegrity = MsgOk} + +pattern Inv :: SMPQueueInfo -> Either AgentErrorType (APartyCmd 'Agent) +pattern Inv invitation <- Right (APartyCmd (INV invitation)) + +pattern Req :: InvitationId -> EntityInfo -> Either AgentErrorType (APartyCmd 'Agent) +pattern Req invId eInfo <- Right (APartyCmd (REQ (IE (Conn invId)) eInfo)) testDuplexConnection :: Transport c => TProxy c -> c -> c -> IO () testDuplexConnection _ alice bob = do - ("1", "C:bob", Right (Inv qInfo)) <- alice #: ("1", "C:bob", "NEW") + ("1", "C:bob", Inv qInfo) <- alice #: ("1", "C:bob", "NEW") let qInfo' = serializeSmpQueueInfo qInfo bob #: ("11", "C:alice", "JOIN " <> qInfo') #> ("", "C:alice", CON) alice <# ("", "C:bob", CON) - alice #: ("2", "C:bob", "SEND :hello") =#> \case ("2", "C:bob", Sent 1) -> True; _ -> False - alice #: ("3", "C:bob", "SEND :how are you?") =#> \case ("3", "C:bob", Sent 2) -> True; _ -> False + alice #: ("2", "C:bob", "SEND :hello") #> ("2", "C:bob", SENT 1) + alice #: ("3", "C:bob", "SEND :how are you?") #> ("3", "C:bob", SENT 2) bob <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False bob <#= \case ("", "C:alice", Msg "how are you?") -> True; _ -> False - bob #: ("14", "C:alice", "SEND 9\nhello too") =#> \case ("14", "C:alice", Sent 3) -> True; _ -> False + bob #: ("14", "C:alice", "SEND 9\nhello too") #> ("14", "C:alice", SENT 3) alice <#= \case ("", "C:bob", Msg "hello too") -> True; _ -> False - bob #: ("15", "C:alice", "SEND 9\nmessage 1") =#> \case ("15", "C:alice", Sent 4) -> True; _ -> False + bob #: ("15", "C:alice", "SEND 9\nmessage 1") #> ("15", "C:alice", SENT 4) alice <#= \case ("", "C:bob", Msg "message 1") -> True; _ -> False alice #: ("5", "C:bob", "OFF") #> ("5", "C:bob", OK) bob #: ("17", "C:alice", "SEND 9\nmessage 3") #> ("17", "C:alice", ERR (SMP AUTH)) alice #: ("6", "C:bob", "DEL") #> ("6", "C:bob", OK) alice #:# "nothing else should be delivered to alice" +testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO () +testDuplexConnRandomIds _ alice bob = do + ("1", bobConn, Inv qInfo) <- alice #: ("1", "C:", "NEW") + let qInfo' = serializeSmpQueueInfo qInfo + ("", aliceConn, Right (APartyCmd CON)) <- bob #: ("11", "C:", "JOIN " <> qInfo') + alice <# ("", bobConn, CON) + alice #: ("2", bobConn, "SEND :hello") #> ("2", bobConn, SENT 1) + alice #: ("3", bobConn, "SEND :how are you?") #> ("3", bobConn, SENT 2) + bob <#= \case ("", c, Msg "hello") -> c == aliceConn; _ -> False + bob <#= \case ("", c, Msg "how are you?") -> c == aliceConn; _ -> False + bob #: ("14", aliceConn, "SEND 9\nhello too") #> ("14", aliceConn, SENT 3) + alice <#= \case ("", c, Msg "hello too") -> c == bobConn; _ -> False + bob #: ("15", aliceConn, "SEND 9\nmessage 1") #> ("15", aliceConn, SENT 4) + alice <#= \case ("", c, Msg "message 1") -> c == bobConn; _ -> False + alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK) + bob #: ("17", aliceConn, "SEND 9\nmessage 3") #> ("17", aliceConn, ERR (SMP AUTH)) + alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK) + alice #:# "nothing else should be delivered to alice" + testSubscription :: Transport c => TProxy c -> c -> c -> c -> IO () testSubscription _ alice1 alice2 bob = do - ("1", "C:bob", Right (Inv qInfo)) <- alice1 #: ("1", "C:bob", "NEW") + ("1", "C:bob", Inv qInfo) <- alice1 #: ("1", "C:bob", "NEW") let qInfo' = serializeSmpQueueInfo qInfo bob #: ("11", "C:alice", "JOIN " <> qInfo') #> ("", "C:alice", CON) - bob #: ("12", "C:alice", "SEND 5\nhello") =#> \case ("12", "C:alice", Sent _) -> True; _ -> False - bob #: ("13", "C:alice", "SEND 11\nhello again") =#> \case ("13", "C:alice", Sent _) -> True; _ -> False + bob #: ("12", "C:alice", "SEND 5\nhello") #> ("12", "C:alice", SENT 1) + bob #: ("13", "C:alice", "SEND 11\nhello again") #> ("13", "C:alice", SENT 2) alice1 <# ("", "C:bob", CON) alice1 <#= \case ("", "C:bob", Msg "hello") -> True; _ -> False alice1 <#= \case ("", "C:bob", Msg "hello again") -> True; _ -> False alice2 #: ("21", "C:bob", "SUB") #> ("21", "C:bob", OK) alice1 <# ("", "C:bob", END) - bob #: ("14", "C:alice", "SEND 2\nhi") =#> \case ("14", "C:alice", Sent _) -> True; _ -> False + bob #: ("14", "C:alice", "SEND 2\nhi") #> ("14", "C:alice", SENT 3) alice2 <#= \case ("", "C:bob", Msg "hi") -> True; _ -> False alice1 #:# "nothing else should be delivered to alice1" testSubscrNotification :: Transport c => TProxy c -> (ThreadId, ThreadId) -> c -> IO () testSubscrNotification _ (server, _) client = do - client #: ("1", "C:conn1", "NEW") =#> \case ("1", "C:conn1", Inv _) -> True; _ -> False + client #: ("1", "C:conn1", "NEW") =#> \case ("1", "C:conn1", APartyCmd INV {}) -> True; _ -> False client #:# "nothing should be delivered to client before the server is killed" killThread server client <# ("", "C:conn1", END) @@ -156,8 +197,8 @@ testBroadcast _ alice bob tom = do alice #: ("e3", "B:team", "ADD C:unknown") #> ("e3", "B:team", ERR $ CONN NOT_FOUND) alice #: ("e4", "B:team", "ADD C:bob") #> ("e4", "B:team", ERR $ CONN DUPLICATE) -- send message - alice #: ("4", "B:team", "SEND 5\nhello") #> ("4", "C:bob", SENT 1) - alice <# ("4", "C:tom", SENT 1) + alice #: ("4", "B:team", "SEND 5\nhello") =#> \case ("4", c, Sent 1) -> c == "C:bob" || c == "C:tom"; _ -> False + alice <#= \case ("4", c, Sent 1) -> c == "C:bob" || c == "C:tom"; _ -> False alice <# ("4", "B:team", SENT 0) bob <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False tom <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False @@ -177,13 +218,104 @@ testBroadcast _ alice bob tom = do -- commands with errors alice #: ("e8", "B:team", "DEL") #> ("e8", "B:team", ERR $ BCAST B_NOT_FOUND) alice #: ("e9", "B:group", "DEL") #> ("e9", "B:group", ERR $ BCAST B_NOT_FOUND) - where - connect :: (c, ByteString) -> (c, ByteString) -> IO () - connect (h1, name1) (h2, name2) = do - ("c1", _, Right (Inv qInfo)) <- h1 #: ("c1", "C:" <> name2, "NEW") - let qInfo' = serializeSmpQueueInfo qInfo - h2 #: ("c2", "C:" <> name1, "JOIN " <> qInfo') =#> \case ("", c1, APartyCmd CON) -> c1 == "C:" <> name1; _ -> False - h1 <#= \case ("", c2, APartyCmd CON) -> c2 == "C:" <> name2; _ -> False + +testBroadcastRandomIds :: forall c. Transport c => TProxy c -> c -> c -> c -> IO () +testBroadcastRandomIds _ alice bob tom = do + -- establish connections + (aliceB, bobA) <- alice `connect'` bob + (aliceT, tomA) <- alice `connect'` tom + -- create and set up broadcast + ("1", team, Right (APartyCmd OK)) <- alice #: ("1", "B:", "NEW") + alice #: ("2", team, "ADD " <> bobA) #> ("2", team, OK) + alice #: ("3", team, "ADD " <> tomA) #> ("3", team, OK) + -- commands with errors + alice #: ("e1", team, "NEW") #> ("e1", team, ERR $ BCAST B_DUPLICATE) + alice #: ("e2", "B:group", "ADD " <> bobA) #> ("e2", "B:group", ERR $ BCAST B_NOT_FOUND) + alice #: ("e3", team, "ADD C:unknown") #> ("e3", team, ERR $ CONN NOT_FOUND) + alice #: ("e4", team, "ADD " <> bobA) #> ("e4", team, ERR $ CONN DUPLICATE) + -- send message + alice #: ("4", team, "SEND 5\nhello") =#> \case ("4", c, Sent 1) -> c == bobA || c == tomA; _ -> False + alice <#= \case ("4", c, Sent 1) -> c == bobA || c == tomA; _ -> False + alice <# ("4", team, SENT 0) + bob <#= \case ("", c, Msg "hello") -> c == aliceB; _ -> False + tom <#= \case ("", c, Msg "hello") -> c == aliceT; _ -> False + -- remove one connection + alice #: ("5", team, "REM " <> tomA) #> ("5", team, OK) + alice #: ("6", team, "SEND 11\nhello again") #> ("6", bobA, SENT 2) + alice <# ("6", team, SENT 0) + bob <#= \case ("", c, Msg "hello again") -> c == aliceB; _ -> False + tom #:# "nothing delivered to tom" + -- commands with errors + alice #: ("e5", "B:group", "REM " <> bobA) #> ("e5", "B:group", ERR $ BCAST B_NOT_FOUND) + alice #: ("e6", team, "REM C:unknown") #> ("e6", team, ERR $ CONN NOT_FOUND) + alice #: ("e7", team, "REM " <> tomA) #> ("e7", team, ERR $ CONN NOT_FOUND) + -- delete broadcast + alice #: ("7", team, "DEL") #> ("7", team, OK) + alice #: ("8", team, "SEND 11\ntry sending") #> ("8", team, ERR $ BCAST B_NOT_FOUND) + -- commands with errors + alice #: ("e8", team, "DEL") #> ("e8", team, ERR $ BCAST B_NOT_FOUND) + alice #: ("e9", "B:group", "DEL") #> ("e9", "B:group", ERR $ BCAST B_NOT_FOUND) + +testIntroduction :: forall c. Transport c => TProxy c -> c -> c -> c -> IO () +testIntroduction _ alice bob tom = do + -- establish connections + (alice, "alice") `connect` (bob, "bob") + (alice, "alice") `connect` (tom, "tom") + -- send introduction of tom to bob + alice #: ("1", "C:bob", "INTRO C:tom 8\nmeet tom") #> ("1", "C:bob", OK) + ("", "C:alice", Req invId1 "meet tom") <- (bob <#:) + bob #: ("2", "C:tom_via_alice", "ACPT C:" <> invId1 <> " 7\nI'm bob") #> ("2", "C:tom_via_alice", OK) + ("", "C:alice", Req invId2 "I'm bob") <- (tom <#:) + -- TODO info "tom here" is not used, either JOIN command also should have eInfo parameter + -- or this should be another command, not ACPT + tom #: ("3", "C:bob_via_alice", "ACPT C:" <> invId2 <> " 8\ntom here") #> ("3", "C:bob_via_alice", OK) + tom <# ("", "C:bob_via_alice", CON) + bob <# ("", "C:tom_via_alice", CON) + alice <# ("", "C:bob", ICON (IE (Conn "tom"))) + -- they can message each other now + tom #: ("4", "C:bob_via_alice", "SEND :hello") #> ("4", "C:bob_via_alice", SENT 1) + bob <#= \case ("", "C:tom_via_alice", Msg "hello") -> True; _ -> False + bob #: ("5", "C:tom_via_alice", "SEND 9\nhello too") #> ("5", "C:tom_via_alice", SENT 2) + tom <#= \case ("", "C:bob_via_alice", Msg "hello too") -> True; _ -> False + +testIntroductionRandomIds :: forall c. Transport c => TProxy c -> c -> c -> c -> IO () +testIntroductionRandomIds _ alice bob tom = do + -- establish connections + (aliceB, bobA) <- alice `connect'` bob + (aliceT, tomA) <- alice `connect'` tom + -- send introduction of tom to bob + alice #: ("1", bobA, "INTRO " <> tomA <> " 8\nmeet tom") #> ("1", bobA, OK) + ("", aliceB', Req invId1 "meet tom") <- (bob <#:) + aliceB' `shouldBe` aliceB + ("2", tomB, Right (APartyCmd OK)) <- bob #: ("2", "C:", "ACPT C:" <> invId1 <> " 7\nI'm bob") + ("", aliceT', Req invId2 "I'm bob") <- (tom <#:) + aliceT' `shouldBe` aliceT + -- TODO info "tom here" is not used, either JOIN command also should have eInfo parameter + -- or this should be another command, not ACPT + ("3", bobT, Right (APartyCmd OK)) <- tom #: ("3", "C:", "ACPT C:" <> invId2 <> " 8\ntom here") + tom <# ("", bobT, CON) + bob <# ("", tomB, CON) + alice <# ("", bobA, ICON . IE . Conn $ B.drop 2 tomA) + -- they can message each other now + tom #: ("4", bobT, "SEND :hello") #> ("4", bobT, SENT 1) + bob <#= \case ("", c, Msg "hello") -> c == tomB; _ -> False + bob #: ("5", tomB, "SEND 9\nhello too") #> ("5", tomB, SENT 2) + tom <#= \case ("", c, Msg "hello too") -> c == bobT; _ -> False + +connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO () +connect (h1, name1) (h2, name2) = do + ("c1", _, Inv qInfo) <- h1 #: ("c1", "C:" <> name2, "NEW") + let qInfo' = serializeSmpQueueInfo qInfo + h2 #: ("c2", "C:" <> name1, "JOIN " <> qInfo') #> ("", "C:" <> name1, CON) + h1 <# ("", "C:" <> name2, CON) + +connect' :: forall c. Transport c => c -> c -> IO (ByteString, ByteString) +connect' h1 h2 = do + ("c1", conn2, Inv qInfo) <- h1 #: ("c1", "C:", "NEW") + let qInfo' = serializeSmpQueueInfo qInfo + ("", conn1, Right (APartyCmd CON)) <- h2 #: ("c2", "C:", "JOIN " <> qInfo') + h1 <# ("", conn2, CON) + pure (conn1, conn2) samplePublicKey :: ByteString samplePublicKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR" @@ -205,7 +337,7 @@ syntaxTests t = do -- TODO: ERROR no connection alias in the response (it does not generate it yet if not provided) -- TODO: add tests with defined connection alias it "using same server as in invitation" $ - ("311", "C:", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "C:", "ERR SMP AUTH") + ("311", "C:a", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "C:a", "ERR SMP AUTH") describe "invalid" do -- TODO: JOIN is not merged yet - to be added it "no parameters" $ ("321", "C:", "JOIN") >#> ("321", "C:", "ERR CMD SYNTAX") diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 834720645..2f8383a8c 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -8,9 +8,11 @@ module AgentTests.SQLiteTests (storeTests) where import Control.Concurrent.Async (concurrently_) +import Control.Concurrent.STM (newTVarIO) import Control.Monad (replicateM_) import Control.Monad.Except (ExceptT, runExceptT) import qualified Crypto.PubKey.RSA as R +import Crypto.Random (drgNew) import Data.ByteString.Char8 (ByteString) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) @@ -73,11 +75,13 @@ storeTests = do describe "Queue and Connection management" do describe "createRcvConn" do testCreateRcvConn + testCreateRcvConnRandomId testCreateRcvConnDuplicate describe "createSndConn" do testCreateSndConn + testCreateSndConnRandomID testCreateSndConnDuplicate - describe "getAllConnAliases" testGetAllConnAliases + describe "getAllConnIds" testGetAllConnIds describe "getRcvConn" testGetRcvConn describe "deleteConn" do testDeleteRcvConn @@ -104,14 +108,16 @@ storeTests = do testConcurrentWrites :: SpecWith (SQLiteStore, SQLiteStore) testConcurrentWrites = it "should complete multiple concurrent write transactions w/t sqlite busy errors" $ \(s1, s2) -> do - _ <- runExceptT $ createRcvConn s1 rcvQueue1 - concurrently_ (runTest s1) (runTest s2) + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn s1 g cData1 rcvQueue1 + let ConnData {connId} = cData1 + concurrently_ (runTest s1 connId) (runTest s2 connId) where - runTest :: SQLiteStore -> IO (Either StoreError ()) - runTest store = runExceptT . replicateM_ 100 $ do - (internalId, internalRcvId, _, _) <- updateRcvIds store rcvQueue1 + runTest :: SQLiteStore -> ConnId -> IO (Either StoreError ()) + runTest store connId = runExceptT . replicateM_ 100 $ do + (internalId, internalRcvId, _, _) <- updateRcvIds store connId let rcvMsgData = mkRcvMsgData internalId internalRcvId 0 "0" "hash_dummy" - createRcvMsg store rcvQueue1 rcvMsgData + createRcvMsg store connId rcvMsgData testCompiledThreadsafe :: SpecWith SQLiteStore testCompiledThreadsafe = @@ -132,12 +138,14 @@ testForeignKeysEnabled = DB.execute_ (dbConn store) inconsistentQuery `shouldThrow` (\e -> DB.sqlError e == DB.ErrorConstraint) +cData1 :: ConnData +cData1 = ConnData {connId = "conn1", viaInv = Nothing, connLevel = 1} + rcvQueue1 :: RcvQueue rcvQueue1 = RcvQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, rcvId = "1234", - connAlias = "conn1", rcvPrivateKey = C.safePrivateKey (1, 2, 3), sndId = Just "2345", sndKey = Nothing, @@ -151,7 +159,6 @@ sndQueue1 = SndQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, sndId = "3456", - connAlias = "conn1", sndPrivateKey = C.safePrivateKey (1, 2, 3), encryptKey = C.PublicKey $ R.PublicKey 1 2 3, signKey = C.safePrivateKey (1, 2, 3), @@ -161,64 +168,95 @@ sndQueue1 = testCreateRcvConn :: SpecWith SQLiteStore testCreateRcvConn = it "should create RcvConnection and add SndQueue" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + g <- newTVarIO =<< drgNew + createRcvConn store g cData1 rcvQueue1 + `returnsResult` "conn1" getConn store "conn1" - `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) + `returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1) upgradeRcvConnToDuplex store "conn1" sndQueue1 `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1) + +testCreateRcvConnRandomId :: SpecWith SQLiteStore +testCreateRcvConnRandomId = + it "should create RcvConnection and add SndQueue with random ID" $ \store -> do + g <- newTVarIO =<< drgNew + Right connId <- runExceptT $ createRcvConn store g cData1 {connId = ""} rcvQueue1 + getConn store connId + `returnsResult` SomeConn SCRcv (RcvConnection cData1 {connId} rcvQueue1) + upgradeRcvConnToDuplex store connId sndQueue1 + `returnsResult` () + getConn store connId + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1) testCreateRcvConnDuplicate :: SpecWith SQLiteStore testCreateRcvConnDuplicate = it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 - createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 + createRcvConn store g cData1 rcvQueue1 `throwsError` SEConnDuplicate testCreateSndConn :: SpecWith SQLiteStore testCreateSndConn = it "should create SndConnection and add RcvQueue" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + g <- newTVarIO =<< drgNew + createSndConn store g cData1 sndQueue1 + `returnsResult` "conn1" getConn store "conn1" - `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) + `returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1) upgradeSndConnToDuplex store "conn1" rcvQueue1 `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1) + +testCreateSndConnRandomID :: SpecWith SQLiteStore +testCreateSndConnRandomID = + it "should create SndConnection and add RcvQueue with random ID" $ \store -> do + g <- newTVarIO =<< drgNew + Right connId <- runExceptT $ createSndConn store g cData1 {connId = ""} sndQueue1 + getConn store connId + `returnsResult` SomeConn SCSnd (SndConnection cData1 {connId} sndQueue1) + upgradeSndConnToDuplex store connId rcvQueue1 + `returnsResult` () + getConn store connId + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1) testCreateSndConnDuplicate :: SpecWith SQLiteStore testCreateSndConnDuplicate = it "should throw error on attempt to create duplicate SndConnection" $ \store -> do - _ <- runExceptT $ createSndConn store sndQueue1 - createSndConn store sndQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createSndConn store g cData1 sndQueue1 + createSndConn store g cData1 sndQueue1 `throwsError` SEConnDuplicate -testGetAllConnAliases :: SpecWith SQLiteStore -testGetAllConnAliases = +testGetAllConnIds :: SpecWith SQLiteStore +testGetAllConnIds = it "should get all conn aliases" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 - _ <- runExceptT $ createSndConn store sndQueue1 {connAlias = "conn2"} - getAllConnAliases store - `returnsResult` ["conn1" :: ConnAlias, "conn2" :: ConnAlias] + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 + _ <- runExceptT $ createSndConn store g cData1 {connId = "conn2"} sndQueue1 + getAllConnIds store + `returnsResult` ["conn1" :: ConnId, "conn2" :: ConnId] testGetRcvConn :: SpecWith SQLiteStore testGetRcvConn = it "should get connection using rcv queue id and server" $ \store -> do let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash let recipientId = "1234" - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 getRcvConn store smpServer recipientId - `returnsResult` SomeConn SCRcv (RcvConnection (connAlias (rcvQueue1 :: RcvQueue)) rcvQueue1) + `returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1) testDeleteRcvConn :: SpecWith SQLiteStore testDeleteRcvConn = it "should create RcvConnection and delete it" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 getConn store "conn1" - `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) + `returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1) deleteConn store "conn1" `returnsResult` () -- TODO check queues are deleted as well @@ -228,9 +266,10 @@ testDeleteRcvConn = testDeleteSndConn :: SpecWith SQLiteStore testDeleteSndConn = it "should create SndConnection and delete it" $ \store -> do - _ <- runExceptT $ createSndConn store sndQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createSndConn store g cData1 sndQueue1 getConn store "conn1" - `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) + `returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1) deleteConn store "conn1" `returnsResult` () -- TODO check queues are deleted as well @@ -240,10 +279,11 @@ testDeleteSndConn = testDeleteDuplexConn :: SpecWith SQLiteStore testDeleteDuplexConn = it "should create DuplexConnection and delete it" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1) deleteConn store "conn1" `returnsResult` () -- TODO check queues are deleted as well @@ -253,12 +293,12 @@ testDeleteDuplexConn = testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore testUpgradeRcvConnToDuplex = it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do - _ <- runExceptT $ createSndConn store sndQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createSndConn store g cData1 sndQueue1 let anotherSndQueue = SndQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, sndId = "2345", - connAlias = "conn1", sndPrivateKey = C.safePrivateKey (1, 2, 3), encryptKey = C.PublicKey $ R.PublicKey 1 2 3, signKey = C.safePrivateKey (1, 2, 3), @@ -273,12 +313,12 @@ testUpgradeRcvConnToDuplex = testUpgradeSndConnToDuplex :: SpecWith SQLiteStore testUpgradeSndConnToDuplex = it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 let anotherRcvQueue = RcvQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, rcvId = "3456", - connAlias = "conn1", rcvPrivateKey = C.safePrivateKey (1, 2, 3), sndId = Just "4567", sndKey = Nothing, @@ -295,40 +335,43 @@ testUpgradeSndConnToDuplex = testSetRcvQueueStatus :: SpecWith SQLiteStore testSetRcvQueueStatus = it "should update status of RcvQueue" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 getConn store "conn1" - `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) + `returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1) setRcvQueueStatus store rcvQueue1 Confirmed `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1 {status = Confirmed}) + `returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1 {status = Confirmed}) testSetSndQueueStatus :: SpecWith SQLiteStore testSetSndQueueStatus = it "should update status of SndQueue" $ \store -> do - _ <- runExceptT $ createSndConn store sndQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createSndConn store g cData1 sndQueue1 getConn store "conn1" - `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) + `returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1) setSndQueueStatus store sndQueue1 Confirmed `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1 {status = Confirmed}) + `returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1 {status = Confirmed}) testSetQueueStatusDuplex :: SpecWith SQLiteStore testSetQueueStatusDuplex = it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1) setRcvQueueStatus store rcvQueue1 Secured `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1) setSndQueueStatus store sndQueue1 Confirmed `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}) testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore testSetRcvQueueStatusNoQueue = @@ -362,20 +405,22 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash = msgIntegrity = MsgOk } -testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> RcvQueue -> RcvMsgData -> Expectation -testCreateRcvMsg' store expectedPrevSndId expectedPrevHash rcvQueue rcvMsgData@RcvMsgData {..} = do - updateRcvIds store rcvQueue +testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvMsgData -> Expectation +testCreateRcvMsg' st expectedPrevSndId expectedPrevHash connId rcvMsgData@RcvMsgData {..} = do + updateRcvIds st connId `returnsResult` (internalId, internalRcvId, expectedPrevSndId, expectedPrevHash) - createRcvMsg store rcvQueue rcvMsgData + createRcvMsg st connId rcvMsgData `returnsResult` () testCreateRcvMsg :: SpecWith SQLiteStore testCreateRcvMsg = - it "should reserve internal ids and create a RcvMsg" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + it "should reserve internal ids and create a RcvMsg" $ \st -> do + g <- newTVarIO =<< drgNew + let ConnData {connId} = cData1 + _ <- runExceptT $ createRcvConn st g cData1 rcvQueue1 -- TODO getMsg to check message - testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy" - testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy" + testCreateRcvMsg' st 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy" + testCreateRcvMsg' st 1 "hash_dummy" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy" mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData mkSndMsgData internalId internalSndId internalHash = @@ -387,29 +432,33 @@ mkSndMsgData internalId internalSndId internalHash = internalHash } -testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation -testCreateSndMsg' store expectedPrevHash sndQueue sndMsgData@SndMsgData {..} = do - updateSndIds store sndQueue +testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> ConnId -> SndMsgData -> Expectation +testCreateSndMsg' store expectedPrevHash connId sndMsgData@SndMsgData {..} = do + updateSndIds store connId `returnsResult` (internalId, internalSndId, expectedPrevHash) - createSndMsg store sndQueue sndMsgData + createSndMsg store connId sndMsgData `returnsResult` () testCreateSndMsg :: SpecWith SQLiteStore testCreateSndMsg = it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do - _ <- runExceptT $ createSndConn store sndQueue1 + g <- newTVarIO =<< drgNew + let ConnData {connId} = cData1 + _ <- runExceptT $ createSndConn store g cData1 sndQueue1 -- TODO getMsg to check message - testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy" - testCreateSndMsg' store "hash_dummy" sndQueue1 $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy" + testCreateSndMsg' store "" connId $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy" + testCreateSndMsg' store "hash_dummy" connId $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy" testCreateRcvAndSndMsgs :: SpecWith SQLiteStore testCreateRcvAndSndMsgs = it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + let ConnData {connId} = cData1 + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 - testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1" - testCreateRcvMsg' store 1 "rcv_hash_1" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2" - testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1" - testCreateRcvMsg' store 2 "rcv_hash_2" rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3" - testCreateSndMsg' store "snd_hash_1" sndQueue1 $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2" - testCreateSndMsg' store "snd_hash_2" sndQueue1 $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3" + testCreateRcvMsg' store 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1" + testCreateRcvMsg' store 1 "rcv_hash_1" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2" + testCreateSndMsg' store "" connId $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1" + testCreateRcvMsg' store 2 "rcv_hash_2" connId $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3" + testCreateSndMsg' store "snd_hash_1" connId $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2" + testCreateSndMsg' store "snd_hash_2" connId $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3" diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 918b276f0..fbbfd7ccb 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -54,12 +54,12 @@ testDB3 = "tests/tmp/smp-agent3.test.protocol.db" smpAgentTest :: forall c. Transport c => TProxy c -> ARawTransmission -> IO ARawTransmission smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> tGetRaw h -runSmpAgentTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => (c -> m a) -> m a +runSmpAgentTest :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => (c -> m a) -> m a runSmpAgentTest test = withSmpServer t . withSmpAgent t $ testSMPAgentClient test where t = transport @c -runSmpAgentServerTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => ((ThreadId, ThreadId) -> c -> m a) -> m a +runSmpAgentServerTest :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => ((ThreadId, ThreadId) -> c -> m a) -> m a runSmpAgentServerTest test = withSmpServerThreadOn t testPort $ \server -> withSmpAgentThreadOn t (agentTestPort, testPort, testDB) $ @@ -70,7 +70,7 @@ runSmpAgentServerTest test = smpAgentServerTest :: Transport c => ((ThreadId, ThreadId) -> c -> IO ()) -> Expectation smpAgentServerTest test' = runSmpAgentServerTest test' `shouldReturn` () -runSmpAgentTestN :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => [(ServiceName, ServiceName, String)] -> ([c] -> m a) -> m a +runSmpAgentTestN :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => [(ServiceName, ServiceName, String)] -> ([c] -> m a) -> m a runSmpAgentTestN agents test = withSmpServer t $ run agents [] where run :: [(ServiceName, ServiceName, String)] -> [c] -> m a @@ -78,7 +78,7 @@ runSmpAgentTestN agents test = withSmpServer t $ run agents [] run (a@(p, _, _) : as) hs = withSmpAgentOn t a $ testSMPAgentClientOn p $ \h -> run as (h : hs) t = transport @c -runSmpAgentTestN_1 :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => Int -> ([c] -> m a) -> m a +runSmpAgentTestN_1 :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => Int -> ([c] -> m a) -> m a runSmpAgentTestN_1 nClients test = withSmpServer t . withSmpAgent t $ run nClients [] where run :: Int -> [c] -> m a @@ -156,17 +156,17 @@ cfg = } } -withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a +withSmpAgentThreadOn :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a withSmpAgentThreadOn t (port', smpPort', db') = let cfg' = cfg {tcpPort = port', dbFile = db', smpServers = L.fromList [SMPServer "localhost" (Just smpPort') testKeyHash]} in serverBracket (\started -> runSMPAgentBlocking t started cfg') (removeFile db') -withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m a -> m a +withSmpAgentOn :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m a -> m a withSmpAgentOn t (port', smpPort', db') = withSmpAgentThreadOn t (port', smpPort', db') . const -withSmpAgent :: (MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a +withSmpAgent :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB) testSMPAgentClientOn :: (Transport c, MonadUnliftIO m) => ServiceName -> (c -> m a) -> m a