mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-24 23:26:00 +00:00
introduction protocol (#156)
* commands to support introduction * agent messages / envelopes to support introductions * introductions and invitations table; insert record with random unique ID * store class methods and types for introductions * process INTRO and ACPT commands for connection introductions * fix tests: add MonadFail constraint, remove OK response to JOIN * process agent messages for introductions * ICON notification when introduction is completed * replace multiway if with case * correction * support random connection IDs * save additional connection fields, refactor create connection funcs * refactor * refactor * test duplex connection with random IDs * store methods for introductions * test introduction * fix parsing of CON agent message * test introduction with random connection IDs * broadcast with random connection and broadcast IDs * clean up sql
This commit is contained in:
committed by
GitHub
parent
46c3589604
commit
ab89963f45
+179
-62
@@ -61,7 +61,7 @@ import UnliftIO.STM
|
||||
-- | Runs an SMP agent as a TCP service using passed configuration.
|
||||
--
|
||||
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
|
||||
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
|
||||
runSMPAgent :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
|
||||
runSMPAgent t cfg = do
|
||||
started <- newEmptyTMVarIO
|
||||
runSMPAgentBlocking t started cfg
|
||||
@@ -70,10 +70,10 @@ runSMPAgent t cfg = do
|
||||
--
|
||||
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
|
||||
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
|
||||
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
|
||||
runSMPAgentBlocking :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
|
||||
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort} = runReaderT (smpAgent t) =<< newSMPAgentEnv cfg
|
||||
where
|
||||
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
|
||||
smpAgent :: forall c m'. (Transport c, MonadFail m', MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
|
||||
smpAgent _ = runTransportServer started tcpPort $ \(h :: c) -> do
|
||||
liftIO $ putLn h "Welcome to SMP v0.3.2 agent"
|
||||
c <- getSMPAgentClient
|
||||
@@ -97,7 +97,7 @@ logConnection c connected =
|
||||
in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"]
|
||||
|
||||
-- | Runs an SMP agent instance that receives commands and sends responses via 'TBQueue's.
|
||||
runSMPAgentClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
runSMPAgentClient :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
runSMPAgentClient c = do
|
||||
db <- asks $ dbFile . config
|
||||
s1 <- liftIO $ connectSQLiteStore db
|
||||
@@ -128,7 +128,7 @@ logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a ->
|
||||
logClient AgentClient {clientId} dir (ATransmission corrId entity cmd) = do
|
||||
logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, serializeEntity entity, B.takeWhile (/= ' ') $ serializeCommand cmd]
|
||||
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
client :: forall m. (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
client c@AgentClient {rcvQ, sndQ} st = forever loop
|
||||
where
|
||||
loop :: m ()
|
||||
@@ -147,6 +147,8 @@ withStore action = do
|
||||
Right c -> return c
|
||||
Left e -> throwError $ storeError e
|
||||
where
|
||||
-- TODO when parsing exception happens in store, the agent hangs;
|
||||
-- changing SQLError to SomeException does not help
|
||||
handleInternal :: (MonadError StoreError m') => SQLError -> m' a
|
||||
handleInternal e = throwError . SEInternal $ bshow e
|
||||
storeError :: StoreError -> AgentErrorType
|
||||
@@ -173,23 +175,28 @@ unsupportedEntity c _ corrId entity _ =
|
||||
|
||||
processConnCommand ::
|
||||
forall c m. (AgentMonad m, EntityCommand 'Conn_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> ACommand 'Client c -> m ()
|
||||
processConnCommand c@AgentClient {sndQ} st corrId conn = \case
|
||||
NEW -> createNewConnection conn
|
||||
JOIN smpQueueInfo replyMode -> joinConnection conn smpQueueInfo replyMode
|
||||
processConnCommand c@AgentClient {sndQ} st corrId conn@(Conn connId) = \case
|
||||
NEW -> createNewConnection Nothing 0 >>= uncurry respond
|
||||
JOIN smpQueueInfo replyMode -> joinConnection smpQueueInfo replyMode Nothing 0 >> pure () -- >>= (`respond` OK)
|
||||
INTRO reEntity reInfo -> makeIntroduction reEntity reInfo
|
||||
ACPT inv eInfo -> acceptInvitation inv eInfo
|
||||
SUB -> subscribeConnection conn
|
||||
SUBALL -> subscribeAll
|
||||
SEND msgBody -> sendMessage c st corrId conn msgBody
|
||||
SEND msgBody -> sendClientMessage c st corrId conn msgBody
|
||||
OFF -> suspendConnection conn
|
||||
DEL -> deleteConnection conn
|
||||
where
|
||||
createNewConnection :: Entity 'Conn_ -> m ()
|
||||
createNewConnection (Conn cId) = do
|
||||
createNewConnection :: Maybe InvitationId -> Int -> m (Entity 'Conn_, ACommand 'Agent 'INV_)
|
||||
createNewConnection viaInv connLevel = do
|
||||
-- TODO create connection alias if not passed
|
||||
-- make connAlias Maybe?
|
||||
-- make connId Maybe?
|
||||
srv <- getSMPServer
|
||||
(rq, qInfo) <- newReceiveQueue c srv cId
|
||||
withStore $ createRcvConn st rq
|
||||
respond conn $ INV qInfo
|
||||
(rq, qInfo) <- newReceiveQueue c srv
|
||||
g <- asks idsDrg
|
||||
let cData = ConnData {connId, viaInv, connLevel}
|
||||
connId' <- withStore $ createRcvConn st g cData rq
|
||||
addSubscription c rq connId'
|
||||
pure (Conn connId', INV qInfo)
|
||||
|
||||
getSMPServer :: m SMPServer
|
||||
getSMPServer =
|
||||
@@ -200,16 +207,47 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
|
||||
i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1)
|
||||
pure $ servers L.!! i
|
||||
|
||||
joinConnection :: Entity 'Conn_ -> SMPQueueInfo -> ReplyMode -> m ()
|
||||
joinConnection (Conn cId) qInfo (ReplyMode replyMode) = do
|
||||
-- TODO create connection alias if not passed
|
||||
-- make connAlias Maybe?
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo cId
|
||||
withStore $ createSndConn st sq
|
||||
joinConnection :: SMPQueueInfo -> ReplyMode -> Maybe InvitationId -> Int -> m (Entity 'Conn_)
|
||||
joinConnection qInfo (ReplyMode replyMode) viaInv connLevel = do
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo
|
||||
g <- asks idsDrg
|
||||
let cData = ConnData {connId, viaInv, connLevel}
|
||||
connId' <- withStore $ createSndConn st g cData sq
|
||||
connectToSendQueue c st sq senderKey verifyKey
|
||||
when (replyMode == On) $ createReplyQueue cId sq
|
||||
-- TODO this response is disabled to avoid two responses in terminal client (OK + CON),
|
||||
-- respond conn OK
|
||||
when (replyMode == On) $ createReplyQueue connId' sq
|
||||
pure $ Conn connId'
|
||||
|
||||
makeIntroduction :: IntroEntity -> EntityInfo -> m ()
|
||||
makeIntroduction (IE reEntity) reInfo = case reEntity of
|
||||
Conn reConn ->
|
||||
withStore ((,) <$> getConn st connId <*> getConn st reConn) >>= \case
|
||||
(SomeConn _ (DuplexConnection _ _ sq), SomeConn _ DuplexConnection {}) -> do
|
||||
g <- asks idsDrg
|
||||
introId <- withStore $ createIntro st g NewIntroduction {toConn = connId, reConn, reInfo}
|
||||
sendControlMessage c sq $ A_INTRO (IE (Conn introId)) reInfo
|
||||
respond conn OK
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ CMD UNSUPPORTED
|
||||
|
||||
acceptInvitation :: IntroEntity -> EntityInfo -> m ()
|
||||
acceptInvitation (IE invEntity) eInfo = case invEntity of
|
||||
Conn invId -> do
|
||||
withStore (getInvitation st invId) >>= \case
|
||||
Invitation {viaConn, qInfo, externalIntroId, status = InvNew} ->
|
||||
withStore (getConn st viaConn) >>= \case
|
||||
SomeConn _ (DuplexConnection ConnData {connLevel} _ sq) -> case qInfo of
|
||||
Nothing -> do
|
||||
(conn', INV qInfo') <- createNewConnection (Just invId) (connLevel + 1)
|
||||
withStore $ addInvitationConn st invId $ fromConn conn'
|
||||
sendControlMessage c sq $ A_INV (Conn externalIntroId) qInfo' eInfo
|
||||
respond conn' OK
|
||||
Just qInfo' -> do
|
||||
conn' <- joinConnection qInfo' (ReplyMode On) (Just invId) (connLevel + 1)
|
||||
withStore $ addInvitationConn st invId $ fromConn conn'
|
||||
respond conn' OK
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
_ -> throwError $ CMD UNSUPPORTED
|
||||
|
||||
subscribeConnection :: Entity 'Conn_ -> m ()
|
||||
subscribeConnection conn'@(Conn cId) =
|
||||
@@ -222,7 +260,7 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
|
||||
|
||||
-- TODO remove - hack for subscribing to all; respond' and parameterization of subscribeConnection are byproduct
|
||||
subscribeAll :: m ()
|
||||
subscribeAll = withStore (getAllConnAliases st) >>= mapM_ (subscribeConnection . Conn)
|
||||
subscribeAll = withStore (getAllConnIds st) >>= mapM_ (subscribeConnection . Conn)
|
||||
|
||||
suspendConnection :: Entity 'Conn_ -> m ()
|
||||
suspendConnection (Conn cId) =
|
||||
@@ -246,25 +284,30 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
|
||||
removeSubscription c cId
|
||||
delConn
|
||||
|
||||
createReplyQueue :: ByteString -> SndQueue -> m ()
|
||||
createReplyQueue :: ConnId -> SndQueue -> m ()
|
||||
createReplyQueue cId sq = do
|
||||
srv <- getSMPServer
|
||||
(rq, qInfo) <- newReceiveQueue c srv cId
|
||||
(rq, qInfo) <- newReceiveQueue c srv
|
||||
addSubscription c rq cId
|
||||
withStore $ upgradeSndConnToDuplex st cId rq
|
||||
senderTimestamp <- liftIO getCurrentTime
|
||||
sendAgentMessage c sq . serializeSMPMessage $
|
||||
SMPMessage
|
||||
{ senderMsgId = 0,
|
||||
senderTimestamp,
|
||||
previousMsgHash = "",
|
||||
agentMessage = REPLY qInfo
|
||||
}
|
||||
sendControlMessage c sq $ REPLY qInfo
|
||||
|
||||
respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m ()
|
||||
respond ent resp = atomically . writeTBQueue sndQ $ ATransmission corrId ent resp
|
||||
|
||||
sendMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m ()
|
||||
sendMessage c st corrId (Conn cId) msgBody =
|
||||
sendControlMessage :: AgentMonad m => AgentClient -> SndQueue -> AMessage -> m ()
|
||||
sendControlMessage c sq agentMessage = do
|
||||
senderTimestamp <- liftIO getCurrentTime
|
||||
sendAgentMessage c sq . serializeSMPMessage $
|
||||
SMPMessage
|
||||
{ senderMsgId = 0,
|
||||
senderTimestamp,
|
||||
previousMsgHash = "",
|
||||
agentMessage
|
||||
}
|
||||
|
||||
sendClientMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m ()
|
||||
sendClientMessage c st corrId (Conn cId) msgBody =
|
||||
withStore (getConn st cId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq
|
||||
SomeConn _ (SndConnection _ sq) -> sendMsg sq
|
||||
@@ -273,7 +316,7 @@ sendMessage c st corrId (Conn cId) msgBody =
|
||||
sendMsg :: SndQueue -> m ()
|
||||
sendMsg sq = do
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq
|
||||
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st cId
|
||||
let msgStr =
|
||||
serializeSMPMessage
|
||||
SMPMessage
|
||||
@@ -284,7 +327,7 @@ sendMessage c st corrId (Conn cId) msgBody =
|
||||
}
|
||||
msgHash = C.sha256Hash msgStr
|
||||
withStore $
|
||||
createSndMsg st sq $
|
||||
createSndMsg st cId $
|
||||
SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash}
|
||||
sendAgentMessage c sq msgStr
|
||||
atomically . writeTBQueue (sndQ c) $ ATransmission corrId (Conn cId) $ SENT (unId internalId)
|
||||
@@ -292,15 +335,21 @@ sendMessage c st corrId (Conn cId) msgBody =
|
||||
processBroadcastCommand ::
|
||||
forall c m. (AgentMonad m, EntityCommand 'Broadcast_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Broadcast_ -> ACommand 'Client c -> m ()
|
||||
processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case
|
||||
NEW -> withStore (createBcast st bId) >> ok
|
||||
NEW -> createNewBroadcast
|
||||
ADD (Conn cId) -> withStore (addBcastConn st bId cId) >> ok
|
||||
REM (Conn cId) -> withStore (removeBcastConn st bId cId) >> ok
|
||||
LS -> withStore (getBcast st bId) >>= respond bcast . MS . map Conn
|
||||
SEND msgBody -> withStore (getBcast st bId) >>= mapM_ (sendMsg msgBody) >> respond bcast (SENT 0)
|
||||
DEL -> withStore (deleteBcast st bId) >> ok
|
||||
where
|
||||
sendMsg :: MsgBody -> ConnAlias -> m ()
|
||||
sendMsg msgBody cId = sendMessage c st corrId (Conn cId) msgBody
|
||||
createNewBroadcast :: m ()
|
||||
createNewBroadcast = do
|
||||
g <- asks idsDrg
|
||||
bId' <- withStore $ createBcast st g bId
|
||||
respond (Broadcast bId') OK
|
||||
|
||||
sendMsg :: MsgBody -> ConnId -> m ()
|
||||
sendMsg msgBody cId = sendClientMessage c st corrId (Conn cId) msgBody
|
||||
|
||||
ok :: m ()
|
||||
ok = respond bcast OK
|
||||
@@ -308,7 +357,7 @@ processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case
|
||||
respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m ()
|
||||
respond ent resp = atomically . writeTBQueue (sndQ c) $ ATransmission corrId ent resp
|
||||
|
||||
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
subscriber :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
subscriber c@AgentClient {msgQ} st = forever $ do
|
||||
-- TODO this will only process messages and notifications
|
||||
t <- atomically $ readTBQueue msgQ
|
||||
@@ -319,29 +368,33 @@ subscriber c@AgentClient {msgQ} st = forever $ do
|
||||
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m ()
|
||||
processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
withStore (getRcvConn st srv rId) >>= \case
|
||||
SomeConn SCDuplex (DuplexConnection _ rq _) -> processSMP SCDuplex rq
|
||||
SomeConn SCRcv (RcvConnection _ rq) -> processSMP SCRcv rq
|
||||
SomeConn SCDuplex (DuplexConnection cData rq _) -> processSMP SCDuplex cData rq
|
||||
SomeConn SCRcv (RcvConnection cData rq) -> processSMP SCRcv cData rq
|
||||
_ -> atomically . writeTBQueue sndQ $ ATransmission "" (Conn "") (ERR $ CONN SIMPLEX)
|
||||
where
|
||||
processSMP :: SConnType c -> RcvQueue -> m ()
|
||||
processSMP cType rq@RcvQueue {connAlias, status} =
|
||||
processSMP :: SConnType c -> ConnData -> RcvQueue -> m ()
|
||||
processSMP cType ConnData {connId} rq@RcvQueue {status} =
|
||||
case cmd of
|
||||
SMP.MSG srvMsgId srvTs msgBody -> do
|
||||
-- TODO deduplicate with previously received
|
||||
msg <- decryptAndVerify rq msgBody
|
||||
let msgHash = C.sha256Hash msg
|
||||
agentMsg <- liftEither $ parseSMPMessage msg
|
||||
case agentMsg of
|
||||
SMPConfirmation senderKey -> smpConfirmation senderKey
|
||||
SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
|
||||
case parseSMPMessage msg of
|
||||
Left e -> notify $ ERR e
|
||||
Right (SMPConfirmation senderKey) -> smpConfirmation senderKey
|
||||
Right SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
|
||||
case agentMessage of
|
||||
HELLO verifyKey _ -> helloMsg verifyKey msgBody
|
||||
REPLY qInfo -> replyMsg qInfo
|
||||
A_MSG body -> agentClientMsg previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash
|
||||
A_INTRO entity eInfo -> introMsg entity eInfo
|
||||
A_INV conn qInfo cInfo -> invMsg conn qInfo cInfo
|
||||
A_REQ conn qInfo cInfo -> reqMsg conn qInfo cInfo
|
||||
A_CON conn -> conMsg conn
|
||||
sendAck c rq
|
||||
return ()
|
||||
SMP.END -> do
|
||||
removeSubscription c connAlias
|
||||
removeSubscription c connId
|
||||
logServer "<--" c srv rId "END"
|
||||
notify END
|
||||
_ -> do
|
||||
@@ -349,7 +402,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
notify . ERR $ BROKER UNEXPECTED
|
||||
where
|
||||
notify :: EntityCommand 'Conn_ c => ACommand 'Agent c -> m ()
|
||||
notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connAlias) msg
|
||||
notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connId) msg
|
||||
|
||||
prohibited :: m ()
|
||||
prohibited = notify . ERR $ AGENT A_PROHIBITED
|
||||
@@ -376,7 +429,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
void $ verifyMessage (Just verifyKey) msgBody
|
||||
withStore $ setRcvQueueActive st rq verifyKey
|
||||
case cType of
|
||||
SCDuplex -> notify CON
|
||||
SCDuplex -> connected
|
||||
_ -> pure ()
|
||||
|
||||
replyMsg :: SMPQueueInfo -> m ()
|
||||
@@ -384,22 +437,87 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
logServer "<--" c srv rId "MSG <REPLY>"
|
||||
case cType of
|
||||
SCRcv -> do
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo connAlias
|
||||
withStore $ upgradeRcvConnToDuplex st connAlias sq
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo
|
||||
withStore $ upgradeRcvConnToDuplex st connId sq
|
||||
connectToSendQueue c st sq senderKey verifyKey
|
||||
notify CON
|
||||
connected
|
||||
_ -> prohibited
|
||||
|
||||
connected :: m ()
|
||||
connected = do
|
||||
withStore (getConnInvitation st connId) >>= \case
|
||||
Just (Invitation {invId, externalIntroId}, DuplexConnection _ _ sq) -> do
|
||||
withStore $ setInvitationStatus st invId InvCon
|
||||
sendControlMessage c sq $ A_CON (Conn externalIntroId)
|
||||
_ -> pure ()
|
||||
notify CON
|
||||
|
||||
introMsg :: IntroEntity -> EntityInfo -> m ()
|
||||
introMsg (IE entity) entityInfo = do
|
||||
logServer "<--" c srv rId "MSG <INTRO>"
|
||||
case (cType, entity) of
|
||||
(SCDuplex, intro@Conn {}) -> createInv intro Nothing entityInfo
|
||||
_ -> prohibited
|
||||
|
||||
invMsg :: Entity 'Conn_ -> SMPQueueInfo -> EntityInfo -> m ()
|
||||
invMsg (Conn introId) qInfo toInfo = do
|
||||
logServer "<--" c srv rId "MSG <INV>"
|
||||
case cType of
|
||||
SCDuplex ->
|
||||
withStore (getIntro st introId) >>= \case
|
||||
Introduction {toConn, toStatus = IntroNew, reConn, reStatus = IntroNew}
|
||||
| toConn /= connId -> prohibited
|
||||
| otherwise ->
|
||||
withStore (addIntroInvitation st introId toInfo qInfo >> getConn st reConn) >>= \case
|
||||
SomeConn _ (DuplexConnection _ _ sq) -> do
|
||||
sendControlMessage c sq $ A_REQ (Conn introId) qInfo toInfo
|
||||
withStore $ setIntroReStatus st introId IntroInv
|
||||
_ -> prohibited
|
||||
_ -> prohibited
|
||||
_ -> prohibited
|
||||
|
||||
reqMsg :: Entity 'Conn_ -> SMPQueueInfo -> EntityInfo -> m ()
|
||||
reqMsg intro qInfo connInfo = do
|
||||
logServer "<--" c srv rId "MSG <REQ>"
|
||||
case cType of
|
||||
SCDuplex -> createInv intro (Just qInfo) connInfo
|
||||
_ -> prohibited
|
||||
|
||||
createInv :: Entity 'Conn_ -> Maybe SMPQueueInfo -> EntityInfo -> m ()
|
||||
createInv (Conn externalIntroId) qInfo entityInfo = do
|
||||
g <- asks idsDrg
|
||||
let newInv = NewInvitation {viaConn = connId, externalIntroId, entityInfo, qInfo}
|
||||
invId <- withStore $ createInvitation st g newInv
|
||||
notify $ REQ (IE (Conn invId)) entityInfo
|
||||
|
||||
conMsg :: Entity 'Conn_ -> m ()
|
||||
conMsg (Conn introId) = do
|
||||
logServer "<--" c srv rId "MSG <CON>"
|
||||
withStore (getIntro st introId) >>= \case
|
||||
Introduction {toConn, toStatus, reConn, reStatus}
|
||||
| toConn == connId && toStatus == IntroInv -> do
|
||||
withStore $ setIntroToStatus st introId IntroCon
|
||||
sendConMsg toConn reConn reStatus
|
||||
| reConn == connId && reStatus == IntroInv -> do
|
||||
withStore $ setIntroReStatus st introId IntroCon
|
||||
sendConMsg toConn reConn toStatus
|
||||
| otherwise -> prohibited
|
||||
where
|
||||
sendConMsg :: ConnId -> ConnId -> IntroStatus -> m ()
|
||||
sendConMsg toConn reConn IntroCon =
|
||||
atomically . writeTBQueue sndQ $ ATransmission "" (Conn toConn) $ ICON $ IE (Conn reConn)
|
||||
sendConMsg _ _ _ = pure ()
|
||||
|
||||
agentClientMsg :: PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m ()
|
||||
agentClientMsg receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
case status of
|
||||
Active -> do
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq
|
||||
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st connId
|
||||
let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash receivedPrevMsgHash
|
||||
withStore $
|
||||
createRcvMsg st rq $
|
||||
createRcvMsg st connId $
|
||||
RcvMsgData
|
||||
{ internalId,
|
||||
internalRcvId,
|
||||
@@ -438,8 +556,8 @@ connectToSendQueue c st sq senderKey verifyKey = do
|
||||
withStore $ setSndQueueStatus st sq Active
|
||||
|
||||
newSendQueue ::
|
||||
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey)
|
||||
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do
|
||||
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SndQueue, SenderPublicKey, VerificationKey)
|
||||
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) = do
|
||||
size <- asks $ rsaKeySize . config
|
||||
(senderKey, sndPrivateKey) <- liftIO $ C.generateKeyPair size
|
||||
(verifyKey, signKey) <- liftIO $ C.generateKeyPair size
|
||||
@@ -447,7 +565,6 @@ newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do
|
||||
SndQueue
|
||||
{ server = smpServer,
|
||||
sndId = senderId,
|
||||
connAlias,
|
||||
sndPrivateKey,
|
||||
encryptKey,
|
||||
signKey,
|
||||
|
||||
@@ -15,6 +15,7 @@ module Simplex.Messaging.Agent.Client
|
||||
closeSMPServerClients,
|
||||
newReceiveQueue,
|
||||
subscribeQueue,
|
||||
addSubscription,
|
||||
sendConfirmation,
|
||||
sendHello,
|
||||
secureQueue,
|
||||
@@ -61,8 +62,8 @@ data AgentClient = AgentClient
|
||||
sndQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue SMPServerTransmission,
|
||||
smpClients :: TVar (Map SMPServer SMPClient),
|
||||
subscrSrvrs :: TVar (Map SMPServer (Set ConnAlias)),
|
||||
subscrConns :: TVar (Map ConnAlias SMPServer),
|
||||
subscrSrvrs :: TVar (Map SMPServer (Set ConnId)),
|
||||
subscrConns :: TVar (Map ConnId SMPServer),
|
||||
clientId :: Int
|
||||
}
|
||||
|
||||
@@ -78,7 +79,7 @@ newAgentClient cc AgentConfig {tbqSize} = do
|
||||
writeTVar cc clientId
|
||||
return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId}
|
||||
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m, MonadFail m)
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
||||
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
@@ -106,7 +107,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
removeSubs >>= mapM_ (mapM_ notifySub)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
removeSubs :: IO (Maybe (Set ConnAlias))
|
||||
removeSubs :: IO (Maybe (Set ConnId))
|
||||
removeSubs = atomically $ do
|
||||
modifyTVar smpClients $ M.delete srv
|
||||
cs <- M.lookup srv <$> readTVar (subscrSrvrs c)
|
||||
@@ -117,7 +118,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
deleteKeys :: Ord k => Set k -> Map k a -> Map k a
|
||||
deleteKeys ks m = S.foldr' M.delete m ks
|
||||
|
||||
notifySub :: ConnAlias -> IO ()
|
||||
notifySub :: ConnId -> IO ()
|
||||
notifySub connAlias = atomically . writeTBQueue (sndQ c) $ ATransmission "" (Conn connAlias) END
|
||||
|
||||
closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m ()
|
||||
@@ -158,8 +159,8 @@ smpClientError = \case
|
||||
SMPTransportError e -> BROKER $ TRANSPORT e
|
||||
e -> INTERNAL $ show e
|
||||
|
||||
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo)
|
||||
newReceiveQueue c srv connAlias = do
|
||||
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> m (RcvQueue, SMPQueueInfo)
|
||||
newReceiveQueue c srv = do
|
||||
size <- asks $ rsaKeySize . config
|
||||
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateKeyPair size
|
||||
logServer "-->" c srv "" "NEW"
|
||||
@@ -170,7 +171,6 @@ newReceiveQueue c srv connAlias = do
|
||||
RcvQueue
|
||||
{ server = srv,
|
||||
rcvId,
|
||||
connAlias,
|
||||
rcvPrivateKey,
|
||||
sndId = Just sId,
|
||||
sndKey = Nothing,
|
||||
@@ -178,25 +178,24 @@ newReceiveQueue c srv connAlias = do
|
||||
verifyKey = Nothing,
|
||||
status = New
|
||||
}
|
||||
addSubscription c rq connAlias
|
||||
return (rq, SMPQueueInfo srv sId encryptKey)
|
||||
|
||||
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnAlias -> m ()
|
||||
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connAlias = do
|
||||
withLogSMP c server rcvId "SUB" $ \smp ->
|
||||
subscribeSMPQueue smp rcvPrivateKey rcvId
|
||||
addSubscription c rq connAlias
|
||||
|
||||
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnAlias -> m ()
|
||||
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
addSubscription c RcvQueue {server} connAlias = atomically $ do
|
||||
modifyTVar (subscrConns c) $ M.insert connAlias server
|
||||
modifyTVar (subscrSrvrs c) $ M.alter (Just . addSub) server
|
||||
where
|
||||
addSub :: Maybe (Set ConnAlias) -> Set ConnAlias
|
||||
addSub :: Maybe (Set ConnId) -> Set ConnId
|
||||
addSub (Just cs) = S.insert connAlias cs
|
||||
addSub _ = S.singleton connAlias
|
||||
|
||||
removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m ()
|
||||
removeSubscription :: AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically $ do
|
||||
cs <- readTVar subscrConns
|
||||
writeTVar subscrConns $ M.delete connAlias cs
|
||||
@@ -204,7 +203,7 @@ removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically
|
||||
(modifyTVar subscrSrvrs . M.alter (>>= delSub))
|
||||
(M.lookup connAlias cs)
|
||||
where
|
||||
delSub :: Set ConnAlias -> Maybe (Set ConnAlias)
|
||||
delSub :: Set ConnId -> Maybe (Set ConnId)
|
||||
delSub cs =
|
||||
let cs' = S.delete connAlias cs
|
||||
in if S.null cs' then Nothing else Just cs'
|
||||
|
||||
@@ -33,6 +33,8 @@ module Simplex.Messaging.Agent.Protocol
|
||||
Entity (..),
|
||||
EntityTag (..),
|
||||
AnEntity (..),
|
||||
IntroEntity (..),
|
||||
EntityInfo,
|
||||
EntityCommand,
|
||||
entityCommand,
|
||||
ACommand (..),
|
||||
@@ -53,7 +55,7 @@ module Simplex.Messaging.Agent.Protocol
|
||||
ATransmission (..),
|
||||
ATransmissionOrError (..),
|
||||
ARawTransmission,
|
||||
ConnAlias,
|
||||
ConnId,
|
||||
ReplyMode (..),
|
||||
AckMode (..),
|
||||
OnOff (..),
|
||||
@@ -171,6 +173,13 @@ deriving instance Eq (Entity t)
|
||||
|
||||
deriving instance Show (Entity t)
|
||||
|
||||
instance TestEquality Entity where
|
||||
testEquality (Conn c) (Conn c') = refl c c'
|
||||
testEquality (OpenConn c) (OpenConn c') = refl c c'
|
||||
testEquality (Broadcast c) (Broadcast c') = refl c c'
|
||||
testEquality (AGroup c) (AGroup c') = refl c c'
|
||||
testEquality _ _ = Nothing
|
||||
|
||||
entityId :: Entity t -> ByteString
|
||||
entityId = \case
|
||||
Conn bs -> bs
|
||||
@@ -195,7 +204,11 @@ type family EntityCommand (t :: EntityTag) (c :: ACmdTag) :: Constraint where
|
||||
EntityCommand Conn_ NEW_ = ()
|
||||
EntityCommand Conn_ INV_ = ()
|
||||
EntityCommand Conn_ JOIN_ = ()
|
||||
EntityCommand Conn_ INTRO_ = ()
|
||||
EntityCommand Conn_ REQ_ = ()
|
||||
EntityCommand Conn_ ACPT_ = ()
|
||||
EntityCommand Conn_ CON_ = ()
|
||||
EntityCommand Conn_ ICON_ = ()
|
||||
EntityCommand Conn_ SUB_ = ()
|
||||
EntityCommand Conn_ SUBALL_ = ()
|
||||
EntityCommand Conn_ END_ = ()
|
||||
@@ -226,7 +239,11 @@ entityCommand = \case
|
||||
NEW -> Just Dict
|
||||
INV _ -> Just Dict
|
||||
JOIN {} -> Just Dict
|
||||
INTRO {} -> Just Dict
|
||||
REQ {} -> Just Dict
|
||||
ACPT {} -> Just Dict
|
||||
CON -> Just Dict
|
||||
ICON {} -> Just Dict
|
||||
SUB -> Just Dict
|
||||
SUBALL -> Just Dict
|
||||
END -> Just Dict
|
||||
@@ -258,7 +275,11 @@ data ACmdTag
|
||||
= NEW_
|
||||
| INV_
|
||||
| JOIN_
|
||||
| INTRO_
|
||||
| REQ_
|
||||
| ACPT_
|
||||
| CON_
|
||||
| ICON_
|
||||
| SUB_
|
||||
| SUBALL_
|
||||
| END_
|
||||
@@ -274,15 +295,31 @@ data ACmdTag
|
||||
| OK_
|
||||
| ERR_
|
||||
|
||||
type family Introduction (t :: EntityTag) :: Constraint where
|
||||
Introduction Conn_ = ()
|
||||
Introduction OpenConn_ = ()
|
||||
Introduction AGroup_ = ()
|
||||
Introduction t = (Int ~ Bool, TypeError (Text "Entity " :<>: ShowType t :<>: Text " cannot be INTRO'd to"))
|
||||
|
||||
data IntroEntity = forall t. Introduction t => IE (Entity t)
|
||||
|
||||
instance Eq IntroEntity where
|
||||
IE e1 == IE e2 = isJust $ testEquality e1 e2
|
||||
|
||||
deriving instance Show IntroEntity
|
||||
|
||||
type EntityInfo = ByteString
|
||||
|
||||
-- | Parameterized type for SMP agent protocol commands and responses from all participants.
|
||||
data ACommand (p :: AParty) (c :: ACmdTag) where
|
||||
NEW :: ACommand Client NEW_ -- response INV
|
||||
INV :: SMPQueueInfo -> ACommand Agent INV_
|
||||
JOIN :: SMPQueueInfo -> ReplyMode -> ACommand Client JOIN_ -- response OK
|
||||
INTRO :: IntroEntity -> EntityInfo -> ACommand Client INTRO_
|
||||
REQ :: IntroEntity -> EntityInfo -> ACommand Agent INTRO_
|
||||
ACPT :: IntroEntity -> EntityInfo -> ACommand Client ACPT_
|
||||
CON :: ACommand Agent CON_ -- notification that connection is established
|
||||
-- TODO currently it automatically allows whoever sends the confirmation
|
||||
-- CONF :: OtherPartyId -> ACommand Agent
|
||||
-- LET :: OtherPartyId -> ACommand Client
|
||||
ICON :: IntroEntity -> ACommand Agent ICON_
|
||||
SUB :: ACommand Client SUB_
|
||||
SUBALL :: ACommand Client SUBALL_ -- TODO should be moved to chat protocol - hack for subscribing to all
|
||||
END :: ACommand Agent END_
|
||||
@@ -318,6 +355,7 @@ instance TestEquality (ACommand p) where
|
||||
testEquality c@INV {} c'@INV {} = refl c c'
|
||||
testEquality c@JOIN {} c'@JOIN {} = refl c c'
|
||||
testEquality CON CON = Just Refl
|
||||
testEquality c@ICON {} c'@ICON {} = refl c c'
|
||||
testEquality SUB SUB = Just Refl
|
||||
testEquality SUBALL SUBALL = Just Refl
|
||||
testEquality END END = Just Refl
|
||||
@@ -334,7 +372,7 @@ instance TestEquality (ACommand p) where
|
||||
testEquality c@ERR {} c'@ERR {} = refl c c'
|
||||
testEquality _ _ = Nothing
|
||||
|
||||
refl :: Eq (f a) => f a -> f a -> Maybe (a :~: a)
|
||||
refl :: Eq a => a -> a -> Maybe (t :~: t)
|
||||
refl x x' = if x == x' then Just Refl else Nothing
|
||||
|
||||
-- | SMP message formats.
|
||||
@@ -366,6 +404,14 @@ data AMessage where
|
||||
REPLY :: SMPQueueInfo -> AMessage
|
||||
-- | agent envelope for the client message
|
||||
A_MSG :: MsgBody -> AMessage
|
||||
-- | agent message for introduction
|
||||
A_INTRO :: IntroEntity -> EntityInfo -> AMessage
|
||||
-- | agent envelope for the sent invitation
|
||||
A_INV :: Entity Conn_ -> SMPQueueInfo -> EntityInfo -> AMessage
|
||||
-- | agent envelope for the forwarded invitation
|
||||
A_REQ :: Entity Conn_ -> SMPQueueInfo -> EntityInfo -> AMessage
|
||||
-- | agent message for intro/group request
|
||||
A_CON :: Entity Conn_ -> AMessage
|
||||
deriving (Show)
|
||||
|
||||
-- | Parse SMP message.
|
||||
@@ -408,12 +454,22 @@ agentMessageP =
|
||||
"HELLO " *> hello
|
||||
<|> "REPLY " *> reply
|
||||
<|> "MSG " *> a_msg
|
||||
<|> "INTRO " *> a_intro
|
||||
<|> "INV " *> a_inv
|
||||
<|> "REQ " *> a_req
|
||||
<|> "CON " *> a_con
|
||||
where
|
||||
hello = HELLO <$> C.pubKeyP <*> ackMode
|
||||
reply = REPLY <$> smpQueueInfoP
|
||||
a_msg = do
|
||||
a_msg = A_MSG <$> binaryBody
|
||||
a_intro = A_INTRO <$> introEntityP <* A.space <*> binaryBody
|
||||
a_inv = invP A_INV
|
||||
a_req = invP A_REQ
|
||||
a_con = A_CON <$> connEntityP
|
||||
invP f = f <$> connEntityP <* A.space <*> smpQueueInfoP <* A.space <*> binaryBody
|
||||
binaryBody = do
|
||||
size :: Int <- A.decimal <* A.endOfLine
|
||||
A_MSG <$> A.take size <* A.endOfLine
|
||||
A.take size <* A.endOfLine
|
||||
ackMode = AckMode <$> (" NO_ACK" $> Off <|> pure On)
|
||||
|
||||
-- | SMP queue information parser.
|
||||
@@ -434,6 +490,13 @@ serializeAgentMessage = \case
|
||||
HELLO verifyKey ackMode -> "HELLO " <> C.serializePubKey verifyKey <> if ackMode == AckMode Off then " NO_ACK" else ""
|
||||
REPLY qInfo -> "REPLY " <> serializeSmpQueueInfo qInfo
|
||||
A_MSG body -> "MSG " <> serializeMsg body <> "\n"
|
||||
A_INTRO (IE entity) eInfo -> "INTRO " <> serializeIntro entity eInfo <> "\n"
|
||||
A_INV conn qInfo eInfo -> "INV " <> serializeInv conn qInfo eInfo
|
||||
A_REQ conn qInfo eInfo -> "REQ " <> serializeInv conn qInfo eInfo
|
||||
A_CON conn -> "CON " <> serializeEntity conn
|
||||
where
|
||||
serializeInv conn qInfo eInfo =
|
||||
B.intercalate " " [serializeEntity conn, serializeSmpQueueInfo qInfo, serializeMsg eInfo] <> "\n"
|
||||
|
||||
-- | Serialize SMP queue information that is sent out-of-band.
|
||||
serializeSmpQueueInfo :: SMPQueueInfo -> ByteString
|
||||
@@ -457,7 +520,7 @@ instance IsString SMPServer where
|
||||
fromString = parseString . parseAll $ smpServerP
|
||||
|
||||
-- | SMP agent connection alias.
|
||||
type ConnAlias = ByteString
|
||||
type ConnId = ByteString
|
||||
|
||||
-- | Connection modes.
|
||||
data OnOff = On | Off deriving (Eq, Show, Read)
|
||||
@@ -614,10 +677,19 @@ anEntityP =
|
||||
<|> "B:" $> AE . Broadcast
|
||||
<|> "G:" $> AE . AGroup
|
||||
)
|
||||
<*> A.takeTill (== ' ')
|
||||
<*> A.takeTill wordEnd
|
||||
|
||||
entityConnP :: Parser (Entity Conn_)
|
||||
entityConnP = "C:" *> (Conn <$> A.takeTill (== ' '))
|
||||
connEntityP :: Parser (Entity Conn_)
|
||||
connEntityP = "C:" *> (Conn <$> A.takeTill wordEnd)
|
||||
|
||||
introEntityP :: Parser IntroEntity
|
||||
introEntityP =
|
||||
($)
|
||||
<$> ( "C:" $> IE . Conn
|
||||
<|> "O:" $> IE . OpenConn
|
||||
<|> "G:" $> IE . AGroup
|
||||
)
|
||||
<*> A.takeTill wordEnd
|
||||
|
||||
serializeEntity :: Entity t -> ByteString
|
||||
serializeEntity = \case
|
||||
@@ -632,6 +704,9 @@ commandP =
|
||||
"NEW" $> ACmd SClient NEW
|
||||
<|> "INV " *> invResp
|
||||
<|> "JOIN " *> joinCmd
|
||||
<|> "INTRO " *> introCmd
|
||||
<|> "REQ " *> reqCmd
|
||||
<|> "ACPT " *> acptCmd
|
||||
<|> "SUB" $> ACmd SClient SUB
|
||||
<|> "SUBALL" $> ACmd SClient SUBALL -- TODO remove - hack for subscribing to all
|
||||
<|> "END" $> ACmd SAgent END
|
||||
@@ -645,16 +720,21 @@ commandP =
|
||||
<|> "LS" $> ACmd SClient LS
|
||||
<|> "MS " *> membersResp
|
||||
<|> "ERR " *> agentError
|
||||
<|> "ICON " *> iconMsg
|
||||
<|> "CON" $> ACmd SAgent CON
|
||||
<|> "OK" $> ACmd SAgent OK
|
||||
where
|
||||
invResp = ACmd SAgent . INV <$> smpQueueInfoP
|
||||
joinCmd = ACmd SClient <$> (JOIN <$> smpQueueInfoP <*> replyMode)
|
||||
introCmd = ACmd SClient <$> introP INTRO
|
||||
reqCmd = ACmd SAgent <$> introP REQ
|
||||
acptCmd = ACmd SClient <$> introP ACPT
|
||||
sendCmd = ACmd SClient . SEND <$> A.takeByteString
|
||||
sentResp = ACmd SAgent . SENT <$> A.decimal
|
||||
addCmd = ACmd SClient . ADD <$> entityConnP
|
||||
removeCmd = ACmd SClient . REM <$> entityConnP
|
||||
membersResp = ACmd SAgent . MS <$> (entityConnP `A.sepBy'` A.char ' ')
|
||||
addCmd = ACmd SClient . ADD <$> connEntityP
|
||||
removeCmd = ACmd SClient . REM <$> connEntityP
|
||||
membersResp = ACmd SAgent . MS <$> (connEntityP `A.sepBy'` A.char ' ')
|
||||
iconMsg = ACmd SAgent . ICON <$> introEntityP
|
||||
message = do
|
||||
msgIntegrity <- msgIntegrityP <* A.space
|
||||
recipientMeta <- "R=" *> partyMeta A.decimal
|
||||
@@ -662,6 +742,7 @@ commandP =
|
||||
senderMeta <- "S=" *> partyMeta A.decimal
|
||||
msgBody <- A.takeByteString
|
||||
return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody}
|
||||
introP f = f <$> introEntityP <* A.space <*> A.takeByteString
|
||||
replyMode = ReplyMode <$> (" NO_REPLY" $> Off <|> pure On)
|
||||
partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space
|
||||
agentError = ACmd SAgent . ERR <$> agentErrorTypeP
|
||||
@@ -685,6 +766,9 @@ serializeCommand = \case
|
||||
NEW -> "NEW"
|
||||
INV qInfo -> "INV " <> serializeSmpQueueInfo qInfo
|
||||
JOIN qInfo rMode -> "JOIN " <> serializeSmpQueueInfo qInfo <> replyMode rMode
|
||||
INTRO (IE entity) eInfo -> "INTRO " <> serializeIntro entity eInfo
|
||||
REQ (IE entity) eInfo -> "REQ " <> serializeIntro entity eInfo
|
||||
ACPT (IE entity) eInfo -> "ACPT " <> serializeIntro entity eInfo
|
||||
SUB -> "SUB"
|
||||
SUBALL -> "SUBALL" -- TODO remove - hack for subscribing to all
|
||||
END -> "END"
|
||||
@@ -706,6 +790,7 @@ serializeCommand = \case
|
||||
LS -> "LS"
|
||||
MS cs -> "MS " <> B.intercalate " " (map serializeEntity cs)
|
||||
CON -> "CON"
|
||||
ICON (IE entity) -> "ICON " <> serializeEntity entity
|
||||
ERR e -> "ERR " <> serializeAgentError e
|
||||
OK -> "OK"
|
||||
where
|
||||
@@ -716,6 +801,9 @@ serializeCommand = \case
|
||||
showTs :: UTCTime -> ByteString
|
||||
showTs = B.pack . formatISO8601Millis
|
||||
|
||||
serializeIntro :: Entity t -> ByteString -> ByteString
|
||||
serializeIntro entity eInfo = serializeEntity entity <> " " <> serializeMsg eInfo
|
||||
|
||||
-- | Serialize message integrity validation result.
|
||||
serializeMsgIntegrity :: MsgIntegrity -> ByteString
|
||||
serializeMsgIntegrity = \case
|
||||
@@ -794,9 +882,10 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
hasEntityId :: AnEntity -> APartyCmd p -> Either AgentErrorType (APartyCmd p)
|
||||
hasEntityId (AE entity) (APartyCmd cmd) =
|
||||
APartyCmd <$> case cmd of
|
||||
-- NEW and JOIN have optional entity
|
||||
-- NEW, JOIN and ACPT have optional entity
|
||||
NEW -> Right cmd
|
||||
JOIN _ _ -> Right cmd
|
||||
JOIN {} -> Right cmd
|
||||
ACPT {} -> Right cmd
|
||||
-- ERROR response does not always have entity
|
||||
ERR _ -> Right cmd
|
||||
-- other responses must have entity
|
||||
@@ -818,6 +907,9 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
APartyCmd <$$> case cmd of
|
||||
SEND body -> SEND <$$> getMsgBody body
|
||||
MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body
|
||||
INTRO entity eInfo -> INTRO entity <$$> getMsgBody eInfo
|
||||
REQ entity eInfo -> REQ entity <$$> getMsgBody eInfo
|
||||
ACPT entity eInfo -> ACPT entity <$$> getMsgBody eInfo
|
||||
_ -> pure $ Right cmd
|
||||
|
||||
-- TODO refactor with server
|
||||
|
||||
@@ -3,16 +3,21 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store where
|
||||
|
||||
import Control.Concurrent.STM (TVar)
|
||||
import Control.Exception (Exception)
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import Data.Kind (Type)
|
||||
import Data.Text (Text)
|
||||
import Data.Time (UTCTime)
|
||||
import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
@@ -30,33 +35,45 @@ import qualified Simplex.Messaging.Protocol as SMP
|
||||
-- | Store class type. Defines store access methods for implementations.
|
||||
class Monad m => MonadAgentStore s m where
|
||||
-- Queue and Connection management
|
||||
createRcvConn :: s -> RcvQueue -> m ()
|
||||
createSndConn :: s -> SndQueue -> m ()
|
||||
getConn :: s -> ConnAlias -> m SomeConn
|
||||
getAllConnAliases :: s -> m [ConnAlias] -- TODO remove - hack for subscribing to all
|
||||
createRcvConn :: s -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId
|
||||
createSndConn :: s -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId
|
||||
getConn :: s -> ConnId -> m SomeConn
|
||||
getAllConnIds :: s -> m [ConnId] -- TODO remove - hack for subscribing to all
|
||||
getRcvConn :: s -> SMPServer -> SMP.RecipientId -> m SomeConn
|
||||
deleteConn :: s -> ConnAlias -> m ()
|
||||
upgradeRcvConnToDuplex :: s -> ConnAlias -> SndQueue -> m ()
|
||||
upgradeSndConnToDuplex :: s -> ConnAlias -> RcvQueue -> m ()
|
||||
deleteConn :: s -> ConnId -> m ()
|
||||
upgradeRcvConnToDuplex :: s -> ConnId -> SndQueue -> m ()
|
||||
upgradeSndConnToDuplex :: s -> ConnId -> RcvQueue -> m ()
|
||||
setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m ()
|
||||
setRcvQueueActive :: s -> RcvQueue -> VerificationKey -> m ()
|
||||
setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m ()
|
||||
|
||||
-- Msg management
|
||||
updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m ()
|
||||
updateRcvIds :: s -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
createRcvMsg :: s -> ConnId -> RcvMsgData -> m ()
|
||||
|
||||
updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
createSndMsg :: s -> SndQueue -> SndMsgData -> m ()
|
||||
updateSndIds :: s -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
createSndMsg :: s -> ConnId -> SndMsgData -> m ()
|
||||
|
||||
getMsg :: s -> ConnAlias -> InternalId -> m Msg
|
||||
getMsg :: s -> ConnId -> InternalId -> m Msg
|
||||
|
||||
-- Broadcasts
|
||||
createBcast :: s -> BroadcastId -> m ()
|
||||
addBcastConn :: s -> BroadcastId -> ConnAlias -> m ()
|
||||
removeBcastConn :: s -> BroadcastId -> ConnAlias -> m ()
|
||||
createBcast :: s -> TVar ChaChaDRG -> BroadcastId -> m BroadcastId
|
||||
addBcastConn :: s -> BroadcastId -> ConnId -> m ()
|
||||
removeBcastConn :: s -> BroadcastId -> ConnId -> m ()
|
||||
deleteBcast :: s -> BroadcastId -> m ()
|
||||
getBcast :: s -> BroadcastId -> m [ConnAlias]
|
||||
getBcast :: s -> BroadcastId -> m [ConnId]
|
||||
|
||||
-- Introductions
|
||||
createIntro :: s -> TVar ChaChaDRG -> NewIntroduction -> m IntroId
|
||||
getIntro :: s -> IntroId -> m Introduction
|
||||
addIntroInvitation :: s -> IntroId -> EntityInfo -> SMPQueueInfo -> m ()
|
||||
setIntroToStatus :: s -> IntroId -> IntroStatus -> m ()
|
||||
setIntroReStatus :: s -> IntroId -> IntroStatus -> m ()
|
||||
createInvitation :: s -> TVar ChaChaDRG -> NewInvitation -> m InvitationId
|
||||
getInvitation :: s -> InvitationId -> m Invitation
|
||||
addInvitationConn :: s -> InvitationId -> ConnId -> m ()
|
||||
getConnInvitation :: s -> ConnId -> m (Maybe (Invitation, Connection CDuplex))
|
||||
setInvitationStatus :: s -> InvitationId -> InvitationStatus -> m ()
|
||||
|
||||
-- * Queue types
|
||||
|
||||
@@ -64,7 +81,6 @@ class Monad m => MonadAgentStore s m where
|
||||
data RcvQueue = RcvQueue
|
||||
{ server :: SMPServer,
|
||||
rcvId :: SMP.RecipientId,
|
||||
connAlias :: ConnAlias,
|
||||
rcvPrivateKey :: RecipientPrivateKey,
|
||||
sndId :: Maybe SMP.SenderId,
|
||||
sndKey :: Maybe SenderPublicKey,
|
||||
@@ -78,7 +94,6 @@ data RcvQueue = RcvQueue
|
||||
data SndQueue = SndQueue
|
||||
{ server :: SMPServer,
|
||||
sndId :: SMP.SenderId,
|
||||
connAlias :: ConnAlias,
|
||||
sndPrivateKey :: SenderPrivateKey,
|
||||
encryptKey :: EncryptionKey,
|
||||
signKey :: SignatureKey,
|
||||
@@ -102,9 +117,9 @@ data ConnType = CRcv | CSnd | CDuplex deriving (Eq, Show)
|
||||
-- - DuplexConnection is a connection that has both receive and send queues set up,
|
||||
-- typically created by upgrading a receive or a send connection with a missing queue.
|
||||
data Connection (d :: ConnType) where
|
||||
RcvConnection :: ConnAlias -> RcvQueue -> Connection CRcv
|
||||
SndConnection :: ConnAlias -> SndQueue -> Connection CSnd
|
||||
DuplexConnection :: ConnAlias -> RcvQueue -> SndQueue -> Connection CDuplex
|
||||
RcvConnection :: ConnData -> RcvQueue -> Connection CRcv
|
||||
SndConnection :: ConnData -> SndQueue -> Connection CSnd
|
||||
DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex
|
||||
|
||||
deriving instance Eq (Connection d)
|
||||
|
||||
@@ -141,6 +156,9 @@ instance Eq SomeConn where
|
||||
|
||||
deriving instance Show SomeConn
|
||||
|
||||
data ConnData = ConnData {connId :: ConnId, viaInv :: Maybe InvitationId, connLevel :: Int}
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- * Message integrity validation types
|
||||
|
||||
type MsgHash = ByteString
|
||||
@@ -263,7 +281,7 @@ type DeliveredTs = UTCTime
|
||||
|
||||
-- | Base message data independent of direction.
|
||||
data MsgBase = MsgBase
|
||||
{ connAlias :: ConnAlias,
|
||||
{ connAlias :: ConnId,
|
||||
-- | Monotonically increasing id of a message per connection, internal to the agent.
|
||||
-- Internal Id preserves ordering between both received and sent messages, and is needed
|
||||
-- to track the order of the conversation (which can be different for the sender / receiver)
|
||||
@@ -281,12 +299,87 @@ newtype InternalId = InternalId {unId :: Int64} deriving (Eq, Show)
|
||||
|
||||
type InternalTs = UTCTime
|
||||
|
||||
-- * Introduction types
|
||||
|
||||
data NewIntroduction = NewIntroduction
|
||||
{ toConn :: ConnId,
|
||||
reConn :: ConnId,
|
||||
reInfo :: ByteString
|
||||
}
|
||||
|
||||
data Introduction = Introduction
|
||||
{ introId :: IntroId,
|
||||
toConn :: ConnId,
|
||||
toInfo :: Maybe ByteString,
|
||||
toStatus :: IntroStatus,
|
||||
reConn :: ConnId,
|
||||
reInfo :: ByteString,
|
||||
reStatus :: IntroStatus,
|
||||
qInfo :: Maybe SMPQueueInfo
|
||||
}
|
||||
|
||||
data IntroStatus = IntroNew | IntroInv | IntroCon
|
||||
deriving (Eq)
|
||||
|
||||
serializeIntroStatus :: IntroStatus -> Text
|
||||
serializeIntroStatus = \case
|
||||
IntroNew -> ""
|
||||
IntroInv -> "INV"
|
||||
IntroCon -> "CON"
|
||||
|
||||
introStatusT :: Text -> Maybe IntroStatus
|
||||
introStatusT = \case
|
||||
"" -> Just IntroNew
|
||||
"INV" -> Just IntroInv
|
||||
"CON" -> Just IntroCon
|
||||
_ -> Nothing
|
||||
|
||||
type IntroId = ByteString
|
||||
|
||||
data NewInvitation = NewInvitation
|
||||
{ viaConn :: ConnId,
|
||||
externalIntroId :: IntroId,
|
||||
entityInfo :: EntityInfo,
|
||||
qInfo :: Maybe SMPQueueInfo
|
||||
}
|
||||
|
||||
data Invitation = Invitation
|
||||
{ invId :: InvitationId,
|
||||
viaConn :: ConnId,
|
||||
externalIntroId :: IntroId,
|
||||
entityInfo :: EntityInfo,
|
||||
qInfo :: Maybe SMPQueueInfo,
|
||||
connId :: Maybe ConnId,
|
||||
status :: InvitationStatus
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data InvitationStatus = InvNew | InvAcpt | InvCon
|
||||
deriving (Eq, Show)
|
||||
|
||||
serializeInvStatus :: InvitationStatus -> Text
|
||||
serializeInvStatus = \case
|
||||
InvNew -> ""
|
||||
InvAcpt -> "ACPT"
|
||||
InvCon -> "CON"
|
||||
|
||||
invStatusT :: Text -> Maybe InvitationStatus
|
||||
invStatusT = \case
|
||||
"" -> Just InvNew
|
||||
"ACPT" -> Just InvAcpt
|
||||
"CON" -> Just InvCon
|
||||
_ -> Nothing
|
||||
|
||||
type InvitationId = ByteString
|
||||
|
||||
-- * Store errors
|
||||
|
||||
-- | Agent store error.
|
||||
data StoreError
|
||||
= -- | IO exceptions in store actions.
|
||||
SEInternal ByteString
|
||||
| -- | failed to generate unique random ID
|
||||
SEUniqueID
|
||||
| -- | Connection alias not found (or both queues absent).
|
||||
SEConnNotFound
|
||||
| -- | Connection alias already used.
|
||||
@@ -298,6 +391,10 @@ data StoreError
|
||||
SEBcastNotFound
|
||||
| -- | Broadcast ID already used.
|
||||
SEBcastDuplicate
|
||||
| -- | Introduction ID not found.
|
||||
SEIntroNotFound
|
||||
| -- | Invitation ID not found.
|
||||
SEInvitationNotFound
|
||||
| -- | Currently not used. The intention was to pass current expected queue status in methods,
|
||||
-- as we always know what it should be at any stage of the protocol,
|
||||
-- and in case it does not match use this error.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
@@ -11,6 +12,7 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
@@ -22,14 +24,19 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
where
|
||||
|
||||
import Control.Concurrent (threadDelay)
|
||||
import Control.Monad (unless, when)
|
||||
import Control.Concurrent.STM (TVar, atomically, stateTVar)
|
||||
import Control.Monad (join, unless, when)
|
||||
import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO))
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Crypto.Random (ChaChaDRG, randomBytesGenerate)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.ByteString.Base64 (encode)
|
||||
import Data.Char (toLower)
|
||||
import Data.Functor (($>))
|
||||
import Data.List (find)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Text (isPrefixOf)
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLData (..), SQLError, field)
|
||||
@@ -66,8 +73,8 @@ createSQLiteStore dbFilePath = do
|
||||
let dbDir = takeDirectory dbFilePath
|
||||
createDirectoryIfMissing False dbDir
|
||||
store <- connectSQLiteStore dbFilePath
|
||||
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]]
|
||||
let threadsafeOption = find (isPrefixOf "THREADSAFE=") (concat compileOptions)
|
||||
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[Text]]
|
||||
let threadsafeOption = find (T.isPrefixOf "THREADSAFE=") (concat compileOptions)
|
||||
case threadsafeOption of
|
||||
Just "THREADSAFE=0" -> confirmOrExit "SQLite compiled with non-threadsafe code."
|
||||
Nothing -> putStrLn "Warning: SQLite THREADSAFE compile option not found"
|
||||
@@ -107,13 +114,16 @@ connectSQLiteStore dbFilePath = do
|
||||
|]
|
||||
pure SQLiteStore {dbFilePath, dbConn, dbNew}
|
||||
|
||||
checkConstraint :: StoreError -> IO () -> IO (Either StoreError ())
|
||||
checkConstraint err action = first handleError <$> E.try action
|
||||
where
|
||||
handleError :: SQLError -> StoreError
|
||||
handleError e
|
||||
| DB.sqlError e == DB.ErrorConstraint = err
|
||||
| otherwise = SEInternal $ bshow e
|
||||
checkConstraint :: StoreError -> IO a -> IO (Either StoreError a)
|
||||
checkConstraint err action = first (handleSQLError err) <$> E.try action
|
||||
|
||||
checkConstraint' :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a)
|
||||
checkConstraint' err action = action `E.catch` (pure . Left . handleSQLError err)
|
||||
|
||||
handleSQLError :: StoreError -> SQLError -> StoreError
|
||||
handleSQLError err e
|
||||
| DB.sqlError e == DB.ErrorConstraint = err
|
||||
| otherwise = SEInternal $ bshow e
|
||||
|
||||
withTransaction :: forall a. DB.Connection -> IO a -> IO a
|
||||
withTransaction db a = loop 100 100_000
|
||||
@@ -128,33 +138,43 @@ withTransaction db a = loop 100 100_000
|
||||
else E.throwIO e
|
||||
|
||||
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
|
||||
createRcvConn :: SQLiteStore -> RcvQueue -> m ()
|
||||
createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} =
|
||||
liftIOEither $
|
||||
checkConstraint SEConnDuplicate $
|
||||
withTransaction dbConn $ do
|
||||
upsertServer_ dbConn server
|
||||
insertRcvQueue_ dbConn q
|
||||
insertRcvConnection_ dbConn q
|
||||
createRcvConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId
|
||||
createRcvConn SQLiteStore {dbConn} gVar cData q@RcvQueue {server} =
|
||||
-- TODO if schema has to be restarted, this function can be refactored
|
||||
-- to create connection first using createWithRandomId
|
||||
liftIOEither . checkConstraint' SEConnDuplicate . withTransaction dbConn $
|
||||
getConnId_ dbConn gVar cData >>= traverse create
|
||||
where
|
||||
create :: ConnId -> IO ConnId
|
||||
create connId = do
|
||||
upsertServer_ dbConn server
|
||||
insertRcvQueue_ dbConn connId q
|
||||
insertRcvConnection_ dbConn cData {connId} q
|
||||
pure connId
|
||||
|
||||
createSndConn :: SQLiteStore -> SndQueue -> m ()
|
||||
createSndConn SQLiteStore {dbConn} q@SndQueue {server} =
|
||||
liftIOEither $
|
||||
checkConstraint SEConnDuplicate $
|
||||
withTransaction dbConn $ do
|
||||
upsertServer_ dbConn server
|
||||
insertSndQueue_ dbConn q
|
||||
insertSndConnection_ dbConn q
|
||||
createSndConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId
|
||||
createSndConn SQLiteStore {dbConn} gVar cData q@SndQueue {server} =
|
||||
-- TODO if schema has to be restarted, this function can be refactored
|
||||
-- to create connection first using createWithRandomId
|
||||
liftIOEither . checkConstraint' SEConnDuplicate . withTransaction dbConn $
|
||||
getConnId_ dbConn gVar cData >>= traverse create
|
||||
where
|
||||
create :: ConnId -> IO ConnId
|
||||
create connId = do
|
||||
upsertServer_ dbConn server
|
||||
insertSndQueue_ dbConn connId q
|
||||
insertSndConnection_ dbConn cData {connId} q
|
||||
pure connId
|
||||
|
||||
getConn :: SQLiteStore -> ConnAlias -> m SomeConn
|
||||
getConn SQLiteStore {dbConn} connAlias =
|
||||
getConn :: SQLiteStore -> ConnId -> m SomeConn
|
||||
getConn SQLiteStore {dbConn} connId =
|
||||
liftIOEither . withTransaction dbConn $
|
||||
getConn_ dbConn connAlias
|
||||
getConn_ dbConn connId
|
||||
|
||||
getAllConnAliases :: SQLiteStore -> m [ConnAlias]
|
||||
getAllConnAliases SQLiteStore {dbConn} =
|
||||
getAllConnIds :: SQLiteStore -> m [ConnId]
|
||||
getAllConnIds SQLiteStore {dbConn} =
|
||||
liftIO $ do
|
||||
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]]
|
||||
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnId]]
|
||||
return (concat r)
|
||||
|
||||
getRcvConn :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m SomeConn
|
||||
@@ -169,37 +189,37 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
|]
|
||||
[":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
|
||||
>>= \case
|
||||
[Only connAlias] -> getConn_ dbConn connAlias
|
||||
[Only connId] -> getConn_ dbConn connId
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
deleteConn :: SQLiteStore -> ConnAlias -> m ()
|
||||
deleteConn SQLiteStore {dbConn} connAlias =
|
||||
deleteConn :: SQLiteStore -> ConnId -> m ()
|
||||
deleteConn SQLiteStore {dbConn} connId =
|
||||
liftIO $
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
"DELETE FROM connections WHERE conn_alias = :conn_alias;"
|
||||
[":conn_alias" := connAlias]
|
||||
[":conn_alias" := connId]
|
||||
|
||||
upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m ()
|
||||
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} =
|
||||
upgradeRcvConnToDuplex :: SQLiteStore -> ConnId -> SndQueue -> m ()
|
||||
upgradeRcvConnToDuplex SQLiteStore {dbConn} connId sq@SndQueue {server} =
|
||||
liftIOEither . withTransaction dbConn $
|
||||
getConn_ dbConn connAlias >>= \case
|
||||
getConn_ dbConn connId >>= \case
|
||||
Right (SomeConn _ RcvConnection {}) -> do
|
||||
upsertServer_ dbConn server
|
||||
insertSndQueue_ dbConn sq
|
||||
updateConnWithSndQueue_ dbConn connAlias sq
|
||||
insertSndQueue_ dbConn connId sq
|
||||
updateConnWithSndQueue_ dbConn connId sq
|
||||
pure $ Right ()
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m ()
|
||||
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} =
|
||||
upgradeSndConnToDuplex :: SQLiteStore -> ConnId -> RcvQueue -> m ()
|
||||
upgradeSndConnToDuplex SQLiteStore {dbConn} connId rq@RcvQueue {server} =
|
||||
liftIOEither . withTransaction dbConn $
|
||||
getConn_ dbConn connAlias >>= \case
|
||||
getConn_ dbConn connId >>= \case
|
||||
Right (SomeConn _ SndConnection {}) -> do
|
||||
upsertServer_ dbConn server
|
||||
insertRcvQueue_ dbConn rq
|
||||
updateConnWithRcvQueue_ dbConn connAlias rq
|
||||
insertRcvQueue_ dbConn connId rq
|
||||
updateConnWithRcvQueue_ dbConn connId rq
|
||||
pure $ Right ()
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
@@ -248,83 +268,233 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId]
|
||||
|
||||
updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} =
|
||||
updateRcvIds :: SQLiteStore -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
updateRcvIds SQLiteStore {dbConn} connId =
|
||||
liftIO . withTransaction dbConn $ do
|
||||
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias
|
||||
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connId
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
|
||||
updateLastIdsRcv_ dbConn connAlias internalId internalRcvId
|
||||
updateLastIdsRcv_ dbConn connId internalId internalRcvId
|
||||
pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash)
|
||||
|
||||
createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m ()
|
||||
createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData =
|
||||
createRcvMsg :: SQLiteStore -> ConnId -> RcvMsgData -> m ()
|
||||
createRcvMsg SQLiteStore {dbConn} connId rcvMsgData =
|
||||
liftIO . withTransaction dbConn $ do
|
||||
insertRcvMsgBase_ dbConn connAlias rcvMsgData
|
||||
insertRcvMsgDetails_ dbConn connAlias rcvMsgData
|
||||
updateHashRcv_ dbConn connAlias rcvMsgData
|
||||
insertRcvMsgBase_ dbConn connId rcvMsgData
|
||||
insertRcvMsgDetails_ dbConn connId rcvMsgData
|
||||
updateHashRcv_ dbConn connId rcvMsgData
|
||||
|
||||
updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} =
|
||||
updateSndIds :: SQLiteStore -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
updateSndIds SQLiteStore {dbConn} connId =
|
||||
liftIO . withTransaction dbConn $ do
|
||||
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias
|
||||
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connId
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
|
||||
updateLastIdsSnd_ dbConn connAlias internalId internalSndId
|
||||
updateLastIdsSnd_ dbConn connId internalId internalSndId
|
||||
pure (internalId, internalSndId, prevSndHash)
|
||||
|
||||
createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m ()
|
||||
createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData =
|
||||
createSndMsg :: SQLiteStore -> ConnId -> SndMsgData -> m ()
|
||||
createSndMsg SQLiteStore {dbConn} connId sndMsgData =
|
||||
liftIO . withTransaction dbConn $ do
|
||||
insertSndMsgBase_ dbConn connAlias sndMsgData
|
||||
insertSndMsgDetails_ dbConn connAlias sndMsgData
|
||||
updateHashSnd_ dbConn connAlias sndMsgData
|
||||
insertSndMsgBase_ dbConn connId sndMsgData
|
||||
insertSndMsgDetails_ dbConn connId sndMsgData
|
||||
updateHashSnd_ dbConn connId sndMsgData
|
||||
|
||||
getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg
|
||||
getMsg :: SQLiteStore -> ConnId -> InternalId -> m Msg
|
||||
getMsg _st _connAlias _id = throwError SENotImplemented
|
||||
|
||||
createBcast :: SQLiteStore -> BroadcastId -> m ()
|
||||
createBcast SQLiteStore {dbConn} bId =
|
||||
liftIOEither $
|
||||
checkConstraint SEBcastDuplicate $
|
||||
DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId)
|
||||
createBcast :: SQLiteStore -> TVar ChaChaDRG -> BroadcastId -> m BroadcastId
|
||||
createBcast SQLiteStore {dbConn} gVar bcastId = liftIOEither $ case bcastId of
|
||||
"" -> createWithRandomId gVar create
|
||||
bId -> checkConstraint SEBcastDuplicate $ create bId $> bId
|
||||
where
|
||||
create bId = DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId)
|
||||
|
||||
addBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m ()
|
||||
addBcastConn SQLiteStore {dbConn} bId connAlias =
|
||||
addBcastConn :: SQLiteStore -> BroadcastId -> ConnId -> m ()
|
||||
addBcastConn SQLiteStore {dbConn} bId connId =
|
||||
liftIOEither . checkBroadcast dbConn bId $
|
||||
getConn_ dbConn connAlias >>= \case
|
||||
getConn_ dbConn connId >>= \case
|
||||
Left _ -> pure $ Left SEConnNotFound
|
||||
Right (SomeConn _ RcvConnection {}) -> pure . Left $ SEBadConnType CRcv
|
||||
Right _ ->
|
||||
checkConstraint SEConnDuplicate $
|
||||
DB.execute
|
||||
dbConn
|
||||
"INSERT INTO broadcast_connections (broadcast_id, conn_alias) VALUES (?, ?);"
|
||||
(bId, connAlias)
|
||||
[sql|
|
||||
INSERT INTO broadcast_connections
|
||||
(broadcast_id, conn_alias) VALUES (?, ?);
|
||||
|]
|
||||
(bId, connId)
|
||||
|
||||
removeBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m ()
|
||||
removeBcastConn SQLiteStore {dbConn} bId connAlias =
|
||||
removeBcastConn :: SQLiteStore -> BroadcastId -> ConnId -> m ()
|
||||
removeBcastConn SQLiteStore {dbConn} bId connId =
|
||||
liftIOEither . checkBroadcast dbConn bId $
|
||||
bcastConnExists_ dbConn bId connAlias >>= \case
|
||||
bcastConnExists_ dbConn bId connId >>= \case
|
||||
False -> pure $ Left SEConnNotFound
|
||||
_ ->
|
||||
Right
|
||||
<$> DB.execute
|
||||
dbConn
|
||||
"DELETE FROM broadcast_connections WHERE broadcast_id = ? AND conn_alias = ?;"
|
||||
(bId, connAlias)
|
||||
[sql|
|
||||
DELETE FROM broadcast_connections
|
||||
WHERE broadcast_id = ? AND conn_alias = ?;
|
||||
|]
|
||||
(bId, connId)
|
||||
|
||||
deleteBcast :: SQLiteStore -> BroadcastId -> m ()
|
||||
deleteBcast SQLiteStore {dbConn} bId =
|
||||
liftIOEither . checkBroadcast dbConn bId $
|
||||
Right <$> DB.execute dbConn "DELETE FROM broadcasts WHERE broadcast_id = ?;" (Only bId)
|
||||
|
||||
getBcast :: SQLiteStore -> BroadcastId -> m [ConnAlias]
|
||||
getBcast :: SQLiteStore -> BroadcastId -> m [ConnId]
|
||||
getBcast SQLiteStore {dbConn} bId =
|
||||
liftIOEither . checkBroadcast dbConn bId $
|
||||
Right . map fromOnly
|
||||
<$> DB.query dbConn "SELECT conn_alias FROM broadcast_connections WHERE broadcast_id = ?;" (Only bId)
|
||||
|
||||
createIntro :: SQLiteStore -> TVar ChaChaDRG -> NewIntroduction -> m IntroId
|
||||
createIntro SQLiteStore {dbConn} gVar NewIntroduction {toConn, reConn, reInfo} =
|
||||
liftIOEither . createWithRandomId gVar $ \introId ->
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO conn_intros
|
||||
(intro_id, to_conn, re_conn, re_info) VALUES (?, ?, ?, ?);
|
||||
|]
|
||||
(introId, toConn, reConn, reInfo)
|
||||
|
||||
getIntro :: SQLiteStore -> IntroId -> m Introduction
|
||||
getIntro SQLiteStore {dbConn} introId =
|
||||
liftIOEither $
|
||||
intro
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT to_conn, to_info, to_status, re_conn, re_info, re_status, queue_info
|
||||
FROM conn_intros
|
||||
WHERE intro_id = ?;
|
||||
|]
|
||||
(Only introId)
|
||||
where
|
||||
intro [(toConn, toInfo, toStatus, reConn, reInfo, reStatus, qInfo)] =
|
||||
Right $ Introduction {introId, toConn, toInfo, toStatus, reConn, reInfo, reStatus, qInfo}
|
||||
intro _ = Left SEIntroNotFound
|
||||
|
||||
addIntroInvitation :: SQLiteStore -> IntroId -> EntityInfo -> SMPQueueInfo -> m ()
|
||||
addIntroInvitation SQLiteStore {dbConn} introId toInfo qInfo =
|
||||
liftIO $
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE conn_intros
|
||||
SET to_info = :to_info,
|
||||
queue_info = :queue_info,
|
||||
to_status = :to_status
|
||||
WHERE intro_id = :intro_id;
|
||||
|]
|
||||
[ ":to_info" := toInfo,
|
||||
":queue_info" := Just qInfo,
|
||||
":to_status" := IntroInv,
|
||||
":intro_id" := introId
|
||||
]
|
||||
|
||||
setIntroToStatus :: SQLiteStore -> IntroId -> IntroStatus -> m ()
|
||||
setIntroToStatus SQLiteStore {dbConn} introId toStatus =
|
||||
liftIO $
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE conn_intros
|
||||
SET to_status = ?
|
||||
WHERE intro_id = ?;
|
||||
|]
|
||||
(toStatus, introId)
|
||||
|
||||
setIntroReStatus :: SQLiteStore -> IntroId -> IntroStatus -> m ()
|
||||
setIntroReStatus SQLiteStore {dbConn} introId reStatus =
|
||||
liftIO $
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE conn_intros
|
||||
SET re_status = ?
|
||||
WHERE intro_id = ?;
|
||||
|]
|
||||
(reStatus, introId)
|
||||
|
||||
createInvitation :: SQLiteStore -> TVar ChaChaDRG -> NewInvitation -> m InvitationId
|
||||
createInvitation SQLiteStore {dbConn} gVar NewInvitation {viaConn, externalIntroId, entityInfo, qInfo} =
|
||||
liftIOEither . createWithRandomId gVar $ \invId ->
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO conn_invitations
|
||||
(inv_id, via_conn, external_intro_id, conn_info, queue_info) VALUES (?, ?, ?, ?, ?);
|
||||
|]
|
||||
(invId, viaConn, externalIntroId, entityInfo, qInfo)
|
||||
|
||||
getInvitation :: SQLiteStore -> InvitationId -> m Invitation
|
||||
getInvitation SQLiteStore {dbConn} invId =
|
||||
liftIOEither $
|
||||
invitation
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT via_conn, external_intro_id, conn_info, queue_info, conn_id, status
|
||||
FROM conn_invitations
|
||||
WHERE inv_id = ?;
|
||||
|]
|
||||
(Only invId)
|
||||
where
|
||||
invitation [(viaConn, externalIntroId, entityInfo, qInfo, connId, status)] =
|
||||
Right $ Invitation {invId, viaConn, externalIntroId, entityInfo, qInfo, connId, status}
|
||||
invitation _ = Left SEInvitationNotFound
|
||||
|
||||
addInvitationConn :: SQLiteStore -> InvitationId -> ConnId -> m ()
|
||||
addInvitationConn SQLiteStore {dbConn} invId connId =
|
||||
liftIO $
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE conn_invitations
|
||||
SET conn_id = :conn_id, status = :status
|
||||
WHERE inv_id = :inv_id;
|
||||
|]
|
||||
[":conn_id" := connId, ":status" := InvAcpt, ":inv_id" := invId]
|
||||
|
||||
getConnInvitation :: SQLiteStore -> ConnId -> m (Maybe (Invitation, Connection 'CDuplex))
|
||||
getConnInvitation SQLiteStore {dbConn} cId =
|
||||
liftIO . withTransaction dbConn $
|
||||
DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT inv_id, via_conn, external_intro_id, conn_info, queue_info, status
|
||||
FROM conn_invitations
|
||||
WHERE conn_id = ?;
|
||||
|]
|
||||
(Only cId)
|
||||
>>= fmap join . traverse getViaConn . invitation
|
||||
where
|
||||
invitation [(invId, viaConn, externalIntroId, entityInfo, qInfo, status)] =
|
||||
Just $ Invitation {invId, viaConn, externalIntroId, entityInfo, qInfo, connId = Just cId, status}
|
||||
invitation _ = Nothing
|
||||
getViaConn :: Invitation -> IO (Maybe (Invitation, Connection 'CDuplex))
|
||||
getViaConn inv@Invitation {viaConn} = fmap (inv,) . duplexConn <$> getConn_ dbConn viaConn
|
||||
duplexConn :: Either StoreError SomeConn -> Maybe (Connection 'CDuplex)
|
||||
duplexConn (Right (SomeConn SCDuplex conn)) = Just conn
|
||||
duplexConn _ = Nothing
|
||||
|
||||
setInvitationStatus :: SQLiteStore -> InvitationId -> InvitationStatus -> m ()
|
||||
setInvitationStatus SQLiteStore {dbConn} invId status =
|
||||
liftIO $
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE conn_invitations
|
||||
SET status = ? WHERE inv_id = ?;
|
||||
|]
|
||||
(status, invId)
|
||||
|
||||
-- * Auxiliary helpers
|
||||
|
||||
-- ? replace with ToField? - it's easy to forget to use this
|
||||
@@ -337,7 +507,7 @@ deserializePort_ port = Just port
|
||||
|
||||
instance ToField QueueStatus where toField = toField . show
|
||||
|
||||
instance FromField QueueStatus where fromField = fromFieldToReadable_
|
||||
instance FromField QueueStatus where fromField = fromTextField_ $ readMaybe . T.unpack
|
||||
|
||||
instance ToField InternalRcvId where toField (InternalRcvId x) = toField x
|
||||
|
||||
@@ -359,13 +529,24 @@ instance ToField MsgIntegrity where toField = toField . serializeMsgIntegrity
|
||||
|
||||
instance FromField MsgIntegrity where fromField = blobFieldParser msgIntegrityP
|
||||
|
||||
fromFieldToReadable_ :: forall a. (Read a, E.Typeable a) => Field -> Ok a
|
||||
fromFieldToReadable_ = \case
|
||||
instance ToField IntroStatus where toField = toField . serializeIntroStatus
|
||||
|
||||
instance FromField IntroStatus where fromField = fromTextField_ introStatusT
|
||||
|
||||
instance ToField InvitationStatus where toField = toField . serializeInvStatus
|
||||
|
||||
instance FromField InvitationStatus where fromField = fromTextField_ invStatusT
|
||||
|
||||
instance ToField SMPQueueInfo where toField = toField . serializeSmpQueueInfo
|
||||
|
||||
instance FromField SMPQueueInfo where fromField = blobFieldParser smpQueueInfoP
|
||||
|
||||
fromTextField_ :: (E.Typeable a) => (Text -> Maybe a) -> Field -> Ok a
|
||||
fromTextField_ fromText = \case
|
||||
f@(Field (SQLText t) _) ->
|
||||
let str = T.unpack t
|
||||
in case readMaybe str of
|
||||
Just x -> Ok x
|
||||
_ -> returnError ConversionFailed f ("invalid string: " <> str)
|
||||
case fromText t of
|
||||
Just x -> Ok x
|
||||
_ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t)
|
||||
f -> returnError ConversionFailed f "expecting SQLText column type"
|
||||
|
||||
{- ORMOLU_DISABLE -}
|
||||
@@ -397,8 +578,8 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do
|
||||
|
||||
-- * createRcvConn helpers
|
||||
|
||||
insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO ()
|
||||
insertRcvQueue_ dbConn RcvQueue {..} = do
|
||||
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
||||
insertRcvQueue_ dbConn connId RcvQueue {..} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
@@ -411,7 +592,7 @@ insertRcvQueue_ dbConn RcvQueue {..} = do
|
||||
[ ":host" := host server,
|
||||
":port" := port_,
|
||||
":rcv_id" := rcvId,
|
||||
":conn_alias" := connAlias,
|
||||
":conn_alias" := connId,
|
||||
":rcv_private_key" := rcvPrivateKey,
|
||||
":snd_id" := sndId,
|
||||
":snd_key" := sndKey,
|
||||
@@ -420,26 +601,29 @@ insertRcvQueue_ dbConn RcvQueue {..} = do
|
||||
":status" := status
|
||||
]
|
||||
|
||||
insertRcvConnection_ :: DB.Connection -> RcvQueue -> IO ()
|
||||
insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do
|
||||
insertRcvConnection_ :: DB.Connection -> ConnData -> RcvQueue -> IO ()
|
||||
insertRcvConnection_ dbConn ConnData {connId, viaInv, connLevel} RcvQueue {server, rcvId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
|
||||
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, via_inv, conn_level, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
VALUES
|
||||
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL,
|
||||
0, 0, 0, 0, x'', x'');
|
||||
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, :via_inv,:conn_level, 0, 0, 0, 0, x'', x'');
|
||||
|]
|
||||
[":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId]
|
||||
[ ":conn_alias" := connId,
|
||||
":rcv_host" := host server,
|
||||
":rcv_port" := port_,
|
||||
":rcv_id" := rcvId,
|
||||
":via_inv" := viaInv,
|
||||
":conn_level" := connLevel
|
||||
]
|
||||
|
||||
-- * createSndConn helpers
|
||||
|
||||
insertSndQueue_ :: DB.Connection -> SndQueue -> IO ()
|
||||
insertSndQueue_ dbConn SndQueue {..} = do
|
||||
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
insertSndQueue_ dbConn connId SndQueue {..} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
@@ -452,85 +636,96 @@ insertSndQueue_ dbConn SndQueue {..} = do
|
||||
[ ":host" := host server,
|
||||
":port" := port_,
|
||||
":snd_id" := sndId,
|
||||
":conn_alias" := connAlias,
|
||||
":conn_alias" := connId,
|
||||
":snd_private_key" := sndPrivateKey,
|
||||
":encrypt_key" := encryptKey,
|
||||
":sign_key" := signKey,
|
||||
":status" := status
|
||||
]
|
||||
|
||||
insertSndConnection_ :: DB.Connection -> SndQueue -> IO ()
|
||||
insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do
|
||||
insertSndConnection_ :: DB.Connection -> ConnData -> SndQueue -> IO ()
|
||||
insertSndConnection_ dbConn ConnData {connId, viaInv, connLevel} SndQueue {server, sndId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
|
||||
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, via_inv, conn_level, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
VALUES
|
||||
(:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id,
|
||||
0, 0, 0, 0, x'', x'');
|
||||
(:conn_alias, NULL, NULL, NULL, :snd_host,:snd_port,:snd_id,:via_inv,:conn_level, 0, 0, 0, 0, x'', x'');
|
||||
|]
|
||||
[":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId]
|
||||
[ ":conn_alias" := connId,
|
||||
":snd_host" := host server,
|
||||
":snd_port" := port_,
|
||||
":snd_id" := sndId,
|
||||
":via_inv" := viaInv,
|
||||
":conn_level" := connLevel
|
||||
]
|
||||
|
||||
-- * getConn helpers
|
||||
|
||||
getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn)
|
||||
getConn_ dbConn connAlias = do
|
||||
rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias
|
||||
sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias
|
||||
pure $ case (rQ, sQ) of
|
||||
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
|
||||
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ)
|
||||
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ)
|
||||
_ -> Left SEConnNotFound
|
||||
getConn_ :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getConn_ dbConn connId =
|
||||
getConnData_ dbConn connId >>= \case
|
||||
Nothing -> pure $ Left SEConnNotFound
|
||||
Just connData -> do
|
||||
rQ <- getRcvQueueByConnAlias_ dbConn connId
|
||||
sQ <- getSndQueueByConnAlias_ dbConn connId
|
||||
pure $ case (rQ, sQ) of
|
||||
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ)
|
||||
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ)
|
||||
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connData sndQ)
|
||||
_ -> Left SEConnNotFound
|
||||
|
||||
retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue)
|
||||
retrieveRcvQueueByConnAlias_ dbConn connAlias = do
|
||||
r <-
|
||||
DB.queryNamed
|
||||
getConnData_ :: DB.Connection -> ConnId -> IO (Maybe ConnData)
|
||||
getConnData_ dbConn connId =
|
||||
connData
|
||||
<$> DB.query dbConn "SELECT via_inv, conn_level FROM connections WHERE conn_alias = ?;" (Only connId)
|
||||
where
|
||||
connData [(viaInv, connLevel)] = Just ConnData {connId, viaInv, connLevel}
|
||||
connData _ = Nothing
|
||||
|
||||
getRcvQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue)
|
||||
getRcvQueueByConnAlias_ dbConn connId =
|
||||
rcvQueue
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT
|
||||
s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key,
|
||||
SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key,
|
||||
q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status
|
||||
FROM rcv_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.conn_alias = :conn_alias;
|
||||
WHERE q.conn_alias = ?;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
case r of
|
||||
[(keyHash, host, port, rcvId, cAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do
|
||||
(Only connId)
|
||||
where
|
||||
rcvQueue [(keyHash, host, port, rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] =
|
||||
let srv = SMPServer host (deserializePort_ port) keyHash
|
||||
return . Just $ RcvQueue srv rcvId cAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status
|
||||
_ -> return Nothing
|
||||
in Just $ RcvQueue srv rcvId rcvPrivateKey sndId sndKey decryptKey verifyKey status
|
||||
rcvQueue _ = Nothing
|
||||
|
||||
retrieveSndQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe SndQueue)
|
||||
retrieveSndQueueByConnAlias_ dbConn connAlias = do
|
||||
r <-
|
||||
DB.queryNamed
|
||||
getSndQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue)
|
||||
getSndQueueByConnAlias_ dbConn connId =
|
||||
sndQueue
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT
|
||||
s.key_hash, q.host, q.port, q.snd_id, q.conn_alias,
|
||||
q.snd_private_key, q.encrypt_key, q.sign_key, q.status
|
||||
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_private_key, q.encrypt_key, q.sign_key, q.status
|
||||
FROM snd_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.conn_alias = :conn_alias;
|
||||
WHERE q.conn_alias = ?;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
case r of
|
||||
[(keyHash, host, port, sndId, cAlias, sndPrivateKey, encryptKey, signKey, status)] -> do
|
||||
(Only connId)
|
||||
where
|
||||
sndQueue [(keyHash, host, port, sndId, sndPrivateKey, encryptKey, signKey, status)] =
|
||||
let srv = SMPServer host (deserializePort_ port) keyHash
|
||||
return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status
|
||||
_ -> return Nothing
|
||||
in Just $ SndQueue srv sndId sndPrivateKey encryptKey signKey status
|
||||
sndQueue _ = Nothing
|
||||
|
||||
-- * upgradeRcvConnToDuplex helpers
|
||||
|
||||
updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO ()
|
||||
updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
|
||||
updateConnWithSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
updateConnWithSndQueue_ dbConn connId SndQueue {server, sndId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
@@ -539,12 +734,12 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
|
||||
SET snd_host = :snd_host, snd_port = :snd_port, snd_id = :snd_id
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connAlias]
|
||||
[":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connId]
|
||||
|
||||
-- * upgradeSndConnToDuplex helpers
|
||||
|
||||
updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO ()
|
||||
updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
|
||||
updateConnWithRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
||||
updateConnWithRcvQueue_ dbConn connId RcvQueue {server, rcvId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
@@ -553,12 +748,12 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
|
||||
SET rcv_host = :rcv_host, rcv_port = :rcv_port, rcv_id = :rcv_id
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias]
|
||||
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connId]
|
||||
|
||||
-- * updateRcvIds helpers
|
||||
|
||||
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
retrieveLastIdsAndHashRcv_ dbConn connAlias = do
|
||||
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
retrieveLastIdsAndHashRcv_ dbConn connId = do
|
||||
[(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
@@ -567,11 +762,11 @@ retrieveLastIdsAndHashRcv_ dbConn connAlias = do
|
||||
FROM connections
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
[":conn_alias" := connId]
|
||||
return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)
|
||||
|
||||
updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO ()
|
||||
updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
|
||||
updateLastIdsRcv_ :: DB.Connection -> ConnId -> InternalId -> InternalRcvId -> IO ()
|
||||
updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -582,13 +777,13 @@ updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
|
||||
|]
|
||||
[ ":last_internal_msg_id" := newInternalId,
|
||||
":last_internal_rcv_msg_id" := newInternalRcvId,
|
||||
":conn_alias" := connAlias
|
||||
":conn_alias" := connId
|
||||
]
|
||||
|
||||
-- * createRcvMsg helpers
|
||||
|
||||
insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
|
||||
insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
||||
insertRcvMsgBase_ dbConn connId RcvMsgData {..} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -597,15 +792,15 @@ insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
|
||||
VALUES
|
||||
(:conn_alias,:internal_id,:internal_ts,:internal_rcv_id, NULL,:body);
|
||||
|]
|
||||
[ ":conn_alias" := connAlias,
|
||||
[ ":conn_alias" := connId,
|
||||
":internal_id" := internalId,
|
||||
":internal_ts" := internalTs,
|
||||
":internal_rcv_id" := internalRcvId,
|
||||
":body" := decodeUtf8 msgBody
|
||||
]
|
||||
|
||||
insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
|
||||
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
||||
insertRcvMsgDetails_ dbConn connId RcvMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -618,7 +813,7 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
|
||||
:broker_id,:broker_ts,:rcv_status, NULL, NULL,
|
||||
:internal_hash,:external_prev_snd_hash,:integrity);
|
||||
|]
|
||||
[ ":conn_alias" := connAlias,
|
||||
[ ":conn_alias" := connId,
|
||||
":internal_rcv_id" := internalRcvId,
|
||||
":internal_id" := internalId,
|
||||
":external_snd_id" := fst senderMeta,
|
||||
@@ -631,8 +826,8 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
|
||||
":integrity" := msgIntegrity
|
||||
]
|
||||
|
||||
updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
updateHashRcv_ dbConn connAlias RcvMsgData {..} =
|
||||
updateHashRcv_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
||||
updateHashRcv_ dbConn connId RcvMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
-- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved
|
||||
@@ -645,14 +840,14 @@ updateHashRcv_ dbConn connAlias RcvMsgData {..} =
|
||||
|]
|
||||
[ ":last_external_snd_msg_id" := fst senderMeta,
|
||||
":last_rcv_msg_hash" := internalHash,
|
||||
":conn_alias" := connAlias,
|
||||
":conn_alias" := connId,
|
||||
":last_internal_rcv_msg_id" := internalRcvId
|
||||
]
|
||||
|
||||
-- * updateSndIds helpers
|
||||
|
||||
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
retrieveLastIdsAndHashSnd_ dbConn connAlias = do
|
||||
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
retrieveLastIdsAndHashSnd_ dbConn connId = do
|
||||
[(lastInternalId, lastInternalSndId, lastSndHash)] <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
@@ -661,11 +856,11 @@ retrieveLastIdsAndHashSnd_ dbConn connAlias = do
|
||||
FROM connections
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
[":conn_alias" := connId]
|
||||
return (lastInternalId, lastInternalSndId, lastSndHash)
|
||||
|
||||
updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO ()
|
||||
updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
|
||||
updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO ()
|
||||
updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -676,13 +871,13 @@ updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
|
||||
|]
|
||||
[ ":last_internal_msg_id" := newInternalId,
|
||||
":last_internal_snd_msg_id" := newInternalSndId,
|
||||
":conn_alias" := connAlias
|
||||
":conn_alias" := connId
|
||||
]
|
||||
|
||||
-- * createSndMsg helpers
|
||||
|
||||
insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
|
||||
insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
||||
insertSndMsgBase_ dbConn connId SndMsgData {..} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -691,15 +886,15 @@ insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
|
||||
VALUES
|
||||
(:conn_alias,:internal_id,:internal_ts, NULL,:internal_snd_id,:body);
|
||||
|]
|
||||
[ ":conn_alias" := connAlias,
|
||||
[ ":conn_alias" := connId,
|
||||
":internal_id" := internalId,
|
||||
":internal_ts" := internalTs,
|
||||
":internal_snd_id" := internalSndId,
|
||||
":body" := decodeUtf8 msgBody
|
||||
]
|
||||
|
||||
insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
|
||||
insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
||||
insertSndMsgDetails_ dbConn connId SndMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -708,15 +903,15 @@ insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
|
||||
VALUES
|
||||
(:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash);
|
||||
|]
|
||||
[ ":conn_alias" := connAlias,
|
||||
[ ":conn_alias" := connId,
|
||||
":internal_snd_id" := internalSndId,
|
||||
":internal_id" := internalId,
|
||||
":snd_status" := Created,
|
||||
":internal_hash" := internalHash
|
||||
]
|
||||
|
||||
updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
updateHashSnd_ dbConn connAlias SndMsgData {..} =
|
||||
updateHashSnd_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
||||
updateHashSnd_ dbConn connId SndMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
-- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved
|
||||
@@ -727,7 +922,7 @@ updateHashSnd_ dbConn connAlias SndMsgData {..} =
|
||||
AND last_internal_snd_msg_id = :last_internal_snd_msg_id;
|
||||
|]
|
||||
[ ":last_snd_msg_hash" := internalHash,
|
||||
":conn_alias" := connAlias,
|
||||
":conn_alias" := connId,
|
||||
":last_internal_snd_msg_id" := internalSndId
|
||||
]
|
||||
|
||||
@@ -745,10 +940,10 @@ bcastExists_ dbConn bId = not . null <$> queryBcast
|
||||
queryBcast :: IO [Only BroadcastId]
|
||||
queryBcast = DB.query dbConn "SELECT broadcast_id FROM broadcasts WHERE broadcast_id = ?;" (Only bId)
|
||||
|
||||
bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnAlias -> IO Bool
|
||||
bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn
|
||||
bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnId -> IO Bool
|
||||
bcastConnExists_ dbConn bId connId = not . null <$> queryBcastConn
|
||||
where
|
||||
queryBcastConn :: IO [(BroadcastId, ConnAlias)]
|
||||
queryBcastConn :: IO [(BroadcastId, ConnId)]
|
||||
queryBcastConn =
|
||||
DB.query
|
||||
dbConn
|
||||
@@ -757,4 +952,37 @@ bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn
|
||||
FROM broadcast_connections
|
||||
WHERE broadcast_id = ? AND conn_alias = ?;
|
||||
|]
|
||||
(bId, connAlias)
|
||||
(bId, connId)
|
||||
|
||||
-- create record with a random ID
|
||||
|
||||
getConnId_ :: DB.Connection -> TVar ChaChaDRG -> ConnData -> IO (Either StoreError ConnId)
|
||||
getConnId_ dbConn gVar ConnData {connId = ""} = getUniqueRandomId gVar $ getConnData_ dbConn
|
||||
getConnId_ _ _ ConnData {connId} = pure $ Right connId
|
||||
|
||||
getUniqueRandomId :: TVar ChaChaDRG -> (ByteString -> IO (Maybe a)) -> IO (Either StoreError ByteString)
|
||||
getUniqueRandomId gVar get = tryGet 3
|
||||
where
|
||||
tryGet :: Int -> IO (Either StoreError ByteString)
|
||||
tryGet 0 = pure $ Left SEUniqueID
|
||||
tryGet n = do
|
||||
id' <- randomId gVar 12
|
||||
get id' >>= \case
|
||||
Nothing -> pure $ Right id'
|
||||
Just _ -> tryGet (n - 1)
|
||||
|
||||
createWithRandomId :: TVar ChaChaDRG -> (ByteString -> IO ()) -> IO (Either StoreError ByteString)
|
||||
createWithRandomId gVar create = tryCreate 3
|
||||
where
|
||||
tryCreate :: Int -> IO (Either StoreError ByteString)
|
||||
tryCreate 0 = pure $ Left SEUniqueID
|
||||
tryCreate n = do
|
||||
id' <- randomId gVar 12
|
||||
E.try (create id') >>= \case
|
||||
Right _ -> pure $ Right id'
|
||||
Left e
|
||||
| DB.sqlError e == DB.ErrorConstraint -> tryCreate (n - 1)
|
||||
| otherwise -> pure . Left . SEInternal $ bshow e
|
||||
|
||||
randomId :: TVar ChaChaDRG -> Int -> IO ByteString
|
||||
randomId gVar n = encode <$> (atomically . stateTVar gVar $ randomBytesGenerate n)
|
||||
|
||||
@@ -30,7 +30,7 @@ base64StringP = do
|
||||
pure $ str <> pad
|
||||
|
||||
tsISO8601P :: Parser UTCTime
|
||||
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ')
|
||||
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd
|
||||
|
||||
parse :: Parser a -> e -> (ByteString -> Either e a)
|
||||
parse parser err = first (const err) . parseAll parser
|
||||
@@ -42,14 +42,17 @@ parseRead :: Read a => Parser ByteString -> Parser a
|
||||
parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack)
|
||||
|
||||
parseRead1 :: Read a => Parser a
|
||||
parseRead1 = parseRead $ A.takeTill (== ' ')
|
||||
parseRead1 = parseRead $ A.takeTill wordEnd
|
||||
|
||||
parseRead2 :: Read a => Parser a
|
||||
parseRead2 = parseRead $ do
|
||||
w1 <- A.takeTill (== ' ') <* A.char ' '
|
||||
w2 <- A.takeTill (== ' ')
|
||||
w1 <- A.takeTill wordEnd <* A.char ' '
|
||||
w2 <- A.takeTill wordEnd
|
||||
pure $ w1 <> " " <> w2
|
||||
|
||||
wordEnd :: Char -> Bool
|
||||
wordEnd c = c == ' ' || c == '\n'
|
||||
|
||||
parseString :: (ByteString -> Either String a) -> (String -> a)
|
||||
parseString p = either error id . p . B.pack
|
||||
|
||||
|
||||
Reference in New Issue
Block a user