mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-26 03:14:53 +00:00
Merge remote-tracking branch 'origin/master' into ab/bench-target
This commit is contained in:
@@ -1,3 +1,13 @@
|
||||
# 5.6.1
|
||||
|
||||
Version 5.6.1.0.
|
||||
|
||||
- Much faster iOS notification server start time (fewer skipped notifications).
|
||||
- Fix SMP server stored message stats.
|
||||
- Prevent overwriting uploaded XFTP files with subsequent upload attempts.
|
||||
- Faster base64 encoding/parsing.
|
||||
- Control port audit log and authentication.
|
||||
|
||||
# 5.6.0
|
||||
|
||||
Version 5.6.0.4.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
@@ -38,7 +39,8 @@ logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
|
||||
-- Warning: this SMP agent server is experimental - it does not work correctly with multiple connected TCP clients in some cases.
|
||||
main :: IO ()
|
||||
main = do
|
||||
putStrLn $ "SMP agent listening on port " ++ tcpPort (cfg :: AgentConfig)
|
||||
let AgentConfig {tcpPort} = cfg
|
||||
putStrLn $ maybe (error "no agent port") (\port -> "SMP agent listening on port " ++ port) tcpPort
|
||||
setLogLevel LogInfo -- LogError
|
||||
Right st <- createAgentStore agentDbFile agentDbKey False MCConsole
|
||||
withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg servers st
|
||||
|
||||
@@ -14,6 +14,12 @@ source-repository-package
|
||||
location: https://github.com/simplex-chat/aeson.git
|
||||
tag: aab7b5a14d6c5ea64c64dcaee418de1bb00dcc2b
|
||||
|
||||
-- old bs/text compat for 8.10
|
||||
source-repository-package
|
||||
type: git
|
||||
location: https://github.com/simplex-chat/base64.git
|
||||
tag: 2d77b6dbcaffc00570a70be8694049f3710e7c94
|
||||
|
||||
source-repository-package
|
||||
type: git
|
||||
location: https://github.com/simplex-chat/hs-socks.git
|
||||
|
||||
+6
-2
@@ -1,5 +1,5 @@
|
||||
name: simplexmq
|
||||
version: 5.6.0.4
|
||||
version: 5.6.1.0
|
||||
synopsis: SimpleXMQ message broker
|
||||
description: |
|
||||
This package includes <./docs/Simplex-Messaging-Server.html server>,
|
||||
@@ -31,7 +31,7 @@ dependencies:
|
||||
- async == 2.2.*
|
||||
- attoparsec == 0.14.*
|
||||
- base >= 4.14 && < 5
|
||||
- base64-bytestring >= 1.0 && < 1.3
|
||||
- base64 == 1.0.*
|
||||
- case-insensitive == 1.2.*
|
||||
- composition == 1.0.*
|
||||
- constraints >= 0.12 && < 0.14
|
||||
@@ -202,3 +202,7 @@ ghc-options:
|
||||
- -Wincomplete-record-updates
|
||||
- -Wincomplete-uni-patterns
|
||||
- -Wunused-type-patterns
|
||||
- -O2
|
||||
|
||||
default-extensions:
|
||||
- StrictData
|
||||
|
||||
+31
-15
@@ -5,7 +5,7 @@ cabal-version: 1.12
|
||||
-- see: https://github.com/sol/hpack
|
||||
|
||||
name: simplexmq
|
||||
version: 5.6.0.4
|
||||
version: 5.6.1.0
|
||||
synopsis: SimpleXMQ message broker
|
||||
description: This package includes <./docs/Simplex-Messaging-Server.html server>,
|
||||
<./docs/Simplex-Messaging-Client.html client> and
|
||||
@@ -119,6 +119,8 @@ library
|
||||
Simplex.Messaging.Crypto.SNTRUP761.Bindings.FFI
|
||||
Simplex.Messaging.Crypto.SNTRUP761.Bindings.RNG
|
||||
Simplex.Messaging.Encoding
|
||||
Simplex.Messaging.Encoding.Base64
|
||||
Simplex.Messaging.Encoding.Base64.URL
|
||||
Simplex.Messaging.Encoding.String
|
||||
Simplex.Messaging.Notifications.Client
|
||||
Simplex.Messaging.Notifications.Protocol
|
||||
@@ -171,7 +173,9 @@ library
|
||||
Paths_simplexmq
|
||||
hs-source-dirs:
|
||||
src
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns
|
||||
default-extensions:
|
||||
StrictData
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2
|
||||
include-dirs:
|
||||
cbits
|
||||
c-sources:
|
||||
@@ -187,7 +191,7 @@ library
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, base64 ==1.0.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
@@ -251,7 +255,9 @@ executable ntf-server
|
||||
Paths_simplexmq
|
||||
hs-source-dirs:
|
||||
apps/ntf-server
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
|
||||
default-extensions:
|
||||
StrictData
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts
|
||||
build-depends:
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
@@ -260,7 +266,7 @@ executable ntf-server
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, base64 ==1.0.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
@@ -325,7 +331,9 @@ executable smp-agent
|
||||
Paths_simplexmq
|
||||
hs-source-dirs:
|
||||
apps/smp-agent
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
|
||||
default-extensions:
|
||||
StrictData
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts
|
||||
build-depends:
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
@@ -334,7 +342,7 @@ executable smp-agent
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, base64 ==1.0.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
@@ -399,7 +407,9 @@ executable smp-server
|
||||
Paths_simplexmq
|
||||
hs-source-dirs:
|
||||
apps/smp-server
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
|
||||
default-extensions:
|
||||
StrictData
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts
|
||||
build-depends:
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
@@ -408,7 +418,7 @@ executable smp-server
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, base64 ==1.0.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
@@ -473,7 +483,9 @@ executable xftp
|
||||
Paths_simplexmq
|
||||
hs-source-dirs:
|
||||
apps/xftp
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
|
||||
default-extensions:
|
||||
StrictData
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts
|
||||
build-depends:
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
@@ -482,7 +494,7 @@ executable xftp
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, base64 ==1.0.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
@@ -547,7 +559,9 @@ executable xftp-server
|
||||
Paths_simplexmq
|
||||
hs-source-dirs:
|
||||
apps/xftp-server
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
|
||||
default-extensions:
|
||||
StrictData
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts
|
||||
build-depends:
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
@@ -556,7 +570,7 @@ executable xftp-server
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, base64 ==1.0.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
@@ -653,7 +667,9 @@ test-suite simplexmq-test
|
||||
Paths_simplexmq
|
||||
hs-source-dirs:
|
||||
tests
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns
|
||||
default-extensions:
|
||||
StrictData
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2
|
||||
build-depends:
|
||||
HUnit ==1.6.*
|
||||
, QuickCheck ==2.14.*
|
||||
@@ -664,7 +680,7 @@ test-suite simplexmq-test
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, base64 ==1.0.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
|
||||
@@ -74,8 +74,9 @@ import Simplex.Messaging.Util (catchAll_, liftError, tshow, unlessM, whenM)
|
||||
import System.FilePath (takeFileName, (</>))
|
||||
import UnliftIO
|
||||
import UnliftIO.Directory
|
||||
import qualified UnliftIO.Exception as E
|
||||
|
||||
startXFTPWorkers :: AgentMonad m => AgentClient -> Maybe FilePath -> m ()
|
||||
startXFTPWorkers :: AgentClient -> Maybe FilePath -> AM ()
|
||||
startXFTPWorkers c workDir = do
|
||||
wd <- asks $ xftpWorkDir . xftpAgent
|
||||
atomically $ writeTVar wd workDir
|
||||
@@ -84,23 +85,26 @@ startXFTPWorkers c workDir = do
|
||||
startSndFiles cfg
|
||||
startDelFiles cfg
|
||||
where
|
||||
startRcvFiles :: AgentConfig -> AM ()
|
||||
startRcvFiles AgentConfig {rcvFilesTTL} = do
|
||||
pendingRcvServers <- withStore' c (`getPendingRcvFilesServers` rcvFilesTTL)
|
||||
forM_ pendingRcvServers $ \s -> resumeXFTPRcvWork c (Just s)
|
||||
lift . forM_ pendingRcvServers $ \s -> resumeXFTPRcvWork c (Just s)
|
||||
-- start local worker for files pending decryption,
|
||||
-- no need to make an extra query for the check
|
||||
-- as the worker will check the store anyway
|
||||
resumeXFTPRcvWork c Nothing
|
||||
lift $ resumeXFTPRcvWork c Nothing
|
||||
startSndFiles :: AgentConfig -> AM ()
|
||||
startSndFiles AgentConfig {sndFilesTTL} = do
|
||||
-- start worker for files pending encryption/creation
|
||||
resumeXFTPSndWork c Nothing
|
||||
lift $ resumeXFTPSndWork c Nothing
|
||||
pendingSndServers <- withStore' c (`getPendingSndFilesServers` sndFilesTTL)
|
||||
forM_ pendingSndServers $ \s -> resumeXFTPSndWork c (Just s)
|
||||
lift . forM_ pendingSndServers $ \s -> resumeXFTPSndWork c (Just s)
|
||||
startDelFiles :: AgentConfig -> AM ()
|
||||
startDelFiles AgentConfig {rcvFilesTTL} = do
|
||||
pendingDelServers <- withStore' c (`getPendingDelFilesServers` rcvFilesTTL)
|
||||
forM_ pendingDelServers $ resumeXFTPDelWork c
|
||||
lift . forM_ pendingDelServers $ resumeXFTPDelWork c
|
||||
|
||||
closeXFTPAgent :: MonadUnliftIO m => XFTPAgent -> m ()
|
||||
closeXFTPAgent :: XFTPAgent -> IO ()
|
||||
closeXFTPAgent a = do
|
||||
stopWorkers $ xftpRcvWorkers a
|
||||
stopWorkers $ xftpSndWorkers a
|
||||
@@ -108,16 +112,16 @@ closeXFTPAgent a = do
|
||||
where
|
||||
stopWorkers workers = atomically (swapTVar workers M.empty) >>= mapM_ (liftIO . cancelWorker)
|
||||
|
||||
xftpReceiveFile' :: AgentMonad m => AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Maybe CryptoFileArgs -> m RcvFileId
|
||||
xftpReceiveFile' :: AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Maybe CryptoFileArgs -> AM RcvFileId
|
||||
xftpReceiveFile' c userId (ValidFileDescription fd@FileDescription {chunks, redirect}) cfArgs = do
|
||||
g <- asks random
|
||||
prefixPath <- getPrefixPath "rcv.xftp"
|
||||
prefixPath <- lift $ getPrefixPath "rcv.xftp"
|
||||
createDirectory prefixPath
|
||||
let relPrefixPath = takeFileName prefixPath
|
||||
relTmpPath = relPrefixPath </> "xftp.encrypted"
|
||||
relSavePath = relPrefixPath </> "xftp.decrypted"
|
||||
createDirectory =<< toFSFilePath relTmpPath
|
||||
createEmptyFile =<< toFSFilePath relSavePath
|
||||
lift $ createDirectory =<< toFSFilePath relTmpPath
|
||||
lift $ createEmptyFile =<< toFSFilePath relSavePath
|
||||
let saveFile = CryptoFile relSavePath cfArgs
|
||||
fId <- case redirect of
|
||||
Nothing -> withStore c $ \db -> createRcvFile db g userId fd relPrefixPath relTmpPath saveFile
|
||||
@@ -125,8 +129,8 @@ xftpReceiveFile' c userId (ValidFileDescription fd@FileDescription {chunks, redi
|
||||
-- prepare description paths
|
||||
let relTmpPathRedirect = relPrefixPath </> "xftp.redirect-encrypted"
|
||||
relSavePathRedirect = relPrefixPath </> "xftp.redirect-decrypted"
|
||||
createDirectory =<< toFSFilePath relTmpPathRedirect
|
||||
createEmptyFile =<< toFSFilePath relSavePathRedirect
|
||||
lift $ createDirectory =<< toFSFilePath relTmpPathRedirect
|
||||
lift $ createEmptyFile =<< toFSFilePath relSavePathRedirect
|
||||
cfArgsRedirect <- atomically $ CF.randomArgs g
|
||||
let saveFileRedirect = CryptoFile relSavePathRedirect $ Just cfArgsRedirect
|
||||
-- create download tasks
|
||||
@@ -134,42 +138,42 @@ xftpReceiveFile' c userId (ValidFileDescription fd@FileDescription {chunks, redi
|
||||
forM_ chunks (downloadChunk c)
|
||||
pure fId
|
||||
|
||||
downloadChunk :: AgentMonad m => AgentClient -> FileChunk -> m ()
|
||||
downloadChunk :: AgentClient -> FileChunk -> AM ()
|
||||
downloadChunk c FileChunk {replicas = (FileChunkReplica {server} : _)} = do
|
||||
void $ getXFTPRcvWorker True c (Just server)
|
||||
lift . void $ getXFTPRcvWorker True c (Just server)
|
||||
downloadChunk _ _ = throwError $ INTERNAL "no replicas"
|
||||
|
||||
getPrefixPath :: AgentMonad m => String -> m FilePath
|
||||
getPrefixPath :: String -> AM' FilePath
|
||||
getPrefixPath suffix = do
|
||||
workPath <- getXFTPWorkPath
|
||||
ts <- liftIO getCurrentTime
|
||||
let isoTime = formatTime defaultTimeLocale "%Y%m%d_%H%M%S_%6q" ts
|
||||
uniqueCombine workPath (isoTime <> "_" <> suffix)
|
||||
|
||||
toFSFilePath :: AgentMonad m => FilePath -> m FilePath
|
||||
toFSFilePath :: FilePath -> AM' FilePath
|
||||
toFSFilePath f = (</> f) <$> getXFTPWorkPath
|
||||
|
||||
createEmptyFile :: AgentMonad m => FilePath -> m ()
|
||||
createEmptyFile :: FilePath -> AM' ()
|
||||
createEmptyFile fPath = liftIO $ B.writeFile fPath ""
|
||||
|
||||
resumeXFTPRcvWork :: AgentMonad' m => AgentClient -> Maybe XFTPServer -> m ()
|
||||
resumeXFTPRcvWork :: AgentClient -> Maybe XFTPServer -> AM' ()
|
||||
resumeXFTPRcvWork = void .: getXFTPRcvWorker False
|
||||
|
||||
getXFTPRcvWorker :: AgentMonad' m => Bool -> AgentClient -> Maybe XFTPServer -> m Worker
|
||||
getXFTPRcvWorker :: Bool -> AgentClient -> Maybe XFTPServer -> AM' Worker
|
||||
getXFTPRcvWorker hasWork c server = do
|
||||
ws <- asks $ xftpRcvWorkers . xftpAgent
|
||||
getAgentWorker "xftp_rcv" hasWork c server ws $
|
||||
maybe (runXFTPRcvLocalWorker c) (runXFTPRcvWorker c) server
|
||||
|
||||
runXFTPRcvWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> Worker -> m ()
|
||||
runXFTPRcvWorker :: AgentClient -> XFTPServer -> Worker -> AM ()
|
||||
runXFTPRcvWorker c srv Worker {doWork} = do
|
||||
cfg <- asks config
|
||||
forever $ do
|
||||
waitForWork doWork
|
||||
lift $ waitForWork doWork
|
||||
atomically $ assertAgentForeground c
|
||||
runXFTPOperation cfg
|
||||
where
|
||||
runXFTPOperation :: AgentConfig -> m ()
|
||||
runXFTPOperation :: AgentConfig -> AM ()
|
||||
runXFTPOperation AgentConfig {rcvFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} =
|
||||
withWork c doWork (\db -> getNextRcvChunkToDownload db srv rcvFilesTTL) $ \case
|
||||
RcvFileChunk {rcvFileId, rcvFileEntityId, fileTmpPath, replicas = []} -> rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath) "chunk has no replicas"
|
||||
@@ -182,14 +186,14 @@ runXFTPRcvWorker c srv Worker {doWork} = do
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
when notifyOnRetry $ notify c rcvFileEntityId $ RFERR e
|
||||
closeXFTPServerClient c userId server digest
|
||||
liftIO $ closeXFTPServerClient c userId server digest
|
||||
withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay
|
||||
atomically $ assertAgentForeground c
|
||||
loop
|
||||
retryDone e = rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath) (show e)
|
||||
downloadFileChunk :: RcvFileChunk -> RcvFileChunkReplica -> m ()
|
||||
downloadFileChunk :: RcvFileChunk -> RcvFileChunkReplica -> AM ()
|
||||
downloadFileChunk RcvFileChunk {userId, rcvFileId, rcvFileEntityId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath} replica = do
|
||||
fsFileTmpPath <- toFSFilePath fileTmpPath
|
||||
fsFileTmpPath <- lift $ toFSFilePath fileTmpPath
|
||||
chunkPath <- uniqueCombine fsFileTmpPath $ show chunkNo
|
||||
let chunkSpec = XFTPRcvChunkSpec chunkPath (unFileSize chunkSize) (unFileDigest digest)
|
||||
relChunkPath = fileTmpPath </> takeFileName chunkPath
|
||||
@@ -206,7 +210,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do
|
||||
liftIO . when complete $ updateRcvFileStatus db rcvFileId RFSReceived
|
||||
pure (entityId, complete, RFPROG rcvd total)
|
||||
notify c entityId progress
|
||||
when complete . void $
|
||||
when complete . lift . void $
|
||||
getXFTPRcvWorker True c Nothing
|
||||
where
|
||||
receivedSize :: [RcvFileChunk] -> Int64
|
||||
@@ -222,37 +226,37 @@ withRetryIntervalLimit maxN ri action =
|
||||
withRetryIntervalCount ri $ \n delay loop ->
|
||||
when (n < maxN) $ action delay loop
|
||||
|
||||
retryOnError :: AgentMonad m => Text -> m a -> m a -> AgentErrorType -> m a
|
||||
retryOnError :: Text -> AM a -> AM a -> AgentErrorType -> AM a
|
||||
retryOnError name loop done e = do
|
||||
logError $ name <> " error: " <> tshow e
|
||||
if temporaryAgentError e
|
||||
then loop
|
||||
else done
|
||||
|
||||
rcvWorkerInternalError :: AgentMonad m => AgentClient -> DBRcvFileId -> RcvFileId -> Maybe FilePath -> String -> m ()
|
||||
rcvWorkerInternalError :: AgentClient -> DBRcvFileId -> RcvFileId -> Maybe FilePath -> String -> AM ()
|
||||
rcvWorkerInternalError c rcvFileId rcvFileEntityId tmpPath internalErrStr = do
|
||||
forM_ tmpPath (removePath <=< toFSFilePath)
|
||||
lift $ forM_ tmpPath (removePath <=< toFSFilePath)
|
||||
withStore' c $ \db -> updateRcvFileError db rcvFileId internalErrStr
|
||||
notify c rcvFileEntityId $ RFERR $ INTERNAL internalErrStr
|
||||
|
||||
runXFTPRcvLocalWorker :: forall m. AgentMonad m => AgentClient -> Worker -> m ()
|
||||
runXFTPRcvLocalWorker :: AgentClient -> Worker -> AM ()
|
||||
runXFTPRcvLocalWorker c Worker {doWork} = do
|
||||
cfg <- asks config
|
||||
forever $ do
|
||||
waitForWork doWork
|
||||
lift $ waitForWork doWork
|
||||
atomically $ assertAgentForeground c
|
||||
runXFTPOperation cfg
|
||||
where
|
||||
runXFTPOperation :: AgentConfig -> m ()
|
||||
runXFTPOperation :: AgentConfig -> AM ()
|
||||
runXFTPOperation AgentConfig {rcvFilesTTL} =
|
||||
withWork c doWork (`getNextRcvFileToDecrypt` rcvFilesTTL) $
|
||||
\f@RcvFile {rcvFileId, rcvFileEntityId, tmpPath} ->
|
||||
decryptFile f `catchAgentError` (rcvWorkerInternalError c rcvFileId rcvFileEntityId tmpPath . show)
|
||||
decryptFile :: RcvFile -> m ()
|
||||
decryptFile :: RcvFile -> AM ()
|
||||
decryptFile RcvFile {rcvFileId, rcvFileEntityId, size, digest, key, nonce, tmpPath, saveFile, status, chunks, redirect} = do
|
||||
let CryptoFile savePath cfArgs = saveFile
|
||||
fsSavePath <- toFSFilePath savePath
|
||||
when (status == RFSDecrypting) $
|
||||
fsSavePath <- lift $ toFSFilePath savePath
|
||||
lift . when (status == RFSDecrypting) $
|
||||
whenM (doesFileExist fsSavePath) (removeFile fsSavePath >> createEmptyFile fsSavePath)
|
||||
withStore' c $ \db -> updateRcvFileStatus db rcvFileId RFSDecrypting
|
||||
chunkPaths <- getChunkPaths chunks
|
||||
@@ -265,16 +269,16 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
|
||||
case redirect of
|
||||
Nothing -> do
|
||||
notify c rcvFileEntityId $ RFDONE fsSavePath
|
||||
forM_ tmpPath (removePath <=< toFSFilePath)
|
||||
lift $ forM_ tmpPath (removePath <=< toFSFilePath)
|
||||
atomically $ waitUntilForeground c
|
||||
withStore' c (`updateRcvFileComplete` rcvFileId)
|
||||
Just RcvFileRedirect {redirectFileInfo, redirectDbId} -> do
|
||||
let RedirectFileInfo {size = redirectSize, digest = redirectDigest} = redirectFileInfo
|
||||
forM_ tmpPath (removePath <=< toFSFilePath)
|
||||
lift $ forM_ tmpPath (removePath <=< toFSFilePath)
|
||||
atomically $ waitUntilForeground c
|
||||
withStore' c (`updateRcvFileComplete` rcvFileId)
|
||||
-- proceed with redirect
|
||||
yaml <- liftError (INTERNAL . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `finally` (toFSFilePath fsSavePath >>= removePath)
|
||||
yaml <- liftError (INTERNAL . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath)
|
||||
next@FileDescription {chunks = nextChunks} <- case strDecode (LB.toStrict yaml) of
|
||||
Left _ -> throwError . XFTP $ XFTP.REDIRECT "decode error"
|
||||
Right (ValidFileDescription fd@FileDescription {size = dstSize, digest = dstDigest})
|
||||
@@ -285,19 +289,19 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
|
||||
withStore c $ \db -> updateRcvFileRedirect db redirectDbId next
|
||||
forM_ nextChunks (downloadChunk c)
|
||||
where
|
||||
getChunkPaths :: [RcvFileChunk] -> m [FilePath]
|
||||
getChunkPaths :: [RcvFileChunk] -> AM [FilePath]
|
||||
getChunkPaths [] = pure []
|
||||
getChunkPaths (RcvFileChunk {chunkTmpPath = Just path} : cs) = do
|
||||
ps <- getChunkPaths cs
|
||||
fsPath <- toFSFilePath path
|
||||
fsPath <- lift $ toFSFilePath path
|
||||
pure $ fsPath : ps
|
||||
getChunkPaths (RcvFileChunk {chunkTmpPath = Nothing} : _cs) =
|
||||
throwError $ INTERNAL "no chunk path"
|
||||
|
||||
xftpDeleteRcvFile' :: AgentMonad m => AgentClient -> RcvFileId -> m ()
|
||||
xftpDeleteRcvFile' :: AgentClient -> RcvFileId -> AM' ()
|
||||
xftpDeleteRcvFile' c rcvFileEntityId = xftpDeleteRcvFiles' c [rcvFileEntityId]
|
||||
|
||||
xftpDeleteRcvFiles' :: forall m. AgentMonad m => AgentClient -> [RcvFileId] -> m ()
|
||||
xftpDeleteRcvFiles' :: AgentClient -> [RcvFileId] -> AM' ()
|
||||
xftpDeleteRcvFiles' c rcvFileEntityIds = do
|
||||
rcvFiles <- rights <$> withStoreBatch c (\db -> map (fmap (first storeError) . getRcvFileByEntityId db) rcvFileEntityIds)
|
||||
redirects <- rights <$> batchFiles getRcvFileRedirects rcvFiles
|
||||
@@ -309,29 +313,29 @@ xftpDeleteRcvFiles' c rcvFileEntityIds = do
|
||||
(removePath . (workPath </>)) prefixPath `catchAll_` pure ()
|
||||
where
|
||||
fileComplete RcvFile {status} = status == RFSComplete || status == RFSError
|
||||
batchFiles :: (DB.Connection -> DBRcvFileId -> IO a) -> [RcvFile] -> m [Either AgentErrorType a]
|
||||
batchFiles :: (DB.Connection -> DBRcvFileId -> IO a) -> [RcvFile] -> AM' [Either AgentErrorType a]
|
||||
batchFiles f rcvFiles = withStoreBatch' c $ \db -> map (\RcvFile {rcvFileId} -> f db rcvFileId) rcvFiles
|
||||
|
||||
notify :: forall m e. (MonadUnliftIO m, AEntityI e) => AgentClient -> EntityId -> ACommand 'Agent e -> m ()
|
||||
notify :: forall m e. (MonadIO m, AEntityI e) => AgentClient -> EntityId -> ACommand 'Agent e -> m ()
|
||||
notify c entId cmd = atomically $ writeTBQueue (subQ c) ("", entId, APC (sAEntity @e) cmd)
|
||||
|
||||
xftpSendFile' :: AgentMonad m => AgentClient -> UserId -> CryptoFile -> Int -> m SndFileId
|
||||
xftpSendFile' :: AgentClient -> UserId -> CryptoFile -> Int -> AM SndFileId
|
||||
xftpSendFile' c userId file numRecipients = do
|
||||
g <- asks random
|
||||
prefixPath <- getPrefixPath "snd.xftp"
|
||||
prefixPath <- lift $ getPrefixPath "snd.xftp"
|
||||
createDirectory prefixPath
|
||||
let relPrefixPath = takeFileName prefixPath
|
||||
key <- atomically $ C.randomSbKey g
|
||||
nonce <- atomically $ C.randomCbNonce g
|
||||
-- saving absolute filePath will not allow to restore file encryption after app update, but it's a short window
|
||||
fId <- withStore c $ \db -> createSndFile db g userId file numRecipients relPrefixPath key nonce Nothing
|
||||
void $ getXFTPSndWorker True c Nothing
|
||||
lift . void $ getXFTPSndWorker True c Nothing
|
||||
pure fId
|
||||
|
||||
xftpSendDescription' :: forall m. AgentMonad m => AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Int -> m SndFileId
|
||||
xftpSendDescription' :: AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Int -> AM SndFileId
|
||||
xftpSendDescription' c userId (ValidFileDescription fdDirect@FileDescription {size, digest}) numRecipients = do
|
||||
g <- asks random
|
||||
prefixPath <- getPrefixPath "snd.xftp"
|
||||
prefixPath <- lift $ getPrefixPath "snd.xftp"
|
||||
createDirectory prefixPath
|
||||
let relPrefixPath = takeFileName prefixPath
|
||||
let directYaml = prefixPath </> "direct.yaml"
|
||||
@@ -341,39 +345,39 @@ xftpSendDescription' c userId (ValidFileDescription fdDirect@FileDescription {si
|
||||
key <- atomically $ C.randomSbKey g
|
||||
nonce <- atomically $ C.randomCbNonce g
|
||||
fId <- withStore c $ \db -> createSndFile db g userId file numRecipients relPrefixPath key nonce $ Just RedirectFileInfo {size, digest}
|
||||
void $ getXFTPSndWorker True c Nothing
|
||||
lift . void $ getXFTPSndWorker True c Nothing
|
||||
pure fId
|
||||
|
||||
resumeXFTPSndWork :: AgentMonad' m => AgentClient -> Maybe XFTPServer -> m ()
|
||||
resumeXFTPSndWork :: AgentClient -> Maybe XFTPServer -> AM' ()
|
||||
resumeXFTPSndWork = void .: getXFTPSndWorker False
|
||||
|
||||
getXFTPSndWorker :: AgentMonad' m => Bool -> AgentClient -> Maybe XFTPServer -> m Worker
|
||||
getXFTPSndWorker :: Bool -> AgentClient -> Maybe XFTPServer -> AM' Worker
|
||||
getXFTPSndWorker hasWork c server = do
|
||||
ws <- asks $ xftpSndWorkers . xftpAgent
|
||||
getAgentWorker "xftp_snd" hasWork c server ws $
|
||||
maybe (runXFTPSndPrepareWorker c) (runXFTPSndWorker c) server
|
||||
|
||||
runXFTPSndPrepareWorker :: forall m. AgentMonad m => AgentClient -> Worker -> m ()
|
||||
runXFTPSndPrepareWorker :: AgentClient -> Worker -> AM ()
|
||||
runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
cfg <- asks config
|
||||
forever $ do
|
||||
waitForWork doWork
|
||||
lift $ waitForWork doWork
|
||||
atomically $ assertAgentForeground c
|
||||
runXFTPOperation cfg
|
||||
where
|
||||
runXFTPOperation :: AgentConfig -> m ()
|
||||
runXFTPOperation :: AgentConfig -> AM ()
|
||||
runXFTPOperation cfg@AgentConfig {sndFilesTTL} =
|
||||
withWork c doWork (`getNextSndFileToPrepare` sndFilesTTL) $
|
||||
\f@SndFile {sndFileId, sndFileEntityId, prefixPath} ->
|
||||
prepareFile cfg f `catchAgentError` (sndWorkerInternalError c sndFileId sndFileEntityId prefixPath . show)
|
||||
prepareFile :: AgentConfig -> SndFile -> m ()
|
||||
prepareFile :: AgentConfig -> SndFile -> AM ()
|
||||
prepareFile _ SndFile {prefixPath = Nothing} =
|
||||
throwError $ INTERNAL "no prefix path"
|
||||
prepareFile cfg sndFile@SndFile {sndFileId, userId, prefixPath = Just ppath, status} = do
|
||||
SndFile {numRecipients, chunks} <-
|
||||
if status /= SFSEncrypted -- status is SFSNew or SFSEncrypting
|
||||
then do
|
||||
fsEncPath <- toFSFilePath $ sndFileEncPath ppath
|
||||
fsEncPath <- lift . toFSFilePath $ sndFileEncPath ppath
|
||||
when (status == SFSEncrypting) . whenM (doesFileExist fsEncPath) $
|
||||
removeFile fsEncPath
|
||||
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSEncrypting
|
||||
@@ -389,7 +393,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSUploading
|
||||
where
|
||||
AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients, messageRetryInterval = ri} = cfg
|
||||
encryptFileForUpload :: SndFile -> FilePath -> m (FileDigest, [(XFTPChunkSpec, FileDigest)])
|
||||
encryptFileForUpload :: SndFile -> FilePath -> AM (FileDigest, [(XFTPChunkSpec, FileDigest)])
|
||||
encryptFileForUpload SndFile {key, nonce, srcFile, redirect} fsEncPath = do
|
||||
let CryptoFile {filePath} = srcFile
|
||||
fileName = takeFileName filePath
|
||||
@@ -412,12 +416,12 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
chunkCreated :: SndFileChunk -> Bool
|
||||
chunkCreated SndFileChunk {replicas} =
|
||||
any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSCreated) replicas
|
||||
createChunk :: Int -> SndFileChunk -> m ()
|
||||
createChunk :: Int -> SndFileChunk -> AM ()
|
||||
createChunk numRecipients' ch = do
|
||||
atomically $ assertAgentForeground c
|
||||
(replica, ProtoServerWithAuth srv _) <- tryCreate
|
||||
withStore' c $ \db -> createSndFileReplica db ch replica
|
||||
void $ getXFTPSndWorker True c (Just srv)
|
||||
lift . void $ getXFTPSndWorker True c (Just srv)
|
||||
where
|
||||
tryCreate = do
|
||||
usedSrvs <- newTVarIO ([] :: [XFTPServer])
|
||||
@@ -433,21 +437,21 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
replica <- agentXFTPNewChunk c ch numRecipients' srvAuth
|
||||
pure (replica, srvAuth)
|
||||
|
||||
sndWorkerInternalError :: AgentMonad m => AgentClient -> DBSndFileId -> SndFileId -> Maybe FilePath -> String -> m ()
|
||||
sndWorkerInternalError :: AgentClient -> DBSndFileId -> SndFileId -> Maybe FilePath -> String -> AM ()
|
||||
sndWorkerInternalError c sndFileId sndFileEntityId prefixPath internalErrStr = do
|
||||
forM_ prefixPath $ removePath <=< toFSFilePath
|
||||
lift . forM_ prefixPath $ removePath <=< toFSFilePath
|
||||
withStore' c $ \db -> updateSndFileError db sndFileId internalErrStr
|
||||
notify c sndFileEntityId $ SFERR $ INTERNAL internalErrStr
|
||||
|
||||
runXFTPSndWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> Worker -> m ()
|
||||
runXFTPSndWorker :: AgentClient -> XFTPServer -> Worker -> AM ()
|
||||
runXFTPSndWorker c srv Worker {doWork} = do
|
||||
cfg <- asks config
|
||||
forever $ do
|
||||
waitForWork doWork
|
||||
lift $ waitForWork doWork
|
||||
atomically $ assertAgentForeground c
|
||||
runXFTPOperation cfg
|
||||
where
|
||||
runXFTPOperation :: AgentConfig -> m ()
|
||||
runXFTPOperation :: AgentConfig -> AM ()
|
||||
runXFTPOperation cfg@AgentConfig {sndFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} = do
|
||||
withWork c doWork (\db -> getNextSndChunkToUpload db srv sndFilesTTL) $ \case
|
||||
SndFileChunk {sndFileId, sndFileEntityId, filePrefixPath, replicas = []} -> sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) "chunk has no replicas"
|
||||
@@ -460,15 +464,15 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
when notifyOnRetry $ notify c sndFileEntityId $ SFERR e
|
||||
closeXFTPServerClient c userId server digest
|
||||
liftIO $ closeXFTPServerClient c userId server digest
|
||||
withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay
|
||||
atomically $ assertAgentForeground c
|
||||
loop
|
||||
retryDone e = sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) (show e)
|
||||
uploadFileChunk :: AgentConfig -> SndFileChunk -> SndFileChunkReplica -> m ()
|
||||
uploadFileChunk :: AgentConfig -> SndFileChunk -> SndFileChunkReplica -> AM ()
|
||||
uploadFileChunk AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients} sndFileChunk@SndFileChunk {sndFileId, userId, chunkSpec = chunkSpec@XFTPChunkSpec {filePath}, digest = chunkDigest} replica = do
|
||||
replica'@SndFileChunkReplica {sndChunkReplicaId} <- addRecipients sndFileChunk replica
|
||||
fsFilePath <- toFSFilePath filePath
|
||||
fsFilePath <- lift $ toFSFilePath filePath
|
||||
unlessM (doesFileExist fsFilePath) $ throwError $ INTERNAL "encrypted file doesn't exist on upload"
|
||||
let chunkSpec' = chunkSpec {filePath = fsFilePath} :: XFTPChunkSpec
|
||||
atomically $ assertAgentForeground c
|
||||
@@ -484,10 +488,10 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
when complete $ do
|
||||
(sndDescr, rcvDescrs) <- sndFileToDescrs sf
|
||||
notify c sndFileEntityId $ SFDONE sndDescr rcvDescrs
|
||||
forM_ prefixPath $ removePath <=< toFSFilePath
|
||||
lift . forM_ prefixPath $ removePath <=< toFSFilePath
|
||||
withStore' c $ \db -> updateSndFileComplete db sndFileId
|
||||
where
|
||||
addRecipients :: SndFileChunk -> SndFileChunkReplica -> m SndFileChunkReplica
|
||||
addRecipients :: SndFileChunk -> SndFileChunkReplica -> AM SndFileChunkReplica
|
||||
addRecipients ch@SndFileChunk {numRecipients} cr@SndFileChunkReplica {rcvIdsKeys}
|
||||
| length rcvIdsKeys > numRecipients = throwError $ INTERNAL "too many recipients"
|
||||
| length rcvIdsKeys == numRecipients = pure cr
|
||||
@@ -496,7 +500,7 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
rcvIdsKeys' <- agentXFTPAddRecipients c userId chunkDigest cr numRecipients'
|
||||
cr' <- withStore' c $ \db -> addSndChunkReplicaRecipients db cr $ L.toList rcvIdsKeys'
|
||||
addRecipients ch cr'
|
||||
sndFileToDescrs :: SndFile -> m (ValidFileDescription 'FSender, [ValidFileDescription 'FRecipient])
|
||||
sndFileToDescrs :: SndFile -> AM (ValidFileDescription 'FSender, [ValidFileDescription 'FRecipient])
|
||||
sndFileToDescrs SndFile {digest = Nothing} = throwError $ INTERNAL "snd file has no digest"
|
||||
sndFileToDescrs SndFile {chunks = []} = throwError $ INTERNAL "snd file has no chunks"
|
||||
sndFileToDescrs SndFile {digest = Just digest, key, nonce, chunks = chunks@(fstChunk : _), redirect} = do
|
||||
@@ -511,7 +515,7 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
fdRcvs = createRcvFileDescriptions fdRcv chunks
|
||||
validFdRcvs <- either (throwError . INTERNAL) pure $ mapM validateFileDescription fdRcvs
|
||||
pure (validFdSnd, validFdRcvs)
|
||||
toSndDescrChunk :: SndFileChunk -> m FileChunk
|
||||
toSndDescrChunk :: SndFileChunk -> AM FileChunk
|
||||
toSndDescrChunk SndFileChunk {replicas = []} = throwError $ INTERNAL "snd file chunk has no replicas"
|
||||
toSndDescrChunk ch@SndFileChunk {chunkNo, digest = chDigest, replicas = (SndFileChunkReplica {server, replicaId, replicaKey} : _)} = do
|
||||
let chunkSize = FileSize $ sndChunkSize ch
|
||||
@@ -562,10 +566,10 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
chunkUploaded SndFileChunk {replicas} =
|
||||
any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSUploaded) replicas
|
||||
|
||||
deleteSndFileInternal :: AgentMonad m => AgentClient -> SndFileId -> m ()
|
||||
deleteSndFileInternal :: AgentClient -> SndFileId -> AM' ()
|
||||
deleteSndFileInternal c sndFileEntityId = deleteSndFilesInternal c [sndFileEntityId]
|
||||
|
||||
deleteSndFilesInternal :: forall m. AgentMonad m => AgentClient -> [SndFileId] -> m ()
|
||||
deleteSndFilesInternal :: AgentClient -> [SndFileId] -> AM' ()
|
||||
deleteSndFilesInternal c sndFileEntityIds = do
|
||||
sndFiles <- rights <$> withStoreBatch c (\db -> map (fmap (first storeError) . getSndFileByEntityId db) sndFileEntityIds)
|
||||
let (toDelete, toMarkDeleted) = partition fileComplete sndFiles
|
||||
@@ -576,15 +580,15 @@ deleteSndFilesInternal c sndFileEntityIds = do
|
||||
batchFiles_ updateSndFileDeleted toMarkDeleted
|
||||
where
|
||||
fileComplete SndFile {status} = status == SFSComplete || status == SFSError
|
||||
batchFiles_ :: (DB.Connection -> DBSndFileId -> IO a) -> [SndFile] -> m ()
|
||||
batchFiles_ :: (DB.Connection -> DBSndFileId -> IO a) -> [SndFile] -> AM' ()
|
||||
batchFiles_ f sndFiles = void $ withStoreBatch' c $ \db -> map (\SndFile {sndFileId} -> f db sndFileId) sndFiles
|
||||
|
||||
deleteSndFileRemote :: forall m. AgentMonad m => AgentClient -> UserId -> SndFileId -> ValidFileDescription 'FSender -> m ()
|
||||
deleteSndFileRemote :: AgentClient -> UserId -> SndFileId -> ValidFileDescription 'FSender -> AM' ()
|
||||
deleteSndFileRemote c userId sndFileEntityId sfd = deleteSndFilesRemote c userId [(sndFileEntityId, sfd)]
|
||||
|
||||
deleteSndFilesRemote :: forall m. AgentMonad m => AgentClient -> UserId -> [(SndFileId, ValidFileDescription 'FSender)] -> m ()
|
||||
deleteSndFilesRemote :: AgentClient -> UserId -> [(SndFileId, ValidFileDescription 'FSender)] -> AM' ()
|
||||
deleteSndFilesRemote c userId sndFileIdsDescrs = do
|
||||
deleteSndFilesInternal c (map fst sndFileIdsDescrs) `catchAgentError` (notify c "" . SFERR)
|
||||
deleteSndFilesInternal c (map fst sndFileIdsDescrs) `E.catchAny` (notify c "" . SFERR . INTERNAL . show)
|
||||
let rs = concatMap (mapMaybe chunkReplica . fdChunks . snd) sndFileIdsDescrs
|
||||
void $ withStoreBatch' c (\db -> map (uncurry $ createDeletedSndChunkReplica db userId) rs)
|
||||
let servers = S.fromList $ map (\(FileChunkReplica {server}, _) -> server) rs
|
||||
@@ -596,23 +600,23 @@ deleteSndFilesRemote c userId sndFileIdsDescrs = do
|
||||
FileChunk {digest, replicas = replica : _} -> Just (replica, digest)
|
||||
_ -> Nothing
|
||||
|
||||
resumeXFTPDelWork :: AgentMonad' m => AgentClient -> XFTPServer -> m ()
|
||||
resumeXFTPDelWork :: AgentClient -> XFTPServer -> AM' ()
|
||||
resumeXFTPDelWork = void .: getXFTPDelWorker False
|
||||
|
||||
getXFTPDelWorker :: AgentMonad' m => Bool -> AgentClient -> XFTPServer -> m Worker
|
||||
getXFTPDelWorker :: Bool -> AgentClient -> XFTPServer -> AM' Worker
|
||||
getXFTPDelWorker hasWork c server = do
|
||||
ws <- asks $ xftpDelWorkers . xftpAgent
|
||||
getAgentWorker "xftp_del" hasWork c server ws $ runXFTPDelWorker c server
|
||||
|
||||
runXFTPDelWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> Worker -> m ()
|
||||
runXFTPDelWorker :: AgentClient -> XFTPServer -> Worker -> AM ()
|
||||
runXFTPDelWorker c srv Worker {doWork} = do
|
||||
cfg <- asks config
|
||||
forever $ do
|
||||
waitForWork doWork
|
||||
lift $ waitForWork doWork
|
||||
atomically $ assertAgentForeground c
|
||||
runXFTPOperation cfg
|
||||
where
|
||||
runXFTPOperation :: AgentConfig -> m ()
|
||||
runXFTPOperation :: AgentConfig -> AM ()
|
||||
runXFTPOperation AgentConfig {rcvFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} = do
|
||||
-- no point in deleting files older than rcv ttl, as they will be expired on server
|
||||
withWork c doWork (\db -> getNextDeletedSndChunkReplica db srv rcvFilesTTL) processDeletedReplica
|
||||
@@ -626,7 +630,7 @@ runXFTPDelWorker c srv Worker {doWork} = do
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
when notifyOnRetry $ notify c "" $ SFERR e
|
||||
closeXFTPServerClient c userId server chunkDigest
|
||||
liftIO $ closeXFTPServerClient c userId server chunkDigest
|
||||
withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay
|
||||
atomically $ assertAgentForeground c
|
||||
loop
|
||||
@@ -635,7 +639,7 @@ runXFTPDelWorker c srv Worker {doWork} = do
|
||||
agentXFTPDeleteChunk c userId replica
|
||||
withStore' c $ \db -> deleteDeletedSndChunkReplica db deletedSndChunkReplicaId
|
||||
|
||||
delWorkerInternalError :: AgentMonad m => AgentClient -> Int64 -> AgentErrorType -> m ()
|
||||
delWorkerInternalError :: AgentClient -> Int64 -> AgentErrorType -> AM ()
|
||||
delWorkerInternalError c deletedSndChunkReplicaId e = do
|
||||
withStore' c $ \db -> deleteDeletedSndChunkReplica db deletedSndChunkReplicaId
|
||||
notify c "" $ SFERR e
|
||||
|
||||
@@ -50,7 +50,7 @@ import Simplex.Messaging.Transport.Client (TransportClientConfig, TransportHost)
|
||||
import Simplex.Messaging.Transport.HTTP2
|
||||
import Simplex.Messaging.Transport.HTTP2.Client
|
||||
import Simplex.Messaging.Transport.HTTP2.File
|
||||
import Simplex.Messaging.Util (bshow, liftEitherError, whenM)
|
||||
import Simplex.Messaging.Util (bshow, whenM)
|
||||
import UnliftIO
|
||||
import UnliftIO.Directory
|
||||
|
||||
@@ -98,7 +98,7 @@ getXFTPClient transportSession@(_, srv, _) config@XFTPClientConfig {xftpNetworkC
|
||||
clientVar <- newTVarIO Nothing
|
||||
let usePort = if null port then "443" else port
|
||||
clientDisconnected = readTVarIO clientVar >>= mapM_ disconnected
|
||||
http2Client <- liftEitherError xftpClientError $ getVerifiedHTTP2Client (Just username) useHost usePort (Just keyHash) Nothing http2Config clientDisconnected
|
||||
http2Client <- withExceptT xftpClientError . ExceptT $ getVerifiedHTTP2Client (Just username) useHost usePort (Just keyHash) Nothing http2Config clientDisconnected
|
||||
let HTTP2Client {sessionId} = http2Client
|
||||
thParams = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = currentXFTPVersion, thAuth = Nothing, implySessId = False, batch = True}
|
||||
c = XFTPClient {http2Client, thParams, transportSession, config}
|
||||
@@ -145,7 +145,7 @@ sendXFTPTransmission :: XFTPClient -> ByteString -> Maybe XFTPChunkSpec -> Excep
|
||||
sendXFTPTransmission XFTPClient {config, thParams, http2Client} t chunkSpec_ = do
|
||||
let req = H.requestStreaming N.methodPost "/" [] streamBody
|
||||
reqTimeout = (\XFTPChunkSpec {chunkSize} -> chunkTimeout config chunkSize) <$> chunkSpec_
|
||||
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- liftEitherError xftpClientError $ sendRequest http2Client req reqTimeout
|
||||
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req reqTimeout
|
||||
when (B.length bodyHead /= xftpBlockSize) $ throwError $ PCEResponseError BLOCK
|
||||
-- TODO validate that the file ID is the same as in the request?
|
||||
(_, _, (_, _fId, respOrErr)) <- liftEither . first PCEResponseError $ xftpDecodeTransmission thParams bodyHead
|
||||
@@ -196,9 +196,9 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec {
|
||||
let dhSecret = C.dh' sDhKey rpDhKey
|
||||
cbState <- liftEither . first PCECryptoError $ LC.cbInit dhSecret cbNonce
|
||||
let t = chunkTimeout config chunkSize
|
||||
t `timeout` download cbState >>= maybe (throwError PCEResponseTimeout) pure
|
||||
ExceptT (sequence <$> (t `timeout` download cbState)) >>= maybe (throwError PCEResponseTimeout) pure
|
||||
where
|
||||
download cbState =
|
||||
download cbState = runExceptT $
|
||||
withExceptT PCEResponseError $
|
||||
receiveEncFile chunkPart cbState chunkSpec `catchError` \e ->
|
||||
whenM (doesFileExist filePath) (removeFile filePath) >> throwError e
|
||||
|
||||
@@ -10,6 +10,7 @@ module Simplex.FileTransfer.Client.Agent where
|
||||
import Control.Logger.Simple (logInfo)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans (lift)
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Text (Text)
|
||||
@@ -86,7 +87,7 @@ getXFTPServerClient XFTPClientAgent {xftpClients, config} srv = do
|
||||
waitForXFTPClient :: XFTPClientVar -> ME XFTPClient
|
||||
waitForXFTPClient clientVar = do
|
||||
let XFTPClientConfig {xftpNetworkConfig = NetworkConfig {tcpConnectTimeout}} = xftpConfig config
|
||||
client_ <- tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
|
||||
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
|
||||
liftEither $ case client_ of
|
||||
Just (Right c) -> Right c
|
||||
Just (Left e) -> Left e
|
||||
@@ -110,7 +111,7 @@ getXFTPServerClient XFTPClientAgent {xftpClients, config} srv = do
|
||||
TM.delete srv xftpClients
|
||||
throwError e
|
||||
tryConnectAsync :: ME ()
|
||||
tryConnectAsync = void . async $ do
|
||||
tryConnectAsync = void . lift . async . runExceptT $ do
|
||||
withRetryInterval (reconnectInterval config) $ \_ loop -> void $ tryConnectClient loop
|
||||
|
||||
showServer :: XFTPServer -> Text
|
||||
|
||||
@@ -37,6 +37,7 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Char (toLower)
|
||||
import Data.Either (partitionEithers)
|
||||
import Data.Int (Int64)
|
||||
import Data.List (foldl', sortOn)
|
||||
import Data.List.NonEmpty (NonEmpty (..), nonEmpty)
|
||||
@@ -321,7 +322,9 @@ cliSendFileOpts SendOptions {filePath, outputDir, numRecipients, xftpServers, re
|
||||
-- the reason we don't do pooled downloads here within one server is that http2 library doesn't handle cleint concurrency, even though
|
||||
-- upload doesn't allow other requests within the same client until complete (but download does allow).
|
||||
logInfo $ "uploading " <> tshow (length chunks) <> " chunks..."
|
||||
map snd . sortOn fst . concat <$> pooledForConcurrentlyN 16 chunks' (mapM $ uploadFileChunk a)
|
||||
(errs, rs) <- partitionEithers . concat <$> liftIO (pooledForConcurrentlyN 16 chunks' . mapM $ runExceptT . uploadFileChunk a)
|
||||
mapM_ throwError errs
|
||||
pure $ map snd (sortOn fst rs)
|
||||
where
|
||||
uploadFileChunk :: XFTPClientAgent -> (Int, XFTPChunkSpec, XFTPServerWithAuth) -> ExceptT CLIError IO (Int, SentFileChunk)
|
||||
uploadFileChunk a (chunkNo, chunkSpec@XFTPChunkSpec {chunkSize}, ProtoServerWithAuth xftpServer auth) = do
|
||||
@@ -433,7 +436,9 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath,
|
||||
FileChunkReplica {server} : _ -> server
|
||||
srvChunks = groupAllOn srv chunks
|
||||
g <- liftIO C.newRandom
|
||||
chunkPaths <- map snd . sortOn fst . concat <$> pooledForConcurrentlyN 16 srvChunks (mapM $ downloadFileChunk g a encPath size downloadedChunks)
|
||||
(errs, rs) <- partitionEithers . concat <$> liftIO (pooledForConcurrentlyN 16 srvChunks $ mapM $ runExceptT . downloadFileChunk g a encPath size downloadedChunks)
|
||||
mapM_ throwError errs
|
||||
let chunkPaths = map snd $ sortOn fst rs
|
||||
encDigest <- liftIO $ LC.sha512Hash <$> readChunks chunkPaths
|
||||
when (encDigest /= unFileDigest digest) $ throwError $ CLIError "File digest mismatch"
|
||||
encSize <- liftIO $ foldM (\s path -> (s +) . fromIntegral <$> getFileSize path) 0 chunkPaths
|
||||
|
||||
@@ -29,7 +29,7 @@ import UnliftIO.Directory (removeFile)
|
||||
encryptFile :: CryptoFile -> ByteString -> C.SbKey -> C.CbNonce -> Int64 -> Int64 -> FilePath -> ExceptT FTCryptoError IO ()
|
||||
encryptFile srcFile fileHdr key nonce fileSize' encSize encFile = do
|
||||
sb <- liftEitherWith FTCECryptoError $ LC.sbInit key nonce
|
||||
CF.withFile srcFile ReadMode $ \r -> withFile encFile WriteMode $ \w -> do
|
||||
CF.withFile srcFile ReadMode $ \r -> ExceptT . withFile encFile WriteMode $ \w -> runExceptT $ do
|
||||
let lenStr = smpEncode fileSize'
|
||||
(hdr, !sb') = LC.sbEncryptChunk sb $ lenStr <> fileHdr
|
||||
padLen = encSize - authTagSize - fileSize' - 8
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiWayIf #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
@@ -17,7 +18,6 @@ import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Reader
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64.URL as B64
|
||||
import Data.ByteString.Builder (byteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
@@ -26,7 +26,7 @@ import Data.List (intercalate)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Maybe (fromMaybe, isJust)
|
||||
import qualified Data.Text as T
|
||||
import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
|
||||
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
||||
@@ -46,14 +46,16 @@ import Simplex.FileTransfer.Server.StoreLog
|
||||
import Simplex.FileTransfer.Transport
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (CorrId, RcvPublicDhKey, RcvPublicAuthKey, RecipientId, TransmissionAuth)
|
||||
import Simplex.Messaging.Protocol (CorrId, RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TransmissionAuth)
|
||||
import Simplex.Messaging.Server (dummyVerifyCmd, verifyCmdAuthorization)
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Server.Stats
|
||||
import Simplex.Messaging.Transport (THandleParams (..))
|
||||
import Simplex.Messaging.Transport.Buffer (trimCR)
|
||||
import Simplex.Messaging.Transport.HTTP2
|
||||
import Simplex.Messaging.Transport.HTTP2.File (fileBlockSize)
|
||||
import Simplex.Messaging.Transport.HTTP2.Server
|
||||
import Simplex.Messaging.Transport.Server (runTCPServer)
|
||||
import Simplex.Messaging.Util
|
||||
@@ -67,13 +69,12 @@ import qualified UnliftIO.Exception as E
|
||||
|
||||
type M a = ReaderT XFTPEnv IO a
|
||||
|
||||
data XFTPTransportRequest =
|
||||
XFTPTransportRequest
|
||||
{ thParams :: THandleParams XFTPVersion,
|
||||
reqBody :: HTTP2Body,
|
||||
request :: H.Request,
|
||||
sendResponse :: H.Response -> IO ()
|
||||
}
|
||||
data XFTPTransportRequest = XFTPTransportRequest
|
||||
{ thParams :: THandleParams XFTPVersion,
|
||||
reqBody :: HTTP2Body,
|
||||
request :: H.Request,
|
||||
sendResponse :: H.Response -> IO ()
|
||||
}
|
||||
|
||||
runXFTPServer :: XFTPServerConfig -> IO ()
|
||||
runXFTPServer cfg = do
|
||||
@@ -222,15 +223,13 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
|
||||
| Just auth == user = CPRUser
|
||||
| otherwise = CPRNone
|
||||
CPStatsRTS -> E.tryAny getRTSStats >>= either (hPrint h) (hPrint h)
|
||||
CPDelete fileId fKey -> withUserRole $ unliftIO u $ do
|
||||
CPDelete fileId -> withUserRole $ unliftIO u $ do
|
||||
fs <- asks store
|
||||
r <- runExceptT $ do
|
||||
let asSender = ExceptT . atomically $ getFile fs SFSender fileId
|
||||
let asRecipient = ExceptT . atomically $ getFile fs SFRecipient fileId
|
||||
(fr, fKey') <- asSender `catchError` const asRecipient
|
||||
if fKey == fKey'
|
||||
then ExceptT $ deleteServerFile_ fr
|
||||
else throwError AUTH
|
||||
(fr, _) <- asSender `catchError` const asRecipient
|
||||
ExceptT $ deleteServerFile_ fr
|
||||
liftIO . hPutStrLn h $ either (\e -> "error: " <> show e) (\() -> "ok") r
|
||||
CPHelp -> hPutStrLn h "commands: stats-rts, delete, help, quit"
|
||||
CPQuit -> pure ()
|
||||
@@ -331,7 +330,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
|
||||
-- TODO validate body empty
|
||||
sId <- ExceptT $ addFileRetry st file 3 ts
|
||||
rcps <- mapM (ExceptT . addRecipientRetry st 3 sId) rks
|
||||
withFileLog $ \sl -> do
|
||||
lift $ withFileLog $ \sl -> do
|
||||
logAddFile sl sId file ts
|
||||
logAddRecipients sl sId rcps
|
||||
stats <- asks serverStats
|
||||
@@ -363,7 +362,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
|
||||
st <- asks store
|
||||
r <- runExceptT $ do
|
||||
rcps <- mapM (ExceptT . addRecipientRetry st 3 sId) rks
|
||||
withFileLog $ \sl -> logAddRecipients sl sId rcps
|
||||
lift $ withFileLog $ \sl -> logAddRecipients sl sId rcps
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar' (fileRecipients stats) (+ length rks)
|
||||
let rIds = L.map (\(FileRecipient rId _) -> rId) rcps
|
||||
@@ -373,8 +372,19 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
|
||||
receiveServerFile FileRec {senderId, fileInfo = FileInfo {size, digest}, filePath} = case bodyPart of
|
||||
Nothing -> pure $ FRErr SIZE
|
||||
-- TODO validate body size from request before downloading, once it's populated
|
||||
Just getBody -> ifM reserve receive (pure $ FRErr QUOTA) -- TODO: handle duplicate uploads
|
||||
Just getBody -> skipCommitted $ ifM reserve receive (pure $ FRErr QUOTA)
|
||||
where
|
||||
-- having a filePath means the file is already uploaded and committed, must not change anything
|
||||
skipCommitted = ifM (isJust <$> readTVarIO filePath) (liftIO $ drain $ fromIntegral size)
|
||||
where
|
||||
-- can't send FROk without reading the request body or a client will block on sending it
|
||||
-- can't send any old error as the client would fail or restart indefinitely
|
||||
drain s = do
|
||||
bs <- B.length <$> getBody fileBlockSize
|
||||
if
|
||||
| bs == s -> pure FROk
|
||||
| bs == 0 || bs > s -> pure $ FRErr SIZE
|
||||
| otherwise -> drain (s - bs)
|
||||
reserve = do
|
||||
us <- asks $ usedStorage . store
|
||||
quota <- asks $ fromMaybe maxBound . fileSizeQuota . config
|
||||
@@ -382,7 +392,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
|
||||
\used -> let used' = used + fromIntegral size in if used' <= quota then (True, used') else (False, used)
|
||||
receive = do
|
||||
path <- asks $ filesPath . config
|
||||
let fPath = path </> B.unpack (B64.encode senderId)
|
||||
let fPath = path </> B.unpack (U.encode senderId)
|
||||
receiveChunk (XFTPRcvChunkSpec fPath size digest) >>= \case
|
||||
Right () -> do
|
||||
stats <- asks serverStats
|
||||
@@ -447,7 +457,7 @@ deleteServerFile_ FileRec {senderId, fileInfo, filePath} = do
|
||||
atomically $ modifyTVar' (filesCount stats) (subtract 1)
|
||||
atomically $ modifyTVar' (filesSize stats) (subtract $ fromIntegral $ size fileInfo)
|
||||
|
||||
randomId :: (MonadUnliftIO m, MonadReader XFTPEnv m) => Int -> m ByteString
|
||||
randomId :: Int -> M ByteString
|
||||
randomId n = atomically . C.randomBytes n =<< asks random
|
||||
|
||||
getFileId :: M XFTPFileId
|
||||
@@ -455,7 +465,7 @@ getFileId = do
|
||||
size <- asks (fileIdSize . config)
|
||||
atomically . C.randomBytes size =<< asks random
|
||||
|
||||
withFileLog :: (MonadIO m, MonadReader XFTPEnv m) => (StoreLog 'WriteMode -> IO a) -> m ()
|
||||
withFileLog :: (StoreLog 'WriteMode -> IO a) -> M ()
|
||||
withFileLog action = liftIO . mapM_ action =<< asks storeLog
|
||||
|
||||
incFileStat :: (FileServerStats -> TVar Int) -> M ()
|
||||
|
||||
@@ -5,7 +5,6 @@ module Simplex.FileTransfer.Server.Control where
|
||||
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (BasicAuth)
|
||||
|
||||
@@ -14,7 +13,7 @@ data CPClientRole = CPRNone | CPRUser | CPRAdmin
|
||||
data ControlProtocol
|
||||
= CPAuth BasicAuth
|
||||
| CPStatsRTS
|
||||
| CPDelete ByteString C.APublicAuthKey
|
||||
| CPDelete ByteString
|
||||
| CPHelp
|
||||
| CPQuit
|
||||
| CPSkip
|
||||
@@ -23,7 +22,7 @@ instance StrEncoding ControlProtocol where
|
||||
strEncode = \case
|
||||
CPAuth tok -> "auth " <> strEncode tok
|
||||
CPStatsRTS -> "stats-rts"
|
||||
CPDelete fId fKey -> strEncode (Str "delete", fId, fKey)
|
||||
CPDelete fId -> strEncode (Str "delete", fId)
|
||||
CPHelp -> "help"
|
||||
CPQuit -> "quit"
|
||||
CPSkip -> ""
|
||||
@@ -31,7 +30,7 @@ instance StrEncoding ControlProtocol where
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"auth" -> CPAuth <$> _strP
|
||||
"stats-rts" -> pure CPStatsRTS
|
||||
"delete" -> CPDelete <$> _strP <*> _strP
|
||||
"delete" -> CPDelete <$> _strP
|
||||
"help" -> pure CPHelp
|
||||
"quit" -> pure CPQuit
|
||||
"" -> pure CPSkip
|
||||
|
||||
@@ -94,7 +94,7 @@ defaultFileExpiration =
|
||||
checkInterval = 2 * 3600 -- seconds, 2 hours
|
||||
}
|
||||
|
||||
newXFTPServerEnv :: (MonadUnliftIO m, MonadRandom m) => XFTPServerConfig -> m XFTPEnv
|
||||
newXFTPServerEnv :: XFTPServerConfig -> IO XFTPEnv
|
||||
newXFTPServerEnv config@XFTPServerConfig {storeLogFile, fileSizeQuota, caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
random <- liftIO C.newRandom
|
||||
store <- atomically newFileStore
|
||||
|
||||
+393
-355
File diff suppressed because it is too large
Load Diff
@@ -142,7 +142,6 @@ import Crypto.Random (ChaChaDRG)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Bifunctor (bimap, first, second)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Composition ((.:.))
|
||||
@@ -182,6 +181,7 @@ import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.Base64 (encode)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Client
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
@@ -293,10 +293,11 @@ data AgentClient = AgentClient
|
||||
agentEnv :: Env
|
||||
}
|
||||
|
||||
getAgentWorker :: (AgentMonad' m, Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> ExceptT AgentErrorType m ()) -> m Worker
|
||||
getAgentWorker :: (Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> AM ()) -> AM' Worker
|
||||
getAgentWorker = getAgentWorker' id pure
|
||||
{-# INLINE getAgentWorker #-}
|
||||
|
||||
getAgentWorker' :: forall a k m. (AgentMonad' m, Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> ExceptT AgentErrorType m ()) -> m a
|
||||
getAgentWorker' :: forall a k. (Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> AM ()) -> AM' a
|
||||
getAgentWorker' toW fromW name hasWork c key ws work = do
|
||||
atomically (getWorker >>= maybe createWorker whenExists) >>= \w -> runWorker w $> w
|
||||
where
|
||||
@@ -310,9 +311,9 @@ getAgentWorker' toW fromW name hasWork c key ws work = do
|
||||
| otherwise = pure w
|
||||
runWorker w = runWorkerAsync (toW w) runWork
|
||||
where
|
||||
runWork :: m ()
|
||||
runWork :: AM' ()
|
||||
runWork = tryAgentError' (work w) >>= restartOrDelete
|
||||
restartOrDelete :: Either AgentErrorType () -> m ()
|
||||
restartOrDelete :: Either AgentErrorType () -> AM' ()
|
||||
restartOrDelete e_ = do
|
||||
t <- liftIO getSystemTime
|
||||
maxRestarts <- asks $ maxWorkerRestartsPerMin . config
|
||||
@@ -350,7 +351,7 @@ newWorker c = do
|
||||
restarts <- newTVar $ RestartCount 0 0
|
||||
pure Worker {workerId, doWork, action, restarts}
|
||||
|
||||
runWorkerAsync :: AgentMonad' m => Worker -> m () -> m ()
|
||||
runWorkerAsync :: Worker -> AM' () -> AM' ()
|
||||
runWorkerAsync Worker {action} work =
|
||||
E.bracket
|
||||
(atomically $ takeTMVar action) -- get current action, locking to avoid race conditions
|
||||
@@ -394,6 +395,7 @@ data AgentStatsKey = AgentStatsKey
|
||||
}
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's.
|
||||
newAgentClient :: Int -> InitialAgentServers -> Env -> STM AgentClient
|
||||
newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
|
||||
let qSize = tbqSize $ config agentEnv
|
||||
@@ -469,13 +471,15 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv =
|
||||
|
||||
agentClientStore :: AgentClient -> SQLiteStore
|
||||
agentClientStore AgentClient {agentEnv = Env {store}} = store
|
||||
{-# INLINE agentClientStore #-}
|
||||
|
||||
agentDRG :: AgentClient -> TVar ChaChaDRG
|
||||
agentDRG AgentClient {agentEnv = Env {random}} = random
|
||||
{-# INLINE agentDRG #-}
|
||||
|
||||
class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where
|
||||
type Client msg = c | c -> msg
|
||||
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (Client msg)
|
||||
getProtocolServerClient :: AgentClient -> TransportSession msg -> AM (Client msg)
|
||||
clientProtocolError :: err -> AgentErrorType
|
||||
closeProtocolServerClient :: Client msg -> IO ()
|
||||
clientServer :: Client msg -> String
|
||||
@@ -509,7 +513,7 @@ instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where
|
||||
clientTransportHost = X.xftpTransportHost
|
||||
clientSessionTs = X.xftpSessionTs
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m SMPClient
|
||||
getSMPServerClient :: AgentClient -> SMPTransportSession -> AM SMPClient
|
||||
getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, _) = do
|
||||
unlessM (readTVarIO active) . throwError $ INACTIVE
|
||||
atomically (getTSessVar c tSess smpClients)
|
||||
@@ -520,13 +524,14 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv,
|
||||
-- make it expensive to check for pending subscriptions.
|
||||
newClient v =
|
||||
newProtocolClient c tSess smpClients connectClient v
|
||||
`catchAgentError` \e -> resubscribeSMPSession c tSess >> throwError e
|
||||
connectClient :: SMPClientVar -> m SMPClient
|
||||
`catchAgentError` \e -> lift (resubscribeSMPSession c tSess) >> throwError e
|
||||
connectClient :: SMPClientVar -> AM SMPClient
|
||||
connectClient v = do
|
||||
cfg <- getClientConfig c smpCfg
|
||||
cfg <- lift $ getClientConfig c smpCfg
|
||||
g <- asks random
|
||||
env <- ask
|
||||
liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient g tSess cfg (Just msgQ) $ clientDisconnected env v)
|
||||
liftError' (protocolClientError SMP $ B.unpack $ strEncode srv) $
|
||||
getProtocolClient g tSess cfg (Just msgQ) $ clientDisconnected env v
|
||||
|
||||
clientDisconnected :: Env -> SMPClientVar -> SMPClient -> IO ()
|
||||
clientDisconnected env v client = do
|
||||
@@ -557,7 +562,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv,
|
||||
notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO ()
|
||||
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd)
|
||||
|
||||
resubscribeSMPSession :: AgentMonad' m => AgentClient -> SMPTransportSession -> m ()
|
||||
resubscribeSMPSession :: AgentClient -> SMPTransportSession -> AM' ()
|
||||
resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess =
|
||||
atomically getWorkerVar >>= mapM_ (either newSubWorker (\_ -> pure ()))
|
||||
where
|
||||
@@ -585,12 +590,12 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess =
|
||||
whenM (isEmptyTMVar $ sessionVar v) retry
|
||||
removeTSessVar v tSess smpSubWorkers
|
||||
|
||||
reconnectSMPClient :: forall m. AgentMonad m => TVar Int -> AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> m ()
|
||||
reconnectSMPClient :: TVar Int -> AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> AM ()
|
||||
reconnectSMPClient tc c tSess@(_, srv, _) qs = do
|
||||
NetworkConfig {tcpTimeout} <- readTVarIO $ useNetworkConfig c
|
||||
-- this allows 3x of timeout per batch of subscription (90 queues per batch empirically)
|
||||
let t = (length qs `div` 90 + 1) * tcpTimeout * 3
|
||||
t `timeout` resubscribe >>= \case
|
||||
ExceptT (sequence <$> (t `timeout` runExceptT resubscribe)) >>= \case
|
||||
Just _ -> atomically $ writeTVar tc 0
|
||||
Nothing -> do
|
||||
tc' <- atomically $ stateTVar tc $ \i -> (i + 1, i + 1)
|
||||
@@ -599,10 +604,10 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do
|
||||
msg = show tc' <> " consecutive subscription timeouts: " <> show (length qs) <> " queues, transport session: " <> show tSess
|
||||
atomically $ writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ err msg)
|
||||
where
|
||||
resubscribe :: m ()
|
||||
resubscribe :: AM ()
|
||||
resubscribe = do
|
||||
cs <- readTVarIO $ RQ.getConnections $ activeSubs c
|
||||
rs <- subscribeQueues c $ L.toList qs
|
||||
rs <- lift . subscribeQueues c $ L.toList qs
|
||||
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
|
||||
liftIO $ do
|
||||
let conns = filter (`M.notMember` cs) okConns
|
||||
@@ -616,7 +621,7 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do
|
||||
notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO ()
|
||||
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd)
|
||||
|
||||
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfTransportSession -> m NtfClient
|
||||
getNtfServerClient :: AgentClient -> NtfTransportSession -> AM NtfClient
|
||||
getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do
|
||||
unlessM (readTVarIO active) . throwError $ INACTIVE
|
||||
atomically (getTSessVar c tSess ntfClients)
|
||||
@@ -624,11 +629,12 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d
|
||||
(newProtocolClient c tSess ntfClients connectClient)
|
||||
(waitForProtocolClient c tSess)
|
||||
where
|
||||
connectClient :: NtfClientVar -> m NtfClient
|
||||
connectClient :: NtfClientVar -> AM NtfClient
|
||||
connectClient v = do
|
||||
cfg <- getClientConfig c ntfCfg
|
||||
cfg <- lift $ getClientConfig c ntfCfg
|
||||
g <- asks random
|
||||
liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient g tSess cfg Nothing $ clientDisconnected v)
|
||||
liftError' (protocolClientError NTF $ B.unpack $ strEncode srv) $
|
||||
getProtocolClient g tSess cfg Nothing $ clientDisconnected v
|
||||
|
||||
clientDisconnected :: NtfClientVar -> NtfClient -> IO ()
|
||||
clientDisconnected v client = do
|
||||
@@ -637,7 +643,7 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d
|
||||
atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
getXFTPServerClient :: forall m. AgentMonad m => AgentClient -> XFTPTransportSession -> m XFTPClient
|
||||
getXFTPServerClient :: AgentClient -> XFTPTransportSession -> AM XFTPClient
|
||||
getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@(userId, srv, _) = do
|
||||
unlessM (readTVarIO active) . throwError $ INACTIVE
|
||||
atomically (getTSessVar c tSess xftpClients)
|
||||
@@ -645,11 +651,12 @@ getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@
|
||||
(newProtocolClient c tSess xftpClients connectClient)
|
||||
(waitForProtocolClient c tSess)
|
||||
where
|
||||
connectClient :: XFTPClientVar -> m XFTPClient
|
||||
connectClient :: XFTPClientVar -> AM XFTPClient
|
||||
connectClient v = do
|
||||
cfg <- asks $ xftpCfg . config
|
||||
xftpNetworkConfig <- readTVarIO useNetworkConfig
|
||||
liftEitherError (protocolClientError XFTP $ B.unpack $ strEncode srv) (X.getXFTPClient tSess cfg {xftpNetworkConfig} $ clientDisconnected v)
|
||||
liftError' (protocolClientError XFTP $ B.unpack $ strEncode srv) $
|
||||
X.getXFTPClient tSess cfg {xftpNetworkConfig} $ clientDisconnected v
|
||||
|
||||
clientDisconnected :: XFTPClientVar -> XFTPClient -> IO ()
|
||||
clientDisconnected v client = do
|
||||
@@ -671,6 +678,7 @@ getTSessVar c tSess vs = maybe (Left <$> newSessionVar) (pure . Right) =<< TM.lo
|
||||
|
||||
removeTSessVar :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM ()
|
||||
removeTSessVar = void .:. removeTSessVar'
|
||||
{-# INLINE removeTSessVar #-}
|
||||
|
||||
removeTSessVar' :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM Bool
|
||||
removeTSessVar' v tSess vs =
|
||||
@@ -678,7 +686,7 @@ removeTSessVar' v tSess vs =
|
||||
Just v' | sessionVarId v == sessionVarId v' -> TM.delete tSess vs $> True
|
||||
_ -> pure False
|
||||
|
||||
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (Client msg)
|
||||
waitForProtocolClient :: ProtocolTypeI (ProtoType msg) => AgentClient -> TransportSession msg -> ClientVar msg -> AM (Client msg)
|
||||
waitForProtocolClient c (_, srv, _) v = do
|
||||
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
|
||||
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v)
|
||||
@@ -689,14 +697,14 @@ waitForProtocolClient c (_, srv, _) v = do
|
||||
|
||||
-- clientConnected arg is only passed for SMP server
|
||||
newProtocolClient ::
|
||||
forall v err msg m.
|
||||
(AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) =>
|
||||
forall v err msg.
|
||||
(ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) =>
|
||||
AgentClient ->
|
||||
TransportSession msg ->
|
||||
TMap (TransportSession msg) (ClientVar msg) ->
|
||||
(ClientVar msg -> m (Client msg)) ->
|
||||
(ClientVar msg -> AM (Client msg)) ->
|
||||
ClientVar msg ->
|
||||
m (Client msg)
|
||||
AM (Client msg)
|
||||
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
|
||||
tryAgentError (connectClient v) >>= \case
|
||||
Right client -> do
|
||||
@@ -715,14 +723,14 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
|
||||
hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone
|
||||
hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost
|
||||
|
||||
getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> m (ProtocolClientConfig v)
|
||||
getClientConfig :: AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> AM' (ProtocolClientConfig v)
|
||||
getClientConfig AgentClient {useNetworkConfig} cfgSel = do
|
||||
cfg <- asks $ cfgSel . config
|
||||
networkConfig <- readTVarIO useNetworkConfig
|
||||
pure cfg {networkConfig}
|
||||
|
||||
closeAgentClient :: MonadIO m => AgentClient -> m ()
|
||||
closeAgentClient c = liftIO $ do
|
||||
closeAgentClient :: AgentClient -> IO ()
|
||||
closeAgentClient c = do
|
||||
atomically $ writeTVar (active c) False
|
||||
closeProtocolServerClients c smpClients
|
||||
closeProtocolServerClients c ntfClients
|
||||
@@ -750,9 +758,11 @@ cancelWorker Worker {doWork, action} = do
|
||||
|
||||
waitUntilActive :: AgentClient -> STM ()
|
||||
waitUntilActive c = unlessM (readTVar $ active c) retry
|
||||
{-# INLINE waitUntilActive #-}
|
||||
|
||||
throwWhenInactive :: AgentClient -> STM ()
|
||||
throwWhenInactive c = unlessM (readTVar $ active c) $ throwSTM ThreadKilled
|
||||
{-# INLINE throwWhenInactive #-}
|
||||
|
||||
-- this function is used to remove workers once delivery is complete, not when it is removed from the map
|
||||
throwWhenNoDelivery :: AgentClient -> SndQueue -> STM ()
|
||||
@@ -779,82 +789,98 @@ closeClient_ c v = do
|
||||
Just (Right client) -> closeProtocolServerClient client `catchAll_` pure ()
|
||||
_ -> pure ()
|
||||
|
||||
closeXFTPServerClient :: AgentMonad' m => AgentClient -> UserId -> XFTPServer -> FileDigest -> m ()
|
||||
closeXFTPServerClient :: AgentClient -> UserId -> XFTPServer -> FileDigest -> IO ()
|
||||
closeXFTPServerClient c userId server (FileDigest chunkDigest) =
|
||||
mkTransportSession c userId server chunkDigest >>= liftIO . closeClient c xftpClients
|
||||
mkTransportSession c userId server chunkDigest >>= closeClient c xftpClients
|
||||
|
||||
withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a
|
||||
withConnLock _ "" _ = id
|
||||
withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId name
|
||||
withConnLock :: AgentClient -> ConnId -> String -> AM a -> AM a
|
||||
withConnLock c connId name = ExceptT . withConnLock' c connId name . runExceptT
|
||||
{-# INLINE withConnLock #-}
|
||||
|
||||
withInvLock :: MonadUnliftIO m => AgentClient -> ByteString -> String -> m a -> m a
|
||||
withInvLock AgentClient {invLocks} = withLockMap_ invLocks
|
||||
withConnLock' :: AgentClient -> ConnId -> String -> AM' a -> AM' a
|
||||
withConnLock' _ "" _ = id
|
||||
withConnLock' AgentClient {connLocks} connId name = withLockMap_ connLocks connId name
|
||||
{-# INLINE withConnLock' #-}
|
||||
|
||||
withConnLocks :: MonadUnliftIO m => AgentClient -> [ConnId] -> String -> m a -> m a
|
||||
withInvLock :: AgentClient -> ByteString -> String -> AM a -> AM a
|
||||
withInvLock c key name = ExceptT . withInvLock' c key name . runExceptT
|
||||
{-# INLINE withInvLock #-}
|
||||
|
||||
withInvLock' :: AgentClient -> ByteString -> String -> AM' a -> AM' a
|
||||
withInvLock' AgentClient {invLocks} = withLockMap_ invLocks
|
||||
{-# INLINE withInvLock' #-}
|
||||
|
||||
withConnLocks :: AgentClient -> [ConnId] -> String -> AM' a -> AM' a
|
||||
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null)
|
||||
{-# INLINE withConnLocks #-}
|
||||
|
||||
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
|
||||
withLockMap_ = withGetLock . getMapLock
|
||||
{-# INLINE withLockMap_ #-}
|
||||
|
||||
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a
|
||||
withLocksMap_ = withGetLocks . getMapLock
|
||||
{-# INLINE withLocksMap_ #-}
|
||||
|
||||
getMapLock :: Ord k => TMap k Lock -> k -> STM Lock
|
||||
getMapLock locks key = TM.lookup key locks >>= maybe newLock pure
|
||||
where
|
||||
newLock = createLock >>= \l -> TM.insert key l locks $> l
|
||||
|
||||
withClient_ :: forall a m v err msg. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a
|
||||
withClient_ :: forall a v err msg. ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> AM a) -> AM a
|
||||
withClient_ c tSess@(userId, srv, _) statCmd action = do
|
||||
cl <- getProtocolServerClient c tSess
|
||||
(action cl <* stat cl "OK") `catchAgentError` logServerError cl
|
||||
where
|
||||
stat cl = liftIO . incClientStat c userId cl statCmd
|
||||
logServerError :: Client msg -> AgentErrorType -> m a
|
||||
logServerError :: Client msg -> AgentErrorType -> AM a
|
||||
logServerError cl e = do
|
||||
logServer "<--" c srv "" $ strEncode e
|
||||
stat cl $ strEncode e
|
||||
throwError e
|
||||
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a
|
||||
withLogClient_ :: ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> AM a) -> AM a
|
||||
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
|
||||
logServer "-->" c srv entId cmdStr
|
||||
res <- withClient_ c tSess cmdStr action
|
||||
logServer "<--" c srv entId "OK"
|
||||
return res
|
||||
|
||||
withClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
|
||||
withClient :: forall v err msg a. ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> AM a
|
||||
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client
|
||||
{-# INLINE withClient #-}
|
||||
|
||||
withLogClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
|
||||
withLogClient :: forall v err msg a. ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> AM a
|
||||
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client
|
||||
{-# INLINE withLogClient #-}
|
||||
|
||||
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withSMPClient :: SMPQueueRec q => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> AM a
|
||||
withSMPClient c q cmdStr action = do
|
||||
tSess <- mkSMPTransportSession c q
|
||||
tSess <- liftIO $ mkSMPTransportSession c q
|
||||
withLogClient c tSess (queueId q) cmdStr action
|
||||
|
||||
withSMPClient_ :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> m a) -> m a
|
||||
withSMPClient_ :: SMPQueueRec q => AgentClient -> q -> ByteString -> (SMPClient -> AM a) -> AM a
|
||||
withSMPClient_ c q cmdStr action = do
|
||||
tSess <- mkSMPTransportSession c q
|
||||
tSess <- liftIO $ mkSMPTransportSession c q
|
||||
withLogClient_ c tSess (queueId q) cmdStr action
|
||||
|
||||
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT NtfClientError IO a) -> m a
|
||||
withNtfClient :: AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT NtfClientError IO a) -> AM a
|
||||
withNtfClient c srv = withLogClient c (0, srv, Nothing)
|
||||
|
||||
withXFTPClient ::
|
||||
(AgentMonad m, ProtocolServerClient v err msg) =>
|
||||
ProtocolServerClient v err msg =>
|
||||
AgentClient ->
|
||||
(UserId, ProtoServer msg, EntityId) ->
|
||||
ByteString ->
|
||||
(Client msg -> ExceptT (ProtocolClientError err) IO b) ->
|
||||
m b
|
||||
AM b
|
||||
withXFTPClient c (userId, srv, entityId) cmdStr action = do
|
||||
tSess <- mkTransportSession c userId srv entityId
|
||||
tSess <- liftIO $ mkTransportSession c userId srv entityId
|
||||
withLogClient c tSess entityId cmdStr action
|
||||
|
||||
liftClient :: (AgentMonad m, Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> m a
|
||||
liftClient :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> AM a
|
||||
liftClient protocolError_ = liftError . protocolClientError protocolError_
|
||||
{-# INLINE liftClient #-}
|
||||
|
||||
protocolClientError :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ProtocolClientError err -> AgentErrorType
|
||||
protocolClientError protocolError_ host = \case
|
||||
@@ -889,7 +915,7 @@ data ProtocolTestFailure = ProtocolTestFailure
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe ProtocolTestFailure)
|
||||
runSMPServerTest :: AgentClient -> UserId -> SMPServerWithAuth -> AM' (Maybe ProtocolTestFailure)
|
||||
runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
cfg <- getClientConfig c smpCfg
|
||||
C.AuthAlg ra <- asks $ rcvAuthAlg . config
|
||||
@@ -915,7 +941,7 @@ runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
testErr :: ProtocolTestStep -> SMPClientError -> ProtocolTestFailure
|
||||
testErr step = ProtocolTestFailure step . protocolClientError SMP addr
|
||||
|
||||
runXFTPServerTest :: forall m. AgentMonad m => AgentClient -> UserId -> XFTPServerWithAuth -> m (Maybe ProtocolTestFailure)
|
||||
runXFTPServerTest :: AgentClient -> UserId -> XFTPServerWithAuth -> AM' (Maybe ProtocolTestFailure)
|
||||
runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
cfg <- asks $ xftpCfg . config
|
||||
g <- asks random
|
||||
@@ -949,7 +975,7 @@ runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
testErr step = ProtocolTestFailure step . protocolClientError XFTP addr
|
||||
chSize :: Integral a => a
|
||||
chSize = kb 64
|
||||
getTempFilePath :: FilePath -> m FilePath
|
||||
getTempFilePath :: FilePath -> AM' FilePath
|
||||
getTempFilePath workPath = do
|
||||
ts <- liftIO getCurrentTime
|
||||
let isoTime = formatTime defaultTimeLocale "%Y-%m-%dT%H%M%S.%6q" ts
|
||||
@@ -963,7 +989,7 @@ runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
createTestChunk :: FilePath -> IO ()
|
||||
createTestChunk fp = B.writeFile fp =<< atomically . C.randomBytes chSize =<< C.newRandom
|
||||
|
||||
runNTFServerTest :: AgentMonad m => AgentClient -> UserId -> NtfServerWithAuth -> m (Maybe ProtocolTestFailure)
|
||||
runNTFServerTest :: AgentClient -> UserId -> NtfServerWithAuth -> AM' (Maybe ProtocolTestFailure)
|
||||
runNTFServerTest c userId (ProtoServerWithAuth srv _) = do
|
||||
cfg <- getClientConfig c ntfCfg
|
||||
C.AuthAlg a <- asks $ rcvAuthAlg . config
|
||||
@@ -987,27 +1013,32 @@ runNTFServerTest c userId (ProtoServerWithAuth srv _) = do
|
||||
testErr :: ProtocolTestStep -> SMPClientError -> ProtocolTestFailure
|
||||
testErr step = ProtocolTestFailure step . protocolClientError NTF addr
|
||||
|
||||
getXFTPWorkPath :: AgentMonad m => m FilePath
|
||||
getXFTPWorkPath :: AM' FilePath
|
||||
getXFTPWorkPath = do
|
||||
workDir <- readTVarIO =<< asks (xftpWorkDir . xftpAgent)
|
||||
maybe getTemporaryDirectory pure workDir
|
||||
|
||||
mkTransportSession :: AgentMonad' m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg)
|
||||
mkTransportSession :: AgentClient -> UserId -> ProtoServer msg -> EntityId -> IO (TransportSession msg)
|
||||
mkTransportSession c userId srv entityId = mkTSession userId srv entityId <$> getSessionMode c
|
||||
{-# INLINE mkTransportSession #-}
|
||||
|
||||
mkTSession :: UserId -> ProtoServer msg -> EntityId -> TransportSessionMode -> TransportSession msg
|
||||
mkTSession userId srv entityId mode = (userId, srv, if mode == TSMEntity then Just entityId else Nothing)
|
||||
{-# INLINE mkTSession #-}
|
||||
|
||||
mkSMPTransportSession :: (AgentMonad' m, SMPQueueRec q) => AgentClient -> q -> m SMPTransportSession
|
||||
mkSMPTransportSession :: SMPQueueRec q => AgentClient -> q -> IO SMPTransportSession
|
||||
mkSMPTransportSession c q = mkSMPTSession q <$> getSessionMode c
|
||||
{-# INLINE mkSMPTransportSession #-}
|
||||
|
||||
mkSMPTSession :: SMPQueueRec q => q -> TransportSessionMode -> SMPTransportSession
|
||||
mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q)
|
||||
{-# INLINE mkSMPTSession #-}
|
||||
|
||||
getSessionMode :: AgentMonad' m => AgentClient -> m TransportSessionMode
|
||||
getSessionMode :: AgentClient -> IO TransportSessionMode
|
||||
getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig
|
||||
{-# INLINE getSessionMode #-}
|
||||
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri)
|
||||
newRcvQueue :: AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> AM (NewRcvQueue, SMPQueueUri)
|
||||
newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do
|
||||
C.AuthAlg a <- asks (rcvAuthAlg . config)
|
||||
g <- asks random
|
||||
@@ -1015,10 +1046,10 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do
|
||||
(dhKey, privDhKey) <- atomically $ C.generateKeyPair g
|
||||
(e2eDhKey, e2ePrivKey) <- atomically $ C.generateKeyPair g
|
||||
logServer "-->" c srv "" "NEW"
|
||||
tSess <- mkTransportSession c userId srv connId
|
||||
tSess <- liftIO $ mkTransportSession c userId srv connId
|
||||
QIK {rcvId, sndId, rcvPublicDhKey} <-
|
||||
withClient c tSess "NEW" $ \smp -> createSMPQueue smp rKeys dhKey auth subMode
|
||||
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
|
||||
liftIO . logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
|
||||
let rq =
|
||||
RcvQueue
|
||||
{ userId,
|
||||
@@ -1057,14 +1088,16 @@ temporaryAgentError = \case
|
||||
BROKER _ TIMEOUT -> True
|
||||
INACTIVE -> True
|
||||
_ -> False
|
||||
{-# INLINE temporaryAgentError #-}
|
||||
|
||||
temporaryOrHostError :: AgentErrorType -> Bool
|
||||
temporaryOrHostError = \case
|
||||
BROKER _ HOST -> True
|
||||
e -> temporaryAgentError e
|
||||
{-# INLINE temporaryOrHostError #-}
|
||||
|
||||
-- | Subscribe to queues. The list of results can have a different order.
|
||||
subscribeQueues :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
subscribeQueues :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())]
|
||||
subscribeQueues c qs = do
|
||||
(errs, qs') <- partitionEithers <$> mapM checkQueue qs
|
||||
atomically $ do
|
||||
@@ -1088,11 +1121,11 @@ subscribeQueues c qs = do
|
||||
type BatchResponses e r = (NonEmpty (RcvQueue, Either e r))
|
||||
|
||||
-- statBatchSize is not used to batch the commands, only for traffic statistics
|
||||
sendTSessionBatches :: forall m q r. AgentMonad' m => ByteString -> Int -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses SMPClientError r)) -> AgentClient -> [q] -> m [(RcvQueue, Either AgentErrorType r)]
|
||||
sendTSessionBatches :: forall q r. ByteString -> Int -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses SMPClientError r)) -> AgentClient -> [q] -> AM' [(RcvQueue, Either AgentErrorType r)]
|
||||
sendTSessionBatches statCmd statBatchSize toRQ action c qs =
|
||||
concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues)
|
||||
where
|
||||
batchQueues :: m [(SMPTransportSession, NonEmpty q)]
|
||||
batchQueues :: AM' [(SMPTransportSession, NonEmpty q)]
|
||||
batchQueues = do
|
||||
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
|
||||
pure . M.assocs $ foldl' (batch mode) M.empty qs
|
||||
@@ -1100,7 +1133,7 @@ sendTSessionBatches statCmd statBatchSize toRQ action c qs =
|
||||
batch mode m q =
|
||||
let tSess = mkSMPTSession (toRQ q) mode
|
||||
in M.alter (Just . maybe [q] (q <|)) tSess m
|
||||
sendClientBatch :: (SMPTransportSession, NonEmpty q) -> m (BatchResponses AgentErrorType r)
|
||||
sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses AgentErrorType r)
|
||||
sendClientBatch (tSess@(userId, srv, _), qs') =
|
||||
tryAgentError' (getSMPServerClient c tSess) >>= \case
|
||||
Left e -> pure $ L.map ((,Left e) . toRQ) qs'
|
||||
@@ -1120,7 +1153,7 @@ sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs)
|
||||
where
|
||||
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
|
||||
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m ()
|
||||
addSubscription :: AgentClient -> RcvQueue -> IO ()
|
||||
addSubscription c rq@RcvQueue {connId} = atomically $ do
|
||||
modifyTVar' (subscrConns c) $ S.insert connId
|
||||
RQ.addQueue rq $ activeSubs c
|
||||
@@ -1128,6 +1161,7 @@ addSubscription c rq@RcvQueue {connId} = atomically $ do
|
||||
|
||||
hasActiveSubscription :: AgentClient -> ConnId -> STM Bool
|
||||
hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c
|
||||
{-# INLINE hasActiveSubscription #-}
|
||||
|
||||
removeSubscription :: AgentClient -> ConnId -> STM ()
|
||||
removeSubscription c connId = do
|
||||
@@ -1137,19 +1171,23 @@ removeSubscription c connId = do
|
||||
|
||||
getSubscriptions :: AgentClient -> STM (Set ConnId)
|
||||
getSubscriptions = readTVar . subscrConns
|
||||
{-# INLINE getSubscriptions #-}
|
||||
|
||||
logServer :: MonadIO m => ByteString -> AgentClient -> ProtocolServer s -> QueueId -> ByteString -> m ()
|
||||
logServer dir AgentClient {clientId} srv qId cmdStr =
|
||||
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr]
|
||||
{-# INLINE logServer #-}
|
||||
|
||||
showServer :: ProtocolServer s -> ByteString
|
||||
showServer ProtocolServer {host, port} =
|
||||
strEncode host <> B.pack (if null port then "" else ':' : port)
|
||||
{-# INLINE showServer #-}
|
||||
|
||||
logSecret :: ByteString -> ByteString
|
||||
logSecret bs = encode $ B.take 3 bs
|
||||
{-# INLINE logSecret #-}
|
||||
|
||||
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
||||
sendConfirmation :: AgentClient -> SndQueue -> ByteString -> AM ()
|
||||
sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
|
||||
withSMPClient_ c sq "SEND <CONF>" $ \smp -> do
|
||||
let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation
|
||||
@@ -1157,21 +1195,21 @@ sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubK
|
||||
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg
|
||||
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
|
||||
|
||||
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
sendInvitation :: AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> AM ()
|
||||
sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do
|
||||
tSess <- mkTransportSession c userId smpServer senderId
|
||||
tSess <- liftIO $ mkTransportSession c userId smpServer senderId
|
||||
withLogClient_ c tSess senderId "SEND <INV>" $ \smp -> do
|
||||
msg <- mkInvitation
|
||||
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing senderId MsgFlags {notification = True} msg
|
||||
where
|
||||
mkInvitation :: m ByteString
|
||||
mkInvitation :: AM ByteString
|
||||
-- this is only encrypted with per-queue E2E, not with double ratchet
|
||||
mkInvitation = do
|
||||
let agentEnvelope = AgentInvitation {agentVersion, connReq, connInfo}
|
||||
agentCbEncryptOnce v dhPublicKey . smpEncode $
|
||||
SMP.ClientMessage SMP.PHEmpty (smpEncode agentEnvelope)
|
||||
|
||||
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta)
|
||||
getQueueMessage :: AgentClient -> RcvQueue -> AM (Maybe SMPMsgMeta)
|
||||
getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
|
||||
atomically createTakeGetLock
|
||||
msg_ <- withSMPClient c rq "GET" $ \smp ->
|
||||
@@ -1186,23 +1224,23 @@ getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
|
||||
takeTMVar l
|
||||
pure $ Just l
|
||||
|
||||
decryptSMPMessage :: AgentMonad m => RcvQueue -> SMP.RcvMessage -> m SMP.ClientRcvMsgBody
|
||||
decryptSMPMessage :: RcvQueue -> SMP.RcvMessage -> AM SMP.ClientRcvMsgBody
|
||||
decryptSMPMessage rq SMP.RcvMessage {msgId, msgBody = SMP.EncRcvMsgBody body} =
|
||||
liftEither . parse SMP.clientRcvMsgBodyP (AGENT A_MESSAGE) =<< decrypt body
|
||||
where
|
||||
decrypt = agentCbDecrypt (rcvDhSecret rq) (C.cbNonce msgId)
|
||||
|
||||
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicAuthKey -> m ()
|
||||
secureQueue :: AgentClient -> RcvQueue -> SndPublicAuthKey -> AM ()
|
||||
secureQueue c rq@RcvQueue {rcvId, rcvPrivateKey} senderKey =
|
||||
withSMPClient c rq "KEY <key>" $ \smp ->
|
||||
secureSMPQueue smp rcvPrivateKey rcvId senderKey
|
||||
|
||||
enableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> SMP.NtfPublicAuthKey -> SMP.RcvNtfPublicDhKey -> m (SMP.NotifierId, SMP.RcvNtfPublicDhKey)
|
||||
enableQueueNotifications :: AgentClient -> RcvQueue -> SMP.NtfPublicAuthKey -> SMP.RcvNtfPublicDhKey -> AM (SMP.NotifierId, SMP.RcvNtfPublicDhKey)
|
||||
enableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
|
||||
withSMPClient c rq "NKEY <nkey>" $ \smp ->
|
||||
enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey
|
||||
|
||||
enableQueuesNtfs :: forall m. AgentMonad' m => AgentClient -> [(RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)] -> m [(RcvQueue, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))]
|
||||
enableQueuesNtfs :: AgentClient -> [(RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)] -> AM' [(RcvQueue, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))]
|
||||
enableQueuesNtfs = sendTSessionBatches "NKEY" 90 fst3 enableQueues_
|
||||
where
|
||||
fst3 (x, _, _) = x
|
||||
@@ -1211,15 +1249,15 @@ enableQueuesNtfs = sendTSessionBatches "NKEY" 90 fst3 enableQueues_
|
||||
queueCreds :: (RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) -> (SMP.RcvPrivateAuthKey, SMP.RecipientId, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)
|
||||
queueCreds (RcvQueue {rcvPrivateKey, rcvId}, notifierKey, rcvNtfPublicDhKey) = (rcvPrivateKey, rcvId, notifierKey, rcvNtfPublicDhKey)
|
||||
|
||||
disableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
disableQueueNotifications :: AgentClient -> RcvQueue -> AM ()
|
||||
disableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} =
|
||||
withSMPClient c rq "NDEL" $ \smp ->
|
||||
disableSMPQueueNotifications smp rcvPrivateKey rcvId
|
||||
|
||||
disableQueuesNtfs :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
disableQueuesNtfs :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())]
|
||||
disableQueuesNtfs = sendTSessionBatches "NDEL" 90 id $ sendBatch disableSMPQueuesNtfs
|
||||
|
||||
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m ()
|
||||
sendAck :: AgentClient -> RcvQueue -> MsgId -> AM ()
|
||||
sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do
|
||||
withSMPClient c rq ("ACK:" <> logSecret msgId) $ \smp ->
|
||||
ackSMPMessage smp rcvPrivateKey rcvId msgId
|
||||
@@ -1233,93 +1271,93 @@ releaseGetLock :: AgentClient -> RcvQueue -> STM ()
|
||||
releaseGetLock c RcvQueue {server, rcvId} =
|
||||
TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ())
|
||||
|
||||
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
suspendQueue :: AgentClient -> RcvQueue -> AM ()
|
||||
suspendQueue c rq@RcvQueue {rcvId, rcvPrivateKey} =
|
||||
withSMPClient c rq "OFF" $ \smp ->
|
||||
suspendSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
deleteQueue :: AgentClient -> RcvQueue -> AM ()
|
||||
deleteQueue c rq@RcvQueue {rcvId, rcvPrivateKey} = do
|
||||
withSMPClient c rq "DEL" $ \smp ->
|
||||
deleteSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
deleteQueues :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
deleteQueues :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())]
|
||||
deleteQueues = sendTSessionBatches "DEL" 90 id $ sendBatch deleteSMPQueues
|
||||
|
||||
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
|
||||
sendAgentMessage :: AgentClient -> SndQueue -> MsgFlags -> ByteString -> AM ()
|
||||
sendAgentMessage c sq@SndQueue {sndId, sndPrivateKey} msgFlags agentMsg =
|
||||
withSMPClient_ c sq "SEND <MSG>" $ \smp -> do
|
||||
let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg
|
||||
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
|
||||
liftClient SMP (clientServer smp) $ sendSMPMessage smp (Just sndPrivateKey) sndId msgFlags msg
|
||||
|
||||
agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> NtfPublicAuthKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519)
|
||||
agentNtfRegisterToken :: AgentClient -> NtfToken -> NtfPublicAuthKey -> C.PublicKeyX25519 -> AM (NtfTokenId, C.PublicKeyX25519)
|
||||
agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey =
|
||||
withClient c (0, ntfServer, Nothing) "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
|
||||
|
||||
agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m ()
|
||||
agentNtfVerifyToken :: AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> AM ()
|
||||
agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code =
|
||||
withNtfClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
|
||||
|
||||
agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus
|
||||
agentNtfCheckToken :: AgentClient -> NtfTokenId -> NtfToken -> AM NtfTknStatus
|
||||
agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withNtfClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
|
||||
|
||||
agentNtfReplaceToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> m ()
|
||||
agentNtfReplaceToken :: AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> AM ()
|
||||
agentNtfReplaceToken c tknId NtfToken {ntfServer, ntfPrivKey} token =
|
||||
withNtfClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
|
||||
|
||||
agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m ()
|
||||
agentNtfDeleteToken :: AgentClient -> NtfTokenId -> NtfToken -> AM ()
|
||||
agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withNtfClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
|
||||
|
||||
agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m ()
|
||||
agentNtfEnableCron :: AgentClient -> NtfTokenId -> NtfToken -> Word16 -> AM ()
|
||||
agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval =
|
||||
withNtfClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
|
||||
|
||||
agentNtfCreateSubscription :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> SMP.NtfPrivateAuthKey -> m NtfSubscriptionId
|
||||
agentNtfCreateSubscription :: AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> SMP.NtfPrivateAuthKey -> AM NtfSubscriptionId
|
||||
agentNtfCreateSubscription c tknId NtfToken {ntfServer, ntfPrivKey} smpQueue nKey =
|
||||
withNtfClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
|
||||
|
||||
agentNtfCheckSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m NtfSubStatus
|
||||
agentNtfCheckSubscription :: AgentClient -> NtfSubscriptionId -> NtfToken -> AM NtfSubStatus
|
||||
agentNtfCheckSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withNtfClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
|
||||
|
||||
agentNtfDeleteSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m ()
|
||||
agentNtfDeleteSubscription :: AgentClient -> NtfSubscriptionId -> NtfToken -> AM ()
|
||||
agentNtfDeleteSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withNtfClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
|
||||
|
||||
agentXFTPDownloadChunk :: AgentMonad m => AgentClient -> UserId -> FileDigest -> RcvFileChunkReplica -> XFTPRcvChunkSpec -> m ()
|
||||
agentXFTPDownloadChunk :: AgentClient -> UserId -> FileDigest -> RcvFileChunkReplica -> XFTPRcvChunkSpec -> AM ()
|
||||
agentXFTPDownloadChunk c userId (FileDigest chunkDigest) RcvFileChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey} chunkSpec = do
|
||||
g <- asks random
|
||||
withXFTPClient c (userId, server, chunkDigest) "FGET" $ \xftp -> X.downloadXFTPChunk g xftp replicaKey fId chunkSpec
|
||||
|
||||
agentXFTPNewChunk :: AgentMonad m => AgentClient -> SndFileChunk -> Int -> XFTPServerWithAuth -> m NewSndChunkReplica
|
||||
agentXFTPNewChunk :: AgentClient -> SndFileChunk -> Int -> XFTPServerWithAuth -> AM NewSndChunkReplica
|
||||
agentXFTPNewChunk c SndFileChunk {userId, chunkSpec = XFTPChunkSpec {chunkSize}, digest = FileDigest chunkDigest} n (ProtoServerWithAuth srv auth) = do
|
||||
rKeys <- xftpRcvKeys n
|
||||
(sndKey, replicaKey) <- atomically . C.generateAuthKeyPair C.SEd25519 =<< asks random
|
||||
let fileInfo = FileInfo {sndKey, size = fromIntegral chunkSize, digest = chunkDigest}
|
||||
logServer "-->" c srv "" "FNEW"
|
||||
tSess <- mkTransportSession c userId srv chunkDigest
|
||||
tSess <- liftIO $ mkTransportSession c userId srv chunkDigest
|
||||
(sndId, rIds) <- withClient c tSess "FNEW" $ \xftp -> X.createXFTPChunk xftp replicaKey fileInfo (L.map fst rKeys) auth
|
||||
logServer "<--" c srv "" $ B.unwords ["SIDS", logSecret sndId]
|
||||
pure NewSndChunkReplica {server = srv, replicaId = ChunkReplicaId sndId, replicaKey, rcvIdsKeys = L.toList $ xftpRcvIdsKeys rIds rKeys}
|
||||
|
||||
agentXFTPUploadChunk :: AgentMonad m => AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> XFTPChunkSpec -> m ()
|
||||
agentXFTPUploadChunk :: AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> XFTPChunkSpec -> AM ()
|
||||
agentXFTPUploadChunk c userId (FileDigest chunkDigest) SndFileChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey} chunkSpec =
|
||||
withXFTPClient c (userId, server, chunkDigest) "FPUT" $ \xftp -> X.uploadXFTPChunk xftp replicaKey fId chunkSpec
|
||||
|
||||
agentXFTPAddRecipients :: AgentMonad m => AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> Int -> m (NonEmpty (ChunkReplicaId, C.APrivateAuthKey))
|
||||
agentXFTPAddRecipients :: AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> Int -> AM (NonEmpty (ChunkReplicaId, C.APrivateAuthKey))
|
||||
agentXFTPAddRecipients c userId (FileDigest chunkDigest) SndFileChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey} n = do
|
||||
rKeys <- xftpRcvKeys n
|
||||
rIds <- withXFTPClient c (userId, server, chunkDigest) "FADD" $ \xftp -> X.addXFTPRecipients xftp replicaKey fId (L.map fst rKeys)
|
||||
pure $ xftpRcvIdsKeys rIds rKeys
|
||||
|
||||
agentXFTPDeleteChunk :: AgentMonad m => AgentClient -> UserId -> DeletedSndChunkReplica -> m ()
|
||||
agentXFTPDeleteChunk :: AgentClient -> UserId -> DeletedSndChunkReplica -> AM ()
|
||||
agentXFTPDeleteChunk c userId DeletedSndChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey, chunkDigest = FileDigest chunkDigest} =
|
||||
withXFTPClient c (userId, server, chunkDigest) "FDEL" $ \xftp -> X.deleteXFTPChunk xftp replicaKey fId
|
||||
|
||||
xftpRcvKeys :: AgentMonad m => Int -> m (NonEmpty C.AAuthKeyPair)
|
||||
xftpRcvKeys :: Int -> AM (NonEmpty C.AAuthKeyPair)
|
||||
xftpRcvKeys n = do
|
||||
rKeys <- atomically . replicateM n . C.generateAuthKeyPair C.SEd25519 =<< asks random
|
||||
case L.nonEmpty rKeys of
|
||||
@@ -1329,7 +1367,7 @@ xftpRcvKeys n = do
|
||||
xftpRcvIdsKeys :: NonEmpty ByteString -> NonEmpty C.AAuthKeyPair -> NonEmpty (ChunkReplicaId, C.APrivateAuthKey)
|
||||
xftpRcvIdsKeys rIds rKeys = L.map ChunkReplicaId rIds `L.zip` L.map snd rKeys
|
||||
|
||||
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
|
||||
agentCbEncrypt :: SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> AM ByteString
|
||||
agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
|
||||
cmNonce <- atomically . C.randomCbNonce =<< asks random
|
||||
let paddedLen = maybe SMP.e2eEncMessageLength (const SMP.e2eEncConfirmationLength) e2ePubKey
|
||||
@@ -1340,7 +1378,7 @@ agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
|
||||
pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody}
|
||||
|
||||
-- add encoding as AgentInvitation'?
|
||||
agentCbEncryptOnce :: AgentMonad m => VersionSMPC -> C.PublicKeyX25519 -> ByteString -> m ByteString
|
||||
agentCbEncryptOnce :: VersionSMPC -> C.PublicKeyX25519 -> ByteString -> AM ByteString
|
||||
agentCbEncryptOnce clientVersion dhRcvPubKey msg = do
|
||||
g <- asks random
|
||||
(dhSndPubKey, dhSndPrivKey) <- atomically $ C.generateKeyPair g
|
||||
@@ -1354,7 +1392,7 @@ agentCbEncryptOnce clientVersion dhRcvPubKey msg = do
|
||||
|
||||
-- | NaCl crypto-box decrypt - both for messages received from the server
|
||||
-- and per-queue E2E encrypted messages from the sender that were inside.
|
||||
agentCbDecrypt :: AgentMonad m => C.DhSecretX25519 -> C.CbNonce -> ByteString -> m ByteString
|
||||
agentCbDecrypt :: C.DhSecretX25519 -> C.CbNonce -> ByteString -> AM ByteString
|
||||
agentCbDecrypt dhSecret nonce msg =
|
||||
liftEither . first cryptoError $
|
||||
C.cbDecrypt dhSecret nonce msg
|
||||
@@ -1373,10 +1411,11 @@ cryptoError = \case
|
||||
where
|
||||
c = AGENT . A_CRYPTO
|
||||
|
||||
waitForWork :: AgentMonad' m => TMVar () -> m ()
|
||||
waitForWork :: MonadIO m => TMVar () -> m ()
|
||||
waitForWork = void . atomically . readTMVar
|
||||
{-# INLINE waitForWork #-}
|
||||
|
||||
withWork :: AgentMonad m => AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> m ()) -> m ()
|
||||
withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> AM ()) -> AM ()
|
||||
withWork c doWork getWork action =
|
||||
withStore' c getWork >>= \case
|
||||
Right (Just r) -> action r
|
||||
@@ -1389,12 +1428,15 @@ withWork c doWork getWork action =
|
||||
|
||||
noWorkToDo :: TMVar () -> IO ()
|
||||
noWorkToDo = void . atomically . tryTakeTMVar
|
||||
{-# INLINE noWorkToDo #-}
|
||||
|
||||
hasWorkToDo :: Worker -> STM ()
|
||||
hasWorkToDo = hasWorkToDo' . doWork
|
||||
{-# INLINE hasWorkToDo #-}
|
||||
|
||||
hasWorkToDo' :: TMVar () -> STM ()
|
||||
hasWorkToDo' = void . (`tryPutTMVar` ())
|
||||
{-# INLINE hasWorkToDo' #-}
|
||||
|
||||
endAgentOperation :: AgentClient -> AgentOperation -> STM ()
|
||||
endAgentOperation c op = endOperation c op $ case op of
|
||||
@@ -1438,6 +1480,7 @@ endOperation c op endedAction = do
|
||||
|
||||
whenSuspending :: AgentClient -> STM () -> STM ()
|
||||
whenSuspending c = whenM ((== ASSuspending) <$> readTVar (agentState c))
|
||||
{-# INLINE whenSuspending #-}
|
||||
|
||||
beginAgentOperation :: AgentClient -> AgentOperation -> STM ()
|
||||
beginAgentOperation c op = do
|
||||
@@ -1457,20 +1500,22 @@ agentOperationBracket c op check action =
|
||||
|
||||
waitUntilForeground :: AgentClient -> STM ()
|
||||
waitUntilForeground c = unlessM ((ASForeground ==) <$> readTVar (agentState c)) retry
|
||||
{-# INLINE waitUntilForeground #-}
|
||||
|
||||
withStore' :: AgentMonad m => AgentClient -> (DB.Connection -> IO a) -> m a
|
||||
withStore' :: AgentClient -> (DB.Connection -> IO a) -> AM a
|
||||
withStore' c action = withStore c $ fmap Right . action
|
||||
{-# INLINE withStore' #-}
|
||||
|
||||
withStore :: AgentMonad m => AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a
|
||||
withStore :: AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> AM a
|
||||
withStore c action = do
|
||||
st <- asks store
|
||||
liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $
|
||||
withExceptT storeError . ExceptT . liftIO . agentOperationBracket c AODatabase (\_ -> pure ()) $
|
||||
withTransaction st action `E.catch` handleInternal ""
|
||||
where
|
||||
handleInternal :: String -> E.SomeException -> IO (Either StoreError a)
|
||||
handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr
|
||||
|
||||
withStoreBatch :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> m (t (Either AgentErrorType a))
|
||||
withStoreBatch :: Traversable t => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> AM' (t (Either AgentErrorType a))
|
||||
withStoreBatch c actions = do
|
||||
st <- asks store
|
||||
liftIO . agentOperationBracket c AODatabase (\_ -> pure ()) $
|
||||
@@ -1480,8 +1525,9 @@ withStoreBatch c actions = do
|
||||
handleInternal :: E.SomeException -> IO (Either AgentErrorType a)
|
||||
handleInternal = pure . Left . INTERNAL . show
|
||||
|
||||
withStoreBatch' :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO a)) -> m (t (Either AgentErrorType a))
|
||||
withStoreBatch' :: Traversable t => AgentClient -> (DB.Connection -> t (IO a)) -> AM' (t (Either AgentErrorType a))
|
||||
withStoreBatch' c actions = withStoreBatch c (fmap (fmap Right) . actions)
|
||||
{-# INLINE withStoreBatch' #-}
|
||||
|
||||
storeError :: StoreError -> AgentErrorType
|
||||
storeError = \case
|
||||
@@ -1505,6 +1551,7 @@ incStat AgentClient {agentStats} n k = do
|
||||
|
||||
incClientStat :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO ()
|
||||
incClientStat c userId pc = incClientStatN c userId pc 1
|
||||
{-# INLINE incClientStat #-}
|
||||
|
||||
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
|
||||
incServerStat c userId ProtocolServer {host} cmd res = do
|
||||
@@ -1523,27 +1570,28 @@ userServers :: forall p. (ProtocolTypeI p, UserProtocol p) => AgentClient -> TMa
|
||||
userServers c = case protocolTypeI @p of
|
||||
SPSMP -> smpServers c
|
||||
SPXFTP -> xftpServers c
|
||||
{-# INLINE userServers #-}
|
||||
|
||||
pickServer :: forall p m. AgentMonad' m => NonEmpty (ProtoServerWithAuth p) -> m (ProtoServerWithAuth p)
|
||||
pickServer :: forall p. NonEmpty (ProtoServerWithAuth p) -> AM (ProtoServerWithAuth p)
|
||||
pickServer = \case
|
||||
srv :| [] -> pure srv
|
||||
servers -> do
|
||||
gen <- asks randomServer
|
||||
atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1))
|
||||
|
||||
getNextServer :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> [ProtocolServer p] -> m (ProtoServerWithAuth p)
|
||||
getNextServer :: forall p. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> [ProtocolServer p] -> AM (ProtoServerWithAuth p)
|
||||
getNextServer c userId usedSrvs = withUserServers c userId $ \srvs ->
|
||||
case L.nonEmpty $ deleteFirstsBy sameSrvAddr' (L.toList srvs) (map noAuthSrv usedSrvs) of
|
||||
Just srvs' -> pickServer srvs'
|
||||
_ -> pickServer srvs
|
||||
|
||||
withUserServers :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> (NonEmpty (ProtoServerWithAuth p) -> m a) -> m a
|
||||
withUserServers :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> (NonEmpty (ProtoServerWithAuth p) -> AM a) -> AM a
|
||||
withUserServers c userId action =
|
||||
atomically (TM.lookup userId $ userServers c) >>= \case
|
||||
Just srvs -> action srvs
|
||||
_ -> throwError $ INTERNAL "unknown userId - no user servers"
|
||||
|
||||
withNextSrv :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> m a) -> m a
|
||||
withNextSrv :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> AM a) -> AM a
|
||||
withNextSrv c userId usedSrvs initUsed action = do
|
||||
used <- readTVarIO usedSrvs
|
||||
srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId used
|
||||
@@ -1564,7 +1612,7 @@ data SubscriptionsInfo = SubscriptionsInfo
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
getAgentSubscriptions :: MonadIO m => AgentClient -> m SubscriptionsInfo
|
||||
getAgentSubscriptions :: AgentClient -> IO SubscriptionsInfo
|
||||
getAgentSubscriptions c = do
|
||||
activeSubscriptions <- getSubs activeSubs
|
||||
pendingSubscriptions <- getSubs pendingSubs
|
||||
@@ -1600,7 +1648,7 @@ data WorkersDetails = WorkersDetails
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
getAgentWorkersDetails :: MonadIO m => AgentClient -> m AgentWorkersDetails
|
||||
getAgentWorkersDetails :: AgentClient -> IO AgentWorkersDetails
|
||||
getAgentWorkersDetails AgentClient {smpClients, ntfClients, xftpClients, smpDeliveryWorkers, asyncCmdWorkers, smpSubWorkers, agentEnv} = do
|
||||
smpClients_ <- textKeys <$> readTVarIO smpClients
|
||||
ntfClients_ <- textKeys <$> readTVarIO ntfClients
|
||||
@@ -1632,7 +1680,7 @@ getAgentWorkersDetails AgentClient {smpClients, ntfClients, xftpClients, smpDeli
|
||||
textKeys = map textKey . M.keys
|
||||
textKey :: StrEncoding k => k -> Text
|
||||
textKey = decodeASCII . strEncode
|
||||
workerStats :: (StrEncoding k, MonadIO m) => Map k Worker -> m (Map Text WorkersDetails)
|
||||
workerStats :: StrEncoding k => Map k Worker -> IO (Map Text WorkersDetails)
|
||||
workerStats ws = fmap M.fromList . forM (M.toList ws) $ \(qa, Worker {restarts, doWork, action}) -> do
|
||||
RestartCount {restartCount} <- readTVarIO restarts
|
||||
hasWork <- atomically $ not <$> isEmptyTMVar doWork
|
||||
@@ -1664,7 +1712,7 @@ data WorkersSummary = WorkersSummary
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
getAgentWorkersSummary :: MonadIO m => AgentClient -> m AgentWorkersSummary
|
||||
getAgentWorkersSummary :: AgentClient -> IO AgentWorkersSummary
|
||||
getAgentWorkersSummary AgentClient {smpClients, ntfClients, xftpClients, smpDeliveryWorkers, asyncCmdWorkers, smpSubWorkers, agentEnv} = do
|
||||
smpClientsCount <- M.size <$> readTVarIO smpClients
|
||||
ntfClientsCount <- M.size <$> readTVarIO ntfClients
|
||||
@@ -1695,7 +1743,7 @@ getAgentWorkersSummary AgentClient {smpClients, ntfClients, xftpClients, smpDeli
|
||||
Env {ntfSupervisor, xftpAgent} = agentEnv
|
||||
NtfSupervisor {ntfWorkers, ntfSMPWorkers} = ntfSupervisor
|
||||
XFTPAgent {xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers} = xftpAgent
|
||||
workerSummary :: MonadIO m => M.Map k Worker -> m WorkersSummary
|
||||
workerSummary :: M.Map k Worker -> IO WorkersSummary
|
||||
workerSummary = liftIO . foldM byWork WorkersSummary {numActive = 0, numIdle = 0, totalRestarts = 0}
|
||||
where
|
||||
byWork WorkersSummary {numActive, numIdle, totalRestarts} Worker {action, restarts} = do
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Env.SQLite
|
||||
( AgentMonad,
|
||||
AgentMonad',
|
||||
( AM',
|
||||
AM,
|
||||
AgentConfig (..),
|
||||
InitialAgentServers (..),
|
||||
NetworkConfig (..),
|
||||
@@ -21,6 +21,7 @@ module Simplex.Messaging.Agent.Env.SQLite
|
||||
tryAgentError,
|
||||
tryAgentError',
|
||||
catchAgentError,
|
||||
catchAgentError',
|
||||
agentFinally,
|
||||
Env (..),
|
||||
newSMPAgentEnv,
|
||||
@@ -34,7 +35,6 @@ module Simplex.Messaging.Agent.Env.SQLite
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
@@ -65,14 +65,14 @@ import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (SMPVersion, TLS, Transport (..))
|
||||
import Simplex.Messaging.Transport.Client (defaultSMPPort)
|
||||
import Simplex.Messaging.Util (allFinally, catchAllErrors, tryAllErrors)
|
||||
import Simplex.Messaging.Util (allFinally, catchAllErrors, catchAllErrors', tryAllErrors, tryAllErrors')
|
||||
import System.Random (StdGen, newStdGen)
|
||||
import UnliftIO (Async, SomeException)
|
||||
import UnliftIO.STM
|
||||
|
||||
type AgentMonad' m = (MonadUnliftIO m, MonadReader Env m)
|
||||
type AM' a = ReaderT Env IO a
|
||||
|
||||
type AgentMonad m = (AgentMonad' m, MonadError AgentErrorType m)
|
||||
type AM a = ExceptT AgentErrorType (ReaderT Env IO) a
|
||||
|
||||
data InitialAgentServers = InitialAgentServers
|
||||
{ smp :: Map UserId (NonEmpty SMPServerWithAuth),
|
||||
@@ -82,7 +82,7 @@ data InitialAgentServers = InitialAgentServers
|
||||
}
|
||||
|
||||
data AgentConfig = AgentConfig
|
||||
{ tcpPort :: ServiceName,
|
||||
{ tcpPort :: Maybe ServiceName,
|
||||
rcvAuthAlg :: C.AuthAlg,
|
||||
sndAuthAlg :: C.AuthAlg,
|
||||
connIdBytes :: Int,
|
||||
@@ -149,7 +149,7 @@ defaultMessageRetryInterval =
|
||||
defaultAgentConfig :: AgentConfig
|
||||
defaultAgentConfig =
|
||||
AgentConfig
|
||||
{ tcpPort = "5224",
|
||||
{ tcpPort = Just "5224",
|
||||
-- while the current client version supports X25519, it can only be enabled once support for SMP v6 is dropped,
|
||||
-- and all servers are required to support v7 to be compatible.
|
||||
rcvAuthAlg = C.AuthAlg C.SEd25519, -- this will stay as Ed25519
|
||||
@@ -250,20 +250,24 @@ newXFTPAgent = do
|
||||
xftpDelWorkers <- TM.empty
|
||||
pure XFTPAgent {xftpWorkDir, xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers}
|
||||
|
||||
tryAgentError :: AgentMonad m => m a -> m (Either AgentErrorType a)
|
||||
tryAgentError :: AM a -> AM (Either AgentErrorType a)
|
||||
tryAgentError = tryAllErrors mkInternal
|
||||
{-# INLINE tryAgentError #-}
|
||||
|
||||
-- unlike runExceptT, this ensures we catch IO exceptions as well
|
||||
tryAgentError' :: AgentMonad' m => ExceptT AgentErrorType m a -> m (Either AgentErrorType a)
|
||||
tryAgentError' = fmap join . runExceptT . tryAgentError
|
||||
tryAgentError' :: AM a -> AM' (Either AgentErrorType a)
|
||||
tryAgentError' = tryAllErrors' mkInternal
|
||||
{-# INLINE tryAgentError' #-}
|
||||
|
||||
catchAgentError :: AgentMonad m => m a -> (AgentErrorType -> m a) -> m a
|
||||
catchAgentError :: AM a -> (AgentErrorType -> AM a) -> AM a
|
||||
catchAgentError = catchAllErrors mkInternal
|
||||
{-# INLINE catchAgentError #-}
|
||||
|
||||
agentFinally :: AgentMonad m => m a -> m b -> m a
|
||||
catchAgentError' :: AM a -> (AgentErrorType -> AM' a) -> AM' a
|
||||
catchAgentError' = catchAllErrors' mkInternal
|
||||
{-# INLINE catchAgentError' #-}
|
||||
|
||||
agentFinally :: AM a -> AM b -> AM a
|
||||
agentFinally = allFinally mkInternal
|
||||
{-# INLINE agentFinally #-}
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Lock
|
||||
( Lock,
|
||||
createLock,
|
||||
withLock,
|
||||
withLock',
|
||||
withGetLock,
|
||||
withGetLocks,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Monad (void)
|
||||
import Control.Monad.Except (ExceptT (..), runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.Functor (($>))
|
||||
import UnliftIO.Async (forConcurrently)
|
||||
@@ -22,8 +22,12 @@ createLock :: STM Lock
|
||||
createLock = newEmptyTMVar
|
||||
{-# INLINE createLock #-}
|
||||
|
||||
withLock :: MonadUnliftIO m => Lock -> String -> m a -> m a
|
||||
withLock lock name =
|
||||
withLock :: MonadUnliftIO m => Lock -> String -> ExceptT e m a -> ExceptT e m a
|
||||
withLock lock name = ExceptT . withLock' lock name . runExceptT
|
||||
{-# INLINE withLock #-}
|
||||
|
||||
withLock' :: MonadUnliftIO m => Lock -> String -> m a -> m a
|
||||
withLock' lock name =
|
||||
E.bracket_
|
||||
(atomically $ putTMVar lock name)
|
||||
(void . atomically $ takeTMVar lock)
|
||||
|
||||
@@ -44,7 +44,7 @@ import UnliftIO
|
||||
import UnliftIO.Concurrent (forkIO, threadDelay)
|
||||
import qualified UnliftIO.Exception as E
|
||||
|
||||
runNtfSupervisor :: forall m. AgentMonad' m => AgentClient -> m ()
|
||||
runNtfSupervisor :: AgentClient -> AM' ()
|
||||
runNtfSupervisor c = do
|
||||
ns <- asks ntfSupervisor
|
||||
forever $ do
|
||||
@@ -54,13 +54,13 @@ runNtfSupervisor c = do
|
||||
Left e -> notifyErr connId e
|
||||
Right _ -> return ()
|
||||
where
|
||||
handleErr :: ConnId -> m () -> m ()
|
||||
handleErr :: ConnId -> AM' () -> AM' ()
|
||||
handleErr connId = E.handle $ \(e :: E.SomeException) -> do
|
||||
logError $ "runNtfSupervisor error " <> tshow e
|
||||
notifyErr connId e
|
||||
notifyErr connId e = notifyInternalError c connId $ "runNtfSupervisor error " <> show e
|
||||
|
||||
processNtfSub :: forall m. AgentMonad m => AgentClient -> (ConnId, NtfSupervisorCommand) -> m ()
|
||||
processNtfSub :: AgentClient -> (ConnId, NtfSupervisorCommand) -> AM ()
|
||||
processNtfSub c (connId, cmd) = do
|
||||
logInfo $ "processNtfSub - connId = " <> tshow connId <> " - cmd = " <> tshow cmd
|
||||
case cmd of
|
||||
@@ -77,11 +77,11 @@ processNtfSub c (connId, cmd) = do
|
||||
Just ClientNtfCreds {notifierId} -> do
|
||||
let newSub = newNtfSubscription connId smpServer (Just notifierId) ntfServer NASKey
|
||||
withStore c $ \db -> createNtfSubscription db newSub $ NtfSubNTFAction NSACreate
|
||||
void $ getNtfNTFWorker True c ntfServer
|
||||
lift . void $ getNtfNTFWorker True c ntfServer
|
||||
Nothing -> do
|
||||
let newSub = newNtfSubscription connId smpServer Nothing ntfServer NASNew
|
||||
withStore c $ \db -> createNtfSubscription db newSub $ NtfSubSMPAction NSASmpKey
|
||||
void $ getNtfSMPWorker True c smpServer
|
||||
lift . void $ getNtfSMPWorker True c smpServer
|
||||
(Just (sub@NtfSubscription {ntfSubStatus, ntfServer = subNtfServer, smpServer = smpServer', ntfQueueId}, action_)) -> do
|
||||
case (clientNtfCreds, ntfQueueId) of
|
||||
(Just ClientNtfCreds {notifierId}, Just ntfQueueId')
|
||||
@@ -90,7 +90,7 @@ processNtfSub c (connId, cmd) = do
|
||||
(Nothing, Nothing) -> create
|
||||
_ -> rotate
|
||||
where
|
||||
create :: m ()
|
||||
create :: AM ()
|
||||
create = case action_ of
|
||||
-- action was set to NULL after worker internal error
|
||||
Nothing -> resetSubscription
|
||||
@@ -101,60 +101,60 @@ processNtfSub c (connId, cmd) = do
|
||||
then resetSubscription
|
||||
else withTokenServer $ \ntfServer -> do
|
||||
withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate)
|
||||
void $ getNtfNTFWorker True c ntfServer
|
||||
lift . void $ getNtfNTFWorker True c ntfServer
|
||||
| otherwise -> case action of
|
||||
NtfSubNTFAction _ -> void $ getNtfNTFWorker True c subNtfServer
|
||||
NtfSubSMPAction _ -> void $ getNtfSMPWorker True c smpServer
|
||||
rotate :: m ()
|
||||
NtfSubNTFAction _ -> lift . void $ getNtfNTFWorker True c subNtfServer
|
||||
NtfSubSMPAction _ -> lift . void $ getNtfSMPWorker True c smpServer
|
||||
rotate :: AM ()
|
||||
rotate = do
|
||||
withStore' c $ \db -> supervisorUpdateNtfSub db sub (NtfSubNTFAction NSARotate)
|
||||
void $ getNtfNTFWorker True c subNtfServer
|
||||
resetSubscription :: m ()
|
||||
lift . void $ getNtfNTFWorker True c subNtfServer
|
||||
resetSubscription :: AM ()
|
||||
resetSubscription =
|
||||
withTokenServer $ \ntfServer -> do
|
||||
let sub' = sub {ntfQueueId = Nothing, ntfServer, ntfSubId = Nothing, ntfSubStatus = NASNew}
|
||||
withStore' c $ \db -> supervisorUpdateNtfSub db sub' (NtfSubSMPAction NSASmpKey)
|
||||
void $ getNtfSMPWorker True c smpServer
|
||||
lift . void $ getNtfSMPWorker True c smpServer
|
||||
NSCDelete -> do
|
||||
sub_ <- withStore' c $ \db -> do
|
||||
supervisorUpdateNtfAction db connId (NtfSubNTFAction NSADelete)
|
||||
getNtfSubscription db connId
|
||||
logInfo $ "processNtfSub, NSCDelete - sub_ = " <> tshow sub_
|
||||
case sub_ of
|
||||
(Just (NtfSubscription {ntfServer}, _)) -> void $ getNtfNTFWorker True c ntfServer
|
||||
(Just (NtfSubscription {ntfServer}, _)) -> lift . void $ getNtfNTFWorker True c ntfServer
|
||||
_ -> pure () -- err "NSCDelete - no subscription"
|
||||
NSCSmpDelete -> do
|
||||
withStore' c (`getPrimaryRcvQueue` connId) >>= \case
|
||||
Right rq@RcvQueue {server = smpServer} -> do
|
||||
logInfo $ "processNtfSub, NSCSmpDelete - rq = " <> tshow rq
|
||||
withStore' c $ \db -> supervisorUpdateNtfAction db connId (NtfSubSMPAction NSASmpDelete)
|
||||
void $ getNtfSMPWorker True c smpServer
|
||||
lift . void $ getNtfSMPWorker True c smpServer
|
||||
_ -> notifyInternalError c connId "NSCSmpDelete - no rcv queue"
|
||||
NSCNtfWorker ntfServer -> void $ getNtfNTFWorker True c ntfServer
|
||||
NSCNtfSMPWorker smpServer -> void $ getNtfSMPWorker True c smpServer
|
||||
NSCNtfWorker ntfServer -> lift . void $ getNtfNTFWorker True c ntfServer
|
||||
NSCNtfSMPWorker smpServer -> lift . void $ getNtfSMPWorker True c smpServer
|
||||
|
||||
getNtfNTFWorker :: AgentMonad' m => Bool -> AgentClient -> NtfServer -> m Worker
|
||||
getNtfNTFWorker :: Bool -> AgentClient -> NtfServer -> AM' Worker
|
||||
getNtfNTFWorker hasWork c server = do
|
||||
ws <- asks $ ntfWorkers . ntfSupervisor
|
||||
getAgentWorker "ntf_ntf" hasWork c server ws $ runNtfWorker c server
|
||||
|
||||
getNtfSMPWorker :: AgentMonad' m => Bool -> AgentClient -> SMPServer -> m Worker
|
||||
getNtfSMPWorker :: Bool -> AgentClient -> SMPServer -> AM' Worker
|
||||
getNtfSMPWorker hasWork c server = do
|
||||
ws <- asks $ ntfSMPWorkers . ntfSupervisor
|
||||
getAgentWorker "ntf_smp" hasWork c server ws $ runNtfSMPWorker c server
|
||||
|
||||
withTokenServer :: AgentMonad' m => (NtfServer -> m ()) -> m ()
|
||||
withTokenServer action = getNtfToken >>= mapM_ (\NtfToken {ntfServer} -> action ntfServer)
|
||||
withTokenServer :: (NtfServer -> AM ()) -> AM ()
|
||||
withTokenServer action = lift getNtfToken >>= mapM_ (\NtfToken {ntfServer} -> action ntfServer)
|
||||
|
||||
runNtfWorker :: forall m. AgentMonad m => AgentClient -> NtfServer -> Worker -> m ()
|
||||
runNtfWorker :: AgentClient -> NtfServer -> Worker -> AM ()
|
||||
runNtfWorker c srv Worker {doWork} = do
|
||||
delay <- asks $ ntfWorkerDelay . config
|
||||
forever $ do
|
||||
waitForWork doWork
|
||||
agentOperationBracket c AONtfNetwork throwWhenInactive runNtfOperation
|
||||
ExceptT $ agentOperationBracket c AONtfNetwork throwWhenInactive $ runExceptT runNtfOperation
|
||||
threadDelay delay
|
||||
where
|
||||
runNtfOperation :: m ()
|
||||
runNtfOperation :: AM ()
|
||||
runNtfOperation =
|
||||
withWork c doWork (`getNextNtfSubNTFAction` srv) $
|
||||
\nextSub@(NtfSubscription {connId}, _, _) -> do
|
||||
@@ -163,13 +163,13 @@ runNtfWorker c srv Worker {doWork} = do
|
||||
withRetryInterval ri $ \_ loop ->
|
||||
processSub nextSub
|
||||
`catchAgentError` retryOnError c "NtfWorker" loop (workerInternalError c connId . show)
|
||||
processSub :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> m ()
|
||||
processSub :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> AM ()
|
||||
processSub (sub@NtfSubscription {connId, smpServer, ntfSubId}, action, actionTs) = do
|
||||
ts <- liftIO getCurrentTime
|
||||
unlessM (rescheduleAction doWork ts actionTs) $
|
||||
unlessM (lift $ rescheduleAction doWork ts actionTs) $
|
||||
case action of
|
||||
NSACreate ->
|
||||
getNtfToken >>= \case
|
||||
lift getNtfToken >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do
|
||||
RcvQueue {clientNtfCreds} <- withStore c (`getPrimaryRcvQueue` connId)
|
||||
case clientNtfCreds of
|
||||
@@ -182,13 +182,13 @@ runNtfWorker c srv Worker {doWork} = do
|
||||
_ -> workerInternalError c connId "NSACreate - no notifier queue credentials"
|
||||
_ -> workerInternalError c connId "NSACreate - no active token"
|
||||
NSACheck ->
|
||||
getNtfToken >>= \case
|
||||
lift getNtfToken >>= \case
|
||||
Just tkn ->
|
||||
case ntfSubId of
|
||||
Just nSubId ->
|
||||
agentNtfCheckSubscription c nSubId tkn >>= \case
|
||||
NSAuth -> do
|
||||
getNtfServer c >>= \case
|
||||
lift (getNtfServer c) >>= \case
|
||||
Just ntfServer -> do
|
||||
withStore' c $ \db ->
|
||||
updateNtfSubscription db sub {ntfServer, ntfQueueId = Nothing, ntfSubId = Nothing, ntfSubStatus = NASNew} (NtfSubSMPAction NSASmpKey) ts
|
||||
@@ -200,7 +200,7 @@ runNtfWorker c srv Worker {doWork} = do
|
||||
_ -> workerInternalError c connId "NSACheck - no active token"
|
||||
NSADelete -> case ntfSubId of
|
||||
Just nSubId ->
|
||||
(getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId))
|
||||
(lift getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId))
|
||||
`agentFinally` continueDeletion
|
||||
_ -> continueDeletion
|
||||
where
|
||||
@@ -211,7 +211,7 @@ runNtfWorker c srv Worker {doWork} = do
|
||||
atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer)
|
||||
NSARotate -> case ntfSubId of
|
||||
Just nSubId ->
|
||||
(getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId))
|
||||
(lift getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId))
|
||||
`agentFinally` deleteCreate
|
||||
_ -> deleteCreate
|
||||
where
|
||||
@@ -228,12 +228,14 @@ runNtfWorker c srv Worker {doWork} = do
|
||||
withStore' c $ \db ->
|
||||
updateNtfSubscription db sub {ntfSubStatus = toStatus} toAction actionTs'
|
||||
|
||||
runNtfSMPWorker :: forall m. AgentMonad m => AgentClient -> SMPServer -> Worker -> m ()
|
||||
runNtfSMPWorker :: AgentClient -> SMPServer -> Worker -> AM ()
|
||||
runNtfSMPWorker c srv Worker {doWork} = do
|
||||
env <- ask
|
||||
delay <- asks $ ntfSMPWorkerDelay . config
|
||||
forever $ do
|
||||
waitForWork doWork
|
||||
agentOperationBracket c AONtfNetwork throwWhenInactive runNtfSMPOperation
|
||||
ExceptT . liftIO . agentOperationBracket c AONtfNetwork throwWhenInactive $
|
||||
runReaderT (runExceptT runNtfSMPOperation) env
|
||||
threadDelay delay
|
||||
where
|
||||
runNtfSMPOperation =
|
||||
@@ -244,13 +246,13 @@ runNtfSMPWorker c srv Worker {doWork} = do
|
||||
withRetryInterval ri $ \_ loop ->
|
||||
processSub nextSub
|
||||
`catchAgentError` retryOnError c "NtfSMPWorker" loop (workerInternalError c connId . show)
|
||||
processSub :: (NtfSubscription, NtfSubSMPAction, NtfActionTs) -> m ()
|
||||
processSub :: (NtfSubscription, NtfSubSMPAction, NtfActionTs) -> AM ()
|
||||
processSub (sub@NtfSubscription {connId, ntfServer}, smpAction, actionTs) = do
|
||||
ts <- liftIO getCurrentTime
|
||||
unlessM (rescheduleAction doWork ts actionTs) $
|
||||
unlessM (lift $ rescheduleAction doWork ts actionTs) $
|
||||
case smpAction of
|
||||
NSASmpKey ->
|
||||
getNtfToken >>= \case
|
||||
lift getNtfToken >>= \case
|
||||
Just NtfToken {ntfTknStatus = NTActive, ntfMode = NMInstant} -> do
|
||||
rq <- withStore c (`getPrimaryRcvQueue` connId)
|
||||
C.AuthAlg a <- asks (rcvAuthAlg . config)
|
||||
@@ -272,7 +274,7 @@ runNtfSMPWorker c srv Worker {doWork} = do
|
||||
mapM_ (disableQueueNotifications c) rq_
|
||||
withStore' c $ \db -> deleteNtfSubscription db connId
|
||||
|
||||
rescheduleAction :: AgentMonad' m => TMVar () -> UTCTime -> UTCTime -> m Bool
|
||||
rescheduleAction :: TMVar () -> UTCTime -> UTCTime -> AM' Bool
|
||||
rescheduleAction doWork ts actionTs
|
||||
| actionTs <= ts = pure False
|
||||
| otherwise = do
|
||||
@@ -282,7 +284,7 @@ rescheduleAction doWork ts actionTs
|
||||
atomically $ hasWorkToDo' doWork
|
||||
pure True
|
||||
|
||||
retryOnError :: AgentMonad' m => AgentClient -> Text -> m () -> (AgentErrorType -> m ()) -> AgentErrorType -> m ()
|
||||
retryOnError :: AgentClient -> Text -> AM () -> (AgentErrorType -> AM ()) -> AgentErrorType -> AM ()
|
||||
retryOnError c name loop done e = do
|
||||
logError $ name <> " error: " <> tshow e
|
||||
case e of
|
||||
@@ -296,16 +298,17 @@ retryOnError c name loop done e = do
|
||||
atomically $ beginAgentOperation c AONtfNetwork
|
||||
loop
|
||||
|
||||
workerInternalError :: AgentMonad m => AgentClient -> ConnId -> String -> m ()
|
||||
workerInternalError :: AgentClient -> ConnId -> String -> AM ()
|
||||
workerInternalError c connId internalErrStr = do
|
||||
withStore' c $ \db -> setNullNtfSubscriptionAction db connId
|
||||
notifyInternalError c connId internalErrStr
|
||||
|
||||
-- TODO change error
|
||||
notifyInternalError :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m ()
|
||||
notifyInternalError :: MonadIO m => AgentClient -> ConnId -> String -> m ()
|
||||
notifyInternalError AgentClient {subQ} connId internalErrStr = atomically $ writeTBQueue subQ ("", connId, APC SAEConn $ ERR $ INTERNAL internalErrStr)
|
||||
{-# INLINE notifyInternalError #-}
|
||||
|
||||
getNtfToken :: AgentMonad' m => m (Maybe NtfToken)
|
||||
getNtfToken :: AM' (Maybe NtfToken)
|
||||
getNtfToken = do
|
||||
tkn <- asks $ ntfTkn . ntfSupervisor
|
||||
readTVarIO tkn
|
||||
@@ -326,14 +329,14 @@ instantNotifications = \case
|
||||
Just NtfToken {ntfTknStatus = NTActive, ntfMode = NMInstant} -> True
|
||||
_ -> False
|
||||
|
||||
closeNtfSupervisor :: MonadUnliftIO m => NtfSupervisor -> m ()
|
||||
closeNtfSupervisor :: NtfSupervisor -> IO ()
|
||||
closeNtfSupervisor ns = do
|
||||
stopWorkers $ ntfWorkers ns
|
||||
stopWorkers $ ntfSMPWorkers ns
|
||||
where
|
||||
stopWorkers workers = atomically (swapTVar workers M.empty) >>= mapM_ (liftIO . cancelWorker)
|
||||
|
||||
getNtfServer :: AgentMonad' m => AgentClient -> m (Maybe NtfServer)
|
||||
getNtfServer :: AgentClient -> AM' (Maybe NtfServer)
|
||||
getNtfServer c = do
|
||||
ntfServers <- readTVarIO $ ntfServers c
|
||||
case ntfServers of
|
||||
|
||||
@@ -163,7 +163,6 @@ import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
@@ -202,6 +201,7 @@ import Simplex.Messaging.Crypto.Ratchet
|
||||
SndE2ERatchetParams
|
||||
)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.Base64 (base64P, encode)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Protocol
|
||||
@@ -1897,15 +1897,15 @@ tGetRaw :: Transport c => c -> IO ARawTransmission
|
||||
tGetRaw h = (,,) <$> getLn h <*> getLn h <*> getLn h
|
||||
|
||||
-- | Send SMP agent protocol command (or response) to TCP connection.
|
||||
tPut :: (Transport c, MonadIO m) => c -> ATransmission p -> m ()
|
||||
tPut :: Transport c => c -> ATransmission p -> IO ()
|
||||
tPut h (corrId, connId, APC _ cmd) =
|
||||
liftIO $ tPutRaw h (corrId, connId, serializeCommand cmd)
|
||||
tPutRaw h (corrId, connId, serializeCommand cmd)
|
||||
|
||||
-- | Receive client and agent transmissions from TCP connection.
|
||||
tGet :: forall c m p. (Transport c, MonadIO m) => SAParty p -> c -> m (ATransmissionOrError p)
|
||||
tGet :: forall c p. Transport c => SAParty p -> c -> IO (ATransmissionOrError p)
|
||||
tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
where
|
||||
tParseLoadBody :: ARawTransmission -> m (ATransmissionOrError p)
|
||||
tParseLoadBody :: ARawTransmission -> IO (ATransmissionOrError p)
|
||||
tParseLoadBody t@(corrId, entId, command) = do
|
||||
let cmd = parseCommand command >>= fromParty >>= tConnId t
|
||||
fullCmd <- either (return . Left) cmdWithMsgBody cmd
|
||||
@@ -1935,7 +1935,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
| B.null entId -> Left $ CMD NO_CONN
|
||||
| otherwise -> Right cmd
|
||||
|
||||
cmdWithMsgBody :: APartyCmd p -> m (Either AgentErrorType (APartyCmd p))
|
||||
cmdWithMsgBody :: APartyCmd p -> IO (Either AgentErrorType (APartyCmd p))
|
||||
cmdWithMsgBody (APC e cmd) =
|
||||
APC e <$$> case cmd of
|
||||
SEND pqEnc msgFlags body -> SEND pqEnc msgFlags <$$> getBody body
|
||||
@@ -1948,7 +1948,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
INFO pqSup cInfo -> INFO pqSup <$$> getBody cInfo
|
||||
_ -> pure $ Right cmd
|
||||
|
||||
getBody :: ByteString -> m (Either AgentErrorType ByteString)
|
||||
getBody :: ByteString -> IO (Either AgentErrorType ByteString)
|
||||
getBody binary =
|
||||
case B.unpack binary of
|
||||
':' : body -> return . Right $ B.pack body
|
||||
|
||||
@@ -12,13 +12,13 @@ where
|
||||
|
||||
import Control.Logger.Simple (logInfo)
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random (MonadRandom)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Client (newAgentClient)
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore)
|
||||
@@ -32,7 +32,7 @@ import UnliftIO.STM
|
||||
-- | Runs an SMP agent as a TCP service using passed configuration.
|
||||
--
|
||||
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
|
||||
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> m ()
|
||||
runSMPAgent :: ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> IO ()
|
||||
runSMPAgent t cfg initServers store =
|
||||
runSMPAgentBlocking t cfg initServers store 0 =<< newEmptyTMVarIO
|
||||
|
||||
@@ -40,44 +40,46 @@ runSMPAgent t cfg initServers store =
|
||||
--
|
||||
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
|
||||
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
|
||||
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Int -> TMVar Bool -> m ()
|
||||
runSMPAgentBlocking (ATransport t) cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers store initClientId started = do
|
||||
liftIO (newSMPAgentEnv cfg store) >>= runReaderT (smpAgent t)
|
||||
runSMPAgentBlocking :: ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Int -> TMVar Bool -> IO ()
|
||||
runSMPAgentBlocking (ATransport t) cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers store initClientId started =
|
||||
case tcpPort of
|
||||
Just port -> newSMPAgentEnv cfg store >>= smpAgent t port
|
||||
Nothing -> E.throwIO $ userError "no agent port"
|
||||
where
|
||||
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
|
||||
smpAgent _ = do
|
||||
smpAgent :: forall c. Transport c => TProxy c -> ServiceName -> Env -> IO ()
|
||||
smpAgent _ port env = do
|
||||
-- tlsServerParams is not in Env to avoid breaking functional API w/t key and certificate generation
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
clientId <- newTVarIO initClientId
|
||||
runTransportServer started tcpPort tlsServerParams defaultTransportServerConfig $ \(h :: c) -> do
|
||||
liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
runTransportServer started port tlsServerParams defaultTransportServerConfig $ \(h :: c) -> do
|
||||
putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
cId <- atomically $ stateTVar clientId $ \i -> (i + 1, i + 1)
|
||||
c <- getAgentClient cId initServers
|
||||
c <- atomically $ newAgentClient cId initServers env
|
||||
logConnection c True
|
||||
race_ (connectClient h c) (runAgentClient c)
|
||||
`E.finally` disconnectAgentClient c
|
||||
race_ (connectClient h c) (runAgentClient c `runReaderT` env)
|
||||
`E.finally` (disconnectAgentClient c)
|
||||
|
||||
connectClient :: Transport c => MonadUnliftIO m => c -> AgentClient -> m ()
|
||||
connectClient :: Transport c => c -> AgentClient -> IO ()
|
||||
connectClient h c = race_ (send h c) (receive h c)
|
||||
|
||||
receive :: forall c m. (Transport c, MonadUnliftIO m) => c -> AgentClient -> m ()
|
||||
receive :: forall c. Transport c => c -> AgentClient -> IO ()
|
||||
receive h c@AgentClient {rcvQ, subQ} = forever $ do
|
||||
(corrId, entId, cmdOrErr) <- tGet SClient h
|
||||
case cmdOrErr of
|
||||
Right cmd -> write rcvQ (corrId, entId, cmd)
|
||||
Left e -> write subQ (corrId, entId, APC SAEConn $ ERR e)
|
||||
where
|
||||
write :: TBQueue (ATransmission p) -> ATransmission p -> m ()
|
||||
write :: TBQueue (ATransmission p) -> ATransmission p -> IO ()
|
||||
write q t = do
|
||||
logClient c "-->" t
|
||||
atomically $ writeTBQueue q t
|
||||
|
||||
send :: (Transport c, MonadUnliftIO m) => c -> AgentClient -> m ()
|
||||
send :: Transport c => c -> AgentClient -> IO ()
|
||||
send h c@AgentClient {subQ} = forever $ do
|
||||
t <- atomically $ readTBQueue subQ
|
||||
tPut h t
|
||||
logClient c "<--" t
|
||||
|
||||
logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> m ()
|
||||
logClient :: AgentClient -> ByteString -> ATransmission a -> IO ()
|
||||
logClient AgentClient {clientId} dir (corrId, connId, APC _ cmd) = do
|
||||
logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, connId, B.takeWhile (/= ' ') $ serializeCommand cmd]
|
||||
|
||||
@@ -231,7 +231,6 @@ import Data.Bifunctor (first, second)
|
||||
import Data.ByteArray (ScrubbedBytes)
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (toLower)
|
||||
import Data.Functor (($>))
|
||||
@@ -271,6 +270,7 @@ import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..))
|
||||
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys, PQEncryption (..), PQSupport (..))
|
||||
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..))
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
|
||||
@@ -97,7 +97,7 @@ import Data.List (find)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Time.Clock (UTCTime, getCurrentTime)
|
||||
import Data.Time.Clock (UTCTime (..), getCurrentTime)
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
@@ -138,11 +138,12 @@ data PClient v err msg = PClient
|
||||
msgQ :: Maybe (TBQueue (ServerTransmission v msg))
|
||||
}
|
||||
|
||||
smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM (ProtocolClient SMPVersion err msg)
|
||||
smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM SMPClient
|
||||
smpClientStub g sessionId thVersion thAuth = do
|
||||
connected <- newTVar False
|
||||
clientCorrId <- C.newRandomDRG g
|
||||
sentCommands <- TM.empty
|
||||
pingErrorCount <- newTVar 0
|
||||
sndQ <- newTBQueue 100
|
||||
rcvQ <- newTBQueue 100
|
||||
return
|
||||
@@ -157,15 +158,15 @@ smpClientStub g sessionId thVersion thAuth = do
|
||||
implySessId = thVersion >= authCmdsSMPVersion,
|
||||
batch = True
|
||||
},
|
||||
sessionTs = undefined,
|
||||
sessionTs = UTCTime (read "2024-03-31") 0,
|
||||
client_ =
|
||||
PClient
|
||||
{ connected,
|
||||
transportSession = undefined,
|
||||
transportHost = undefined,
|
||||
tcpTimeout = undefined,
|
||||
transportSession = (1, "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001", Nothing),
|
||||
transportHost = "localhost",
|
||||
tcpTimeout = 15_000_000,
|
||||
batchDelay = Nothing,
|
||||
pingErrorCount = undefined,
|
||||
pingErrorCount,
|
||||
clientCorrId,
|
||||
sentCommands,
|
||||
sndQ,
|
||||
@@ -239,6 +240,7 @@ defaultNetworkConfig =
|
||||
transportClientConfig :: NetworkConfig -> TransportClientConfig
|
||||
transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} =
|
||||
TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing}
|
||||
{-# INLINE transportClientConfig #-}
|
||||
|
||||
-- | protocol client configuration.
|
||||
data ProtocolClientConfig v = ProtocolClientConfig
|
||||
@@ -264,9 +266,11 @@ defaultClientConfig serverVRange =
|
||||
serverVRange,
|
||||
batchDelay = Nothing
|
||||
}
|
||||
{-# INLINE defaultClientConfig #-}
|
||||
|
||||
defaultSMPClientConfig :: ProtocolClientConfig SMPVersion
|
||||
defaultSMPClientConfig = defaultClientConfig supportedClientSMPRelayVRange
|
||||
{-# INLINE defaultSMPClientConfig #-}
|
||||
|
||||
data Request err msg = Request
|
||||
{ entityId :: EntityId,
|
||||
@@ -296,12 +300,15 @@ protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient v err ms
|
||||
protocolClientServer = B.unpack . strEncode . snd3 . transportSession . client_
|
||||
where
|
||||
snd3 (_, s, _) = s
|
||||
{-# INLINE protocolClientServer #-}
|
||||
|
||||
transportHost' :: ProtocolClient v err msg -> TransportHost
|
||||
transportHost' = transportHost . client_
|
||||
{-# INLINE transportHost' #-}
|
||||
|
||||
transportSession' :: ProtocolClient v err msg -> TransportSession msg
|
||||
transportSession' = transportSession . client_
|
||||
{-# INLINE transportSession' #-}
|
||||
|
||||
type UserId = Int64
|
||||
|
||||
@@ -426,10 +433,12 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
|
||||
|
||||
proxyUsername :: TransportSession msg -> ByteString
|
||||
proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_
|
||||
{-# INLINE proxyUsername #-}
|
||||
|
||||
-- | Disconnects client from the server and terminates client threads.
|
||||
closeProtocolClient :: ProtocolClient v err msg -> IO ()
|
||||
closeProtocolClient = mapM_ uninterruptibleCancel . action
|
||||
{-# INLINE closeProtocolClient #-}
|
||||
|
||||
-- | SMP client error type.
|
||||
data ProtocolClientError err
|
||||
@@ -469,6 +478,7 @@ temporaryClientError = \case
|
||||
PCEResponseTimeout -> True
|
||||
PCEIOError _ -> True
|
||||
_ -> False
|
||||
{-# INLINE temporaryClientError #-}
|
||||
|
||||
-- | Create a new SMP queue.
|
||||
--
|
||||
@@ -536,16 +546,19 @@ getSMPMessage c rpKey rId =
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications
|
||||
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateAuthKey -> NotifierId -> ExceptT SMPClientError IO ()
|
||||
subscribeSMPQueueNotifications = okSMPCommand NSUB
|
||||
{-# INLINE subscribeSMPQueueNotifications #-}
|
||||
|
||||
-- | Subscribe to multiple SMP queues notifications batching commands if supported.
|
||||
subscribeSMPQueuesNtfs :: SMPClient -> NonEmpty (NtfPrivateAuthKey, NotifierId) -> IO (NonEmpty (Either SMPClientError ()))
|
||||
subscribeSMPQueuesNtfs = okSMPCommands NSUB
|
||||
{-# INLINE subscribeSMPQueuesNtfs #-}
|
||||
|
||||
-- | Secure the SMP queue by adding a sender public key.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command
|
||||
secureSMPQueue :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> SndPublicAuthKey -> ExceptT SMPClientError IO ()
|
||||
secureSMPQueue c rpKey rId senderKey = okSMPCommand (KEY senderKey) c rpKey rId
|
||||
{-# INLINE secureSMPQueue #-}
|
||||
|
||||
-- | Enable notifications for the queue for push notifications server.
|
||||
--
|
||||
@@ -571,10 +584,12 @@ enableSMPQueuesNtfs c qs = L.map process <$> sendProtocolCommands c cs
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#disable-notifications-command
|
||||
disableSMPQueueNotifications :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> ExceptT SMPClientError IO ()
|
||||
disableSMPQueueNotifications = okSMPCommand NDEL
|
||||
{-# INLINE disableSMPQueueNotifications #-}
|
||||
|
||||
-- | Disable notifications for multiple queues for push notifications server.
|
||||
disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateAuthKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
|
||||
disableSMPQueuesNtfs = okSMPCommands NDEL
|
||||
{-# INLINE disableSMPQueuesNtfs #-}
|
||||
|
||||
-- | Send SMP message.
|
||||
--
|
||||
@@ -601,16 +616,19 @@ ackSMPMessage c rpKey rId msgId =
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue
|
||||
suspendSMPQueue :: SMPClient -> RcvPrivateAuthKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
suspendSMPQueue = okSMPCommand OFF
|
||||
{-# INLINE suspendSMPQueue #-}
|
||||
|
||||
-- | Irreversibly delete SMP queue and all messages in it.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue
|
||||
deleteSMPQueue :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> ExceptT SMPClientError IO ()
|
||||
deleteSMPQueue = okSMPCommand DEL
|
||||
{-# INLINE deleteSMPQueue #-}
|
||||
|
||||
-- | Delete multiple SMP queues batching commands if supported.
|
||||
deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateAuthKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ()))
|
||||
deleteSMPQueues = okSMPCommands DEL
|
||||
{-# INLINE deleteSMPQueues #-}
|
||||
|
||||
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateAuthKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
okSMPCommand cmd c pKey qId =
|
||||
@@ -631,6 +649,7 @@ okSMPCommands cmd c qs = L.map process <$> sendProtocolCommands c cs
|
||||
-- | Send SMP command
|
||||
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateAuthKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg
|
||||
sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
|
||||
{-# INLINE sendSMPCommand #-}
|
||||
|
||||
type PCTransmission err msg = (Either TransportError SentRawTransmission, Request err msg)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
@@ -18,6 +19,7 @@ import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Trans.Except
|
||||
import Control.Monad.Trans.Reader
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
@@ -106,12 +108,24 @@ newtype InternalException e = InternalException {unInternalException :: e}
|
||||
|
||||
instance Exception e => Exception (InternalException e)
|
||||
|
||||
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
|
||||
withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b
|
||||
withRunInIO exceptToIO =
|
||||
instance Exception e => MonadUnliftIO (ExceptT e IO) where
|
||||
{-# INLINE withRunInIO #-}
|
||||
withRunInIO :: ((forall a. ExceptT e IO a -> IO a) -> IO b) -> ExceptT e IO b
|
||||
withRunInIO inner =
|
||||
ExceptT . fmap (first unInternalException) . E.try $
|
||||
withRunInIO $ \run ->
|
||||
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
|
||||
-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`,
|
||||
-- the last two lines could be replaced with:
|
||||
-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT
|
||||
|
||||
instance Exception e => MonadUnliftIO (ExceptT e (ReaderT r IO)) where
|
||||
{-# INLINE withRunInIO #-}
|
||||
withRunInIO :: ((forall a. ExceptT e (ReaderT r IO) a -> IO a) -> IO b) -> ExceptT e (ReaderT r IO) b
|
||||
withRunInIO inner =
|
||||
withExceptT unInternalException . ExceptT . E.try $
|
||||
withRunInIO $ \run ->
|
||||
exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT)
|
||||
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
|
||||
|
||||
newSMPClientAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> STM SMPClientAgent
|
||||
newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg = do
|
||||
@@ -147,7 +161,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
|
||||
Nothing -> Left PCEResponseTimeout
|
||||
|
||||
newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
|
||||
newSMPClient smpVar = tryConnectClient pure tryConnectAsync
|
||||
newSMPClient smpVar = tryConnectClient pure (liftIO tryConnectAsync)
|
||||
where
|
||||
tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a
|
||||
tryConnectClient successAction retryAction =
|
||||
@@ -163,9 +177,9 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
|
||||
putTMVar smpVar (Left e)
|
||||
TM.delete srv smpClients
|
||||
throwE e
|
||||
tryConnectAsync :: ExceptT SMPClientError IO ()
|
||||
tryConnectAsync :: IO ()
|
||||
tryConnectAsync = do
|
||||
a <- async connectAsync
|
||||
a <- async $ void $ runExceptT connectAsync
|
||||
atomically $ modifyTVar' (asyncClients ca) (a :)
|
||||
connectAsync :: ExceptT SMPClientError IO ()
|
||||
connectAsync =
|
||||
@@ -199,11 +213,11 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
|
||||
serverDown :: Map SMPSub C.APrivateAuthKey -> IO ()
|
||||
serverDown ss = unless (M.null ss) $ do
|
||||
notify . CADisconnected srv $ M.keysSet ss
|
||||
void $ runExceptT reconnectServer
|
||||
reconnectServer
|
||||
|
||||
reconnectServer :: ExceptT SMPClientError IO ()
|
||||
reconnectServer :: IO ()
|
||||
reconnectServer = do
|
||||
a <- async tryReconnectClient
|
||||
a <- async $ void $ runExceptT tryReconnectClient
|
||||
atomically $ modifyTVar' (reconnections ca) (a :)
|
||||
|
||||
tryReconnectClient :: ExceptT SMPClientError IO ()
|
||||
@@ -247,8 +261,8 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
|
||||
notify :: SMPClientAgentEvent -> IO ()
|
||||
notify evt = atomically $ writeTBQueue (agentQ ca) evt
|
||||
|
||||
closeSMPClientAgent :: MonadUnliftIO m => SMPClientAgent -> m ()
|
||||
closeSMPClientAgent c = liftIO $ do
|
||||
closeSMPClientAgent :: SMPClientAgent -> IO ()
|
||||
closeSMPClientAgent c = do
|
||||
closeSMPServerClients c
|
||||
cancelActions $ reconnections c
|
||||
cancelActions $ asyncClients c
|
||||
|
||||
@@ -211,8 +211,6 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import Data.ByteArray (ByteArrayAccess)
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString.Base64 (decode, encode)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.ByteString.Lazy (fromStrict, toStrict)
|
||||
@@ -230,6 +228,8 @@ import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+))
|
||||
import Network.Transport.Internal (decodeWord16, encodeWord16)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.Base64 (decode, encode)
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (blobFieldDecoder, parseAll, parseString)
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
@@ -76,7 +76,7 @@ withFile :: CryptoFile -> IOMode -> (CryptoFileHandle -> ExceptT FTCryptoError I
|
||||
withFile (CryptoFile path cfArgs) mode action = do
|
||||
sb <- forM cfArgs $ \(CFArgs key nonce) ->
|
||||
liftEitherWith FTCECryptoError (LC.sbInit key nonce) >>= newTVarIO
|
||||
IO.withFile path mode $ \h -> action $ CFHandle h sb
|
||||
ExceptT . IO.withFile path mode $ \h -> runExceptT $ action $ CFHandle h sb
|
||||
|
||||
hPut :: CryptoFileHandle -> LazyByteString -> IO ()
|
||||
hPut (CFHandle h sb_) s = LB.hPut h =<< maybe (pure s) encrypt sb_
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
-- | Compatibility wrappers for base64 package, Base64 (padded) variant.
|
||||
module Simplex.Messaging.Encoding.Base64 where
|
||||
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Base64.Types (extractBase64)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64 (decodeBase64Untyped, encodeBase64')
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Text as T
|
||||
|
||||
encode :: ByteString -> ByteString
|
||||
encode = extractBase64 . encodeBase64'
|
||||
{-# INLINE encode #-}
|
||||
|
||||
decode :: ByteString -> Either String ByteString
|
||||
decode = first T.unpack . decodeBase64Untyped
|
||||
{-# INLINE decode #-}
|
||||
|
||||
base64P :: A.Parser ByteString
|
||||
base64P = do
|
||||
str <- A.takeWhile1 (`B.elem` base64Alphabet)
|
||||
pad <- A.takeWhile (== '=') -- correct amount of padding can be derived from str length
|
||||
either (fail . T.unpack) pure $ decodeBase64Untyped (str <> pad)
|
||||
|
||||
base64Alphabet :: ByteString
|
||||
base64Alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
|
||||
@@ -0,0 +1,33 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
-- | Compatibility wrappers for base64 package, Base64URL-padded variant.
|
||||
module Simplex.Messaging.Encoding.Base64.URL where
|
||||
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Base64.Types (extractBase64)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64.URL (decodeBase64Lenient, decodeBase64UnpaddedUntyped, decodeBase64Untyped, encodeBase64')
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Text as T
|
||||
|
||||
encode :: ByteString -> ByteString
|
||||
encode = extractBase64 . encodeBase64'
|
||||
{-# INLINE encode #-}
|
||||
|
||||
decode :: ByteString -> Either String ByteString
|
||||
decode = first T.unpack . decodeBase64Untyped
|
||||
{-# INLINE decode #-}
|
||||
|
||||
decodeLenient :: ByteString -> ByteString
|
||||
decodeLenient = decodeBase64Lenient
|
||||
{-# INLINE decodeLenient #-}
|
||||
|
||||
base64urlP :: A.Parser ByteString
|
||||
base64urlP = do
|
||||
str <- A.takeWhile1 (`B.elem` base64AlphabetURL)
|
||||
_pad <- A.takeWhile (== '=') -- correct amount of padding can be derived from str length
|
||||
either (fail . T.unpack) pure $ decodeBase64UnpaddedUntyped str
|
||||
|
||||
base64AlphabetURL :: ByteString
|
||||
base64AlphabetURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
|
||||
@@ -10,7 +10,6 @@ module Simplex.Messaging.Encoding.String
|
||||
strToJSON,
|
||||
strToJEncoding,
|
||||
strParseJSON,
|
||||
base64urlP,
|
||||
strEncodeList,
|
||||
strListP,
|
||||
)
|
||||
@@ -23,10 +22,8 @@ import qualified Data.Aeson.Encoding as JE
|
||||
import qualified Data.Aeson.Types as JT
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isAlphaNum)
|
||||
import Data.Int (Int64)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Set (Set)
|
||||
@@ -38,6 +35,7 @@ import Data.Time.Clock.System (SystemTime (..))
|
||||
import Data.Time.Format.ISO8601
|
||||
import Data.Word (Word16, Word32)
|
||||
import Simplex.Messaging.Encoding
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
@@ -54,19 +52,16 @@ class StrEncoding a where
|
||||
strDecode :: ByteString -> Either String a
|
||||
strDecode = parseAll strP
|
||||
strP :: Parser a
|
||||
strP = strDecode <$?> base64urlP
|
||||
strP = strDecode <$?> U.base64urlP
|
||||
|
||||
-- base64url encoding/decoding of ByteStrings - the parser only allows non-empty strings
|
||||
instance StrEncoding ByteString where
|
||||
strEncode = U.encode
|
||||
{-# INLINE strEncode #-}
|
||||
strDecode = U.decode
|
||||
strP = base64urlP
|
||||
|
||||
base64urlP :: Parser ByteString
|
||||
base64urlP = do
|
||||
str <- A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_')
|
||||
pad <- A.takeWhile (== '=')
|
||||
either fail pure $ U.decode (str <> pad)
|
||||
{-# INLINE strDecode #-}
|
||||
strP = U.base64urlP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
newtype Str = Str {unStr :: ByteString}
|
||||
deriving (Eq, Show)
|
||||
|
||||
@@ -81,7 +81,8 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do
|
||||
runServer (tcpPort, ATransport t) = do
|
||||
serverParams <- asks tlsServerParams
|
||||
serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams
|
||||
runTransportServer started tcpPort serverParams tCfg (runClient serverSignKey t)
|
||||
env <- ask
|
||||
liftIO $ runTransportServer started tcpPort serverParams tCfg $ \h -> runClient serverSignKey t h `runReaderT` env
|
||||
fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey
|
||||
|
||||
runClient :: Transport c => C.APrivateSignKey -> TProxy c -> c -> M ()
|
||||
|
||||
@@ -83,7 +83,7 @@ data NtfEnv = NtfEnv
|
||||
serverStats :: NtfServerStats
|
||||
}
|
||||
|
||||
newNtfServerEnv :: (MonadUnliftIO m, MonadRandom m) => NtfServerConfig -> m NtfEnv
|
||||
newNtfServerEnv :: NtfServerConfig -> IO NtfEnv
|
||||
newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsConfig, storeLogFile, caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
random <- liftIO C.newRandom
|
||||
store <- atomically newNtfStore
|
||||
|
||||
@@ -27,8 +27,9 @@ import Data.Aeson (ToJSON, (.=))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
import qualified Data.Aeson.TH as JQ
|
||||
import Data.Base64.Types (extractBase64)
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import qualified Data.ByteString.Base64.URL as UP
|
||||
import Data.ByteString.Builder (lazyByteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
@@ -46,6 +47,7 @@ import Network.HTTP2.Client (Request)
|
||||
import qualified Network.HTTP2.Client as H
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS.Internal
|
||||
@@ -91,8 +93,8 @@ signedJWTToken pk (JWTToken hdr claims) = do
|
||||
pure $ hc <> "." <> serialize sig
|
||||
where
|
||||
jwtEncode :: ToJSON a => a -> ByteString
|
||||
jwtEncode = U.encodeUnpadded . LB.toStrict . J.encode
|
||||
serialize sig = U.encodeUnpadded $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
|
||||
jwtEncode = extractBase64 . UP.encodeBase64Unpadded' . LB.toStrict . J.encode
|
||||
serialize sig = extractBase64 . UP.encodeBase64Unpadded' $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
|
||||
|
||||
readECPrivateKey :: FilePath -> IO EC.PrivateKey
|
||||
readECPrivateKey f = do
|
||||
|
||||
@@ -10,10 +10,9 @@ import qualified Data.Aeson as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isAlphaNum, toLower)
|
||||
import Data.Char (toLower)
|
||||
import Data.String
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
@@ -24,23 +23,8 @@ import Database.SQLite.Simple (ResultError (..), SQLData (..))
|
||||
import Database.SQLite.Simple.FromField (FieldParser, returnError)
|
||||
import Database.SQLite.Simple.Internal (Field (..))
|
||||
import Database.SQLite.Simple.Ok (Ok (Ok))
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
import Text.Read (readMaybe)
|
||||
|
||||
base64P :: Parser ByteString
|
||||
base64P = decode <$?> paddedBase64 rawBase64P
|
||||
|
||||
paddedBase64 :: Parser ByteString -> Parser ByteString
|
||||
paddedBase64 raw = (<>) <$> raw <*> pad
|
||||
where
|
||||
pad = A.takeWhile (== '=')
|
||||
|
||||
rawBase64P :: Parser ByteString
|
||||
rawBase64P = A.takeWhile1 (\c -> isAlphaNum c || c == '+' || c == '/')
|
||||
|
||||
-- rawBase64UriP :: Parser ByteString
|
||||
-- rawBase64UriP = A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_')
|
||||
|
||||
tsISO8601P :: Parser UTCTime
|
||||
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd
|
||||
|
||||
|
||||
@@ -176,7 +176,6 @@ import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64 as B64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isPrint, isSpace)
|
||||
@@ -194,6 +193,7 @@ import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
|
||||
import Network.Socket (ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import qualified Simplex.Messaging.Encoding.Base64 as B64
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.ServiceScheme
|
||||
|
||||
@@ -45,7 +45,6 @@ import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64 (encode)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (fromRight, partitionEithers)
|
||||
@@ -68,6 +67,7 @@ import Network.Socket (ServiceName, Socket, socketToHandle)
|
||||
import Simplex.Messaging.Agent.Lock
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding (Encoding (smpEncode))
|
||||
import Simplex.Messaging.Encoding.Base64 (encode)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.Control
|
||||
@@ -128,14 +128,15 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
||||
: serverThread s "server ntfSubscribedQ" ntfSubscribedQ Env.notifiers ntfSubscriptions (\_ -> pure ())
|
||||
: map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg
|
||||
)
|
||||
`finally` withLock (savingLock s) "final" (saveServer False)
|
||||
`finally` withLock' (savingLock s) "final" (saveServer False)
|
||||
where
|
||||
runServer :: (ServiceName, ATransport) -> M ()
|
||||
runServer (tcpPort, ATransport t) = do
|
||||
serverParams <- asks tlsServerParams
|
||||
ss <- asks sockets
|
||||
serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams
|
||||
runTransportServerState ss started tcpPort serverParams tCfg (runClient serverSignKey t)
|
||||
env <- ask
|
||||
liftIO $ runTransportServerState ss started tcpPort serverParams tCfg $ \h -> runClient serverSignKey t h `runReaderT` env
|
||||
fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey
|
||||
|
||||
saveServer :: Bool -> M ()
|
||||
@@ -387,7 +388,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
||||
withLog (`logDeleteQueue` queueId)
|
||||
updateDeletedStats q
|
||||
liftIO . hPutStrLn h $ "ok, " <> show numDeleted <> " messages deleted"
|
||||
CPSave -> withAdminRole $ withLock (savingLock srv) "control" $ do
|
||||
CPSave -> withAdminRole $ withLock' (savingLock srv) "control" $ do
|
||||
hPutStrLn h "saving server state..."
|
||||
unliftIO u $ saveServer True
|
||||
hPutStrLn h "server state saved!"
|
||||
@@ -579,7 +580,7 @@ dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/Xo
|
||||
dummyKeyX25519 :: C.PublicKey 'C.X25519
|
||||
dummyKeyX25519 = "MCowBQYDK2VuAyEA4JGSMYht18H4mas/jHeBwfcM7jLwNYJNOAhi2/g4RXg="
|
||||
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m ()
|
||||
client :: Client -> Server -> M ()
|
||||
client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Server {subscribedQ, ntfSubscribedQ, notifiers} = do
|
||||
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands"
|
||||
forever $
|
||||
@@ -587,7 +588,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
>>= mapM processCommand
|
||||
>>= atomically . writeTBQueue sndQ
|
||||
where
|
||||
processCommand :: (Maybe QueueRec, Transmission Cmd) -> m (Transmission BrokerMsg)
|
||||
processCommand :: (Maybe QueueRec, Transmission Cmd) -> M (Transmission BrokerMsg)
|
||||
processCommand (qr_, (corrId, queueId, cmd)) = do
|
||||
st <- asks queueStore
|
||||
case cmd of
|
||||
@@ -616,7 +617,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
OFF -> suspendQueue_ st
|
||||
DEL -> delQueueAndMsgs st
|
||||
where
|
||||
createQueue :: QueueStore -> RcvPublicAuthKey -> RcvPublicDhKey -> SubscriptionMode -> m (Transmission BrokerMsg)
|
||||
createQueue :: QueueStore -> RcvPublicAuthKey -> RcvPublicDhKey -> SubscriptionMode -> M (Transmission BrokerMsg)
|
||||
createQueue st recipientKey dhKey subMode = time "NEW" $ do
|
||||
(rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random
|
||||
let rcvDhSecret = C.dh' dhKey privDhKey
|
||||
@@ -634,7 +635,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
(corrId,queueId,) <$> addQueueRetry 3 qik qRec
|
||||
where
|
||||
addQueueRetry ::
|
||||
Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> m BrokerMsg
|
||||
Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> M BrokerMsg
|
||||
addQueueRetry 0 _ _ = pure $ ERR INTERNAL
|
||||
addQueueRetry n qik qRec = do
|
||||
ids@(rId, _) <- getIds
|
||||
@@ -659,25 +660,25 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
Right q -> logCreateQueue s q
|
||||
_ -> pure ()
|
||||
|
||||
getIds :: m (RecipientId, SenderId)
|
||||
getIds :: M (RecipientId, SenderId)
|
||||
getIds = do
|
||||
n <- asks $ queueIdBytes . config
|
||||
liftM2 (,) (randomId n) (randomId n)
|
||||
|
||||
secureQueue_ :: QueueStore -> SndPublicAuthKey -> m (Transmission BrokerMsg)
|
||||
secureQueue_ :: QueueStore -> SndPublicAuthKey -> M (Transmission BrokerMsg)
|
||||
secureQueue_ st sKey = time "KEY" $ do
|
||||
withLog $ \s -> logSecureQueue s queueId sKey
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar' (qSecured stats) (+ 1)
|
||||
atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey
|
||||
|
||||
addQueueNotifier_ :: QueueStore -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> m (Transmission BrokerMsg)
|
||||
addQueueNotifier_ :: QueueStore -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> M (Transmission BrokerMsg)
|
||||
addQueueNotifier_ st notifierKey dhKey = time "NKEY" $ do
|
||||
(rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random
|
||||
let rcvNtfDhSecret = C.dh' dhKey privDhKey
|
||||
(corrId,queueId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret
|
||||
where
|
||||
addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> m BrokerMsg
|
||||
addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> M BrokerMsg
|
||||
addNotifierRetry 0 _ _ = pure $ ERR INTERNAL
|
||||
addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do
|
||||
notifierId <- randomId =<< asks (queueIdBytes . config)
|
||||
@@ -689,17 +690,17 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
withLog $ \s -> logAddNotifier s queueId ntfCreds
|
||||
pure $ NID notifierId rcvPublicDhKey
|
||||
|
||||
deleteQueueNotifier_ :: QueueStore -> m (Transmission BrokerMsg)
|
||||
deleteQueueNotifier_ :: QueueStore -> M (Transmission BrokerMsg)
|
||||
deleteQueueNotifier_ st = do
|
||||
withLog (`logDeleteNotifier` queueId)
|
||||
okResp <$> atomically (deleteQueueNotifier st queueId)
|
||||
|
||||
suspendQueue_ :: QueueStore -> m (Transmission BrokerMsg)
|
||||
suspendQueue_ :: QueueStore -> M (Transmission BrokerMsg)
|
||||
suspendQueue_ st = do
|
||||
withLog (`logSuspendQueue` queueId)
|
||||
okResp <$> atomically (suspendQueue st queueId)
|
||||
|
||||
subscribeQueue :: QueueRec -> RecipientId -> m (Transmission BrokerMsg)
|
||||
subscribeQueue :: QueueRec -> RecipientId -> M (Transmission BrokerMsg)
|
||||
subscribeQueue qr rId = do
|
||||
atomically (TM.lookup rId subscriptions) >>= \case
|
||||
Nothing ->
|
||||
@@ -712,19 +713,19 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
s ->
|
||||
atomically (tryTakeTMVar $ delivered s) >> deliver sub
|
||||
where
|
||||
newSub :: m (TVar Sub)
|
||||
newSub :: M (TVar Sub)
|
||||
newSub = time "SUB newSub" . atomically $ do
|
||||
writeTQueue subscribedQ (rId, clnt)
|
||||
sub <- newTVar =<< newSubscription NoSub
|
||||
TM.insert rId sub subscriptions
|
||||
pure sub
|
||||
deliver :: TVar Sub -> m (Transmission BrokerMsg)
|
||||
deliver :: TVar Sub -> M (Transmission BrokerMsg)
|
||||
deliver sub = do
|
||||
q <- getStoreMsgQueue "SUB" rId
|
||||
msg_ <- atomically $ tryPeekMsg q
|
||||
deliverMessage "SUB" qr rId sub q msg_
|
||||
|
||||
getMessage :: QueueRec -> m (Transmission BrokerMsg)
|
||||
getMessage :: QueueRec -> M (Transmission BrokerMsg)
|
||||
getMessage qr = time "GET" $ do
|
||||
atomically (TM.lookup queueId subscriptions) >>= \case
|
||||
Nothing ->
|
||||
@@ -743,7 +744,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
sub <- newTVar s
|
||||
TM.insert queueId sub subscriptions
|
||||
pure s
|
||||
getMessage_ :: Sub -> m (Transmission BrokerMsg)
|
||||
getMessage_ :: Sub -> M (Transmission BrokerMsg)
|
||||
getMessage_ s = do
|
||||
q <- getStoreMsgQueue "GET" queueId
|
||||
atomically $
|
||||
@@ -753,17 +754,17 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
in setDelivered s msg $> (corrId, queueId, MSG encMsg)
|
||||
_ -> pure (corrId, queueId, OK)
|
||||
|
||||
withQueue :: (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg)
|
||||
withQueue :: (QueueRec -> M (Transmission BrokerMsg)) -> M (Transmission BrokerMsg)
|
||||
withQueue action = maybe (pure $ err AUTH) action qr_
|
||||
|
||||
subscribeNotifications :: m (Transmission BrokerMsg)
|
||||
subscribeNotifications :: M (Transmission BrokerMsg)
|
||||
subscribeNotifications = time "NSUB" . atomically $ do
|
||||
unlessM (TM.member queueId ntfSubscriptions) $ do
|
||||
writeTQueue ntfSubscribedQ (queueId, clnt)
|
||||
TM.insert queueId () ntfSubscriptions
|
||||
pure ok
|
||||
|
||||
acknowledgeMsg :: QueueRec -> MsgId -> m (Transmission BrokerMsg)
|
||||
acknowledgeMsg :: QueueRec -> MsgId -> M (Transmission BrokerMsg)
|
||||
acknowledgeMsg qr msgId = time "ACK" $ do
|
||||
atomically (TM.lookup queueId subscriptions) >>= \case
|
||||
Nothing -> pure $ err NO_MSG
|
||||
@@ -789,7 +790,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
if msgId == msgId' || B.null msgId
|
||||
then pure $ Just s
|
||||
else putTMVar delivered msgId' $> Nothing
|
||||
updateStats :: Message -> m ()
|
||||
updateStats :: Message -> M ()
|
||||
updateStats = \case
|
||||
MessageQuota {} -> pure ()
|
||||
Message {msgFlags} -> do
|
||||
@@ -801,7 +802,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
atomically $ modifyTVar' (msgRecvNtf stats) (+ 1)
|
||||
atomically $ updatePeriodStats (activeQueuesNtf stats) queueId
|
||||
|
||||
sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
|
||||
sendMessage :: QueueRec -> MsgFlags -> MsgBody -> M (Transmission BrokerMsg)
|
||||
sendMessage qr msgFlags msgBody
|
||||
| B.length msgBody > maxMessageLength = pure $ err LARGE_MSG
|
||||
| otherwise = case status qr of
|
||||
@@ -827,13 +828,13 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
|
||||
pure ok
|
||||
where
|
||||
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
|
||||
mkMessage :: C.MaxLenBS MaxMessageLen -> M Message
|
||||
mkMessage body = do
|
||||
msgId <- randomId =<< asks (msgIdBytes . config)
|
||||
msgTs <- liftIO getSystemTime
|
||||
pure $ Message msgId msgTs msgFlags body
|
||||
|
||||
expireMessages :: MsgQueue -> m ()
|
||||
expireMessages :: MsgQueue -> M ()
|
||||
expireMessages q = do
|
||||
msgExp <- asks $ messageExpiration . config
|
||||
old <- liftIO $ mapM expireBeforeEpoch msgExp
|
||||
@@ -861,7 +862,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
|
||||
pure . (cbNonce,) $ fromRight "" encNMsgMeta
|
||||
|
||||
deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg)
|
||||
deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> M (Transmission BrokerMsg)
|
||||
deliverMessage name qr rId sub q msg_ = time (name <> " deliver") $ do
|
||||
readTVarIO sub >>= \case
|
||||
s@Sub {subThread = NoSub} ->
|
||||
@@ -872,7 +873,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
_ -> forkSub $> ok
|
||||
_ -> pure ok
|
||||
where
|
||||
forkSub :: m ()
|
||||
forkSub :: M ()
|
||||
forkSub = do
|
||||
atomically . modifyTVar' sub $ \s -> s {subThread = SubPending}
|
||||
t <- mkWeakThreadId =<< forkIO subscriber
|
||||
@@ -890,7 +891,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
void $ setDelivered s msg
|
||||
writeTVar sub $! s {subThread = NoSub}
|
||||
|
||||
time :: T.Text -> m a -> m a
|
||||
time :: T.Text -> M a -> M a
|
||||
time name = timed name queueId
|
||||
|
||||
encryptMsg :: QueueRec -> Message -> RcvMessage
|
||||
@@ -906,13 +907,13 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
setDelivered :: Sub -> Message -> STM Bool
|
||||
setDelivered s msg = tryPutTMVar (delivered s) (messageId msg)
|
||||
|
||||
getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue
|
||||
getStoreMsgQueue :: T.Text -> RecipientId -> M MsgQueue
|
||||
getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do
|
||||
ms <- asks msgStore
|
||||
quota <- asks $ msgQueueQuota . config
|
||||
atomically $ getMsgQueue ms rId quota
|
||||
|
||||
delQueueAndMsgs :: QueueStore -> m (Transmission BrokerMsg)
|
||||
delQueueAndMsgs :: QueueStore -> M (Transmission BrokerMsg)
|
||||
delQueueAndMsgs st = do
|
||||
withLog (`logDeleteQueue` queueId)
|
||||
ms <- asks msgStore
|
||||
@@ -929,7 +930,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv
|
||||
okResp :: Either ErrorType () -> Transmission BrokerMsg
|
||||
okResp = either err $ const ok
|
||||
|
||||
updateDeletedStats :: (MonadUnliftIO m, MonadReader Env m) => QueueRec -> m ()
|
||||
updateDeletedStats :: QueueRec -> M ()
|
||||
updateDeletedStats q = do
|
||||
stats <- asks serverStats
|
||||
let delSel = if isNothing (senderKey q) then qDeletedNew else qDeletedSecured
|
||||
@@ -937,12 +938,12 @@ updateDeletedStats q = do
|
||||
atomically $ modifyTVar' (qDeletedAll stats) (+ 1)
|
||||
atomically $ modifyTVar' (qCount stats) (subtract 1)
|
||||
|
||||
withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m ()
|
||||
withLog :: (StoreLog 'WriteMode -> IO a) -> M ()
|
||||
withLog action = do
|
||||
env <- ask
|
||||
liftIO . mapM_ action $ storeLog (env :: Env)
|
||||
|
||||
timed :: MonadUnliftIO m => T.Text -> RecipientId -> m a -> m a
|
||||
timed :: T.Text -> RecipientId -> M a -> M a
|
||||
timed name qId a = do
|
||||
t <- liftIO getSystemTime
|
||||
r <- a
|
||||
@@ -954,10 +955,10 @@ timed name qId a = do
|
||||
diff t t' = (systemSeconds t' - systemSeconds t) * sec + fromIntegral (systemNanoseconds t' - systemNanoseconds t)
|
||||
sec = 1000_000000
|
||||
|
||||
randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m ByteString
|
||||
randomId :: Int -> M ByteString
|
||||
randomId n = atomically . C.randomBytes n =<< asks random
|
||||
|
||||
saveServerMessages :: (MonadUnliftIO m, MonadReader Env m) => Bool -> m ()
|
||||
saveServerMessages :: Bool -> M ()
|
||||
saveServerMessages keepMsgs = asks (storeMsgsFile . config) >>= mapM_ saveMessages
|
||||
where
|
||||
saveMessages f = do
|
||||
@@ -972,7 +973,7 @@ saveServerMessages keepMsgs = asks (storeMsgsFile . config) >>= mapM_ saveMessag
|
||||
atomically (getMessages ms rId)
|
||||
>>= mapM_ (B.hPutStrLn h . strEncode . MLRv3 rId)
|
||||
|
||||
restoreServerMessages :: forall m. (MonadUnliftIO m, MonadReader Env m) => m Int
|
||||
restoreServerMessages :: M Int
|
||||
restoreServerMessages = asks (storeMsgsFile . config) >>= \case
|
||||
Just f -> ifM (doesFileExist f) (restoreMessages f) (pure 0)
|
||||
Nothing -> pure 0
|
||||
@@ -1008,7 +1009,7 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= \case
|
||||
msgErr :: Show e => String -> e -> String
|
||||
msgErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s)
|
||||
|
||||
saveServerStats :: (MonadUnliftIO m, MonadReader Env m) => m ()
|
||||
saveServerStats :: M ()
|
||||
saveServerStats =
|
||||
asks (serverStatsBackupFile . config)
|
||||
>>= mapM_ (\f -> asks serverStats >>= atomically . getServerStatsData >>= liftIO . saveStats f)
|
||||
@@ -1018,7 +1019,7 @@ saveServerStats =
|
||||
B.writeFile f $ strEncode stats
|
||||
logInfo "server stats saved"
|
||||
|
||||
restoreServerStats :: (MonadUnliftIO m, MonadReader Env m) => Int -> m ()
|
||||
restoreServerStats :: Int -> M ()
|
||||
restoreServerStats expiredWhileRestoring = asks (serverStatsBackupFile . config) >>= mapM_ restoreStats
|
||||
where
|
||||
restoreStats f = whenM (doesFileExist f) $ do
|
||||
|
||||
@@ -173,25 +173,25 @@ newSubscription subThread = do
|
||||
delivered <- newEmptyTMVar
|
||||
return Sub {subThread, delivered}
|
||||
|
||||
newEnv :: forall m. (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env
|
||||
newEnv :: ServerConfig -> IO Env
|
||||
newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile, storeLogFile} = do
|
||||
server <- atomically newServer
|
||||
queueStore <- atomically newQueueStore
|
||||
msgStore <- atomically newMsgStore
|
||||
random <- liftIO C.newRandom
|
||||
storeLog <- restoreQueues queueStore `mapM` storeLogFile
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile
|
||||
tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
Fingerprint fp <- loadFingerprint caCertificateFile
|
||||
let serverIdentity = KeyHash fp
|
||||
serverStats <- atomically . newServerStats =<< liftIO getCurrentTime
|
||||
serverStats <- atomically . newServerStats =<< getCurrentTime
|
||||
sockets <- atomically newSocketState
|
||||
clientSeq <- newTVarIO 0
|
||||
clients <- newTVarIO mempty
|
||||
return Env {config, server, serverIdentity, queueStore, msgStore, random, storeLog, tlsServerParams, serverStats, sockets, clientSeq, clients}
|
||||
where
|
||||
restoreQueues :: QueueStore -> FilePath -> m (StoreLog 'WriteMode)
|
||||
restoreQueues :: QueueStore -> FilePath -> IO (StoreLog 'WriteMode)
|
||||
restoreQueues QueueStore {queues, senders, notifiers} f = do
|
||||
(qs, s) <- liftIO $ readWriteStoreLog f
|
||||
(qs, s) <- readWriteStoreLog f
|
||||
atomically $ do
|
||||
writeTVar queues =<< mapM newTVar qs
|
||||
writeTVar senders $! M.foldr' addSender M.empty qs
|
||||
|
||||
@@ -23,7 +23,6 @@ where
|
||||
import Control.Applicative (optional)
|
||||
import Control.Logger.Simple (logError)
|
||||
import Control.Monad (when)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
@@ -126,10 +125,10 @@ clientTransportConfig TransportClientConfig {logTLSErrors} =
|
||||
TransportConfig {logTLSErrors, transportTimeout = Nothing}
|
||||
|
||||
-- | Connect to passed TCP host:port and pass handle to the client.
|
||||
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTransportClient :: Transport c => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> IO a) -> IO a
|
||||
runTransportClient = runTLSTransportClient supportedParameters Nothing
|
||||
|
||||
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTLSTransportClient :: Transport c => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> IO a) -> IO a
|
||||
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials} proxyUsername host port keyHash client = do
|
||||
serverCert <- newEmptyTMVarIO
|
||||
let hostName = B.unpack $ strEncode host
|
||||
@@ -137,7 +136,7 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy,
|
||||
connectTCP = case socksProxy of
|
||||
Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host
|
||||
_ -> connectTCPClient hostName
|
||||
c <- liftIO $ do
|
||||
c <- do
|
||||
sock <- connectTCP port
|
||||
mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive" <> tshow e)
|
||||
let tCfg = clientTransportConfig cfg
|
||||
@@ -148,7 +147,7 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy,
|
||||
closeTLS tls >> error "onServerCertificate failed"
|
||||
Just c -> pure c
|
||||
getClientConnection tCfg chain tls
|
||||
client c `E.finally` liftIO (closeConnection c)
|
||||
client c `E.finally` closeConnection c
|
||||
where
|
||||
hostAddr = \case
|
||||
THIPv4 addr -> SocksAddrIPV4 $ tupleToHostAddress addr
|
||||
|
||||
@@ -26,7 +26,6 @@ where
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
import qualified Crypto.Store.X509 as SX
|
||||
import Data.Default (def)
|
||||
import Data.List (find)
|
||||
@@ -70,27 +69,26 @@ serverTransportConfig TransportServerConfig {logTLSErrors} =
|
||||
-- | Run transport server (plain TCP or WebSockets) on passed TCP port and signal when server started and stopped via passed TMVar.
|
||||
--
|
||||
-- All accepted connections are passed to the passed function.
|
||||
runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m ()
|
||||
runTransportServer :: forall c. Transport c => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> IO ()) -> IO ()
|
||||
runTransportServer started port params cfg server = do
|
||||
ss <- atomically newSocketState
|
||||
runTransportServerState ss started port params cfg server
|
||||
|
||||
runTransportServerState :: forall c m. (Transport c, MonadUnliftIO m) => SocketState -> TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m ()
|
||||
runTransportServerState :: forall c . Transport c => SocketState -> TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> IO ()) -> IO ()
|
||||
runTransportServerState ss started port = runTransportServerSocketState ss started (startTCPServer started port) (transportName (TProxy :: TProxy c))
|
||||
|
||||
-- | Run a transport server with provided connection setup and handler.
|
||||
runTransportServerSocket :: (MonadUnliftIO m, Transport a) => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m ()
|
||||
runTransportServerSocket :: Transport a => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> IO ()) -> IO ()
|
||||
runTransportServerSocket started getSocket threadLabel serverParams cfg server = do
|
||||
ss <- atomically newSocketState
|
||||
runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server
|
||||
|
||||
-- | Run a transport server with provided connection setup and handler.
|
||||
runTransportServerSocketState :: (MonadUnliftIO m, Transport a) => SocketState -> TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m ()
|
||||
runTransportServerSocketState :: Transport a => SocketState -> TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> IO ()) -> IO ()
|
||||
runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server = do
|
||||
u <- askUnliftIO
|
||||
labelMyThread $ "transport server for " <> threadLabel
|
||||
liftIO . runTCPServerSocket ss started getSocket $ \conn ->
|
||||
E.bracket (setup conn >>= maybe (fail "tls setup timeout") pure) closeConnection (unliftIO u . server)
|
||||
runTCPServerSocket ss started getSocket $ \conn ->
|
||||
E.bracket (setup conn >>= maybe (fail "tls setup timeout") pure) closeConnection server
|
||||
where
|
||||
tCfg = serverTransportConfig cfg
|
||||
setup conn = timeout (tlsSetupTimeout cfg) $ do
|
||||
|
||||
@@ -50,26 +50,18 @@ maybeWord :: (a -> ByteString) -> Maybe a -> ByteString
|
||||
maybeWord f = maybe "" $ B.cons ' ' . f
|
||||
{-# INLINE maybeWord #-}
|
||||
|
||||
liftIOEither :: (MonadIO m, MonadError e m) => IO (Either e a) -> m a
|
||||
liftIOEither a = liftIO a >>= liftEither
|
||||
{-# INLINE liftIOEither #-}
|
||||
|
||||
liftError :: (MonadIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a
|
||||
liftError f = liftEitherError f . runExceptT
|
||||
liftError :: MonadIO m => (e -> e') -> ExceptT e IO a -> ExceptT e' m a
|
||||
liftError f = liftError' f . runExceptT
|
||||
{-# INLINE liftError #-}
|
||||
|
||||
liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a
|
||||
liftEitherError f a = liftIOEither (first f <$> a)
|
||||
{-# INLINE liftEitherError #-}
|
||||
liftError' :: MonadIO m => (e -> e') -> IO (Either e a) -> ExceptT e' m a
|
||||
liftError' f = ExceptT . fmap (first f) . liftIO
|
||||
{-# INLINE liftError' #-}
|
||||
|
||||
liftEitherWith :: MonadError e' m => (e -> e') -> Either e a -> m a
|
||||
liftEitherWith :: MonadIO m => (e -> e') -> Either e a -> ExceptT e' m a
|
||||
liftEitherWith f = liftEither . first f
|
||||
{-# INLINE liftEitherWith #-}
|
||||
|
||||
liftE :: (e -> e') -> ExceptT e IO a -> ExceptT e' IO a
|
||||
liftE f a = ExceptT $ first f <$> runExceptT a
|
||||
{-# INLINE liftE #-}
|
||||
|
||||
ifM :: Monad m => m Bool -> m a -> m a -> m a
|
||||
ifM ba t f = ba >>= \b -> if b then t else f
|
||||
{-# INLINE ifM #-}
|
||||
@@ -109,10 +101,18 @@ tryAllErrors :: (MonadUnliftIO m, MonadError e m) => (E.SomeException -> e) -> m
|
||||
tryAllErrors err action = tryError action `UE.catch` (pure . Left . err)
|
||||
{-# INLINE tryAllErrors #-}
|
||||
|
||||
tryAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> m (Either e a)
|
||||
tryAllErrors' err action = runExceptT action `UE.catch` (pure . Left . err)
|
||||
{-# INLINE tryAllErrors' #-}
|
||||
|
||||
catchAllErrors :: (MonadUnliftIO m, MonadError e m) => (E.SomeException -> e) -> m a -> (e -> m a) -> m a
|
||||
catchAllErrors err action handler = tryAllErrors err action >>= either handler pure
|
||||
{-# INLINE catchAllErrors #-}
|
||||
|
||||
catchAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> m a) -> m a
|
||||
catchAllErrors' err action handler = tryAllErrors' err action >>= either handler pure
|
||||
{-# INLINE catchAllErrors' #-}
|
||||
|
||||
catchThrow :: (MonadUnliftIO m, MonadError e m) => m a -> (E.SomeException -> e) -> m a
|
||||
catchThrow action err = catchAllErrors err action throwError
|
||||
{-# INLINE catchThrow #-}
|
||||
@@ -148,8 +148,8 @@ safeDecodeUtf8 = decodeUtf8With onError
|
||||
where
|
||||
onError _ _ = Just '?'
|
||||
|
||||
timeoutThrow :: (MonadUnliftIO m, MonadError e m) => e -> Int -> m a -> m a
|
||||
timeoutThrow e ms action = timeout ms action >>= maybe (throwError e) pure
|
||||
timeoutThrow :: MonadUnliftIO m => e -> Int -> ExceptT e m a -> ExceptT e m a
|
||||
timeoutThrow e ms action = ExceptT (sequence <$> (ms `timeout` runExceptT action)) >>= maybe (throwError e) pure
|
||||
|
||||
threadDelay' :: Int64 -> IO ()
|
||||
threadDelay' time
|
||||
|
||||
@@ -103,14 +103,14 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
|
||||
found@(RCCtrlAddress {address} :| _) <- findCtrlAddress
|
||||
c@RCHClient_ {startedPort, announcer} <- liftIO mkClient
|
||||
hostKeys <- atomically genHostKeys
|
||||
action <- runClient c r hostKeys `putRCError` r
|
||||
action <- liftIO $ runClient c r hostKeys
|
||||
-- wait for the port to make invitation
|
||||
portNum <- atomically $ readTMVar startedPort
|
||||
signedInv@RCSignedInvitation {invitation} <- maybe (throwError RCETLSStartFailed) (liftIO . mkInvitation hostKeys address) portNum
|
||||
when multicast $ case knownHost of
|
||||
Nothing -> throwError RCENewController
|
||||
Just KnownHostPairing {hostDhPubKey} -> do
|
||||
ann <- async . liftIO . runExceptT $ announceRC drg 60 idPrivKey hostDhPubKey hostKeys invitation
|
||||
ann <- liftIO . async . runExceptT $ announceRC drg 60 idPrivKey hostDhPubKey hostKeys invitation
|
||||
atomically $ putTMVar announcer ann
|
||||
pure (found, signedInv, RCHostClient {action, client_ = c}, r)
|
||||
where
|
||||
@@ -125,9 +125,9 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
|
||||
endSession <- newEmptyTMVarIO
|
||||
hostCAHash <- newEmptyTMVarIO
|
||||
pure RCHClient_ {startedPort, announcer, hostCAHash, endSession}
|
||||
runClient :: RCHClient_ -> RCStepTMVar (SessionCode, TLS, RCStepTMVar (RCHostSession, RCHostHello, RCHostPairing)) -> RCHostKeys -> ExceptT RCErrorType IO (Async ())
|
||||
runClient :: RCHClient_ -> RCStepTMVar (SessionCode, TLS, RCStepTMVar (RCHostSession, RCHostHello, RCHostPairing)) -> RCHostKeys -> IO (Async ())
|
||||
runClient RCHClient_ {startedPort, announcer, hostCAHash, endSession} r hostKeys = do
|
||||
tlsCreds <- liftIO $ genTLSCredentials drg caKey caCert
|
||||
tlsCreds <- genTLSCredentials drg caKey caCert
|
||||
startTLSServer port_ startedPort tlsCreds (tlsHooks r knownHost hostCAHash) $ \tls ->
|
||||
void . runExceptT $ do
|
||||
r' <- newEmptyTMVarIO
|
||||
@@ -265,7 +265,7 @@ connectRCCtrl_ :: TVar ChaChaDRG -> RCCtrlPairing -> RCInvitation -> J.Value ->
|
||||
connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca, host, port} hostAppInfo = do
|
||||
r <- newEmptyTMVarIO
|
||||
c <- liftIO mkClient
|
||||
action <- async $ runClient c r `putRCError` r
|
||||
action <- liftIO . async . void . runExceptT $ runClient c r `putRCError` r
|
||||
pure (RCCtrlClient {action, client_ = c}, r)
|
||||
where
|
||||
mkClient :: IO RCCClient_
|
||||
@@ -280,7 +280,7 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca,
|
||||
TLS.Credentials (creds : _) -> pure $ Just creds
|
||||
_ -> throwError $ RCEInternal "genTLSCredentials must generate credentials"
|
||||
let clientConfig = defaultTransportClientConfig {clientCredentials}
|
||||
runTransportClient clientConfig Nothing host (show port) (Just ca) $ \tls@TLS {tlsBuffer, tlsContext} -> do
|
||||
ExceptT . runTransportClient clientConfig Nothing host (show port) (Just ca) $ \tls@TLS {tlsBuffer, tlsContext} -> runExceptT $ do
|
||||
-- pump socket to detect connection problems
|
||||
liftIO $ peekBuffered tlsBuffer 100000 (TLS.recvData tlsContext) >>= logDebug . tshow -- should normally be ("", Nothing) here
|
||||
logDebug "Got TLS connection"
|
||||
@@ -360,11 +360,11 @@ prepareCtrlSession
|
||||
-- * Multicast discovery
|
||||
|
||||
announceRC :: TVar ChaChaDRG -> Int -> C.PrivateKeyEd25519 -> C.PublicKeyX25519 -> RCHostKeys -> RCInvitation -> ExceptT RCErrorType IO ()
|
||||
announceRC drg maxCount idPrivKey knownDhPub RCHostKeys {sessKeys, dhKeys} inv = withSender $ \sender -> do
|
||||
announceRC drg maxCount idPrivKey knownDhPub RCHostKeys {sessKeys, dhKeys} inv = ExceptT $ withSender $ \sender -> runExceptT $ do
|
||||
replicateM_ maxCount $ do
|
||||
logDebug "Announcing..."
|
||||
nonce <- atomically $ C.randomCbNonce drg
|
||||
encInvitation <- liftEitherWith undefined $ C.cbEncrypt sharedKey nonce sigInvitation encInvitationSize
|
||||
encInvitation <- liftEitherWith (const RCEEncrypt) $ C.cbEncrypt sharedKey nonce sigInvitation encInvitationSize
|
||||
liftIO . UDP.send sender $ smpEncode RCEncInvitation {dhPubKey, nonce, encInvitation}
|
||||
threadDelay 1000000
|
||||
where
|
||||
@@ -375,9 +375,9 @@ announceRC drg maxCount idPrivKey knownDhPub RCHostKeys {sessKeys, dhKeys} inv =
|
||||
|
||||
discoverRCCtrl :: TMVar Int -> NonEmpty RCCtrlPairing -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation)
|
||||
discoverRCCtrl subscribers pairings =
|
||||
timeoutThrow RCENotDiscovered 30000000 $ withListener subscribers $ \listener ->
|
||||
loop $ do
|
||||
(source, bytes) <- recvAnnounce listener
|
||||
timeoutThrow RCENotDiscovered 30000000 $ ExceptT $ withListener subscribers $ \listener ->
|
||||
runExceptT . loop $ do
|
||||
(source, bytes) <- liftIO $ recvAnnounce listener
|
||||
encInvitation <- liftEitherWith (const RCEInvitation) $ smpDecode bytes
|
||||
r@(_, RCVerifiedInvitation RCInvitation {host}) <- findRCCtrlPairing pairings encInvitation
|
||||
case source of
|
||||
@@ -386,10 +386,7 @@ discoverRCCtrl subscribers pairings =
|
||||
pure r
|
||||
where
|
||||
loop :: ExceptT RCErrorType IO a -> ExceptT RCErrorType IO a
|
||||
loop action =
|
||||
liftIO (runExceptT action) >>= \case
|
||||
Left err -> logError (tshow err) >> loop action
|
||||
Right res -> pure res
|
||||
loop action = action `catchRCError` \e -> logError (tshow e) >> loop action
|
||||
|
||||
findRCCtrlPairing :: NonEmpty RCCtrlPairing -> RCEncInvitation -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation)
|
||||
findRCCtrlPairing pairings RCEncInvitation {dhPubKey, nonce, encInvitation} = do
|
||||
|
||||
@@ -68,7 +68,7 @@ preferAddress RCCtrlAddress {address, interface} addrs =
|
||||
matchAddr RCCtrlAddress {address = a} = a == address
|
||||
matchIface RCCtrlAddress {interface = i} = i == interface
|
||||
|
||||
startTLSServer :: MonadUnliftIO m => Maybe Word16 -> TMVar (Maybe N.PortNumber) -> TLS.Credentials -> TLS.ServerHooks -> (Transport.TLS -> IO ()) -> m (Async ())
|
||||
startTLSServer :: Maybe Word16 -> TMVar (Maybe N.PortNumber) -> TLS.Credentials -> TLS.ServerHooks -> (Transport.TLS -> IO ()) -> IO (Async ())
|
||||
startTLSServer port_ startedOnPort credentials hooks server = async . liftIO $ do
|
||||
started <- newEmptyTMVarIO
|
||||
bracketOnError (startTCPServer started $ maybe "0" show port_) (\_e -> setPort Nothing) $ \socket ->
|
||||
@@ -91,14 +91,14 @@ startTLSServer port_ startedOnPort credentials hooks server = async . liftIO $ d
|
||||
TLS.serverSupported = supportedParameters
|
||||
}
|
||||
|
||||
withSender :: MonadUnliftIO m => (UDP.UDPSocket -> m a) -> m a
|
||||
withSender = bracket (liftIO $ UDP.clientSocket MULTICAST_ADDR_V4 DISCOVERY_PORT False) (liftIO . UDP.close)
|
||||
withSender :: (UDP.UDPSocket -> IO a) -> IO a
|
||||
withSender = bracket (UDP.clientSocket MULTICAST_ADDR_V4 DISCOVERY_PORT False) (UDP.close)
|
||||
|
||||
withListener :: MonadUnliftIO m => TMVar Int -> (UDP.ListenSocket -> m a) -> m a
|
||||
withListener :: TMVar Int -> (UDP.ListenSocket -> IO a) -> IO a
|
||||
withListener subscribers = bracket (openListener subscribers) (closeListener subscribers)
|
||||
|
||||
openListener :: MonadIO m => TMVar Int -> m UDP.ListenSocket
|
||||
openListener subscribers = liftIO $ do
|
||||
openListener :: TMVar Int -> IO UDP.ListenSocket
|
||||
openListener subscribers = do
|
||||
sock <- UDP.serverSocket (MULTICAST_ADDR_V4, read DISCOVERY_PORT)
|
||||
logDebug $ "Discovery listener socket: " <> tshow sock
|
||||
let raw = UDP.listenSocket sock
|
||||
@@ -106,10 +106,9 @@ openListener subscribers = liftIO $ do
|
||||
joinMulticast subscribers raw (listenerHostAddr4 sock)
|
||||
pure sock
|
||||
|
||||
closeListener :: MonadIO m => TMVar Int -> UDP.ListenSocket -> m ()
|
||||
closeListener :: TMVar Int -> UDP.ListenSocket -> IO ()
|
||||
closeListener subscribers sock =
|
||||
liftIO $
|
||||
partMulticast subscribers (UDP.listenSocket sock) (listenerHostAddr4 sock) `finally` UDP.stop sock
|
||||
partMulticast subscribers (UDP.listenSocket sock) (listenerHostAddr4 sock) `finally` UDP.stop sock
|
||||
|
||||
joinMulticast :: TMVar Int -> N.Socket -> N.HostAddress -> IO ()
|
||||
joinMulticast subscribers sock group = do
|
||||
@@ -132,7 +131,7 @@ listenerHostAddr4 sock = case UDP.mySockAddr sock of
|
||||
N.SockAddrInet _port host -> host
|
||||
_ -> error "MULTICAST_ADDR_V4 is V4"
|
||||
|
||||
recvAnnounce :: MonadIO m => UDP.ListenSocket -> m (N.SockAddr, ByteString)
|
||||
recvAnnounce sock = liftIO $ do
|
||||
recvAnnounce :: UDP.ListenSocket -> IO (N.SockAddr, ByteString)
|
||||
recvAnnounce sock = do
|
||||
(invite, UDP.ClientSockAddr source _cmsg) <- UDP.recvFrom sock
|
||||
pure (source, invite)
|
||||
|
||||
@@ -238,17 +238,6 @@ type SessionCode = ByteString
|
||||
|
||||
type RCStepTMVar a = TMVar (Either RCErrorType a)
|
||||
|
||||
type Tasks = TVar [Async ()]
|
||||
|
||||
asyncRegistered :: MonadUnliftIO m => Tasks -> m () -> m ()
|
||||
asyncRegistered tasks action = async action >>= registerAsync tasks
|
||||
|
||||
registerAsync :: MonadIO m => Tasks -> Async () -> m ()
|
||||
registerAsync tasks = atomically . modifyTVar tasks . (:)
|
||||
|
||||
cancelTasks :: MonadIO m => Tasks -> m ()
|
||||
cancelTasks tasks = readTVarIO tasks >>= mapM_ cancel
|
||||
|
||||
$(JQ.deriveJSON (sumTypeJSON $ dropPrefix "RCE") ''RCErrorType)
|
||||
|
||||
$(JQ.deriveJSON defaultJSON ''RCCtrlAddress)
|
||||
|
||||
@@ -233,13 +233,13 @@ inAnyOrder g rs = do
|
||||
expected :: a -> (a -> Bool) -> Bool
|
||||
expected r rp = rp r
|
||||
|
||||
createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c)
|
||||
createConnection :: AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> AE (ConnId, ConnectionRequestUri c)
|
||||
createConnection c userId enableNtfs cMode clientData = A.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQSupportOn)
|
||||
|
||||
joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId
|
||||
joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE ConnId
|
||||
joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId enableNtfs cReq connInfo PQSupportOn
|
||||
|
||||
sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> m AgentMsgId
|
||||
sendMessage :: AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> AE AgentMsgId
|
||||
sendMessage c connId msgFlags msgBody = do
|
||||
(msgId, pqEnc) <- A.sendMessage c connId PQEncOn msgFlags msgBody
|
||||
liftIO $ pqEnc `shouldBe` PQEncOn
|
||||
@@ -664,7 +664,7 @@ testAsyncInitiatingOffline :: HasCallStack => IO ()
|
||||
testAsyncInitiatingOffline =
|
||||
withAgentClients2 $ \alice bob -> runRight_ $ do
|
||||
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe
|
||||
disposeAgentClient alice
|
||||
liftIO $ disposeAgentClient alice
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe
|
||||
alice' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB
|
||||
subscribeConnection alice' bobId
|
||||
@@ -680,7 +680,7 @@ testAsyncJoiningOfflineBeforeActivation =
|
||||
withAgentClients2 $ \alice bob -> runRight_ $ do
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe
|
||||
disposeAgentClient bob
|
||||
liftIO $ disposeAgentClient bob
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
bob' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB2
|
||||
@@ -694,9 +694,9 @@ testAsyncBothOffline :: HasCallStack => IO ()
|
||||
testAsyncBothOffline =
|
||||
withAgentClients2 $ \alice bob -> runRight_ $ do
|
||||
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe
|
||||
disposeAgentClient alice
|
||||
liftIO $ disposeAgentClient alice
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe
|
||||
disposeAgentClient bob
|
||||
liftIO $ disposeAgentClient bob
|
||||
alice' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB
|
||||
subscribeConnection alice' bobId
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice'
|
||||
@@ -1067,8 +1067,7 @@ testExpireMessageQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1} testP
|
||||
b <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2
|
||||
(aId, bId) <- runRight $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
liftIO $ threadDelay 500000
|
||||
disposeAgentClient b
|
||||
liftIO $ threadDelay 500000 >> disposeAgentClient b
|
||||
4 <- sendMessage a bId SMP.noMsgFlags "1"
|
||||
get a ##> ("", bId, SENT 4)
|
||||
5 <- sendMessage a bId SMP.noMsgFlags "2"
|
||||
@@ -1091,8 +1090,7 @@ testExpireManyMessagesQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1}
|
||||
b <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2
|
||||
(aId, bId) <- runRight $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
liftIO $ threadDelay 500000
|
||||
disposeAgentClient b
|
||||
liftIO $ threadDelay 500000 >> disposeAgentClient b
|
||||
4 <- sendMessage a bId SMP.noMsgFlags "1"
|
||||
get a ##> ("", bId, SENT 4)
|
||||
5 <- sendMessage a bId SMP.noMsgFlags "2"
|
||||
@@ -1161,13 +1159,13 @@ setupDesynchronizedRatchet alice bob = do
|
||||
runRight_ $ do
|
||||
subscribeConnection bob2 aliceId
|
||||
|
||||
Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId PQSupportOn False
|
||||
Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ synchronizeRatchet bob2 aliceId PQSupportOn False
|
||||
|
||||
8 <- sendMessage alice bobId SMP.noMsgFlags "hello 5"
|
||||
get alice ##> ("", bobId, SENT 8)
|
||||
get bob2 =##> ratchetSyncP aliceId RSRequired
|
||||
|
||||
Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ sendMessage bob2 aliceId SMP.noMsgFlags "hello 6"
|
||||
Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ sendMessage bob2 aliceId SMP.noMsgFlags "hello 6"
|
||||
pure ()
|
||||
|
||||
pure (aliceId, bobId, bob2)
|
||||
@@ -1224,7 +1222,7 @@ testRatchetSyncClientRestart t = do
|
||||
("", "", DOWN _ _) <- nGet bob2
|
||||
ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False
|
||||
liftIO $ ratchetSyncState `shouldBe` RSStarted
|
||||
disposeAgentClient bob2
|
||||
liftIO $ disposeAgentClient bob2
|
||||
bob3 <- getSMPAgentClient' 3 agentCfg initAgentServers testDB2
|
||||
withSmpServerStoreMsgLogOn t testPort $ \_ -> do
|
||||
runRight_ $ do
|
||||
@@ -1420,12 +1418,12 @@ testSuspendingAgent =
|
||||
get a ##> ("", bId, SENT 4)
|
||||
get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False
|
||||
ackMessage b aId 4 Nothing
|
||||
suspendAgent b 1000000
|
||||
liftIO $ suspendAgent b 1000000
|
||||
get' b ##> ("", "", SUSPENDED)
|
||||
5 <- sendMessage a bId SMP.noMsgFlags "hello 2"
|
||||
get a ##> ("", bId, SENT 5)
|
||||
Nothing <- 100000 `timeout` get b
|
||||
foregroundAgent b
|
||||
liftIO $ foregroundAgent b
|
||||
get b =##> \case ("", c, Msg "hello 2") -> c == aId; _ -> False
|
||||
|
||||
testSuspendingAgentCompleteSending :: ATransport -> IO ()
|
||||
@@ -1444,7 +1442,7 @@ testSuspendingAgentCompleteSending t = withAgentClients2 $ \a b -> do
|
||||
5 <- sendMessage b aId SMP.noMsgFlags "hello too"
|
||||
6 <- sendMessage b aId SMP.noMsgFlags "how are you?"
|
||||
liftIO $ threadDelay 100000
|
||||
suspendAgent b 5000000
|
||||
liftIO $ suspendAgent b 5000000
|
||||
|
||||
withSmpServerStoreLogOn t testPort $ \_ -> runRight_ @AgentErrorType $ do
|
||||
pGet b =##> \case ("", c, APC _ (SENT 5)) -> c == aId; ("", "", APC _ UP {}) -> True; _ -> False
|
||||
@@ -1473,7 +1471,7 @@ testSuspendingAgentTimeout t = withAgentClients2 $ \a b -> do
|
||||
("", "", DOWN {}) <- nGet b
|
||||
5 <- sendMessage b aId SMP.noMsgFlags "hello too"
|
||||
6 <- sendMessage b aId SMP.noMsgFlags "how are you?"
|
||||
suspendAgent b 100000
|
||||
liftIO $ suspendAgent b 100000
|
||||
("", "", SUSPENDED) <- nGet b
|
||||
pure ()
|
||||
|
||||
@@ -2095,7 +2093,7 @@ testSwitchDelete servers = do
|
||||
runRight_ $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
exchangeGreetingsMsgId 4 a bId b aId
|
||||
disposeAgentClient b
|
||||
liftIO $ disposeAgentClient b
|
||||
stats <- switchConnectionAsync a "" bId
|
||||
liftIO $ rcvSwchStatuses' stats `shouldMatchList` [Just RSSwitchStarted]
|
||||
phaseRcv a bId SPStarted [Just RSSendingQADD, Nothing]
|
||||
@@ -2120,7 +2118,7 @@ testAbortSwitchStarted servers = do
|
||||
liftIO $ rcvSwchStatuses' stats `shouldMatchList` [Just RSSwitchStarted]
|
||||
phaseRcv a bId SPStarted [Just RSSendingQADD, Nothing]
|
||||
-- repeat switch is prohibited
|
||||
Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ switchConnectionAsync a "" bId
|
||||
Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ switchConnectionAsync a "" bId
|
||||
-- abort current switch
|
||||
stats' <- abortConnectionSwitch a bId
|
||||
liftIO $ rcvSwchStatuses' stats' `shouldMatchList` [Nothing]
|
||||
@@ -2242,7 +2240,7 @@ testCannotAbortSwitchSecured servers = do
|
||||
withA' $ \a -> do
|
||||
phaseRcv a bId SPConfirmed [Just RSSendingQADD, Nothing]
|
||||
phaseRcv a bId SPSecured [Just RSSendingQUSE, Nothing]
|
||||
Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ abortConnectionSwitch a bId
|
||||
Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ abortConnectionSwitch a bId
|
||||
pure ()
|
||||
withA $ \a -> withB $ \b -> runRight_ $ do
|
||||
subscribeConnection a bId
|
||||
@@ -2407,7 +2405,7 @@ testSMPServerConnectionTest :: ATransport -> Maybe BasicAuth -> SMPServerWithAut
|
||||
testSMPServerConnectionTest t newQueueBasicAuth srv =
|
||||
withSmpServerConfigOn t cfg {newQueueBasicAuth} testPort2 $ \_ -> do
|
||||
a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB -- initially passed server is not running
|
||||
runRight $ testProtocolServer a 1 srv
|
||||
testProtocolServer a 1 srv
|
||||
|
||||
testRatchetAdHash :: HasCallStack => IO ()
|
||||
testRatchetAdHash =
|
||||
@@ -2551,7 +2549,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do
|
||||
exchangeGreetings a bId1' b aId1'
|
||||
a `hasClients` 1
|
||||
b `hasClients` 1
|
||||
setNetworkConfig a nc {sessionMode = TSMEntity}
|
||||
liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity}
|
||||
liftIO $ threadDelay 250000
|
||||
("", "", DOWN _ _) <- nGet a
|
||||
("", "", UP _ _) <- nGet a
|
||||
@@ -2560,7 +2558,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do
|
||||
exchangeGreetingsMsgId 6 a bId1 b aId1
|
||||
exchangeGreetingsMsgId 6 a bId1' b aId1'
|
||||
liftIO $ threadDelay 250000
|
||||
setNetworkConfig a nc {sessionMode = TSMUser}
|
||||
liftIO $ setNetworkConfig a nc {sessionMode = TSMUser}
|
||||
liftIO $ threadDelay 250000
|
||||
("", "", DOWN _ _) <- nGet a
|
||||
("", "", DOWN _ _) <- nGet a
|
||||
@@ -2575,7 +2573,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do
|
||||
exchangeGreetings a bId2' b aId2'
|
||||
a `hasClients` 2
|
||||
b `hasClients` 1
|
||||
setNetworkConfig a nc {sessionMode = TSMEntity}
|
||||
liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity}
|
||||
liftIO $ threadDelay 250000
|
||||
("", "", DOWN _ _) <- nGet a
|
||||
("", "", DOWN _ _) <- nGet a
|
||||
@@ -2587,7 +2585,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do
|
||||
exchangeGreetingsMsgId 6 a bId2 b aId2
|
||||
exchangeGreetingsMsgId 6 a bId2' b aId2'
|
||||
liftIO $ threadDelay 250000
|
||||
setNetworkConfig a nc {sessionMode = TSMUser}
|
||||
liftIO $ setNetworkConfig a nc {sessionMode = TSMUser}
|
||||
liftIO $ threadDelay 250000
|
||||
("", "", DOWN _ _) <- nGet a
|
||||
("", "", DOWN _ _) <- nGet a
|
||||
@@ -2625,9 +2623,10 @@ testServerMultipleIdentities =
|
||||
get bob ##> ("", aliceId, CON)
|
||||
exchangeGreetings alice bobId bob aliceId
|
||||
-- this saves queue with second server identity
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe
|
||||
disposeAgentClient bob
|
||||
bob' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB2
|
||||
bob' <- liftIO $ do
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe
|
||||
disposeAgentClient bob
|
||||
getSMPAgentClient' 3 agentCfg initAgentServers testDB2
|
||||
subscribeConnection bob' aliceId
|
||||
exchangeGreetingsMsgId 6 alice bobId bob' aliceId
|
||||
where
|
||||
|
||||
@@ -42,7 +42,6 @@ import Control.Monad.Trans.Except
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Types as JT
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
@@ -51,6 +50,7 @@ import SMPAgentClient (agentCfg, initAgentServers, initAgentServers2, testDB, te
|
||||
import SMPClient (cfg, cfgV7, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn)
|
||||
import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage)
|
||||
import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore')
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers)
|
||||
import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO)
|
||||
import Simplex.Messaging.Agent.Store.SQLite (getSavedNtfToken)
|
||||
@@ -179,7 +179,7 @@ testNotificationToken APNSMockServer {apnsQ} = do
|
||||
deleteNtfToken a tkn
|
||||
-- agent deleted this token
|
||||
Left (CMD PROHIBITED) <- tryE $ checkNtfToken a tkn
|
||||
disposeAgentClient a
|
||||
liftIO $ disposeAgentClient a
|
||||
|
||||
(.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString
|
||||
v .-> key = do
|
||||
@@ -211,7 +211,7 @@ testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do
|
||||
-- can still use the first verification code, it is the same after decryption
|
||||
verifyNtfToken a tkn nonce verification
|
||||
NTActive <- checkNtfToken a tkn
|
||||
disposeAgentClient a
|
||||
liftIO $ disposeAgentClient a
|
||||
|
||||
testNtfTokenSecondRegistration :: APNSMockServer -> IO ()
|
||||
testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do
|
||||
@@ -247,8 +247,9 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do
|
||||
Left (NTF AUTH) <- tryE $ checkNtfToken a tkn
|
||||
-- and the second is active
|
||||
NTActive <- checkNtfToken a' tkn
|
||||
disposeAgentClient a
|
||||
disposeAgentClient a'
|
||||
pure ()
|
||||
disposeAgentClient a
|
||||
disposeAgentClient a'
|
||||
|
||||
testNtfTokenServerRestart :: ATransport -> APNSMockServer -> IO ()
|
||||
testNtfTokenServerRestart t APNSMockServer {apnsQ} = do
|
||||
@@ -277,11 +278,11 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do
|
||||
liftIO $ sendApnsResponse' APNSRespOk
|
||||
verifyNtfToken a' tkn nonce' verification'
|
||||
NTActive <- checkNtfToken a' tkn
|
||||
disposeAgentClient a'
|
||||
liftIO $ disposeAgentClient a'
|
||||
|
||||
getTestNtfTokenPort :: (MonadUnliftIO m, MonadError AgentErrorType m) => AgentClient -> m String
|
||||
getTestNtfTokenPort :: AgentClient -> AE String
|
||||
getTestNtfTokenPort a =
|
||||
runReaderT (withStore' a getSavedNtfToken) (agentEnv a) >>= \case
|
||||
ExceptT (runExceptT (withStore' a getSavedNtfToken) `runReaderT` agentEnv a) >>= \case
|
||||
Just NtfToken {ntfServer = ProtocolServer {port}} -> pure port
|
||||
Nothing -> error "no active NtfToken"
|
||||
|
||||
@@ -317,18 +318,18 @@ testNtfTokenChangeServers t APNSMockServer {apnsQ} =
|
||||
a <- liftIO $ getSMPAgentClient' 1 agentCfg initAgentServers testDB
|
||||
tkn <- registerTestToken a "abcd" NMInstant apnsQ
|
||||
NTActive <- checkNtfToken a tkn
|
||||
setNtfServers a [testNtfServer2]
|
||||
liftIO $ setNtfServers a [testNtfServer2]
|
||||
NTActive <- checkNtfToken a tkn -- still works on old server
|
||||
disposeAgentClient a
|
||||
liftIO $ disposeAgentClient a
|
||||
pure tkn
|
||||
|
||||
threadDelay 1000000
|
||||
|
||||
a <- liftIO $ getSMPAgentClient' 2 agentCfg initAgentServers testDB
|
||||
a <- getSMPAgentClient' 2 agentCfg initAgentServers testDB
|
||||
runRight_ $ do
|
||||
getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort
|
||||
NTActive <- checkNtfToken a tkn1
|
||||
setNtfServers a [testNtfServer2] -- just change configured server list
|
||||
liftIO $ setNtfServers a [testNtfServer2] -- just change configured server list
|
||||
getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort -- not yet changed
|
||||
-- trigger token replace
|
||||
tkn2 <- registerTestToken a "xyzw" NMInstant apnsQ
|
||||
@@ -345,7 +346,7 @@ testRunNTFServerTests :: ATransport -> NtfServer -> IO (Maybe ProtocolTestFailur
|
||||
testRunNTFServerTests t srv =
|
||||
withNtfServerThreadOn t ntfTestPort $ \ntf -> do
|
||||
a <- liftIO $ getSMPAgentClient' 1 agentCfg initAgentServers testDB
|
||||
r <- runRight $ testProtocolServer a 1 $ ProtoServerWithAuth srv Nothing
|
||||
r <- testProtocolServer a 1 $ ProtoServerWithAuth srv Nothing
|
||||
killThread ntf
|
||||
pure r
|
||||
|
||||
@@ -712,7 +713,7 @@ testNotificationsOldToken APNSMockServer {apnsQ} = do
|
||||
liftIO $ threadDelay 250000
|
||||
testMessageAB "hello"
|
||||
-- change server
|
||||
setNtfServers a [testNtfServer2] -- server 2 isn't running now, don't use
|
||||
liftIO $ setNtfServers a [testNtfServer2] -- server 2 isn't running now, don't use
|
||||
-- replacing token keeps server
|
||||
_ <- registerTestToken a "xyzw" NMInstant apnsQ
|
||||
getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort
|
||||
@@ -738,7 +739,7 @@ testNotificationsNewToken APNSMockServer {apnsQ} oldNtf = do
|
||||
liftIO $ threadDelay 250000
|
||||
testMessageAB "hello"
|
||||
-- switch
|
||||
setNtfServers a [testNtfServer2]
|
||||
liftIO $ setNtfServers a [testNtfServer2]
|
||||
deleteNtfToken a tkn
|
||||
_ <- registerTestToken a "abcd" NMInstant apnsQ
|
||||
getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort2
|
||||
|
||||
+3
-4
@@ -16,7 +16,6 @@ module NtfClient where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.Except (runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..), (.:))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Types as JT
|
||||
@@ -71,13 +70,13 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
|
||||
ntfTestStoreLogFile :: FilePath
|
||||
ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log"
|
||||
|
||||
testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleNTF c -> m a) -> m a
|
||||
testNtfClient :: Transport c => (THandleNTF c -> IO a) -> IO a
|
||||
testNtfClient client = do
|
||||
Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost
|
||||
runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h -> do
|
||||
g <- liftIO C.newRandom
|
||||
g <- C.newRandom
|
||||
ks <- atomically $ C.generateKeyPair g
|
||||
liftIO (runExceptT $ ntfClientHandshake h ks testKeyHash supportedClientNTFVRange) >>= \case
|
||||
runExceptT (ntfClientHandshake h ks testKeyHash supportedClientNTFVRange) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ import Control.Concurrent (threadDelay)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Types as JT
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import NtfClient
|
||||
@@ -35,6 +34,7 @@ import ServerTests
|
||||
import qualified Simplex.Messaging.Agent.Protocol as AP
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
|
||||
+9
-12
@@ -12,7 +12,6 @@ module SMPAgentClient where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import Data.Map.Strict (Map)
|
||||
@@ -202,7 +201,7 @@ initAgentServers2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer
|
||||
agentCfg :: AgentConfig
|
||||
agentCfg =
|
||||
defaultAgentConfig
|
||||
{ tcpPort = agentTestPort,
|
||||
{ tcpPort = Just agentTestPort,
|
||||
tbqSize = 4,
|
||||
-- database = testDB,
|
||||
smpCfg = defaultSMPClientConfig {qSize = 1, defaultTransport = (testPort, transport @TLS), networkConfig},
|
||||
@@ -224,11 +223,9 @@ fastRetryInterval = defaultReconnectInterval {initialInterval = 50_000}
|
||||
fastMessageRetryInterval :: RetryInterval2
|
||||
fastMessageRetryInterval = RetryInterval2 {riFast = fastRetryInterval, riSlow = fastRetryInterval}
|
||||
|
||||
type AgentTestMonad m = (MonadUnliftIO m, MonadRandom m, MonadFail m)
|
||||
|
||||
withSmpAgentThreadOn_ :: AgentTestMonad m => ATransport -> (ServiceName, ServiceName, FilePath) -> Int -> m () -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn_ :: ATransport -> (ServiceName, ServiceName, FilePath) -> Int -> IO () -> (ThreadId -> IO a) -> IO a
|
||||
withSmpAgentThreadOn_ t (port', smpPort', db') initClientId afterProcess =
|
||||
let cfg' = agentCfg {tcpPort = port'}
|
||||
let cfg' = agentCfg {tcpPort = Just port'}
|
||||
initServers' = initAgentServers {smp = userServers [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
|
||||
in serverBracket
|
||||
( \started -> do
|
||||
@@ -241,24 +238,24 @@ withSmpAgentThreadOn_ t (port', smpPort', db') initClientId afterProcess =
|
||||
userServers :: NonEmpty (ProtoServerWithAuth p) -> Map UserId (NonEmpty (ProtoServerWithAuth p))
|
||||
userServers srvs = M.fromList [(1, srvs)]
|
||||
|
||||
withSmpAgentThreadOn :: AgentTestMonad m => ATransport -> (ServiceName, ServiceName, FilePath) -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn :: ATransport -> (ServiceName, ServiceName, FilePath) -> (ThreadId -> IO a) -> IO a
|
||||
withSmpAgentThreadOn t a@(_, _, db') = withSmpAgentThreadOn_ t a 0 $ removeFile db'
|
||||
|
||||
withSmpAgentOn :: AgentTestMonad m => ATransport -> (ServiceName, ServiceName, FilePath) -> m a -> m a
|
||||
withSmpAgentOn :: ATransport -> (ServiceName, ServiceName, FilePath) -> IO a -> IO a
|
||||
withSmpAgentOn t (port', smpPort', db') = withSmpAgentThreadOn t (port', smpPort', db') . const
|
||||
|
||||
withSmpAgent :: AgentTestMonad m => ATransport -> m a -> m a
|
||||
withSmpAgent :: ATransport -> IO a -> IO a
|
||||
withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
|
||||
|
||||
testSMPAgentClientOn :: (Transport c, MonadUnliftIO m, MonadFail m) => ServiceName -> (c -> m a) -> m a
|
||||
testSMPAgentClientOn :: Transport c => ServiceName -> (c -> IO a) -> IO a
|
||||
testSMPAgentClientOn port' client = do
|
||||
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig agentTestHost
|
||||
runTransportClient defaultTransportClientConfig Nothing useHost port' (Just testKeyHash) $ \h -> do
|
||||
line <- liftIO $ getLn h
|
||||
line <- getLn h
|
||||
if line == "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
then client h
|
||||
else do
|
||||
error $ "wrong welcome message: " <> B.unpack line
|
||||
|
||||
testSMPAgentClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (c -> m a) -> m a
|
||||
testSMPAgentClient :: Transport c => (c -> IO a) -> IO a
|
||||
testSMPAgentClient = testSMPAgentClientOn agentTestPort
|
||||
|
||||
+6
-7
@@ -13,7 +13,6 @@
|
||||
module SMPClient where
|
||||
|
||||
import Control.Monad.Except (runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import Network.Socket
|
||||
@@ -68,23 +67,23 @@ xit'' d t = do
|
||||
ci <- runIO $ lookupEnv "CI"
|
||||
(if ci == Just "true" then skip "skipped on CI" . it d else it d) t
|
||||
|
||||
testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleSMP c -> m a) -> m a
|
||||
testSMPClient :: Transport c => (THandleSMP c -> IO a) -> IO a
|
||||
testSMPClient = testSMPClientVR supportedClientSMPRelayVRange
|
||||
|
||||
testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRangeSMP -> (THandleSMP c -> m a) -> m a
|
||||
testSMPClientVR :: Transport c => VersionRangeSMP -> (THandleSMP c -> IO a) -> IO a
|
||||
testSMPClientVR vr client = do
|
||||
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost
|
||||
runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h -> do
|
||||
g <- liftIO C.newRandom
|
||||
g <- C.newRandom
|
||||
ks <- atomically $ C.generateKeyPair g
|
||||
liftIO (runExceptT $ smpClientHandshake h ks testKeyHash vr) >>= \case
|
||||
runExceptT (smpClientHandshake h ks testKeyHash vr) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
cfg :: ServerConfig
|
||||
cfg =
|
||||
ServerConfig
|
||||
{ transports = undefined,
|
||||
{ transports = [],
|
||||
smpHandshakeTimeout = 60000000,
|
||||
tbqSize = 1,
|
||||
-- serverTbqSize = 1,
|
||||
@@ -129,7 +128,7 @@ withSmpServerConfigOn t cfg' port' =
|
||||
withSmpServerThreadOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a
|
||||
withSmpServerThreadOn t = withSmpServerConfigOn t cfg
|
||||
|
||||
serverBracket :: (HasCallStack, MonadUnliftIO m) => (TMVar Bool -> m ()) -> m () -> (HasCallStack => ThreadId -> m a) -> m a
|
||||
serverBracket :: HasCallStack => (TMVar Bool -> IO ()) -> IO () -> (HasCallStack => ThreadId -> IO a) -> IO a
|
||||
serverBracket process afterProcess f = do
|
||||
started <- newEmptyTMVarIO
|
||||
E.bracket
|
||||
|
||||
@@ -22,7 +22,6 @@ import Control.Exception (SomeException, try)
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Class
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Set as S
|
||||
@@ -31,6 +30,7 @@ import GHC.Stack (withFrozenCallStack)
|
||||
import SMPClient
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.Base64 (encode)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Protocol
|
||||
@@ -103,7 +103,7 @@ tPut1 h t = do
|
||||
[r] <- tPut h [Right t]
|
||||
pure r
|
||||
|
||||
tGet1 :: (ProtocolEncoding v err cmd, Transport c, MonadIO m, MonadFail m) => THandle v c -> m (SignedTransmission err cmd)
|
||||
tGet1 :: (ProtocolEncoding v err cmd, Transport c) => THandle v c -> IO (SignedTransmission err cmd)
|
||||
tGet1 h = do
|
||||
[r] <- liftIO $ tGet h
|
||||
pure r
|
||||
|
||||
+13
-18
@@ -103,7 +103,7 @@ testXFTPAgentSendReceive = withXFTPServer $ do
|
||||
sndr <- getSMPAgentClient' 1 agentCfg initAgentServers testDB
|
||||
(rfd1, rfd2) <- runRight $ do
|
||||
(sfId, _, rfd1, rfd2) <- testSend sndr filePath
|
||||
xftpDeleteSndFileInternal sndr sfId
|
||||
liftIO $ xftpDeleteSndFileInternal sndr sfId
|
||||
pure (rfd1, rfd2)
|
||||
|
||||
-- receive file, delete rcv file
|
||||
@@ -112,9 +112,8 @@ testXFTPAgentSendReceive = withXFTPServer $ do
|
||||
where
|
||||
testReceiveDelete clientId rfd originalFilePath = do
|
||||
rcp <- getSMPAgentClient' clientId agentCfg initAgentServers testDB2
|
||||
runRight_ $ do
|
||||
rfId <- testReceive rcp rfd originalFilePath
|
||||
xftpDeleteRcvFile rcp rfId
|
||||
rfId <- runRight $ testReceive rcp rfd originalFilePath
|
||||
xftpDeleteRcvFile rcp rfId
|
||||
disposeAgentClient rcp
|
||||
|
||||
testXFTPAgentSendReceiveEncrypted :: HasCallStack => IO ()
|
||||
@@ -127,7 +126,7 @@ testXFTPAgentSendReceiveEncrypted = withXFTPServer $ do
|
||||
sndr <- getSMPAgentClient' 1 agentCfg initAgentServers testDB
|
||||
(rfd1, rfd2) <- runRight $ do
|
||||
(sfId, _, rfd1, rfd2) <- testSendCF sndr file
|
||||
xftpDeleteSndFileInternal sndr sfId
|
||||
liftIO $ xftpDeleteSndFileInternal sndr sfId
|
||||
pure (rfd1, rfd2)
|
||||
-- receive file, delete rcv file
|
||||
testReceiveDelete 2 rfd1 filePath g
|
||||
@@ -136,9 +135,8 @@ testXFTPAgentSendReceiveEncrypted = withXFTPServer $ do
|
||||
testReceiveDelete clientId rfd originalFilePath g = do
|
||||
rcp <- getSMPAgentClient' clientId agentCfg initAgentServers testDB2
|
||||
cfArgs <- atomically $ Just <$> CF.randomArgs g
|
||||
runRight_ $ do
|
||||
rfId <- testReceiveCF rcp rfd cfArgs originalFilePath
|
||||
xftpDeleteRcvFile rcp rfId
|
||||
rfId <- runRight $ testReceiveCF rcp rfd cfArgs originalFilePath
|
||||
xftpDeleteRcvFile rcp rfId
|
||||
disposeAgentClient rcp
|
||||
|
||||
testXFTPAgentSendReceiveRedirect :: HasCallStack => IO ()
|
||||
@@ -468,11 +466,9 @@ testXFTPAgentDelete = withGlobalLogging logCfgNoLogs $
|
||||
length <$> listDirectory xftpServerFiles `shouldReturn` 6
|
||||
|
||||
-- delete file
|
||||
runRight $ do
|
||||
xftpStartWorkers sndr (Just senderFiles)
|
||||
xftpDeleteSndFileRemote sndr 1 sfId sndDescr
|
||||
Nothing <- liftIO $ 100000 `timeout` sfGet sndr
|
||||
pure ()
|
||||
runRight_ $ xftpStartWorkers sndr (Just senderFiles)
|
||||
xftpDeleteSndFileRemote sndr 1 sfId sndDescr
|
||||
Nothing <- 100000 `timeout` sfGet sndr
|
||||
disposeAgentClient rcp1
|
||||
|
||||
threadDelay 1000000
|
||||
@@ -505,10 +501,9 @@ testXFTPAgentDeleteRestore = withGlobalLogging logCfgNoLogs $ do
|
||||
|
||||
-- delete file - should not succeed with server down
|
||||
sndr <- getSMPAgentClient' 3 agentCfg initAgentServers testDB
|
||||
runRight $ do
|
||||
xftpStartWorkers sndr (Just senderFiles)
|
||||
xftpDeleteSndFileRemote sndr 1 sfId sndDescr
|
||||
liftIO $ timeout 300000 (get sndr) `shouldReturn` Nothing -- wait for worker attempt
|
||||
runRight_ $ xftpStartWorkers sndr (Just senderFiles)
|
||||
xftpDeleteSndFileRemote sndr 1 sfId sndDescr
|
||||
timeout 300000 (get sndr) `shouldReturn` Nothing -- wait for worker attempt
|
||||
disposeAgentClient sndr
|
||||
|
||||
threadDelay 300000
|
||||
@@ -636,4 +631,4 @@ testXFTPServerTest :: HasCallStack => Maybe BasicAuth -> XFTPServerWithAuth -> I
|
||||
testXFTPServerTest newFileBasicAuth srv =
|
||||
withXFTPServerCfg testXFTPServerConfig {newFileBasicAuth, xftpPort = xftpTestPort2} $ \_ -> do
|
||||
a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB -- initially passed server is not running
|
||||
runRight $ testProtocolServer a 1 srv
|
||||
testProtocolServer a 1 srv
|
||||
|
||||
@@ -12,7 +12,6 @@ import Control.Exception (SomeException)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import qualified Data.ByteString.Base64.URL as B64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
@@ -26,9 +25,9 @@ import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (..)
|
||||
import Simplex.Messaging.Client (ProtocolClientError (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Protocol (BasicAuth, SenderId)
|
||||
import Simplex.Messaging.Server.Expiration (ExpirationConfig (..))
|
||||
import Simplex.Messaging.Util (liftIOEither)
|
||||
import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive, removeFile)
|
||||
import System.FilePath ((</>))
|
||||
import Test.Hspec
|
||||
@@ -60,6 +59,7 @@ xftpServerTests =
|
||||
it "prohibited when FNEW disabled" $ testFileBasicAuth False (Just "pwd") (Just "pwd") False
|
||||
it "allowed with correct basic auth" $ testFileBasicAuth True (Just "pwd") (Just "pwd") True
|
||||
it "allowed with auth on server without auth" $ testFileBasicAuth True Nothing (Just "any") True
|
||||
it "should not change content for uploaded and committed files" testFileSkipCommitted
|
||||
|
||||
chSize :: Integral a => a
|
||||
chSize = kb 128
|
||||
@@ -75,7 +75,7 @@ createTestChunk fp = do
|
||||
pure bytes
|
||||
|
||||
readChunk :: SenderId -> IO ByteString
|
||||
readChunk sId = B.readFile (xftpServerFiles </> B.unpack (B64.encode sId))
|
||||
readChunk sId = B.readFile (xftpServerFiles </> B.unpack (U.encode sId))
|
||||
|
||||
testFileChunkDelivery :: Expectation
|
||||
testFileChunkDelivery = xftpTest $ \c -> runRight_ $ runTestFileChunkDelivery c c
|
||||
@@ -219,7 +219,7 @@ testFileChunkExpiration = withXFTPServerCfg testXFTPServerConfig {fileExpiration
|
||||
testInactiveClientExpiration :: Expectation
|
||||
testInactiveClientExpiration = withXFTPServerCfg testXFTPServerConfig {inactiveClientExpiration} $ \_ -> runRight_ $ do
|
||||
disconnected <- newEmptyTMVarIO
|
||||
c <- liftIOEither $ getXFTPClient (1, testXFTPServer, Nothing) testXFTPClientConfig (\_ -> atomically $ putTMVar disconnected ())
|
||||
c <- ExceptT $ getXFTPClient (1, testXFTPServer, Nothing) testXFTPClientConfig (\_ -> atomically $ putTMVar disconnected ())
|
||||
pingXFTP c
|
||||
liftIO $ do
|
||||
threadDelay 100000
|
||||
@@ -372,3 +372,22 @@ testFileBasicAuth allowNewFiles newFileBasicAuth clntAuth success =
|
||||
else do
|
||||
void (createXFTPChunk c spKey file [rcvKey] clntAuth)
|
||||
`catchError` (liftIO . (`shouldBe` PCEProtocolError AUTH))
|
||||
|
||||
testFileSkipCommitted :: IO ()
|
||||
testFileSkipCommitted =
|
||||
withXFTPServerCfg testXFTPServerConfig $
|
||||
\_ -> testXFTPClient $ \c -> do
|
||||
g <- C.newRandom
|
||||
(sndKey, spKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
|
||||
(rcvKey, rpKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
|
||||
bytes <- createTestChunk testChunkPath
|
||||
digest <- LC.sha256Hash <$> LB.readFile testChunkPath
|
||||
let file = FileInfo {sndKey, size = chSize, digest}
|
||||
chunkSpec = XFTPChunkSpec {filePath = testChunkPath, chunkOffset = 0, chunkSize = chSize}
|
||||
runRight_ $ do
|
||||
(sId, [rId]) <- createXFTPChunk c spKey file [rcvKey] Nothing
|
||||
uploadXFTPChunk c spKey sId chunkSpec
|
||||
void . liftIO $ createTestChunk testChunkPath -- trash chunk contents
|
||||
uploadXFTPChunk c spKey sId chunkSpec -- upload again to get FROk without getting stuck
|
||||
downloadXFTPChunk g c rpKey rId $ XFTPRcvChunkSpec "tests/tmp/received_chunk" chSize digest
|
||||
liftIO $ B.readFile "tests/tmp/received_chunk" `shouldReturn` bytes -- new chunk content got ignored
|
||||
|
||||
Reference in New Issue
Block a user