mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 16:26:02 +00:00
random connection and message IDs
This commit is contained in:
@@ -13,8 +13,10 @@ extra-source-files:
|
||||
|
||||
dependencies:
|
||||
- base >= 4.7 && < 5
|
||||
- base64-bytestring >= 1.0 && < 1.3
|
||||
- bytestring
|
||||
- containers
|
||||
- cryptonite == 0.26.*
|
||||
- iso8601-time == 0.1.*
|
||||
- mtl
|
||||
- network
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module ConnStore where
|
||||
|
||||
@@ -9,8 +10,8 @@ import Transmission
|
||||
|
||||
data Connection = Connection
|
||||
{ recipientId :: ConnId,
|
||||
recipientKey :: PublicKey,
|
||||
senderId :: ConnId,
|
||||
recipientKey :: PublicKey,
|
||||
senderKey :: Maybe PublicKey,
|
||||
status :: ConnStatus
|
||||
}
|
||||
@@ -18,19 +19,19 @@ data Connection = Connection
|
||||
data ConnStatus = ConnActive | ConnOff
|
||||
|
||||
class MonadConnStore s m where
|
||||
addConn :: s -> RecipientKey -> m (Either ErrorType Connection)
|
||||
addConn :: s -> m (RecipientId, SenderId) -> RecipientKey -> m (Either ErrorType Connection)
|
||||
getConn :: s -> Sing (a :: Party) -> ConnId -> m (Either ErrorType Connection)
|
||||
secureConn :: s -> RecipientId -> SenderKey -> m (Either ErrorType ())
|
||||
suspendConn :: s -> RecipientId -> m (Either ErrorType ())
|
||||
deleteConn :: s -> RecipientId -> m (Either ErrorType ())
|
||||
|
||||
-- TODO stub
|
||||
mkConnection :: RecipientKey -> Connection
|
||||
mkConnection rKey =
|
||||
mkConnection :: (RecipientId, SenderId) -> RecipientKey -> Connection
|
||||
mkConnection (recipientId, senderId) recipientKey =
|
||||
Connection
|
||||
{ recipientId = "1",
|
||||
recipientKey = rKey,
|
||||
senderId = "2",
|
||||
{ recipientId,
|
||||
senderId,
|
||||
recipientKey,
|
||||
senderKey = Nothing,
|
||||
status = ConnActive
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
@@ -29,45 +31,57 @@ newConnStore :: STM STMConnStore
|
||||
newConnStore = newTVar ConnStoreData {connections = M.empty, senders = M.empty}
|
||||
|
||||
instance MonadUnliftIO m => MonadConnStore STMConnStore m where
|
||||
addConn :: STMConnStore -> RecipientKey -> m (Either ErrorType Connection)
|
||||
addConn store rKey = atomically $ do
|
||||
let c@Connection {recipientId = rId, senderId = sId} = mkConnection rKey
|
||||
modifyTVar store $ \db ->
|
||||
db
|
||||
{ connections = M.insert rId c (connections db),
|
||||
senders = M.insert sId rId (senders db)
|
||||
}
|
||||
return $ Right c
|
||||
addConn :: STMConnStore -> m (RecipientId, SenderId) -> RecipientKey -> m (Either ErrorType Connection)
|
||||
addConn = _addConn (3 :: Int)
|
||||
where
|
||||
_addConn 0 _ _ _ = return $ Left INTERNAL
|
||||
_addConn retry store getIds rKey = do
|
||||
getIds >>= atomically . insertConn >>= \case
|
||||
Nothing -> _addConn (retry - 1) store getIds rKey
|
||||
Just c -> return $ Right c
|
||||
where
|
||||
insertConn ids@(rId, sId) = do
|
||||
cs@ConnStoreData {connections, senders} <- readTVar store
|
||||
if M.member rId connections || M.member sId senders
|
||||
then return Nothing
|
||||
else do
|
||||
let c = mkConnection ids rKey
|
||||
writeTVar store $
|
||||
cs
|
||||
{ connections = M.insert rId c connections,
|
||||
senders = M.insert sId rId senders
|
||||
}
|
||||
return $ Just c
|
||||
|
||||
getConn :: STMConnStore -> Sing (p :: Party) -> ConnId -> m (Either ErrorType Connection)
|
||||
getConn store SRecipient rId = atomically $ do
|
||||
db <- readTVar store
|
||||
return $ getRcpConn db rId
|
||||
cs <- readTVar store
|
||||
return $ getRcpConn cs rId
|
||||
getConn store SSender sId = atomically $ do
|
||||
db <- readTVar store
|
||||
let rId = M.lookup sId $ senders db
|
||||
return $ maybe (Left AUTH) (getRcpConn db) rId
|
||||
cs <- readTVar store
|
||||
let rId = M.lookup sId $ senders cs
|
||||
return $ maybe (Left AUTH) (getRcpConn cs) rId
|
||||
getConn _ SBroker _ =
|
||||
return $ Left INTERNAL
|
||||
|
||||
secureConn store rId sKey =
|
||||
updateConnections store rId $ \db c ->
|
||||
updateConnections store rId $ \cs c ->
|
||||
case senderKey c of
|
||||
Just _ -> (Left AUTH, db)
|
||||
_ -> (Right (), db {connections = M.insert rId c {senderKey = Just sKey} (connections db)})
|
||||
Just _ -> (Left AUTH, cs)
|
||||
_ -> (Right (), cs {connections = M.insert rId c {senderKey = Just sKey} (connections cs)})
|
||||
|
||||
suspendConn :: STMConnStore -> RecipientId -> m (Either ErrorType ())
|
||||
suspendConn store rId =
|
||||
updateConnections store rId $ \db c ->
|
||||
(Right (), db {connections = M.insert rId c {status = ConnOff} (connections db)})
|
||||
updateConnections store rId $ \cs c ->
|
||||
(Right (), cs {connections = M.insert rId c {status = ConnOff} (connections cs)})
|
||||
|
||||
deleteConn :: STMConnStore -> RecipientId -> m (Either ErrorType ())
|
||||
deleteConn store rId =
|
||||
updateConnections store rId $ \db c ->
|
||||
updateConnections store rId $ \cs c ->
|
||||
( Right (),
|
||||
db
|
||||
{ connections = M.delete rId (connections db),
|
||||
senders = M.delete (senderId c) (senders db)
|
||||
cs
|
||||
{ connections = M.delete rId (connections cs),
|
||||
senders = M.delete (senderId c) (senders cs)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -78,14 +92,14 @@ updateConnections ::
|
||||
(ConnStoreData -> Connection -> (Either ErrorType (), ConnStoreData)) ->
|
||||
m (Either ErrorType ())
|
||||
updateConnections store rId update = atomically $ do
|
||||
db <- readTVar store
|
||||
let conn = getRcpConn db rId
|
||||
either (return . Left) (_update db) conn
|
||||
cs <- readTVar store
|
||||
let conn = getRcpConn cs rId
|
||||
either (return . Left) (_update cs) conn
|
||||
where
|
||||
_update db c = do
|
||||
let (res, db') = update db c
|
||||
writeTVar store db'
|
||||
_update cs c = do
|
||||
let (res, cs') = update cs c
|
||||
writeTVar store cs'
|
||||
return res
|
||||
|
||||
getRcpConn :: ConnStoreData -> RecipientId -> Either ErrorType Connection
|
||||
getRcpConn db rId = maybe (Left AUTH) Right . M.lookup rId $ connections db
|
||||
getRcpConn cs rId = maybe (Left AUTH) Right . M.lookup rId $ connections cs
|
||||
|
||||
@@ -4,21 +4,24 @@
|
||||
module Env.STM where
|
||||
|
||||
import ConnStore.STM
|
||||
import Control.Concurrent
|
||||
import Control.Concurrent.STM
|
||||
import Control.Concurrent (ThreadId)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import MsgStore.STM
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import Transmission
|
||||
import UnliftIO.STM
|
||||
|
||||
data Env = Env
|
||||
{ tcpPort :: ServiceName,
|
||||
queueSize :: Natural,
|
||||
server :: Server,
|
||||
connStore :: STMConnStore,
|
||||
msgStore :: STMMsgStore
|
||||
msgStore :: STMMsgStore,
|
||||
idsDrg :: TVar ChaChaDRG
|
||||
}
|
||||
|
||||
data Server = Server
|
||||
@@ -45,9 +48,10 @@ newClient qSize = do
|
||||
sndQ <- newTBQueue qSize
|
||||
return Client {connections, rcvQ, sndQ}
|
||||
|
||||
newEnv :: String -> Natural -> STM Env
|
||||
newEnv :: (MonadUnliftIO m, MonadRandom m) => String -> Natural -> m Env
|
||||
newEnv tcpPort queueSize = do
|
||||
server <- newServer queueSize
|
||||
connStore <- newConnStore
|
||||
msgStore <- newMsgStore
|
||||
return Env {tcpPort, queueSize, server, connStore, msgStore}
|
||||
server <- atomically $ newServer queueSize
|
||||
connStore <- atomically newConnStore
|
||||
msgStore <- atomically newMsgStore
|
||||
idsDrg <- drgNew >>= newTVarIO
|
||||
return Env {tcpPort, queueSize, server, connStore, msgStore, idsDrg}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module MsgStore where
|
||||
|
||||
import Control.Monad.IO.Class
|
||||
import Data.Time.Clock
|
||||
import Transmission
|
||||
|
||||
@@ -22,13 +20,3 @@ class MonadMsgQueue q m where
|
||||
tryPeekMsg :: q -> m (Maybe Message) -- non blocking
|
||||
peekMsg :: q -> m Message -- blocking
|
||||
tryDelPeekMsg :: q -> m (Maybe Message) -- atomic delete (== read) last and peek next message, if available
|
||||
|
||||
newMessage :: MonadIO m => MsgBody -> m Message
|
||||
newMessage msgBody = do
|
||||
ts <- liftIO getCurrentTime
|
||||
return
|
||||
Message
|
||||
{ msgId = "1",
|
||||
ts,
|
||||
msgBody
|
||||
}
|
||||
|
||||
@@ -15,8 +15,13 @@ import ConnStore
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Singletons
|
||||
import Data.Time.Clock
|
||||
import Env.STM
|
||||
import MsgStore
|
||||
import MsgStore.STM (MsgQueue)
|
||||
@@ -29,9 +34,9 @@ import UnliftIO.Concurrent
|
||||
import UnliftIO.IO
|
||||
import UnliftIO.STM
|
||||
|
||||
runSMPServer :: MonadUnliftIO m => ServiceName -> Natural -> m ()
|
||||
runSMPServer :: (MonadRandom m, MonadUnliftIO m) => ServiceName -> Natural -> m ()
|
||||
runSMPServer port queueSize = do
|
||||
env <- atomically $ newEnv port queueSize
|
||||
env <- newEnv port queueSize
|
||||
runReaderT smpServer env
|
||||
where
|
||||
smpServer :: (MonadUnliftIO m, MonadReader Env m) => m ()
|
||||
@@ -131,12 +136,13 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} =
|
||||
createConn :: MonadConnStore s m => s -> RecipientKey -> m Signed
|
||||
createConn st rKey = mkSigned "" <$> addSubscribe
|
||||
where
|
||||
addSubscribe =
|
||||
addConn st rKey >>= \case
|
||||
addSubscribe = do
|
||||
addConn st getIds rKey >>= \case
|
||||
Right Connection {recipientId = rId, senderId = sId} -> do
|
||||
void $ subscribeConn rId
|
||||
return $ IDS rId sId
|
||||
Left e -> return $ ERR e
|
||||
getIds = liftM2 (,) (randomId 16) (randomId 16)
|
||||
|
||||
subscribeConn :: RecipientId -> m Signed
|
||||
subscribeConn rId = do
|
||||
@@ -165,8 +171,9 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} =
|
||||
ConnActive -> do
|
||||
ms <- asks msgStore
|
||||
q <- getMsgQueue ms (recipientId c)
|
||||
msg <- newMessage msgBody
|
||||
writeMsg q msg
|
||||
msgId <- randomId 8
|
||||
ts <- liftIO getCurrentTime
|
||||
writeMsg q $ Message {msgId, ts, msgBody}
|
||||
return OK
|
||||
ConnOff -> return $ ERR AUTH
|
||||
|
||||
@@ -199,3 +206,15 @@ client clnt@Client {connections, rcvQ, sndQ} Server {subscribedQ} =
|
||||
|
||||
msgResponse :: RecipientId -> Message -> Signed
|
||||
msgResponse rId Message {msgId, ts, msgBody} = mkSigned rId $ MSG msgId ts msgBody
|
||||
|
||||
randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m Encoded
|
||||
randomId n = do
|
||||
gVar <- asks idsDrg
|
||||
B.unpack . encode <$> atomically (randomBytes n gVar)
|
||||
|
||||
randomBytes :: Int -> TVar ChaChaDRG -> STM ByteString
|
||||
randomBytes n gVar = do
|
||||
g <- readTVar gVar
|
||||
let (bytes, g') = randomBytesGenerate n g
|
||||
writeTVar gVar g'
|
||||
return bytes
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
module SMPClient where
|
||||
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Network.Socket
|
||||
import Numeric.Natural
|
||||
import Server
|
||||
@@ -14,7 +15,7 @@ import UnliftIO.IO
|
||||
|
||||
testSMPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a
|
||||
testSMPClient host port client = do
|
||||
threadDelay 1 -- TODO hack: thread delay for SMP server to start
|
||||
threadDelay 100 -- TODO hack: thread delay for SMP server to start
|
||||
runTCPClient host port $ \h -> do
|
||||
line <- getLn h
|
||||
if line == "Welcome to SMP"
|
||||
@@ -32,7 +33,7 @@ queueSize = 2
|
||||
|
||||
type TestTransmission = (Signature, ConnId, String)
|
||||
|
||||
runSmpTest :: MonadUnliftIO m => (Handle -> m a) -> m a
|
||||
runSmpTest :: (MonadUnliftIO m, MonadRandom m) => (Handle -> m a) -> m a
|
||||
runSmpTest test =
|
||||
E.bracket
|
||||
(forkIO $ runSMPServer testPort queueSize)
|
||||
|
||||
@@ -47,7 +47,7 @@ testCreateSecure = do
|
||||
Resp _ (MSG _ _ msg1) <- tGet fromServer h
|
||||
(msg1, "hello") #== "delivers message"
|
||||
|
||||
Resp _ ok4 <- sendRecv h ("123", "1", "ACK")
|
||||
Resp _ ok4 <- sendRecv h ("123", rId, "ACK")
|
||||
(ok4, OK) #== "replies OK when message acknowledged if no more messages"
|
||||
|
||||
Resp sId2 err1 <- sendRecv h ("456", sId, "SEND :hello")
|
||||
@@ -73,7 +73,7 @@ testCreateSecure = do
|
||||
Resp _ (MSG _ _ msg) <- tGet fromServer h
|
||||
(msg, "hello again") #== "delivers message 2"
|
||||
|
||||
Resp _ ok5 <- sendRecv h ("123", "1", "ACK")
|
||||
Resp _ ok5 <- sendRecv h ("123", rId, "ACK")
|
||||
(ok5, OK) #== "replies OK when message acknowledged 2"
|
||||
|
||||
Resp _ err5 <- sendRecv h ("", sId, "SEND :hello")
|
||||
|
||||
Reference in New Issue
Block a user