Merge remote-tracking branch 'origin/master' into ab/bench-target

This commit is contained in:
Alexander Bondarenko
2024-04-01 12:24:02 +03:00
50 changed files with 1187 additions and 962 deletions
+10
View File
@@ -1,3 +1,13 @@
# 5.6.1
Version 5.6.1.0.
- Much faster iOS notification server start time (fewer skipped notifications).
- Fix SMP server stored message stats.
- Prevent overwriting uploaded XFTP files with subsequent upload attempts.
- Faster base64 encoding/parsing.
- Control port audit log and authentication.
# 5.6.0
Version 5.6.0.4.
+3 -1
View File
@@ -1,4 +1,5 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications #-}
@@ -38,7 +39,8 @@ logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
-- Warning: this SMP agent server is experimental - it does not work correctly with multiple connected TCP clients in some cases.
main :: IO ()
main = do
putStrLn $ "SMP agent listening on port " ++ tcpPort (cfg :: AgentConfig)
let AgentConfig {tcpPort} = cfg
putStrLn $ maybe (error "no agent port") (\port -> "SMP agent listening on port " ++ port) tcpPort
setLogLevel LogInfo -- LogError
Right st <- createAgentStore agentDbFile agentDbKey False MCConsole
withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg servers st
+6
View File
@@ -14,6 +14,12 @@ source-repository-package
location: https://github.com/simplex-chat/aeson.git
tag: aab7b5a14d6c5ea64c64dcaee418de1bb00dcc2b
-- old bs/text compat for 8.10
source-repository-package
type: git
location: https://github.com/simplex-chat/base64.git
tag: 2d77b6dbcaffc00570a70be8694049f3710e7c94
source-repository-package
type: git
location: https://github.com/simplex-chat/hs-socks.git
+6 -2
View File
@@ -1,5 +1,5 @@
name: simplexmq
version: 5.6.0.4
version: 5.6.1.0
synopsis: SimpleXMQ message broker
description: |
This package includes <./docs/Simplex-Messaging-Server.html server>,
@@ -31,7 +31,7 @@ dependencies:
- async == 2.2.*
- attoparsec == 0.14.*
- base >= 4.14 && < 5
- base64-bytestring >= 1.0 && < 1.3
- base64 == 1.0.*
- case-insensitive == 1.2.*
- composition == 1.0.*
- constraints >= 0.12 && < 0.14
@@ -202,3 +202,7 @@ ghc-options:
- -Wincomplete-record-updates
- -Wincomplete-uni-patterns
- -Wunused-type-patterns
- -O2
default-extensions:
- StrictData
+31 -15
View File
@@ -5,7 +5,7 @@ cabal-version: 1.12
-- see: https://github.com/sol/hpack
name: simplexmq
version: 5.6.0.4
version: 5.6.1.0
synopsis: SimpleXMQ message broker
description: This package includes <./docs/Simplex-Messaging-Server.html server>,
<./docs/Simplex-Messaging-Client.html client> and
@@ -119,6 +119,8 @@ library
Simplex.Messaging.Crypto.SNTRUP761.Bindings.FFI
Simplex.Messaging.Crypto.SNTRUP761.Bindings.RNG
Simplex.Messaging.Encoding
Simplex.Messaging.Encoding.Base64
Simplex.Messaging.Encoding.Base64.URL
Simplex.Messaging.Encoding.String
Simplex.Messaging.Notifications.Client
Simplex.Messaging.Notifications.Protocol
@@ -171,7 +173,9 @@ library
Paths_simplexmq
hs-source-dirs:
src
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns
default-extensions:
StrictData
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2
include-dirs:
cbits
c-sources:
@@ -187,7 +191,7 @@ library
, async ==2.2.*
, attoparsec ==0.14.*
, base >=4.14 && <5
, base64-bytestring >=1.0 && <1.3
, base64 ==1.0.*
, case-insensitive ==1.2.*
, composition ==1.0.*
, constraints >=0.12 && <0.14
@@ -251,7 +255,9 @@ executable ntf-server
Paths_simplexmq
hs-source-dirs:
apps/ntf-server
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
default-extensions:
StrictData
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts
build-depends:
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
@@ -260,7 +266,7 @@ executable ntf-server
, async ==2.2.*
, attoparsec ==0.14.*
, base >=4.14 && <5
, base64-bytestring >=1.0 && <1.3
, base64 ==1.0.*
, case-insensitive ==1.2.*
, composition ==1.0.*
, constraints >=0.12 && <0.14
@@ -325,7 +331,9 @@ executable smp-agent
Paths_simplexmq
hs-source-dirs:
apps/smp-agent
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
default-extensions:
StrictData
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts
build-depends:
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
@@ -334,7 +342,7 @@ executable smp-agent
, async ==2.2.*
, attoparsec ==0.14.*
, base >=4.14 && <5
, base64-bytestring >=1.0 && <1.3
, base64 ==1.0.*
, case-insensitive ==1.2.*
, composition ==1.0.*
, constraints >=0.12 && <0.14
@@ -399,7 +407,9 @@ executable smp-server
Paths_simplexmq
hs-source-dirs:
apps/smp-server
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
default-extensions:
StrictData
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts
build-depends:
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
@@ -408,7 +418,7 @@ executable smp-server
, async ==2.2.*
, attoparsec ==0.14.*
, base >=4.14 && <5
, base64-bytestring >=1.0 && <1.3
, base64 ==1.0.*
, case-insensitive ==1.2.*
, composition ==1.0.*
, constraints >=0.12 && <0.14
@@ -473,7 +483,9 @@ executable xftp
Paths_simplexmq
hs-source-dirs:
apps/xftp
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
default-extensions:
StrictData
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts
build-depends:
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
@@ -482,7 +494,7 @@ executable xftp
, async ==2.2.*
, attoparsec ==0.14.*
, base >=4.14 && <5
, base64-bytestring >=1.0 && <1.3
, base64 ==1.0.*
, case-insensitive ==1.2.*
, composition ==1.0.*
, constraints >=0.12 && <0.14
@@ -547,7 +559,9 @@ executable xftp-server
Paths_simplexmq
hs-source-dirs:
apps/xftp-server
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
default-extensions:
StrictData
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts
build-depends:
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
@@ -556,7 +570,7 @@ executable xftp-server
, async ==2.2.*
, attoparsec ==0.14.*
, base >=4.14 && <5
, base64-bytestring >=1.0 && <1.3
, base64 ==1.0.*
, case-insensitive ==1.2.*
, composition ==1.0.*
, constraints >=0.12 && <0.14
@@ -653,7 +667,9 @@ test-suite simplexmq-test
Paths_simplexmq
hs-source-dirs:
tests
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns
default-extensions:
StrictData
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2
build-depends:
HUnit ==1.6.*
, QuickCheck ==2.14.*
@@ -664,7 +680,7 @@ test-suite simplexmq-test
, async ==2.2.*
, attoparsec ==0.14.*
, base >=4.14 && <5
, base64-bytestring >=1.0 && <1.3
, base64 ==1.0.*
, case-insensitive ==1.2.*
, composition ==1.0.*
, constraints >=0.12 && <0.14
+90 -86
View File
@@ -74,8 +74,9 @@ import Simplex.Messaging.Util (catchAll_, liftError, tshow, unlessM, whenM)
import System.FilePath (takeFileName, (</>))
import UnliftIO
import UnliftIO.Directory
import qualified UnliftIO.Exception as E
startXFTPWorkers :: AgentMonad m => AgentClient -> Maybe FilePath -> m ()
startXFTPWorkers :: AgentClient -> Maybe FilePath -> AM ()
startXFTPWorkers c workDir = do
wd <- asks $ xftpWorkDir . xftpAgent
atomically $ writeTVar wd workDir
@@ -84,23 +85,26 @@ startXFTPWorkers c workDir = do
startSndFiles cfg
startDelFiles cfg
where
startRcvFiles :: AgentConfig -> AM ()
startRcvFiles AgentConfig {rcvFilesTTL} = do
pendingRcvServers <- withStore' c (`getPendingRcvFilesServers` rcvFilesTTL)
forM_ pendingRcvServers $ \s -> resumeXFTPRcvWork c (Just s)
lift . forM_ pendingRcvServers $ \s -> resumeXFTPRcvWork c (Just s)
-- start local worker for files pending decryption,
-- no need to make an extra query for the check
-- as the worker will check the store anyway
resumeXFTPRcvWork c Nothing
lift $ resumeXFTPRcvWork c Nothing
startSndFiles :: AgentConfig -> AM ()
startSndFiles AgentConfig {sndFilesTTL} = do
-- start worker for files pending encryption/creation
resumeXFTPSndWork c Nothing
lift $ resumeXFTPSndWork c Nothing
pendingSndServers <- withStore' c (`getPendingSndFilesServers` sndFilesTTL)
forM_ pendingSndServers $ \s -> resumeXFTPSndWork c (Just s)
lift . forM_ pendingSndServers $ \s -> resumeXFTPSndWork c (Just s)
startDelFiles :: AgentConfig -> AM ()
startDelFiles AgentConfig {rcvFilesTTL} = do
pendingDelServers <- withStore' c (`getPendingDelFilesServers` rcvFilesTTL)
forM_ pendingDelServers $ resumeXFTPDelWork c
lift . forM_ pendingDelServers $ resumeXFTPDelWork c
closeXFTPAgent :: MonadUnliftIO m => XFTPAgent -> m ()
closeXFTPAgent :: XFTPAgent -> IO ()
closeXFTPAgent a = do
stopWorkers $ xftpRcvWorkers a
stopWorkers $ xftpSndWorkers a
@@ -108,16 +112,16 @@ closeXFTPAgent a = do
where
stopWorkers workers = atomically (swapTVar workers M.empty) >>= mapM_ (liftIO . cancelWorker)
xftpReceiveFile' :: AgentMonad m => AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Maybe CryptoFileArgs -> m RcvFileId
xftpReceiveFile' :: AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Maybe CryptoFileArgs -> AM RcvFileId
xftpReceiveFile' c userId (ValidFileDescription fd@FileDescription {chunks, redirect}) cfArgs = do
g <- asks random
prefixPath <- getPrefixPath "rcv.xftp"
prefixPath <- lift $ getPrefixPath "rcv.xftp"
createDirectory prefixPath
let relPrefixPath = takeFileName prefixPath
relTmpPath = relPrefixPath </> "xftp.encrypted"
relSavePath = relPrefixPath </> "xftp.decrypted"
createDirectory =<< toFSFilePath relTmpPath
createEmptyFile =<< toFSFilePath relSavePath
lift $ createDirectory =<< toFSFilePath relTmpPath
lift $ createEmptyFile =<< toFSFilePath relSavePath
let saveFile = CryptoFile relSavePath cfArgs
fId <- case redirect of
Nothing -> withStore c $ \db -> createRcvFile db g userId fd relPrefixPath relTmpPath saveFile
@@ -125,8 +129,8 @@ xftpReceiveFile' c userId (ValidFileDescription fd@FileDescription {chunks, redi
-- prepare description paths
let relTmpPathRedirect = relPrefixPath </> "xftp.redirect-encrypted"
relSavePathRedirect = relPrefixPath </> "xftp.redirect-decrypted"
createDirectory =<< toFSFilePath relTmpPathRedirect
createEmptyFile =<< toFSFilePath relSavePathRedirect
lift $ createDirectory =<< toFSFilePath relTmpPathRedirect
lift $ createEmptyFile =<< toFSFilePath relSavePathRedirect
cfArgsRedirect <- atomically $ CF.randomArgs g
let saveFileRedirect = CryptoFile relSavePathRedirect $ Just cfArgsRedirect
-- create download tasks
@@ -134,42 +138,42 @@ xftpReceiveFile' c userId (ValidFileDescription fd@FileDescription {chunks, redi
forM_ chunks (downloadChunk c)
pure fId
downloadChunk :: AgentMonad m => AgentClient -> FileChunk -> m ()
downloadChunk :: AgentClient -> FileChunk -> AM ()
downloadChunk c FileChunk {replicas = (FileChunkReplica {server} : _)} = do
void $ getXFTPRcvWorker True c (Just server)
lift . void $ getXFTPRcvWorker True c (Just server)
downloadChunk _ _ = throwError $ INTERNAL "no replicas"
getPrefixPath :: AgentMonad m => String -> m FilePath
getPrefixPath :: String -> AM' FilePath
getPrefixPath suffix = do
workPath <- getXFTPWorkPath
ts <- liftIO getCurrentTime
let isoTime = formatTime defaultTimeLocale "%Y%m%d_%H%M%S_%6q" ts
uniqueCombine workPath (isoTime <> "_" <> suffix)
toFSFilePath :: AgentMonad m => FilePath -> m FilePath
toFSFilePath :: FilePath -> AM' FilePath
toFSFilePath f = (</> f) <$> getXFTPWorkPath
createEmptyFile :: AgentMonad m => FilePath -> m ()
createEmptyFile :: FilePath -> AM' ()
createEmptyFile fPath = liftIO $ B.writeFile fPath ""
resumeXFTPRcvWork :: AgentMonad' m => AgentClient -> Maybe XFTPServer -> m ()
resumeXFTPRcvWork :: AgentClient -> Maybe XFTPServer -> AM' ()
resumeXFTPRcvWork = void .: getXFTPRcvWorker False
getXFTPRcvWorker :: AgentMonad' m => Bool -> AgentClient -> Maybe XFTPServer -> m Worker
getXFTPRcvWorker :: Bool -> AgentClient -> Maybe XFTPServer -> AM' Worker
getXFTPRcvWorker hasWork c server = do
ws <- asks $ xftpRcvWorkers . xftpAgent
getAgentWorker "xftp_rcv" hasWork c server ws $
maybe (runXFTPRcvLocalWorker c) (runXFTPRcvWorker c) server
runXFTPRcvWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> Worker -> m ()
runXFTPRcvWorker :: AgentClient -> XFTPServer -> Worker -> AM ()
runXFTPRcvWorker c srv Worker {doWork} = do
cfg <- asks config
forever $ do
waitForWork doWork
lift $ waitForWork doWork
atomically $ assertAgentForeground c
runXFTPOperation cfg
where
runXFTPOperation :: AgentConfig -> m ()
runXFTPOperation :: AgentConfig -> AM ()
runXFTPOperation AgentConfig {rcvFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} =
withWork c doWork (\db -> getNextRcvChunkToDownload db srv rcvFilesTTL) $ \case
RcvFileChunk {rcvFileId, rcvFileEntityId, fileTmpPath, replicas = []} -> rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath) "chunk has no replicas"
@@ -182,14 +186,14 @@ runXFTPRcvWorker c srv Worker {doWork} = do
retryLoop loop e replicaDelay = do
flip catchAgentError (\_ -> pure ()) $ do
when notifyOnRetry $ notify c rcvFileEntityId $ RFERR e
closeXFTPServerClient c userId server digest
liftIO $ closeXFTPServerClient c userId server digest
withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay
atomically $ assertAgentForeground c
loop
retryDone e = rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath) (show e)
downloadFileChunk :: RcvFileChunk -> RcvFileChunkReplica -> m ()
downloadFileChunk :: RcvFileChunk -> RcvFileChunkReplica -> AM ()
downloadFileChunk RcvFileChunk {userId, rcvFileId, rcvFileEntityId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath} replica = do
fsFileTmpPath <- toFSFilePath fileTmpPath
fsFileTmpPath <- lift $ toFSFilePath fileTmpPath
chunkPath <- uniqueCombine fsFileTmpPath $ show chunkNo
let chunkSpec = XFTPRcvChunkSpec chunkPath (unFileSize chunkSize) (unFileDigest digest)
relChunkPath = fileTmpPath </> takeFileName chunkPath
@@ -206,7 +210,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do
liftIO . when complete $ updateRcvFileStatus db rcvFileId RFSReceived
pure (entityId, complete, RFPROG rcvd total)
notify c entityId progress
when complete . void $
when complete . lift . void $
getXFTPRcvWorker True c Nothing
where
receivedSize :: [RcvFileChunk] -> Int64
@@ -222,37 +226,37 @@ withRetryIntervalLimit maxN ri action =
withRetryIntervalCount ri $ \n delay loop ->
when (n < maxN) $ action delay loop
retryOnError :: AgentMonad m => Text -> m a -> m a -> AgentErrorType -> m a
retryOnError :: Text -> AM a -> AM a -> AgentErrorType -> AM a
retryOnError name loop done e = do
logError $ name <> " error: " <> tshow e
if temporaryAgentError e
then loop
else done
rcvWorkerInternalError :: AgentMonad m => AgentClient -> DBRcvFileId -> RcvFileId -> Maybe FilePath -> String -> m ()
rcvWorkerInternalError :: AgentClient -> DBRcvFileId -> RcvFileId -> Maybe FilePath -> String -> AM ()
rcvWorkerInternalError c rcvFileId rcvFileEntityId tmpPath internalErrStr = do
forM_ tmpPath (removePath <=< toFSFilePath)
lift $ forM_ tmpPath (removePath <=< toFSFilePath)
withStore' c $ \db -> updateRcvFileError db rcvFileId internalErrStr
notify c rcvFileEntityId $ RFERR $ INTERNAL internalErrStr
runXFTPRcvLocalWorker :: forall m. AgentMonad m => AgentClient -> Worker -> m ()
runXFTPRcvLocalWorker :: AgentClient -> Worker -> AM ()
runXFTPRcvLocalWorker c Worker {doWork} = do
cfg <- asks config
forever $ do
waitForWork doWork
lift $ waitForWork doWork
atomically $ assertAgentForeground c
runXFTPOperation cfg
where
runXFTPOperation :: AgentConfig -> m ()
runXFTPOperation :: AgentConfig -> AM ()
runXFTPOperation AgentConfig {rcvFilesTTL} =
withWork c doWork (`getNextRcvFileToDecrypt` rcvFilesTTL) $
\f@RcvFile {rcvFileId, rcvFileEntityId, tmpPath} ->
decryptFile f `catchAgentError` (rcvWorkerInternalError c rcvFileId rcvFileEntityId tmpPath . show)
decryptFile :: RcvFile -> m ()
decryptFile :: RcvFile -> AM ()
decryptFile RcvFile {rcvFileId, rcvFileEntityId, size, digest, key, nonce, tmpPath, saveFile, status, chunks, redirect} = do
let CryptoFile savePath cfArgs = saveFile
fsSavePath <- toFSFilePath savePath
when (status == RFSDecrypting) $
fsSavePath <- lift $ toFSFilePath savePath
lift . when (status == RFSDecrypting) $
whenM (doesFileExist fsSavePath) (removeFile fsSavePath >> createEmptyFile fsSavePath)
withStore' c $ \db -> updateRcvFileStatus db rcvFileId RFSDecrypting
chunkPaths <- getChunkPaths chunks
@@ -265,16 +269,16 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
case redirect of
Nothing -> do
notify c rcvFileEntityId $ RFDONE fsSavePath
forM_ tmpPath (removePath <=< toFSFilePath)
lift $ forM_ tmpPath (removePath <=< toFSFilePath)
atomically $ waitUntilForeground c
withStore' c (`updateRcvFileComplete` rcvFileId)
Just RcvFileRedirect {redirectFileInfo, redirectDbId} -> do
let RedirectFileInfo {size = redirectSize, digest = redirectDigest} = redirectFileInfo
forM_ tmpPath (removePath <=< toFSFilePath)
lift $ forM_ tmpPath (removePath <=< toFSFilePath)
atomically $ waitUntilForeground c
withStore' c (`updateRcvFileComplete` rcvFileId)
-- proceed with redirect
yaml <- liftError (INTERNAL . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `finally` (toFSFilePath fsSavePath >>= removePath)
yaml <- liftError (INTERNAL . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath)
next@FileDescription {chunks = nextChunks} <- case strDecode (LB.toStrict yaml) of
Left _ -> throwError . XFTP $ XFTP.REDIRECT "decode error"
Right (ValidFileDescription fd@FileDescription {size = dstSize, digest = dstDigest})
@@ -285,19 +289,19 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
withStore c $ \db -> updateRcvFileRedirect db redirectDbId next
forM_ nextChunks (downloadChunk c)
where
getChunkPaths :: [RcvFileChunk] -> m [FilePath]
getChunkPaths :: [RcvFileChunk] -> AM [FilePath]
getChunkPaths [] = pure []
getChunkPaths (RcvFileChunk {chunkTmpPath = Just path} : cs) = do
ps <- getChunkPaths cs
fsPath <- toFSFilePath path
fsPath <- lift $ toFSFilePath path
pure $ fsPath : ps
getChunkPaths (RcvFileChunk {chunkTmpPath = Nothing} : _cs) =
throwError $ INTERNAL "no chunk path"
xftpDeleteRcvFile' :: AgentMonad m => AgentClient -> RcvFileId -> m ()
xftpDeleteRcvFile' :: AgentClient -> RcvFileId -> AM' ()
xftpDeleteRcvFile' c rcvFileEntityId = xftpDeleteRcvFiles' c [rcvFileEntityId]
xftpDeleteRcvFiles' :: forall m. AgentMonad m => AgentClient -> [RcvFileId] -> m ()
xftpDeleteRcvFiles' :: AgentClient -> [RcvFileId] -> AM' ()
xftpDeleteRcvFiles' c rcvFileEntityIds = do
rcvFiles <- rights <$> withStoreBatch c (\db -> map (fmap (first storeError) . getRcvFileByEntityId db) rcvFileEntityIds)
redirects <- rights <$> batchFiles getRcvFileRedirects rcvFiles
@@ -309,29 +313,29 @@ xftpDeleteRcvFiles' c rcvFileEntityIds = do
(removePath . (workPath </>)) prefixPath `catchAll_` pure ()
where
fileComplete RcvFile {status} = status == RFSComplete || status == RFSError
batchFiles :: (DB.Connection -> DBRcvFileId -> IO a) -> [RcvFile] -> m [Either AgentErrorType a]
batchFiles :: (DB.Connection -> DBRcvFileId -> IO a) -> [RcvFile] -> AM' [Either AgentErrorType a]
batchFiles f rcvFiles = withStoreBatch' c $ \db -> map (\RcvFile {rcvFileId} -> f db rcvFileId) rcvFiles
notify :: forall m e. (MonadUnliftIO m, AEntityI e) => AgentClient -> EntityId -> ACommand 'Agent e -> m ()
notify :: forall m e. (MonadIO m, AEntityI e) => AgentClient -> EntityId -> ACommand 'Agent e -> m ()
notify c entId cmd = atomically $ writeTBQueue (subQ c) ("", entId, APC (sAEntity @e) cmd)
xftpSendFile' :: AgentMonad m => AgentClient -> UserId -> CryptoFile -> Int -> m SndFileId
xftpSendFile' :: AgentClient -> UserId -> CryptoFile -> Int -> AM SndFileId
xftpSendFile' c userId file numRecipients = do
g <- asks random
prefixPath <- getPrefixPath "snd.xftp"
prefixPath <- lift $ getPrefixPath "snd.xftp"
createDirectory prefixPath
let relPrefixPath = takeFileName prefixPath
key <- atomically $ C.randomSbKey g
nonce <- atomically $ C.randomCbNonce g
-- saving absolute filePath will not allow to restore file encryption after app update, but it's a short window
fId <- withStore c $ \db -> createSndFile db g userId file numRecipients relPrefixPath key nonce Nothing
void $ getXFTPSndWorker True c Nothing
lift . void $ getXFTPSndWorker True c Nothing
pure fId
xftpSendDescription' :: forall m. AgentMonad m => AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Int -> m SndFileId
xftpSendDescription' :: AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Int -> AM SndFileId
xftpSendDescription' c userId (ValidFileDescription fdDirect@FileDescription {size, digest}) numRecipients = do
g <- asks random
prefixPath <- getPrefixPath "snd.xftp"
prefixPath <- lift $ getPrefixPath "snd.xftp"
createDirectory prefixPath
let relPrefixPath = takeFileName prefixPath
let directYaml = prefixPath </> "direct.yaml"
@@ -341,39 +345,39 @@ xftpSendDescription' c userId (ValidFileDescription fdDirect@FileDescription {si
key <- atomically $ C.randomSbKey g
nonce <- atomically $ C.randomCbNonce g
fId <- withStore c $ \db -> createSndFile db g userId file numRecipients relPrefixPath key nonce $ Just RedirectFileInfo {size, digest}
void $ getXFTPSndWorker True c Nothing
lift . void $ getXFTPSndWorker True c Nothing
pure fId
resumeXFTPSndWork :: AgentMonad' m => AgentClient -> Maybe XFTPServer -> m ()
resumeXFTPSndWork :: AgentClient -> Maybe XFTPServer -> AM' ()
resumeXFTPSndWork = void .: getXFTPSndWorker False
getXFTPSndWorker :: AgentMonad' m => Bool -> AgentClient -> Maybe XFTPServer -> m Worker
getXFTPSndWorker :: Bool -> AgentClient -> Maybe XFTPServer -> AM' Worker
getXFTPSndWorker hasWork c server = do
ws <- asks $ xftpSndWorkers . xftpAgent
getAgentWorker "xftp_snd" hasWork c server ws $
maybe (runXFTPSndPrepareWorker c) (runXFTPSndWorker c) server
runXFTPSndPrepareWorker :: forall m. AgentMonad m => AgentClient -> Worker -> m ()
runXFTPSndPrepareWorker :: AgentClient -> Worker -> AM ()
runXFTPSndPrepareWorker c Worker {doWork} = do
cfg <- asks config
forever $ do
waitForWork doWork
lift $ waitForWork doWork
atomically $ assertAgentForeground c
runXFTPOperation cfg
where
runXFTPOperation :: AgentConfig -> m ()
runXFTPOperation :: AgentConfig -> AM ()
runXFTPOperation cfg@AgentConfig {sndFilesTTL} =
withWork c doWork (`getNextSndFileToPrepare` sndFilesTTL) $
\f@SndFile {sndFileId, sndFileEntityId, prefixPath} ->
prepareFile cfg f `catchAgentError` (sndWorkerInternalError c sndFileId sndFileEntityId prefixPath . show)
prepareFile :: AgentConfig -> SndFile -> m ()
prepareFile :: AgentConfig -> SndFile -> AM ()
prepareFile _ SndFile {prefixPath = Nothing} =
throwError $ INTERNAL "no prefix path"
prepareFile cfg sndFile@SndFile {sndFileId, userId, prefixPath = Just ppath, status} = do
SndFile {numRecipients, chunks} <-
if status /= SFSEncrypted -- status is SFSNew or SFSEncrypting
then do
fsEncPath <- toFSFilePath $ sndFileEncPath ppath
fsEncPath <- lift . toFSFilePath $ sndFileEncPath ppath
when (status == SFSEncrypting) . whenM (doesFileExist fsEncPath) $
removeFile fsEncPath
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSEncrypting
@@ -389,7 +393,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSUploading
where
AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients, messageRetryInterval = ri} = cfg
encryptFileForUpload :: SndFile -> FilePath -> m (FileDigest, [(XFTPChunkSpec, FileDigest)])
encryptFileForUpload :: SndFile -> FilePath -> AM (FileDigest, [(XFTPChunkSpec, FileDigest)])
encryptFileForUpload SndFile {key, nonce, srcFile, redirect} fsEncPath = do
let CryptoFile {filePath} = srcFile
fileName = takeFileName filePath
@@ -412,12 +416,12 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
chunkCreated :: SndFileChunk -> Bool
chunkCreated SndFileChunk {replicas} =
any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSCreated) replicas
createChunk :: Int -> SndFileChunk -> m ()
createChunk :: Int -> SndFileChunk -> AM ()
createChunk numRecipients' ch = do
atomically $ assertAgentForeground c
(replica, ProtoServerWithAuth srv _) <- tryCreate
withStore' c $ \db -> createSndFileReplica db ch replica
void $ getXFTPSndWorker True c (Just srv)
lift . void $ getXFTPSndWorker True c (Just srv)
where
tryCreate = do
usedSrvs <- newTVarIO ([] :: [XFTPServer])
@@ -433,21 +437,21 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
replica <- agentXFTPNewChunk c ch numRecipients' srvAuth
pure (replica, srvAuth)
sndWorkerInternalError :: AgentMonad m => AgentClient -> DBSndFileId -> SndFileId -> Maybe FilePath -> String -> m ()
sndWorkerInternalError :: AgentClient -> DBSndFileId -> SndFileId -> Maybe FilePath -> String -> AM ()
sndWorkerInternalError c sndFileId sndFileEntityId prefixPath internalErrStr = do
forM_ prefixPath $ removePath <=< toFSFilePath
lift . forM_ prefixPath $ removePath <=< toFSFilePath
withStore' c $ \db -> updateSndFileError db sndFileId internalErrStr
notify c sndFileEntityId $ SFERR $ INTERNAL internalErrStr
runXFTPSndWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> Worker -> m ()
runXFTPSndWorker :: AgentClient -> XFTPServer -> Worker -> AM ()
runXFTPSndWorker c srv Worker {doWork} = do
cfg <- asks config
forever $ do
waitForWork doWork
lift $ waitForWork doWork
atomically $ assertAgentForeground c
runXFTPOperation cfg
where
runXFTPOperation :: AgentConfig -> m ()
runXFTPOperation :: AgentConfig -> AM ()
runXFTPOperation cfg@AgentConfig {sndFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} = do
withWork c doWork (\db -> getNextSndChunkToUpload db srv sndFilesTTL) $ \case
SndFileChunk {sndFileId, sndFileEntityId, filePrefixPath, replicas = []} -> sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) "chunk has no replicas"
@@ -460,15 +464,15 @@ runXFTPSndWorker c srv Worker {doWork} = do
retryLoop loop e replicaDelay = do
flip catchAgentError (\_ -> pure ()) $ do
when notifyOnRetry $ notify c sndFileEntityId $ SFERR e
closeXFTPServerClient c userId server digest
liftIO $ closeXFTPServerClient c userId server digest
withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay
atomically $ assertAgentForeground c
loop
retryDone e = sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) (show e)
uploadFileChunk :: AgentConfig -> SndFileChunk -> SndFileChunkReplica -> m ()
uploadFileChunk :: AgentConfig -> SndFileChunk -> SndFileChunkReplica -> AM ()
uploadFileChunk AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients} sndFileChunk@SndFileChunk {sndFileId, userId, chunkSpec = chunkSpec@XFTPChunkSpec {filePath}, digest = chunkDigest} replica = do
replica'@SndFileChunkReplica {sndChunkReplicaId} <- addRecipients sndFileChunk replica
fsFilePath <- toFSFilePath filePath
fsFilePath <- lift $ toFSFilePath filePath
unlessM (doesFileExist fsFilePath) $ throwError $ INTERNAL "encrypted file doesn't exist on upload"
let chunkSpec' = chunkSpec {filePath = fsFilePath} :: XFTPChunkSpec
atomically $ assertAgentForeground c
@@ -484,10 +488,10 @@ runXFTPSndWorker c srv Worker {doWork} = do
when complete $ do
(sndDescr, rcvDescrs) <- sndFileToDescrs sf
notify c sndFileEntityId $ SFDONE sndDescr rcvDescrs
forM_ prefixPath $ removePath <=< toFSFilePath
lift . forM_ prefixPath $ removePath <=< toFSFilePath
withStore' c $ \db -> updateSndFileComplete db sndFileId
where
addRecipients :: SndFileChunk -> SndFileChunkReplica -> m SndFileChunkReplica
addRecipients :: SndFileChunk -> SndFileChunkReplica -> AM SndFileChunkReplica
addRecipients ch@SndFileChunk {numRecipients} cr@SndFileChunkReplica {rcvIdsKeys}
| length rcvIdsKeys > numRecipients = throwError $ INTERNAL "too many recipients"
| length rcvIdsKeys == numRecipients = pure cr
@@ -496,7 +500,7 @@ runXFTPSndWorker c srv Worker {doWork} = do
rcvIdsKeys' <- agentXFTPAddRecipients c userId chunkDigest cr numRecipients'
cr' <- withStore' c $ \db -> addSndChunkReplicaRecipients db cr $ L.toList rcvIdsKeys'
addRecipients ch cr'
sndFileToDescrs :: SndFile -> m (ValidFileDescription 'FSender, [ValidFileDescription 'FRecipient])
sndFileToDescrs :: SndFile -> AM (ValidFileDescription 'FSender, [ValidFileDescription 'FRecipient])
sndFileToDescrs SndFile {digest = Nothing} = throwError $ INTERNAL "snd file has no digest"
sndFileToDescrs SndFile {chunks = []} = throwError $ INTERNAL "snd file has no chunks"
sndFileToDescrs SndFile {digest = Just digest, key, nonce, chunks = chunks@(fstChunk : _), redirect} = do
@@ -511,7 +515,7 @@ runXFTPSndWorker c srv Worker {doWork} = do
fdRcvs = createRcvFileDescriptions fdRcv chunks
validFdRcvs <- either (throwError . INTERNAL) pure $ mapM validateFileDescription fdRcvs
pure (validFdSnd, validFdRcvs)
toSndDescrChunk :: SndFileChunk -> m FileChunk
toSndDescrChunk :: SndFileChunk -> AM FileChunk
toSndDescrChunk SndFileChunk {replicas = []} = throwError $ INTERNAL "snd file chunk has no replicas"
toSndDescrChunk ch@SndFileChunk {chunkNo, digest = chDigest, replicas = (SndFileChunkReplica {server, replicaId, replicaKey} : _)} = do
let chunkSize = FileSize $ sndChunkSize ch
@@ -562,10 +566,10 @@ runXFTPSndWorker c srv Worker {doWork} = do
chunkUploaded SndFileChunk {replicas} =
any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSUploaded) replicas
deleteSndFileInternal :: AgentMonad m => AgentClient -> SndFileId -> m ()
deleteSndFileInternal :: AgentClient -> SndFileId -> AM' ()
deleteSndFileInternal c sndFileEntityId = deleteSndFilesInternal c [sndFileEntityId]
deleteSndFilesInternal :: forall m. AgentMonad m => AgentClient -> [SndFileId] -> m ()
deleteSndFilesInternal :: AgentClient -> [SndFileId] -> AM' ()
deleteSndFilesInternal c sndFileEntityIds = do
sndFiles <- rights <$> withStoreBatch c (\db -> map (fmap (first storeError) . getSndFileByEntityId db) sndFileEntityIds)
let (toDelete, toMarkDeleted) = partition fileComplete sndFiles
@@ -576,15 +580,15 @@ deleteSndFilesInternal c sndFileEntityIds = do
batchFiles_ updateSndFileDeleted toMarkDeleted
where
fileComplete SndFile {status} = status == SFSComplete || status == SFSError
batchFiles_ :: (DB.Connection -> DBSndFileId -> IO a) -> [SndFile] -> m ()
batchFiles_ :: (DB.Connection -> DBSndFileId -> IO a) -> [SndFile] -> AM' ()
batchFiles_ f sndFiles = void $ withStoreBatch' c $ \db -> map (\SndFile {sndFileId} -> f db sndFileId) sndFiles
deleteSndFileRemote :: forall m. AgentMonad m => AgentClient -> UserId -> SndFileId -> ValidFileDescription 'FSender -> m ()
deleteSndFileRemote :: AgentClient -> UserId -> SndFileId -> ValidFileDescription 'FSender -> AM' ()
deleteSndFileRemote c userId sndFileEntityId sfd = deleteSndFilesRemote c userId [(sndFileEntityId, sfd)]
deleteSndFilesRemote :: forall m. AgentMonad m => AgentClient -> UserId -> [(SndFileId, ValidFileDescription 'FSender)] -> m ()
deleteSndFilesRemote :: AgentClient -> UserId -> [(SndFileId, ValidFileDescription 'FSender)] -> AM' ()
deleteSndFilesRemote c userId sndFileIdsDescrs = do
deleteSndFilesInternal c (map fst sndFileIdsDescrs) `catchAgentError` (notify c "" . SFERR)
deleteSndFilesInternal c (map fst sndFileIdsDescrs) `E.catchAny` (notify c "" . SFERR . INTERNAL . show)
let rs = concatMap (mapMaybe chunkReplica . fdChunks . snd) sndFileIdsDescrs
void $ withStoreBatch' c (\db -> map (uncurry $ createDeletedSndChunkReplica db userId) rs)
let servers = S.fromList $ map (\(FileChunkReplica {server}, _) -> server) rs
@@ -596,23 +600,23 @@ deleteSndFilesRemote c userId sndFileIdsDescrs = do
FileChunk {digest, replicas = replica : _} -> Just (replica, digest)
_ -> Nothing
resumeXFTPDelWork :: AgentMonad' m => AgentClient -> XFTPServer -> m ()
resumeXFTPDelWork :: AgentClient -> XFTPServer -> AM' ()
resumeXFTPDelWork = void .: getXFTPDelWorker False
getXFTPDelWorker :: AgentMonad' m => Bool -> AgentClient -> XFTPServer -> m Worker
getXFTPDelWorker :: Bool -> AgentClient -> XFTPServer -> AM' Worker
getXFTPDelWorker hasWork c server = do
ws <- asks $ xftpDelWorkers . xftpAgent
getAgentWorker "xftp_del" hasWork c server ws $ runXFTPDelWorker c server
runXFTPDelWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> Worker -> m ()
runXFTPDelWorker :: AgentClient -> XFTPServer -> Worker -> AM ()
runXFTPDelWorker c srv Worker {doWork} = do
cfg <- asks config
forever $ do
waitForWork doWork
lift $ waitForWork doWork
atomically $ assertAgentForeground c
runXFTPOperation cfg
where
runXFTPOperation :: AgentConfig -> m ()
runXFTPOperation :: AgentConfig -> AM ()
runXFTPOperation AgentConfig {rcvFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} = do
-- no point in deleting files older than rcv ttl, as they will be expired on server
withWork c doWork (\db -> getNextDeletedSndChunkReplica db srv rcvFilesTTL) processDeletedReplica
@@ -626,7 +630,7 @@ runXFTPDelWorker c srv Worker {doWork} = do
retryLoop loop e replicaDelay = do
flip catchAgentError (\_ -> pure ()) $ do
when notifyOnRetry $ notify c "" $ SFERR e
closeXFTPServerClient c userId server chunkDigest
liftIO $ closeXFTPServerClient c userId server chunkDigest
withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay
atomically $ assertAgentForeground c
loop
@@ -635,7 +639,7 @@ runXFTPDelWorker c srv Worker {doWork} = do
agentXFTPDeleteChunk c userId replica
withStore' c $ \db -> deleteDeletedSndChunkReplica db deletedSndChunkReplicaId
delWorkerInternalError :: AgentMonad m => AgentClient -> Int64 -> AgentErrorType -> m ()
delWorkerInternalError :: AgentClient -> Int64 -> AgentErrorType -> AM ()
delWorkerInternalError c deletedSndChunkReplicaId e = do
withStore' c $ \db -> deleteDeletedSndChunkReplica db deletedSndChunkReplicaId
notify c "" $ SFERR e
+5 -5
View File
@@ -50,7 +50,7 @@ import Simplex.Messaging.Transport.Client (TransportClientConfig, TransportHost)
import Simplex.Messaging.Transport.HTTP2
import Simplex.Messaging.Transport.HTTP2.Client
import Simplex.Messaging.Transport.HTTP2.File
import Simplex.Messaging.Util (bshow, liftEitherError, whenM)
import Simplex.Messaging.Util (bshow, whenM)
import UnliftIO
import UnliftIO.Directory
@@ -98,7 +98,7 @@ getXFTPClient transportSession@(_, srv, _) config@XFTPClientConfig {xftpNetworkC
clientVar <- newTVarIO Nothing
let usePort = if null port then "443" else port
clientDisconnected = readTVarIO clientVar >>= mapM_ disconnected
http2Client <- liftEitherError xftpClientError $ getVerifiedHTTP2Client (Just username) useHost usePort (Just keyHash) Nothing http2Config clientDisconnected
http2Client <- withExceptT xftpClientError . ExceptT $ getVerifiedHTTP2Client (Just username) useHost usePort (Just keyHash) Nothing http2Config clientDisconnected
let HTTP2Client {sessionId} = http2Client
thParams = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = currentXFTPVersion, thAuth = Nothing, implySessId = False, batch = True}
c = XFTPClient {http2Client, thParams, transportSession, config}
@@ -145,7 +145,7 @@ sendXFTPTransmission :: XFTPClient -> ByteString -> Maybe XFTPChunkSpec -> Excep
sendXFTPTransmission XFTPClient {config, thParams, http2Client} t chunkSpec_ = do
let req = H.requestStreaming N.methodPost "/" [] streamBody
reqTimeout = (\XFTPChunkSpec {chunkSize} -> chunkTimeout config chunkSize) <$> chunkSpec_
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- liftEitherError xftpClientError $ sendRequest http2Client req reqTimeout
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req reqTimeout
when (B.length bodyHead /= xftpBlockSize) $ throwError $ PCEResponseError BLOCK
-- TODO validate that the file ID is the same as in the request?
(_, _, (_, _fId, respOrErr)) <- liftEither . first PCEResponseError $ xftpDecodeTransmission thParams bodyHead
@@ -196,9 +196,9 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec {
let dhSecret = C.dh' sDhKey rpDhKey
cbState <- liftEither . first PCECryptoError $ LC.cbInit dhSecret cbNonce
let t = chunkTimeout config chunkSize
t `timeout` download cbState >>= maybe (throwError PCEResponseTimeout) pure
ExceptT (sequence <$> (t `timeout` download cbState)) >>= maybe (throwError PCEResponseTimeout) pure
where
download cbState =
download cbState = runExceptT $
withExceptT PCEResponseError $
receiveEncFile chunkPart cbState chunkSpec `catchError` \e ->
whenM (doesFileExist filePath) (removeFile filePath) >> throwError e
+3 -2
View File
@@ -10,6 +10,7 @@ module Simplex.FileTransfer.Client.Agent where
import Control.Logger.Simple (logInfo)
import Control.Monad
import Control.Monad.Except
import Control.Monad.Trans (lift)
import Data.Bifunctor (first)
import qualified Data.ByteString.Char8 as B
import Data.Text (Text)
@@ -86,7 +87,7 @@ getXFTPServerClient XFTPClientAgent {xftpClients, config} srv = do
waitForXFTPClient :: XFTPClientVar -> ME XFTPClient
waitForXFTPClient clientVar = do
let XFTPClientConfig {xftpNetworkConfig = NetworkConfig {tcpConnectTimeout}} = xftpConfig config
client_ <- tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
liftEither $ case client_ of
Just (Right c) -> Right c
Just (Left e) -> Left e
@@ -110,7 +111,7 @@ getXFTPServerClient XFTPClientAgent {xftpClients, config} srv = do
TM.delete srv xftpClients
throwError e
tryConnectAsync :: ME ()
tryConnectAsync = void . async $ do
tryConnectAsync = void . lift . async . runExceptT $ do
withRetryInterval (reconnectInterval config) $ \_ loop -> void $ tryConnectClient loop
showServer :: XFTPServer -> Text
+7 -2
View File
@@ -37,6 +37,7 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Char (toLower)
import Data.Either (partitionEithers)
import Data.Int (Int64)
import Data.List (foldl', sortOn)
import Data.List.NonEmpty (NonEmpty (..), nonEmpty)
@@ -321,7 +322,9 @@ cliSendFileOpts SendOptions {filePath, outputDir, numRecipients, xftpServers, re
-- the reason we don't do pooled downloads here within one server is that http2 library doesn't handle cleint concurrency, even though
-- upload doesn't allow other requests within the same client until complete (but download does allow).
logInfo $ "uploading " <> tshow (length chunks) <> " chunks..."
map snd . sortOn fst . concat <$> pooledForConcurrentlyN 16 chunks' (mapM $ uploadFileChunk a)
(errs, rs) <- partitionEithers . concat <$> liftIO (pooledForConcurrentlyN 16 chunks' . mapM $ runExceptT . uploadFileChunk a)
mapM_ throwError errs
pure $ map snd (sortOn fst rs)
where
uploadFileChunk :: XFTPClientAgent -> (Int, XFTPChunkSpec, XFTPServerWithAuth) -> ExceptT CLIError IO (Int, SentFileChunk)
uploadFileChunk a (chunkNo, chunkSpec@XFTPChunkSpec {chunkSize}, ProtoServerWithAuth xftpServer auth) = do
@@ -433,7 +436,9 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath,
FileChunkReplica {server} : _ -> server
srvChunks = groupAllOn srv chunks
g <- liftIO C.newRandom
chunkPaths <- map snd . sortOn fst . concat <$> pooledForConcurrentlyN 16 srvChunks (mapM $ downloadFileChunk g a encPath size downloadedChunks)
(errs, rs) <- partitionEithers . concat <$> liftIO (pooledForConcurrentlyN 16 srvChunks $ mapM $ runExceptT . downloadFileChunk g a encPath size downloadedChunks)
mapM_ throwError errs
let chunkPaths = map snd $ sortOn fst rs
encDigest <- liftIO $ LC.sha512Hash <$> readChunks chunkPaths
when (encDigest /= unFileDigest digest) $ throwError $ CLIError "File digest mismatch"
encSize <- liftIO $ foldM (\s path -> (s +) . fromIntegral <$> getFileSize path) 0 chunkPaths
+1 -1
View File
@@ -29,7 +29,7 @@ import UnliftIO.Directory (removeFile)
encryptFile :: CryptoFile -> ByteString -> C.SbKey -> C.CbNonce -> Int64 -> Int64 -> FilePath -> ExceptT FTCryptoError IO ()
encryptFile srcFile fileHdr key nonce fileSize' encSize encFile = do
sb <- liftEitherWith FTCECryptoError $ LC.sbInit key nonce
CF.withFile srcFile ReadMode $ \r -> withFile encFile WriteMode $ \w -> do
CF.withFile srcFile ReadMode $ \r -> ExceptT . withFile encFile WriteMode $ \w -> runExceptT $ do
let lenStr = smpEncode fileSize'
(hdr, !sb') = LC.sbEncryptChunk sb $ lenStr <> fileHdr
padLen = encSize - authTagSize - fileSize' - 8
+31 -21
View File
@@ -4,6 +4,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedLists #-}
@@ -17,7 +18,6 @@ import Control.Monad
import Control.Monad.Except
import Control.Monad.Reader
import Data.Bifunctor (first)
import qualified Data.ByteString.Base64.URL as B64
import Data.ByteString.Builder (byteString)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
@@ -26,7 +26,7 @@ import Data.List (intercalate)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import Data.Maybe (fromMaybe, isJust)
import qualified Data.Text as T
import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
@@ -46,14 +46,16 @@ import Simplex.FileTransfer.Server.StoreLog
import Simplex.FileTransfer.Transport
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (CorrId, RcvPublicDhKey, RcvPublicAuthKey, RecipientId, TransmissionAuth)
import Simplex.Messaging.Protocol (CorrId, RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TransmissionAuth)
import Simplex.Messaging.Server (dummyVerifyCmd, verifyCmdAuthorization)
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Server.Stats
import Simplex.Messaging.Transport (THandleParams (..))
import Simplex.Messaging.Transport.Buffer (trimCR)
import Simplex.Messaging.Transport.HTTP2
import Simplex.Messaging.Transport.HTTP2.File (fileBlockSize)
import Simplex.Messaging.Transport.HTTP2.Server
import Simplex.Messaging.Transport.Server (runTCPServer)
import Simplex.Messaging.Util
@@ -67,13 +69,12 @@ import qualified UnliftIO.Exception as E
type M a = ReaderT XFTPEnv IO a
data XFTPTransportRequest =
XFTPTransportRequest
{ thParams :: THandleParams XFTPVersion,
reqBody :: HTTP2Body,
request :: H.Request,
sendResponse :: H.Response -> IO ()
}
data XFTPTransportRequest = XFTPTransportRequest
{ thParams :: THandleParams XFTPVersion,
reqBody :: HTTP2Body,
request :: H.Request,
sendResponse :: H.Response -> IO ()
}
runXFTPServer :: XFTPServerConfig -> IO ()
runXFTPServer cfg = do
@@ -222,15 +223,13 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
| Just auth == user = CPRUser
| otherwise = CPRNone
CPStatsRTS -> E.tryAny getRTSStats >>= either (hPrint h) (hPrint h)
CPDelete fileId fKey -> withUserRole $ unliftIO u $ do
CPDelete fileId -> withUserRole $ unliftIO u $ do
fs <- asks store
r <- runExceptT $ do
let asSender = ExceptT . atomically $ getFile fs SFSender fileId
let asRecipient = ExceptT . atomically $ getFile fs SFRecipient fileId
(fr, fKey') <- asSender `catchError` const asRecipient
if fKey == fKey'
then ExceptT $ deleteServerFile_ fr
else throwError AUTH
(fr, _) <- asSender `catchError` const asRecipient
ExceptT $ deleteServerFile_ fr
liftIO . hPutStrLn h $ either (\e -> "error: " <> show e) (\() -> "ok") r
CPHelp -> hPutStrLn h "commands: stats-rts, delete, help, quit"
CPQuit -> pure ()
@@ -331,7 +330,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
-- TODO validate body empty
sId <- ExceptT $ addFileRetry st file 3 ts
rcps <- mapM (ExceptT . addRecipientRetry st 3 sId) rks
withFileLog $ \sl -> do
lift $ withFileLog $ \sl -> do
logAddFile sl sId file ts
logAddRecipients sl sId rcps
stats <- asks serverStats
@@ -363,7 +362,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
st <- asks store
r <- runExceptT $ do
rcps <- mapM (ExceptT . addRecipientRetry st 3 sId) rks
withFileLog $ \sl -> logAddRecipients sl sId rcps
lift $ withFileLog $ \sl -> logAddRecipients sl sId rcps
stats <- asks serverStats
atomically $ modifyTVar' (fileRecipients stats) (+ length rks)
let rIds = L.map (\(FileRecipient rId _) -> rId) rcps
@@ -373,8 +372,19 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
receiveServerFile FileRec {senderId, fileInfo = FileInfo {size, digest}, filePath} = case bodyPart of
Nothing -> pure $ FRErr SIZE
-- TODO validate body size from request before downloading, once it's populated
Just getBody -> ifM reserve receive (pure $ FRErr QUOTA) -- TODO: handle duplicate uploads
Just getBody -> skipCommitted $ ifM reserve receive (pure $ FRErr QUOTA)
where
-- having a filePath means the file is already uploaded and committed, must not change anything
skipCommitted = ifM (isJust <$> readTVarIO filePath) (liftIO $ drain $ fromIntegral size)
where
-- can't send FROk without reading the request body or a client will block on sending it
-- can't send any old error as the client would fail or restart indefinitely
drain s = do
bs <- B.length <$> getBody fileBlockSize
if
| bs == s -> pure FROk
| bs == 0 || bs > s -> pure $ FRErr SIZE
| otherwise -> drain (s - bs)
reserve = do
us <- asks $ usedStorage . store
quota <- asks $ fromMaybe maxBound . fileSizeQuota . config
@@ -382,7 +392,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
\used -> let used' = used + fromIntegral size in if used' <= quota then (True, used') else (False, used)
receive = do
path <- asks $ filesPath . config
let fPath = path </> B.unpack (B64.encode senderId)
let fPath = path </> B.unpack (U.encode senderId)
receiveChunk (XFTPRcvChunkSpec fPath size digest) >>= \case
Right () -> do
stats <- asks serverStats
@@ -447,7 +457,7 @@ deleteServerFile_ FileRec {senderId, fileInfo, filePath} = do
atomically $ modifyTVar' (filesCount stats) (subtract 1)
atomically $ modifyTVar' (filesSize stats) (subtract $ fromIntegral $ size fileInfo)
randomId :: (MonadUnliftIO m, MonadReader XFTPEnv m) => Int -> m ByteString
randomId :: Int -> M ByteString
randomId n = atomically . C.randomBytes n =<< asks random
getFileId :: M XFTPFileId
@@ -455,7 +465,7 @@ getFileId = do
size <- asks (fileIdSize . config)
atomically . C.randomBytes size =<< asks random
withFileLog :: (MonadIO m, MonadReader XFTPEnv m) => (StoreLog 'WriteMode -> IO a) -> m ()
withFileLog :: (StoreLog 'WriteMode -> IO a) -> M ()
withFileLog action = liftIO . mapM_ action =<< asks storeLog
incFileStat :: (FileServerStats -> TVar Int) -> M ()
+3 -4
View File
@@ -5,7 +5,6 @@ module Simplex.FileTransfer.Server.Control where
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString (ByteString)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (BasicAuth)
@@ -14,7 +13,7 @@ data CPClientRole = CPRNone | CPRUser | CPRAdmin
data ControlProtocol
= CPAuth BasicAuth
| CPStatsRTS
| CPDelete ByteString C.APublicAuthKey
| CPDelete ByteString
| CPHelp
| CPQuit
| CPSkip
@@ -23,7 +22,7 @@ instance StrEncoding ControlProtocol where
strEncode = \case
CPAuth tok -> "auth " <> strEncode tok
CPStatsRTS -> "stats-rts"
CPDelete fId fKey -> strEncode (Str "delete", fId, fKey)
CPDelete fId -> strEncode (Str "delete", fId)
CPHelp -> "help"
CPQuit -> "quit"
CPSkip -> ""
@@ -31,7 +30,7 @@ instance StrEncoding ControlProtocol where
A.takeTill (== ' ') >>= \case
"auth" -> CPAuth <$> _strP
"stats-rts" -> pure CPStatsRTS
"delete" -> CPDelete <$> _strP <*> _strP
"delete" -> CPDelete <$> _strP
"help" -> pure CPHelp
"quit" -> pure CPQuit
"" -> pure CPSkip
+1 -1
View File
@@ -94,7 +94,7 @@ defaultFileExpiration =
checkInterval = 2 * 3600 -- seconds, 2 hours
}
newXFTPServerEnv :: (MonadUnliftIO m, MonadRandom m) => XFTPServerConfig -> m XFTPEnv
newXFTPServerEnv :: XFTPServerConfig -> IO XFTPEnv
newXFTPServerEnv config@XFTPServerConfig {storeLogFile, fileSizeQuota, caCertificateFile, certificateFile, privateKeyFile} = do
random <- liftIO C.newRandom
store <- atomically newFileStore
File diff suppressed because it is too large Load Diff
+169 -121
View File
@@ -142,7 +142,6 @@ import Crypto.Random (ChaChaDRG)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Composition ((.:.))
@@ -182,6 +181,7 @@ import Simplex.Messaging.Client
import Simplex.Messaging.Client.Agent ()
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.Base64 (encode)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol
@@ -293,10 +293,11 @@ data AgentClient = AgentClient
agentEnv :: Env
}
getAgentWorker :: (AgentMonad' m, Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> ExceptT AgentErrorType m ()) -> m Worker
getAgentWorker :: (Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> AM ()) -> AM' Worker
getAgentWorker = getAgentWorker' id pure
{-# INLINE getAgentWorker #-}
getAgentWorker' :: forall a k m. (AgentMonad' m, Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> ExceptT AgentErrorType m ()) -> m a
getAgentWorker' :: forall a k. (Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> AM ()) -> AM' a
getAgentWorker' toW fromW name hasWork c key ws work = do
atomically (getWorker >>= maybe createWorker whenExists) >>= \w -> runWorker w $> w
where
@@ -310,9 +311,9 @@ getAgentWorker' toW fromW name hasWork c key ws work = do
| otherwise = pure w
runWorker w = runWorkerAsync (toW w) runWork
where
runWork :: m ()
runWork :: AM' ()
runWork = tryAgentError' (work w) >>= restartOrDelete
restartOrDelete :: Either AgentErrorType () -> m ()
restartOrDelete :: Either AgentErrorType () -> AM' ()
restartOrDelete e_ = do
t <- liftIO getSystemTime
maxRestarts <- asks $ maxWorkerRestartsPerMin . config
@@ -350,7 +351,7 @@ newWorker c = do
restarts <- newTVar $ RestartCount 0 0
pure Worker {workerId, doWork, action, restarts}
runWorkerAsync :: AgentMonad' m => Worker -> m () -> m ()
runWorkerAsync :: Worker -> AM' () -> AM' ()
runWorkerAsync Worker {action} work =
E.bracket
(atomically $ takeTMVar action) -- get current action, locking to avoid race conditions
@@ -394,6 +395,7 @@ data AgentStatsKey = AgentStatsKey
}
deriving (Eq, Ord, Show)
-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's.
newAgentClient :: Int -> InitialAgentServers -> Env -> STM AgentClient
newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
let qSize = tbqSize $ config agentEnv
@@ -469,13 +471,15 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv =
agentClientStore :: AgentClient -> SQLiteStore
agentClientStore AgentClient {agentEnv = Env {store}} = store
{-# INLINE agentClientStore #-}
agentDRG :: AgentClient -> TVar ChaChaDRG
agentDRG AgentClient {agentEnv = Env {random}} = random
{-# INLINE agentDRG #-}
class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where
type Client msg = c | c -> msg
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (Client msg)
getProtocolServerClient :: AgentClient -> TransportSession msg -> AM (Client msg)
clientProtocolError :: err -> AgentErrorType
closeProtocolServerClient :: Client msg -> IO ()
clientServer :: Client msg -> String
@@ -509,7 +513,7 @@ instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where
clientTransportHost = X.xftpTransportHost
clientSessionTs = X.xftpSessionTs
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m SMPClient
getSMPServerClient :: AgentClient -> SMPTransportSession -> AM SMPClient
getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, _) = do
unlessM (readTVarIO active) . throwError $ INACTIVE
atomically (getTSessVar c tSess smpClients)
@@ -520,13 +524,14 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv,
-- make it expensive to check for pending subscriptions.
newClient v =
newProtocolClient c tSess smpClients connectClient v
`catchAgentError` \e -> resubscribeSMPSession c tSess >> throwError e
connectClient :: SMPClientVar -> m SMPClient
`catchAgentError` \e -> lift (resubscribeSMPSession c tSess) >> throwError e
connectClient :: SMPClientVar -> AM SMPClient
connectClient v = do
cfg <- getClientConfig c smpCfg
cfg <- lift $ getClientConfig c smpCfg
g <- asks random
env <- ask
liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient g tSess cfg (Just msgQ) $ clientDisconnected env v)
liftError' (protocolClientError SMP $ B.unpack $ strEncode srv) $
getProtocolClient g tSess cfg (Just msgQ) $ clientDisconnected env v
clientDisconnected :: Env -> SMPClientVar -> SMPClient -> IO ()
clientDisconnected env v client = do
@@ -557,7 +562,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv,
notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO ()
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd)
resubscribeSMPSession :: AgentMonad' m => AgentClient -> SMPTransportSession -> m ()
resubscribeSMPSession :: AgentClient -> SMPTransportSession -> AM' ()
resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess =
atomically getWorkerVar >>= mapM_ (either newSubWorker (\_ -> pure ()))
where
@@ -585,12 +590,12 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess =
whenM (isEmptyTMVar $ sessionVar v) retry
removeTSessVar v tSess smpSubWorkers
reconnectSMPClient :: forall m. AgentMonad m => TVar Int -> AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> m ()
reconnectSMPClient :: TVar Int -> AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> AM ()
reconnectSMPClient tc c tSess@(_, srv, _) qs = do
NetworkConfig {tcpTimeout} <- readTVarIO $ useNetworkConfig c
-- this allows 3x of timeout per batch of subscription (90 queues per batch empirically)
let t = (length qs `div` 90 + 1) * tcpTimeout * 3
t `timeout` resubscribe >>= \case
ExceptT (sequence <$> (t `timeout` runExceptT resubscribe)) >>= \case
Just _ -> atomically $ writeTVar tc 0
Nothing -> do
tc' <- atomically $ stateTVar tc $ \i -> (i + 1, i + 1)
@@ -599,10 +604,10 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do
msg = show tc' <> " consecutive subscription timeouts: " <> show (length qs) <> " queues, transport session: " <> show tSess
atomically $ writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ err msg)
where
resubscribe :: m ()
resubscribe :: AM ()
resubscribe = do
cs <- readTVarIO $ RQ.getConnections $ activeSubs c
rs <- subscribeQueues c $ L.toList qs
rs <- lift . subscribeQueues c $ L.toList qs
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
liftIO $ do
let conns = filter (`M.notMember` cs) okConns
@@ -616,7 +621,7 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do
notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO ()
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd)
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfTransportSession -> m NtfClient
getNtfServerClient :: AgentClient -> NtfTransportSession -> AM NtfClient
getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do
unlessM (readTVarIO active) . throwError $ INACTIVE
atomically (getTSessVar c tSess ntfClients)
@@ -624,11 +629,12 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d
(newProtocolClient c tSess ntfClients connectClient)
(waitForProtocolClient c tSess)
where
connectClient :: NtfClientVar -> m NtfClient
connectClient :: NtfClientVar -> AM NtfClient
connectClient v = do
cfg <- getClientConfig c ntfCfg
cfg <- lift $ getClientConfig c ntfCfg
g <- asks random
liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient g tSess cfg Nothing $ clientDisconnected v)
liftError' (protocolClientError NTF $ B.unpack $ strEncode srv) $
getProtocolClient g tSess cfg Nothing $ clientDisconnected v
clientDisconnected :: NtfClientVar -> NtfClient -> IO ()
clientDisconnected v client = do
@@ -637,7 +643,7 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d
atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
getXFTPServerClient :: forall m. AgentMonad m => AgentClient -> XFTPTransportSession -> m XFTPClient
getXFTPServerClient :: AgentClient -> XFTPTransportSession -> AM XFTPClient
getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@(userId, srv, _) = do
unlessM (readTVarIO active) . throwError $ INACTIVE
atomically (getTSessVar c tSess xftpClients)
@@ -645,11 +651,12 @@ getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@
(newProtocolClient c tSess xftpClients connectClient)
(waitForProtocolClient c tSess)
where
connectClient :: XFTPClientVar -> m XFTPClient
connectClient :: XFTPClientVar -> AM XFTPClient
connectClient v = do
cfg <- asks $ xftpCfg . config
xftpNetworkConfig <- readTVarIO useNetworkConfig
liftEitherError (protocolClientError XFTP $ B.unpack $ strEncode srv) (X.getXFTPClient tSess cfg {xftpNetworkConfig} $ clientDisconnected v)
liftError' (protocolClientError XFTP $ B.unpack $ strEncode srv) $
X.getXFTPClient tSess cfg {xftpNetworkConfig} $ clientDisconnected v
clientDisconnected :: XFTPClientVar -> XFTPClient -> IO ()
clientDisconnected v client = do
@@ -671,6 +678,7 @@ getTSessVar c tSess vs = maybe (Left <$> newSessionVar) (pure . Right) =<< TM.lo
removeTSessVar :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM ()
removeTSessVar = void .:. removeTSessVar'
{-# INLINE removeTSessVar #-}
removeTSessVar' :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM Bool
removeTSessVar' v tSess vs =
@@ -678,7 +686,7 @@ removeTSessVar' v tSess vs =
Just v' | sessionVarId v == sessionVarId v' -> TM.delete tSess vs $> True
_ -> pure False
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (Client msg)
waitForProtocolClient :: ProtocolTypeI (ProtoType msg) => AgentClient -> TransportSession msg -> ClientVar msg -> AM (Client msg)
waitForProtocolClient c (_, srv, _) v = do
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v)
@@ -689,14 +697,14 @@ waitForProtocolClient c (_, srv, _) v = do
-- clientConnected arg is only passed for SMP server
newProtocolClient ::
forall v err msg m.
(AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) =>
forall v err msg.
(ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) =>
AgentClient ->
TransportSession msg ->
TMap (TransportSession msg) (ClientVar msg) ->
(ClientVar msg -> m (Client msg)) ->
(ClientVar msg -> AM (Client msg)) ->
ClientVar msg ->
m (Client msg)
AM (Client msg)
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
tryAgentError (connectClient v) >>= \case
Right client -> do
@@ -715,14 +723,14 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone
hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost
getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> m (ProtocolClientConfig v)
getClientConfig :: AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> AM' (ProtocolClientConfig v)
getClientConfig AgentClient {useNetworkConfig} cfgSel = do
cfg <- asks $ cfgSel . config
networkConfig <- readTVarIO useNetworkConfig
pure cfg {networkConfig}
closeAgentClient :: MonadIO m => AgentClient -> m ()
closeAgentClient c = liftIO $ do
closeAgentClient :: AgentClient -> IO ()
closeAgentClient c = do
atomically $ writeTVar (active c) False
closeProtocolServerClients c smpClients
closeProtocolServerClients c ntfClients
@@ -750,9 +758,11 @@ cancelWorker Worker {doWork, action} = do
waitUntilActive :: AgentClient -> STM ()
waitUntilActive c = unlessM (readTVar $ active c) retry
{-# INLINE waitUntilActive #-}
throwWhenInactive :: AgentClient -> STM ()
throwWhenInactive c = unlessM (readTVar $ active c) $ throwSTM ThreadKilled
{-# INLINE throwWhenInactive #-}
-- this function is used to remove workers once delivery is complete, not when it is removed from the map
throwWhenNoDelivery :: AgentClient -> SndQueue -> STM ()
@@ -779,82 +789,98 @@ closeClient_ c v = do
Just (Right client) -> closeProtocolServerClient client `catchAll_` pure ()
_ -> pure ()
closeXFTPServerClient :: AgentMonad' m => AgentClient -> UserId -> XFTPServer -> FileDigest -> m ()
closeXFTPServerClient :: AgentClient -> UserId -> XFTPServer -> FileDigest -> IO ()
closeXFTPServerClient c userId server (FileDigest chunkDigest) =
mkTransportSession c userId server chunkDigest >>= liftIO . closeClient c xftpClients
mkTransportSession c userId server chunkDigest >>= closeClient c xftpClients
withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a
withConnLock _ "" _ = id
withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId name
withConnLock :: AgentClient -> ConnId -> String -> AM a -> AM a
withConnLock c connId name = ExceptT . withConnLock' c connId name . runExceptT
{-# INLINE withConnLock #-}
withInvLock :: MonadUnliftIO m => AgentClient -> ByteString -> String -> m a -> m a
withInvLock AgentClient {invLocks} = withLockMap_ invLocks
withConnLock' :: AgentClient -> ConnId -> String -> AM' a -> AM' a
withConnLock' _ "" _ = id
withConnLock' AgentClient {connLocks} connId name = withLockMap_ connLocks connId name
{-# INLINE withConnLock' #-}
withConnLocks :: MonadUnliftIO m => AgentClient -> [ConnId] -> String -> m a -> m a
withInvLock :: AgentClient -> ByteString -> String -> AM a -> AM a
withInvLock c key name = ExceptT . withInvLock' c key name . runExceptT
{-# INLINE withInvLock #-}
withInvLock' :: AgentClient -> ByteString -> String -> AM' a -> AM' a
withInvLock' AgentClient {invLocks} = withLockMap_ invLocks
{-# INLINE withInvLock' #-}
withConnLocks :: AgentClient -> [ConnId] -> String -> AM' a -> AM' a
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null)
{-# INLINE withConnLocks #-}
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
withLockMap_ = withGetLock . getMapLock
{-# INLINE withLockMap_ #-}
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a
withLocksMap_ = withGetLocks . getMapLock
{-# INLINE withLocksMap_ #-}
getMapLock :: Ord k => TMap k Lock -> k -> STM Lock
getMapLock locks key = TM.lookup key locks >>= maybe newLock pure
where
newLock = createLock >>= \l -> TM.insert key l locks $> l
withClient_ :: forall a m v err msg. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a
withClient_ :: forall a v err msg. ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> AM a) -> AM a
withClient_ c tSess@(userId, srv, _) statCmd action = do
cl <- getProtocolServerClient c tSess
(action cl <* stat cl "OK") `catchAgentError` logServerError cl
where
stat cl = liftIO . incClientStat c userId cl statCmd
logServerError :: Client msg -> AgentErrorType -> m a
logServerError :: Client msg -> AgentErrorType -> AM a
logServerError cl e = do
logServer "<--" c srv "" $ strEncode e
stat cl $ strEncode e
throwError e
withLogClient_ :: (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a
withLogClient_ :: ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> AM a) -> AM a
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
logServer "-->" c srv entId cmdStr
res <- withClient_ c tSess cmdStr action
logServer "<--" c srv entId "OK"
return res
withClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withClient :: forall v err msg a. ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> AM a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client
{-# INLINE withClient #-}
withLogClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withLogClient :: forall v err msg a. ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> AM a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client
{-# INLINE withLogClient #-}
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
withSMPClient :: SMPQueueRec q => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> AM a
withSMPClient c q cmdStr action = do
tSess <- mkSMPTransportSession c q
tSess <- liftIO $ mkSMPTransportSession c q
withLogClient c tSess (queueId q) cmdStr action
withSMPClient_ :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> m a) -> m a
withSMPClient_ :: SMPQueueRec q => AgentClient -> q -> ByteString -> (SMPClient -> AM a) -> AM a
withSMPClient_ c q cmdStr action = do
tSess <- mkSMPTransportSession c q
tSess <- liftIO $ mkSMPTransportSession c q
withLogClient_ c tSess (queueId q) cmdStr action
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT NtfClientError IO a) -> m a
withNtfClient :: AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT NtfClientError IO a) -> AM a
withNtfClient c srv = withLogClient c (0, srv, Nothing)
withXFTPClient ::
(AgentMonad m, ProtocolServerClient v err msg) =>
ProtocolServerClient v err msg =>
AgentClient ->
(UserId, ProtoServer msg, EntityId) ->
ByteString ->
(Client msg -> ExceptT (ProtocolClientError err) IO b) ->
m b
AM b
withXFTPClient c (userId, srv, entityId) cmdStr action = do
tSess <- mkTransportSession c userId srv entityId
tSess <- liftIO $ mkTransportSession c userId srv entityId
withLogClient c tSess entityId cmdStr action
liftClient :: (AgentMonad m, Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> m a
liftClient :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> AM a
liftClient protocolError_ = liftError . protocolClientError protocolError_
{-# INLINE liftClient #-}
protocolClientError :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ProtocolClientError err -> AgentErrorType
protocolClientError protocolError_ host = \case
@@ -889,7 +915,7 @@ data ProtocolTestFailure = ProtocolTestFailure
}
deriving (Eq, Show)
runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe ProtocolTestFailure)
runSMPServerTest :: AgentClient -> UserId -> SMPServerWithAuth -> AM' (Maybe ProtocolTestFailure)
runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
cfg <- getClientConfig c smpCfg
C.AuthAlg ra <- asks $ rcvAuthAlg . config
@@ -915,7 +941,7 @@ runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
testErr :: ProtocolTestStep -> SMPClientError -> ProtocolTestFailure
testErr step = ProtocolTestFailure step . protocolClientError SMP addr
runXFTPServerTest :: forall m. AgentMonad m => AgentClient -> UserId -> XFTPServerWithAuth -> m (Maybe ProtocolTestFailure)
runXFTPServerTest :: AgentClient -> UserId -> XFTPServerWithAuth -> AM' (Maybe ProtocolTestFailure)
runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do
cfg <- asks $ xftpCfg . config
g <- asks random
@@ -949,7 +975,7 @@ runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do
testErr step = ProtocolTestFailure step . protocolClientError XFTP addr
chSize :: Integral a => a
chSize = kb 64
getTempFilePath :: FilePath -> m FilePath
getTempFilePath :: FilePath -> AM' FilePath
getTempFilePath workPath = do
ts <- liftIO getCurrentTime
let isoTime = formatTime defaultTimeLocale "%Y-%m-%dT%H%M%S.%6q" ts
@@ -963,7 +989,7 @@ runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do
createTestChunk :: FilePath -> IO ()
createTestChunk fp = B.writeFile fp =<< atomically . C.randomBytes chSize =<< C.newRandom
runNTFServerTest :: AgentMonad m => AgentClient -> UserId -> NtfServerWithAuth -> m (Maybe ProtocolTestFailure)
runNTFServerTest :: AgentClient -> UserId -> NtfServerWithAuth -> AM' (Maybe ProtocolTestFailure)
runNTFServerTest c userId (ProtoServerWithAuth srv _) = do
cfg <- getClientConfig c ntfCfg
C.AuthAlg a <- asks $ rcvAuthAlg . config
@@ -987,27 +1013,32 @@ runNTFServerTest c userId (ProtoServerWithAuth srv _) = do
testErr :: ProtocolTestStep -> SMPClientError -> ProtocolTestFailure
testErr step = ProtocolTestFailure step . protocolClientError NTF addr
getXFTPWorkPath :: AgentMonad m => m FilePath
getXFTPWorkPath :: AM' FilePath
getXFTPWorkPath = do
workDir <- readTVarIO =<< asks (xftpWorkDir . xftpAgent)
maybe getTemporaryDirectory pure workDir
mkTransportSession :: AgentMonad' m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg)
mkTransportSession :: AgentClient -> UserId -> ProtoServer msg -> EntityId -> IO (TransportSession msg)
mkTransportSession c userId srv entityId = mkTSession userId srv entityId <$> getSessionMode c
{-# INLINE mkTransportSession #-}
mkTSession :: UserId -> ProtoServer msg -> EntityId -> TransportSessionMode -> TransportSession msg
mkTSession userId srv entityId mode = (userId, srv, if mode == TSMEntity then Just entityId else Nothing)
{-# INLINE mkTSession #-}
mkSMPTransportSession :: (AgentMonad' m, SMPQueueRec q) => AgentClient -> q -> m SMPTransportSession
mkSMPTransportSession :: SMPQueueRec q => AgentClient -> q -> IO SMPTransportSession
mkSMPTransportSession c q = mkSMPTSession q <$> getSessionMode c
{-# INLINE mkSMPTransportSession #-}
mkSMPTSession :: SMPQueueRec q => q -> TransportSessionMode -> SMPTransportSession
mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q)
{-# INLINE mkSMPTSession #-}
getSessionMode :: AgentMonad' m => AgentClient -> m TransportSessionMode
getSessionMode :: AgentClient -> IO TransportSessionMode
getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig
{-# INLINE getSessionMode #-}
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri)
newRcvQueue :: AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> AM (NewRcvQueue, SMPQueueUri)
newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do
C.AuthAlg a <- asks (rcvAuthAlg . config)
g <- asks random
@@ -1015,10 +1046,10 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do
(dhKey, privDhKey) <- atomically $ C.generateKeyPair g
(e2eDhKey, e2ePrivKey) <- atomically $ C.generateKeyPair g
logServer "-->" c srv "" "NEW"
tSess <- mkTransportSession c userId srv connId
tSess <- liftIO $ mkTransportSession c userId srv connId
QIK {rcvId, sndId, rcvPublicDhKey} <-
withClient c tSess "NEW" $ \smp -> createSMPQueue smp rKeys dhKey auth subMode
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
liftIO . logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
let rq =
RcvQueue
{ userId,
@@ -1057,14 +1088,16 @@ temporaryAgentError = \case
BROKER _ TIMEOUT -> True
INACTIVE -> True
_ -> False
{-# INLINE temporaryAgentError #-}
temporaryOrHostError :: AgentErrorType -> Bool
temporaryOrHostError = \case
BROKER _ HOST -> True
e -> temporaryAgentError e
{-# INLINE temporaryOrHostError #-}
-- | Subscribe to queues. The list of results can have a different order.
subscribeQueues :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
subscribeQueues :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())]
subscribeQueues c qs = do
(errs, qs') <- partitionEithers <$> mapM checkQueue qs
atomically $ do
@@ -1088,11 +1121,11 @@ subscribeQueues c qs = do
type BatchResponses e r = (NonEmpty (RcvQueue, Either e r))
-- statBatchSize is not used to batch the commands, only for traffic statistics
sendTSessionBatches :: forall m q r. AgentMonad' m => ByteString -> Int -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses SMPClientError r)) -> AgentClient -> [q] -> m [(RcvQueue, Either AgentErrorType r)]
sendTSessionBatches :: forall q r. ByteString -> Int -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses SMPClientError r)) -> AgentClient -> [q] -> AM' [(RcvQueue, Either AgentErrorType r)]
sendTSessionBatches statCmd statBatchSize toRQ action c qs =
concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues)
where
batchQueues :: m [(SMPTransportSession, NonEmpty q)]
batchQueues :: AM' [(SMPTransportSession, NonEmpty q)]
batchQueues = do
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
pure . M.assocs $ foldl' (batch mode) M.empty qs
@@ -1100,7 +1133,7 @@ sendTSessionBatches statCmd statBatchSize toRQ action c qs =
batch mode m q =
let tSess = mkSMPTSession (toRQ q) mode
in M.alter (Just . maybe [q] (q <|)) tSess m
sendClientBatch :: (SMPTransportSession, NonEmpty q) -> m (BatchResponses AgentErrorType r)
sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses AgentErrorType r)
sendClientBatch (tSess@(userId, srv, _), qs') =
tryAgentError' (getSMPServerClient c tSess) >>= \case
Left e -> pure $ L.map ((,Left e) . toRQ) qs'
@@ -1120,7 +1153,7 @@ sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs)
where
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m ()
addSubscription :: AgentClient -> RcvQueue -> IO ()
addSubscription c rq@RcvQueue {connId} = atomically $ do
modifyTVar' (subscrConns c) $ S.insert connId
RQ.addQueue rq $ activeSubs c
@@ -1128,6 +1161,7 @@ addSubscription c rq@RcvQueue {connId} = atomically $ do
hasActiveSubscription :: AgentClient -> ConnId -> STM Bool
hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c
{-# INLINE hasActiveSubscription #-}
removeSubscription :: AgentClient -> ConnId -> STM ()
removeSubscription c connId = do
@@ -1137,19 +1171,23 @@ removeSubscription c connId = do
getSubscriptions :: AgentClient -> STM (Set ConnId)
getSubscriptions = readTVar . subscrConns
{-# INLINE getSubscriptions #-}
logServer :: MonadIO m => ByteString -> AgentClient -> ProtocolServer s -> QueueId -> ByteString -> m ()
logServer dir AgentClient {clientId} srv qId cmdStr =
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr]
{-# INLINE logServer #-}
showServer :: ProtocolServer s -> ByteString
showServer ProtocolServer {host, port} =
strEncode host <> B.pack (if null port then "" else ':' : port)
{-# INLINE showServer #-}
logSecret :: ByteString -> ByteString
logSecret bs = encode $ B.take 3 bs
{-# INLINE logSecret #-}
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
sendConfirmation :: AgentClient -> SndQueue -> ByteString -> AM ()
sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
withSMPClient_ c sq "SEND <CONF>" $ \smp -> do
let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation
@@ -1157,21 +1195,21 @@ sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubK
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
sendInvitation :: AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> AM ()
sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do
tSess <- mkTransportSession c userId smpServer senderId
tSess <- liftIO $ mkTransportSession c userId smpServer senderId
withLogClient_ c tSess senderId "SEND <INV>" $ \smp -> do
msg <- mkInvitation
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing senderId MsgFlags {notification = True} msg
where
mkInvitation :: m ByteString
mkInvitation :: AM ByteString
-- this is only encrypted with per-queue E2E, not with double ratchet
mkInvitation = do
let agentEnvelope = AgentInvitation {agentVersion, connReq, connInfo}
agentCbEncryptOnce v dhPublicKey . smpEncode $
SMP.ClientMessage SMP.PHEmpty (smpEncode agentEnvelope)
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta)
getQueueMessage :: AgentClient -> RcvQueue -> AM (Maybe SMPMsgMeta)
getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
atomically createTakeGetLock
msg_ <- withSMPClient c rq "GET" $ \smp ->
@@ -1186,23 +1224,23 @@ getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
takeTMVar l
pure $ Just l
decryptSMPMessage :: AgentMonad m => RcvQueue -> SMP.RcvMessage -> m SMP.ClientRcvMsgBody
decryptSMPMessage :: RcvQueue -> SMP.RcvMessage -> AM SMP.ClientRcvMsgBody
decryptSMPMessage rq SMP.RcvMessage {msgId, msgBody = SMP.EncRcvMsgBody body} =
liftEither . parse SMP.clientRcvMsgBodyP (AGENT A_MESSAGE) =<< decrypt body
where
decrypt = agentCbDecrypt (rcvDhSecret rq) (C.cbNonce msgId)
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicAuthKey -> m ()
secureQueue :: AgentClient -> RcvQueue -> SndPublicAuthKey -> AM ()
secureQueue c rq@RcvQueue {rcvId, rcvPrivateKey} senderKey =
withSMPClient c rq "KEY <key>" $ \smp ->
secureSMPQueue smp rcvPrivateKey rcvId senderKey
enableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> SMP.NtfPublicAuthKey -> SMP.RcvNtfPublicDhKey -> m (SMP.NotifierId, SMP.RcvNtfPublicDhKey)
enableQueueNotifications :: AgentClient -> RcvQueue -> SMP.NtfPublicAuthKey -> SMP.RcvNtfPublicDhKey -> AM (SMP.NotifierId, SMP.RcvNtfPublicDhKey)
enableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
withSMPClient c rq "NKEY <nkey>" $ \smp ->
enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey
enableQueuesNtfs :: forall m. AgentMonad' m => AgentClient -> [(RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)] -> m [(RcvQueue, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))]
enableQueuesNtfs :: AgentClient -> [(RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)] -> AM' [(RcvQueue, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))]
enableQueuesNtfs = sendTSessionBatches "NKEY" 90 fst3 enableQueues_
where
fst3 (x, _, _) = x
@@ -1211,15 +1249,15 @@ enableQueuesNtfs = sendTSessionBatches "NKEY" 90 fst3 enableQueues_
queueCreds :: (RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) -> (SMP.RcvPrivateAuthKey, SMP.RecipientId, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)
queueCreds (RcvQueue {rcvPrivateKey, rcvId}, notifierKey, rcvNtfPublicDhKey) = (rcvPrivateKey, rcvId, notifierKey, rcvNtfPublicDhKey)
disableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> m ()
disableQueueNotifications :: AgentClient -> RcvQueue -> AM ()
disableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} =
withSMPClient c rq "NDEL" $ \smp ->
disableSMPQueueNotifications smp rcvPrivateKey rcvId
disableQueuesNtfs :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
disableQueuesNtfs :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())]
disableQueuesNtfs = sendTSessionBatches "NDEL" 90 id $ sendBatch disableSMPQueuesNtfs
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m ()
sendAck :: AgentClient -> RcvQueue -> MsgId -> AM ()
sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do
withSMPClient c rq ("ACK:" <> logSecret msgId) $ \smp ->
ackSMPMessage smp rcvPrivateKey rcvId msgId
@@ -1233,93 +1271,93 @@ releaseGetLock :: AgentClient -> RcvQueue -> STM ()
releaseGetLock c RcvQueue {server, rcvId} =
TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ())
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
suspendQueue :: AgentClient -> RcvQueue -> AM ()
suspendQueue c rq@RcvQueue {rcvId, rcvPrivateKey} =
withSMPClient c rq "OFF" $ \smp ->
suspendSMPQueue smp rcvPrivateKey rcvId
deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
deleteQueue :: AgentClient -> RcvQueue -> AM ()
deleteQueue c rq@RcvQueue {rcvId, rcvPrivateKey} = do
withSMPClient c rq "DEL" $ \smp ->
deleteSMPQueue smp rcvPrivateKey rcvId
deleteQueues :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
deleteQueues :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())]
deleteQueues = sendTSessionBatches "DEL" 90 id $ sendBatch deleteSMPQueues
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
sendAgentMessage :: AgentClient -> SndQueue -> MsgFlags -> ByteString -> AM ()
sendAgentMessage c sq@SndQueue {sndId, sndPrivateKey} msgFlags agentMsg =
withSMPClient_ c sq "SEND <MSG>" $ \smp -> do
let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
liftClient SMP (clientServer smp) $ sendSMPMessage smp (Just sndPrivateKey) sndId msgFlags msg
agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> NtfPublicAuthKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519)
agentNtfRegisterToken :: AgentClient -> NtfToken -> NtfPublicAuthKey -> C.PublicKeyX25519 -> AM (NtfTokenId, C.PublicKeyX25519)
agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey =
withClient c (0, ntfServer, Nothing) "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m ()
agentNtfVerifyToken :: AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> AM ()
agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code =
withNtfClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus
agentNtfCheckToken :: AgentClient -> NtfTokenId -> NtfToken -> AM NtfTknStatus
agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} =
withNtfClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
agentNtfReplaceToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> m ()
agentNtfReplaceToken :: AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> AM ()
agentNtfReplaceToken c tknId NtfToken {ntfServer, ntfPrivKey} token =
withNtfClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m ()
agentNtfDeleteToken :: AgentClient -> NtfTokenId -> NtfToken -> AM ()
agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} =
withNtfClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m ()
agentNtfEnableCron :: AgentClient -> NtfTokenId -> NtfToken -> Word16 -> AM ()
agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval =
withNtfClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
agentNtfCreateSubscription :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> SMP.NtfPrivateAuthKey -> m NtfSubscriptionId
agentNtfCreateSubscription :: AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> SMP.NtfPrivateAuthKey -> AM NtfSubscriptionId
agentNtfCreateSubscription c tknId NtfToken {ntfServer, ntfPrivKey} smpQueue nKey =
withNtfClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
agentNtfCheckSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m NtfSubStatus
agentNtfCheckSubscription :: AgentClient -> NtfSubscriptionId -> NtfToken -> AM NtfSubStatus
agentNtfCheckSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
withNtfClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
agentNtfDeleteSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m ()
agentNtfDeleteSubscription :: AgentClient -> NtfSubscriptionId -> NtfToken -> AM ()
agentNtfDeleteSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
withNtfClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
agentXFTPDownloadChunk :: AgentMonad m => AgentClient -> UserId -> FileDigest -> RcvFileChunkReplica -> XFTPRcvChunkSpec -> m ()
agentXFTPDownloadChunk :: AgentClient -> UserId -> FileDigest -> RcvFileChunkReplica -> XFTPRcvChunkSpec -> AM ()
agentXFTPDownloadChunk c userId (FileDigest chunkDigest) RcvFileChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey} chunkSpec = do
g <- asks random
withXFTPClient c (userId, server, chunkDigest) "FGET" $ \xftp -> X.downloadXFTPChunk g xftp replicaKey fId chunkSpec
agentXFTPNewChunk :: AgentMonad m => AgentClient -> SndFileChunk -> Int -> XFTPServerWithAuth -> m NewSndChunkReplica
agentXFTPNewChunk :: AgentClient -> SndFileChunk -> Int -> XFTPServerWithAuth -> AM NewSndChunkReplica
agentXFTPNewChunk c SndFileChunk {userId, chunkSpec = XFTPChunkSpec {chunkSize}, digest = FileDigest chunkDigest} n (ProtoServerWithAuth srv auth) = do
rKeys <- xftpRcvKeys n
(sndKey, replicaKey) <- atomically . C.generateAuthKeyPair C.SEd25519 =<< asks random
let fileInfo = FileInfo {sndKey, size = fromIntegral chunkSize, digest = chunkDigest}
logServer "-->" c srv "" "FNEW"
tSess <- mkTransportSession c userId srv chunkDigest
tSess <- liftIO $ mkTransportSession c userId srv chunkDigest
(sndId, rIds) <- withClient c tSess "FNEW" $ \xftp -> X.createXFTPChunk xftp replicaKey fileInfo (L.map fst rKeys) auth
logServer "<--" c srv "" $ B.unwords ["SIDS", logSecret sndId]
pure NewSndChunkReplica {server = srv, replicaId = ChunkReplicaId sndId, replicaKey, rcvIdsKeys = L.toList $ xftpRcvIdsKeys rIds rKeys}
agentXFTPUploadChunk :: AgentMonad m => AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> XFTPChunkSpec -> m ()
agentXFTPUploadChunk :: AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> XFTPChunkSpec -> AM ()
agentXFTPUploadChunk c userId (FileDigest chunkDigest) SndFileChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey} chunkSpec =
withXFTPClient c (userId, server, chunkDigest) "FPUT" $ \xftp -> X.uploadXFTPChunk xftp replicaKey fId chunkSpec
agentXFTPAddRecipients :: AgentMonad m => AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> Int -> m (NonEmpty (ChunkReplicaId, C.APrivateAuthKey))
agentXFTPAddRecipients :: AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> Int -> AM (NonEmpty (ChunkReplicaId, C.APrivateAuthKey))
agentXFTPAddRecipients c userId (FileDigest chunkDigest) SndFileChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey} n = do
rKeys <- xftpRcvKeys n
rIds <- withXFTPClient c (userId, server, chunkDigest) "FADD" $ \xftp -> X.addXFTPRecipients xftp replicaKey fId (L.map fst rKeys)
pure $ xftpRcvIdsKeys rIds rKeys
agentXFTPDeleteChunk :: AgentMonad m => AgentClient -> UserId -> DeletedSndChunkReplica -> m ()
agentXFTPDeleteChunk :: AgentClient -> UserId -> DeletedSndChunkReplica -> AM ()
agentXFTPDeleteChunk c userId DeletedSndChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey, chunkDigest = FileDigest chunkDigest} =
withXFTPClient c (userId, server, chunkDigest) "FDEL" $ \xftp -> X.deleteXFTPChunk xftp replicaKey fId
xftpRcvKeys :: AgentMonad m => Int -> m (NonEmpty C.AAuthKeyPair)
xftpRcvKeys :: Int -> AM (NonEmpty C.AAuthKeyPair)
xftpRcvKeys n = do
rKeys <- atomically . replicateM n . C.generateAuthKeyPair C.SEd25519 =<< asks random
case L.nonEmpty rKeys of
@@ -1329,7 +1367,7 @@ xftpRcvKeys n = do
xftpRcvIdsKeys :: NonEmpty ByteString -> NonEmpty C.AAuthKeyPair -> NonEmpty (ChunkReplicaId, C.APrivateAuthKey)
xftpRcvIdsKeys rIds rKeys = L.map ChunkReplicaId rIds `L.zip` L.map snd rKeys
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
agentCbEncrypt :: SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> AM ByteString
agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
cmNonce <- atomically . C.randomCbNonce =<< asks random
let paddedLen = maybe SMP.e2eEncMessageLength (const SMP.e2eEncConfirmationLength) e2ePubKey
@@ -1340,7 +1378,7 @@ agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody}
-- add encoding as AgentInvitation'?
agentCbEncryptOnce :: AgentMonad m => VersionSMPC -> C.PublicKeyX25519 -> ByteString -> m ByteString
agentCbEncryptOnce :: VersionSMPC -> C.PublicKeyX25519 -> ByteString -> AM ByteString
agentCbEncryptOnce clientVersion dhRcvPubKey msg = do
g <- asks random
(dhSndPubKey, dhSndPrivKey) <- atomically $ C.generateKeyPair g
@@ -1354,7 +1392,7 @@ agentCbEncryptOnce clientVersion dhRcvPubKey msg = do
-- | NaCl crypto-box decrypt - both for messages received from the server
-- and per-queue E2E encrypted messages from the sender that were inside.
agentCbDecrypt :: AgentMonad m => C.DhSecretX25519 -> C.CbNonce -> ByteString -> m ByteString
agentCbDecrypt :: C.DhSecretX25519 -> C.CbNonce -> ByteString -> AM ByteString
agentCbDecrypt dhSecret nonce msg =
liftEither . first cryptoError $
C.cbDecrypt dhSecret nonce msg
@@ -1373,10 +1411,11 @@ cryptoError = \case
where
c = AGENT . A_CRYPTO
waitForWork :: AgentMonad' m => TMVar () -> m ()
waitForWork :: MonadIO m => TMVar () -> m ()
waitForWork = void . atomically . readTMVar
{-# INLINE waitForWork #-}
withWork :: AgentMonad m => AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> m ()) -> m ()
withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> AM ()) -> AM ()
withWork c doWork getWork action =
withStore' c getWork >>= \case
Right (Just r) -> action r
@@ -1389,12 +1428,15 @@ withWork c doWork getWork action =
noWorkToDo :: TMVar () -> IO ()
noWorkToDo = void . atomically . tryTakeTMVar
{-# INLINE noWorkToDo #-}
hasWorkToDo :: Worker -> STM ()
hasWorkToDo = hasWorkToDo' . doWork
{-# INLINE hasWorkToDo #-}
hasWorkToDo' :: TMVar () -> STM ()
hasWorkToDo' = void . (`tryPutTMVar` ())
{-# INLINE hasWorkToDo' #-}
endAgentOperation :: AgentClient -> AgentOperation -> STM ()
endAgentOperation c op = endOperation c op $ case op of
@@ -1438,6 +1480,7 @@ endOperation c op endedAction = do
whenSuspending :: AgentClient -> STM () -> STM ()
whenSuspending c = whenM ((== ASSuspending) <$> readTVar (agentState c))
{-# INLINE whenSuspending #-}
beginAgentOperation :: AgentClient -> AgentOperation -> STM ()
beginAgentOperation c op = do
@@ -1457,20 +1500,22 @@ agentOperationBracket c op check action =
waitUntilForeground :: AgentClient -> STM ()
waitUntilForeground c = unlessM ((ASForeground ==) <$> readTVar (agentState c)) retry
{-# INLINE waitUntilForeground #-}
withStore' :: AgentMonad m => AgentClient -> (DB.Connection -> IO a) -> m a
withStore' :: AgentClient -> (DB.Connection -> IO a) -> AM a
withStore' c action = withStore c $ fmap Right . action
{-# INLINE withStore' #-}
withStore :: AgentMonad m => AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a
withStore :: AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> AM a
withStore c action = do
st <- asks store
liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $
withExceptT storeError . ExceptT . liftIO . agentOperationBracket c AODatabase (\_ -> pure ()) $
withTransaction st action `E.catch` handleInternal ""
where
handleInternal :: String -> E.SomeException -> IO (Either StoreError a)
handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr
withStoreBatch :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> m (t (Either AgentErrorType a))
withStoreBatch :: Traversable t => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> AM' (t (Either AgentErrorType a))
withStoreBatch c actions = do
st <- asks store
liftIO . agentOperationBracket c AODatabase (\_ -> pure ()) $
@@ -1480,8 +1525,9 @@ withStoreBatch c actions = do
handleInternal :: E.SomeException -> IO (Either AgentErrorType a)
handleInternal = pure . Left . INTERNAL . show
withStoreBatch' :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO a)) -> m (t (Either AgentErrorType a))
withStoreBatch' :: Traversable t => AgentClient -> (DB.Connection -> t (IO a)) -> AM' (t (Either AgentErrorType a))
withStoreBatch' c actions = withStoreBatch c (fmap (fmap Right) . actions)
{-# INLINE withStoreBatch' #-}
storeError :: StoreError -> AgentErrorType
storeError = \case
@@ -1505,6 +1551,7 @@ incStat AgentClient {agentStats} n k = do
incClientStat :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO ()
incClientStat c userId pc = incClientStatN c userId pc 1
{-# INLINE incClientStat #-}
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
incServerStat c userId ProtocolServer {host} cmd res = do
@@ -1523,27 +1570,28 @@ userServers :: forall p. (ProtocolTypeI p, UserProtocol p) => AgentClient -> TMa
userServers c = case protocolTypeI @p of
SPSMP -> smpServers c
SPXFTP -> xftpServers c
{-# INLINE userServers #-}
pickServer :: forall p m. AgentMonad' m => NonEmpty (ProtoServerWithAuth p) -> m (ProtoServerWithAuth p)
pickServer :: forall p. NonEmpty (ProtoServerWithAuth p) -> AM (ProtoServerWithAuth p)
pickServer = \case
srv :| [] -> pure srv
servers -> do
gen <- asks randomServer
atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1))
getNextServer :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> [ProtocolServer p] -> m (ProtoServerWithAuth p)
getNextServer :: forall p. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> [ProtocolServer p] -> AM (ProtoServerWithAuth p)
getNextServer c userId usedSrvs = withUserServers c userId $ \srvs ->
case L.nonEmpty $ deleteFirstsBy sameSrvAddr' (L.toList srvs) (map noAuthSrv usedSrvs) of
Just srvs' -> pickServer srvs'
_ -> pickServer srvs
withUserServers :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> (NonEmpty (ProtoServerWithAuth p) -> m a) -> m a
withUserServers :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> (NonEmpty (ProtoServerWithAuth p) -> AM a) -> AM a
withUserServers c userId action =
atomically (TM.lookup userId $ userServers c) >>= \case
Just srvs -> action srvs
_ -> throwError $ INTERNAL "unknown userId - no user servers"
withNextSrv :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> m a) -> m a
withNextSrv :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> AM a) -> AM a
withNextSrv c userId usedSrvs initUsed action = do
used <- readTVarIO usedSrvs
srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId used
@@ -1564,7 +1612,7 @@ data SubscriptionsInfo = SubscriptionsInfo
}
deriving (Show)
getAgentSubscriptions :: MonadIO m => AgentClient -> m SubscriptionsInfo
getAgentSubscriptions :: AgentClient -> IO SubscriptionsInfo
getAgentSubscriptions c = do
activeSubscriptions <- getSubs activeSubs
pendingSubscriptions <- getSubs pendingSubs
@@ -1600,7 +1648,7 @@ data WorkersDetails = WorkersDetails
}
deriving (Show)
getAgentWorkersDetails :: MonadIO m => AgentClient -> m AgentWorkersDetails
getAgentWorkersDetails :: AgentClient -> IO AgentWorkersDetails
getAgentWorkersDetails AgentClient {smpClients, ntfClients, xftpClients, smpDeliveryWorkers, asyncCmdWorkers, smpSubWorkers, agentEnv} = do
smpClients_ <- textKeys <$> readTVarIO smpClients
ntfClients_ <- textKeys <$> readTVarIO ntfClients
@@ -1632,7 +1680,7 @@ getAgentWorkersDetails AgentClient {smpClients, ntfClients, xftpClients, smpDeli
textKeys = map textKey . M.keys
textKey :: StrEncoding k => k -> Text
textKey = decodeASCII . strEncode
workerStats :: (StrEncoding k, MonadIO m) => Map k Worker -> m (Map Text WorkersDetails)
workerStats :: StrEncoding k => Map k Worker -> IO (Map Text WorkersDetails)
workerStats ws = fmap M.fromList . forM (M.toList ws) $ \(qa, Worker {restarts, doWork, action}) -> do
RestartCount {restartCount} <- readTVarIO restarts
hasWork <- atomically $ not <$> isEmptyTMVar doWork
@@ -1664,7 +1712,7 @@ data WorkersSummary = WorkersSummary
}
deriving (Show)
getAgentWorkersSummary :: MonadIO m => AgentClient -> m AgentWorkersSummary
getAgentWorkersSummary :: AgentClient -> IO AgentWorkersSummary
getAgentWorkersSummary AgentClient {smpClients, ntfClients, xftpClients, smpDeliveryWorkers, asyncCmdWorkers, smpSubWorkers, agentEnv} = do
smpClientsCount <- M.size <$> readTVarIO smpClients
ntfClientsCount <- M.size <$> readTVarIO ntfClients
@@ -1695,7 +1743,7 @@ getAgentWorkersSummary AgentClient {smpClients, ntfClients, xftpClients, smpDeli
Env {ntfSupervisor, xftpAgent} = agentEnv
NtfSupervisor {ntfWorkers, ntfSMPWorkers} = ntfSupervisor
XFTPAgent {xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers} = xftpAgent
workerSummary :: MonadIO m => M.Map k Worker -> m WorkersSummary
workerSummary :: M.Map k Worker -> IO WorkersSummary
workerSummary = liftIO . foldM byWork WorkersSummary {numActive = 0, numIdle = 0, totalRestarts = 0}
where
byWork WorkersSummary {numActive, numIdle, totalRestarts} Worker {action, restarts} = do
+17 -13
View File
@@ -11,8 +11,8 @@
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
module Simplex.Messaging.Agent.Env.SQLite
( AgentMonad,
AgentMonad',
( AM',
AM,
AgentConfig (..),
InitialAgentServers (..),
NetworkConfig (..),
@@ -21,6 +21,7 @@ module Simplex.Messaging.Agent.Env.SQLite
tryAgentError,
tryAgentError',
catchAgentError,
catchAgentError',
agentFinally,
Env (..),
newSMPAgentEnv,
@@ -34,7 +35,6 @@ module Simplex.Messaging.Agent.Env.SQLite
)
where
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Reader
@@ -65,14 +65,14 @@ import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (SMPVersion, TLS, Transport (..))
import Simplex.Messaging.Transport.Client (defaultSMPPort)
import Simplex.Messaging.Util (allFinally, catchAllErrors, tryAllErrors)
import Simplex.Messaging.Util (allFinally, catchAllErrors, catchAllErrors', tryAllErrors, tryAllErrors')
import System.Random (StdGen, newStdGen)
import UnliftIO (Async, SomeException)
import UnliftIO.STM
type AgentMonad' m = (MonadUnliftIO m, MonadReader Env m)
type AM' a = ReaderT Env IO a
type AgentMonad m = (AgentMonad' m, MonadError AgentErrorType m)
type AM a = ExceptT AgentErrorType (ReaderT Env IO) a
data InitialAgentServers = InitialAgentServers
{ smp :: Map UserId (NonEmpty SMPServerWithAuth),
@@ -82,7 +82,7 @@ data InitialAgentServers = InitialAgentServers
}
data AgentConfig = AgentConfig
{ tcpPort :: ServiceName,
{ tcpPort :: Maybe ServiceName,
rcvAuthAlg :: C.AuthAlg,
sndAuthAlg :: C.AuthAlg,
connIdBytes :: Int,
@@ -149,7 +149,7 @@ defaultMessageRetryInterval =
defaultAgentConfig :: AgentConfig
defaultAgentConfig =
AgentConfig
{ tcpPort = "5224",
{ tcpPort = Just "5224",
-- while the current client version supports X25519, it can only be enabled once support for SMP v6 is dropped,
-- and all servers are required to support v7 to be compatible.
rcvAuthAlg = C.AuthAlg C.SEd25519, -- this will stay as Ed25519
@@ -250,20 +250,24 @@ newXFTPAgent = do
xftpDelWorkers <- TM.empty
pure XFTPAgent {xftpWorkDir, xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers}
tryAgentError :: AgentMonad m => m a -> m (Either AgentErrorType a)
tryAgentError :: AM a -> AM (Either AgentErrorType a)
tryAgentError = tryAllErrors mkInternal
{-# INLINE tryAgentError #-}
-- unlike runExceptT, this ensures we catch IO exceptions as well
tryAgentError' :: AgentMonad' m => ExceptT AgentErrorType m a -> m (Either AgentErrorType a)
tryAgentError' = fmap join . runExceptT . tryAgentError
tryAgentError' :: AM a -> AM' (Either AgentErrorType a)
tryAgentError' = tryAllErrors' mkInternal
{-# INLINE tryAgentError' #-}
catchAgentError :: AgentMonad m => m a -> (AgentErrorType -> m a) -> m a
catchAgentError :: AM a -> (AgentErrorType -> AM a) -> AM a
catchAgentError = catchAllErrors mkInternal
{-# INLINE catchAgentError #-}
agentFinally :: AgentMonad m => m a -> m b -> m a
catchAgentError' :: AM a -> (AgentErrorType -> AM' a) -> AM' a
catchAgentError' = catchAllErrors' mkInternal
{-# INLINE catchAgentError' #-}
agentFinally :: AM a -> AM b -> AM a
agentFinally = allFinally mkInternal
{-# INLINE agentFinally #-}
+8 -4
View File
@@ -1,15 +1,15 @@
{-# LANGUAGE NamedFieldPuns #-}
module Simplex.Messaging.Agent.Lock
( Lock,
createLock,
withLock,
withLock',
withGetLock,
withGetLocks,
)
where
import Control.Monad (void)
import Control.Monad.Except (ExceptT (..), runExceptT)
import Control.Monad.IO.Unlift
import Data.Functor (($>))
import UnliftIO.Async (forConcurrently)
@@ -22,8 +22,12 @@ createLock :: STM Lock
createLock = newEmptyTMVar
{-# INLINE createLock #-}
withLock :: MonadUnliftIO m => Lock -> String -> m a -> m a
withLock lock name =
withLock :: MonadUnliftIO m => Lock -> String -> ExceptT e m a -> ExceptT e m a
withLock lock name = ExceptT . withLock' lock name . runExceptT
{-# INLINE withLock #-}
withLock' :: MonadUnliftIO m => Lock -> String -> m a -> m a
withLock' lock name =
E.bracket_
(atomically $ putTMVar lock name)
(void . atomically $ takeTMVar lock)
+46 -43
View File
@@ -44,7 +44,7 @@ import UnliftIO
import UnliftIO.Concurrent (forkIO, threadDelay)
import qualified UnliftIO.Exception as E
runNtfSupervisor :: forall m. AgentMonad' m => AgentClient -> m ()
runNtfSupervisor :: AgentClient -> AM' ()
runNtfSupervisor c = do
ns <- asks ntfSupervisor
forever $ do
@@ -54,13 +54,13 @@ runNtfSupervisor c = do
Left e -> notifyErr connId e
Right _ -> return ()
where
handleErr :: ConnId -> m () -> m ()
handleErr :: ConnId -> AM' () -> AM' ()
handleErr connId = E.handle $ \(e :: E.SomeException) -> do
logError $ "runNtfSupervisor error " <> tshow e
notifyErr connId e
notifyErr connId e = notifyInternalError c connId $ "runNtfSupervisor error " <> show e
processNtfSub :: forall m. AgentMonad m => AgentClient -> (ConnId, NtfSupervisorCommand) -> m ()
processNtfSub :: AgentClient -> (ConnId, NtfSupervisorCommand) -> AM ()
processNtfSub c (connId, cmd) = do
logInfo $ "processNtfSub - connId = " <> tshow connId <> " - cmd = " <> tshow cmd
case cmd of
@@ -77,11 +77,11 @@ processNtfSub c (connId, cmd) = do
Just ClientNtfCreds {notifierId} -> do
let newSub = newNtfSubscription connId smpServer (Just notifierId) ntfServer NASKey
withStore c $ \db -> createNtfSubscription db newSub $ NtfSubNTFAction NSACreate
void $ getNtfNTFWorker True c ntfServer
lift . void $ getNtfNTFWorker True c ntfServer
Nothing -> do
let newSub = newNtfSubscription connId smpServer Nothing ntfServer NASNew
withStore c $ \db -> createNtfSubscription db newSub $ NtfSubSMPAction NSASmpKey
void $ getNtfSMPWorker True c smpServer
lift . void $ getNtfSMPWorker True c smpServer
(Just (sub@NtfSubscription {ntfSubStatus, ntfServer = subNtfServer, smpServer = smpServer', ntfQueueId}, action_)) -> do
case (clientNtfCreds, ntfQueueId) of
(Just ClientNtfCreds {notifierId}, Just ntfQueueId')
@@ -90,7 +90,7 @@ processNtfSub c (connId, cmd) = do
(Nothing, Nothing) -> create
_ -> rotate
where
create :: m ()
create :: AM ()
create = case action_ of
-- action was set to NULL after worker internal error
Nothing -> resetSubscription
@@ -101,60 +101,60 @@ processNtfSub c (connId, cmd) = do
then resetSubscription
else withTokenServer $ \ntfServer -> do
withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate)
void $ getNtfNTFWorker True c ntfServer
lift . void $ getNtfNTFWorker True c ntfServer
| otherwise -> case action of
NtfSubNTFAction _ -> void $ getNtfNTFWorker True c subNtfServer
NtfSubSMPAction _ -> void $ getNtfSMPWorker True c smpServer
rotate :: m ()
NtfSubNTFAction _ -> lift . void $ getNtfNTFWorker True c subNtfServer
NtfSubSMPAction _ -> lift . void $ getNtfSMPWorker True c smpServer
rotate :: AM ()
rotate = do
withStore' c $ \db -> supervisorUpdateNtfSub db sub (NtfSubNTFAction NSARotate)
void $ getNtfNTFWorker True c subNtfServer
resetSubscription :: m ()
lift . void $ getNtfNTFWorker True c subNtfServer
resetSubscription :: AM ()
resetSubscription =
withTokenServer $ \ntfServer -> do
let sub' = sub {ntfQueueId = Nothing, ntfServer, ntfSubId = Nothing, ntfSubStatus = NASNew}
withStore' c $ \db -> supervisorUpdateNtfSub db sub' (NtfSubSMPAction NSASmpKey)
void $ getNtfSMPWorker True c smpServer
lift . void $ getNtfSMPWorker True c smpServer
NSCDelete -> do
sub_ <- withStore' c $ \db -> do
supervisorUpdateNtfAction db connId (NtfSubNTFAction NSADelete)
getNtfSubscription db connId
logInfo $ "processNtfSub, NSCDelete - sub_ = " <> tshow sub_
case sub_ of
(Just (NtfSubscription {ntfServer}, _)) -> void $ getNtfNTFWorker True c ntfServer
(Just (NtfSubscription {ntfServer}, _)) -> lift . void $ getNtfNTFWorker True c ntfServer
_ -> pure () -- err "NSCDelete - no subscription"
NSCSmpDelete -> do
withStore' c (`getPrimaryRcvQueue` connId) >>= \case
Right rq@RcvQueue {server = smpServer} -> do
logInfo $ "processNtfSub, NSCSmpDelete - rq = " <> tshow rq
withStore' c $ \db -> supervisorUpdateNtfAction db connId (NtfSubSMPAction NSASmpDelete)
void $ getNtfSMPWorker True c smpServer
lift . void $ getNtfSMPWorker True c smpServer
_ -> notifyInternalError c connId "NSCSmpDelete - no rcv queue"
NSCNtfWorker ntfServer -> void $ getNtfNTFWorker True c ntfServer
NSCNtfSMPWorker smpServer -> void $ getNtfSMPWorker True c smpServer
NSCNtfWorker ntfServer -> lift . void $ getNtfNTFWorker True c ntfServer
NSCNtfSMPWorker smpServer -> lift . void $ getNtfSMPWorker True c smpServer
getNtfNTFWorker :: AgentMonad' m => Bool -> AgentClient -> NtfServer -> m Worker
getNtfNTFWorker :: Bool -> AgentClient -> NtfServer -> AM' Worker
getNtfNTFWorker hasWork c server = do
ws <- asks $ ntfWorkers . ntfSupervisor
getAgentWorker "ntf_ntf" hasWork c server ws $ runNtfWorker c server
getNtfSMPWorker :: AgentMonad' m => Bool -> AgentClient -> SMPServer -> m Worker
getNtfSMPWorker :: Bool -> AgentClient -> SMPServer -> AM' Worker
getNtfSMPWorker hasWork c server = do
ws <- asks $ ntfSMPWorkers . ntfSupervisor
getAgentWorker "ntf_smp" hasWork c server ws $ runNtfSMPWorker c server
withTokenServer :: AgentMonad' m => (NtfServer -> m ()) -> m ()
withTokenServer action = getNtfToken >>= mapM_ (\NtfToken {ntfServer} -> action ntfServer)
withTokenServer :: (NtfServer -> AM ()) -> AM ()
withTokenServer action = lift getNtfToken >>= mapM_ (\NtfToken {ntfServer} -> action ntfServer)
runNtfWorker :: forall m. AgentMonad m => AgentClient -> NtfServer -> Worker -> m ()
runNtfWorker :: AgentClient -> NtfServer -> Worker -> AM ()
runNtfWorker c srv Worker {doWork} = do
delay <- asks $ ntfWorkerDelay . config
forever $ do
waitForWork doWork
agentOperationBracket c AONtfNetwork throwWhenInactive runNtfOperation
ExceptT $ agentOperationBracket c AONtfNetwork throwWhenInactive $ runExceptT runNtfOperation
threadDelay delay
where
runNtfOperation :: m ()
runNtfOperation :: AM ()
runNtfOperation =
withWork c doWork (`getNextNtfSubNTFAction` srv) $
\nextSub@(NtfSubscription {connId}, _, _) -> do
@@ -163,13 +163,13 @@ runNtfWorker c srv Worker {doWork} = do
withRetryInterval ri $ \_ loop ->
processSub nextSub
`catchAgentError` retryOnError c "NtfWorker" loop (workerInternalError c connId . show)
processSub :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> m ()
processSub :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> AM ()
processSub (sub@NtfSubscription {connId, smpServer, ntfSubId}, action, actionTs) = do
ts <- liftIO getCurrentTime
unlessM (rescheduleAction doWork ts actionTs) $
unlessM (lift $ rescheduleAction doWork ts actionTs) $
case action of
NSACreate ->
getNtfToken >>= \case
lift getNtfToken >>= \case
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do
RcvQueue {clientNtfCreds} <- withStore c (`getPrimaryRcvQueue` connId)
case clientNtfCreds of
@@ -182,13 +182,13 @@ runNtfWorker c srv Worker {doWork} = do
_ -> workerInternalError c connId "NSACreate - no notifier queue credentials"
_ -> workerInternalError c connId "NSACreate - no active token"
NSACheck ->
getNtfToken >>= \case
lift getNtfToken >>= \case
Just tkn ->
case ntfSubId of
Just nSubId ->
agentNtfCheckSubscription c nSubId tkn >>= \case
NSAuth -> do
getNtfServer c >>= \case
lift (getNtfServer c) >>= \case
Just ntfServer -> do
withStore' c $ \db ->
updateNtfSubscription db sub {ntfServer, ntfQueueId = Nothing, ntfSubId = Nothing, ntfSubStatus = NASNew} (NtfSubSMPAction NSASmpKey) ts
@@ -200,7 +200,7 @@ runNtfWorker c srv Worker {doWork} = do
_ -> workerInternalError c connId "NSACheck - no active token"
NSADelete -> case ntfSubId of
Just nSubId ->
(getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId))
(lift getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId))
`agentFinally` continueDeletion
_ -> continueDeletion
where
@@ -211,7 +211,7 @@ runNtfWorker c srv Worker {doWork} = do
atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer)
NSARotate -> case ntfSubId of
Just nSubId ->
(getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId))
(lift getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId))
`agentFinally` deleteCreate
_ -> deleteCreate
where
@@ -228,12 +228,14 @@ runNtfWorker c srv Worker {doWork} = do
withStore' c $ \db ->
updateNtfSubscription db sub {ntfSubStatus = toStatus} toAction actionTs'
runNtfSMPWorker :: forall m. AgentMonad m => AgentClient -> SMPServer -> Worker -> m ()
runNtfSMPWorker :: AgentClient -> SMPServer -> Worker -> AM ()
runNtfSMPWorker c srv Worker {doWork} = do
env <- ask
delay <- asks $ ntfSMPWorkerDelay . config
forever $ do
waitForWork doWork
agentOperationBracket c AONtfNetwork throwWhenInactive runNtfSMPOperation
ExceptT . liftIO . agentOperationBracket c AONtfNetwork throwWhenInactive $
runReaderT (runExceptT runNtfSMPOperation) env
threadDelay delay
where
runNtfSMPOperation =
@@ -244,13 +246,13 @@ runNtfSMPWorker c srv Worker {doWork} = do
withRetryInterval ri $ \_ loop ->
processSub nextSub
`catchAgentError` retryOnError c "NtfSMPWorker" loop (workerInternalError c connId . show)
processSub :: (NtfSubscription, NtfSubSMPAction, NtfActionTs) -> m ()
processSub :: (NtfSubscription, NtfSubSMPAction, NtfActionTs) -> AM ()
processSub (sub@NtfSubscription {connId, ntfServer}, smpAction, actionTs) = do
ts <- liftIO getCurrentTime
unlessM (rescheduleAction doWork ts actionTs) $
unlessM (lift $ rescheduleAction doWork ts actionTs) $
case smpAction of
NSASmpKey ->
getNtfToken >>= \case
lift getNtfToken >>= \case
Just NtfToken {ntfTknStatus = NTActive, ntfMode = NMInstant} -> do
rq <- withStore c (`getPrimaryRcvQueue` connId)
C.AuthAlg a <- asks (rcvAuthAlg . config)
@@ -272,7 +274,7 @@ runNtfSMPWorker c srv Worker {doWork} = do
mapM_ (disableQueueNotifications c) rq_
withStore' c $ \db -> deleteNtfSubscription db connId
rescheduleAction :: AgentMonad' m => TMVar () -> UTCTime -> UTCTime -> m Bool
rescheduleAction :: TMVar () -> UTCTime -> UTCTime -> AM' Bool
rescheduleAction doWork ts actionTs
| actionTs <= ts = pure False
| otherwise = do
@@ -282,7 +284,7 @@ rescheduleAction doWork ts actionTs
atomically $ hasWorkToDo' doWork
pure True
retryOnError :: AgentMonad' m => AgentClient -> Text -> m () -> (AgentErrorType -> m ()) -> AgentErrorType -> m ()
retryOnError :: AgentClient -> Text -> AM () -> (AgentErrorType -> AM ()) -> AgentErrorType -> AM ()
retryOnError c name loop done e = do
logError $ name <> " error: " <> tshow e
case e of
@@ -296,16 +298,17 @@ retryOnError c name loop done e = do
atomically $ beginAgentOperation c AONtfNetwork
loop
workerInternalError :: AgentMonad m => AgentClient -> ConnId -> String -> m ()
workerInternalError :: AgentClient -> ConnId -> String -> AM ()
workerInternalError c connId internalErrStr = do
withStore' c $ \db -> setNullNtfSubscriptionAction db connId
notifyInternalError c connId internalErrStr
-- TODO change error
notifyInternalError :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m ()
notifyInternalError :: MonadIO m => AgentClient -> ConnId -> String -> m ()
notifyInternalError AgentClient {subQ} connId internalErrStr = atomically $ writeTBQueue subQ ("", connId, APC SAEConn $ ERR $ INTERNAL internalErrStr)
{-# INLINE notifyInternalError #-}
getNtfToken :: AgentMonad' m => m (Maybe NtfToken)
getNtfToken :: AM' (Maybe NtfToken)
getNtfToken = do
tkn <- asks $ ntfTkn . ntfSupervisor
readTVarIO tkn
@@ -326,14 +329,14 @@ instantNotifications = \case
Just NtfToken {ntfTknStatus = NTActive, ntfMode = NMInstant} -> True
_ -> False
closeNtfSupervisor :: MonadUnliftIO m => NtfSupervisor -> m ()
closeNtfSupervisor :: NtfSupervisor -> IO ()
closeNtfSupervisor ns = do
stopWorkers $ ntfWorkers ns
stopWorkers $ ntfSMPWorkers ns
where
stopWorkers workers = atomically (swapTVar workers M.empty) >>= mapM_ (liftIO . cancelWorker)
getNtfServer :: AgentMonad' m => AgentClient -> m (Maybe NtfServer)
getNtfServer :: AgentClient -> AM' (Maybe NtfServer)
getNtfServer c = do
ntfServers <- readTVarIO $ ntfServers c
case ntfServers of
+7 -7
View File
@@ -163,7 +163,6 @@ import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
@@ -202,6 +201,7 @@ import Simplex.Messaging.Crypto.Ratchet
SndE2ERatchetParams
)
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.Base64 (base64P, encode)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers
import Simplex.Messaging.Protocol
@@ -1897,15 +1897,15 @@ tGetRaw :: Transport c => c -> IO ARawTransmission
tGetRaw h = (,,) <$> getLn h <*> getLn h <*> getLn h
-- | Send SMP agent protocol command (or response) to TCP connection.
tPut :: (Transport c, MonadIO m) => c -> ATransmission p -> m ()
tPut :: Transport c => c -> ATransmission p -> IO ()
tPut h (corrId, connId, APC _ cmd) =
liftIO $ tPutRaw h (corrId, connId, serializeCommand cmd)
tPutRaw h (corrId, connId, serializeCommand cmd)
-- | Receive client and agent transmissions from TCP connection.
tGet :: forall c m p. (Transport c, MonadIO m) => SAParty p -> c -> m (ATransmissionOrError p)
tGet :: forall c p. Transport c => SAParty p -> c -> IO (ATransmissionOrError p)
tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
where
tParseLoadBody :: ARawTransmission -> m (ATransmissionOrError p)
tParseLoadBody :: ARawTransmission -> IO (ATransmissionOrError p)
tParseLoadBody t@(corrId, entId, command) = do
let cmd = parseCommand command >>= fromParty >>= tConnId t
fullCmd <- either (return . Left) cmdWithMsgBody cmd
@@ -1935,7 +1935,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
| B.null entId -> Left $ CMD NO_CONN
| otherwise -> Right cmd
cmdWithMsgBody :: APartyCmd p -> m (Either AgentErrorType (APartyCmd p))
cmdWithMsgBody :: APartyCmd p -> IO (Either AgentErrorType (APartyCmd p))
cmdWithMsgBody (APC e cmd) =
APC e <$$> case cmd of
SEND pqEnc msgFlags body -> SEND pqEnc msgFlags <$$> getBody body
@@ -1948,7 +1948,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
INFO pqSup cInfo -> INFO pqSup <$$> getBody cInfo
_ -> pure $ Right cmd
getBody :: ByteString -> m (Either AgentErrorType ByteString)
getBody :: ByteString -> IO (Either AgentErrorType ByteString)
getBody binary =
case B.unpack binary of
':' : body -> return . Right $ B.pack body
+21 -19
View File
@@ -12,13 +12,13 @@ where
import Control.Logger.Simple (logInfo)
import Control.Monad
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Reader
import Crypto.Random (MonadRandom)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Text.Encoding (decodeUtf8)
import Network.Socket (ServiceName)
import Simplex.Messaging.Agent
import Simplex.Messaging.Agent.Client (newAgentClient)
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore)
@@ -32,7 +32,7 @@ import UnliftIO.STM
-- | Runs an SMP agent as a TCP service using passed configuration.
--
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> m ()
runSMPAgent :: ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> IO ()
runSMPAgent t cfg initServers store =
runSMPAgentBlocking t cfg initServers store 0 =<< newEmptyTMVarIO
@@ -40,44 +40,46 @@ runSMPAgent t cfg initServers store =
--
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Int -> TMVar Bool -> m ()
runSMPAgentBlocking (ATransport t) cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers store initClientId started = do
liftIO (newSMPAgentEnv cfg store) >>= runReaderT (smpAgent t)
runSMPAgentBlocking :: ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Int -> TMVar Bool -> IO ()
runSMPAgentBlocking (ATransport t) cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers store initClientId started =
case tcpPort of
Just port -> newSMPAgentEnv cfg store >>= smpAgent t port
Nothing -> E.throwIO $ userError "no agent port"
where
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
smpAgent _ = do
smpAgent :: forall c. Transport c => TProxy c -> ServiceName -> Env -> IO ()
smpAgent _ port env = do
-- tlsServerParams is not in Env to avoid breaking functional API w/t key and certificate generation
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile
clientId <- newTVarIO initClientId
runTransportServer started tcpPort tlsServerParams defaultTransportServerConfig $ \(h :: c) -> do
liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
runTransportServer started port tlsServerParams defaultTransportServerConfig $ \(h :: c) -> do
putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
cId <- atomically $ stateTVar clientId $ \i -> (i + 1, i + 1)
c <- getAgentClient cId initServers
c <- atomically $ newAgentClient cId initServers env
logConnection c True
race_ (connectClient h c) (runAgentClient c)
`E.finally` disconnectAgentClient c
race_ (connectClient h c) (runAgentClient c `runReaderT` env)
`E.finally` (disconnectAgentClient c)
connectClient :: Transport c => MonadUnliftIO m => c -> AgentClient -> m ()
connectClient :: Transport c => c -> AgentClient -> IO ()
connectClient h c = race_ (send h c) (receive h c)
receive :: forall c m. (Transport c, MonadUnliftIO m) => c -> AgentClient -> m ()
receive :: forall c. Transport c => c -> AgentClient -> IO ()
receive h c@AgentClient {rcvQ, subQ} = forever $ do
(corrId, entId, cmdOrErr) <- tGet SClient h
case cmdOrErr of
Right cmd -> write rcvQ (corrId, entId, cmd)
Left e -> write subQ (corrId, entId, APC SAEConn $ ERR e)
where
write :: TBQueue (ATransmission p) -> ATransmission p -> m ()
write :: TBQueue (ATransmission p) -> ATransmission p -> IO ()
write q t = do
logClient c "-->" t
atomically $ writeTBQueue q t
send :: (Transport c, MonadUnliftIO m) => c -> AgentClient -> m ()
send :: Transport c => c -> AgentClient -> IO ()
send h c@AgentClient {subQ} = forever $ do
t <- atomically $ readTBQueue subQ
tPut h t
logClient c "<--" t
logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> m ()
logClient :: AgentClient -> ByteString -> ATransmission a -> IO ()
logClient AgentClient {clientId} dir (corrId, connId, APC _ cmd) = do
logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, connId, B.takeWhile (/= ' ') $ serializeCommand cmd]
+1 -1
View File
@@ -231,7 +231,6 @@ import Data.Bifunctor (first, second)
import Data.ByteArray (ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString.Base64.URL as U
import qualified Data.ByteString.Char8 as B
import Data.Char (toLower)
import Data.Functor (($>))
@@ -271,6 +270,7 @@ import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..))
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys, PQEncryption (..), PQSupport (..))
import qualified Simplex.Messaging.Crypto.Ratchet as CR
import Simplex.Messaging.Encoding
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..))
import Simplex.Messaging.Notifications.Types
+26 -7
View File
@@ -97,7 +97,7 @@ import Data.List (find)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Maybe (fromMaybe)
import Data.Time.Clock (UTCTime, getCurrentTime)
import Data.Time.Clock (UTCTime (..), getCurrentTime)
import Network.Socket (ServiceName)
import Numeric.Natural
import qualified Simplex.Messaging.Crypto as C
@@ -138,11 +138,12 @@ data PClient v err msg = PClient
msgQ :: Maybe (TBQueue (ServerTransmission v msg))
}
smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM (ProtocolClient SMPVersion err msg)
smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM SMPClient
smpClientStub g sessionId thVersion thAuth = do
connected <- newTVar False
clientCorrId <- C.newRandomDRG g
sentCommands <- TM.empty
pingErrorCount <- newTVar 0
sndQ <- newTBQueue 100
rcvQ <- newTBQueue 100
return
@@ -157,15 +158,15 @@ smpClientStub g sessionId thVersion thAuth = do
implySessId = thVersion >= authCmdsSMPVersion,
batch = True
},
sessionTs = undefined,
sessionTs = UTCTime (read "2024-03-31") 0,
client_ =
PClient
{ connected,
transportSession = undefined,
transportHost = undefined,
tcpTimeout = undefined,
transportSession = (1, "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001", Nothing),
transportHost = "localhost",
tcpTimeout = 15_000_000,
batchDelay = Nothing,
pingErrorCount = undefined,
pingErrorCount,
clientCorrId,
sentCommands,
sndQ,
@@ -239,6 +240,7 @@ defaultNetworkConfig =
transportClientConfig :: NetworkConfig -> TransportClientConfig
transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} =
TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing}
{-# INLINE transportClientConfig #-}
-- | protocol client configuration.
data ProtocolClientConfig v = ProtocolClientConfig
@@ -264,9 +266,11 @@ defaultClientConfig serverVRange =
serverVRange,
batchDelay = Nothing
}
{-# INLINE defaultClientConfig #-}
defaultSMPClientConfig :: ProtocolClientConfig SMPVersion
defaultSMPClientConfig = defaultClientConfig supportedClientSMPRelayVRange
{-# INLINE defaultSMPClientConfig #-}
data Request err msg = Request
{ entityId :: EntityId,
@@ -296,12 +300,15 @@ protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient v err ms
protocolClientServer = B.unpack . strEncode . snd3 . transportSession . client_
where
snd3 (_, s, _) = s
{-# INLINE protocolClientServer #-}
transportHost' :: ProtocolClient v err msg -> TransportHost
transportHost' = transportHost . client_
{-# INLINE transportHost' #-}
transportSession' :: ProtocolClient v err msg -> TransportSession msg
transportSession' = transportSession . client_
{-# INLINE transportSession' #-}
type UserId = Int64
@@ -426,10 +433,12 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
proxyUsername :: TransportSession msg -> ByteString
proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_
{-# INLINE proxyUsername #-}
-- | Disconnects client from the server and terminates client threads.
closeProtocolClient :: ProtocolClient v err msg -> IO ()
closeProtocolClient = mapM_ uninterruptibleCancel . action
{-# INLINE closeProtocolClient #-}
-- | SMP client error type.
data ProtocolClientError err
@@ -469,6 +478,7 @@ temporaryClientError = \case
PCEResponseTimeout -> True
PCEIOError _ -> True
_ -> False
{-# INLINE temporaryClientError #-}
-- | Create a new SMP queue.
--
@@ -536,16 +546,19 @@ getSMPMessage c rpKey rId =
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateAuthKey -> NotifierId -> ExceptT SMPClientError IO ()
subscribeSMPQueueNotifications = okSMPCommand NSUB
{-# INLINE subscribeSMPQueueNotifications #-}
-- | Subscribe to multiple SMP queues notifications batching commands if supported.
subscribeSMPQueuesNtfs :: SMPClient -> NonEmpty (NtfPrivateAuthKey, NotifierId) -> IO (NonEmpty (Either SMPClientError ()))
subscribeSMPQueuesNtfs = okSMPCommands NSUB
{-# INLINE subscribeSMPQueuesNtfs #-}
-- | Secure the SMP queue by adding a sender public key.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command
secureSMPQueue :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> SndPublicAuthKey -> ExceptT SMPClientError IO ()
secureSMPQueue c rpKey rId senderKey = okSMPCommand (KEY senderKey) c rpKey rId
{-# INLINE secureSMPQueue #-}
-- | Enable notifications for the queue for push notifications server.
--
@@ -571,10 +584,12 @@ enableSMPQueuesNtfs c qs = L.map process <$> sendProtocolCommands c cs
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#disable-notifications-command
disableSMPQueueNotifications :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> ExceptT SMPClientError IO ()
disableSMPQueueNotifications = okSMPCommand NDEL
{-# INLINE disableSMPQueueNotifications #-}
-- | Disable notifications for multiple queues for push notifications server.
disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateAuthKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
disableSMPQueuesNtfs = okSMPCommands NDEL
{-# INLINE disableSMPQueuesNtfs #-}
-- | Send SMP message.
--
@@ -601,16 +616,19 @@ ackSMPMessage c rpKey rId msgId =
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue
suspendSMPQueue :: SMPClient -> RcvPrivateAuthKey -> QueueId -> ExceptT SMPClientError IO ()
suspendSMPQueue = okSMPCommand OFF
{-# INLINE suspendSMPQueue #-}
-- | Irreversibly delete SMP queue and all messages in it.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue
deleteSMPQueue :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> ExceptT SMPClientError IO ()
deleteSMPQueue = okSMPCommand DEL
{-# INLINE deleteSMPQueue #-}
-- | Delete multiple SMP queues batching commands if supported.
deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateAuthKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
deleteSMPQueues = okSMPCommands DEL
{-# INLINE deleteSMPQueues #-}
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateAuthKey -> QueueId -> ExceptT SMPClientError IO ()
okSMPCommand cmd c pKey qId =
@@ -631,6 +649,7 @@ okSMPCommands cmd c qs = L.map process <$> sendProtocolCommands c cs
-- | Send SMP command
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateAuthKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg
sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
{-# INLINE sendSMPCommand #-}
type PCTransmission err msg = (Either TransportError SentRawTransmission, Request err msg)
+26 -12
View File
@@ -1,4 +1,5 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
@@ -18,6 +19,7 @@ import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Trans.Except
import Control.Monad.Trans.Reader
import Crypto.Random (ChaChaDRG)
import Data.Bifunctor (bimap, first)
import Data.ByteString.Char8 (ByteString)
@@ -106,12 +108,24 @@ newtype InternalException e = InternalException {unInternalException :: e}
instance Exception e => Exception (InternalException e)
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b
withRunInIO exceptToIO =
instance Exception e => MonadUnliftIO (ExceptT e IO) where
{-# INLINE withRunInIO #-}
withRunInIO :: ((forall a. ExceptT e IO a -> IO a) -> IO b) -> ExceptT e IO b
withRunInIO inner =
ExceptT . fmap (first unInternalException) . E.try $
withRunInIO $ \run ->
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`,
-- the last two lines could be replaced with:
-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT
instance Exception e => MonadUnliftIO (ExceptT e (ReaderT r IO)) where
{-# INLINE withRunInIO #-}
withRunInIO :: ((forall a. ExceptT e (ReaderT r IO) a -> IO a) -> IO b) -> ExceptT e (ReaderT r IO) b
withRunInIO inner =
withExceptT unInternalException . ExceptT . E.try $
withRunInIO $ \run ->
exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT)
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
newSMPClientAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> STM SMPClientAgent
newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg = do
@@ -147,7 +161,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
Nothing -> Left PCEResponseTimeout
newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
newSMPClient smpVar = tryConnectClient pure tryConnectAsync
newSMPClient smpVar = tryConnectClient pure (liftIO tryConnectAsync)
where
tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a
tryConnectClient successAction retryAction =
@@ -163,9 +177,9 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
putTMVar smpVar (Left e)
TM.delete srv smpClients
throwE e
tryConnectAsync :: ExceptT SMPClientError IO ()
tryConnectAsync :: IO ()
tryConnectAsync = do
a <- async connectAsync
a <- async $ void $ runExceptT connectAsync
atomically $ modifyTVar' (asyncClients ca) (a :)
connectAsync :: ExceptT SMPClientError IO ()
connectAsync =
@@ -199,11 +213,11 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
serverDown :: Map SMPSub C.APrivateAuthKey -> IO ()
serverDown ss = unless (M.null ss) $ do
notify . CADisconnected srv $ M.keysSet ss
void $ runExceptT reconnectServer
reconnectServer
reconnectServer :: ExceptT SMPClientError IO ()
reconnectServer :: IO ()
reconnectServer = do
a <- async tryReconnectClient
a <- async $ void $ runExceptT tryReconnectClient
atomically $ modifyTVar' (reconnections ca) (a :)
tryReconnectClient :: ExceptT SMPClientError IO ()
@@ -247,8 +261,8 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
notify :: SMPClientAgentEvent -> IO ()
notify evt = atomically $ writeTBQueue (agentQ ca) evt
closeSMPClientAgent :: MonadUnliftIO m => SMPClientAgent -> m ()
closeSMPClientAgent c = liftIO $ do
closeSMPClientAgent :: SMPClientAgent -> IO ()
closeSMPClientAgent c = do
closeSMPServerClients c
cancelActions $ reconnections c
cancelActions $ asyncClients c
+2 -2
View File
@@ -211,8 +211,6 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (bimap, first)
import Data.ByteArray (ByteArrayAccess)
import qualified Data.ByteArray as BA
import Data.ByteString.Base64 (decode, encode)
import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.ByteString.Lazy (fromStrict, toStrict)
@@ -230,6 +228,8 @@ import Database.SQLite.Simple.ToField (ToField (..))
import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+))
import Network.Transport.Internal (decodeWord16, encodeWord16)
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.Base64 (decode, encode)
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (blobFieldDecoder, parseAll, parseString)
import Simplex.Messaging.Util ((<$?>))
+1 -1
View File
@@ -76,7 +76,7 @@ withFile :: CryptoFile -> IOMode -> (CryptoFileHandle -> ExceptT FTCryptoError I
withFile (CryptoFile path cfArgs) mode action = do
sb <- forM cfArgs $ \(CFArgs key nonce) ->
liftEitherWith FTCECryptoError (LC.sbInit key nonce) >>= newTVarIO
IO.withFile path mode $ \h -> action $ CFHandle h sb
ExceptT . IO.withFile path mode $ \h -> runExceptT $ action $ CFHandle h sb
hPut :: CryptoFileHandle -> LazyByteString -> IO ()
hPut (CFHandle h sb_) s = LB.hPut h =<< maybe (pure s) encrypt sb_
+29
View File
@@ -0,0 +1,29 @@
{-# LANGUAGE OverloadedStrings #-}
-- | Compatibility wrappers for base64 package, Base64 (padded) variant.
module Simplex.Messaging.Encoding.Base64 where
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Base64.Types (extractBase64)
import Data.Bifunctor (first)
import Data.ByteString.Base64 (decodeBase64Untyped, encodeBase64')
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.Text as T
encode :: ByteString -> ByteString
encode = extractBase64 . encodeBase64'
{-# INLINE encode #-}
decode :: ByteString -> Either String ByteString
decode = first T.unpack . decodeBase64Untyped
{-# INLINE decode #-}
base64P :: A.Parser ByteString
base64P = do
str <- A.takeWhile1 (`B.elem` base64Alphabet)
pad <- A.takeWhile (== '=') -- correct amount of padding can be derived from str length
either (fail . T.unpack) pure $ decodeBase64Untyped (str <> pad)
base64Alphabet :: ByteString
base64Alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
@@ -0,0 +1,33 @@
{-# LANGUAGE OverloadedStrings #-}
-- | Compatibility wrappers for base64 package, Base64URL-padded variant.
module Simplex.Messaging.Encoding.Base64.URL where
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Base64.Types (extractBase64)
import Data.Bifunctor (first)
import Data.ByteString.Base64.URL (decodeBase64Lenient, decodeBase64UnpaddedUntyped, decodeBase64Untyped, encodeBase64')
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.Text as T
encode :: ByteString -> ByteString
encode = extractBase64 . encodeBase64'
{-# INLINE encode #-}
decode :: ByteString -> Either String ByteString
decode = first T.unpack . decodeBase64Untyped
{-# INLINE decode #-}
decodeLenient :: ByteString -> ByteString
decodeLenient = decodeBase64Lenient
{-# INLINE decodeLenient #-}
base64urlP :: A.Parser ByteString
base64urlP = do
str <- A.takeWhile1 (`B.elem` base64AlphabetURL)
_pad <- A.takeWhile (== '=') -- correct amount of padding can be derived from str length
either (fail . T.unpack) pure $ decodeBase64UnpaddedUntyped str
base64AlphabetURL :: ByteString
base64AlphabetURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
+6 -11
View File
@@ -10,7 +10,6 @@ module Simplex.Messaging.Encoding.String
strToJSON,
strToJEncoding,
strParseJSON,
base64urlP,
strEncodeList,
strListP,
)
@@ -23,10 +22,8 @@ import qualified Data.Aeson.Encoding as JE
import qualified Data.Aeson.Types as JT
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (isAlphaNum)
import Data.Int (Int64)
import qualified Data.List.NonEmpty as L
import Data.Set (Set)
@@ -38,6 +35,7 @@ import Data.Time.Clock.System (SystemTime (..))
import Data.Time.Format.ISO8601
import Data.Word (Word16, Word32)
import Simplex.Messaging.Encoding
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Util ((<$?>))
@@ -54,19 +52,16 @@ class StrEncoding a where
strDecode :: ByteString -> Either String a
strDecode = parseAll strP
strP :: Parser a
strP = strDecode <$?> base64urlP
strP = strDecode <$?> U.base64urlP
-- base64url encoding/decoding of ByteStrings - the parser only allows non-empty strings
instance StrEncoding ByteString where
strEncode = U.encode
{-# INLINE strEncode #-}
strDecode = U.decode
strP = base64urlP
base64urlP :: Parser ByteString
base64urlP = do
str <- A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_')
pad <- A.takeWhile (== '=')
either fail pure $ U.decode (str <> pad)
{-# INLINE strDecode #-}
strP = U.base64urlP
{-# INLINE strP #-}
newtype Str = Str {unStr :: ByteString}
deriving (Eq, Show)
@@ -81,7 +81,8 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do
runServer (tcpPort, ATransport t) = do
serverParams <- asks tlsServerParams
serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams
runTransportServer started tcpPort serverParams tCfg (runClient serverSignKey t)
env <- ask
liftIO $ runTransportServer started tcpPort serverParams tCfg $ \h -> runClient serverSignKey t h `runReaderT` env
fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey
runClient :: Transport c => C.APrivateSignKey -> TProxy c -> c -> M ()
@@ -83,7 +83,7 @@ data NtfEnv = NtfEnv
serverStats :: NtfServerStats
}
newNtfServerEnv :: (MonadUnliftIO m, MonadRandom m) => NtfServerConfig -> m NtfEnv
newNtfServerEnv :: NtfServerConfig -> IO NtfEnv
newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsConfig, storeLogFile, caCertificateFile, certificateFile, privateKeyFile} = do
random <- liftIO C.newRandom
store <- atomically newNtfStore
@@ -27,8 +27,9 @@ import Data.Aeson (ToJSON, (.=))
import qualified Data.Aeson as J
import qualified Data.Aeson.Encoding as JE
import qualified Data.Aeson.TH as JQ
import Data.Base64.Types (extractBase64)
import Data.Bifunctor (first)
import qualified Data.ByteString.Base64.URL as U
import qualified Data.ByteString.Base64.URL as UP
import Data.ByteString.Builder (lazyByteString)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Lazy.Char8 as LB
@@ -46,6 +47,7 @@ import Network.HTTP2.Client (Request)
import qualified Network.HTTP2.Client as H
import Network.Socket (HostName, ServiceName)
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Server.Push.APNS.Internal
@@ -91,8 +93,8 @@ signedJWTToken pk (JWTToken hdr claims) = do
pure $ hc <> "." <> serialize sig
where
jwtEncode :: ToJSON a => a -> ByteString
jwtEncode = U.encodeUnpadded . LB.toStrict . J.encode
serialize sig = U.encodeUnpadded $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
jwtEncode = extractBase64 . UP.encodeBase64Unpadded' . LB.toStrict . J.encode
serialize sig = extractBase64 . UP.encodeBase64Unpadded' $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
readECPrivateKey :: FilePath -> IO EC.PrivateKey
readECPrivateKey f = do
+1 -17
View File
@@ -10,10 +10,9 @@ import qualified Data.Aeson as J
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (isAlphaNum, toLower)
import Data.Char (toLower)
import Data.String
import Data.Text (Text)
import qualified Data.Text as T
@@ -24,23 +23,8 @@ import Database.SQLite.Simple (ResultError (..), SQLData (..))
import Database.SQLite.Simple.FromField (FieldParser, returnError)
import Database.SQLite.Simple.Internal (Field (..))
import Database.SQLite.Simple.Ok (Ok (Ok))
import Simplex.Messaging.Util ((<$?>))
import Text.Read (readMaybe)
base64P :: Parser ByteString
base64P = decode <$?> paddedBase64 rawBase64P
paddedBase64 :: Parser ByteString -> Parser ByteString
paddedBase64 raw = (<>) <$> raw <*> pad
where
pad = A.takeWhile (== '=')
rawBase64P :: Parser ByteString
rawBase64P = A.takeWhile1 (\c -> isAlphaNum c || c == '+' || c == '/')
-- rawBase64UriP :: Parser ByteString
-- rawBase64UriP = A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_')
tsISO8601P :: Parser UTCTime
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd
+1 -1
View File
@@ -176,7 +176,6 @@ import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import qualified Data.ByteString.Base64 as B64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (isPrint, isSpace)
@@ -194,6 +193,7 @@ import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
import Network.Socket (ServiceName)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import qualified Simplex.Messaging.Encoding.Base64 as B64
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers
import Simplex.Messaging.ServiceScheme
+40 -39
View File
@@ -45,7 +45,6 @@ import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Crypto.Random
import Data.Bifunctor (first)
import Data.ByteString.Base64 (encode)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (fromRight, partitionEithers)
@@ -68,6 +67,7 @@ import Network.Socket (ServiceName, Socket, socketToHandle)
import Simplex.Messaging.Agent.Lock
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding (Encoding (smpEncode))
import Simplex.Messaging.Encoding.Base64 (encode)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol
import Simplex.Messaging.Server.Control
@@ -128,14 +128,15 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
: serverThread s "server ntfSubscribedQ" ntfSubscribedQ Env.notifiers ntfSubscriptions (\_ -> pure ())
: map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg
)
`finally` withLock (savingLock s) "final" (saveServer False)
`finally` withLock' (savingLock s) "final" (saveServer False)
where
runServer :: (ServiceName, ATransport) -> M ()
runServer (tcpPort, ATransport t) = do
serverParams <- asks tlsServerParams
ss <- asks sockets
serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams
runTransportServerState ss started tcpPort serverParams tCfg (runClient serverSignKey t)
env <- ask
liftIO $ runTransportServerState ss started tcpPort serverParams tCfg $ \h -> runClient serverSignKey t h `runReaderT` env
fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey
saveServer :: Bool -> M ()
@@ -387,7 +388,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
withLog (`logDeleteQueue` queueId)
updateDeletedStats q
liftIO . hPutStrLn h $ "ok, " <> show numDeleted <> " messages deleted"
CPSave -> withAdminRole $ withLock (savingLock srv) "control" $ do
CPSave -> withAdminRole $ withLock' (savingLock srv) "control" $ do
hPutStrLn h "saving server state..."
unliftIO u $ saveServer True
hPutStrLn h "server state saved!"
@@ -579,7 +580,7 @@ dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/Xo
dummyKeyX25519 :: C.PublicKey 'C.X25519
dummyKeyX25519 = "MCowBQYDK2VuAyEA4JGSMYht18H4mas/jHeBwfcM7jLwNYJNOAhi2/g4RXg="
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m ()
client :: Client -> Server -> M ()
client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Server {subscribedQ, ntfSubscribedQ, notifiers} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands"
forever $
@@ -587,7 +588,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
>>= mapM processCommand
>>= atomically . writeTBQueue sndQ
where
processCommand :: (Maybe QueueRec, Transmission Cmd) -> m (Transmission BrokerMsg)
processCommand :: (Maybe QueueRec, Transmission Cmd) -> M (Transmission BrokerMsg)
processCommand (qr_, (corrId, queueId, cmd)) = do
st <- asks queueStore
case cmd of
@@ -616,7 +617,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
OFF -> suspendQueue_ st
DEL -> delQueueAndMsgs st
where
createQueue :: QueueStore -> RcvPublicAuthKey -> RcvPublicDhKey -> SubscriptionMode -> m (Transmission BrokerMsg)
createQueue :: QueueStore -> RcvPublicAuthKey -> RcvPublicDhKey -> SubscriptionMode -> M (Transmission BrokerMsg)
createQueue st recipientKey dhKey subMode = time "NEW" $ do
(rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random
let rcvDhSecret = C.dh' dhKey privDhKey
@@ -634,7 +635,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
(corrId,queueId,) <$> addQueueRetry 3 qik qRec
where
addQueueRetry ::
Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> m BrokerMsg
Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> M BrokerMsg
addQueueRetry 0 _ _ = pure $ ERR INTERNAL
addQueueRetry n qik qRec = do
ids@(rId, _) <- getIds
@@ -659,25 +660,25 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
Right q -> logCreateQueue s q
_ -> pure ()
getIds :: m (RecipientId, SenderId)
getIds :: M (RecipientId, SenderId)
getIds = do
n <- asks $ queueIdBytes . config
liftM2 (,) (randomId n) (randomId n)
secureQueue_ :: QueueStore -> SndPublicAuthKey -> m (Transmission BrokerMsg)
secureQueue_ :: QueueStore -> SndPublicAuthKey -> M (Transmission BrokerMsg)
secureQueue_ st sKey = time "KEY" $ do
withLog $ \s -> logSecureQueue s queueId sKey
stats <- asks serverStats
atomically $ modifyTVar' (qSecured stats) (+ 1)
atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey
addQueueNotifier_ :: QueueStore -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> m (Transmission BrokerMsg)
addQueueNotifier_ :: QueueStore -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> M (Transmission BrokerMsg)
addQueueNotifier_ st notifierKey dhKey = time "NKEY" $ do
(rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random
let rcvNtfDhSecret = C.dh' dhKey privDhKey
(corrId,queueId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret
where
addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> m BrokerMsg
addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> M BrokerMsg
addNotifierRetry 0 _ _ = pure $ ERR INTERNAL
addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do
notifierId <- randomId =<< asks (queueIdBytes . config)
@@ -689,17 +690,17 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
withLog $ \s -> logAddNotifier s queueId ntfCreds
pure $ NID notifierId rcvPublicDhKey
deleteQueueNotifier_ :: QueueStore -> m (Transmission BrokerMsg)
deleteQueueNotifier_ :: QueueStore -> M (Transmission BrokerMsg)
deleteQueueNotifier_ st = do
withLog (`logDeleteNotifier` queueId)
okResp <$> atomically (deleteQueueNotifier st queueId)
suspendQueue_ :: QueueStore -> m (Transmission BrokerMsg)
suspendQueue_ :: QueueStore -> M (Transmission BrokerMsg)
suspendQueue_ st = do
withLog (`logSuspendQueue` queueId)
okResp <$> atomically (suspendQueue st queueId)
subscribeQueue :: QueueRec -> RecipientId -> m (Transmission BrokerMsg)
subscribeQueue :: QueueRec -> RecipientId -> M (Transmission BrokerMsg)
subscribeQueue qr rId = do
atomically (TM.lookup rId subscriptions) >>= \case
Nothing ->
@@ -712,19 +713,19 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
s ->
atomically (tryTakeTMVar $ delivered s) >> deliver sub
where
newSub :: m (TVar Sub)
newSub :: M (TVar Sub)
newSub = time "SUB newSub" . atomically $ do
writeTQueue subscribedQ (rId, clnt)
sub <- newTVar =<< newSubscription NoSub
TM.insert rId sub subscriptions
pure sub
deliver :: TVar Sub -> m (Transmission BrokerMsg)
deliver :: TVar Sub -> M (Transmission BrokerMsg)
deliver sub = do
q <- getStoreMsgQueue "SUB" rId
msg_ <- atomically $ tryPeekMsg q
deliverMessage "SUB" qr rId sub q msg_
getMessage :: QueueRec -> m (Transmission BrokerMsg)
getMessage :: QueueRec -> M (Transmission BrokerMsg)
getMessage qr = time "GET" $ do
atomically (TM.lookup queueId subscriptions) >>= \case
Nothing ->
@@ -743,7 +744,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
sub <- newTVar s
TM.insert queueId sub subscriptions
pure s
getMessage_ :: Sub -> m (Transmission BrokerMsg)
getMessage_ :: Sub -> M (Transmission BrokerMsg)
getMessage_ s = do
q <- getStoreMsgQueue "GET" queueId
atomically $
@@ -753,17 +754,17 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
in setDelivered s msg $> (corrId, queueId, MSG encMsg)
_ -> pure (corrId, queueId, OK)
withQueue :: (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg)
withQueue :: (QueueRec -> M (Transmission BrokerMsg)) -> M (Transmission BrokerMsg)
withQueue action = maybe (pure $ err AUTH) action qr_
subscribeNotifications :: m (Transmission BrokerMsg)
subscribeNotifications :: M (Transmission BrokerMsg)
subscribeNotifications = time "NSUB" . atomically $ do
unlessM (TM.member queueId ntfSubscriptions) $ do
writeTQueue ntfSubscribedQ (queueId, clnt)
TM.insert queueId () ntfSubscriptions
pure ok
acknowledgeMsg :: QueueRec -> MsgId -> m (Transmission BrokerMsg)
acknowledgeMsg :: QueueRec -> MsgId -> M (Transmission BrokerMsg)
acknowledgeMsg qr msgId = time "ACK" $ do
atomically (TM.lookup queueId subscriptions) >>= \case
Nothing -> pure $ err NO_MSG
@@ -789,7 +790,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
if msgId == msgId' || B.null msgId
then pure $ Just s
else putTMVar delivered msgId' $> Nothing
updateStats :: Message -> m ()
updateStats :: Message -> M ()
updateStats = \case
MessageQuota {} -> pure ()
Message {msgFlags} -> do
@@ -801,7 +802,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
atomically $ modifyTVar' (msgRecvNtf stats) (+ 1)
atomically $ updatePeriodStats (activeQueuesNtf stats) queueId
sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
sendMessage :: QueueRec -> MsgFlags -> MsgBody -> M (Transmission BrokerMsg)
sendMessage qr msgFlags msgBody
| B.length msgBody > maxMessageLength = pure $ err LARGE_MSG
| otherwise = case status qr of
@@ -827,13 +828,13 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
pure ok
where
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
mkMessage :: C.MaxLenBS MaxMessageLen -> M Message
mkMessage body = do
msgId <- randomId =<< asks (msgIdBytes . config)
msgTs <- liftIO getSystemTime
pure $ Message msgId msgTs msgFlags body
expireMessages :: MsgQueue -> m ()
expireMessages :: MsgQueue -> M ()
expireMessages q = do
msgExp <- asks $ messageExpiration . config
old <- liftIO $ mapM expireBeforeEpoch msgExp
@@ -861,7 +862,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
pure . (cbNonce,) $ fromRight "" encNMsgMeta
deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg)
deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> M (Transmission BrokerMsg)
deliverMessage name qr rId sub q msg_ = time (name <> " deliver") $ do
readTVarIO sub >>= \case
s@Sub {subThread = NoSub} ->
@@ -872,7 +873,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
_ -> forkSub $> ok
_ -> pure ok
where
forkSub :: m ()
forkSub :: M ()
forkSub = do
atomically . modifyTVar' sub $ \s -> s {subThread = SubPending}
t <- mkWeakThreadId =<< forkIO subscriber
@@ -890,7 +891,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
void $ setDelivered s msg
writeTVar sub $! s {subThread = NoSub}
time :: T.Text -> m a -> m a
time :: T.Text -> M a -> M a
time name = timed name queueId
encryptMsg :: QueueRec -> Message -> RcvMessage
@@ -906,13 +907,13 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
setDelivered :: Sub -> Message -> STM Bool
setDelivered s msg = tryPutTMVar (delivered s) (messageId msg)
getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue
getStoreMsgQueue :: T.Text -> RecipientId -> M MsgQueue
getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do
ms <- asks msgStore
quota <- asks $ msgQueueQuota . config
atomically $ getMsgQueue ms rId quota
delQueueAndMsgs :: QueueStore -> m (Transmission BrokerMsg)
delQueueAndMsgs :: QueueStore -> M (Transmission BrokerMsg)
delQueueAndMsgs st = do
withLog (`logDeleteQueue` queueId)
ms <- asks msgStore
@@ -929,7 +930,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
okResp :: Either ErrorType () -> Transmission BrokerMsg
okResp = either err $ const ok
updateDeletedStats :: (MonadUnliftIO m, MonadReader Env m) => QueueRec -> m ()
updateDeletedStats :: QueueRec -> M ()
updateDeletedStats q = do
stats <- asks serverStats
let delSel = if isNothing (senderKey q) then qDeletedNew else qDeletedSecured
@@ -937,12 +938,12 @@ updateDeletedStats q = do
atomically $ modifyTVar' (qDeletedAll stats) (+ 1)
atomically $ modifyTVar' (qCount stats) (subtract 1)
withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m ()
withLog :: (StoreLog 'WriteMode -> IO a) -> M ()
withLog action = do
env <- ask
liftIO . mapM_ action $ storeLog (env :: Env)
timed :: MonadUnliftIO m => T.Text -> RecipientId -> m a -> m a
timed :: T.Text -> RecipientId -> M a -> M a
timed name qId a = do
t <- liftIO getSystemTime
r <- a
@@ -954,10 +955,10 @@ timed name qId a = do
diff t t' = (systemSeconds t' - systemSeconds t) * sec + fromIntegral (systemNanoseconds t' - systemNanoseconds t)
sec = 1000_000000
randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m ByteString
randomId :: Int -> M ByteString
randomId n = atomically . C.randomBytes n =<< asks random
saveServerMessages :: (MonadUnliftIO m, MonadReader Env m) => Bool -> m ()
saveServerMessages :: Bool -> M ()
saveServerMessages keepMsgs = asks (storeMsgsFile . config) >>= mapM_ saveMessages
where
saveMessages f = do
@@ -972,7 +973,7 @@ saveServerMessages keepMsgs = asks (storeMsgsFile . config) >>= mapM_ saveMessag
atomically (getMessages ms rId)
>>= mapM_ (B.hPutStrLn h . strEncode . MLRv3 rId)
restoreServerMessages :: forall m. (MonadUnliftIO m, MonadReader Env m) => m Int
restoreServerMessages :: M Int
restoreServerMessages = asks (storeMsgsFile . config) >>= \case
Just f -> ifM (doesFileExist f) (restoreMessages f) (pure 0)
Nothing -> pure 0
@@ -1008,7 +1009,7 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= \case
msgErr :: Show e => String -> e -> String
msgErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s)
saveServerStats :: (MonadUnliftIO m, MonadReader Env m) => m ()
saveServerStats :: M ()
saveServerStats =
asks (serverStatsBackupFile . config)
>>= mapM_ (\f -> asks serverStats >>= atomically . getServerStatsData >>= liftIO . saveStats f)
@@ -1018,7 +1019,7 @@ saveServerStats =
B.writeFile f $ strEncode stats
logInfo "server stats saved"
restoreServerStats :: (MonadUnliftIO m, MonadReader Env m) => Int -> m ()
restoreServerStats :: Int -> M ()
restoreServerStats expiredWhileRestoring = asks (serverStatsBackupFile . config) >>= mapM_ restoreStats
where
restoreStats f = whenM (doesFileExist f) $ do
+6 -6
View File
@@ -173,25 +173,25 @@ newSubscription subThread = do
delivered <- newEmptyTMVar
return Sub {subThread, delivered}
newEnv :: forall m. (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env
newEnv :: ServerConfig -> IO Env
newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile, storeLogFile} = do
server <- atomically newServer
queueStore <- atomically newQueueStore
msgStore <- atomically newMsgStore
random <- liftIO C.newRandom
storeLog <- restoreQueues queueStore `mapM` storeLogFile
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile
tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile
Fingerprint fp <- loadFingerprint caCertificateFile
let serverIdentity = KeyHash fp
serverStats <- atomically . newServerStats =<< liftIO getCurrentTime
serverStats <- atomically . newServerStats =<< getCurrentTime
sockets <- atomically newSocketState
clientSeq <- newTVarIO 0
clients <- newTVarIO mempty
return Env {config, server, serverIdentity, queueStore, msgStore, random, storeLog, tlsServerParams, serverStats, sockets, clientSeq, clients}
where
restoreQueues :: QueueStore -> FilePath -> m (StoreLog 'WriteMode)
restoreQueues :: QueueStore -> FilePath -> IO (StoreLog 'WriteMode)
restoreQueues QueueStore {queues, senders, notifiers} f = do
(qs, s) <- liftIO $ readWriteStoreLog f
(qs, s) <- readWriteStoreLog f
atomically $ do
writeTVar queues =<< mapM newTVar qs
writeTVar senders $! M.foldr' addSender M.empty qs
+4 -5
View File
@@ -23,7 +23,6 @@ where
import Control.Applicative (optional)
import Control.Logger.Simple (logError)
import Control.Monad (when)
import Control.Monad.IO.Unlift
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Char8 (ByteString)
@@ -126,10 +125,10 @@ clientTransportConfig TransportClientConfig {logTLSErrors} =
TransportConfig {logTLSErrors, transportTimeout = Nothing}
-- | Connect to passed TCP host:port and pass handle to the client.
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTransportClient :: Transport c => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> IO a) -> IO a
runTransportClient = runTLSTransportClient supportedParameters Nothing
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTLSTransportClient :: Transport c => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> IO a) -> IO a
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials} proxyUsername host port keyHash client = do
serverCert <- newEmptyTMVarIO
let hostName = B.unpack $ strEncode host
@@ -137,7 +136,7 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy,
connectTCP = case socksProxy of
Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host
_ -> connectTCPClient hostName
c <- liftIO $ do
c <- do
sock <- connectTCP port
mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive" <> tshow e)
let tCfg = clientTransportConfig cfg
@@ -148,7 +147,7 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy,
closeTLS tls >> error "onServerCertificate failed"
Just c -> pure c
getClientConnection tCfg chain tls
client c `E.finally` liftIO (closeConnection c)
client c `E.finally` closeConnection c
where
hostAddr = \case
THIPv4 addr -> SocksAddrIPV4 $ tupleToHostAddress addr
+6 -8
View File
@@ -26,7 +26,6 @@ where
import Control.Applicative ((<|>))
import Control.Logger.Simple
import Control.Monad
import Control.Monad.IO.Unlift
import qualified Crypto.Store.X509 as SX
import Data.Default (def)
import Data.List (find)
@@ -70,27 +69,26 @@ serverTransportConfig TransportServerConfig {logTLSErrors} =
-- | Run transport server (plain TCP or WebSockets) on passed TCP port and signal when server started and stopped via passed TMVar.
--
-- All accepted connections are passed to the passed function.
runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m ()
runTransportServer :: forall c. Transport c => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> IO ()) -> IO ()
runTransportServer started port params cfg server = do
ss <- atomically newSocketState
runTransportServerState ss started port params cfg server
runTransportServerState :: forall c m. (Transport c, MonadUnliftIO m) => SocketState -> TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m ()
runTransportServerState :: forall c . Transport c => SocketState -> TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> IO ()) -> IO ()
runTransportServerState ss started port = runTransportServerSocketState ss started (startTCPServer started port) (transportName (TProxy :: TProxy c))
-- | Run a transport server with provided connection setup and handler.
runTransportServerSocket :: (MonadUnliftIO m, Transport a) => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m ()
runTransportServerSocket :: Transport a => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> IO ()) -> IO ()
runTransportServerSocket started getSocket threadLabel serverParams cfg server = do
ss <- atomically newSocketState
runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server
-- | Run a transport server with provided connection setup and handler.
runTransportServerSocketState :: (MonadUnliftIO m, Transport a) => SocketState -> TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m ()
runTransportServerSocketState :: Transport a => SocketState -> TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> IO ()) -> IO ()
runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server = do
u <- askUnliftIO
labelMyThread $ "transport server for " <> threadLabel
liftIO . runTCPServerSocket ss started getSocket $ \conn ->
E.bracket (setup conn >>= maybe (fail "tls setup timeout") pure) closeConnection (unliftIO u . server)
runTCPServerSocket ss started getSocket $ \conn ->
E.bracket (setup conn >>= maybe (fail "tls setup timeout") pure) closeConnection server
where
tCfg = serverTransportConfig cfg
setup conn = timeout (tlsSetupTimeout cfg) $ do
+16 -16
View File
@@ -50,26 +50,18 @@ maybeWord :: (a -> ByteString) -> Maybe a -> ByteString
maybeWord f = maybe "" $ B.cons ' ' . f
{-# INLINE maybeWord #-}
liftIOEither :: (MonadIO m, MonadError e m) => IO (Either e a) -> m a
liftIOEither a = liftIO a >>= liftEither
{-# INLINE liftIOEither #-}
liftError :: (MonadIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a
liftError f = liftEitherError f . runExceptT
liftError :: MonadIO m => (e -> e') -> ExceptT e IO a -> ExceptT e' m a
liftError f = liftError' f . runExceptT
{-# INLINE liftError #-}
liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a
liftEitherError f a = liftIOEither (first f <$> a)
{-# INLINE liftEitherError #-}
liftError' :: MonadIO m => (e -> e') -> IO (Either e a) -> ExceptT e' m a
liftError' f = ExceptT . fmap (first f) . liftIO
{-# INLINE liftError' #-}
liftEitherWith :: MonadError e' m => (e -> e') -> Either e a -> m a
liftEitherWith :: MonadIO m => (e -> e') -> Either e a -> ExceptT e' m a
liftEitherWith f = liftEither . first f
{-# INLINE liftEitherWith #-}
liftE :: (e -> e') -> ExceptT e IO a -> ExceptT e' IO a
liftE f a = ExceptT $ first f <$> runExceptT a
{-# INLINE liftE #-}
ifM :: Monad m => m Bool -> m a -> m a -> m a
ifM ba t f = ba >>= \b -> if b then t else f
{-# INLINE ifM #-}
@@ -109,10 +101,18 @@ tryAllErrors :: (MonadUnliftIO m, MonadError e m) => (E.SomeException -> e) -> m
tryAllErrors err action = tryError action `UE.catch` (pure . Left . err)
{-# INLINE tryAllErrors #-}
tryAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> m (Either e a)
tryAllErrors' err action = runExceptT action `UE.catch` (pure . Left . err)
{-# INLINE tryAllErrors' #-}
catchAllErrors :: (MonadUnliftIO m, MonadError e m) => (E.SomeException -> e) -> m a -> (e -> m a) -> m a
catchAllErrors err action handler = tryAllErrors err action >>= either handler pure
{-# INLINE catchAllErrors #-}
catchAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> m a) -> m a
catchAllErrors' err action handler = tryAllErrors' err action >>= either handler pure
{-# INLINE catchAllErrors' #-}
catchThrow :: (MonadUnliftIO m, MonadError e m) => m a -> (E.SomeException -> e) -> m a
catchThrow action err = catchAllErrors err action throwError
{-# INLINE catchThrow #-}
@@ -148,8 +148,8 @@ safeDecodeUtf8 = decodeUtf8With onError
where
onError _ _ = Just '?'
timeoutThrow :: (MonadUnliftIO m, MonadError e m) => e -> Int -> m a -> m a
timeoutThrow e ms action = timeout ms action >>= maybe (throwError e) pure
timeoutThrow :: MonadUnliftIO m => e -> Int -> ExceptT e m a -> ExceptT e m a
timeoutThrow e ms action = ExceptT (sequence <$> (ms `timeout` runExceptT action)) >>= maybe (throwError e) pure
threadDelay' :: Int64 -> IO ()
threadDelay' time
+12 -15
View File
@@ -103,14 +103,14 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
found@(RCCtrlAddress {address} :| _) <- findCtrlAddress
c@RCHClient_ {startedPort, announcer} <- liftIO mkClient
hostKeys <- atomically genHostKeys
action <- runClient c r hostKeys `putRCError` r
action <- liftIO $ runClient c r hostKeys
-- wait for the port to make invitation
portNum <- atomically $ readTMVar startedPort
signedInv@RCSignedInvitation {invitation} <- maybe (throwError RCETLSStartFailed) (liftIO . mkInvitation hostKeys address) portNum
when multicast $ case knownHost of
Nothing -> throwError RCENewController
Just KnownHostPairing {hostDhPubKey} -> do
ann <- async . liftIO . runExceptT $ announceRC drg 60 idPrivKey hostDhPubKey hostKeys invitation
ann <- liftIO . async . runExceptT $ announceRC drg 60 idPrivKey hostDhPubKey hostKeys invitation
atomically $ putTMVar announcer ann
pure (found, signedInv, RCHostClient {action, client_ = c}, r)
where
@@ -125,9 +125,9 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
endSession <- newEmptyTMVarIO
hostCAHash <- newEmptyTMVarIO
pure RCHClient_ {startedPort, announcer, hostCAHash, endSession}
runClient :: RCHClient_ -> RCStepTMVar (SessionCode, TLS, RCStepTMVar (RCHostSession, RCHostHello, RCHostPairing)) -> RCHostKeys -> ExceptT RCErrorType IO (Async ())
runClient :: RCHClient_ -> RCStepTMVar (SessionCode, TLS, RCStepTMVar (RCHostSession, RCHostHello, RCHostPairing)) -> RCHostKeys -> IO (Async ())
runClient RCHClient_ {startedPort, announcer, hostCAHash, endSession} r hostKeys = do
tlsCreds <- liftIO $ genTLSCredentials drg caKey caCert
tlsCreds <- genTLSCredentials drg caKey caCert
startTLSServer port_ startedPort tlsCreds (tlsHooks r knownHost hostCAHash) $ \tls ->
void . runExceptT $ do
r' <- newEmptyTMVarIO
@@ -265,7 +265,7 @@ connectRCCtrl_ :: TVar ChaChaDRG -> RCCtrlPairing -> RCInvitation -> J.Value ->
connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca, host, port} hostAppInfo = do
r <- newEmptyTMVarIO
c <- liftIO mkClient
action <- async $ runClient c r `putRCError` r
action <- liftIO . async . void . runExceptT $ runClient c r `putRCError` r
pure (RCCtrlClient {action, client_ = c}, r)
where
mkClient :: IO RCCClient_
@@ -280,7 +280,7 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca,
TLS.Credentials (creds : _) -> pure $ Just creds
_ -> throwError $ RCEInternal "genTLSCredentials must generate credentials"
let clientConfig = defaultTransportClientConfig {clientCredentials}
runTransportClient clientConfig Nothing host (show port) (Just ca) $ \tls@TLS {tlsBuffer, tlsContext} -> do
ExceptT . runTransportClient clientConfig Nothing host (show port) (Just ca) $ \tls@TLS {tlsBuffer, tlsContext} -> runExceptT $ do
-- pump socket to detect connection problems
liftIO $ peekBuffered tlsBuffer 100000 (TLS.recvData tlsContext) >>= logDebug . tshow -- should normally be ("", Nothing) here
logDebug "Got TLS connection"
@@ -360,11 +360,11 @@ prepareCtrlSession
-- * Multicast discovery
announceRC :: TVar ChaChaDRG -> Int -> C.PrivateKeyEd25519 -> C.PublicKeyX25519 -> RCHostKeys -> RCInvitation -> ExceptT RCErrorType IO ()
announceRC drg maxCount idPrivKey knownDhPub RCHostKeys {sessKeys, dhKeys} inv = withSender $ \sender -> do
announceRC drg maxCount idPrivKey knownDhPub RCHostKeys {sessKeys, dhKeys} inv = ExceptT $ withSender $ \sender -> runExceptT $ do
replicateM_ maxCount $ do
logDebug "Announcing..."
nonce <- atomically $ C.randomCbNonce drg
encInvitation <- liftEitherWith undefined $ C.cbEncrypt sharedKey nonce sigInvitation encInvitationSize
encInvitation <- liftEitherWith (const RCEEncrypt) $ C.cbEncrypt sharedKey nonce sigInvitation encInvitationSize
liftIO . UDP.send sender $ smpEncode RCEncInvitation {dhPubKey, nonce, encInvitation}
threadDelay 1000000
where
@@ -375,9 +375,9 @@ announceRC drg maxCount idPrivKey knownDhPub RCHostKeys {sessKeys, dhKeys} inv =
discoverRCCtrl :: TMVar Int -> NonEmpty RCCtrlPairing -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation)
discoverRCCtrl subscribers pairings =
timeoutThrow RCENotDiscovered 30000000 $ withListener subscribers $ \listener ->
loop $ do
(source, bytes) <- recvAnnounce listener
timeoutThrow RCENotDiscovered 30000000 $ ExceptT $ withListener subscribers $ \listener ->
runExceptT . loop $ do
(source, bytes) <- liftIO $ recvAnnounce listener
encInvitation <- liftEitherWith (const RCEInvitation) $ smpDecode bytes
r@(_, RCVerifiedInvitation RCInvitation {host}) <- findRCCtrlPairing pairings encInvitation
case source of
@@ -386,10 +386,7 @@ discoverRCCtrl subscribers pairings =
pure r
where
loop :: ExceptT RCErrorType IO a -> ExceptT RCErrorType IO a
loop action =
liftIO (runExceptT action) >>= \case
Left err -> logError (tshow err) >> loop action
Right res -> pure res
loop action = action `catchRCError` \e -> logError (tshow e) >> loop action
findRCCtrlPairing :: NonEmpty RCCtrlPairing -> RCEncInvitation -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation)
findRCCtrlPairing pairings RCEncInvitation {dhPubKey, nonce, encInvitation} = do
+10 -11
View File
@@ -68,7 +68,7 @@ preferAddress RCCtrlAddress {address, interface} addrs =
matchAddr RCCtrlAddress {address = a} = a == address
matchIface RCCtrlAddress {interface = i} = i == interface
startTLSServer :: MonadUnliftIO m => Maybe Word16 -> TMVar (Maybe N.PortNumber) -> TLS.Credentials -> TLS.ServerHooks -> (Transport.TLS -> IO ()) -> m (Async ())
startTLSServer :: Maybe Word16 -> TMVar (Maybe N.PortNumber) -> TLS.Credentials -> TLS.ServerHooks -> (Transport.TLS -> IO ()) -> IO (Async ())
startTLSServer port_ startedOnPort credentials hooks server = async . liftIO $ do
started <- newEmptyTMVarIO
bracketOnError (startTCPServer started $ maybe "0" show port_) (\_e -> setPort Nothing) $ \socket ->
@@ -91,14 +91,14 @@ startTLSServer port_ startedOnPort credentials hooks server = async . liftIO $ d
TLS.serverSupported = supportedParameters
}
withSender :: MonadUnliftIO m => (UDP.UDPSocket -> m a) -> m a
withSender = bracket (liftIO $ UDP.clientSocket MULTICAST_ADDR_V4 DISCOVERY_PORT False) (liftIO . UDP.close)
withSender :: (UDP.UDPSocket -> IO a) -> IO a
withSender = bracket (UDP.clientSocket MULTICAST_ADDR_V4 DISCOVERY_PORT False) (UDP.close)
withListener :: MonadUnliftIO m => TMVar Int -> (UDP.ListenSocket -> m a) -> m a
withListener :: TMVar Int -> (UDP.ListenSocket -> IO a) -> IO a
withListener subscribers = bracket (openListener subscribers) (closeListener subscribers)
openListener :: MonadIO m => TMVar Int -> m UDP.ListenSocket
openListener subscribers = liftIO $ do
openListener :: TMVar Int -> IO UDP.ListenSocket
openListener subscribers = do
sock <- UDP.serverSocket (MULTICAST_ADDR_V4, read DISCOVERY_PORT)
logDebug $ "Discovery listener socket: " <> tshow sock
let raw = UDP.listenSocket sock
@@ -106,10 +106,9 @@ openListener subscribers = liftIO $ do
joinMulticast subscribers raw (listenerHostAddr4 sock)
pure sock
closeListener :: MonadIO m => TMVar Int -> UDP.ListenSocket -> m ()
closeListener :: TMVar Int -> UDP.ListenSocket -> IO ()
closeListener subscribers sock =
liftIO $
partMulticast subscribers (UDP.listenSocket sock) (listenerHostAddr4 sock) `finally` UDP.stop sock
partMulticast subscribers (UDP.listenSocket sock) (listenerHostAddr4 sock) `finally` UDP.stop sock
joinMulticast :: TMVar Int -> N.Socket -> N.HostAddress -> IO ()
joinMulticast subscribers sock group = do
@@ -132,7 +131,7 @@ listenerHostAddr4 sock = case UDP.mySockAddr sock of
N.SockAddrInet _port host -> host
_ -> error "MULTICAST_ADDR_V4 is V4"
recvAnnounce :: MonadIO m => UDP.ListenSocket -> m (N.SockAddr, ByteString)
recvAnnounce sock = liftIO $ do
recvAnnounce :: UDP.ListenSocket -> IO (N.SockAddr, ByteString)
recvAnnounce sock = do
(invite, UDP.ClientSockAddr source _cmsg) <- UDP.recvFrom sock
pure (source, invite)
-11
View File
@@ -238,17 +238,6 @@ type SessionCode = ByteString
type RCStepTMVar a = TMVar (Either RCErrorType a)
type Tasks = TVar [Async ()]
asyncRegistered :: MonadUnliftIO m => Tasks -> m () -> m ()
asyncRegistered tasks action = async action >>= registerAsync tasks
registerAsync :: MonadIO m => Tasks -> Async () -> m ()
registerAsync tasks = atomically . modifyTVar tasks . (:)
cancelTasks :: MonadIO m => Tasks -> m ()
cancelTasks tasks = readTVarIO tasks >>= mapM_ cancel
$(JQ.deriveJSON (sumTypeJSON $ dropPrefix "RCE") ''RCErrorType)
$(JQ.deriveJSON defaultJSON ''RCCtrlAddress)
+28 -29
View File
@@ -233,13 +233,13 @@ inAnyOrder g rs = do
expected :: a -> (a -> Bool) -> Bool
expected r rp = rp r
createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c)
createConnection :: AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> AE (ConnId, ConnectionRequestUri c)
createConnection c userId enableNtfs cMode clientData = A.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQSupportOn)
joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId
joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE ConnId
joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId enableNtfs cReq connInfo PQSupportOn
sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> m AgentMsgId
sendMessage :: AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> AE AgentMsgId
sendMessage c connId msgFlags msgBody = do
(msgId, pqEnc) <- A.sendMessage c connId PQEncOn msgFlags msgBody
liftIO $ pqEnc `shouldBe` PQEncOn
@@ -664,7 +664,7 @@ testAsyncInitiatingOffline :: HasCallStack => IO ()
testAsyncInitiatingOffline =
withAgentClients2 $ \alice bob -> runRight_ $ do
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe
disposeAgentClient alice
liftIO $ disposeAgentClient alice
aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe
alice' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB
subscribeConnection alice' bobId
@@ -680,7 +680,7 @@ testAsyncJoiningOfflineBeforeActivation =
withAgentClients2 $ \alice bob -> runRight_ $ do
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe
disposeAgentClient bob
liftIO $ disposeAgentClient bob
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
bob' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB2
@@ -694,9 +694,9 @@ testAsyncBothOffline :: HasCallStack => IO ()
testAsyncBothOffline =
withAgentClients2 $ \alice bob -> runRight_ $ do
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe
disposeAgentClient alice
liftIO $ disposeAgentClient alice
aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe
disposeAgentClient bob
liftIO $ disposeAgentClient bob
alice' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB
subscribeConnection alice' bobId
("", _, CONF confId _ "bob's connInfo") <- get alice'
@@ -1067,8 +1067,7 @@ testExpireMessageQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1} testP
b <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2
(aId, bId) <- runRight $ do
(aId, bId) <- makeConnection a b
liftIO $ threadDelay 500000
disposeAgentClient b
liftIO $ threadDelay 500000 >> disposeAgentClient b
4 <- sendMessage a bId SMP.noMsgFlags "1"
get a ##> ("", bId, SENT 4)
5 <- sendMessage a bId SMP.noMsgFlags "2"
@@ -1091,8 +1090,7 @@ testExpireManyMessagesQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1}
b <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2
(aId, bId) <- runRight $ do
(aId, bId) <- makeConnection a b
liftIO $ threadDelay 500000
disposeAgentClient b
liftIO $ threadDelay 500000 >> disposeAgentClient b
4 <- sendMessage a bId SMP.noMsgFlags "1"
get a ##> ("", bId, SENT 4)
5 <- sendMessage a bId SMP.noMsgFlags "2"
@@ -1161,13 +1159,13 @@ setupDesynchronizedRatchet alice bob = do
runRight_ $ do
subscribeConnection bob2 aliceId
Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId PQSupportOn False
Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ synchronizeRatchet bob2 aliceId PQSupportOn False
8 <- sendMessage alice bobId SMP.noMsgFlags "hello 5"
get alice ##> ("", bobId, SENT 8)
get bob2 =##> ratchetSyncP aliceId RSRequired
Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ sendMessage bob2 aliceId SMP.noMsgFlags "hello 6"
Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ sendMessage bob2 aliceId SMP.noMsgFlags "hello 6"
pure ()
pure (aliceId, bobId, bob2)
@@ -1224,7 +1222,7 @@ testRatchetSyncClientRestart t = do
("", "", DOWN _ _) <- nGet bob2
ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False
liftIO $ ratchetSyncState `shouldBe` RSStarted
disposeAgentClient bob2
liftIO $ disposeAgentClient bob2
bob3 <- getSMPAgentClient' 3 agentCfg initAgentServers testDB2
withSmpServerStoreMsgLogOn t testPort $ \_ -> do
runRight_ $ do
@@ -1420,12 +1418,12 @@ testSuspendingAgent =
get a ##> ("", bId, SENT 4)
get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False
ackMessage b aId 4 Nothing
suspendAgent b 1000000
liftIO $ suspendAgent b 1000000
get' b ##> ("", "", SUSPENDED)
5 <- sendMessage a bId SMP.noMsgFlags "hello 2"
get a ##> ("", bId, SENT 5)
Nothing <- 100000 `timeout` get b
foregroundAgent b
liftIO $ foregroundAgent b
get b =##> \case ("", c, Msg "hello 2") -> c == aId; _ -> False
testSuspendingAgentCompleteSending :: ATransport -> IO ()
@@ -1444,7 +1442,7 @@ testSuspendingAgentCompleteSending t = withAgentClients2 $ \a b -> do
5 <- sendMessage b aId SMP.noMsgFlags "hello too"
6 <- sendMessage b aId SMP.noMsgFlags "how are you?"
liftIO $ threadDelay 100000
suspendAgent b 5000000
liftIO $ suspendAgent b 5000000
withSmpServerStoreLogOn t testPort $ \_ -> runRight_ @AgentErrorType $ do
pGet b =##> \case ("", c, APC _ (SENT 5)) -> c == aId; ("", "", APC _ UP {}) -> True; _ -> False
@@ -1473,7 +1471,7 @@ testSuspendingAgentTimeout t = withAgentClients2 $ \a b -> do
("", "", DOWN {}) <- nGet b
5 <- sendMessage b aId SMP.noMsgFlags "hello too"
6 <- sendMessage b aId SMP.noMsgFlags "how are you?"
suspendAgent b 100000
liftIO $ suspendAgent b 100000
("", "", SUSPENDED) <- nGet b
pure ()
@@ -2095,7 +2093,7 @@ testSwitchDelete servers = do
runRight_ $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
disposeAgentClient b
liftIO $ disposeAgentClient b
stats <- switchConnectionAsync a "" bId
liftIO $ rcvSwchStatuses' stats `shouldMatchList` [Just RSSwitchStarted]
phaseRcv a bId SPStarted [Just RSSendingQADD, Nothing]
@@ -2120,7 +2118,7 @@ testAbortSwitchStarted servers = do
liftIO $ rcvSwchStatuses' stats `shouldMatchList` [Just RSSwitchStarted]
phaseRcv a bId SPStarted [Just RSSendingQADD, Nothing]
-- repeat switch is prohibited
Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ switchConnectionAsync a "" bId
Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ switchConnectionAsync a "" bId
-- abort current switch
stats' <- abortConnectionSwitch a bId
liftIO $ rcvSwchStatuses' stats' `shouldMatchList` [Nothing]
@@ -2242,7 +2240,7 @@ testCannotAbortSwitchSecured servers = do
withA' $ \a -> do
phaseRcv a bId SPConfirmed [Just RSSendingQADD, Nothing]
phaseRcv a bId SPSecured [Just RSSendingQUSE, Nothing]
Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ abortConnectionSwitch a bId
Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ abortConnectionSwitch a bId
pure ()
withA $ \a -> withB $ \b -> runRight_ $ do
subscribeConnection a bId
@@ -2407,7 +2405,7 @@ testSMPServerConnectionTest :: ATransport -> Maybe BasicAuth -> SMPServerWithAut
testSMPServerConnectionTest t newQueueBasicAuth srv =
withSmpServerConfigOn t cfg {newQueueBasicAuth} testPort2 $ \_ -> do
a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB -- initially passed server is not running
runRight $ testProtocolServer a 1 srv
testProtocolServer a 1 srv
testRatchetAdHash :: HasCallStack => IO ()
testRatchetAdHash =
@@ -2551,7 +2549,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do
exchangeGreetings a bId1' b aId1'
a `hasClients` 1
b `hasClients` 1
setNetworkConfig a nc {sessionMode = TSMEntity}
liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- nGet a
("", "", UP _ _) <- nGet a
@@ -2560,7 +2558,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do
exchangeGreetingsMsgId 6 a bId1 b aId1
exchangeGreetingsMsgId 6 a bId1' b aId1'
liftIO $ threadDelay 250000
setNetworkConfig a nc {sessionMode = TSMUser}
liftIO $ setNetworkConfig a nc {sessionMode = TSMUser}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- nGet a
("", "", DOWN _ _) <- nGet a
@@ -2575,7 +2573,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do
exchangeGreetings a bId2' b aId2'
a `hasClients` 2
b `hasClients` 1
setNetworkConfig a nc {sessionMode = TSMEntity}
liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- nGet a
("", "", DOWN _ _) <- nGet a
@@ -2587,7 +2585,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do
exchangeGreetingsMsgId 6 a bId2 b aId2
exchangeGreetingsMsgId 6 a bId2' b aId2'
liftIO $ threadDelay 250000
setNetworkConfig a nc {sessionMode = TSMUser}
liftIO $ setNetworkConfig a nc {sessionMode = TSMUser}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- nGet a
("", "", DOWN _ _) <- nGet a
@@ -2625,9 +2623,10 @@ testServerMultipleIdentities =
get bob ##> ("", aliceId, CON)
exchangeGreetings alice bobId bob aliceId
-- this saves queue with second server identity
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe
disposeAgentClient bob
bob' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB2
bob' <- liftIO $ do
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe
disposeAgentClient bob
getSMPAgentClient' 3 agentCfg initAgentServers testDB2
subscribeConnection bob' aliceId
exchangeGreetingsMsgId 6 alice bobId bob' aliceId
where
+16 -15
View File
@@ -42,7 +42,6 @@ import Control.Monad.Trans.Except
import qualified Data.Aeson as J
import qualified Data.Aeson.Types as JT
import Data.Bifunctor (bimap, first)
import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Text.Encoding (encodeUtf8)
@@ -51,6 +50,7 @@ import SMPAgentClient (agentCfg, initAgentServers, initAgentServers2, testDB, te
import SMPClient (cfg, cfgV7, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn)
import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage)
import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore')
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers)
import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO)
import Simplex.Messaging.Agent.Store.SQLite (getSavedNtfToken)
@@ -179,7 +179,7 @@ testNotificationToken APNSMockServer {apnsQ} = do
deleteNtfToken a tkn
-- agent deleted this token
Left (CMD PROHIBITED) <- tryE $ checkNtfToken a tkn
disposeAgentClient a
liftIO $ disposeAgentClient a
(.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString
v .-> key = do
@@ -211,7 +211,7 @@ testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do
-- can still use the first verification code, it is the same after decryption
verifyNtfToken a tkn nonce verification
NTActive <- checkNtfToken a tkn
disposeAgentClient a
liftIO $ disposeAgentClient a
testNtfTokenSecondRegistration :: APNSMockServer -> IO ()
testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do
@@ -247,8 +247,9 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do
Left (NTF AUTH) <- tryE $ checkNtfToken a tkn
-- and the second is active
NTActive <- checkNtfToken a' tkn
disposeAgentClient a
disposeAgentClient a'
pure ()
disposeAgentClient a
disposeAgentClient a'
testNtfTokenServerRestart :: ATransport -> APNSMockServer -> IO ()
testNtfTokenServerRestart t APNSMockServer {apnsQ} = do
@@ -277,11 +278,11 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do
liftIO $ sendApnsResponse' APNSRespOk
verifyNtfToken a' tkn nonce' verification'
NTActive <- checkNtfToken a' tkn
disposeAgentClient a'
liftIO $ disposeAgentClient a'
getTestNtfTokenPort :: (MonadUnliftIO m, MonadError AgentErrorType m) => AgentClient -> m String
getTestNtfTokenPort :: AgentClient -> AE String
getTestNtfTokenPort a =
runReaderT (withStore' a getSavedNtfToken) (agentEnv a) >>= \case
ExceptT (runExceptT (withStore' a getSavedNtfToken) `runReaderT` agentEnv a) >>= \case
Just NtfToken {ntfServer = ProtocolServer {port}} -> pure port
Nothing -> error "no active NtfToken"
@@ -317,18 +318,18 @@ testNtfTokenChangeServers t APNSMockServer {apnsQ} =
a <- liftIO $ getSMPAgentClient' 1 agentCfg initAgentServers testDB
tkn <- registerTestToken a "abcd" NMInstant apnsQ
NTActive <- checkNtfToken a tkn
setNtfServers a [testNtfServer2]
liftIO $ setNtfServers a [testNtfServer2]
NTActive <- checkNtfToken a tkn -- still works on old server
disposeAgentClient a
liftIO $ disposeAgentClient a
pure tkn
threadDelay 1000000
a <- liftIO $ getSMPAgentClient' 2 agentCfg initAgentServers testDB
a <- getSMPAgentClient' 2 agentCfg initAgentServers testDB
runRight_ $ do
getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort
NTActive <- checkNtfToken a tkn1
setNtfServers a [testNtfServer2] -- just change configured server list
liftIO $ setNtfServers a [testNtfServer2] -- just change configured server list
getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort -- not yet changed
-- trigger token replace
tkn2 <- registerTestToken a "xyzw" NMInstant apnsQ
@@ -345,7 +346,7 @@ testRunNTFServerTests :: ATransport -> NtfServer -> IO (Maybe ProtocolTestFailur
testRunNTFServerTests t srv =
withNtfServerThreadOn t ntfTestPort $ \ntf -> do
a <- liftIO $ getSMPAgentClient' 1 agentCfg initAgentServers testDB
r <- runRight $ testProtocolServer a 1 $ ProtoServerWithAuth srv Nothing
r <- testProtocolServer a 1 $ ProtoServerWithAuth srv Nothing
killThread ntf
pure r
@@ -712,7 +713,7 @@ testNotificationsOldToken APNSMockServer {apnsQ} = do
liftIO $ threadDelay 250000
testMessageAB "hello"
-- change server
setNtfServers a [testNtfServer2] -- server 2 isn't running now, don't use
liftIO $ setNtfServers a [testNtfServer2] -- server 2 isn't running now, don't use
-- replacing token keeps server
_ <- registerTestToken a "xyzw" NMInstant apnsQ
getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort
@@ -738,7 +739,7 @@ testNotificationsNewToken APNSMockServer {apnsQ} oldNtf = do
liftIO $ threadDelay 250000
testMessageAB "hello"
-- switch
setNtfServers a [testNtfServer2]
liftIO $ setNtfServers a [testNtfServer2]
deleteNtfToken a tkn
_ <- registerTestToken a "abcd" NMInstant apnsQ
getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort2
+3 -4
View File
@@ -16,7 +16,6 @@ module NtfClient where
import Control.Monad
import Control.Monad.Except (runExceptT)
import Control.Monad.IO.Unlift
import Data.Aeson (FromJSON (..), ToJSON (..), (.:))
import qualified Data.Aeson as J
import qualified Data.Aeson.Types as JT
@@ -71,13 +70,13 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
ntfTestStoreLogFile :: FilePath
ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log"
testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleNTF c -> m a) -> m a
testNtfClient :: Transport c => (THandleNTF c -> IO a) -> IO a
testNtfClient client = do
Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost
runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h -> do
g <- liftIO C.newRandom
g <- C.newRandom
ks <- atomically $ C.generateKeyPair g
liftIO (runExceptT $ ntfClientHandshake h ks testKeyHash supportedClientNTFVRange) >>= \case
runExceptT (ntfClientHandshake h ks testKeyHash supportedClientNTFVRange) >>= \case
Right th -> client th
Left e -> error $ show e
+1 -1
View File
@@ -15,7 +15,6 @@ import Control.Concurrent (threadDelay)
import qualified Data.Aeson as J
import qualified Data.Aeson.Types as JT
import Data.Bifunctor (first)
import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Char8 (ByteString)
import Data.Text.Encoding (encodeUtf8)
import NtfClient
@@ -35,6 +34,7 @@ import ServerTests
import qualified Simplex.Messaging.Agent.Protocol as AP
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Server.Push.APNS
+9 -12
View File
@@ -12,7 +12,6 @@ module SMPAgentClient where
import Control.Monad
import Control.Monad.IO.Unlift
import Crypto.Random
import qualified Data.ByteString.Char8 as B
import Data.List.NonEmpty (NonEmpty)
import Data.Map.Strict (Map)
@@ -202,7 +201,7 @@ initAgentServers2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer
agentCfg :: AgentConfig
agentCfg =
defaultAgentConfig
{ tcpPort = agentTestPort,
{ tcpPort = Just agentTestPort,
tbqSize = 4,
-- database = testDB,
smpCfg = defaultSMPClientConfig {qSize = 1, defaultTransport = (testPort, transport @TLS), networkConfig},
@@ -224,11 +223,9 @@ fastRetryInterval = defaultReconnectInterval {initialInterval = 50_000}
fastMessageRetryInterval :: RetryInterval2
fastMessageRetryInterval = RetryInterval2 {riFast = fastRetryInterval, riSlow = fastRetryInterval}
type AgentTestMonad m = (MonadUnliftIO m, MonadRandom m, MonadFail m)
withSmpAgentThreadOn_ :: AgentTestMonad m => ATransport -> (ServiceName, ServiceName, FilePath) -> Int -> m () -> (ThreadId -> m a) -> m a
withSmpAgentThreadOn_ :: ATransport -> (ServiceName, ServiceName, FilePath) -> Int -> IO () -> (ThreadId -> IO a) -> IO a
withSmpAgentThreadOn_ t (port', smpPort', db') initClientId afterProcess =
let cfg' = agentCfg {tcpPort = port'}
let cfg' = agentCfg {tcpPort = Just port'}
initServers' = initAgentServers {smp = userServers [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
in serverBracket
( \started -> do
@@ -241,24 +238,24 @@ withSmpAgentThreadOn_ t (port', smpPort', db') initClientId afterProcess =
userServers :: NonEmpty (ProtoServerWithAuth p) -> Map UserId (NonEmpty (ProtoServerWithAuth p))
userServers srvs = M.fromList [(1, srvs)]
withSmpAgentThreadOn :: AgentTestMonad m => ATransport -> (ServiceName, ServiceName, FilePath) -> (ThreadId -> m a) -> m a
withSmpAgentThreadOn :: ATransport -> (ServiceName, ServiceName, FilePath) -> (ThreadId -> IO a) -> IO a
withSmpAgentThreadOn t a@(_, _, db') = withSmpAgentThreadOn_ t a 0 $ removeFile db'
withSmpAgentOn :: AgentTestMonad m => ATransport -> (ServiceName, ServiceName, FilePath) -> m a -> m a
withSmpAgentOn :: ATransport -> (ServiceName, ServiceName, FilePath) -> IO a -> IO a
withSmpAgentOn t (port', smpPort', db') = withSmpAgentThreadOn t (port', smpPort', db') . const
withSmpAgent :: AgentTestMonad m => ATransport -> m a -> m a
withSmpAgent :: ATransport -> IO a -> IO a
withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
testSMPAgentClientOn :: (Transport c, MonadUnliftIO m, MonadFail m) => ServiceName -> (c -> m a) -> m a
testSMPAgentClientOn :: Transport c => ServiceName -> (c -> IO a) -> IO a
testSMPAgentClientOn port' client = do
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig agentTestHost
runTransportClient defaultTransportClientConfig Nothing useHost port' (Just testKeyHash) $ \h -> do
line <- liftIO $ getLn h
line <- getLn h
if line == "Welcome to SMP agent v" <> B.pack simplexMQVersion
then client h
else do
error $ "wrong welcome message: " <> B.unpack line
testSMPAgentClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (c -> m a) -> m a
testSMPAgentClient :: Transport c => (c -> IO a) -> IO a
testSMPAgentClient = testSMPAgentClientOn agentTestPort
+6 -7
View File
@@ -13,7 +13,6 @@
module SMPClient where
import Control.Monad.Except (runExceptT)
import Control.Monad.IO.Unlift
import Data.ByteString.Char8 (ByteString)
import Data.List.NonEmpty (NonEmpty)
import Network.Socket
@@ -68,23 +67,23 @@ xit'' d t = do
ci <- runIO $ lookupEnv "CI"
(if ci == Just "true" then skip "skipped on CI" . it d else it d) t
testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleSMP c -> m a) -> m a
testSMPClient :: Transport c => (THandleSMP c -> IO a) -> IO a
testSMPClient = testSMPClientVR supportedClientSMPRelayVRange
testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRangeSMP -> (THandleSMP c -> m a) -> m a
testSMPClientVR :: Transport c => VersionRangeSMP -> (THandleSMP c -> IO a) -> IO a
testSMPClientVR vr client = do
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost
runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h -> do
g <- liftIO C.newRandom
g <- C.newRandom
ks <- atomically $ C.generateKeyPair g
liftIO (runExceptT $ smpClientHandshake h ks testKeyHash vr) >>= \case
runExceptT (smpClientHandshake h ks testKeyHash vr) >>= \case
Right th -> client th
Left e -> error $ show e
cfg :: ServerConfig
cfg =
ServerConfig
{ transports = undefined,
{ transports = [],
smpHandshakeTimeout = 60000000,
tbqSize = 1,
-- serverTbqSize = 1,
@@ -129,7 +128,7 @@ withSmpServerConfigOn t cfg' port' =
withSmpServerThreadOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a
withSmpServerThreadOn t = withSmpServerConfigOn t cfg
serverBracket :: (HasCallStack, MonadUnliftIO m) => (TMVar Bool -> m ()) -> m () -> (HasCallStack => ThreadId -> m a) -> m a
serverBracket :: HasCallStack => (TMVar Bool -> IO ()) -> IO () -> (HasCallStack => ThreadId -> IO a) -> IO a
serverBracket process afterProcess f = do
started <- newEmptyTMVarIO
E.bracket
+2 -2
View File
@@ -22,7 +22,6 @@ import Control.Exception (SomeException, try)
import Control.Monad
import Control.Monad.IO.Class
import Data.Bifunctor (first)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.Set as S
@@ -31,6 +30,7 @@ import GHC.Stack (withFrozenCallStack)
import SMPClient
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.Base64 (encode)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Protocol
@@ -103,7 +103,7 @@ tPut1 h t = do
[r] <- tPut h [Right t]
pure r
tGet1 :: (ProtocolEncoding v err cmd, Transport c, MonadIO m, MonadFail m) => THandle v c -> m (SignedTransmission err cmd)
tGet1 :: (ProtocolEncoding v err cmd, Transport c) => THandle v c -> IO (SignedTransmission err cmd)
tGet1 h = do
[r] <- liftIO $ tGet h
pure r
+13 -18
View File
@@ -103,7 +103,7 @@ testXFTPAgentSendReceive = withXFTPServer $ do
sndr <- getSMPAgentClient' 1 agentCfg initAgentServers testDB
(rfd1, rfd2) <- runRight $ do
(sfId, _, rfd1, rfd2) <- testSend sndr filePath
xftpDeleteSndFileInternal sndr sfId
liftIO $ xftpDeleteSndFileInternal sndr sfId
pure (rfd1, rfd2)
-- receive file, delete rcv file
@@ -112,9 +112,8 @@ testXFTPAgentSendReceive = withXFTPServer $ do
where
testReceiveDelete clientId rfd originalFilePath = do
rcp <- getSMPAgentClient' clientId agentCfg initAgentServers testDB2
runRight_ $ do
rfId <- testReceive rcp rfd originalFilePath
xftpDeleteRcvFile rcp rfId
rfId <- runRight $ testReceive rcp rfd originalFilePath
xftpDeleteRcvFile rcp rfId
disposeAgentClient rcp
testXFTPAgentSendReceiveEncrypted :: HasCallStack => IO ()
@@ -127,7 +126,7 @@ testXFTPAgentSendReceiveEncrypted = withXFTPServer $ do
sndr <- getSMPAgentClient' 1 agentCfg initAgentServers testDB
(rfd1, rfd2) <- runRight $ do
(sfId, _, rfd1, rfd2) <- testSendCF sndr file
xftpDeleteSndFileInternal sndr sfId
liftIO $ xftpDeleteSndFileInternal sndr sfId
pure (rfd1, rfd2)
-- receive file, delete rcv file
testReceiveDelete 2 rfd1 filePath g
@@ -136,9 +135,8 @@ testXFTPAgentSendReceiveEncrypted = withXFTPServer $ do
testReceiveDelete clientId rfd originalFilePath g = do
rcp <- getSMPAgentClient' clientId agentCfg initAgentServers testDB2
cfArgs <- atomically $ Just <$> CF.randomArgs g
runRight_ $ do
rfId <- testReceiveCF rcp rfd cfArgs originalFilePath
xftpDeleteRcvFile rcp rfId
rfId <- runRight $ testReceiveCF rcp rfd cfArgs originalFilePath
xftpDeleteRcvFile rcp rfId
disposeAgentClient rcp
testXFTPAgentSendReceiveRedirect :: HasCallStack => IO ()
@@ -468,11 +466,9 @@ testXFTPAgentDelete = withGlobalLogging logCfgNoLogs $
length <$> listDirectory xftpServerFiles `shouldReturn` 6
-- delete file
runRight $ do
xftpStartWorkers sndr (Just senderFiles)
xftpDeleteSndFileRemote sndr 1 sfId sndDescr
Nothing <- liftIO $ 100000 `timeout` sfGet sndr
pure ()
runRight_ $ xftpStartWorkers sndr (Just senderFiles)
xftpDeleteSndFileRemote sndr 1 sfId sndDescr
Nothing <- 100000 `timeout` sfGet sndr
disposeAgentClient rcp1
threadDelay 1000000
@@ -505,10 +501,9 @@ testXFTPAgentDeleteRestore = withGlobalLogging logCfgNoLogs $ do
-- delete file - should not succeed with server down
sndr <- getSMPAgentClient' 3 agentCfg initAgentServers testDB
runRight $ do
xftpStartWorkers sndr (Just senderFiles)
xftpDeleteSndFileRemote sndr 1 sfId sndDescr
liftIO $ timeout 300000 (get sndr) `shouldReturn` Nothing -- wait for worker attempt
runRight_ $ xftpStartWorkers sndr (Just senderFiles)
xftpDeleteSndFileRemote sndr 1 sfId sndDescr
timeout 300000 (get sndr) `shouldReturn` Nothing -- wait for worker attempt
disposeAgentClient sndr
threadDelay 300000
@@ -636,4 +631,4 @@ testXFTPServerTest :: HasCallStack => Maybe BasicAuth -> XFTPServerWithAuth -> I
testXFTPServerTest newFileBasicAuth srv =
withXFTPServerCfg testXFTPServerConfig {newFileBasicAuth, xftpPort = xftpTestPort2} $ \_ -> do
a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB -- initially passed server is not running
runRight $ testProtocolServer a 1 srv
testProtocolServer a 1 srv
+23 -4
View File
@@ -12,7 +12,6 @@ import Control.Exception (SomeException)
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import qualified Data.ByteString.Base64.URL as B64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
@@ -26,9 +25,9 @@ import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (..)
import Simplex.Messaging.Client (ProtocolClientError (..))
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Protocol (BasicAuth, SenderId)
import Simplex.Messaging.Server.Expiration (ExpirationConfig (..))
import Simplex.Messaging.Util (liftIOEither)
import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive, removeFile)
import System.FilePath ((</>))
import Test.Hspec
@@ -60,6 +59,7 @@ xftpServerTests =
it "prohibited when FNEW disabled" $ testFileBasicAuth False (Just "pwd") (Just "pwd") False
it "allowed with correct basic auth" $ testFileBasicAuth True (Just "pwd") (Just "pwd") True
it "allowed with auth on server without auth" $ testFileBasicAuth True Nothing (Just "any") True
it "should not change content for uploaded and committed files" testFileSkipCommitted
chSize :: Integral a => a
chSize = kb 128
@@ -75,7 +75,7 @@ createTestChunk fp = do
pure bytes
readChunk :: SenderId -> IO ByteString
readChunk sId = B.readFile (xftpServerFiles </> B.unpack (B64.encode sId))
readChunk sId = B.readFile (xftpServerFiles </> B.unpack (U.encode sId))
testFileChunkDelivery :: Expectation
testFileChunkDelivery = xftpTest $ \c -> runRight_ $ runTestFileChunkDelivery c c
@@ -219,7 +219,7 @@ testFileChunkExpiration = withXFTPServerCfg testXFTPServerConfig {fileExpiration
testInactiveClientExpiration :: Expectation
testInactiveClientExpiration = withXFTPServerCfg testXFTPServerConfig {inactiveClientExpiration} $ \_ -> runRight_ $ do
disconnected <- newEmptyTMVarIO
c <- liftIOEither $ getXFTPClient (1, testXFTPServer, Nothing) testXFTPClientConfig (\_ -> atomically $ putTMVar disconnected ())
c <- ExceptT $ getXFTPClient (1, testXFTPServer, Nothing) testXFTPClientConfig (\_ -> atomically $ putTMVar disconnected ())
pingXFTP c
liftIO $ do
threadDelay 100000
@@ -372,3 +372,22 @@ testFileBasicAuth allowNewFiles newFileBasicAuth clntAuth success =
else do
void (createXFTPChunk c spKey file [rcvKey] clntAuth)
`catchError` (liftIO . (`shouldBe` PCEProtocolError AUTH))
testFileSkipCommitted :: IO ()
testFileSkipCommitted =
withXFTPServerCfg testXFTPServerConfig $
\_ -> testXFTPClient $ \c -> do
g <- C.newRandom
(sndKey, spKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
(rcvKey, rpKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
bytes <- createTestChunk testChunkPath
digest <- LC.sha256Hash <$> LB.readFile testChunkPath
let file = FileInfo {sndKey, size = chSize, digest}
chunkSpec = XFTPChunkSpec {filePath = testChunkPath, chunkOffset = 0, chunkSize = chSize}
runRight_ $ do
(sId, [rId]) <- createXFTPChunk c spKey file [rcvKey] Nothing
uploadXFTPChunk c spKey sId chunkSpec
void . liftIO $ createTestChunk testChunkPath -- trash chunk contents
uploadXFTPChunk c spKey sId chunkSpec -- upload again to get FROk without getting stuck
downloadXFTPChunk g c rpKey rId $ XFTPRcvChunkSpec "tests/tmp/received_chunk" chSize digest
liftIO $ B.readFile "tests/tmp/received_chunk" `shouldReturn` bytes -- new chunk content got ignored