introduction protocol (#156)

* commands to support introduction

* agent messages / envelopes to support introductions

* introductions and invitations table; insert record with random unique ID

* store class methods and types for introductions

* process INTRO and ACPT commands for connection introductions

* fix tests: add MonadFail constraint, remove OK response to JOIN

* process agent messages for introductions

* ICON notification when introduction is completed

* replace multiway if with case

* correction

* support random connection IDs

* save additional connection fields, refactor create connection funcs

* refactor

* refactor

* test duplex connection with random IDs

* store methods for introductions

* test introduction

* fix parsing of CON agent message

* test introduction with random connection IDs

* broadcast with random connection and broadcast IDs

* clean up sql
This commit is contained in:
Evgeny Poberezkin
2021-06-11 21:33:13 +01:00
committed by GitHub
parent 46c3589604
commit ab89963f45
11 changed files with 1156 additions and 413 deletions
+179 -62
View File
@@ -61,7 +61,7 @@ import UnliftIO.STM
-- | Runs an SMP agent as a TCP service using passed configuration.
--
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
runSMPAgent :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
runSMPAgent t cfg = do
started <- newEmptyTMVarIO
runSMPAgentBlocking t started cfg
@@ -70,10 +70,10 @@ runSMPAgent t cfg = do
--
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
runSMPAgentBlocking :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort} = runReaderT (smpAgent t) =<< newSMPAgentEnv cfg
where
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
smpAgent :: forall c m'. (Transport c, MonadFail m', MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
smpAgent _ = runTransportServer started tcpPort $ \(h :: c) -> do
liftIO $ putLn h "Welcome to SMP v0.3.2 agent"
c <- getSMPAgentClient
@@ -97,7 +97,7 @@ logConnection c connected =
in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"]
-- | Runs an SMP agent instance that receives commands and sends responses via 'TBQueue's.
runSMPAgentClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
runSMPAgentClient :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
runSMPAgentClient c = do
db <- asks $ dbFile . config
s1 <- liftIO $ connectSQLiteStore db
@@ -128,7 +128,7 @@ logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a ->
logClient AgentClient {clientId} dir (ATransmission corrId entity cmd) = do
logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, serializeEntity entity, B.takeWhile (/= ' ') $ serializeCommand cmd]
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
client :: forall m. (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
client c@AgentClient {rcvQ, sndQ} st = forever loop
where
loop :: m ()
@@ -147,6 +147,8 @@ withStore action = do
Right c -> return c
Left e -> throwError $ storeError e
where
-- TODO when parsing exception happens in store, the agent hangs;
-- changing SQLError to SomeException does not help
handleInternal :: (MonadError StoreError m') => SQLError -> m' a
handleInternal e = throwError . SEInternal $ bshow e
storeError :: StoreError -> AgentErrorType
@@ -173,23 +175,28 @@ unsupportedEntity c _ corrId entity _ =
processConnCommand ::
forall c m. (AgentMonad m, EntityCommand 'Conn_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> ACommand 'Client c -> m ()
processConnCommand c@AgentClient {sndQ} st corrId conn = \case
NEW -> createNewConnection conn
JOIN smpQueueInfo replyMode -> joinConnection conn smpQueueInfo replyMode
processConnCommand c@AgentClient {sndQ} st corrId conn@(Conn connId) = \case
NEW -> createNewConnection Nothing 0 >>= uncurry respond
JOIN smpQueueInfo replyMode -> joinConnection smpQueueInfo replyMode Nothing 0 >> pure () -- >>= (`respond` OK)
INTRO reEntity reInfo -> makeIntroduction reEntity reInfo
ACPT inv eInfo -> acceptInvitation inv eInfo
SUB -> subscribeConnection conn
SUBALL -> subscribeAll
SEND msgBody -> sendMessage c st corrId conn msgBody
SEND msgBody -> sendClientMessage c st corrId conn msgBody
OFF -> suspendConnection conn
DEL -> deleteConnection conn
where
createNewConnection :: Entity 'Conn_ -> m ()
createNewConnection (Conn cId) = do
createNewConnection :: Maybe InvitationId -> Int -> m (Entity 'Conn_, ACommand 'Agent 'INV_)
createNewConnection viaInv connLevel = do
-- TODO create connection alias if not passed
-- make connAlias Maybe?
-- make connId Maybe?
srv <- getSMPServer
(rq, qInfo) <- newReceiveQueue c srv cId
withStore $ createRcvConn st rq
respond conn $ INV qInfo
(rq, qInfo) <- newReceiveQueue c srv
g <- asks idsDrg
let cData = ConnData {connId, viaInv, connLevel}
connId' <- withStore $ createRcvConn st g cData rq
addSubscription c rq connId'
pure (Conn connId', INV qInfo)
getSMPServer :: m SMPServer
getSMPServer =
@@ -200,16 +207,47 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1)
pure $ servers L.!! i
joinConnection :: Entity 'Conn_ -> SMPQueueInfo -> ReplyMode -> m ()
joinConnection (Conn cId) qInfo (ReplyMode replyMode) = do
-- TODO create connection alias if not passed
-- make connAlias Maybe?
(sq, senderKey, verifyKey) <- newSendQueue qInfo cId
withStore $ createSndConn st sq
joinConnection :: SMPQueueInfo -> ReplyMode -> Maybe InvitationId -> Int -> m (Entity 'Conn_)
joinConnection qInfo (ReplyMode replyMode) viaInv connLevel = do
(sq, senderKey, verifyKey) <- newSendQueue qInfo
g <- asks idsDrg
let cData = ConnData {connId, viaInv, connLevel}
connId' <- withStore $ createSndConn st g cData sq
connectToSendQueue c st sq senderKey verifyKey
when (replyMode == On) $ createReplyQueue cId sq
-- TODO this response is disabled to avoid two responses in terminal client (OK + CON),
-- respond conn OK
when (replyMode == On) $ createReplyQueue connId' sq
pure $ Conn connId'
makeIntroduction :: IntroEntity -> EntityInfo -> m ()
makeIntroduction (IE reEntity) reInfo = case reEntity of
Conn reConn ->
withStore ((,) <$> getConn st connId <*> getConn st reConn) >>= \case
(SomeConn _ (DuplexConnection _ _ sq), SomeConn _ DuplexConnection {}) -> do
g <- asks idsDrg
introId <- withStore $ createIntro st g NewIntroduction {toConn = connId, reConn, reInfo}
sendControlMessage c sq $ A_INTRO (IE (Conn introId)) reInfo
respond conn OK
_ -> throwError $ CONN SIMPLEX
_ -> throwError $ CMD UNSUPPORTED
acceptInvitation :: IntroEntity -> EntityInfo -> m ()
acceptInvitation (IE invEntity) eInfo = case invEntity of
Conn invId -> do
withStore (getInvitation st invId) >>= \case
Invitation {viaConn, qInfo, externalIntroId, status = InvNew} ->
withStore (getConn st viaConn) >>= \case
SomeConn _ (DuplexConnection ConnData {connLevel} _ sq) -> case qInfo of
Nothing -> do
(conn', INV qInfo') <- createNewConnection (Just invId) (connLevel + 1)
withStore $ addInvitationConn st invId $ fromConn conn'
sendControlMessage c sq $ A_INV (Conn externalIntroId) qInfo' eInfo
respond conn' OK
Just qInfo' -> do
conn' <- joinConnection qInfo' (ReplyMode On) (Just invId) (connLevel + 1)
withStore $ addInvitationConn st invId $ fromConn conn'
respond conn' OK
_ -> throwError $ CONN SIMPLEX
_ -> throwError $ CMD PROHIBITED
_ -> throwError $ CMD UNSUPPORTED
subscribeConnection :: Entity 'Conn_ -> m ()
subscribeConnection conn'@(Conn cId) =
@@ -222,7 +260,7 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
-- TODO remove - hack for subscribing to all; respond' and parameterization of subscribeConnection are byproduct
subscribeAll :: m ()
subscribeAll = withStore (getAllConnAliases st) >>= mapM_ (subscribeConnection . Conn)
subscribeAll = withStore (getAllConnIds st) >>= mapM_ (subscribeConnection . Conn)
suspendConnection :: Entity 'Conn_ -> m ()
suspendConnection (Conn cId) =
@@ -246,25 +284,30 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
removeSubscription c cId
delConn
createReplyQueue :: ByteString -> SndQueue -> m ()
createReplyQueue :: ConnId -> SndQueue -> m ()
createReplyQueue cId sq = do
srv <- getSMPServer
(rq, qInfo) <- newReceiveQueue c srv cId
(rq, qInfo) <- newReceiveQueue c srv
addSubscription c rq cId
withStore $ upgradeSndConnToDuplex st cId rq
senderTimestamp <- liftIO getCurrentTime
sendAgentMessage c sq . serializeSMPMessage $
SMPMessage
{ senderMsgId = 0,
senderTimestamp,
previousMsgHash = "",
agentMessage = REPLY qInfo
}
sendControlMessage c sq $ REPLY qInfo
respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m ()
respond ent resp = atomically . writeTBQueue sndQ $ ATransmission corrId ent resp
sendMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m ()
sendMessage c st corrId (Conn cId) msgBody =
sendControlMessage :: AgentMonad m => AgentClient -> SndQueue -> AMessage -> m ()
sendControlMessage c sq agentMessage = do
senderTimestamp <- liftIO getCurrentTime
sendAgentMessage c sq . serializeSMPMessage $
SMPMessage
{ senderMsgId = 0,
senderTimestamp,
previousMsgHash = "",
agentMessage
}
sendClientMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m ()
sendClientMessage c st corrId (Conn cId) msgBody =
withStore (getConn st cId) >>= \case
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq
SomeConn _ (SndConnection _ sq) -> sendMsg sq
@@ -273,7 +316,7 @@ sendMessage c st corrId (Conn cId) msgBody =
sendMsg :: SndQueue -> m ()
sendMsg sq = do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st cId
let msgStr =
serializeSMPMessage
SMPMessage
@@ -284,7 +327,7 @@ sendMessage c st corrId (Conn cId) msgBody =
}
msgHash = C.sha256Hash msgStr
withStore $
createSndMsg st sq $
createSndMsg st cId $
SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash}
sendAgentMessage c sq msgStr
atomically . writeTBQueue (sndQ c) $ ATransmission corrId (Conn cId) $ SENT (unId internalId)
@@ -292,15 +335,21 @@ sendMessage c st corrId (Conn cId) msgBody =
processBroadcastCommand ::
forall c m. (AgentMonad m, EntityCommand 'Broadcast_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Broadcast_ -> ACommand 'Client c -> m ()
processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case
NEW -> withStore (createBcast st bId) >> ok
NEW -> createNewBroadcast
ADD (Conn cId) -> withStore (addBcastConn st bId cId) >> ok
REM (Conn cId) -> withStore (removeBcastConn st bId cId) >> ok
LS -> withStore (getBcast st bId) >>= respond bcast . MS . map Conn
SEND msgBody -> withStore (getBcast st bId) >>= mapM_ (sendMsg msgBody) >> respond bcast (SENT 0)
DEL -> withStore (deleteBcast st bId) >> ok
where
sendMsg :: MsgBody -> ConnAlias -> m ()
sendMsg msgBody cId = sendMessage c st corrId (Conn cId) msgBody
createNewBroadcast :: m ()
createNewBroadcast = do
g <- asks idsDrg
bId' <- withStore $ createBcast st g bId
respond (Broadcast bId') OK
sendMsg :: MsgBody -> ConnId -> m ()
sendMsg msgBody cId = sendClientMessage c st corrId (Conn cId) msgBody
ok :: m ()
ok = respond bcast OK
@@ -308,7 +357,7 @@ processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case
respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m ()
respond ent resp = atomically . writeTBQueue (sndQ c) $ ATransmission corrId ent resp
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
subscriber :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
subscriber c@AgentClient {msgQ} st = forever $ do
-- TODO this will only process messages and notifications
t <- atomically $ readTBQueue msgQ
@@ -319,29 +368,33 @@ subscriber c@AgentClient {msgQ} st = forever $ do
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m ()
processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
withStore (getRcvConn st srv rId) >>= \case
SomeConn SCDuplex (DuplexConnection _ rq _) -> processSMP SCDuplex rq
SomeConn SCRcv (RcvConnection _ rq) -> processSMP SCRcv rq
SomeConn SCDuplex (DuplexConnection cData rq _) -> processSMP SCDuplex cData rq
SomeConn SCRcv (RcvConnection cData rq) -> processSMP SCRcv cData rq
_ -> atomically . writeTBQueue sndQ $ ATransmission "" (Conn "") (ERR $ CONN SIMPLEX)
where
processSMP :: SConnType c -> RcvQueue -> m ()
processSMP cType rq@RcvQueue {connAlias, status} =
processSMP :: SConnType c -> ConnData -> RcvQueue -> m ()
processSMP cType ConnData {connId} rq@RcvQueue {status} =
case cmd of
SMP.MSG srvMsgId srvTs msgBody -> do
-- TODO deduplicate with previously received
msg <- decryptAndVerify rq msgBody
let msgHash = C.sha256Hash msg
agentMsg <- liftEither $ parseSMPMessage msg
case agentMsg of
SMPConfirmation senderKey -> smpConfirmation senderKey
SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
case parseSMPMessage msg of
Left e -> notify $ ERR e
Right (SMPConfirmation senderKey) -> smpConfirmation senderKey
Right SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
case agentMessage of
HELLO verifyKey _ -> helloMsg verifyKey msgBody
REPLY qInfo -> replyMsg qInfo
A_MSG body -> agentClientMsg previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash
A_INTRO entity eInfo -> introMsg entity eInfo
A_INV conn qInfo cInfo -> invMsg conn qInfo cInfo
A_REQ conn qInfo cInfo -> reqMsg conn qInfo cInfo
A_CON conn -> conMsg conn
sendAck c rq
return ()
SMP.END -> do
removeSubscription c connAlias
removeSubscription c connId
logServer "<--" c srv rId "END"
notify END
_ -> do
@@ -349,7 +402,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
notify . ERR $ BROKER UNEXPECTED
where
notify :: EntityCommand 'Conn_ c => ACommand 'Agent c -> m ()
notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connAlias) msg
notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connId) msg
prohibited :: m ()
prohibited = notify . ERR $ AGENT A_PROHIBITED
@@ -376,7 +429,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
void $ verifyMessage (Just verifyKey) msgBody
withStore $ setRcvQueueActive st rq verifyKey
case cType of
SCDuplex -> notify CON
SCDuplex -> connected
_ -> pure ()
replyMsg :: SMPQueueInfo -> m ()
@@ -384,22 +437,87 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
logServer "<--" c srv rId "MSG <REPLY>"
case cType of
SCRcv -> do
(sq, senderKey, verifyKey) <- newSendQueue qInfo connAlias
withStore $ upgradeRcvConnToDuplex st connAlias sq
(sq, senderKey, verifyKey) <- newSendQueue qInfo
withStore $ upgradeRcvConnToDuplex st connId sq
connectToSendQueue c st sq senderKey verifyKey
notify CON
connected
_ -> prohibited
connected :: m ()
connected = do
withStore (getConnInvitation st connId) >>= \case
Just (Invitation {invId, externalIntroId}, DuplexConnection _ _ sq) -> do
withStore $ setInvitationStatus st invId InvCon
sendControlMessage c sq $ A_CON (Conn externalIntroId)
_ -> pure ()
notify CON
introMsg :: IntroEntity -> EntityInfo -> m ()
introMsg (IE entity) entityInfo = do
logServer "<--" c srv rId "MSG <INTRO>"
case (cType, entity) of
(SCDuplex, intro@Conn {}) -> createInv intro Nothing entityInfo
_ -> prohibited
invMsg :: Entity 'Conn_ -> SMPQueueInfo -> EntityInfo -> m ()
invMsg (Conn introId) qInfo toInfo = do
logServer "<--" c srv rId "MSG <INV>"
case cType of
SCDuplex ->
withStore (getIntro st introId) >>= \case
Introduction {toConn, toStatus = IntroNew, reConn, reStatus = IntroNew}
| toConn /= connId -> prohibited
| otherwise ->
withStore (addIntroInvitation st introId toInfo qInfo >> getConn st reConn) >>= \case
SomeConn _ (DuplexConnection _ _ sq) -> do
sendControlMessage c sq $ A_REQ (Conn introId) qInfo toInfo
withStore $ setIntroReStatus st introId IntroInv
_ -> prohibited
_ -> prohibited
_ -> prohibited
reqMsg :: Entity 'Conn_ -> SMPQueueInfo -> EntityInfo -> m ()
reqMsg intro qInfo connInfo = do
logServer "<--" c srv rId "MSG <REQ>"
case cType of
SCDuplex -> createInv intro (Just qInfo) connInfo
_ -> prohibited
createInv :: Entity 'Conn_ -> Maybe SMPQueueInfo -> EntityInfo -> m ()
createInv (Conn externalIntroId) qInfo entityInfo = do
g <- asks idsDrg
let newInv = NewInvitation {viaConn = connId, externalIntroId, entityInfo, qInfo}
invId <- withStore $ createInvitation st g newInv
notify $ REQ (IE (Conn invId)) entityInfo
conMsg :: Entity 'Conn_ -> m ()
conMsg (Conn introId) = do
logServer "<--" c srv rId "MSG <CON>"
withStore (getIntro st introId) >>= \case
Introduction {toConn, toStatus, reConn, reStatus}
| toConn == connId && toStatus == IntroInv -> do
withStore $ setIntroToStatus st introId IntroCon
sendConMsg toConn reConn reStatus
| reConn == connId && reStatus == IntroInv -> do
withStore $ setIntroReStatus st introId IntroCon
sendConMsg toConn reConn toStatus
| otherwise -> prohibited
where
sendConMsg :: ConnId -> ConnId -> IntroStatus -> m ()
sendConMsg toConn reConn IntroCon =
atomically . writeTBQueue sndQ $ ATransmission "" (Conn toConn) $ ICON $ IE (Conn reConn)
sendConMsg _ _ _ = pure ()
agentClientMsg :: PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m ()
agentClientMsg receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do
logServer "<--" c srv rId "MSG <MSG>"
case status of
Active -> do
internalTs <- liftIO getCurrentTime
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st connId
let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash receivedPrevMsgHash
withStore $
createRcvMsg st rq $
createRcvMsg st connId $
RcvMsgData
{ internalId,
internalRcvId,
@@ -438,8 +556,8 @@ connectToSendQueue c st sq senderKey verifyKey = do
withStore $ setSndQueueStatus st sq Active
newSendQueue ::
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey)
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SndQueue, SenderPublicKey, VerificationKey)
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) = do
size <- asks $ rsaKeySize . config
(senderKey, sndPrivateKey) <- liftIO $ C.generateKeyPair size
(verifyKey, signKey) <- liftIO $ C.generateKeyPair size
@@ -447,7 +565,6 @@ newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do
SndQueue
{ server = smpServer,
sndId = senderId,
connAlias,
sndPrivateKey,
encryptKey,
signKey,
+13 -14
View File
@@ -15,6 +15,7 @@ module Simplex.Messaging.Agent.Client
closeSMPServerClients,
newReceiveQueue,
subscribeQueue,
addSubscription,
sendConfirmation,
sendHello,
secureQueue,
@@ -61,8 +62,8 @@ data AgentClient = AgentClient
sndQ :: TBQueue (ATransmission 'Agent),
msgQ :: TBQueue SMPServerTransmission,
smpClients :: TVar (Map SMPServer SMPClient),
subscrSrvrs :: TVar (Map SMPServer (Set ConnAlias)),
subscrConns :: TVar (Map ConnAlias SMPServer),
subscrSrvrs :: TVar (Map SMPServer (Set ConnId)),
subscrConns :: TVar (Map ConnId SMPServer),
clientId :: Int
}
@@ -78,7 +79,7 @@ newAgentClient cc AgentConfig {tbqSize} = do
writeTVar cc clientId
return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId}
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m, MonadFail m)
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
@@ -106,7 +107,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
removeSubs >>= mapM_ (mapM_ notifySub)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
removeSubs :: IO (Maybe (Set ConnAlias))
removeSubs :: IO (Maybe (Set ConnId))
removeSubs = atomically $ do
modifyTVar smpClients $ M.delete srv
cs <- M.lookup srv <$> readTVar (subscrSrvrs c)
@@ -117,7 +118,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
deleteKeys :: Ord k => Set k -> Map k a -> Map k a
deleteKeys ks m = S.foldr' M.delete m ks
notifySub :: ConnAlias -> IO ()
notifySub :: ConnId -> IO ()
notifySub connAlias = atomically . writeTBQueue (sndQ c) $ ATransmission "" (Conn connAlias) END
closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m ()
@@ -158,8 +159,8 @@ smpClientError = \case
SMPTransportError e -> BROKER $ TRANSPORT e
e -> INTERNAL $ show e
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo)
newReceiveQueue c srv connAlias = do
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> m (RcvQueue, SMPQueueInfo)
newReceiveQueue c srv = do
size <- asks $ rsaKeySize . config
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateKeyPair size
logServer "-->" c srv "" "NEW"
@@ -170,7 +171,6 @@ newReceiveQueue c srv connAlias = do
RcvQueue
{ server = srv,
rcvId,
connAlias,
rcvPrivateKey,
sndId = Just sId,
sndKey = Nothing,
@@ -178,25 +178,24 @@ newReceiveQueue c srv connAlias = do
verifyKey = Nothing,
status = New
}
addSubscription c rq connAlias
return (rq, SMPQueueInfo srv sId encryptKey)
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnAlias -> m ()
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connAlias = do
withLogSMP c server rcvId "SUB" $ \smp ->
subscribeSMPQueue smp rcvPrivateKey rcvId
addSubscription c rq connAlias
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnAlias -> m ()
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m ()
addSubscription c RcvQueue {server} connAlias = atomically $ do
modifyTVar (subscrConns c) $ M.insert connAlias server
modifyTVar (subscrSrvrs c) $ M.alter (Just . addSub) server
where
addSub :: Maybe (Set ConnAlias) -> Set ConnAlias
addSub :: Maybe (Set ConnId) -> Set ConnId
addSub (Just cs) = S.insert connAlias cs
addSub _ = S.singleton connAlias
removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m ()
removeSubscription :: AgentMonad m => AgentClient -> ConnId -> m ()
removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically $ do
cs <- readTVar subscrConns
writeTVar subscrConns $ M.delete connAlias cs
@@ -204,7 +203,7 @@ removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically
(modifyTVar subscrSrvrs . M.alter (>>= delSub))
(M.lookup connAlias cs)
where
delSub :: Set ConnAlias -> Maybe (Set ConnAlias)
delSub :: Set ConnId -> Maybe (Set ConnId)
delSub cs =
let cs' = S.delete connAlias cs
in if S.null cs' then Nothing else Just cs'
+108 -16
View File
@@ -33,6 +33,8 @@ module Simplex.Messaging.Agent.Protocol
Entity (..),
EntityTag (..),
AnEntity (..),
IntroEntity (..),
EntityInfo,
EntityCommand,
entityCommand,
ACommand (..),
@@ -53,7 +55,7 @@ module Simplex.Messaging.Agent.Protocol
ATransmission (..),
ATransmissionOrError (..),
ARawTransmission,
ConnAlias,
ConnId,
ReplyMode (..),
AckMode (..),
OnOff (..),
@@ -171,6 +173,13 @@ deriving instance Eq (Entity t)
deriving instance Show (Entity t)
instance TestEquality Entity where
testEquality (Conn c) (Conn c') = refl c c'
testEquality (OpenConn c) (OpenConn c') = refl c c'
testEquality (Broadcast c) (Broadcast c') = refl c c'
testEquality (AGroup c) (AGroup c') = refl c c'
testEquality _ _ = Nothing
entityId :: Entity t -> ByteString
entityId = \case
Conn bs -> bs
@@ -195,7 +204,11 @@ type family EntityCommand (t :: EntityTag) (c :: ACmdTag) :: Constraint where
EntityCommand Conn_ NEW_ = ()
EntityCommand Conn_ INV_ = ()
EntityCommand Conn_ JOIN_ = ()
EntityCommand Conn_ INTRO_ = ()
EntityCommand Conn_ REQ_ = ()
EntityCommand Conn_ ACPT_ = ()
EntityCommand Conn_ CON_ = ()
EntityCommand Conn_ ICON_ = ()
EntityCommand Conn_ SUB_ = ()
EntityCommand Conn_ SUBALL_ = ()
EntityCommand Conn_ END_ = ()
@@ -226,7 +239,11 @@ entityCommand = \case
NEW -> Just Dict
INV _ -> Just Dict
JOIN {} -> Just Dict
INTRO {} -> Just Dict
REQ {} -> Just Dict
ACPT {} -> Just Dict
CON -> Just Dict
ICON {} -> Just Dict
SUB -> Just Dict
SUBALL -> Just Dict
END -> Just Dict
@@ -258,7 +275,11 @@ data ACmdTag
= NEW_
| INV_
| JOIN_
| INTRO_
| REQ_
| ACPT_
| CON_
| ICON_
| SUB_
| SUBALL_
| END_
@@ -274,15 +295,31 @@ data ACmdTag
| OK_
| ERR_
type family Introduction (t :: EntityTag) :: Constraint where
Introduction Conn_ = ()
Introduction OpenConn_ = ()
Introduction AGroup_ = ()
Introduction t = (Int ~ Bool, TypeError (Text "Entity " :<>: ShowType t :<>: Text " cannot be INTRO'd to"))
data IntroEntity = forall t. Introduction t => IE (Entity t)
instance Eq IntroEntity where
IE e1 == IE e2 = isJust $ testEquality e1 e2
deriving instance Show IntroEntity
type EntityInfo = ByteString
-- | Parameterized type for SMP agent protocol commands and responses from all participants.
data ACommand (p :: AParty) (c :: ACmdTag) where
NEW :: ACommand Client NEW_ -- response INV
INV :: SMPQueueInfo -> ACommand Agent INV_
JOIN :: SMPQueueInfo -> ReplyMode -> ACommand Client JOIN_ -- response OK
INTRO :: IntroEntity -> EntityInfo -> ACommand Client INTRO_
REQ :: IntroEntity -> EntityInfo -> ACommand Agent INTRO_
ACPT :: IntroEntity -> EntityInfo -> ACommand Client ACPT_
CON :: ACommand Agent CON_ -- notification that connection is established
-- TODO currently it automatically allows whoever sends the confirmation
-- CONF :: OtherPartyId -> ACommand Agent
-- LET :: OtherPartyId -> ACommand Client
ICON :: IntroEntity -> ACommand Agent ICON_
SUB :: ACommand Client SUB_
SUBALL :: ACommand Client SUBALL_ -- TODO should be moved to chat protocol - hack for subscribing to all
END :: ACommand Agent END_
@@ -318,6 +355,7 @@ instance TestEquality (ACommand p) where
testEquality c@INV {} c'@INV {} = refl c c'
testEquality c@JOIN {} c'@JOIN {} = refl c c'
testEquality CON CON = Just Refl
testEquality c@ICON {} c'@ICON {} = refl c c'
testEquality SUB SUB = Just Refl
testEquality SUBALL SUBALL = Just Refl
testEquality END END = Just Refl
@@ -334,7 +372,7 @@ instance TestEquality (ACommand p) where
testEquality c@ERR {} c'@ERR {} = refl c c'
testEquality _ _ = Nothing
refl :: Eq (f a) => f a -> f a -> Maybe (a :~: a)
refl :: Eq a => a -> a -> Maybe (t :~: t)
refl x x' = if x == x' then Just Refl else Nothing
-- | SMP message formats.
@@ -366,6 +404,14 @@ data AMessage where
REPLY :: SMPQueueInfo -> AMessage
-- | agent envelope for the client message
A_MSG :: MsgBody -> AMessage
-- | agent message for introduction
A_INTRO :: IntroEntity -> EntityInfo -> AMessage
-- | agent envelope for the sent invitation
A_INV :: Entity Conn_ -> SMPQueueInfo -> EntityInfo -> AMessage
-- | agent envelope for the forwarded invitation
A_REQ :: Entity Conn_ -> SMPQueueInfo -> EntityInfo -> AMessage
-- | agent message for intro/group request
A_CON :: Entity Conn_ -> AMessage
deriving (Show)
-- | Parse SMP message.
@@ -408,12 +454,22 @@ agentMessageP =
"HELLO " *> hello
<|> "REPLY " *> reply
<|> "MSG " *> a_msg
<|> "INTRO " *> a_intro
<|> "INV " *> a_inv
<|> "REQ " *> a_req
<|> "CON " *> a_con
where
hello = HELLO <$> C.pubKeyP <*> ackMode
reply = REPLY <$> smpQueueInfoP
a_msg = do
a_msg = A_MSG <$> binaryBody
a_intro = A_INTRO <$> introEntityP <* A.space <*> binaryBody
a_inv = invP A_INV
a_req = invP A_REQ
a_con = A_CON <$> connEntityP
invP f = f <$> connEntityP <* A.space <*> smpQueueInfoP <* A.space <*> binaryBody
binaryBody = do
size :: Int <- A.decimal <* A.endOfLine
A_MSG <$> A.take size <* A.endOfLine
A.take size <* A.endOfLine
ackMode = AckMode <$> (" NO_ACK" $> Off <|> pure On)
-- | SMP queue information parser.
@@ -434,6 +490,13 @@ serializeAgentMessage = \case
HELLO verifyKey ackMode -> "HELLO " <> C.serializePubKey verifyKey <> if ackMode == AckMode Off then " NO_ACK" else ""
REPLY qInfo -> "REPLY " <> serializeSmpQueueInfo qInfo
A_MSG body -> "MSG " <> serializeMsg body <> "\n"
A_INTRO (IE entity) eInfo -> "INTRO " <> serializeIntro entity eInfo <> "\n"
A_INV conn qInfo eInfo -> "INV " <> serializeInv conn qInfo eInfo
A_REQ conn qInfo eInfo -> "REQ " <> serializeInv conn qInfo eInfo
A_CON conn -> "CON " <> serializeEntity conn
where
serializeInv conn qInfo eInfo =
B.intercalate " " [serializeEntity conn, serializeSmpQueueInfo qInfo, serializeMsg eInfo] <> "\n"
-- | Serialize SMP queue information that is sent out-of-band.
serializeSmpQueueInfo :: SMPQueueInfo -> ByteString
@@ -457,7 +520,7 @@ instance IsString SMPServer where
fromString = parseString . parseAll $ smpServerP
-- | SMP agent connection alias.
type ConnAlias = ByteString
type ConnId = ByteString
-- | Connection modes.
data OnOff = On | Off deriving (Eq, Show, Read)
@@ -614,10 +677,19 @@ anEntityP =
<|> "B:" $> AE . Broadcast
<|> "G:" $> AE . AGroup
)
<*> A.takeTill (== ' ')
<*> A.takeTill wordEnd
entityConnP :: Parser (Entity Conn_)
entityConnP = "C:" *> (Conn <$> A.takeTill (== ' '))
connEntityP :: Parser (Entity Conn_)
connEntityP = "C:" *> (Conn <$> A.takeTill wordEnd)
introEntityP :: Parser IntroEntity
introEntityP =
($)
<$> ( "C:" $> IE . Conn
<|> "O:" $> IE . OpenConn
<|> "G:" $> IE . AGroup
)
<*> A.takeTill wordEnd
serializeEntity :: Entity t -> ByteString
serializeEntity = \case
@@ -632,6 +704,9 @@ commandP =
"NEW" $> ACmd SClient NEW
<|> "INV " *> invResp
<|> "JOIN " *> joinCmd
<|> "INTRO " *> introCmd
<|> "REQ " *> reqCmd
<|> "ACPT " *> acptCmd
<|> "SUB" $> ACmd SClient SUB
<|> "SUBALL" $> ACmd SClient SUBALL -- TODO remove - hack for subscribing to all
<|> "END" $> ACmd SAgent END
@@ -645,16 +720,21 @@ commandP =
<|> "LS" $> ACmd SClient LS
<|> "MS " *> membersResp
<|> "ERR " *> agentError
<|> "ICON " *> iconMsg
<|> "CON" $> ACmd SAgent CON
<|> "OK" $> ACmd SAgent OK
where
invResp = ACmd SAgent . INV <$> smpQueueInfoP
joinCmd = ACmd SClient <$> (JOIN <$> smpQueueInfoP <*> replyMode)
introCmd = ACmd SClient <$> introP INTRO
reqCmd = ACmd SAgent <$> introP REQ
acptCmd = ACmd SClient <$> introP ACPT
sendCmd = ACmd SClient . SEND <$> A.takeByteString
sentResp = ACmd SAgent . SENT <$> A.decimal
addCmd = ACmd SClient . ADD <$> entityConnP
removeCmd = ACmd SClient . REM <$> entityConnP
membersResp = ACmd SAgent . MS <$> (entityConnP `A.sepBy'` A.char ' ')
addCmd = ACmd SClient . ADD <$> connEntityP
removeCmd = ACmd SClient . REM <$> connEntityP
membersResp = ACmd SAgent . MS <$> (connEntityP `A.sepBy'` A.char ' ')
iconMsg = ACmd SAgent . ICON <$> introEntityP
message = do
msgIntegrity <- msgIntegrityP <* A.space
recipientMeta <- "R=" *> partyMeta A.decimal
@@ -662,6 +742,7 @@ commandP =
senderMeta <- "S=" *> partyMeta A.decimal
msgBody <- A.takeByteString
return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody}
introP f = f <$> introEntityP <* A.space <*> A.takeByteString
replyMode = ReplyMode <$> (" NO_REPLY" $> Off <|> pure On)
partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space
agentError = ACmd SAgent . ERR <$> agentErrorTypeP
@@ -685,6 +766,9 @@ serializeCommand = \case
NEW -> "NEW"
INV qInfo -> "INV " <> serializeSmpQueueInfo qInfo
JOIN qInfo rMode -> "JOIN " <> serializeSmpQueueInfo qInfo <> replyMode rMode
INTRO (IE entity) eInfo -> "INTRO " <> serializeIntro entity eInfo
REQ (IE entity) eInfo -> "REQ " <> serializeIntro entity eInfo
ACPT (IE entity) eInfo -> "ACPT " <> serializeIntro entity eInfo
SUB -> "SUB"
SUBALL -> "SUBALL" -- TODO remove - hack for subscribing to all
END -> "END"
@@ -706,6 +790,7 @@ serializeCommand = \case
LS -> "LS"
MS cs -> "MS " <> B.intercalate " " (map serializeEntity cs)
CON -> "CON"
ICON (IE entity) -> "ICON " <> serializeEntity entity
ERR e -> "ERR " <> serializeAgentError e
OK -> "OK"
where
@@ -716,6 +801,9 @@ serializeCommand = \case
showTs :: UTCTime -> ByteString
showTs = B.pack . formatISO8601Millis
serializeIntro :: Entity t -> ByteString -> ByteString
serializeIntro entity eInfo = serializeEntity entity <> " " <> serializeMsg eInfo
-- | Serialize message integrity validation result.
serializeMsgIntegrity :: MsgIntegrity -> ByteString
serializeMsgIntegrity = \case
@@ -794,9 +882,10 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
hasEntityId :: AnEntity -> APartyCmd p -> Either AgentErrorType (APartyCmd p)
hasEntityId (AE entity) (APartyCmd cmd) =
APartyCmd <$> case cmd of
-- NEW and JOIN have optional entity
-- NEW, JOIN and ACPT have optional entity
NEW -> Right cmd
JOIN _ _ -> Right cmd
JOIN {} -> Right cmd
ACPT {} -> Right cmd
-- ERROR response does not always have entity
ERR _ -> Right cmd
-- other responses must have entity
@@ -818,6 +907,9 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
APartyCmd <$$> case cmd of
SEND body -> SEND <$$> getMsgBody body
MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body
INTRO entity eInfo -> INTRO entity <$$> getMsgBody eInfo
REQ entity eInfo -> REQ entity <$$> getMsgBody eInfo
ACPT entity eInfo -> ACPT entity <$$> getMsgBody eInfo
_ -> pure $ Right cmd
-- TODO refactor with server
+119 -22
View File
@@ -3,16 +3,21 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
module Simplex.Messaging.Agent.Store where
import Control.Concurrent.STM (TVar)
import Control.Exception (Exception)
import Crypto.Random (ChaChaDRG)
import Data.ByteString.Char8 (ByteString)
import Data.Int (Int64)
import Data.Kind (Type)
import Data.Text (Text)
import Data.Time (UTCTime)
import Data.Type.Equality
import Simplex.Messaging.Agent.Protocol
@@ -30,33 +35,45 @@ import qualified Simplex.Messaging.Protocol as SMP
-- | Store class type. Defines store access methods for implementations.
class Monad m => MonadAgentStore s m where
-- Queue and Connection management
createRcvConn :: s -> RcvQueue -> m ()
createSndConn :: s -> SndQueue -> m ()
getConn :: s -> ConnAlias -> m SomeConn
getAllConnAliases :: s -> m [ConnAlias] -- TODO remove - hack for subscribing to all
createRcvConn :: s -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId
createSndConn :: s -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId
getConn :: s -> ConnId -> m SomeConn
getAllConnIds :: s -> m [ConnId] -- TODO remove - hack for subscribing to all
getRcvConn :: s -> SMPServer -> SMP.RecipientId -> m SomeConn
deleteConn :: s -> ConnAlias -> m ()
upgradeRcvConnToDuplex :: s -> ConnAlias -> SndQueue -> m ()
upgradeSndConnToDuplex :: s -> ConnAlias -> RcvQueue -> m ()
deleteConn :: s -> ConnId -> m ()
upgradeRcvConnToDuplex :: s -> ConnId -> SndQueue -> m ()
upgradeSndConnToDuplex :: s -> ConnId -> RcvQueue -> m ()
setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m ()
setRcvQueueActive :: s -> RcvQueue -> VerificationKey -> m ()
setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m ()
-- Msg management
updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m ()
updateRcvIds :: s -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
createRcvMsg :: s -> ConnId -> RcvMsgData -> m ()
updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
createSndMsg :: s -> SndQueue -> SndMsgData -> m ()
updateSndIds :: s -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash)
createSndMsg :: s -> ConnId -> SndMsgData -> m ()
getMsg :: s -> ConnAlias -> InternalId -> m Msg
getMsg :: s -> ConnId -> InternalId -> m Msg
-- Broadcasts
createBcast :: s -> BroadcastId -> m ()
addBcastConn :: s -> BroadcastId -> ConnAlias -> m ()
removeBcastConn :: s -> BroadcastId -> ConnAlias -> m ()
createBcast :: s -> TVar ChaChaDRG -> BroadcastId -> m BroadcastId
addBcastConn :: s -> BroadcastId -> ConnId -> m ()
removeBcastConn :: s -> BroadcastId -> ConnId -> m ()
deleteBcast :: s -> BroadcastId -> m ()
getBcast :: s -> BroadcastId -> m [ConnAlias]
getBcast :: s -> BroadcastId -> m [ConnId]
-- Introductions
createIntro :: s -> TVar ChaChaDRG -> NewIntroduction -> m IntroId
getIntro :: s -> IntroId -> m Introduction
addIntroInvitation :: s -> IntroId -> EntityInfo -> SMPQueueInfo -> m ()
setIntroToStatus :: s -> IntroId -> IntroStatus -> m ()
setIntroReStatus :: s -> IntroId -> IntroStatus -> m ()
createInvitation :: s -> TVar ChaChaDRG -> NewInvitation -> m InvitationId
getInvitation :: s -> InvitationId -> m Invitation
addInvitationConn :: s -> InvitationId -> ConnId -> m ()
getConnInvitation :: s -> ConnId -> m (Maybe (Invitation, Connection CDuplex))
setInvitationStatus :: s -> InvitationId -> InvitationStatus -> m ()
-- * Queue types
@@ -64,7 +81,6 @@ class Monad m => MonadAgentStore s m where
data RcvQueue = RcvQueue
{ server :: SMPServer,
rcvId :: SMP.RecipientId,
connAlias :: ConnAlias,
rcvPrivateKey :: RecipientPrivateKey,
sndId :: Maybe SMP.SenderId,
sndKey :: Maybe SenderPublicKey,
@@ -78,7 +94,6 @@ data RcvQueue = RcvQueue
data SndQueue = SndQueue
{ server :: SMPServer,
sndId :: SMP.SenderId,
connAlias :: ConnAlias,
sndPrivateKey :: SenderPrivateKey,
encryptKey :: EncryptionKey,
signKey :: SignatureKey,
@@ -102,9 +117,9 @@ data ConnType = CRcv | CSnd | CDuplex deriving (Eq, Show)
-- - DuplexConnection is a connection that has both receive and send queues set up,
-- typically created by upgrading a receive or a send connection with a missing queue.
data Connection (d :: ConnType) where
RcvConnection :: ConnAlias -> RcvQueue -> Connection CRcv
SndConnection :: ConnAlias -> SndQueue -> Connection CSnd
DuplexConnection :: ConnAlias -> RcvQueue -> SndQueue -> Connection CDuplex
RcvConnection :: ConnData -> RcvQueue -> Connection CRcv
SndConnection :: ConnData -> SndQueue -> Connection CSnd
DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex
deriving instance Eq (Connection d)
@@ -141,6 +156,9 @@ instance Eq SomeConn where
deriving instance Show SomeConn
data ConnData = ConnData {connId :: ConnId, viaInv :: Maybe InvitationId, connLevel :: Int}
deriving (Eq, Show)
-- * Message integrity validation types
type MsgHash = ByteString
@@ -263,7 +281,7 @@ type DeliveredTs = UTCTime
-- | Base message data independent of direction.
data MsgBase = MsgBase
{ connAlias :: ConnAlias,
{ connAlias :: ConnId,
-- | Monotonically increasing id of a message per connection, internal to the agent.
-- Internal Id preserves ordering between both received and sent messages, and is needed
-- to track the order of the conversation (which can be different for the sender / receiver)
@@ -281,12 +299,87 @@ newtype InternalId = InternalId {unId :: Int64} deriving (Eq, Show)
type InternalTs = UTCTime
-- * Introduction types
data NewIntroduction = NewIntroduction
{ toConn :: ConnId,
reConn :: ConnId,
reInfo :: ByteString
}
data Introduction = Introduction
{ introId :: IntroId,
toConn :: ConnId,
toInfo :: Maybe ByteString,
toStatus :: IntroStatus,
reConn :: ConnId,
reInfo :: ByteString,
reStatus :: IntroStatus,
qInfo :: Maybe SMPQueueInfo
}
data IntroStatus = IntroNew | IntroInv | IntroCon
deriving (Eq)
serializeIntroStatus :: IntroStatus -> Text
serializeIntroStatus = \case
IntroNew -> ""
IntroInv -> "INV"
IntroCon -> "CON"
introStatusT :: Text -> Maybe IntroStatus
introStatusT = \case
"" -> Just IntroNew
"INV" -> Just IntroInv
"CON" -> Just IntroCon
_ -> Nothing
type IntroId = ByteString
data NewInvitation = NewInvitation
{ viaConn :: ConnId,
externalIntroId :: IntroId,
entityInfo :: EntityInfo,
qInfo :: Maybe SMPQueueInfo
}
data Invitation = Invitation
{ invId :: InvitationId,
viaConn :: ConnId,
externalIntroId :: IntroId,
entityInfo :: EntityInfo,
qInfo :: Maybe SMPQueueInfo,
connId :: Maybe ConnId,
status :: InvitationStatus
}
deriving (Show)
data InvitationStatus = InvNew | InvAcpt | InvCon
deriving (Eq, Show)
serializeInvStatus :: InvitationStatus -> Text
serializeInvStatus = \case
InvNew -> ""
InvAcpt -> "ACPT"
InvCon -> "CON"
invStatusT :: Text -> Maybe InvitationStatus
invStatusT = \case
"" -> Just InvNew
"ACPT" -> Just InvAcpt
"CON" -> Just InvCon
_ -> Nothing
type InvitationId = ByteString
-- * Store errors
-- | Agent store error.
data StoreError
= -- | IO exceptions in store actions.
SEInternal ByteString
| -- | failed to generate unique random ID
SEUniqueID
| -- | Connection alias not found (or both queues absent).
SEConnNotFound
| -- | Connection alias already used.
@@ -298,6 +391,10 @@ data StoreError
SEBcastNotFound
| -- | Broadcast ID already used.
SEBcastDuplicate
| -- | Introduction ID not found.
SEIntroNotFound
| -- | Invitation ID not found.
SEInvitationNotFound
| -- | Currently not used. The intention was to pass current expected queue status in methods,
-- as we always know what it should be at any stage of the protocol,
-- and in case it does not match use this error.
+413 -185
View File
@@ -1,3 +1,4 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
@@ -11,6 +12,7 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
@@ -22,14 +24,19 @@ module Simplex.Messaging.Agent.Store.SQLite
where
import Control.Concurrent (threadDelay)
import Control.Monad (unless, when)
import Control.Concurrent.STM (TVar, atomically, stateTVar)
import Control.Monad (join, unless, when)
import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO))
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Crypto.Random (ChaChaDRG, randomBytesGenerate)
import Data.Bifunctor (first)
import Data.ByteString (ByteString)
import Data.ByteString.Base64 (encode)
import Data.Char (toLower)
import Data.Functor (($>))
import Data.List (find)
import Data.Maybe (fromMaybe)
import Data.Text (isPrefixOf)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8)
import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLData (..), SQLError, field)
@@ -66,8 +73,8 @@ createSQLiteStore dbFilePath = do
let dbDir = takeDirectory dbFilePath
createDirectoryIfMissing False dbDir
store <- connectSQLiteStore dbFilePath
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]]
let threadsafeOption = find (isPrefixOf "THREADSAFE=") (concat compileOptions)
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[Text]]
let threadsafeOption = find (T.isPrefixOf "THREADSAFE=") (concat compileOptions)
case threadsafeOption of
Just "THREADSAFE=0" -> confirmOrExit "SQLite compiled with non-threadsafe code."
Nothing -> putStrLn "Warning: SQLite THREADSAFE compile option not found"
@@ -107,13 +114,16 @@ connectSQLiteStore dbFilePath = do
|]
pure SQLiteStore {dbFilePath, dbConn, dbNew}
checkConstraint :: StoreError -> IO () -> IO (Either StoreError ())
checkConstraint err action = first handleError <$> E.try action
where
handleError :: SQLError -> StoreError
handleError e
| DB.sqlError e == DB.ErrorConstraint = err
| otherwise = SEInternal $ bshow e
checkConstraint :: StoreError -> IO a -> IO (Either StoreError a)
checkConstraint err action = first (handleSQLError err) <$> E.try action
checkConstraint' :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a)
checkConstraint' err action = action `E.catch` (pure . Left . handleSQLError err)
handleSQLError :: StoreError -> SQLError -> StoreError
handleSQLError err e
| DB.sqlError e == DB.ErrorConstraint = err
| otherwise = SEInternal $ bshow e
withTransaction :: forall a. DB.Connection -> IO a -> IO a
withTransaction db a = loop 100 100_000
@@ -128,33 +138,43 @@ withTransaction db a = loop 100 100_000
else E.throwIO e
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
createRcvConn :: SQLiteStore -> RcvQueue -> m ()
createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} =
liftIOEither $
checkConstraint SEConnDuplicate $
withTransaction dbConn $ do
upsertServer_ dbConn server
insertRcvQueue_ dbConn q
insertRcvConnection_ dbConn q
createRcvConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId
createRcvConn SQLiteStore {dbConn} gVar cData q@RcvQueue {server} =
-- TODO if schema has to be restarted, this function can be refactored
-- to create connection first using createWithRandomId
liftIOEither . checkConstraint' SEConnDuplicate . withTransaction dbConn $
getConnId_ dbConn gVar cData >>= traverse create
where
create :: ConnId -> IO ConnId
create connId = do
upsertServer_ dbConn server
insertRcvQueue_ dbConn connId q
insertRcvConnection_ dbConn cData {connId} q
pure connId
createSndConn :: SQLiteStore -> SndQueue -> m ()
createSndConn SQLiteStore {dbConn} q@SndQueue {server} =
liftIOEither $
checkConstraint SEConnDuplicate $
withTransaction dbConn $ do
upsertServer_ dbConn server
insertSndQueue_ dbConn q
insertSndConnection_ dbConn q
createSndConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId
createSndConn SQLiteStore {dbConn} gVar cData q@SndQueue {server} =
-- TODO if schema has to be restarted, this function can be refactored
-- to create connection first using createWithRandomId
liftIOEither . checkConstraint' SEConnDuplicate . withTransaction dbConn $
getConnId_ dbConn gVar cData >>= traverse create
where
create :: ConnId -> IO ConnId
create connId = do
upsertServer_ dbConn server
insertSndQueue_ dbConn connId q
insertSndConnection_ dbConn cData {connId} q
pure connId
getConn :: SQLiteStore -> ConnAlias -> m SomeConn
getConn SQLiteStore {dbConn} connAlias =
getConn :: SQLiteStore -> ConnId -> m SomeConn
getConn SQLiteStore {dbConn} connId =
liftIOEither . withTransaction dbConn $
getConn_ dbConn connAlias
getConn_ dbConn connId
getAllConnAliases :: SQLiteStore -> m [ConnAlias]
getAllConnAliases SQLiteStore {dbConn} =
getAllConnIds :: SQLiteStore -> m [ConnId]
getAllConnIds SQLiteStore {dbConn} =
liftIO $ do
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]]
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnId]]
return (concat r)
getRcvConn :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m SomeConn
@@ -169,37 +189,37 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|]
[":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
>>= \case
[Only connAlias] -> getConn_ dbConn connAlias
[Only connId] -> getConn_ dbConn connId
_ -> pure $ Left SEConnNotFound
deleteConn :: SQLiteStore -> ConnAlias -> m ()
deleteConn SQLiteStore {dbConn} connAlias =
deleteConn :: SQLiteStore -> ConnId -> m ()
deleteConn SQLiteStore {dbConn} connId =
liftIO $
DB.executeNamed
dbConn
"DELETE FROM connections WHERE conn_alias = :conn_alias;"
[":conn_alias" := connAlias]
[":conn_alias" := connId]
upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m ()
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} =
upgradeRcvConnToDuplex :: SQLiteStore -> ConnId -> SndQueue -> m ()
upgradeRcvConnToDuplex SQLiteStore {dbConn} connId sq@SndQueue {server} =
liftIOEither . withTransaction dbConn $
getConn_ dbConn connAlias >>= \case
getConn_ dbConn connId >>= \case
Right (SomeConn _ RcvConnection {}) -> do
upsertServer_ dbConn server
insertSndQueue_ dbConn sq
updateConnWithSndQueue_ dbConn connAlias sq
insertSndQueue_ dbConn connId sq
updateConnWithSndQueue_ dbConn connId sq
pure $ Right ()
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
_ -> pure $ Left SEConnNotFound
upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m ()
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} =
upgradeSndConnToDuplex :: SQLiteStore -> ConnId -> RcvQueue -> m ()
upgradeSndConnToDuplex SQLiteStore {dbConn} connId rq@RcvQueue {server} =
liftIOEither . withTransaction dbConn $
getConn_ dbConn connAlias >>= \case
getConn_ dbConn connId >>= \case
Right (SomeConn _ SndConnection {}) -> do
upsertServer_ dbConn server
insertRcvQueue_ dbConn rq
updateConnWithRcvQueue_ dbConn connAlias rq
insertRcvQueue_ dbConn connId rq
updateConnWithRcvQueue_ dbConn connId rq
pure $ Right ()
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
_ -> pure $ Left SEConnNotFound
@@ -248,83 +268,233 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|]
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId]
updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} =
updateRcvIds :: SQLiteStore -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
updateRcvIds SQLiteStore {dbConn} connId =
liftIO . withTransaction dbConn $ do
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connId
let internalId = InternalId $ unId lastInternalId + 1
internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
updateLastIdsRcv_ dbConn connAlias internalId internalRcvId
updateLastIdsRcv_ dbConn connId internalId internalRcvId
pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash)
createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m ()
createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData =
createRcvMsg :: SQLiteStore -> ConnId -> RcvMsgData -> m ()
createRcvMsg SQLiteStore {dbConn} connId rcvMsgData =
liftIO . withTransaction dbConn $ do
insertRcvMsgBase_ dbConn connAlias rcvMsgData
insertRcvMsgDetails_ dbConn connAlias rcvMsgData
updateHashRcv_ dbConn connAlias rcvMsgData
insertRcvMsgBase_ dbConn connId rcvMsgData
insertRcvMsgDetails_ dbConn connId rcvMsgData
updateHashRcv_ dbConn connId rcvMsgData
updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} =
updateSndIds :: SQLiteStore -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash)
updateSndIds SQLiteStore {dbConn} connId =
liftIO . withTransaction dbConn $ do
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connId
let internalId = InternalId $ unId lastInternalId + 1
internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
updateLastIdsSnd_ dbConn connAlias internalId internalSndId
updateLastIdsSnd_ dbConn connId internalId internalSndId
pure (internalId, internalSndId, prevSndHash)
createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m ()
createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData =
createSndMsg :: SQLiteStore -> ConnId -> SndMsgData -> m ()
createSndMsg SQLiteStore {dbConn} connId sndMsgData =
liftIO . withTransaction dbConn $ do
insertSndMsgBase_ dbConn connAlias sndMsgData
insertSndMsgDetails_ dbConn connAlias sndMsgData
updateHashSnd_ dbConn connAlias sndMsgData
insertSndMsgBase_ dbConn connId sndMsgData
insertSndMsgDetails_ dbConn connId sndMsgData
updateHashSnd_ dbConn connId sndMsgData
getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg
getMsg :: SQLiteStore -> ConnId -> InternalId -> m Msg
getMsg _st _connAlias _id = throwError SENotImplemented
createBcast :: SQLiteStore -> BroadcastId -> m ()
createBcast SQLiteStore {dbConn} bId =
liftIOEither $
checkConstraint SEBcastDuplicate $
DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId)
createBcast :: SQLiteStore -> TVar ChaChaDRG -> BroadcastId -> m BroadcastId
createBcast SQLiteStore {dbConn} gVar bcastId = liftIOEither $ case bcastId of
"" -> createWithRandomId gVar create
bId -> checkConstraint SEBcastDuplicate $ create bId $> bId
where
create bId = DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId)
addBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m ()
addBcastConn SQLiteStore {dbConn} bId connAlias =
addBcastConn :: SQLiteStore -> BroadcastId -> ConnId -> m ()
addBcastConn SQLiteStore {dbConn} bId connId =
liftIOEither . checkBroadcast dbConn bId $
getConn_ dbConn connAlias >>= \case
getConn_ dbConn connId >>= \case
Left _ -> pure $ Left SEConnNotFound
Right (SomeConn _ RcvConnection {}) -> pure . Left $ SEBadConnType CRcv
Right _ ->
checkConstraint SEConnDuplicate $
DB.execute
dbConn
"INSERT INTO broadcast_connections (broadcast_id, conn_alias) VALUES (?, ?);"
(bId, connAlias)
[sql|
INSERT INTO broadcast_connections
(broadcast_id, conn_alias) VALUES (?, ?);
|]
(bId, connId)
removeBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m ()
removeBcastConn SQLiteStore {dbConn} bId connAlias =
removeBcastConn :: SQLiteStore -> BroadcastId -> ConnId -> m ()
removeBcastConn SQLiteStore {dbConn} bId connId =
liftIOEither . checkBroadcast dbConn bId $
bcastConnExists_ dbConn bId connAlias >>= \case
bcastConnExists_ dbConn bId connId >>= \case
False -> pure $ Left SEConnNotFound
_ ->
Right
<$> DB.execute
dbConn
"DELETE FROM broadcast_connections WHERE broadcast_id = ? AND conn_alias = ?;"
(bId, connAlias)
[sql|
DELETE FROM broadcast_connections
WHERE broadcast_id = ? AND conn_alias = ?;
|]
(bId, connId)
deleteBcast :: SQLiteStore -> BroadcastId -> m ()
deleteBcast SQLiteStore {dbConn} bId =
liftIOEither . checkBroadcast dbConn bId $
Right <$> DB.execute dbConn "DELETE FROM broadcasts WHERE broadcast_id = ?;" (Only bId)
getBcast :: SQLiteStore -> BroadcastId -> m [ConnAlias]
getBcast :: SQLiteStore -> BroadcastId -> m [ConnId]
getBcast SQLiteStore {dbConn} bId =
liftIOEither . checkBroadcast dbConn bId $
Right . map fromOnly
<$> DB.query dbConn "SELECT conn_alias FROM broadcast_connections WHERE broadcast_id = ?;" (Only bId)
createIntro :: SQLiteStore -> TVar ChaChaDRG -> NewIntroduction -> m IntroId
createIntro SQLiteStore {dbConn} gVar NewIntroduction {toConn, reConn, reInfo} =
liftIOEither . createWithRandomId gVar $ \introId ->
DB.execute
dbConn
[sql|
INSERT INTO conn_intros
(intro_id, to_conn, re_conn, re_info) VALUES (?, ?, ?, ?);
|]
(introId, toConn, reConn, reInfo)
getIntro :: SQLiteStore -> IntroId -> m Introduction
getIntro SQLiteStore {dbConn} introId =
liftIOEither $
intro
<$> DB.query
dbConn
[sql|
SELECT to_conn, to_info, to_status, re_conn, re_info, re_status, queue_info
FROM conn_intros
WHERE intro_id = ?;
|]
(Only introId)
where
intro [(toConn, toInfo, toStatus, reConn, reInfo, reStatus, qInfo)] =
Right $ Introduction {introId, toConn, toInfo, toStatus, reConn, reInfo, reStatus, qInfo}
intro _ = Left SEIntroNotFound
addIntroInvitation :: SQLiteStore -> IntroId -> EntityInfo -> SMPQueueInfo -> m ()
addIntroInvitation SQLiteStore {dbConn} introId toInfo qInfo =
liftIO $
DB.executeNamed
dbConn
[sql|
UPDATE conn_intros
SET to_info = :to_info,
queue_info = :queue_info,
to_status = :to_status
WHERE intro_id = :intro_id;
|]
[ ":to_info" := toInfo,
":queue_info" := Just qInfo,
":to_status" := IntroInv,
":intro_id" := introId
]
setIntroToStatus :: SQLiteStore -> IntroId -> IntroStatus -> m ()
setIntroToStatus SQLiteStore {dbConn} introId toStatus =
liftIO $
DB.execute
dbConn
[sql|
UPDATE conn_intros
SET to_status = ?
WHERE intro_id = ?;
|]
(toStatus, introId)
setIntroReStatus :: SQLiteStore -> IntroId -> IntroStatus -> m ()
setIntroReStatus SQLiteStore {dbConn} introId reStatus =
liftIO $
DB.execute
dbConn
[sql|
UPDATE conn_intros
SET re_status = ?
WHERE intro_id = ?;
|]
(reStatus, introId)
createInvitation :: SQLiteStore -> TVar ChaChaDRG -> NewInvitation -> m InvitationId
createInvitation SQLiteStore {dbConn} gVar NewInvitation {viaConn, externalIntroId, entityInfo, qInfo} =
liftIOEither . createWithRandomId gVar $ \invId ->
DB.execute
dbConn
[sql|
INSERT INTO conn_invitations
(inv_id, via_conn, external_intro_id, conn_info, queue_info) VALUES (?, ?, ?, ?, ?);
|]
(invId, viaConn, externalIntroId, entityInfo, qInfo)
getInvitation :: SQLiteStore -> InvitationId -> m Invitation
getInvitation SQLiteStore {dbConn} invId =
liftIOEither $
invitation
<$> DB.query
dbConn
[sql|
SELECT via_conn, external_intro_id, conn_info, queue_info, conn_id, status
FROM conn_invitations
WHERE inv_id = ?;
|]
(Only invId)
where
invitation [(viaConn, externalIntroId, entityInfo, qInfo, connId, status)] =
Right $ Invitation {invId, viaConn, externalIntroId, entityInfo, qInfo, connId, status}
invitation _ = Left SEInvitationNotFound
addInvitationConn :: SQLiteStore -> InvitationId -> ConnId -> m ()
addInvitationConn SQLiteStore {dbConn} invId connId =
liftIO $
DB.executeNamed
dbConn
[sql|
UPDATE conn_invitations
SET conn_id = :conn_id, status = :status
WHERE inv_id = :inv_id;
|]
[":conn_id" := connId, ":status" := InvAcpt, ":inv_id" := invId]
getConnInvitation :: SQLiteStore -> ConnId -> m (Maybe (Invitation, Connection 'CDuplex))
getConnInvitation SQLiteStore {dbConn} cId =
liftIO . withTransaction dbConn $
DB.query
dbConn
[sql|
SELECT inv_id, via_conn, external_intro_id, conn_info, queue_info, status
FROM conn_invitations
WHERE conn_id = ?;
|]
(Only cId)
>>= fmap join . traverse getViaConn . invitation
where
invitation [(invId, viaConn, externalIntroId, entityInfo, qInfo, status)] =
Just $ Invitation {invId, viaConn, externalIntroId, entityInfo, qInfo, connId = Just cId, status}
invitation _ = Nothing
getViaConn :: Invitation -> IO (Maybe (Invitation, Connection 'CDuplex))
getViaConn inv@Invitation {viaConn} = fmap (inv,) . duplexConn <$> getConn_ dbConn viaConn
duplexConn :: Either StoreError SomeConn -> Maybe (Connection 'CDuplex)
duplexConn (Right (SomeConn SCDuplex conn)) = Just conn
duplexConn _ = Nothing
setInvitationStatus :: SQLiteStore -> InvitationId -> InvitationStatus -> m ()
setInvitationStatus SQLiteStore {dbConn} invId status =
liftIO $
DB.execute
dbConn
[sql|
UPDATE conn_invitations
SET status = ? WHERE inv_id = ?;
|]
(status, invId)
-- * Auxiliary helpers
-- ? replace with ToField? - it's easy to forget to use this
@@ -337,7 +507,7 @@ deserializePort_ port = Just port
instance ToField QueueStatus where toField = toField . show
instance FromField QueueStatus where fromField = fromFieldToReadable_
instance FromField QueueStatus where fromField = fromTextField_ $ readMaybe . T.unpack
instance ToField InternalRcvId where toField (InternalRcvId x) = toField x
@@ -359,13 +529,24 @@ instance ToField MsgIntegrity where toField = toField . serializeMsgIntegrity
instance FromField MsgIntegrity where fromField = blobFieldParser msgIntegrityP
fromFieldToReadable_ :: forall a. (Read a, E.Typeable a) => Field -> Ok a
fromFieldToReadable_ = \case
instance ToField IntroStatus where toField = toField . serializeIntroStatus
instance FromField IntroStatus where fromField = fromTextField_ introStatusT
instance ToField InvitationStatus where toField = toField . serializeInvStatus
instance FromField InvitationStatus where fromField = fromTextField_ invStatusT
instance ToField SMPQueueInfo where toField = toField . serializeSmpQueueInfo
instance FromField SMPQueueInfo where fromField = blobFieldParser smpQueueInfoP
fromTextField_ :: (E.Typeable a) => (Text -> Maybe a) -> Field -> Ok a
fromTextField_ fromText = \case
f@(Field (SQLText t) _) ->
let str = T.unpack t
in case readMaybe str of
Just x -> Ok x
_ -> returnError ConversionFailed f ("invalid string: " <> str)
case fromText t of
Just x -> Ok x
_ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t)
f -> returnError ConversionFailed f "expecting SQLText column type"
{- ORMOLU_DISABLE -}
@@ -397,8 +578,8 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do
-- * createRcvConn helpers
insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO ()
insertRcvQueue_ dbConn RcvQueue {..} = do
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
insertRcvQueue_ dbConn connId RcvQueue {..} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
@@ -411,7 +592,7 @@ insertRcvQueue_ dbConn RcvQueue {..} = do
[ ":host" := host server,
":port" := port_,
":rcv_id" := rcvId,
":conn_alias" := connAlias,
":conn_alias" := connId,
":rcv_private_key" := rcvPrivateKey,
":snd_id" := sndId,
":snd_key" := sndKey,
@@ -420,26 +601,29 @@ insertRcvQueue_ dbConn RcvQueue {..} = do
":status" := status
]
insertRcvConnection_ :: DB.Connection -> RcvQueue -> IO ()
insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do
insertRcvConnection_ :: DB.Connection -> ConnData -> RcvQueue -> IO ()
insertRcvConnection_ dbConn ConnData {connId, viaInv, connLevel} RcvQueue {server, rcvId} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
[sql|
INSERT INTO connections
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, via_inv, conn_level, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
VALUES
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL,
0, 0, 0, 0, x'', x'');
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, :via_inv,:conn_level, 0, 0, 0, 0, x'', x'');
|]
[":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId]
[ ":conn_alias" := connId,
":rcv_host" := host server,
":rcv_port" := port_,
":rcv_id" := rcvId,
":via_inv" := viaInv,
":conn_level" := connLevel
]
-- * createSndConn helpers
insertSndQueue_ :: DB.Connection -> SndQueue -> IO ()
insertSndQueue_ dbConn SndQueue {..} = do
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
insertSndQueue_ dbConn connId SndQueue {..} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
@@ -452,85 +636,96 @@ insertSndQueue_ dbConn SndQueue {..} = do
[ ":host" := host server,
":port" := port_,
":snd_id" := sndId,
":conn_alias" := connAlias,
":conn_alias" := connId,
":snd_private_key" := sndPrivateKey,
":encrypt_key" := encryptKey,
":sign_key" := signKey,
":status" := status
]
insertSndConnection_ :: DB.Connection -> SndQueue -> IO ()
insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do
insertSndConnection_ :: DB.Connection -> ConnData -> SndQueue -> IO ()
insertSndConnection_ dbConn ConnData {connId, viaInv, connLevel} SndQueue {server, sndId} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
[sql|
INSERT INTO connections
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, via_inv, conn_level, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
VALUES
(:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id,
0, 0, 0, 0, x'', x'');
(:conn_alias, NULL, NULL, NULL, :snd_host,:snd_port,:snd_id,:via_inv,:conn_level, 0, 0, 0, 0, x'', x'');
|]
[":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId]
[ ":conn_alias" := connId,
":snd_host" := host server,
":snd_port" := port_,
":snd_id" := sndId,
":via_inv" := viaInv,
":conn_level" := connLevel
]
-- * getConn helpers
getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn)
getConn_ dbConn connAlias = do
rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias
sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias
pure $ case (rQ, sQ) of
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ)
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ)
_ -> Left SEConnNotFound
getConn_ :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getConn_ dbConn connId =
getConnData_ dbConn connId >>= \case
Nothing -> pure $ Left SEConnNotFound
Just connData -> do
rQ <- getRcvQueueByConnAlias_ dbConn connId
sQ <- getSndQueueByConnAlias_ dbConn connId
pure $ case (rQ, sQ) of
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ)
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ)
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connData sndQ)
_ -> Left SEConnNotFound
retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue)
retrieveRcvQueueByConnAlias_ dbConn connAlias = do
r <-
DB.queryNamed
getConnData_ :: DB.Connection -> ConnId -> IO (Maybe ConnData)
getConnData_ dbConn connId =
connData
<$> DB.query dbConn "SELECT via_inv, conn_level FROM connections WHERE conn_alias = ?;" (Only connId)
where
connData [(viaInv, connLevel)] = Just ConnData {connId, viaInv, connLevel}
connData _ = Nothing
getRcvQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue)
getRcvQueueByConnAlias_ dbConn connId =
rcvQueue
<$> DB.query
dbConn
[sql|
SELECT
s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key,
SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key,
q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status
FROM rcv_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
WHERE q.conn_alias = :conn_alias;
WHERE q.conn_alias = ?;
|]
[":conn_alias" := connAlias]
case r of
[(keyHash, host, port, rcvId, cAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do
(Only connId)
where
rcvQueue [(keyHash, host, port, rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] =
let srv = SMPServer host (deserializePort_ port) keyHash
return . Just $ RcvQueue srv rcvId cAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status
_ -> return Nothing
in Just $ RcvQueue srv rcvId rcvPrivateKey sndId sndKey decryptKey verifyKey status
rcvQueue _ = Nothing
retrieveSndQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe SndQueue)
retrieveSndQueueByConnAlias_ dbConn connAlias = do
r <-
DB.queryNamed
getSndQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue)
getSndQueueByConnAlias_ dbConn connId =
sndQueue
<$> DB.query
dbConn
[sql|
SELECT
s.key_hash, q.host, q.port, q.snd_id, q.conn_alias,
q.snd_private_key, q.encrypt_key, q.sign_key, q.status
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_private_key, q.encrypt_key, q.sign_key, q.status
FROM snd_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
WHERE q.conn_alias = :conn_alias;
WHERE q.conn_alias = ?;
|]
[":conn_alias" := connAlias]
case r of
[(keyHash, host, port, sndId, cAlias, sndPrivateKey, encryptKey, signKey, status)] -> do
(Only connId)
where
sndQueue [(keyHash, host, port, sndId, sndPrivateKey, encryptKey, signKey, status)] =
let srv = SMPServer host (deserializePort_ port) keyHash
return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status
_ -> return Nothing
in Just $ SndQueue srv sndId sndPrivateKey encryptKey signKey status
sndQueue _ = Nothing
-- * upgradeRcvConnToDuplex helpers
updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO ()
updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
updateConnWithSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
updateConnWithSndQueue_ dbConn connId SndQueue {server, sndId} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
@@ -539,12 +734,12 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
SET snd_host = :snd_host, snd_port = :snd_port, snd_id = :snd_id
WHERE conn_alias = :conn_alias;
|]
[":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connAlias]
[":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connId]
-- * upgradeSndConnToDuplex helpers
updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO ()
updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
updateConnWithRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
updateConnWithRcvQueue_ dbConn connId RcvQueue {server, rcvId} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
@@ -553,12 +748,12 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
SET rcv_host = :rcv_host, rcv_port = :rcv_port, rcv_id = :rcv_id
WHERE conn_alias = :conn_alias;
|]
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias]
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connId]
-- * updateRcvIds helpers
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
retrieveLastIdsAndHashRcv_ dbConn connAlias = do
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
retrieveLastIdsAndHashRcv_ dbConn connId = do
[(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <-
DB.queryNamed
dbConn
@@ -567,11 +762,11 @@ retrieveLastIdsAndHashRcv_ dbConn connAlias = do
FROM connections
WHERE conn_alias = :conn_alias;
|]
[":conn_alias" := connAlias]
[":conn_alias" := connId]
return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)
updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO ()
updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
updateLastIdsRcv_ :: DB.Connection -> ConnId -> InternalId -> InternalRcvId -> IO ()
updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId =
DB.executeNamed
dbConn
[sql|
@@ -582,13 +777,13 @@ updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
|]
[ ":last_internal_msg_id" := newInternalId,
":last_internal_rcv_msg_id" := newInternalRcvId,
":conn_alias" := connAlias
":conn_alias" := connId
]
-- * createRcvMsg helpers
insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
insertRcvMsgBase_ dbConn connId RcvMsgData {..} = do
DB.executeNamed
dbConn
[sql|
@@ -597,15 +792,15 @@ insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
VALUES
(:conn_alias,:internal_id,:internal_ts,:internal_rcv_id, NULL,:body);
|]
[ ":conn_alias" := connAlias,
[ ":conn_alias" := connId,
":internal_id" := internalId,
":internal_ts" := internalTs,
":internal_rcv_id" := internalRcvId,
":body" := decodeUtf8 msgBody
]
insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
insertRcvMsgDetails_ dbConn connId RcvMsgData {..} =
DB.executeNamed
dbConn
[sql|
@@ -618,7 +813,7 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
:broker_id,:broker_ts,:rcv_status, NULL, NULL,
:internal_hash,:external_prev_snd_hash,:integrity);
|]
[ ":conn_alias" := connAlias,
[ ":conn_alias" := connId,
":internal_rcv_id" := internalRcvId,
":internal_id" := internalId,
":external_snd_id" := fst senderMeta,
@@ -631,8 +826,8 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
":integrity" := msgIntegrity
]
updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
updateHashRcv_ dbConn connAlias RcvMsgData {..} =
updateHashRcv_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
updateHashRcv_ dbConn connId RcvMsgData {..} =
DB.executeNamed
dbConn
-- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved
@@ -645,14 +840,14 @@ updateHashRcv_ dbConn connAlias RcvMsgData {..} =
|]
[ ":last_external_snd_msg_id" := fst senderMeta,
":last_rcv_msg_hash" := internalHash,
":conn_alias" := connAlias,
":conn_alias" := connId,
":last_internal_rcv_msg_id" := internalRcvId
]
-- * updateSndIds helpers
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash)
retrieveLastIdsAndHashSnd_ dbConn connAlias = do
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash)
retrieveLastIdsAndHashSnd_ dbConn connId = do
[(lastInternalId, lastInternalSndId, lastSndHash)] <-
DB.queryNamed
dbConn
@@ -661,11 +856,11 @@ retrieveLastIdsAndHashSnd_ dbConn connAlias = do
FROM connections
WHERE conn_alias = :conn_alias;
|]
[":conn_alias" := connAlias]
[":conn_alias" := connId]
return (lastInternalId, lastInternalSndId, lastSndHash)
updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO ()
updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO ()
updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId =
DB.executeNamed
dbConn
[sql|
@@ -676,13 +871,13 @@ updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
|]
[ ":last_internal_msg_id" := newInternalId,
":last_internal_snd_msg_id" := newInternalSndId,
":conn_alias" := connAlias
":conn_alias" := connId
]
-- * createSndMsg helpers
insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
insertSndMsgBase_ dbConn connId SndMsgData {..} = do
DB.executeNamed
dbConn
[sql|
@@ -691,15 +886,15 @@ insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
VALUES
(:conn_alias,:internal_id,:internal_ts, NULL,:internal_snd_id,:body);
|]
[ ":conn_alias" := connAlias,
[ ":conn_alias" := connId,
":internal_id" := internalId,
":internal_ts" := internalTs,
":internal_snd_id" := internalSndId,
":body" := decodeUtf8 msgBody
]
insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
insertSndMsgDetails_ dbConn connId SndMsgData {..} =
DB.executeNamed
dbConn
[sql|
@@ -708,15 +903,15 @@ insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
VALUES
(:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash);
|]
[ ":conn_alias" := connAlias,
[ ":conn_alias" := connId,
":internal_snd_id" := internalSndId,
":internal_id" := internalId,
":snd_status" := Created,
":internal_hash" := internalHash
]
updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
updateHashSnd_ dbConn connAlias SndMsgData {..} =
updateHashSnd_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
updateHashSnd_ dbConn connId SndMsgData {..} =
DB.executeNamed
dbConn
-- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved
@@ -727,7 +922,7 @@ updateHashSnd_ dbConn connAlias SndMsgData {..} =
AND last_internal_snd_msg_id = :last_internal_snd_msg_id;
|]
[ ":last_snd_msg_hash" := internalHash,
":conn_alias" := connAlias,
":conn_alias" := connId,
":last_internal_snd_msg_id" := internalSndId
]
@@ -745,10 +940,10 @@ bcastExists_ dbConn bId = not . null <$> queryBcast
queryBcast :: IO [Only BroadcastId]
queryBcast = DB.query dbConn "SELECT broadcast_id FROM broadcasts WHERE broadcast_id = ?;" (Only bId)
bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnAlias -> IO Bool
bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn
bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnId -> IO Bool
bcastConnExists_ dbConn bId connId = not . null <$> queryBcastConn
where
queryBcastConn :: IO [(BroadcastId, ConnAlias)]
queryBcastConn :: IO [(BroadcastId, ConnId)]
queryBcastConn =
DB.query
dbConn
@@ -757,4 +952,37 @@ bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn
FROM broadcast_connections
WHERE broadcast_id = ? AND conn_alias = ?;
|]
(bId, connAlias)
(bId, connId)
-- create record with a random ID
getConnId_ :: DB.Connection -> TVar ChaChaDRG -> ConnData -> IO (Either StoreError ConnId)
getConnId_ dbConn gVar ConnData {connId = ""} = getUniqueRandomId gVar $ getConnData_ dbConn
getConnId_ _ _ ConnData {connId} = pure $ Right connId
getUniqueRandomId :: TVar ChaChaDRG -> (ByteString -> IO (Maybe a)) -> IO (Either StoreError ByteString)
getUniqueRandomId gVar get = tryGet 3
where
tryGet :: Int -> IO (Either StoreError ByteString)
tryGet 0 = pure $ Left SEUniqueID
tryGet n = do
id' <- randomId gVar 12
get id' >>= \case
Nothing -> pure $ Right id'
Just _ -> tryGet (n - 1)
createWithRandomId :: TVar ChaChaDRG -> (ByteString -> IO ()) -> IO (Either StoreError ByteString)
createWithRandomId gVar create = tryCreate 3
where
tryCreate :: Int -> IO (Either StoreError ByteString)
tryCreate 0 = pure $ Left SEUniqueID
tryCreate n = do
id' <- randomId gVar 12
E.try (create id') >>= \case
Right _ -> pure $ Right id'
Left e
| DB.sqlError e == DB.ErrorConstraint -> tryCreate (n - 1)
| otherwise -> pure . Left . SEInternal $ bshow e
randomId :: TVar ChaChaDRG -> Int -> IO ByteString
randomId gVar n = encode <$> (atomically . stateTVar gVar $ randomBytesGenerate n)
+7 -4
View File
@@ -30,7 +30,7 @@ base64StringP = do
pure $ str <> pad
tsISO8601P :: Parser UTCTime
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ')
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd
parse :: Parser a -> e -> (ByteString -> Either e a)
parse parser err = first (const err) . parseAll parser
@@ -42,14 +42,17 @@ parseRead :: Read a => Parser ByteString -> Parser a
parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack)
parseRead1 :: Read a => Parser a
parseRead1 = parseRead $ A.takeTill (== ' ')
parseRead1 = parseRead $ A.takeTill wordEnd
parseRead2 :: Read a => Parser a
parseRead2 = parseRead $ do
w1 <- A.takeTill (== ' ') <* A.char ' '
w2 <- A.takeTill (== ' ')
w1 <- A.takeTill wordEnd <* A.char ' '
w2 <- A.takeTill wordEnd
pure $ w1 <> " " <> w2
wordEnd :: Char -> Bool
wordEnd c = c == ' ' || c == '\n'
parseString :: (ByteString -> Either String a) -> (String -> a)
parseString p = either error id . p . B.pack