mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-25 14:14:54 +00:00
Merge remote-tracking branch 'origin/master' into ab/async-subs
This commit is contained in:
@@ -1,3 +1,22 @@
|
||||
# 5.8.0
|
||||
|
||||
Version 5.8.0.10
|
||||
|
||||
SMP server and client:
|
||||
- protocol extension to forward messages to the destination servers, to protect sending client IP address and transport session.
|
||||
|
||||
Agent:
|
||||
- process timed out subscription responses to reduce the number of resubscriptions.
|
||||
- avoid sending messages and commands when waiting for response timed out (except batched SUB and DEL commands).
|
||||
- fix issue with stuck message reception on slow connection (when response to ACK timed out, and the new message was not processed until resubscribed).
|
||||
- fix issue when temporary file sending or receiving error was treated as permanent.
|
||||
|
||||
SMP server:
|
||||
- include OK responses to all batched SUB requests to reduce subscription timeouts.
|
||||
|
||||
XFTP server:
|
||||
- report file upload timeout as TIMEOUT, to avoid delivery failure.
|
||||
|
||||
# 5.7.6
|
||||
|
||||
XFTP agent:
|
||||
|
||||
@@ -147,15 +147,6 @@ executables:
|
||||
- -threaded
|
||||
- -rtsopts
|
||||
|
||||
smp-agent:
|
||||
source-dirs: apps/smp-agent
|
||||
main: Main.hs
|
||||
dependencies:
|
||||
- simplexmq
|
||||
ghc-options:
|
||||
- -threaded
|
||||
- -rtsopts
|
||||
|
||||
xftp:
|
||||
source-dirs: apps/xftp
|
||||
main: Main.hs
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# Evolving agent API
|
||||
|
||||
## Problem
|
||||
|
||||
Historically, agent API started as a TCP protocol with encoding. We do not use the actual protocol and maintaining the encoding complicates the evolution of the API.
|
||||
|
||||
Currently, I was trying to add ERRS event to combine multiple subscription errors into one to prevent overloading the UI with processing multiple subscription errors (e.g.):
|
||||
|
||||
```haskell
|
||||
ERRS :: (ConnId, AgentErrorType) -> ACommand Agent AEConn
|
||||
```
|
||||
|
||||
This constructor is not possible to encode/parse in a sensible way other than including lengths of errors.
|
||||
|
||||
## Proposal
|
||||
|
||||
Remove commands type and encodings for commands and events.
|
||||
|
||||
Only keep encodings for the commands that are saved to the database: NEW, JOIN, LET, ACK, SWCH, DEL (this one is no longer used but needs to be supported for backwards compatibility).
|
||||
@@ -95,7 +95,6 @@ library
|
||||
Simplex.Messaging.Agent.Protocol
|
||||
Simplex.Messaging.Agent.QueryString
|
||||
Simplex.Messaging.Agent.RetryInterval
|
||||
Simplex.Messaging.Agent.Server
|
||||
Simplex.Messaging.Agent.Store
|
||||
Simplex.Messaging.Agent.Store.SQLite
|
||||
Simplex.Messaging.Agent.Store.SQLite.Common
|
||||
@@ -352,81 +351,6 @@ executable ntf-server
|
||||
, template-haskell ==2.16.*
|
||||
, text >=1.2.3.0 && <1.3
|
||||
|
||||
executable smp-agent
|
||||
main-is: Main.hs
|
||||
other-modules:
|
||||
Paths_simplexmq
|
||||
hs-source-dirs:
|
||||
apps/smp-agent
|
||||
default-extensions:
|
||||
StrictData
|
||||
ghc-options: -Weverything -Wno-missing-exported-signatures -Wno-missing-import-lists -Wno-missed-specialisations -Wno-all-missed-specialisations -Wno-unsafe -Wno-safe -Wno-missing-local-signatures -Wno-missing-kind-signatures -Wno-missing-deriving-strategies -Wno-monomorphism-restriction -Wno-prepositive-qualified-module -Wno-unused-packages -Wno-implicit-prelude -Wno-missing-safe-haskell-mode -Wno-missing-export-lists -Wno-partial-fields -Wcompat -Werror=incomplete-record-updates -Werror=incomplete-patterns -Werror=incomplete-uni-patterns -Werror=missing-methods -Werror=tabs -Wredundant-constraints -Wincomplete-record-updates -Wunused-type-patterns -O2 -threaded -rtsopts
|
||||
build-depends:
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
, asn1-encoding ==0.9.*
|
||||
, asn1-types ==0.3.*
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
, containers ==0.6.*
|
||||
, crypton ==0.34.*
|
||||
, crypton-x509 ==1.7.*
|
||||
, crypton-x509-store ==1.6.*
|
||||
, crypton-x509-validation ==1.6.*
|
||||
, cryptostore ==0.3.*
|
||||
, data-default ==0.7.*
|
||||
, direct-sqlcipher ==2.3.*
|
||||
, directory ==1.3.*
|
||||
, filepath ==1.4.*
|
||||
, hourglass ==0.2.*
|
||||
, http-types ==0.12.*
|
||||
, http2 >=4.2.2 && <4.3
|
||||
, ini ==0.4.1
|
||||
, iproute ==1.7.*
|
||||
, iso8601-time ==0.1.*
|
||||
, memory ==0.18.*
|
||||
, mtl >=2.3.1 && <3.0
|
||||
, network >=3.1.2.7 && <3.2
|
||||
, network-info ==0.2.*
|
||||
, network-transport ==0.5.6
|
||||
, network-udp ==0.0.*
|
||||
, optparse-applicative >=0.15 && <0.17
|
||||
, process ==1.6.*
|
||||
, random >=1.1 && <1.3
|
||||
, simple-logger ==0.1.*
|
||||
, simplexmq
|
||||
, socks ==0.6.*
|
||||
, sqlcipher-simple ==0.4.*
|
||||
, stm ==2.5.*
|
||||
, temporary ==1.3.*
|
||||
, time ==1.12.*
|
||||
, time-manager ==0.0.*
|
||||
, tls >=1.7.0 && <1.8
|
||||
, transformers ==0.6.*
|
||||
, unliftio ==0.2.*
|
||||
, unliftio-core ==0.2.*
|
||||
, websockets ==0.12.*
|
||||
, yaml ==0.11.*
|
||||
, zstd ==0.1.3.*
|
||||
default-language: Haskell2010
|
||||
if flag(swift)
|
||||
cpp-options: -DswiftJSON
|
||||
if impl(ghc >= 9.6.2)
|
||||
build-depends:
|
||||
bytestring ==0.11.*
|
||||
, template-haskell ==2.20.*
|
||||
, text >=2.0.1 && <2.2
|
||||
if impl(ghc < 9.6.2)
|
||||
build-depends:
|
||||
bytestring ==0.10.*
|
||||
, template-haskell ==2.16.*
|
||||
, text >=1.2.3.0 && <1.3
|
||||
|
||||
executable smp-server
|
||||
main-is: Main.hs
|
||||
other-modules:
|
||||
@@ -677,7 +601,6 @@ test-suite simplexmq-test
|
||||
CoreTests.CryptoFileTests
|
||||
CoreTests.CryptoTests
|
||||
CoreTests.EncodingTests
|
||||
CoreTests.ProtocolErrorTests
|
||||
CoreTests.RetryIntervalTests
|
||||
CoreTests.TRcvQueuesTests
|
||||
CoreTests.UtilTests
|
||||
|
||||
@@ -32,12 +32,13 @@ import Control.Logger.Simple (logError)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Composition ((.:))
|
||||
import Data.Either (rights)
|
||||
import Data.Either (partitionEithers, rights)
|
||||
import Data.Int (Int64)
|
||||
import Data.List (foldl', partition, sortOn)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
@@ -56,6 +57,7 @@ import Simplex.FileTransfer.Protocol (FileParty (..), SFileParty (..))
|
||||
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..))
|
||||
import qualified Simplex.FileTransfer.Transport as XFTP
|
||||
import Simplex.FileTransfer.Types
|
||||
import qualified Simplex.FileTransfer.Types as FT
|
||||
import Simplex.FileTransfer.Util (removePath, uniqueCombine)
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
@@ -69,7 +71,8 @@ import qualified Simplex.Messaging.Crypto.File as CF
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String (strDecode, strEncode)
|
||||
import Simplex.Messaging.Protocol (EntityId, XFTPServer)
|
||||
import Simplex.Messaging.Protocol (EntityId, ProtocolServer, ProtocolType (..), XFTPServer)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (catchAll_, liftError, tshow, unlessM, whenM)
|
||||
import System.FilePath (takeFileName, (</>))
|
||||
import UnliftIO
|
||||
@@ -141,7 +144,7 @@ xftpReceiveFile' c userId (ValidFileDescription fd@FileDescription {chunks, redi
|
||||
downloadChunk :: AgentClient -> FileChunk -> AM ()
|
||||
downloadChunk c FileChunk {replicas = (FileChunkReplica {server} : _)} = do
|
||||
lift . void $ getXFTPRcvWorker True c (Just server)
|
||||
downloadChunk _ _ = throwError $ INTERNAL "no replicas"
|
||||
downloadChunk _ _ = throwE $ INTERNAL "no replicas"
|
||||
|
||||
getPrefixPath :: String -> AM' FilePath
|
||||
getPrefixPath suffix = do
|
||||
@@ -174,7 +177,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do
|
||||
runXFTPOperation cfg
|
||||
where
|
||||
runXFTPOperation :: AgentConfig -> AM ()
|
||||
runXFTPOperation AgentConfig {rcvFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} =
|
||||
runXFTPOperation AgentConfig {rcvFilesTTL, reconnectInterval = ri, xftpConsecutiveRetries} =
|
||||
withWork c doWork (\db -> getNextRcvChunkToDownload db srv rcvFilesTTL) $ \case
|
||||
(RcvFileChunk {rcvFileId, rcvFileEntityId, fileTmpPath, replicas = []}, _) -> rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath) (INTERNAL "chunk has no replicas")
|
||||
(fc@RcvFileChunk {userId, rcvFileId, rcvFileEntityId, digest, fileTmpPath, replicas = replica@RcvFileChunkReplica {rcvChunkReplicaId, server, delay} : _}, approvedRelays) -> do
|
||||
@@ -186,7 +189,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do
|
||||
where
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
when notifyOnRetry $ notify c rcvFileEntityId $ RFERR e
|
||||
when (serverHostError e) $ notify c rcvFileEntityId $ RFWARN e
|
||||
liftIO $ closeXFTPServerClient c userId server digest
|
||||
withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay
|
||||
atomically $ assertAgentForeground c
|
||||
@@ -194,7 +197,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do
|
||||
retryDone = rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath)
|
||||
downloadFileChunk :: RcvFileChunk -> RcvFileChunkReplica -> Bool -> AM ()
|
||||
downloadFileChunk RcvFileChunk {userId, rcvFileId, rcvFileEntityId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath} replica approvedRelays = do
|
||||
unlessM ((approvedRelays ||) <$> ipAddressProtected') $ throwError $ XFTP "" XFTP.NOT_APPROVED
|
||||
unlessM ((approvedRelays ||) <$> ipAddressProtected') $ throwE $ FILE NOT_APPROVED
|
||||
fsFileTmpPath <- lift $ toFSFilePath fileTmpPath
|
||||
chunkPath <- uniqueCombine fsFileTmpPath $ show chunkNo
|
||||
let chunkSpec = XFTPRcvChunkSpec chunkPath (unFileSize chunkSize) (unFileDigest digest)
|
||||
@@ -235,7 +238,7 @@ withRetryIntervalLimit maxN ri action =
|
||||
retryOnError :: Text -> AM a -> AM a -> AgentErrorType -> AM a
|
||||
retryOnError name loop done e = do
|
||||
logError $ name <> " error: " <> tshow e
|
||||
if temporaryAgentError e
|
||||
if temporaryOrHostError e
|
||||
then loop
|
||||
else done
|
||||
|
||||
@@ -267,11 +270,11 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
|
||||
withStore' c $ \db -> updateRcvFileStatus db rcvFileId RFSDecrypting
|
||||
chunkPaths <- getChunkPaths chunks
|
||||
encSize <- liftIO $ foldM (\s path -> (s +) . fromIntegral <$> getFileSize path) 0 chunkPaths
|
||||
when (FileSize encSize /= size) $ throwError $ XFTP "" XFTP.SIZE
|
||||
when (FileSize encSize /= size) $ throwE $ XFTP "" XFTP.SIZE
|
||||
encDigest <- liftIO $ LC.sha512Hash <$> readChunks chunkPaths
|
||||
when (FileDigest encDigest /= digest) $ throwError $ XFTP "" XFTP.DIGEST
|
||||
when (FileDigest encDigest /= digest) $ throwE $ XFTP "" XFTP.DIGEST
|
||||
let destFile = CryptoFile fsSavePath cfArgs
|
||||
void $ liftError (INTERNAL . show) $ decryptChunks encSize chunkPaths key nonce $ \_ -> pure destFile
|
||||
void $ liftError (FILE . FILE_IO . show) $ decryptChunks encSize chunkPaths key nonce $ \_ -> pure destFile
|
||||
case redirect of
|
||||
Nothing -> do
|
||||
notify c rcvFileEntityId $ RFDONE fsSavePath
|
||||
@@ -284,13 +287,13 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
|
||||
atomically $ waitUntilForeground c
|
||||
withStore' c (`updateRcvFileComplete` rcvFileId)
|
||||
-- proceed with redirect
|
||||
yaml <- liftError (INTERNAL . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath)
|
||||
yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath)
|
||||
next@FileDescription {chunks = nextChunks} <- case strDecode (LB.toStrict yaml) of
|
||||
-- TODO switch to another error constructor
|
||||
Left _ -> throwError . XFTP "" $ XFTP.REDIRECT "decode error"
|
||||
Left _ -> throwE . FILE $ REDIRECT "decode error"
|
||||
Right (ValidFileDescription fd@FileDescription {size = dstSize, digest = dstDigest})
|
||||
| dstSize /= redirectSize -> throwError . XFTP "" $ XFTP.REDIRECT "size mismatch"
|
||||
| dstDigest /= redirectDigest -> throwError . XFTP "" $ XFTP.REDIRECT "digest mismatch"
|
||||
| dstSize /= redirectSize -> throwE . FILE $ REDIRECT "size mismatch"
|
||||
| dstDigest /= redirectDigest -> throwE . FILE $ REDIRECT "digest mismatch"
|
||||
| otherwise -> pure fd
|
||||
-- register and download chunks from the actual file
|
||||
withStore c $ \db -> updateRcvFileRedirect db redirectDbId next
|
||||
@@ -303,7 +306,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
|
||||
fsPath <- lift $ toFSFilePath path
|
||||
pure $ fsPath : ps
|
||||
getChunkPaths (RcvFileChunk {chunkTmpPath = Nothing} : _cs) =
|
||||
throwError $ INTERNAL "no chunk path"
|
||||
throwE $ INTERNAL "no chunk path"
|
||||
|
||||
xftpDeleteRcvFile' :: AgentClient -> RcvFileId -> AM' ()
|
||||
xftpDeleteRcvFile' c rcvFileEntityId = xftpDeleteRcvFiles' c [rcvFileEntityId]
|
||||
@@ -323,8 +326,8 @@ xftpDeleteRcvFiles' c rcvFileEntityIds = do
|
||||
batchFiles :: (DB.Connection -> DBRcvFileId -> IO a) -> [RcvFile] -> AM' [Either AgentErrorType a]
|
||||
batchFiles f rcvFiles = withStoreBatch' c $ \db -> map (\RcvFile {rcvFileId} -> f db rcvFileId) rcvFiles
|
||||
|
||||
notify :: forall m e. (MonadIO m, AEntityI e) => AgentClient -> EntityId -> ACommand 'Agent e -> m ()
|
||||
notify c entId cmd = atomically $ writeTBQueue (subQ c) ("", entId, APC (sAEntity @e) cmd)
|
||||
notify :: forall m e. (MonadIO m, AEntityI e) => AgentClient -> EntityId -> AEvent e -> m ()
|
||||
notify c entId cmd = atomically $ writeTBQueue (subQ c) ("", entId, AEvt (sAEntity @e) cmd)
|
||||
|
||||
xftpSendFile' :: AgentClient -> UserId -> CryptoFile -> Int -> AM SndFileId
|
||||
xftpSendFile' c userId file numRecipients = do
|
||||
@@ -348,7 +351,7 @@ xftpSendDescription' c userId (ValidFileDescription fdDirect@FileDescription {si
|
||||
let directYaml = prefixPath </> "direct.yaml"
|
||||
cfArgs <- atomically $ CF.randomArgs g
|
||||
let file = CryptoFile directYaml (Just cfArgs)
|
||||
liftError (INTERNAL . show) $ CF.writeFile file (LB.fromStrict $ strEncode fdDirect)
|
||||
liftError (FILE . FILE_IO . show) $ CF.writeFile file (LB.fromStrict $ strEncode fdDirect)
|
||||
key <- atomically $ C.randomSbKey g
|
||||
nonce <- atomically $ C.randomCbNonce g
|
||||
fId <- withStore c $ \db -> createSndFile db g userId file numRecipients relPrefixPath key nonce $ Just RedirectFileInfo {size, digest}
|
||||
@@ -376,11 +379,11 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
runXFTPOperation cfg@AgentConfig {sndFilesTTL} =
|
||||
withWork c doWork (`getNextSndFileToPrepare` sndFilesTTL) $
|
||||
\f@SndFile {sndFileId, sndFileEntityId, prefixPath} ->
|
||||
prepareFile cfg f `catchAgentError` (sndWorkerInternalError c sndFileId sndFileEntityId prefixPath . show)
|
||||
prepareFile cfg f `catchAgentError` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath
|
||||
prepareFile :: AgentConfig -> SndFile -> AM ()
|
||||
prepareFile _ SndFile {prefixPath = Nothing} =
|
||||
throwError $ INTERNAL "no prefix path"
|
||||
prepareFile cfg sndFile@SndFile {sndFileId, userId, prefixPath = Just ppath, status} = do
|
||||
throwE $ INTERNAL "no prefix path"
|
||||
prepareFile cfg sndFile@SndFile {sndFileId, sndFileEntityId, userId, prefixPath = Just ppath, status} = do
|
||||
SndFile {numRecipients, chunks} <-
|
||||
if status /= SFSEncrypted -- status is SFSNew or SFSEncrypting
|
||||
then do
|
||||
@@ -394,9 +397,14 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
getSndFile db sndFileId
|
||||
else pure sndFile
|
||||
let numRecipients' = min numRecipients maxRecipients
|
||||
-- in case chunk preparation previously failed mid-way, some chunks may already be created -
|
||||
-- here we split previously prepared chunks from the pending ones to then build full list of servers
|
||||
let (pendingChunks, preparedSrvs) = partitionEithers $ map srvOrPendingChunk chunks
|
||||
-- concurrently?
|
||||
-- separate worker to create chunks? record retries and delay on snd_file_chunks?
|
||||
forM_ (filter (\SndFileChunk {replicas} -> null replicas) chunks) $ createChunk numRecipients'
|
||||
srvs <- forM pendingChunks $ createChunk numRecipients'
|
||||
let allSrvs = S.fromList $ preparedSrvs <> srvs
|
||||
lift $ forM_ allSrvs $ \srv -> getXFTPSndWorker True c (Just srv)
|
||||
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSUploading
|
||||
where
|
||||
AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients, messageRetryInterval = ri} = cfg
|
||||
@@ -405,48 +413,60 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
let CryptoFile {filePath} = srcFile
|
||||
fileName = takeFileName filePath
|
||||
fileSize <- liftIO $ fromInteger <$> CF.getFileContentsSize srcFile
|
||||
when (fileSize > maxFileSizeHard) $ throwError $ INTERNAL "max file size exceeded"
|
||||
when (fileSize > maxFileSizeHard) $ throwE $ FILE FT.SIZE
|
||||
let fileHdr = smpEncode FileHeader {fileName, fileExtra = Nothing}
|
||||
fileSize' = fromIntegral (B.length fileHdr) + fileSize
|
||||
payloadSize = fileSize' + fileSizeLen + authTagSize
|
||||
chunkSizes <- case redirect of
|
||||
Nothing -> pure $ prepareChunkSizes payloadSize
|
||||
Just _ -> case singleChunkSize payloadSize of
|
||||
Nothing -> throwError $ INTERNAL "max file size exceeded for redirect"
|
||||
Nothing -> throwE $ FILE FT.SIZE
|
||||
Just chunkSize -> pure [chunkSize]
|
||||
let encSize = sum $ map fromIntegral chunkSizes
|
||||
void $ liftError (INTERNAL . show) $ encryptFile srcFile fileHdr key nonce fileSize' encSize fsEncPath
|
||||
void $ liftError (FILE . FILE_IO . show) $ encryptFile srcFile fileHdr key nonce fileSize' encSize fsEncPath
|
||||
digest <- liftIO $ LC.sha512Hash <$> LB.readFile fsEncPath
|
||||
let chunkSpecs = prepareChunkSpecs fsEncPath chunkSizes
|
||||
chunkDigests <- liftIO $ mapM getChunkDigest chunkSpecs
|
||||
pure (FileDigest digest, zip chunkSpecs $ coerce chunkDigests)
|
||||
createChunk :: Int -> SndFileChunk -> AM ()
|
||||
srvOrPendingChunk :: SndFileChunk -> Either SndFileChunk (ProtocolServer 'PXFTP)
|
||||
srvOrPendingChunk ch@SndFileChunk {replicas} = case replicas of
|
||||
[] -> Left ch
|
||||
SndFileChunkReplica {server} : _ -> Right server
|
||||
createChunk :: Int -> SndFileChunk -> AM (ProtocolServer 'PXFTP)
|
||||
createChunk numRecipients' ch = do
|
||||
atomically $ assertAgentForeground c
|
||||
(replica, ProtoServerWithAuth srv _) <- tryCreate
|
||||
withStore' c $ \db -> createSndFileReplica db ch replica
|
||||
lift . void $ getXFTPSndWorker True c (Just srv)
|
||||
pure srv
|
||||
where
|
||||
tryCreate = do
|
||||
usedSrvs <- newTVarIO ([] :: [XFTPServer])
|
||||
withRetryInterval (riFast ri) $ \_ loop -> do
|
||||
let AgentClient {xftpServers} = c
|
||||
userSrvCount <- length <$> atomically (TM.lookup userId xftpServers)
|
||||
withRetryIntervalCount (riFast ri) $ \n _ loop -> do
|
||||
liftIO $ waitForUserNetwork c
|
||||
let triedAllSrvs = n > userSrvCount
|
||||
createWithNextSrv usedSrvs
|
||||
`catchAgentError` \e -> retryOnError "XFTP prepare worker" (retryLoop loop) (throwError e) e
|
||||
`catchAgentError` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e
|
||||
where
|
||||
retryLoop loop = atomically (assertAgentForeground c) >> loop
|
||||
-- we don't do closeXFTPServerClient here to not risk closing connection for concurrent chunk upload
|
||||
retryLoop loop triedAllSrvs e = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
when (triedAllSrvs && serverHostError e) $ notify c sndFileEntityId $ SFWARN e
|
||||
atomically $ assertAgentForeground c
|
||||
loop
|
||||
createWithNextSrv usedSrvs = do
|
||||
deleted <- withStore' c $ \db -> getSndFileDeleted db sndFileId
|
||||
when deleted $ throwError $ INTERNAL "file deleted, aborting chunk creation"
|
||||
when deleted $ throwE $ FILE NO_FILE
|
||||
withNextSrv c userId usedSrvs [] $ \srvAuth -> do
|
||||
replica <- agentXFTPNewChunk c ch numRecipients' srvAuth
|
||||
pure (replica, srvAuth)
|
||||
|
||||
sndWorkerInternalError :: AgentClient -> DBSndFileId -> SndFileId -> Maybe FilePath -> String -> AM ()
|
||||
sndWorkerInternalError c sndFileId sndFileEntityId prefixPath internalErrStr = do
|
||||
sndWorkerInternalError :: AgentClient -> DBSndFileId -> SndFileId -> Maybe FilePath -> AgentErrorType -> AM ()
|
||||
sndWorkerInternalError c sndFileId sndFileEntityId prefixPath err = do
|
||||
lift . forM_ prefixPath $ removePath <=< toFSFilePath
|
||||
withStore' c $ \db -> updateSndFileError db sndFileId internalErrStr
|
||||
notify c sndFileEntityId $ SFERR $ INTERNAL internalErrStr
|
||||
withStore' c $ \db -> updateSndFileError db sndFileId (show err)
|
||||
notify c sndFileEntityId $ SFERR err
|
||||
|
||||
runXFTPSndWorker :: AgentClient -> XFTPServer -> Worker -> AM ()
|
||||
runXFTPSndWorker c srv Worker {doWork} = do
|
||||
@@ -457,9 +477,9 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
runXFTPOperation cfg
|
||||
where
|
||||
runXFTPOperation :: AgentConfig -> AM ()
|
||||
runXFTPOperation cfg@AgentConfig {sndFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} = do
|
||||
runXFTPOperation cfg@AgentConfig {sndFilesTTL, reconnectInterval = ri, xftpConsecutiveRetries} = do
|
||||
withWork c doWork (\db -> getNextSndChunkToUpload db srv sndFilesTTL) $ \case
|
||||
SndFileChunk {sndFileId, sndFileEntityId, filePrefixPath, replicas = []} -> sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) "chunk has no replicas"
|
||||
SndFileChunk {sndFileId, sndFileEntityId, filePrefixPath, replicas = []} -> sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) (INTERNAL "chunk has no replicas")
|
||||
fc@SndFileChunk {userId, sndFileId, sndFileEntityId, filePrefixPath, digest, replicas = replica@SndFileChunkReplica {sndChunkReplicaId, server, delay} : _} -> do
|
||||
let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay
|
||||
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do
|
||||
@@ -469,17 +489,17 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
where
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
when notifyOnRetry $ notify c sndFileEntityId $ SFERR e
|
||||
when (serverHostError e) $ notify c sndFileEntityId $ SFWARN e
|
||||
liftIO $ closeXFTPServerClient c userId server digest
|
||||
withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay
|
||||
atomically $ assertAgentForeground c
|
||||
loop
|
||||
retryDone e = sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) (show e)
|
||||
retryDone = sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath)
|
||||
uploadFileChunk :: AgentConfig -> SndFileChunk -> SndFileChunkReplica -> AM ()
|
||||
uploadFileChunk AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients} sndFileChunk@SndFileChunk {sndFileId, userId, chunkSpec = chunkSpec@XFTPChunkSpec {filePath}, digest = chunkDigest} replica = do
|
||||
replica'@SndFileChunkReplica {sndChunkReplicaId} <- addRecipients sndFileChunk replica
|
||||
fsFilePath <- lift $ toFSFilePath filePath
|
||||
unlessM (doesFileExist fsFilePath) $ throwError $ INTERNAL "encrypted file doesn't exist on upload"
|
||||
unlessM (doesFileExist fsFilePath) $ throwE $ FILE NO_FILE
|
||||
let chunkSpec' = chunkSpec {filePath = fsFilePath} :: XFTPChunkSpec
|
||||
atomically $ assertAgentForeground c
|
||||
agentXFTPUploadChunk c userId chunkDigest replica' chunkSpec'
|
||||
@@ -499,7 +519,7 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
where
|
||||
addRecipients :: SndFileChunk -> SndFileChunkReplica -> AM SndFileChunkReplica
|
||||
addRecipients ch@SndFileChunk {numRecipients} cr@SndFileChunkReplica {rcvIdsKeys}
|
||||
| length rcvIdsKeys > numRecipients = throwError $ INTERNAL "too many recipients"
|
||||
| length rcvIdsKeys > numRecipients = throwE $ INTERNAL "too many recipients"
|
||||
| length rcvIdsKeys == numRecipients = pure cr
|
||||
| otherwise = do
|
||||
let numRecipients' = min (numRecipients - length rcvIdsKeys) maxRecipients
|
||||
@@ -507,22 +527,22 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
cr' <- withStore' c $ \db -> addSndChunkReplicaRecipients db cr $ L.toList rcvIdsKeys'
|
||||
addRecipients ch cr'
|
||||
sndFileToDescrs :: SndFile -> AM (ValidFileDescription 'FSender, [ValidFileDescription 'FRecipient])
|
||||
sndFileToDescrs SndFile {digest = Nothing} = throwError $ INTERNAL "snd file has no digest"
|
||||
sndFileToDescrs SndFile {chunks = []} = throwError $ INTERNAL "snd file has no chunks"
|
||||
sndFileToDescrs SndFile {digest = Nothing} = throwE $ INTERNAL "snd file has no digest"
|
||||
sndFileToDescrs SndFile {chunks = []} = throwE $ INTERNAL "snd file has no chunks"
|
||||
sndFileToDescrs SndFile {digest = Just digest, key, nonce, chunks = chunks@(fstChunk : _), redirect} = do
|
||||
let chunkSize = FileSize $ sndChunkSize fstChunk
|
||||
size = FileSize $ sum $ map (fromIntegral . sndChunkSize) chunks
|
||||
-- snd description
|
||||
sndDescrChunks <- mapM toSndDescrChunk chunks
|
||||
let fdSnd = FileDescription {party = SFSender, size, digest, key, nonce, chunkSize, chunks = sndDescrChunks, redirect = Nothing}
|
||||
validFdSnd <- either (throwError . INTERNAL) pure $ validateFileDescription fdSnd
|
||||
validFdSnd <- either (throwE . INTERNAL) pure $ validateFileDescription fdSnd
|
||||
-- rcv descriptions
|
||||
let fdRcv = FileDescription {party = SFRecipient, size, digest, key, nonce, chunkSize, chunks = [], redirect}
|
||||
fdRcvs = createRcvFileDescriptions fdRcv chunks
|
||||
validFdRcvs <- either (throwError . INTERNAL) pure $ mapM validateFileDescription fdRcvs
|
||||
validFdRcvs <- either (throwE . INTERNAL) pure $ mapM validateFileDescription fdRcvs
|
||||
pure (validFdSnd, validFdRcvs)
|
||||
toSndDescrChunk :: SndFileChunk -> AM FileChunk
|
||||
toSndDescrChunk SndFileChunk {replicas = []} = throwError $ INTERNAL "snd file chunk has no replicas"
|
||||
toSndDescrChunk SndFileChunk {replicas = []} = throwE $ INTERNAL "snd file chunk has no replicas"
|
||||
toSndDescrChunk ch@SndFileChunk {chunkNo, digest = chDigest, replicas = (SndFileChunkReplica {server, replicaId, replicaKey} : _)} = do
|
||||
let chunkSize = FileSize $ sndChunkSize ch
|
||||
replicas = [FileChunkReplica {server, replicaId, replicaKey}]
|
||||
@@ -623,7 +643,7 @@ runXFTPDelWorker c srv Worker {doWork} = do
|
||||
runXFTPOperation cfg
|
||||
where
|
||||
runXFTPOperation :: AgentConfig -> AM ()
|
||||
runXFTPOperation AgentConfig {rcvFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} = do
|
||||
runXFTPOperation AgentConfig {rcvFilesTTL, reconnectInterval = ri, xftpConsecutiveRetries} = do
|
||||
-- no point in deleting files older than rcv ttl, as they will be expired on server
|
||||
withWork c doWork (\db -> getNextDeletedSndChunkReplica db srv rcvFilesTTL) processDeletedReplica
|
||||
where
|
||||
@@ -636,7 +656,7 @@ runXFTPDelWorker c srv Worker {doWork} = do
|
||||
where
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
when notifyOnRetry $ notify c "" $ SFERR e
|
||||
when (serverHostError e) $ notify c "" $ SFWARN e
|
||||
liftIO $ closeXFTPServerClient c userId server chunkDigest
|
||||
withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay
|
||||
atomically $ assertAgentForeground c
|
||||
|
||||
@@ -138,9 +138,9 @@ xftpClientHandshakeV1 serverVRange keyHash@(C.KeyHash kh) c@HTTP2Client {session
|
||||
liftTransportErr (TEHandshake PARSE) . smpDecode =<< liftTransportErr TEBadBlock (C.unPad shsBody)
|
||||
processServerHandshake :: XFTPServerHandshake -> ExceptT XFTPClientError IO (VersionRangeXFTP, C.PublicKeyX25519)
|
||||
processServerHandshake XFTPServerHandshake {xftpVersionRange, sessionId = serverSessId, authPubKey = serverAuth} = do
|
||||
unless (sessionId == serverSessId) $ throwError $ PCETransportError TEBadSession
|
||||
unless (sessionId == serverSessId) $ throwE $ PCETransportError TEBadSession
|
||||
case xftpVersionRange `compatibleVRange` serverVRange of
|
||||
Nothing -> throwError $ PCETransportError TEVersion
|
||||
Nothing -> throwE $ PCETransportError TEVersion
|
||||
Just (Compatible vr) ->
|
||||
fmap (vr,) . liftTransportErr (TEHandshake BAD_AUTH) $ do
|
||||
let (X.CertificateChain cert, exact) = serverAuth
|
||||
@@ -154,7 +154,7 @@ xftpClientHandshakeV1 serverVRange keyHash@(C.KeyHash kh) c@HTTP2Client {session
|
||||
chs' <- liftTransportErr TELargeMsg $ C.pad (smpEncode chs) xftpBlockSize
|
||||
let chsReq = H.requestBuilder "POST" "/" [] $ byteString chs'
|
||||
HTTP2Response {respBody = HTTP2Body {bodyHead}} <- liftError' xftpClientError $ sendRequest c chsReq Nothing
|
||||
unless (B.null bodyHead) $ throwError $ PCETransportError TEBadBlock
|
||||
unless (B.null bodyHead) $ throwE $ PCETransportError TEBadBlock
|
||||
liftTransportErr e = liftEitherWith (const $ PCETransportError e)
|
||||
|
||||
closeXFTPClient :: XFTPClient -> IO ()
|
||||
@@ -200,14 +200,14 @@ sendXFTPTransmission XFTPClient {config, thParams, http2Client} t chunkSpec_ = d
|
||||
let req = H.requestStreaming N.methodPost "/" [] streamBody
|
||||
reqTimeout = xftpReqTimeout config $ (\XFTPChunkSpec {chunkSize} -> chunkSize) <$> chunkSpec_
|
||||
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req (Just reqTimeout)
|
||||
when (B.length bodyHead /= xftpBlockSize) $ throwError $ PCEResponseError BLOCK
|
||||
when (B.length bodyHead /= xftpBlockSize) $ throwE $ PCEResponseError BLOCK
|
||||
-- TODO validate that the file ID is the same as in the request?
|
||||
(_, _, (_, _fId, respOrErr)) <- liftEither . first PCEResponseError $ xftpDecodeTransmission thParams bodyHead
|
||||
case respOrErr of
|
||||
Right r -> case protocolError r of
|
||||
Just e -> throwError $ PCEProtocolError e
|
||||
Just e -> throwE $ PCEProtocolError e
|
||||
_ -> pure (r, body)
|
||||
Left e -> throwError $ PCEResponseError e
|
||||
Left e -> throwE $ PCEResponseError e
|
||||
where
|
||||
streamBody :: (Builder -> IO ()) -> IO () -> IO ()
|
||||
streamBody send done = do
|
||||
@@ -250,7 +250,7 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec {
|
||||
let dhSecret = C.dh' sDhKey rpDhKey
|
||||
cbState <- liftEither . first PCECryptoError $ LC.cbInit dhSecret cbNonce
|
||||
let t = chunkTimeout config chunkSize
|
||||
ExceptT (sequence <$> (t `timeout` (download cbState `catches` errors))) >>= maybe (throwError PCEResponseTimeout) pure
|
||||
ExceptT (sequence <$> (t `timeout` (download cbState `catches` errors))) >>= maybe (throwE PCEResponseTimeout) pure
|
||||
where
|
||||
errors =
|
||||
[ Handler $ \(_e :: H.HTTP2Error) -> pure $ Left PCENetworkError,
|
||||
@@ -260,8 +260,8 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec {
|
||||
download cbState =
|
||||
runExceptT . withExceptT PCEResponseError $
|
||||
receiveEncFile chunkPart cbState chunkSpec `catchError` \e ->
|
||||
whenM (doesFileExist filePath) (removeFile filePath) >> throwError e
|
||||
_ -> throwError $ PCEResponseError NO_FILE
|
||||
whenM (doesFileExist filePath) (removeFile filePath) >> throwE e
|
||||
_ -> throwE $ PCEResponseError NO_FILE
|
||||
(r, _) -> throwE $ unexpectedResponse r
|
||||
|
||||
xftpReqTimeout :: XFTPClientConfig -> Maybe Word32 -> Int
|
||||
@@ -296,7 +296,7 @@ okResponse = \case
|
||||
-- TODO this currently does not check anything because response size is not set and bodyPart is always Just
|
||||
noFile :: HTTP2Body -> a -> ExceptT XFTPClientError IO a
|
||||
noFile HTTP2Body {bodyPart} a = case bodyPart of
|
||||
Just _ -> pure a -- throwError $ PCEResponseError HAS_FILE
|
||||
Just _ -> pure a -- throwE $ PCEResponseError HAS_FILE
|
||||
_ -> pure a
|
||||
|
||||
-- FACK :: FileCommand Recipient
|
||||
|
||||
@@ -11,6 +11,7 @@ import Control.Logger.Simple (logInfo)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans (lift)
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Text (Text)
|
||||
@@ -108,7 +109,7 @@ getXFTPServerClient XFTPClientAgent {xftpClients, config} srv = do
|
||||
else atomically $ do
|
||||
putTMVar clientVar r
|
||||
TM.delete srv xftpClients
|
||||
throwError e
|
||||
throwE e
|
||||
tryConnectAsync :: ME ()
|
||||
tryConnectAsync = void . lift . async . runExceptT $ do
|
||||
withRetryInterval (reconnectInterval config) $ \_ loop -> void $ tryConnectClient loop
|
||||
|
||||
@@ -30,6 +30,7 @@ where
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans.Except
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
@@ -292,7 +293,7 @@ cliSendFileOpts SendOptions {filePath, outputDir, numRecipients, xftpServers, re
|
||||
encryptFileForUpload :: TVar ChaChaDRG -> String -> ExceptT CLIError IO (FilePath, FileDescription 'FRecipient, FileDescription 'FSender, [XFTPChunkSpec], Int64)
|
||||
encryptFileForUpload g fileName = do
|
||||
fileSize <- fromInteger <$> getFileSize filePath
|
||||
when (fileSize > maxFileSize) $ throwError $ CLIError $ "Files bigger than " <> maxFileSizeStr <> " are not supported"
|
||||
when (fileSize > maxFileSize) $ throwE $ CLIError $ "Files bigger than " <> maxFileSizeStr <> " are not supported"
|
||||
encPath <- getEncPath tempPath "xftp"
|
||||
key <- atomically $ C.randomSbKey g
|
||||
nonce <- atomically $ C.randomCbNonce g
|
||||
@@ -323,7 +324,7 @@ cliSendFileOpts SendOptions {filePath, outputDir, numRecipients, xftpServers, re
|
||||
-- upload doesn't allow other requests within the same client until complete (but download does allow).
|
||||
logInfo $ "uploading " <> tshow (length chunks) <> " chunks..."
|
||||
(errs, rs) <- partitionEithers . concat <$> liftIO (pooledForConcurrentlyN 16 chunks' . mapM $ runExceptT . uploadFileChunk a)
|
||||
mapM_ throwError errs
|
||||
mapM_ throwE errs
|
||||
pure $ map snd (sortOn fst rs)
|
||||
where
|
||||
uploadFileChunk :: XFTPClientAgent -> (Int, XFTPChunkSpec, XFTPServerWithAuth) -> ExceptT CLIError IO (Int, SentFileChunk)
|
||||
@@ -437,12 +438,12 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath,
|
||||
srvChunks = groupAllOn srv chunks
|
||||
g <- liftIO C.newRandom
|
||||
(errs, rs) <- partitionEithers . concat <$> liftIO (pooledForConcurrentlyN 16 srvChunks $ mapM $ runExceptT . downloadFileChunk g a encPath size downloadedChunks)
|
||||
mapM_ throwError errs
|
||||
mapM_ throwE errs
|
||||
let chunkPaths = map snd $ sortOn fst rs
|
||||
encDigest <- liftIO $ LC.sha512Hash <$> readChunks chunkPaths
|
||||
when (encDigest /= unFileDigest digest) $ throwError $ CLIError "File digest mismatch"
|
||||
when (encDigest /= unFileDigest digest) $ throwE $ CLIError "File digest mismatch"
|
||||
encSize <- liftIO $ foldM (\s path -> (s +) . fromIntegral <$> getFileSize path) 0 chunkPaths
|
||||
when (FileSize encSize /= size) $ throwError $ CLIError "File size mismatch"
|
||||
when (FileSize encSize /= size) $ throwE $ CLIError "File size mismatch"
|
||||
liftIO $ printNoNewLine "Decrypting file..."
|
||||
CryptoFile path _ <- withExceptT cliCryptoError $ decryptChunks encSize chunkPaths key nonce $ fmap CF.plain . getFilePath
|
||||
forM_ chunks $ acknowledgeFileChunk a
|
||||
@@ -464,20 +465,20 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath,
|
||||
printProgress "Downloaded" downloaded encSize
|
||||
when verbose $ putStrLn ""
|
||||
pure (chunkNo, chunkPath)
|
||||
downloadFileChunk _ _ _ _ _ _ = throwError $ CLIError "chunk has no replicas"
|
||||
downloadFileChunk _ _ _ _ _ _ = throwE $ CLIError "chunk has no replicas"
|
||||
getFilePath :: String -> ExceptT String IO FilePath
|
||||
getFilePath name =
|
||||
case filePath of
|
||||
Just path ->
|
||||
ifM (doesDirectoryExist path) (uniqueCombine path name) $
|
||||
ifM (doesFileExist path) (throwError "File already exists") (pure path)
|
||||
ifM (doesFileExist path) (throwE "File already exists") (pure path)
|
||||
_ -> (`uniqueCombine` name) . (</> "Downloads") =<< getHomeDirectory
|
||||
acknowledgeFileChunk :: XFTPClientAgent -> FileChunk -> ExceptT CLIError IO ()
|
||||
acknowledgeFileChunk a FileChunk {replicas = replica : _} = do
|
||||
let FileChunkReplica {server, replicaId, replicaKey} = replica
|
||||
c <- withRetry retryCount $ getXFTPServerClient a server
|
||||
withRetry retryCount $ ackXFTPChunk c replicaKey (unChunkReplicaId replicaId)
|
||||
acknowledgeFileChunk _ _ = throwError $ CLIError "chunk has no replicas"
|
||||
acknowledgeFileChunk _ _ = throwE $ CLIError "chunk has no replicas"
|
||||
|
||||
printProgress :: String -> Int64 -> Int64 -> IO ()
|
||||
printProgress s part total = printNoNewLine $ s <> " " <> show ((part * 100) `div` total) <> "%"
|
||||
@@ -503,7 +504,7 @@ cliDeleteFile DeleteOptions {fileDescription, retryCount, yes} = do
|
||||
let FileChunkReplica {server, replicaId, replicaKey} = replica
|
||||
withReconnect a server retryCount $ \c -> deleteXFTPChunk c replicaKey (unChunkReplicaId replicaId)
|
||||
logInfo $ "deleted chunk " <> tshow chunkNo <> " from " <> showServer server
|
||||
deleteFileChunk _ _ = throwError $ CLIError "chunk has no replicas"
|
||||
deleteFileChunk _ _ = throwE $ CLIError "chunk has no replicas"
|
||||
|
||||
cliFileDescrInfo :: InfoOptions -> ExceptT CLIError IO ()
|
||||
cliFileDescrInfo InfoOptions {fileDescription} = do
|
||||
@@ -533,7 +534,7 @@ getFileDescription path =
|
||||
getFileDescription' :: FilePartyI p => FilePath -> ExceptT CLIError IO (ValidFileDescription p)
|
||||
getFileDescription' path =
|
||||
getFileDescription path >>= \case
|
||||
AVFD fd -> either (throwError . CLIError) pure $ checkParty fd
|
||||
AVFD fd -> either (throwE . CLIError) pure $ checkParty fd
|
||||
|
||||
singleChunkSize :: Int64 -> Maybe Word32
|
||||
singleChunkSize size' =
|
||||
@@ -574,13 +575,13 @@ withReconnect a srv n run = withRetry n $ do
|
||||
c <- withRetry n $ getXFTPServerClient a srv
|
||||
withExceptT (CLIError . show) (run c) `catchError` \e -> do
|
||||
liftIO $ closeXFTPServerClient a srv
|
||||
throwError e
|
||||
throwE e
|
||||
|
||||
withRetry :: Show e => Int -> ExceptT e IO a -> ExceptT CLIError IO a
|
||||
withRetry retryCount = withRetry' retryCount . withExceptT (CLIError . show)
|
||||
where
|
||||
withRetry' :: Int -> ExceptT CLIError IO a -> ExceptT CLIError IO a
|
||||
withRetry' 0 _ = throwError $ CLIError "internal: no retry attempts"
|
||||
withRetry' 0 _ = throwE $ CLIError "internal: no retry attempts"
|
||||
withRetry' 1 a = a
|
||||
withRetry' n a =
|
||||
a `catchError` \e -> do
|
||||
|
||||
@@ -8,6 +8,7 @@ module Simplex.FileTransfer.Crypto where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans.Except
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteArray as BA
|
||||
@@ -48,17 +49,17 @@ encryptFile srcFile fileHdr key nonce fileSize' encSize encFile = do
|
||||
| otherwise = do
|
||||
let chSize = min len 65536
|
||||
ch <- liftIO $ get chSize
|
||||
when (B.length ch /= fromIntegral chSize) $ throwError $ FTCEFileIOError "encrypting file: unexpected EOF"
|
||||
when (B.length ch /= fromIntegral chSize) $ throwE $ FTCEFileIOError "encrypting file: unexpected EOF"
|
||||
let (ch', sb') = LC.sbEncryptChunk sb ch
|
||||
liftIO $ B.hPut w ch'
|
||||
encryptChunks_ get w (sb', len - chSize)
|
||||
|
||||
decryptChunks :: Int64 -> [FilePath] -> C.SbKey -> C.CbNonce -> (String -> ExceptT String IO CryptoFile) -> ExceptT FTCryptoError IO CryptoFile
|
||||
decryptChunks _ [] _ _ _ = throwError $ FTCEInvalidHeader "empty"
|
||||
decryptChunks _ [] _ _ _ = throwE $ FTCEInvalidHeader "empty"
|
||||
decryptChunks encSize (chPath : chPaths) key nonce getDestFile = case reverse chPaths of
|
||||
[] -> do
|
||||
(!authOk, !f) <- liftEither . first FTCECryptoError . LC.sbDecryptTailTag key nonce (encSize - authTagSize) =<< liftIO (LB.readFile chPath)
|
||||
unless authOk $ throwError FTCEInvalidAuthTag
|
||||
unless authOk $ throwE FTCEInvalidAuthTag
|
||||
(FileHeader {fileName}, !f') <- parseFileHeader f
|
||||
destFile <- withExceptT FTCEFileIOError $ getDestFile fileName
|
||||
CF.writeFile destFile f'
|
||||
@@ -73,7 +74,7 @@ decryptChunks encSize (chPath : chPaths) key nonce getDestFile = case reverse ch
|
||||
decryptLastChunk h state' expectedLen
|
||||
unless authOk $ do
|
||||
removeFile path
|
||||
throwError FTCEInvalidAuthTag
|
||||
throwE FTCEInvalidAuthTag
|
||||
pure destFile
|
||||
where
|
||||
decryptFirstChunk = do
|
||||
@@ -105,8 +106,8 @@ decryptChunks encSize (chPath : chPaths) key nonce getDestFile = case reverse ch
|
||||
parseFileHeader s = do
|
||||
let (hdrStr, s') = LB.splitAt 1024 s
|
||||
case A.parse smpP $ LB.toStrict hdrStr of
|
||||
A.Fail _ _ e -> throwError $ FTCEInvalidHeader e
|
||||
A.Partial _ -> throwError $ FTCEInvalidHeader "incomplete"
|
||||
A.Fail _ _ e -> throwE $ FTCEInvalidHeader e
|
||||
A.Partial _ -> throwE $ FTCEInvalidHeader "incomplete"
|
||||
A.Done rest hdr -> pure (hdr, LB.fromStrict rest <> s')
|
||||
|
||||
readChunks :: [FilePath] -> IO LB.ByteString
|
||||
|
||||
@@ -18,6 +18,7 @@ import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64.URL as B64
|
||||
import Data.ByteString.Builder (Builder, byteString)
|
||||
@@ -136,7 +137,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
|
||||
either sendError pure r
|
||||
where
|
||||
processHello = do
|
||||
unless (B.null bodyHead) $ throwError HANDSHAKE
|
||||
unless (B.null bodyHead) $ throwE HANDSHAKE
|
||||
(k, pk) <- atomically . C.generateKeyPair =<< asks random
|
||||
atomically $ TM.insert sessionId (HandshakeSent pk) sessions
|
||||
let authPubKey = (chain, C.signX509 serverSignKey $ C.publicToX509 k)
|
||||
@@ -148,11 +149,11 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
|
||||
liftIO . sendResponse $ H.responseBuilder N.ok200 [] shs
|
||||
pure Nothing
|
||||
processClientHandshake pk = do
|
||||
unless (B.length bodyHead == xftpBlockSize) $ throwError HANDSHAKE
|
||||
unless (B.length bodyHead == xftpBlockSize) $ throwE HANDSHAKE
|
||||
body <- liftHS $ C.unPad bodyHead
|
||||
XFTPClientHandshake {xftpVersion = v, keyHash} <- liftHS $ smpDecode body
|
||||
kh <- asks serverIdentity
|
||||
unless (keyHash == kh) $ throwError HANDSHAKE
|
||||
unless (keyHash == kh) $ throwE HANDSHAKE
|
||||
case compatibleVRange' xftpServerVRange v of
|
||||
Just (Compatible vr) -> do
|
||||
let auth = THAuthServer {serverPrivKey = pk, sessSecret' = Nothing}
|
||||
@@ -163,7 +164,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
|
||||
#endif
|
||||
liftIO . sendResponse $ H.responseNoBody N.ok200 []
|
||||
pure Nothing
|
||||
Nothing -> throwError HANDSHAKE
|
||||
Nothing -> throwE HANDSHAKE
|
||||
sendError :: XFTPErrorType -> M (Maybe (THandleParams XFTPVersion 'TServer))
|
||||
sendError err = do
|
||||
runExceptT (encodeXftp err) >>= \case
|
||||
@@ -395,7 +396,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
|
||||
st <- asks store
|
||||
r <- runExceptT $ do
|
||||
sizes <- asks $ allowedChunkSizes . config
|
||||
unless (size file `elem` sizes) $ throwError SIZE
|
||||
unless (size file `elem` sizes) $ throwE SIZE
|
||||
ts <- liftIO getSystemTime
|
||||
-- TODO validate body empty
|
||||
sId <- ExceptT $ addFileRetry st file 3 ts
|
||||
|
||||
@@ -194,7 +194,7 @@ receiveFile_ :: (Handle -> Word32 -> IO (Either XFTPErrorType ())) -> XFTPRcvChu
|
||||
receiveFile_ receive XFTPRcvChunkSpec {filePath, chunkSize, chunkDigest} = do
|
||||
ExceptT $ withFile filePath WriteMode (`receive` chunkSize)
|
||||
digest' <- liftIO $ LC.sha256Hash <$> LB.readFile filePath
|
||||
when (digest' /= chunkDigest) $ throwError DIGEST
|
||||
when (digest' /= chunkDigest) $ throwE DIGEST
|
||||
|
||||
data XFTPErrorType
|
||||
= -- | incorrect block format, encoding or signature size
|
||||
@@ -223,10 +223,6 @@ data XFTPErrorType
|
||||
FILE_IO
|
||||
| -- | file sending or receiving timeout
|
||||
TIMEOUT
|
||||
| -- | bad redirect data
|
||||
REDIRECT {redirectError :: String}
|
||||
| -- | cannot proceed with download from not approved relays without proxy
|
||||
NOT_APPROVED
|
||||
| -- | internal server error
|
||||
INTERNAL
|
||||
| -- | used internally, never returned by the server (to be removed)
|
||||
@@ -236,11 +232,9 @@ data XFTPErrorType
|
||||
instance StrEncoding XFTPErrorType where
|
||||
strEncode = \case
|
||||
CMD e -> "CMD " <> bshow e
|
||||
REDIRECT e -> "REDIRECT " <> bshow e
|
||||
e -> bshow e
|
||||
strP =
|
||||
"CMD " *> (CMD <$> parseRead1)
|
||||
<|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString)
|
||||
<|> parseRead1
|
||||
|
||||
instance Encoding XFTPErrorType where
|
||||
@@ -258,8 +252,6 @@ instance Encoding XFTPErrorType where
|
||||
HAS_FILE -> "HAS_FILE"
|
||||
FILE_IO -> "FILE_IO"
|
||||
TIMEOUT -> "TIMEOUT"
|
||||
REDIRECT err -> "REDIRECT " <> smpEncode err
|
||||
NOT_APPROVED -> "NOT_APPROVED"
|
||||
INTERNAL -> "INTERNAL"
|
||||
DUPLICATE_ -> "DUPLICATE_"
|
||||
|
||||
@@ -278,8 +270,6 @@ instance Encoding XFTPErrorType where
|
||||
"HAS_FILE" -> pure HAS_FILE
|
||||
"FILE_IO" -> pure FILE_IO
|
||||
"TIMEOUT" -> pure TIMEOUT
|
||||
"REDIRECT" -> REDIRECT <$> _smpP
|
||||
"NOT_APPROVED" -> pure NOT_APPROVED
|
||||
"INTERNAL" -> pure INTERNAL
|
||||
"DUPLICATE_" -> pure DUPLICATE_
|
||||
_ -> fail "bad error type"
|
||||
|
||||
@@ -2,24 +2,33 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
|
||||
module Simplex.FileTransfer.Types where
|
||||
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.Word (Word32)
|
||||
import Database.SQLite.Simple.FromField (FromField (..))
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import Simplex.FileTransfer.Client (XFTPChunkSpec (..))
|
||||
import Simplex.FileTransfer.Description
|
||||
import Simplex.Messaging.Agent.Protocol (RcvFileId, SndFileId)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.File (CryptoFile (..))
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (fromTextField_)
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Protocol (XFTPServer)
|
||||
import System.FilePath ((</>))
|
||||
|
||||
type RcvFileId = ByteString
|
||||
|
||||
type SndFileId = ByteString
|
||||
|
||||
authTagSize :: Int64
|
||||
authTagSize = fromIntegral C.authTagSize
|
||||
|
||||
@@ -236,3 +245,35 @@ data DeletedSndChunkReplica = DeletedSndChunkReplica
|
||||
retries :: Int
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data FileErrorType
|
||||
= -- | cannot proceed with download from not approved relays without proxy
|
||||
NOT_APPROVED
|
||||
| -- | max file size exceeded
|
||||
SIZE
|
||||
| -- | bad redirect data
|
||||
REDIRECT {redirectError :: String}
|
||||
| -- | file crypto error
|
||||
FILE_IO {fileIOError :: String}
|
||||
| -- | file not found or was deleted
|
||||
NO_FILE
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding FileErrorType where
|
||||
strP =
|
||||
A.takeTill (== ' ')
|
||||
>>= \case
|
||||
"NOT_APPROVED" -> pure NOT_APPROVED
|
||||
"SIZE" -> pure SIZE
|
||||
"REDIRECT" -> REDIRECT <$> (A.space *> textP)
|
||||
"FILE_IO" -> FILE_IO <$> (A.space *> textP)
|
||||
"NO_FILE" -> pure NO_FILE
|
||||
_ -> fail "bad FileErrorType"
|
||||
strEncode = \case
|
||||
NOT_APPROVED -> "NOT_APPROVED"
|
||||
SIZE -> "SIZE"
|
||||
REDIRECT e -> "REDIRECT " <> encodeUtf8 (T.pack e)
|
||||
FILE_IO e -> "FILE_IO " <> encodeUtf8 (T.pack e)
|
||||
NO_FILE -> "NO_FILE"
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''FileErrorType)
|
||||
|
||||
+68
-105
@@ -29,10 +29,7 @@
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md
|
||||
module Simplex.Messaging.Agent
|
||||
( -- * queue-based SMP agent
|
||||
runAgentClient,
|
||||
|
||||
-- * SMP agent functional API
|
||||
( -- * SMP agent functional API
|
||||
AgentClient (..),
|
||||
AE,
|
||||
SubscriptionsInfo (..),
|
||||
@@ -151,6 +148,7 @@ import Data.Word (Word16)
|
||||
import Simplex.FileTransfer.Agent (closeXFTPAgent, deleteSndFileInternal, deleteSndFileRemote, deleteSndFilesInternal, deleteSndFilesRemote, startXFTPWorkers, toFSFilePath, xftpDeleteRcvFile', xftpDeleteRcvFiles', xftpReceiveFile', xftpSendDescription', xftpSendFile')
|
||||
import Simplex.FileTransfer.Description (ValidFileDescription)
|
||||
import Simplex.FileTransfer.Protocol (FileParty (..))
|
||||
import Simplex.FileTransfer.Types (RcvFileId, SndFileId)
|
||||
import Simplex.FileTransfer.Util (removePath)
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
@@ -185,7 +183,6 @@ import Simplex.RemoteControl.Client
|
||||
import Simplex.RemoteControl.Invitation
|
||||
import Simplex.RemoteControl.Types
|
||||
import System.Mem.Weak (deRefWeak)
|
||||
import UnliftIO.Async (race_)
|
||||
import UnliftIO.Concurrent (forkFinally, forkIO, killThread, mkWeakThreadId, threadDelay)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
@@ -219,7 +216,7 @@ getSMPAgentClient_ clientId cfg initServers store backgroundMode =
|
||||
run AgentClient {subQ, acThread} name a =
|
||||
a `E.catchAny` \e -> whenM (isJust <$> readTVarIO acThread) $ do
|
||||
logError $ "Agent thread " <> name <> " crashed: " <> tshow e
|
||||
atomically $ writeTBQueue subQ ("", "", APC SAEConn $ ERR $ CRITICAL True $ show e)
|
||||
atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR $ CRITICAL True $ show e)
|
||||
|
||||
disconnectAgentClient :: AgentClient -> IO ()
|
||||
disconnectAgentClient c@AgentClient {agentEnv = Env {ntfSupervisor = ns, xftpAgent = xa}} = do
|
||||
@@ -573,40 +570,6 @@ logConnection c connected =
|
||||
let event = if connected then "connected to" else "disconnected from"
|
||||
in logInfo $ T.unwords ["client", tshow (clientId c), event, "Agent"]
|
||||
|
||||
-- | Runs an SMP agent instance that receives commands and sends responses via 'TBQueue's.
|
||||
runAgentClient :: AgentClient -> AM' ()
|
||||
runAgentClient c = race_ (subscriber c) (client c)
|
||||
{-# INLINE runAgentClient #-}
|
||||
|
||||
client :: AgentClient -> AM' ()
|
||||
client c@AgentClient {rcvQ, subQ} = forever $ do
|
||||
(corrId, entId, cmd) <- atomically $ readTBQueue rcvQ
|
||||
runExceptT (processCommand c (entId, cmd))
|
||||
>>= atomically . writeTBQueue subQ . \case
|
||||
Left e -> (corrId, entId, APC SAEConn $ ERR e)
|
||||
Right (entId', resp) -> (corrId, entId', resp)
|
||||
|
||||
-- | execute any SMP agent command
|
||||
processCommand :: AgentClient -> (EntityId, APartyCmd 'Client) -> AM (EntityId, APartyCmd 'Agent)
|
||||
processCommand c (connId, APC e cmd) =
|
||||
second (APC e) <$> case cmd of
|
||||
NEW enableNtfs (ACM cMode) pqIK subMode -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing pqIK subMode
|
||||
JOIN enableNtfs (ACR _ cReq) pqEnc subMode connInfo -> (,OK) <$> joinConn c userId connId False enableNtfs cReq connInfo pqEnc subMode
|
||||
LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK)
|
||||
ACPT invId pqEnc ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo pqEnc SMSubscribe
|
||||
RJCT invId -> rejectContact' c connId invId $> (connId, OK)
|
||||
SUB -> subscribeConnection' c connId $> (connId, OK)
|
||||
SEND pqEnc msgFlags msgBody -> (connId,) . uncurry MID <$> sendMessage' c connId pqEnc msgFlags msgBody
|
||||
ACK msgId rcptInfo_ -> ackMessage' c connId msgId rcptInfo_ $> (connId, OK)
|
||||
SWCH -> switchConnection' c connId $> (connId, OK)
|
||||
OFF -> suspendConnection' c connId $> (connId, OK)
|
||||
DEL -> deleteConnection' c connId $> (connId, OK)
|
||||
CHK -> (connId,) . STAT <$> getConnectionServers' c connId
|
||||
where
|
||||
-- command interface does not support different users
|
||||
userId :: UserId
|
||||
userId = 1
|
||||
|
||||
createUser' :: AgentClient -> NonEmpty SMPServerWithAuth -> NonEmpty XFTPServerWithAuth -> AM UserId
|
||||
createUser' c smp xftp = do
|
||||
userId <- withStore' c createUserRecord
|
||||
@@ -623,12 +586,12 @@ deleteUser' c userId delSMPQueues = do
|
||||
where
|
||||
delUser =
|
||||
whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $
|
||||
writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId)
|
||||
writeTBQueue (subQ c) ("", "", AEvt SAENone $ DEL_USER userId)
|
||||
|
||||
newConnAsync :: ConnectionModeI c => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> AM ConnId
|
||||
newConnAsync c userId corrId enableNtfs cMode pqInitKeys subMode = do
|
||||
connId <- newConnNoQueues c userId "" enableNtfs cMode (CR.connPQEncryption pqInitKeys)
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ NEW enableNtfs (ACM cMode) pqInitKeys subMode
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ NEW enableNtfs (ACM cMode) pqInitKeys subMode
|
||||
pure connId
|
||||
|
||||
newConnNoQueues :: AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> PQSupport -> AM ConnId
|
||||
@@ -647,9 +610,9 @@ joinConnAsync c userId corrId enableNtfs cReqUri@CRInvitationUri {} cInfo pqSup
|
||||
let pqSupport = pqSup `CR.pqSupportAnd` versionPQSupport_ connAgentVersion (Just v)
|
||||
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport}
|
||||
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) pqSupport subMode cInfo
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ JOIN enableNtfs (ACR sConnectionMode cReqUri) pqSupport subMode cInfo
|
||||
pure connId
|
||||
Nothing -> throwError $ AGENT A_VERSION
|
||||
Nothing -> throwE $ AGENT A_VERSION
|
||||
joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo _pqEncryption =
|
||||
throwE $ CMD PROHIBITED "joinConnAsync"
|
||||
|
||||
@@ -657,7 +620,7 @@ allowConnectionAsync' :: AgentClient -> ACorrId -> ConnId -> ConfirmationId -> C
|
||||
allowConnectionAsync' c corrId connId confId ownConnInfo =
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (RcvConnection _ RcvQueue {server}) ->
|
||||
enqueueCommand c corrId connId (Just server) $ AClientCommand $ APC SAEConn $ LET confId ownConnInfo
|
||||
enqueueCommand c corrId connId (Just server) $ AClientCommand $ LET confId ownConnInfo
|
||||
_ -> throwE $ CMD PROHIBITED "allowConnectionAsync"
|
||||
|
||||
acceptContactAsync' :: AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId
|
||||
@@ -668,7 +631,7 @@ acceptContactAsync' c corrId enableNtfs invId ownConnInfo pqSupport subMode = do
|
||||
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
|
||||
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do
|
||||
withStore' c (`unacceptInvitation` invId)
|
||||
throwError err
|
||||
throwE err
|
||||
_ -> throwE $ CMD PROHIBITED "acceptContactAsync"
|
||||
|
||||
ackMessageAsync' :: AgentClient -> ACorrId -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> AM ()
|
||||
@@ -677,7 +640,7 @@ ackMessageAsync' c corrId connId msgId rcptInfo_ = do
|
||||
case cType of
|
||||
SCDuplex -> enqueueAck
|
||||
SCRcv -> enqueueAck
|
||||
SCSnd -> throwError $ CONN SIMPLEX
|
||||
SCSnd -> throwE $ CONN SIMPLEX
|
||||
SCContact -> throwE $ CMD PROHIBITED "ackMessageAsync: SCContact"
|
||||
SCNew -> throwE $ CMD PROHIBITED "ackMessageAsync: SCNew"
|
||||
where
|
||||
@@ -687,7 +650,7 @@ ackMessageAsync' c corrId connId msgId rcptInfo_ = do
|
||||
RcvMsg {msgType} <- withStore c $ \db -> getRcvMsg db connId mId
|
||||
when (isJust rcptInfo_ && msgType /= AM_A_MSG_) $ throwE $ CMD PROHIBITED "ackMessageAsync: receipt not allowed"
|
||||
(RcvQueue {server}, _) <- withStore c $ \db -> setMsgUserAck db connId mId
|
||||
enqueueCommand c corrId connId (Just server) . AClientCommand $ APC SAEConn $ ACK msgId rcptInfo_
|
||||
enqueueCommand c corrId connId (Just server) . AClientCommand $ ACK msgId rcptInfo_
|
||||
|
||||
deleteConnectionAsync' :: AgentClient -> Bool -> ConnId -> AM ()
|
||||
deleteConnectionAsync' c waitDelivery connId = deleteConnectionsAsync' c waitDelivery [connId]
|
||||
@@ -717,7 +680,7 @@ switchConnectionAsync' c corrId connId =
|
||||
| otherwise -> do
|
||||
when (ratchetSyncSendProhibited cData) $ throwE $ CMD PROHIBITED "switchConnectionAsync: send prohibited"
|
||||
rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn SWCH
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand SWCH
|
||||
let rqs' = updatedQs rq1 rqs
|
||||
pure . connectionStats $ DuplexConnection cData rqs' sqs
|
||||
_ -> throwE $ CMD PROHIBITED "switchConnectionAsync: not duplex"
|
||||
@@ -740,7 +703,7 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv
|
||||
(SCMContact, CR.IKUsePQ) -> throwE $ CMD PROHIBITED "newRcvConnSrv"
|
||||
_ -> pure ()
|
||||
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
|
||||
(rq, qUri, tSess, sessId) <- newRcvQueue c userId connId srv smpClientVRange subMode `catchAgentError` \e -> liftIO (print e) >> throwError e
|
||||
(rq, qUri, tSess, sessId) <- newRcvQueue c userId connId srv smpClientVRange subMode `catchAgentError` \e -> liftIO (print e) >> throwE e
|
||||
rq' <- withStore c $ \db -> updateNewConnRcv db connId rq
|
||||
lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId
|
||||
when enableNtfs $ do
|
||||
@@ -760,11 +723,11 @@ newConnToJoin c userId connId enableNtfs cReq pqSup = case cReq of
|
||||
CRInvitationUri {} ->
|
||||
lift (compatibleInvitationUri cReq) >>= \case
|
||||
Just (_, (Compatible (CR.E2ERatchetParams v _ _ _)), aVersion) -> create aVersion (Just v)
|
||||
Nothing -> throwError $ AGENT A_VERSION
|
||||
Nothing -> throwE $ AGENT A_VERSION
|
||||
CRContactUri {} ->
|
||||
lift (compatibleContactUri cReq) >>= \case
|
||||
Just (_, aVersion) -> create aVersion Nothing
|
||||
Nothing -> throwError $ AGENT A_VERSION
|
||||
Nothing -> throwE $ AGENT A_VERSION
|
||||
where
|
||||
create :: Compatible VersionSMPA -> Maybe CR.VersionE2E -> AM ConnId
|
||||
create (Compatible connAgentVersion) e2eV_ = do
|
||||
@@ -796,7 +759,7 @@ startJoinInvitation userId connId enableNtfs cReqUri pqSup =
|
||||
q <- lift $ newSndQueue userId "" qInfo
|
||||
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport}
|
||||
pure (cData, q, rc, e2eSndParams)
|
||||
Nothing -> throwError $ AGENT A_VERSION
|
||||
Nothing -> throwE $ AGENT A_VERSION
|
||||
|
||||
connRequestPQSupport :: AgentClient -> PQSupport -> ConnectionRequestUri c -> IO (Maybe (VersionSMPA, PQSupport))
|
||||
connRequestPQSupport c pqSup cReq = withAgentEnv' c $ case cReq of
|
||||
@@ -846,14 +809,14 @@ joinConnSrv c userId connId hasNewConn enableNtfs inv@CRInvitationUri {} cInfo p
|
||||
Left e -> do
|
||||
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
|
||||
void $ withStore' c $ \db -> deleteConn db Nothing connId'
|
||||
throwError e
|
||||
throwE e
|
||||
joinConnSrv c userId connId hasNewConn enableNtfs cReqUri@CRContactUri {} cInfo pqSup subMode srv =
|
||||
lift (compatibleContactUri cReqUri) >>= \case
|
||||
Just (qInfo, vrsn) -> do
|
||||
(connId', cReq) <- newConnSrv c userId connId hasNewConn enableNtfs SCMInvitation Nothing (CR.IKNoPQ pqSup) subMode srv
|
||||
void $ sendInvitation c userId qInfo vrsn cReq cInfo
|
||||
pure connId'
|
||||
Nothing -> throwError $ AGENT A_VERSION
|
||||
Nothing -> throwE $ AGENT A_VERSION
|
||||
|
||||
joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM ()
|
||||
joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = do
|
||||
@@ -899,7 +862,7 @@ acceptContact' c connId enableNtfs invId ownConnInfo pqSupport subMode = withCon
|
||||
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
|
||||
joinConn c userId connId False enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do
|
||||
withStore' c (`unacceptInvitation` invId)
|
||||
throwError err
|
||||
throwE err
|
||||
_ -> throwE $ CMD PROHIBITED "acceptContact"
|
||||
|
||||
-- | Reject contact (RJCT command) in Reader monad
|
||||
@@ -916,8 +879,8 @@ subscribeConnection' c connId = toConnResult connId =<< subscribeConnections' c
|
||||
toConnResult :: ConnId -> Map ConnId (Either AgentErrorType ()) -> AM ()
|
||||
toConnResult connId rs = case M.lookup connId rs of
|
||||
Just (Right ()) -> when (M.size rs > 1) $ logError $ T.pack $ "too many results " <> show (M.size rs)
|
||||
Just (Left e) -> throwError e
|
||||
_ -> throwError $ INTERNAL $ "no result for connection " <> B.unpack connId
|
||||
Just (Left e) -> throwE e
|
||||
_ -> throwE $ INTERNAL $ "no result for connection " <> B.unpack connId
|
||||
|
||||
type QCmdResult = (QueueStatus, Either AgentErrorType ())
|
||||
|
||||
@@ -984,7 +947,7 @@ subscribeConnections' c connIds = do
|
||||
let actual = M.size rs
|
||||
expected = length connIds
|
||||
when (actual /= expected) . atomically $
|
||||
writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected)
|
||||
writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected)
|
||||
|
||||
resubscribeConnection' :: AgentClient -> ConnId -> AM ()
|
||||
resubscribeConnection' c connId = toConnResult connId =<< resubscribeConnections' c [connId]
|
||||
@@ -1006,7 +969,7 @@ getConnectionMessage' c connId = do
|
||||
DuplexConnection _ (rq :| _) _ -> getQueueMessage c rq
|
||||
RcvConnection _ rq -> getQueueMessage c rq
|
||||
ContactConnection _ rq -> getQueueMessage c rq
|
||||
SndConnection _ _ -> throwError $ CONN SIMPLEX
|
||||
SndConnection _ _ -> throwE $ CONN SIMPLEX
|
||||
NewConnection _ -> throwE $ CMD PROHIBITED "getConnectionMessage: NewConnection"
|
||||
|
||||
getNotificationMessage' :: AgentClient -> C.CbNonce -> ByteString -> AM (NotificationInfo, [SMPMsgMeta])
|
||||
@@ -1114,7 +1077,7 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do
|
||||
where
|
||||
processCmd :: RetryInterval -> PendingCommand -> AM ()
|
||||
processCmd ri PendingCommand {cmdId, corrId, userId, connId, command} = case command of
|
||||
AClientCommand (APC _ cmd) -> case cmd of
|
||||
AClientCommand cmd -> case cmd of
|
||||
NEW enableNtfs (ACM cMode) pqEnc subMode -> noServer $ do
|
||||
usedSrvs <- newTVarIO ([] :: [SMPServer])
|
||||
tryCommand . withNextSrv c userId usedSrvs [] $ \srv -> do
|
||||
@@ -1146,7 +1109,7 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do
|
||||
RcvConnection cData rq -> do
|
||||
secure rq senderKey
|
||||
mapM_ (connectReplyQueues c cData ownConnInfo) (L.nonEmpty $ smpReplyQueues senderConf)
|
||||
_ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd)
|
||||
_ -> throwE $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd)
|
||||
ICDuplexSecure _rId senderKey -> withServer' . tryWithLock "ICDuplexSecure" . withDuplexConn $ \(DuplexConnection cData (rq :| _) (sq :| _)) -> do
|
||||
secure rq senderKey
|
||||
void $ enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO
|
||||
@@ -1182,8 +1145,8 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do
|
||||
tryError (deleteQueue c rq') >>= \case
|
||||
Right () -> finalizeSwitch
|
||||
Left e
|
||||
| temporaryOrHostError e -> throwError e
|
||||
| otherwise -> finalizeSwitch >> throwError e
|
||||
| temporaryOrHostError e -> throwE e
|
||||
| otherwise -> finalizeSwitch >> throwE e
|
||||
where
|
||||
finalizeSwitch = do
|
||||
withStore' c $ \db -> deleteConnRcvQueue db rq'
|
||||
@@ -1223,13 +1186,13 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do
|
||||
tryWithLock name = tryCommand . withConnLock c connId name
|
||||
internalErr s = cmdError $ INTERNAL $ s <> ": " <> show (agentCommandTag command)
|
||||
cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId)
|
||||
notify :: forall e. AEntityI e => ACommand 'Agent e -> AM ()
|
||||
notify cmd = atomically $ writeTBQueue subQ (corrId, connId, APC (sAEntity @e) cmd)
|
||||
notify :: forall e. AEntityI e => AEvent e -> AM ()
|
||||
notify cmd = atomically $ writeTBQueue subQ (corrId, connId, AEvt (sAEntity @e) cmd)
|
||||
-- ^ ^ ^ async command processing /
|
||||
|
||||
enqueueMessages :: AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> AM (AgentMsgId, PQEncryption)
|
||||
enqueueMessages c cData sqs msgFlags aMessage = do
|
||||
when (ratchetSyncSendProhibited cData) $ throwError $ INTERNAL "enqueueMessages: ratchet is not synchronized"
|
||||
when (ratchetSyncSendProhibited cData) $ throwE $ INTERNAL "enqueueMessages: ratchet is not synchronized"
|
||||
enqueueMessages' c cData sqs msgFlags aMessage
|
||||
|
||||
enqueueMessages' :: AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> AM (AgentMsgId, CR.PQEncryption)
|
||||
@@ -1460,9 +1423,9 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq (Worker {doWork
|
||||
delMsg = delMsgKeep False
|
||||
delMsgKeep :: Bool -> InternalId -> AM ()
|
||||
delMsgKeep keepForReceipt msgId = withStore' c $ \db -> deleteSndMsgDelivery db connId sq msgId keepForReceipt
|
||||
notify :: forall e. AEntityI e => ACommand 'Agent e -> AM ()
|
||||
notify cmd = atomically $ writeTBQueue subQ ("", connId, APC (sAEntity @e) cmd)
|
||||
notifyDel :: AEntityI e => InternalId -> ACommand 'Agent e -> AM ()
|
||||
notify :: forall e. AEntityI e => AEvent e -> AM ()
|
||||
notify cmd = atomically $ writeTBQueue subQ ("", connId, AEvt (sAEntity @e) cmd)
|
||||
notifyDel :: AEntityI e => InternalId -> AEvent e -> AM ()
|
||||
notifyDel msgId cmd = notify cmd >> delMsg msgId
|
||||
connError msgId = notifyDel msgId . ERR . CONN
|
||||
qError msgId = notifyDel msgId . ERR . AGENT . A_QUEUE
|
||||
@@ -1482,7 +1445,7 @@ ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do
|
||||
case conn of
|
||||
DuplexConnection {} -> ack >> sendRcpt conn >> del
|
||||
RcvConnection {} -> ack >> del
|
||||
SndConnection {} -> throwError $ CONN SIMPLEX
|
||||
SndConnection {} -> throwE $ CONN SIMPLEX
|
||||
ContactConnection {} -> throwE $ CMD PROHIBITED "ackMessage: ContactConnection"
|
||||
NewConnection _ -> throwE $ CMD PROHIBITED "ackMessage: NewConnection"
|
||||
where
|
||||
@@ -1566,7 +1529,7 @@ abortConnectionSwitch' c connId =
|
||||
let rqs'' = updatedQs rq' rqs'
|
||||
conn' = DuplexConnection cData rqs'' sqs
|
||||
pure $ connectionStats conn'
|
||||
_ -> throwError $ INTERNAL "won't delete all rcv queues in connection"
|
||||
_ -> throwE $ INTERNAL "won't delete all rcv queues in connection"
|
||||
| otherwise -> throwE $ CMD PROHIBITED "abortConnectionSwitch: no rcv queues left"
|
||||
_ -> throwE $ CMD PROHIBITED "abortConnectionSwitch: not allowed"
|
||||
_ -> throwE $ CMD PROHIBITED "abortConnectionSwitch: not duplex"
|
||||
@@ -1596,7 +1559,7 @@ ackQueueMessage :: AgentClient -> RcvQueue -> SMP.MsgId -> AM ()
|
||||
ackQueueMessage c rq srvMsgId =
|
||||
sendAck c rq srvMsgId `catchAgentError` \case
|
||||
SMP _ SMP.NO_MSG -> pure ()
|
||||
e -> throwError e
|
||||
e -> throwE e
|
||||
|
||||
-- | Suspend SMP agent connection (OFF command) in Reader monad
|
||||
suspendConnection' :: AgentClient -> ConnId -> AM ()
|
||||
@@ -1606,7 +1569,7 @@ suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do
|
||||
DuplexConnection _ rqs _ -> mapM_ (suspendQueue c) rqs
|
||||
RcvConnection _ rq -> suspendQueue c rq
|
||||
ContactConnection _ rq -> suspendQueue c rq
|
||||
SndConnection _ _ -> throwError $ CONN SIMPLEX
|
||||
SndConnection _ _ -> throwE $ CONN SIMPLEX
|
||||
NewConnection _ -> throwE $ CMD PROHIBITED "suspendConnection"
|
||||
|
||||
-- | Delete SMP agent connection (DEL command) in Reader monad
|
||||
@@ -1663,7 +1626,7 @@ prepareDeleteConnections_ getConnections c waitDelivery connIds = do
|
||||
-- ! between completed deletions of connections, and deletions delayed due to wait for delivery (see deleteConn)
|
||||
deliveryTimeout <- if waitDelivery then asks (Just . connDeleteDeliveryTimeout . config) else pure Nothing
|
||||
rs' <- lift $ catMaybes . rights <$> withStoreBatch' c (\db -> map (deleteConn db deliveryTimeout) (M.keys delRs))
|
||||
forM_ rs' $ \cId -> notify ("", cId, APC SAEConn DEL_CONN)
|
||||
forM_ rs' $ \cId -> notify ("", cId, AEvt SAEConn DEL_CONN)
|
||||
pure (errs' <> delRs, rqs, connIds')
|
||||
where
|
||||
rcvQueues :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue]
|
||||
@@ -1678,7 +1641,7 @@ deleteConnQueues c waitDelivery ntf rqs = do
|
||||
let connIds = M.keys $ M.filter isRight rs
|
||||
deliveryTimeout <- if waitDelivery then asks (Just . connDeleteDeliveryTimeout . config) else pure Nothing
|
||||
rs' <- catMaybes . rights <$> withStoreBatch' c (\db -> map (deleteConn db deliveryTimeout) connIds)
|
||||
forM_ rs' $ \cId -> notify ("", cId, APC SAEConn DEL_CONN)
|
||||
forM_ rs' $ \cId -> notify ("", cId, AEvt SAEConn DEL_CONN)
|
||||
pure rs
|
||||
where
|
||||
deleteQueueRecs :: [(RcvQueue, Either AgentErrorType ())] -> AM' [(RcvQueue, Either AgentErrorType ())]
|
||||
@@ -1698,7 +1661,7 @@ deleteConnQueues c waitDelivery ntf rqs = do
|
||||
Left e
|
||||
| temporaryOrHostError e && deleteErrors rq + 1 < maxErrs -> incRcvDeleteErrors db rq $> ((rq, r), Nothing)
|
||||
| otherwise -> deleteConnRcvQueue db rq $> ((rq, Right ()), Just (notifyRQ rq (Just e)))
|
||||
notifyRQ rq e_ = notify ("", qConnId rq, APC SAEConn $ DEL_RCVQ (qServer rq) (queueId rq) e_)
|
||||
notifyRQ rq e_ = notify ("", qConnId rq, AEvt SAEConn $ DEL_RCVQ (qServer rq) (queueId rq) e_)
|
||||
notify = when ntf . atomically . writeTBQueue (subQ c)
|
||||
connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ())
|
||||
connResults = M.map snd . foldl' addResult M.empty
|
||||
@@ -1735,7 +1698,7 @@ deleteConnections_ getConnections ntf waitDelivery c connIds = do
|
||||
let actual = M.size rs
|
||||
expected = length connIds
|
||||
when (actual /= expected) . atomically $
|
||||
writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ INTERNAL $ "deleteConnections result size: " <> show actual <> ", expected " <> show expected)
|
||||
writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ "deleteConnections result size: " <> show actual <> ", expected " <> show expected)
|
||||
|
||||
getConnectionServers' :: AgentClient -> ConnId -> AM ConnectionStats
|
||||
getConnectionServers' c connId = do
|
||||
@@ -1818,7 +1781,7 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
|
||||
ns <- asks ntfSupervisor
|
||||
tryReplace ns `catchAgentError` \e ->
|
||||
if temporaryOrHostError e
|
||||
then throwError e
|
||||
then throwE e
|
||||
else do
|
||||
withStore' c $ \db -> removeNtfToken db tkn
|
||||
atomically $ nsRemoveNtfToken ns
|
||||
@@ -1906,7 +1869,7 @@ toggleConnectionNtfs' c connId enable = do
|
||||
DuplexConnection cData _ _ -> toggle cData
|
||||
RcvConnection cData _ -> toggle cData
|
||||
ContactConnection cData _ -> toggle cData
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
_ -> throwE $ CONN SIMPLEX
|
||||
where
|
||||
toggle :: ConnData -> AM ()
|
||||
toggle cData
|
||||
@@ -1926,7 +1889,7 @@ deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do
|
||||
atomically $ nsUpdateToken ns tkn {ntfTknStatus, ntfTknAction}
|
||||
agentNtfDeleteToken c tknId tkn `catchAgentError` \case
|
||||
NTF _ AUTH -> pure ()
|
||||
e -> throwError e
|
||||
e -> throwE e
|
||||
withStore' c $ \db -> removeNtfToken db tkn
|
||||
atomically $ nsRemoveNtfToken ns
|
||||
|
||||
@@ -1946,8 +1909,8 @@ withToken c tkn@NtfToken {deviceToken, ntfMode} from_ (toStatus, toAction_) f =
|
||||
withStore' c $ \db -> removeNtfToken db tkn
|
||||
atomically $ nsRemoveNtfToken ns
|
||||
void $ registerNtfToken' c deviceToken ntfMode
|
||||
throwError e
|
||||
Left e -> throwError e
|
||||
throwE e
|
||||
Left e -> throwE e
|
||||
|
||||
initializeNtfSubs :: AgentClient -> AM ()
|
||||
initializeNtfSubs c = sendNtfConnCommands c NSCCreate
|
||||
@@ -1968,7 +1931,7 @@ sendNtfConnCommands c cmd = do
|
||||
Just (ConnData {enableNtfs}, _) ->
|
||||
when enableNtfs . atomically $ writeTBQueue (ntfSubQ ns) (connId, cmd)
|
||||
_ ->
|
||||
atomically $ writeTBQueue (subQ c) ("", connId, APC SAEConn $ ERR $ INTERNAL "no connection data")
|
||||
atomically $ writeTBQueue (subQ c) ("", connId, AEvt SAEConn $ ERR $ INTERNAL "no connection data")
|
||||
|
||||
setNtfServers :: AgentClient -> [NtfServer] -> IO ()
|
||||
setNtfServers c = atomically . writeTVar (ntfServers c)
|
||||
@@ -2050,7 +2013,7 @@ cleanupManager c@AgentClient {subQ} = do
|
||||
run SFERR deleteExpiredReplicasForDeletion
|
||||
liftIO $ threadDelay' int
|
||||
where
|
||||
run :: forall e. AEntityI e => (AgentErrorType -> ACommand 'Agent e) -> AM () -> AM' ()
|
||||
run :: forall e. AEntityI e => (AgentErrorType -> AEvent e) -> AM () -> AM' ()
|
||||
run err a = do
|
||||
waitActive . runExceptT $ a `catchAgentError` (notify "" . err)
|
||||
step <- asks $ cleanupStepInterval . config
|
||||
@@ -2097,8 +2060,8 @@ cleanupManager c@AgentClient {subQ} = do
|
||||
deleteExpiredReplicasForDeletion = do
|
||||
rcvFilesTTL <- asks $ rcvFilesTTL . config
|
||||
withStore' c (`deleteDeletedSndChunkReplicasExpired` rcvFilesTTL)
|
||||
notify :: forall e. AEntityI e => EntityId -> ACommand 'Agent e -> AM ()
|
||||
notify entId cmd = atomically $ writeTBQueue subQ ("", entId, APC (sAEntity @e) cmd)
|
||||
notify :: forall e. AEntityI e => EntityId -> AEvent e -> AM ()
|
||||
notify entId cmd = atomically $ writeTBQueue subQ ("", entId, AEvt (sAEntity @e) cmd)
|
||||
|
||||
data ACKd = ACKd | ACKPending
|
||||
|
||||
@@ -2151,8 +2114,8 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
atomically . whenM (isPendingSub connId) $ failSubscription c rq e
|
||||
lift $ notifyErr connId e
|
||||
isPendingSub connId = (&&) <$> hasPendingSubscription c connId <*> activeClientSession c tSess sessId
|
||||
notify' :: forall e m. (AEntityI e, MonadIO m) => ConnId -> ACommand 'Agent e -> m ()
|
||||
notify' connId msg = atomically $ writeTBQueue subQ ("", connId, APC (sAEntity @e) msg)
|
||||
notify' :: forall e m. (AEntityI e, MonadIO m) => ConnId -> AEvent e -> m ()
|
||||
notify' connId msg = atomically $ writeTBQueue subQ ("", connId, AEvt (sAEntity @e) msg)
|
||||
notifyErr :: ConnId -> SMPClientError -> AM' ()
|
||||
notifyErr connId = notify' connId . ERR . protocolClientError SMP (B.unpack $ strEncode srv)
|
||||
processSMP :: forall c. RcvQueue -> Connection c -> ConnData -> BrokerMsg -> AM ()
|
||||
@@ -2179,7 +2142,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <-
|
||||
parseMessage msgBody
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
|
||||
unless (phVer `isCompatible` clientVRange) . throwE $ AGENT A_VERSION
|
||||
case (e2eDhSecret, e2ePubKey_) of
|
||||
(Nothing, Just e2ePubKey) -> do
|
||||
let e2eDh = C.dh' e2ePubKey e2ePrivKey
|
||||
@@ -2275,7 +2238,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
checkDuplicateHash :: AgentErrorType -> ByteString -> AM ()
|
||||
checkDuplicateHash e encryptedMsgHash =
|
||||
unlessM (withStore' c $ \db -> checkRcvMsgHashExists db connId encryptedMsgHash) $
|
||||
throwError e
|
||||
throwE e
|
||||
updateTotalMsgCount :: STM ()
|
||||
updateTotalMsgCount =
|
||||
TM.lookup connId (msgCounts c) >>= \case
|
||||
@@ -2343,7 +2306,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
SMP.ERR e -> notify $ ERR $ SMP (B.unpack $ strEncode srv) e
|
||||
r -> unexpected r
|
||||
where
|
||||
notify :: forall e m. (AEntityI e, MonadIO m) => ACommand 'Agent e -> m ()
|
||||
notify :: forall e m. (AEntityI e, MonadIO m) => AEvent e -> m ()
|
||||
notify = notify' connId
|
||||
|
||||
prohibited :: String -> AM ()
|
||||
@@ -2368,7 +2331,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
-- aVRange <- asks $ smpAgentVRange . config
|
||||
-- if agentVersion agentEnvelope `isCompatible` aVRange
|
||||
-- then pure (privHeader, agentEnvelope)
|
||||
-- else throwError $ AGENT A_VERSION
|
||||
-- else throwE $ AGENT A_VERSION
|
||||
pure (privHeader, agentEnvelope)
|
||||
|
||||
parseMessage :: Encoding a => ByteString -> AM a
|
||||
@@ -2381,12 +2344,12 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
let ConnData {pqSupport} = toConnData conn'
|
||||
unless
|
||||
(agentVersion `isCompatible` smpAgentVRange && smpClientVersion `isCompatible` smpClientVRange)
|
||||
(throwError $ AGENT A_VERSION)
|
||||
(throwE $ AGENT A_VERSION)
|
||||
case status of
|
||||
New -> case (conn', e2eEncryption) of
|
||||
-- party initiating connection
|
||||
(RcvConnection _ _, Just (CR.AE2ERatchetParams _ e2eSndParams@(CR.E2ERatchetParams e2eVersion _ _ _))) -> do
|
||||
unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwError $ AGENT A_VERSION)
|
||||
unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwE $ AGENT A_VERSION)
|
||||
(pk1, rcDHRs, pKem) <- withStore c (`getRatchetX3dhKeys` connId)
|
||||
rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 rcDHRs pKem e2eSndParams
|
||||
let rcVs = CR.RatchetVersions {current = e2eVersion, maxSupported = maxVersion e2eEncryptVRange}
|
||||
@@ -2482,7 +2445,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
qAddMsg :: SMP.MsgId -> NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> AM ()
|
||||
qAddMsg _ ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported"
|
||||
qAddMsg srvMsgId ((qUri, Just addr) :| _) (DuplexConnection cData' rqs sqs) = do
|
||||
when (ratchetSyncSendProhibited cData') $ throwError $ AGENT (A_QUEUE "ratchet is not synchronized")
|
||||
when (ratchetSyncSendProhibited cData') $ throwE $ AGENT (A_QUEUE "ratchet is not synchronized")
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case qUri `compatibleVersion` clientVRange of
|
||||
Just qInfo@(Compatible sqInfo@SMPQueueInfo {queueAddress}) ->
|
||||
@@ -2509,14 +2472,14 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
_ -> qError "absent sender keys"
|
||||
_ -> qError "QADD: won't delete all snd queues in connection"
|
||||
_ -> qError "QADD: replaced queue address is not found in connection"
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
_ -> throwE $ AGENT A_VERSION
|
||||
|
||||
-- processed by queue recipient
|
||||
qKeyMsg :: SMP.MsgId -> NonEmpty (SMPQueueInfo, SndPublicAuthKey) -> Connection 'CDuplex -> AM ()
|
||||
qKeyMsg srvMsgId ((qInfo, senderKey) :| _) conn'@(DuplexConnection cData' rqs _) = do
|
||||
when (ratchetSyncSendProhibited cData') $ throwError $ AGENT (A_QUEUE "ratchet is not synchronized")
|
||||
when (ratchetSyncSendProhibited cData') $ throwE $ AGENT (A_QUEUE "ratchet is not synchronized")
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
unless (qInfo `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
|
||||
unless (qInfo `isCompatible` clientVRange) . throwE $ AGENT A_VERSION
|
||||
case findRQ (smpServer, senderId) rqs of
|
||||
Just rq'@RcvQueue {rcvId, e2ePrivKey = dhPrivKey, smpClientVersion = cVer, status = status'}
|
||||
| status' == New || status' == Confirmed -> do
|
||||
@@ -2536,7 +2499,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
qUseMsg :: SMP.MsgId -> NonEmpty ((SMPServer, SMP.SenderId), Bool) -> Connection 'CDuplex -> AM ()
|
||||
-- NOTE: does not yet support the change of the primary status during the rotation
|
||||
qUseMsg srvMsgId ((addr, _primary) :| _) (DuplexConnection cData' rqs sqs) = do
|
||||
when (ratchetSyncSendProhibited cData') $ throwError $ AGENT (A_QUEUE "ratchet is not synchronized")
|
||||
when (ratchetSyncSendProhibited cData') $ throwE $ AGENT (A_QUEUE "ratchet is not synchronized")
|
||||
case findQ addr sqs of
|
||||
Just sq'@SndQueue {dbReplaceQueueId = Just replaceQId} -> do
|
||||
case find ((replaceQId ==) . dbQId) sqs of
|
||||
@@ -2555,7 +2518,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
_ -> qError "QUSE: switched queue address not found in connection"
|
||||
|
||||
qError :: String -> AM a
|
||||
qError = throwError . AGENT . A_QUEUE
|
||||
qError = throwE . AGENT . A_QUEUE
|
||||
|
||||
ereadyMsg :: CR.RatchetX448 -> Connection 'CDuplex -> AM ()
|
||||
ereadyMsg rcPrev (DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) = do
|
||||
@@ -2591,7 +2554,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv _) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId, pqSupport} _ sqs) =
|
||||
unlessM ratchetExists $ do
|
||||
AgentConfig {e2eEncryptVRange} <- asks config
|
||||
unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwError $ AGENT A_VERSION)
|
||||
unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwE $ AGENT A_VERSION)
|
||||
keys <- getSendRatchetKeys
|
||||
let rcVs = CR.RatchetVersions {current = e2eVersion, maxSupported = maxVersion e2eEncryptVRange}
|
||||
initRatchet rcVs keys
|
||||
@@ -2616,7 +2579,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(_, srv, _), _v, sessId, ts)
|
||||
-- can communicate for other client to reset to RSRequired
|
||||
-- - need to add new AgentMsgEnvelope, AgentMessage, AgentMessageType
|
||||
-- - need to deduplicate on receiving side
|
||||
throwError $ AGENT (A_CRYPTO RATCHET_SYNC)
|
||||
throwE $ AGENT (A_CRYPTO RATCHET_SYNC)
|
||||
where
|
||||
sendReplyKey = do
|
||||
g <- asks random
|
||||
@@ -2671,7 +2634,7 @@ checkSQSwchStatus sq@SndQueue {sndSwchStatus} expected =
|
||||
|
||||
switchStatusError :: (SMPQueueRec q, Show a) => q -> a -> Maybe a -> AM ()
|
||||
switchStatusError q expected actual =
|
||||
throwError . INTERNAL $
|
||||
throwE . INTERNAL $
|
||||
("unexpected switch status, queueId=" <> show (queueId q))
|
||||
<> (", expected=" <> show expected)
|
||||
<> (", actual=" <> show actual)
|
||||
@@ -2680,7 +2643,7 @@ connectReplyQueues :: AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueIn
|
||||
connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case qInfo `proveCompatible` clientVRange of
|
||||
Nothing -> throwError $ AGENT A_VERSION
|
||||
Nothing -> throwE $ AGENT A_VERSION
|
||||
Just qInfo' -> do
|
||||
sq <- lift $ newSndQueue userId connId qInfo'
|
||||
sq' <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
|
||||
|
||||
@@ -273,8 +273,7 @@ type XFTPTransportSession = TransportSession FileResponse
|
||||
data AgentClient = AgentClient
|
||||
{ acThread :: TVar (Maybe (Weak ThreadId)),
|
||||
active :: TVar Bool,
|
||||
rcvQ :: TBQueue (ATransmission 'Client),
|
||||
subQ :: TBQueue (ATransmission 'Agent),
|
||||
subQ :: TBQueue ATransmission,
|
||||
msgQ :: TBQueue (ServerTransmissionBatch SMPVersion ErrorType BrokerMsg),
|
||||
smpServers :: TMap UserId (NonEmpty SMPServerWithAuth),
|
||||
smpClients :: TMap SMPTransportSession SMPClientVar,
|
||||
@@ -373,7 +372,7 @@ getAgentWorker' toW fromW name hasWork c key ws work = do
|
||||
notifyErr err = do
|
||||
let e = either ((", error: " <>) . show) (\_ -> ", no error") e_
|
||||
msg = "Worker " <> name <> " for " <> show key <> " terminated " <> show (restartCount rc) <> " times" <> e
|
||||
writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ err msg)
|
||||
writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ err msg)
|
||||
|
||||
newWorker :: AgentClient -> STM Worker
|
||||
newWorker c = do
|
||||
@@ -449,7 +448,6 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv =
|
||||
qSize = tbqSize cfg
|
||||
acThread <- newTVar Nothing
|
||||
active <- newTVar True
|
||||
rcvQ <- newTBQueue qSize
|
||||
subQ <- newTBQueue qSize
|
||||
msgQ <- newTBQueue qSize
|
||||
smpServers <- newTVar smp
|
||||
@@ -487,7 +485,6 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv =
|
||||
AgentClient
|
||||
{ acThread,
|
||||
active,
|
||||
rcvQ,
|
||||
subQ,
|
||||
msgQ,
|
||||
smpServers,
|
||||
@@ -586,7 +583,7 @@ instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where
|
||||
|
||||
getSMPServerClient :: AgentClient -> SMPTransportSession -> AM SMPConnectedClient
|
||||
getSMPServerClient c@AgentClient {active, smpClients, workerSeq} tSess = do
|
||||
unlessM (readTVarIO active) . throwError $ INACTIVE
|
||||
unlessM (readTVarIO active) $ throwE INACTIVE
|
||||
ts <- liftIO getCurrentTime
|
||||
atomically (getSessVar workerSeq tSess smpClients ts)
|
||||
>>= either newClient (waitForProtocolClient c tSess smpClients)
|
||||
@@ -597,7 +594,7 @@ getSMPServerClient c@AgentClient {active, smpClients, workerSeq} tSess = do
|
||||
|
||||
getSMPProxyClient :: AgentClient -> SMPTransportSession -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay)
|
||||
getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq} destSess@(userId, destSrv, qId) = do
|
||||
unlessM (readTVarIO active) . throwError $ INACTIVE
|
||||
unlessM (readTVarIO active) $ throwE INACTIVE
|
||||
proxySrv <- getNextServer c userId [destSrv]
|
||||
ts <- liftIO getCurrentTime
|
||||
atomically (getClientVar proxySrv ts) >>= \(tSess, auth, v) ->
|
||||
@@ -633,7 +630,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq
|
||||
liftIO $ incClientStat c userId clnt "PROXY" "OK"
|
||||
pure $ Right sess
|
||||
Left e -> do
|
||||
liftIO $ incClientStat c userId clnt "PROXY" $ strEncode e
|
||||
liftIO $ incClientStat c userId clnt "PROXY" $ bshow e
|
||||
atomically $ do
|
||||
unless (serverHostError e) $ do
|
||||
removeSessVar rv destSrv prs
|
||||
@@ -652,7 +649,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq
|
||||
smpConnectClient :: AgentClient -> SMPTransportSession -> TMap SMPServer ProxiedRelayVar -> SMPClientVar -> AM SMPConnectedClient
|
||||
smpConnectClient c@AgentClient {smpClients, msgQ} tSess@(_, srv, _) prs v =
|
||||
newProtocolClient c tSess smpClients connectClient v
|
||||
`catchAgentError` \e -> lift (resubscribeSMPSession c tSess) >> throwError e
|
||||
`catchAgentError` \e -> lift (resubscribeSMPSession c tSess) >> throwE e
|
||||
where
|
||||
connectClient :: SMPClientVar -> AM SMPConnectedClient
|
||||
connectClient v' = do
|
||||
@@ -692,8 +689,9 @@ smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess
|
||||
atomically $ mapM_ (releaseGetLock c) qs
|
||||
runReaderT (resubscribeSMPSession c tSess) env
|
||||
|
||||
notifySub :: forall e m. (AEntityI e, MonadIO m) => AgentClient -> ConnId -> ACommand 'Agent e -> m ()
|
||||
notifySub c connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd)
|
||||
{-# INLINE notifySub #-}
|
||||
notifySub :: forall e m. (AEntityI e, MonadIO m) => AgentClient -> ConnId -> AEvent e -> m ()
|
||||
notifySub c connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, AEvt (sAEntity @e) cmd)
|
||||
|
||||
resubscribeSMPSession :: AgentClient -> SMPTransportSession -> AM' ()
|
||||
resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do
|
||||
@@ -744,7 +742,7 @@ reconnectSMPClient c tSess qs = handleNotify $ do
|
||||
|
||||
getNtfServerClient :: AgentClient -> NtfTransportSession -> AM NtfClient
|
||||
getNtfServerClient c@AgentClient {active, ntfClients, workerSeq} tSess@(userId, srv, _) = do
|
||||
unlessM (readTVarIO active) . throwError $ INACTIVE
|
||||
unlessM (readTVarIO active) $ throwE INACTIVE
|
||||
ts <- liftIO getCurrentTime
|
||||
atomically (getSessVar workerSeq tSess ntfClients ts)
|
||||
>>= either
|
||||
@@ -763,12 +761,12 @@ getNtfServerClient c@AgentClient {active, ntfClients, workerSeq} tSess@(userId,
|
||||
clientDisconnected v client = do
|
||||
atomically $ removeSessVar v tSess ntfClients
|
||||
incClientStat c userId client "DISCONNECT" ""
|
||||
atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client)
|
||||
atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ hostEvent DISCONNECT client)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
getXFTPServerClient :: AgentClient -> XFTPTransportSession -> AM XFTPClient
|
||||
getXFTPServerClient c@AgentClient {active, xftpClients, workerSeq} tSess@(userId, srv, _) = do
|
||||
unlessM (readTVarIO active) . throwError $ INACTIVE
|
||||
unlessM (readTVarIO active) $ throwE INACTIVE
|
||||
ts <- liftIO getCurrentTime
|
||||
atomically (getSessVar workerSeq tSess xftpClients ts)
|
||||
>>= either
|
||||
@@ -787,7 +785,7 @@ getXFTPServerClient c@AgentClient {active, xftpClients, workerSeq} tSess@(userId
|
||||
clientDisconnected v client = do
|
||||
atomically $ removeSessVar v tSess xftpClients
|
||||
incClientStat c userId client "DISCONNECT" ""
|
||||
atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client)
|
||||
atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ hostEvent DISCONNECT client)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
waitForProtocolClient ::
|
||||
@@ -827,10 +825,10 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")"
|
||||
atomically $ putTMVar (sessionVar v) (Right client)
|
||||
liftIO $ incClientStat c userId client "CLIENT" "OK"
|
||||
atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent CONNECT client)
|
||||
atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ hostEvent CONNECT client)
|
||||
pure client
|
||||
Left e -> do
|
||||
liftIO $ incServerStat c userId srv "CLIENT" $ strEncode e
|
||||
liftIO $ incServerStat c userId srv "CLIENT" $ bshow e
|
||||
ei <- asks $ persistErrorInterval . config
|
||||
if ei == 0
|
||||
then atomically $ do
|
||||
@@ -841,11 +839,11 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
|
||||
atomically $ putTMVar (sessionVar v) (Left (e, Just ts))
|
||||
throwE e -- signal error to caller
|
||||
|
||||
hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone
|
||||
hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> AEvent 'AENone) -> Client msg -> AEvent 'AENone
|
||||
hostEvent event = hostEvent' event . protocolClient
|
||||
{-# INLINE hostEvent #-}
|
||||
|
||||
hostEvent' :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> ProtoClient msg -> ACommand 'Agent 'AENone
|
||||
hostEvent' :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> AEvent 'AENone) -> ProtoClient msg -> AEvent 'AENone
|
||||
hostEvent' event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost
|
||||
|
||||
getClientConfig :: AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> AM' (ProtocolClientConfig v)
|
||||
@@ -982,9 +980,9 @@ withClient_ c tSess@(userId, srv, _) statCmd action = do
|
||||
stat cl = liftIO . incClientStat c userId cl statCmd
|
||||
logServerError :: Client msg -> AgentErrorType -> AM a
|
||||
logServerError cl e = do
|
||||
logServer "<--" c srv "" $ strEncode e
|
||||
stat cl $ strEncode e
|
||||
throwError e
|
||||
logServer "<--" c srv "" $ bshow e
|
||||
stat cl $ bshow e
|
||||
throwE e
|
||||
|
||||
withProxySession :: AgentClient -> SMPTransportSession -> SMP.SenderId -> ByteString -> ((SMPConnectedClient, ProxiedRelay) -> AM a) -> AM a
|
||||
withProxySession c destSess@(userId, destSrv, _) entId cmdStr action = do
|
||||
@@ -1001,9 +999,9 @@ withProxySession c destSess@(userId, destSrv, _) entId cmdStr action = do
|
||||
proxySrv = showServer . protocolClientServer' . protocolClient
|
||||
logServerError :: SMPConnectedClient -> AgentErrorType -> AM a
|
||||
logServerError cl e = do
|
||||
logServer ("<-- " <> proxySrv cl <> " <") c destSrv "" $ strEncode e
|
||||
stat cl $ strEncode e
|
||||
throwError e
|
||||
logServer ("<-- " <> proxySrv cl <> " <") c destSrv "" $ bshow e
|
||||
stat cl $ bshow e
|
||||
throwE e
|
||||
|
||||
withLogClient_ :: ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> AM a) -> AM a
|
||||
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
|
||||
@@ -1188,7 +1186,7 @@ runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
liftError (testErr TSUploadFile) $ X.uploadXFTPChunk xftp spKey sId chunkSpec
|
||||
liftError (testErr TSDownloadFile) $ X.downloadXFTPChunk g xftp rpKey rId $ XFTPRcvChunkSpec rcvPath chSize digest
|
||||
rcvDigest <- liftIO $ C.sha256Hash <$> B.readFile rcvPath
|
||||
unless (digest == rcvDigest) $ throwError $ ProtocolTestFailure TSCompareFile $ XFTP (B.unpack $ strEncode srv) DIGEST
|
||||
unless (digest == rcvDigest) $ throwE $ ProtocolTestFailure TSCompareFile $ XFTP (B.unpack $ strEncode srv) DIGEST
|
||||
liftError (testErr TSDeleteFile) $ X.deleteXFTPChunk xftp spKey sId
|
||||
ok <- tcpTimeout xftpNetworkConfig `timeout` X.closeXFTPClient xftp
|
||||
incClientStat c userId xftp "XFTP_TEST" "OK"
|
||||
@@ -1482,7 +1480,7 @@ sendConfirmation c sq@SndQueue {userId, server, sndId, sndPublicKey = Just sndPu
|
||||
let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation
|
||||
msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg
|
||||
sendOrProxySMPMessage c userId server "<CONF>" Nothing sndId (MsgFlags {notification = True}) msg
|
||||
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
|
||||
sendConfirmation _ _ _ = throwE $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
|
||||
|
||||
sendInvitation :: AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> AM (Maybe SMPServer)
|
||||
sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do
|
||||
@@ -1653,7 +1651,7 @@ xftpRcvKeys n = do
|
||||
rKeys <- atomically . replicateM n . C.generateAuthKeyPair C.SEd25519 =<< asks random
|
||||
case L.nonEmpty rKeys of
|
||||
Just rKeys' -> pure rKeys'
|
||||
_ -> throwError $ INTERNAL "non-positive number of recipients"
|
||||
_ -> throwE $ INTERNAL "non-positive number of recipients"
|
||||
|
||||
xftpRcvIdsKeys :: NonEmpty ByteString -> NonEmpty C.AAuthKeyPair -> NonEmpty (ChunkReplicaId, C.APrivateAuthKey)
|
||||
xftpRcvIdsKeys rIds rKeys = L.map ChunkReplicaId rIds `L.zip` L.map snd rKeys
|
||||
@@ -1715,7 +1713,7 @@ withWork c doWork getWork action =
|
||||
Left e -> notifyErr INTERNAL e
|
||||
where
|
||||
noWork = liftIO $ noWorkToDo doWork
|
||||
notifyErr err e = atomically $ writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ err $ show e)
|
||||
notifyErr err e = atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ err $ show e)
|
||||
|
||||
noWorkToDo :: TMVar () -> IO ()
|
||||
noWorkToDo = void . atomically . tryTakeTMVar
|
||||
@@ -1758,7 +1756,7 @@ suspendOperation c op endedAction = do
|
||||
notifySuspended :: AgentClient -> STM ()
|
||||
notifySuspended c = do
|
||||
-- unsafeIOToSTM $ putStrLn "notifySuspended"
|
||||
writeTBQueue (subQ c) ("", "", APC SAENone SUSPENDED)
|
||||
writeTBQueue (subQ c) ("", "", AEvt SAENone SUSPENDED)
|
||||
writeTVar (agentState c) ASSuspended
|
||||
|
||||
endOperation :: AgentClient -> AgentOperation -> STM () -> STM ()
|
||||
@@ -1891,7 +1889,7 @@ withUserServers :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient
|
||||
withUserServers c userId action =
|
||||
atomically (TM.lookup userId $ userServers c) >>= \case
|
||||
Just srvs -> action srvs
|
||||
_ -> throwError $ INTERNAL "unknown userId - no user servers"
|
||||
_ -> throwE $ INTERNAL "unknown userId - no user servers"
|
||||
|
||||
withNextSrv :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> AM a) -> AM a
|
||||
withNextSrv c userId usedSrvs initUsed action = do
|
||||
|
||||
@@ -105,7 +105,6 @@ data AgentConfig = AgentConfig
|
||||
storedMsgDataTTL :: NominalDiffTime,
|
||||
rcvFilesTTL :: NominalDiffTime,
|
||||
sndFilesTTL :: NominalDiffTime,
|
||||
xftpNotifyErrsOnRetry :: Bool,
|
||||
xftpConsecutiveRetries :: Int,
|
||||
xftpMaxRecipientsPerRequest :: Int,
|
||||
deleteErrorCount :: Int,
|
||||
@@ -176,7 +175,6 @@ defaultAgentConfig =
|
||||
storedMsgDataTTL = 21 * nominalDay,
|
||||
rcvFilesTTL = 2 * nominalDay,
|
||||
sndFilesTTL = nominalDay,
|
||||
xftpNotifyErrsOnRetry = True,
|
||||
xftpConsecutiveRetries = 3,
|
||||
xftpMaxRecipientsPerRequest = 200,
|
||||
deleteErrorCount = 10,
|
||||
|
||||
@@ -29,7 +29,7 @@ import Data.Time (UTCTime, addUTCTime, getCurrentTime)
|
||||
import Data.Time.Clock (diffUTCTime)
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol (ACommand (..), APartyCmd (..), AgentErrorType (..), BrokerErrorType (..), ConnId, NotificationsMode (..), SAEntity (..))
|
||||
import Simplex.Messaging.Agent.Protocol (AEvent (..), AEvt (..), AgentErrorType (..), BrokerErrorType (..), ConnId, NotificationsMode (..), SAEntity (..))
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
@@ -306,7 +306,7 @@ workerInternalError c connId internalErrStr = do
|
||||
|
||||
-- TODO change error
|
||||
notifyInternalError :: MonadIO m => AgentClient -> ConnId -> String -> m ()
|
||||
notifyInternalError AgentClient {subQ} connId internalErrStr = atomically $ writeTBQueue subQ ("", connId, APC SAEConn $ ERR $ INTERNAL internalErrStr)
|
||||
notifyInternalError AgentClient {subQ} connId internalErrStr = atomically $ writeTBQueue subQ ("", connId, AEvt SAEConn $ ERR $ INTERNAL internalErrStr)
|
||||
{-# INLINE notifyInternalError #-}
|
||||
|
||||
getNtfToken :: AM' (Maybe NtfToken)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,85 +0,0 @@
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Server
|
||||
( -- * SMP agent over TCP
|
||||
runSMPAgent,
|
||||
runSMPAgentBlocking,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Logger.Simple (logInfo)
|
||||
import Control.Monad
|
||||
import Control.Monad.Reader
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Client (newAgentClient)
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore)
|
||||
import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), simplexMQVersion)
|
||||
import Simplex.Messaging.Transport.Server (defaultTransportServerConfig, loadTLSServerParams, runTransportServer)
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import UnliftIO.Async (race_)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
-- | Runs an SMP agent as a TCP service using passed configuration.
|
||||
--
|
||||
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
|
||||
runSMPAgent :: ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> IO ()
|
||||
runSMPAgent t cfg initServers store =
|
||||
runSMPAgentBlocking t cfg initServers store 0 =<< newEmptyTMVarIO
|
||||
|
||||
-- | Runs an SMP agent as a TCP service using passed configuration with signalling.
|
||||
--
|
||||
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
|
||||
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
|
||||
runSMPAgentBlocking :: ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Int -> TMVar Bool -> IO ()
|
||||
runSMPAgentBlocking (ATransport t) cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers store initClientId started =
|
||||
case tcpPort of
|
||||
Just port -> newSMPAgentEnv cfg store >>= smpAgent t port
|
||||
Nothing -> E.throwIO $ userError "no agent port"
|
||||
where
|
||||
smpAgent :: forall c. Transport c => TProxy c -> ServiceName -> Env -> IO ()
|
||||
smpAgent _ port env = do
|
||||
-- tlsServerParams is not in Env to avoid breaking functional API w/t key and certificate generation
|
||||
tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile Nothing
|
||||
clientId <- newTVarIO initClientId
|
||||
runTransportServer started port tlsServerParams defaultTransportServerConfig $ \(h :: c) -> do
|
||||
putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
cId <- atomically $ stateTVar clientId $ \i -> (i + 1, i + 1)
|
||||
c <- atomically $ newAgentClient cId initServers env
|
||||
logConnection c True
|
||||
race_ (connectClient h c) (runAgentClient c `runReaderT` env)
|
||||
`E.finally` (disconnectAgentClient c)
|
||||
|
||||
connectClient :: Transport c => c -> AgentClient -> IO ()
|
||||
connectClient h c = race_ (send h c) (receive h c)
|
||||
|
||||
receive :: forall c. Transport c => c -> AgentClient -> IO ()
|
||||
receive h c@AgentClient {rcvQ, subQ} = forever $ do
|
||||
(corrId, entId, cmdOrErr) <- tGet SClient h
|
||||
case cmdOrErr of
|
||||
Right cmd -> write rcvQ (corrId, entId, cmd)
|
||||
Left e -> write subQ (corrId, entId, APC SAEConn $ ERR e)
|
||||
where
|
||||
write :: TBQueue (ATransmission p) -> ATransmission p -> IO ()
|
||||
write q t = do
|
||||
logClient c "-->" t
|
||||
atomically $ writeTBQueue q t
|
||||
|
||||
send :: Transport c => c -> AgentClient -> IO ()
|
||||
send h c@AgentClient {subQ} = forever $ do
|
||||
t <- atomically $ readTBQueue subQ
|
||||
tPut h t
|
||||
logClient c "<--" t
|
||||
|
||||
logClient :: AgentClient -> ByteString -> ATransmission a -> IO ()
|
||||
logClient AgentClient {clientId} dir (corrId, connId, APC _ cmd) = do
|
||||
logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, connId, B.takeWhile (/= ' ') $ serializeCommand cmd]
|
||||
@@ -47,7 +47,6 @@ import Simplex.Messaging.Protocol
|
||||
VersionSMPC,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
-- * Queue types
|
||||
|
||||
@@ -344,20 +343,20 @@ instance StrEncoding AgentCmdType where
|
||||
_ -> fail "bad AgentCmdType"
|
||||
|
||||
data AgentCommand
|
||||
= AClientCommand (APartyCmd 'Client)
|
||||
= AClientCommand ACommand
|
||||
| AInternalCommand InternalCommand
|
||||
|
||||
instance StrEncoding AgentCommand where
|
||||
strEncode = \case
|
||||
AClientCommand (APC _ cmd) -> strEncode (ACClient, Str $ serializeCommand cmd)
|
||||
AClientCommand cmd -> strEncode (ACClient, Str $ serializeCommand cmd)
|
||||
AInternalCommand cmd -> strEncode (ACInternal, cmd)
|
||||
strP =
|
||||
strP_ >>= \case
|
||||
ACClient -> AClientCommand <$> ((\(ACmd _ e cmd) -> checkParty $ APC e cmd) <$?> dbCommandP)
|
||||
ACClient -> AClientCommand <$> dbCommandP
|
||||
ACInternal -> AInternalCommand <$> strP
|
||||
|
||||
data AgentCommandTag
|
||||
= AClientCommandTag (APartyCmdTag 'Client)
|
||||
= AClientCommandTag ACommandTag
|
||||
| AInternalCommandTag InternalCommandTag
|
||||
deriving (Show)
|
||||
|
||||
@@ -436,7 +435,7 @@ instance StrEncoding InternalCommandTag where
|
||||
|
||||
agentCommandTag :: AgentCommand -> AgentCommandTag
|
||||
agentCommandTag = \case
|
||||
AClientCommand cmd -> AClientCommandTag $ aPartyCmdTag cmd
|
||||
AClientCommand cmd -> AClientCommandTag $ aCommandTag cmd
|
||||
AInternalCommand cmd -> AInternalCommandTag $ internalCmdTag cmd
|
||||
|
||||
internalCmdTag :: InternalCommand -> InternalCommandTag
|
||||
|
||||
@@ -225,6 +225,7 @@ import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Monad.Trans.Except
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
@@ -1045,7 +1046,7 @@ getWorkItem :: Show i => ByteString -> IO (Maybe i) -> (i -> IO (Either StoreErr
|
||||
getWorkItem itemName getId getItem markFailed =
|
||||
runExceptT $ handleErr "getId" getId >>= mapM tryGetItem
|
||||
where
|
||||
tryGetItem itemId = ExceptT (getItem itemId) `catchStoreErrors` \e -> mark itemId >> throwError e
|
||||
tryGetItem itemId = ExceptT (getItem itemId) `catchStoreErrors` \e -> mark itemId >> throwE e
|
||||
mark itemId = handleErr ("markFailed ID " <> bshow itemId) $ markFailed itemId
|
||||
catchStoreErrors = catchAllErrors (SEInternal . bshow)
|
||||
-- Errors caught by this function will suspend worker as if there is no more work,
|
||||
|
||||
@@ -933,8 +933,8 @@ forwardSMPMessage :: SMPClient -> CorrId -> VersionSMP -> C.PublicKeyX25519 -> E
|
||||
forwardSMPMessage c@ProtocolClient {thParams, client_ = PClient {clientCorrId = g}} fwdCorrId fwdVersion fwdKey fwdTransmission = do
|
||||
-- prepare params
|
||||
sessSecret <- case thAuth thParams of
|
||||
Nothing -> throwError $ PCETransportError TENoServerAuth
|
||||
Just THAuthClient {sessSecret} -> maybe (throwError $ PCETransportError TENoServerAuth) pure sessSecret
|
||||
Nothing -> throwE $ PCETransportError TENoServerAuth
|
||||
Just THAuthClient {sessSecret} -> maybe (throwE $ PCETransportError TENoServerAuth) pure sessSecret
|
||||
nonce <- liftIO . atomically $ C.randomCbNonce g
|
||||
-- wrap
|
||||
let fwdT = FwdTransmission {fwdCorrId, fwdVersion, fwdKey, fwdTransmission}
|
||||
|
||||
@@ -23,6 +23,7 @@ where
|
||||
import Control.Exception
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans.Except
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.ByteArray as BA
|
||||
@@ -56,10 +57,10 @@ readFile (CryptoFile path cfArgs) = do
|
||||
case cfArgs of
|
||||
Just (CFArgs (C.SbKey key) (C.CbNonce nonce)) -> do
|
||||
let len = LB.length s - fromIntegral C.authTagSize
|
||||
when (len < 0) $ throwError FTCEInvalidFileSize
|
||||
when (len < 0) $ throwE FTCEInvalidFileSize
|
||||
let (s', tag') = LB.splitAt len s
|
||||
(tag :| cs) <- liftEitherWith FTCECryptoError $ LC.secretBox LC.sbDecryptChunk key nonce s'
|
||||
unless (BA.constEq (LB.toStrict tag') tag) $ throwError FTCEInvalidAuthTag
|
||||
unless (BA.constEq (LB.toStrict tag') tag) $ throwE FTCEInvalidAuthTag
|
||||
pure $ LB.fromChunks cs
|
||||
Nothing -> pure s
|
||||
|
||||
@@ -96,7 +97,7 @@ hGetTag :: CryptoFileHandle -> ExceptT FTCryptoError IO ()
|
||||
hGetTag (CFHandle h sb_) = forM_ sb_ $ \sb -> do
|
||||
tag <- liftIO $ B.hGet h C.authTagSize
|
||||
tag' <- LC.sbAuth <$> readTVarIO sb
|
||||
unless (BA.constEq tag tag') $ throwError FTCEInvalidAuthTag
|
||||
unless (BA.constEq tag tag') $ throwE FTCEInvalidAuthTag
|
||||
|
||||
data FTCryptoError
|
||||
= FTCECryptoError C.CryptoError
|
||||
|
||||
@@ -447,7 +447,7 @@ pqX3dhRcv rpk1 rpk2 rpKem_ (E2ERatchetParams v sk1 sk2 sKem_) = do
|
||||
Just (PrivateRKParamsProposed ks@(_, pk)) -> do
|
||||
shared <- liftIO $ sntrup761Dec ct pk
|
||||
pure $ Just (ks, RatchetKEMAccepted k' shared ct)
|
||||
Nothing -> throwError CERatchetKEMState
|
||||
Nothing -> throwE CERatchetKEMState
|
||||
_ -> pure Nothing -- both parties can send "proposal" in case of ratchet renegotiation
|
||||
|
||||
pqX3dh :: DhAlgorithm a => (PublicKey a, PublicKey a) -> DhSecret a -> DhSecret a -> DhSecret a -> Maybe RatchetKEMAccepted -> RatchetInitParams
|
||||
|
||||
@@ -15,6 +15,7 @@ import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Monad.Trans.Except
|
||||
import Crypto.Hash.Algorithms (SHA256 (..))
|
||||
import qualified Crypto.PubKey.ECC.ECDSA as EC
|
||||
import qualified Crypto.PubKey.ECC.Types as ECT
|
||||
@@ -353,18 +354,18 @@ apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {toke
|
||||
| status == Just N.ok200 = pure ()
|
||||
| status == Just N.badRequest400 =
|
||||
case reason' of
|
||||
"BadDeviceToken" -> throwError PPTokenInvalid
|
||||
"DeviceTokenNotForTopic" -> throwError PPTokenInvalid
|
||||
"TopicDisallowed" -> throwError PPPermanentError
|
||||
"BadDeviceToken" -> throwE PPTokenInvalid
|
||||
"DeviceTokenNotForTopic" -> throwE PPTokenInvalid
|
||||
"TopicDisallowed" -> throwE PPPermanentError
|
||||
_ -> err status reason'
|
||||
| status == Just N.forbidden403 = case reason' of
|
||||
"ExpiredProviderToken" -> throwError PPPermanentError -- there should be no point retrying it as the token was refreshed
|
||||
"InvalidProviderToken" -> throwError PPPermanentError
|
||||
"ExpiredProviderToken" -> throwE PPPermanentError -- there should be no point retrying it as the token was refreshed
|
||||
"InvalidProviderToken" -> throwE PPPermanentError
|
||||
_ -> err status reason'
|
||||
| status == Just N.gone410 = throwError PPTokenInvalid
|
||||
| status == Just N.serviceUnavailable503 = liftIO (disconnectApnsHTTP2Client c) >> throwError PPRetryLater
|
||||
| status == Just N.gone410 = throwE PPTokenInvalid
|
||||
| status == Just N.serviceUnavailable503 = liftIO (disconnectApnsHTTP2Client c) >> throwE PPRetryLater
|
||||
-- Just tooManyRequests429 -> TooManyRequests - too many requests for the same token
|
||||
| otherwise = err status reason'
|
||||
err :: Maybe Status -> Text -> ExceptT PushProviderError IO ()
|
||||
err s r = throwError $ PPResponseError s r
|
||||
err s r = throwE $ PPResponseError s r
|
||||
liftHTTPS2 a = ExceptT $ first PPConnection <$> a
|
||||
|
||||
@@ -116,7 +116,7 @@ ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do
|
||||
getHandshake th >>= \case
|
||||
NtfClientHandshake {ntfVersion = v, keyHash}
|
||||
| keyHash /= kh ->
|
||||
throwError $ TEHandshake IDENTITY
|
||||
throwE $ TEHandshake IDENTITY
|
||||
| otherwise ->
|
||||
case compatibleVRange' ntfVersionRange v of
|
||||
Just (Compatible vr) -> pure $ ntfThHandleServer th v vr pk
|
||||
@@ -128,7 +128,7 @@ ntfClientHandshake c keyHash ntfVRange = do
|
||||
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
|
||||
NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th
|
||||
if sessionId /= sessId
|
||||
then throwError TEBadSession
|
||||
then throwE TEBadSession
|
||||
else case ntfVersionRange `compatibleVRange` ntfVRange of
|
||||
Just (Compatible vr) -> do
|
||||
ck_ <- forM sk' $ \signedKey -> liftEitherWith (const $ TEHandshake BAD_AUTH) $ do
|
||||
|
||||
@@ -24,7 +24,7 @@ import Database.SQLite.Simple (ResultError (..), SQLData (..))
|
||||
import Database.SQLite.Simple.FromField (FieldParser, returnError)
|
||||
import Database.SQLite.Simple.Internal (Field (..))
|
||||
import Database.SQLite.Simple.Ok (Ok (Ok))
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
import Simplex.Messaging.Util (safeDecodeUtf8, (<$?>))
|
||||
import Text.Read (readMaybe)
|
||||
|
||||
base64P :: Parser ByteString
|
||||
@@ -154,3 +154,6 @@ singleFieldJSON_ objectTag tagModifier =
|
||||
|
||||
defaultJSON :: J.Options
|
||||
defaultJSON = J.defaultOptions {J.omitNothingFields = True}
|
||||
|
||||
textP :: Parser String
|
||||
textP = T.unpack . safeDecodeUtf8 <$> A.takeByteString
|
||||
|
||||
@@ -7,6 +7,7 @@ import qualified Control.Exception as E
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
@@ -114,11 +115,11 @@ catchAllErrors' err action handler = tryAllErrors' err action >>= either handler
|
||||
{-# INLINE catchAllErrors' #-}
|
||||
|
||||
catchThrow :: MonadUnliftIO m => ExceptT e m a -> (E.SomeException -> e) -> ExceptT e m a
|
||||
catchThrow action err = catchAllErrors err action throwError
|
||||
catchThrow action err = catchAllErrors err action throwE
|
||||
{-# INLINE catchThrow #-}
|
||||
|
||||
allFinally :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m b -> ExceptT e m a
|
||||
allFinally err action final = tryAllErrors err action >>= \r -> final >> either throwError pure r
|
||||
allFinally err action final = tryAllErrors err action >>= \r -> final >> either throwE pure r
|
||||
{-# INLINE allFinally #-}
|
||||
|
||||
eitherToMaybe :: Either a b -> Maybe b
|
||||
@@ -149,7 +150,7 @@ safeDecodeUtf8 = decodeUtf8With onError
|
||||
onError _ _ = Just '?'
|
||||
|
||||
timeoutThrow :: MonadUnliftIO m => e -> Int -> ExceptT e m a -> ExceptT e m a
|
||||
timeoutThrow e ms action = ExceptT (sequence <$> (ms `timeout` runExceptT action)) >>= maybe (throwError e) pure
|
||||
timeoutThrow e ms action = ExceptT (sequence <$> (ms `timeout` runExceptT action)) >>= maybe (throwE e) pure
|
||||
|
||||
threadDelay' :: Int64 -> IO ()
|
||||
threadDelay' = loop
|
||||
|
||||
@@ -30,6 +30,7 @@ import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Monad.Trans.Except
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import qualified Data.Aeson as J
|
||||
import Data.ByteString (ByteString)
|
||||
@@ -106,9 +107,9 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
|
||||
action <- liftIO $ runClient c r hostKeys
|
||||
-- wait for the port to make invitation
|
||||
portNum <- atomically $ readTMVar startedPort
|
||||
signedInv@RCSignedInvitation {invitation} <- maybe (throwError RCETLSStartFailed) (liftIO . mkInvitation hostKeys address) portNum
|
||||
signedInv@RCSignedInvitation {invitation} <- maybe (throwE RCETLSStartFailed) (liftIO . mkInvitation hostKeys address) portNum
|
||||
when multicast $ case knownHost of
|
||||
Nothing -> throwError RCENewController
|
||||
Nothing -> throwE RCENewController
|
||||
Just KnownHostPairing {hostDhPubKey} -> do
|
||||
ann <- liftIO . async . runExceptT $ announceRC drg 60 idPrivKey hostDhPubKey hostKeys invitation
|
||||
atomically $ putTMVar announcer ann
|
||||
@@ -117,7 +118,7 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
|
||||
findCtrlAddress :: ExceptT RCErrorType IO (NonEmpty RCCtrlAddress)
|
||||
findCtrlAddress = do
|
||||
found' <- liftIO $ getLocalAddress rcAddrPrefs_
|
||||
maybe (throwError RCENoLocalAddress) pure $ L.nonEmpty found'
|
||||
maybe (throwE RCENoLocalAddress) pure $ L.nonEmpty found'
|
||||
mkClient :: IO RCHClient_
|
||||
mkClient = do
|
||||
startedPort <- newEmptyTMVarIO
|
||||
@@ -211,10 +212,10 @@ prepareHostSession
|
||||
let sharedKey = C.dh' dhPubKey dhPrivKey
|
||||
helloBody <- liftEitherWith (const RCEDecrypt) $ C.cbDecrypt sharedKey nonce encBody
|
||||
hostHello@RCHostHello {v, ca, kem = kemPubKey} <- liftEitherWith RCESyntax $ J.eitherDecodeStrict helloBody
|
||||
unless (ca == tlsHostFingerprint) $ throwError RCEIdentity
|
||||
unless (ca == tlsHostFingerprint) $ throwE RCEIdentity
|
||||
(kemCiphertext, kemSharedKey) <- liftIO $ sntrup761Enc drg kemPubKey
|
||||
let hybridKey = kemHybridSecret dhPubKey dhPrivKey kemSharedKey
|
||||
unless (isCompatible v supportedRCPVRange) $ throwError RCEVersion
|
||||
unless (isCompatible v supportedRCPVRange) $ throwE RCEVersion
|
||||
let keys = HostSessKeys {hybridKey, idPrivKey, sessPrivKey}
|
||||
knownHost' <- updateKnownHost ca dhPubKey
|
||||
let ctrlHello = RCCtrlHello {}
|
||||
@@ -227,7 +228,7 @@ prepareHostSession
|
||||
updateKnownHost :: C.KeyHash -> C.PublicKeyX25519 -> ExceptT RCErrorType IO KnownHostPairing
|
||||
updateKnownHost ca hostDhPubKey = case knownHost_ of
|
||||
Just h -> do
|
||||
unless (hostFingerprint h == tlsHostFingerprint) . throwError $
|
||||
unless (hostFingerprint h == tlsHostFingerprint) . throwE $
|
||||
RCEInternal "TLS host CA is different from host pairing, should be caught in TLS handshake"
|
||||
pure (h :: KnownHostPairing) {hostDhPubKey}
|
||||
Nothing -> pure KnownHostPairing {hostFingerprint = ca, hostDhPubKey}
|
||||
@@ -257,7 +258,7 @@ connectRCCtrl drg (RCVerifiedInvitation inv@RCInvitation {ca, idkey}) pairing_ h
|
||||
pure RCCtrlPairing {caKey, caCert, ctrlFingerprint = ca, idPubKey = idkey, dhPrivKey, prevDhPrivKey = Nothing}
|
||||
updateCtrlPairing :: RCCtrlPairing -> ExceptT RCErrorType IO RCCtrlPairing
|
||||
updateCtrlPairing pairing@RCCtrlPairing {ctrlFingerprint, idPubKey, dhPrivKey = currDhPrivKey} = do
|
||||
unless (ca == ctrlFingerprint && idPubKey == idkey) $ throwError RCEIdentity
|
||||
unless (ca == ctrlFingerprint && idPubKey == idkey) $ throwE RCEIdentity
|
||||
(_, dhPrivKey) <- atomically $ C.generateKeyPair drg
|
||||
pure pairing {dhPrivKey, prevDhPrivKey = Just currDhPrivKey}
|
||||
|
||||
@@ -278,7 +279,7 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca,
|
||||
clientCredentials <-
|
||||
liftIO (genTLSCredentials drg caKey caCert) >>= \case
|
||||
TLS.Credentials (creds : _) -> pure $ Just creds
|
||||
_ -> throwError $ RCEInternal "genTLSCredentials must generate credentials"
|
||||
_ -> throwE $ RCEInternal "genTLSCredentials must generate credentials"
|
||||
let clientConfig = defaultTransportClientConfig {clientCredentials}
|
||||
ExceptT . runTransportClient clientConfig Nothing host (show port) (Just ca) $ \tls@TLS {tlsBuffer, tlsContext} -> runExceptT $ do
|
||||
-- pump socket to detect connection problems
|
||||
@@ -303,11 +304,13 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca,
|
||||
logDebug "Session ended"
|
||||
|
||||
catchRCError :: ExceptT RCErrorType IO a -> (RCErrorType -> ExceptT RCErrorType IO a) -> ExceptT RCErrorType IO a
|
||||
catchRCError = catchAllErrors (RCEException . show)
|
||||
catchRCError = catchAllErrors $ \e -> case fromException e of
|
||||
Just (TLS.Terminated _ _ (TLS.Error_Protocol (_, _, TLS.UnknownCa))) -> RCEIdentity
|
||||
_ -> RCEException $ show e
|
||||
{-# INLINE catchRCError #-}
|
||||
|
||||
putRCError :: ExceptT RCErrorType IO a -> TMVar (Either RCErrorType b) -> ExceptT RCErrorType IO a
|
||||
a `putRCError` r = a `catchRCError` \e -> atomically (tryPutTMVar r $ Left e) >> throwError e
|
||||
a `putRCError` r = a `catchRCError` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e
|
||||
|
||||
sendRCPacket :: Encoding a => TLS -> a -> ExceptT RCErrorType IO ()
|
||||
sendRCPacket tls pkt = do
|
||||
@@ -317,7 +320,7 @@ sendRCPacket tls pkt = do
|
||||
receiveRCPacket :: Encoding a => TLS -> ExceptT RCErrorType IO a
|
||||
receiveRCPacket tls = do
|
||||
b <- liftIO $ cGet tls xrcpBlockSize
|
||||
when (B.length b /= xrcpBlockSize) $ throwError RCEBlockSize
|
||||
when (B.length b /= xrcpBlockSize) $ throwE RCEBlockSize
|
||||
b' <- liftEitherWith (const RCEBlockSize) $ C.unPad b
|
||||
liftEitherWith RCESyntax $ smpDecode b'
|
||||
|
||||
@@ -329,7 +332,7 @@ prepareHostHello
|
||||
hostAppInfo = do
|
||||
logDebug "Preparing session"
|
||||
case compatibleVersion v supportedRCPVRange of
|
||||
Nothing -> throwError RCEVersion
|
||||
Nothing -> throwE RCEVersion
|
||||
Just (Compatible v') -> do
|
||||
nonce <- liftIO . atomically $ C.randomCbNonce drg
|
||||
(kemPubKey, kemPrivKey) <- liftIO $ sntrup761Keypair drg
|
||||
@@ -355,7 +358,7 @@ prepareCtrlSession
|
||||
pure CtrlSessKeys {hybridKey, idPubKey, sessPubKey = skey}
|
||||
RCCtrlEncError {nonce, encMessage} -> do
|
||||
message <- liftEitherWith (const RCEDecrypt) $ C.cbDecrypt sharedKey nonce encMessage
|
||||
throwError $ RCECtrlError $ T.unpack $ safeDecodeUtf8 message
|
||||
throwE $ RCECtrlError $ T.unpack $ safeDecodeUtf8 message
|
||||
|
||||
-- * Multicast discovery
|
||||
|
||||
@@ -382,7 +385,7 @@ discoverRCCtrl subscribers pairings =
|
||||
r@(_, RCVerifiedInvitation RCInvitation {host}) <- findRCCtrlPairing pairings encInvitation
|
||||
case source of
|
||||
SockAddrInet _ ha | THIPv4 (hostAddressToTuple ha) == host -> pure ()
|
||||
_ -> throwError RCEInvitation
|
||||
_ -> throwE RCEInvitation
|
||||
pure r
|
||||
where
|
||||
loop :: ExceptT RCErrorType IO a -> ExceptT RCErrorType IO a
|
||||
@@ -392,8 +395,8 @@ findRCCtrlPairing :: NonEmpty RCCtrlPairing -> RCEncInvitation -> ExceptT RCErro
|
||||
findRCCtrlPairing pairings RCEncInvitation {dhPubKey, nonce, encInvitation} = do
|
||||
(pairing, signedInvStr) <- liftEither $ decrypt (L.toList pairings)
|
||||
signedInv <- liftEitherWith RCESyntax $ strDecode signedInvStr
|
||||
inv@(RCVerifiedInvitation RCInvitation {dh = invDh}) <- maybe (throwError RCEInvitation) pure $ verifySignedInvitation signedInv
|
||||
unless (invDh == dhPubKey) $ throwError RCEInvitation
|
||||
inv@(RCVerifiedInvitation RCInvitation {dh = invDh}) <- maybe (throwE RCEInvitation) pure $ verifySignedInvitation signedInv
|
||||
unless (invDh == dhPubKey) $ throwE RCEInvitation
|
||||
pure (pairing, inv)
|
||||
where
|
||||
decrypt :: [RCCtrlPairing] -> Either RCErrorType (RCCtrlPairing, ByteString)
|
||||
@@ -433,7 +436,7 @@ rcEncryptBody drg hybridKey s = do
|
||||
rcDecryptBody :: KEMHybridSecret -> C.CbNonce -> LazyByteString -> ExceptT RCErrorType IO LazyByteString
|
||||
rcDecryptBody hybridKey nonce ct = do
|
||||
let len = LB.length ct - 16
|
||||
when (len < 0) $ throwError RCEDecrypt
|
||||
when (len < 0) $ throwE RCEDecrypt
|
||||
(ok, s) <- liftEitherWith (const RCEDecrypt) $ LC.kcbDecryptTailTag hybridKey nonce len ct
|
||||
unless ok $ throwError RCEDecrypt
|
||||
unless ok $ throwE RCEDecrypt
|
||||
pure s
|
||||
|
||||
+2
-618
@@ -12,32 +12,12 @@ module AgentTests (agentTests) where
|
||||
|
||||
import AgentTests.ConnectionRequestTests
|
||||
import AgentTests.DoubleRatchetTests (doubleRatchetTests)
|
||||
import AgentTests.FunctionalAPITests (functionalAPITests, inAnyOrder, pattern Msg, pattern Msg', pattern SENT)
|
||||
import AgentTests.FunctionalAPITests (functionalAPITests)
|
||||
import AgentTests.MigrationTests (migrationTests)
|
||||
import AgentTests.NotificationTests (notificationTests)
|
||||
import AgentTests.SQLiteTests (storeTests)
|
||||
import Control.Concurrent
|
||||
import Control.Monad (forM_, when)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Type.Equality
|
||||
import GHC.Stack (withFrozenCallStack)
|
||||
import Network.HTTP.Types (urlEncode)
|
||||
import SMPAgentClient
|
||||
import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn)
|
||||
import Simplex.Messaging.Agent.Protocol hiding (CONF, INFO, MID, REQ, SENT)
|
||||
import qualified Simplex.Messaging.Agent.Protocol as A
|
||||
import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOff, pattern IKPQOn, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn)
|
||||
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (ErrorType (..))
|
||||
import Simplex.Messaging.Transport (ATransport (..), TProxy (..), Transport (..))
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import System.Directory (removeFile)
|
||||
import System.Timeout
|
||||
import Simplex.Messaging.Transport (ATransport (..))
|
||||
import Test.Hspec
|
||||
import Util
|
||||
|
||||
agentTests :: ATransport -> Spec
|
||||
agentTests (ATransport t) = do
|
||||
@@ -47,599 +27,3 @@ agentTests (ATransport t) = do
|
||||
describe "Notification tests" $ notificationTests (ATransport t)
|
||||
describe "SQLite store" storeTests
|
||||
describe "Migration tests" migrationTests
|
||||
describe "SMP agent protocol syntax" $ syntaxTests t
|
||||
describe "Establishing duplex connection (via agent protocol)" $ do
|
||||
skip "These tests are disabled because the agent does not work correctly with multiple connected TCP clients" $
|
||||
describe "one agent" $ do
|
||||
it "should connect via one server and one agent" $ do
|
||||
smpAgentTest2_1_1 $ testDuplexConnection t
|
||||
it "should connect via one server and one agent (random IDs)" $ do
|
||||
smpAgentTest2_1_1 $ testDuplexConnRandomIds t
|
||||
it "should connect via one server and 2 agents" $ do
|
||||
smpAgentTest2_2_1 $ testDuplexConnection t
|
||||
it "should connect via one server and 2 agents (random IDs)" $ do
|
||||
smpAgentTest2_2_1 $ testDuplexConnRandomIds t
|
||||
describe "should connect via 2 servers and 2 agents" $ do
|
||||
pqMatrix2 t smpAgentTest2_2_2 testDuplexConnection'
|
||||
describe "should connect via 2 servers and 2 agents (random IDs)" $ do
|
||||
pqMatrix2 t smpAgentTest2_2_2 testDuplexConnRandomIds'
|
||||
describe "Establishing connections via `contact connection`" $ do
|
||||
describe "should connect via contact connection with one server and 3 agents" $ do
|
||||
pqMatrix3 t smpAgentTest3 testContactConnection
|
||||
describe "should connect via contact connection with one server and 2 agents (random IDs)" $ do
|
||||
pqMatrix2NoInv t smpAgentTest2_2_1 testContactConnRandomIds
|
||||
it "should support rejecting contact request" $ do
|
||||
smpAgentTest2_2_1 $ testRejectContactRequest t
|
||||
describe "Connection subscriptions" $ do
|
||||
it "should connect via one server and one agent" $ do
|
||||
smpAgentTest3_1_1 $ testSubscription t
|
||||
it "should send notifications to client when server disconnects" $ do
|
||||
smpAgentServerTest $ testSubscrNotification t
|
||||
describe "Message delivery and server reconnection" $ do
|
||||
describe "should deliver messages after losing server connection and re-connecting" $
|
||||
pqMatrix2 t smpAgentTest2_2_2_needs_server testMsgDeliveryServerRestart
|
||||
it "should connect to the server when server goes up if it initially was down" $ do
|
||||
smpAgentTestN [] $ testServerConnectionAfterError t
|
||||
it "should deliver pending messages after agent restarting" $ do
|
||||
smpAgentTest1_1_1 $ testMsgDeliveryAgentRestart t
|
||||
it "should concurrently deliver messages to connections without blocking" $ do
|
||||
smpAgentTest2_2_1 $ testConcurrentMsgDelivery t
|
||||
it "should deliver messages if one of connections has quota exceeded" $ do
|
||||
smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t
|
||||
it "should resume delivering messages after exceeding quota once all messages are received" $ do
|
||||
smpAgentTest2_2_1 $ testResumeDeliveryQuotaExceeded t
|
||||
|
||||
type AEntityTransmission p e = (ACorrId, ConnId, ACommand p e)
|
||||
|
||||
type AEntityTransmissionOrError p e = (ACorrId, ConnId, Either AgentErrorType (ACommand p e))
|
||||
|
||||
tGetAgent :: (Transport c, HasCallStack) => c -> IO (AEntityTransmissionOrError 'Agent 'AEConn)
|
||||
tGetAgent = tGetAgent' True
|
||||
|
||||
tGetAgent' :: forall c e. (Transport c, AEntityI e, HasCallStack) => Bool -> c -> IO (AEntityTransmissionOrError 'Agent e)
|
||||
tGetAgent' skipErr h = do
|
||||
(corrId, connId, cmdOrErr) <- pGetAgent skipErr h
|
||||
case cmdOrErr of
|
||||
Right (APC e cmd) -> case testEquality e (sAEntity @e) of
|
||||
Just Refl -> pure (corrId, connId, Right cmd)
|
||||
_ -> error $ "unexpected command " <> show cmd
|
||||
Left err -> pure (corrId, connId, Left err)
|
||||
|
||||
pGetAgent :: forall c. Transport c => Bool -> c -> IO (ATransmissionOrError 'Agent)
|
||||
pGetAgent skipErr h = do
|
||||
(corrId, connId, cmdOrErr) <- tGet SAgent h
|
||||
case cmdOrErr of
|
||||
Right (APC _ CONNECT {}) -> pGetAgent skipErr h
|
||||
Right (APC _ DISCONNECT {}) -> pGetAgent skipErr h
|
||||
Right (APC _ UP {}) -> pGetAgent skipErr h
|
||||
Right (APC _ (ERR (BROKER _ NETWORK))) | skipErr -> pGetAgent skipErr h
|
||||
cmd -> pure (corrId, connId, cmd)
|
||||
|
||||
-- | receive message to handle `h`
|
||||
(<#:) :: Transport c => c -> IO (AEntityTransmissionOrError 'Agent 'AEConn)
|
||||
(<#:) = tGetAgent
|
||||
|
||||
(<#:?) :: Transport c => c -> IO (ATransmissionOrError 'Agent)
|
||||
(<#:?) = pGetAgent True
|
||||
|
||||
(<#:.) :: Transport c => c -> IO (AEntityTransmissionOrError 'Agent 'AENone)
|
||||
(<#:.) = tGetAgent' True
|
||||
|
||||
-- | send transmission `t` to handle `h` and get response
|
||||
(#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (AEntityTransmissionOrError 'Agent 'AEConn)
|
||||
h #: t = tPutRaw h t >> (<#:) h
|
||||
|
||||
(#:!) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (AEntityTransmissionOrError 'Agent 'AEConn)
|
||||
h #:! t = tPutRaw h t >> tGetAgent' False h
|
||||
|
||||
-- | action and expected response
|
||||
-- `h #:t #> r` is the test that sends `t` to `h` and validates that the response is `r`
|
||||
(#>) :: IO (AEntityTransmissionOrError 'Agent 'AEConn) -> AEntityTransmission 'Agent 'AEConn -> Expectation
|
||||
action #> (corrId, connId, cmd) = withFrozenCallStack $ action `shouldReturn` (corrId, connId, Right cmd)
|
||||
|
||||
-- | action and predicate for the response
|
||||
-- `h #:t =#> p` is the test that sends `t` to `h` and validates the response using `p`
|
||||
(=#>) :: IO (AEntityTransmissionOrError 'Agent 'AEConn) -> (AEntityTransmission 'Agent 'AEConn -> Bool) -> Expectation
|
||||
action =#> p = withFrozenCallStack $ action >>= (`shouldSatisfy` p . correctTransmission)
|
||||
|
||||
pattern MID :: AgentMsgId -> ACommand 'Agent 'AEConn
|
||||
pattern MID msgId = A.MID msgId PQEncOn
|
||||
|
||||
correctTransmission :: (ACorrId, ConnId, Either AgentErrorType cmd) -> (ACorrId, ConnId, cmd)
|
||||
correctTransmission (corrId, connId, cmdOrErr) = case cmdOrErr of
|
||||
Right cmd -> (corrId, connId, cmd)
|
||||
Left e -> error $ show e
|
||||
|
||||
-- | receive message to handle `h` and validate that it is the expected one
|
||||
(<#) :: (HasCallStack, Transport c) => c -> AEntityTransmission 'Agent 'AEConn -> Expectation
|
||||
h <# (corrId, connId, cmd) = timeout 5000000 (h <#:) `shouldReturn` Just (corrId, connId, Right cmd)
|
||||
|
||||
(<#.) :: (HasCallStack, Transport c) => c -> AEntityTransmission 'Agent 'AENone -> Expectation
|
||||
h <#. (corrId, connId, cmd) = timeout 5000000 (h <#:.) `shouldReturn` Just (corrId, connId, Right cmd)
|
||||
|
||||
-- | receive message to handle `h` and validate it using predicate `p`
|
||||
(<#=) :: (HasCallStack, Transport c) => c -> (AEntityTransmission 'Agent 'AEConn -> Bool) -> Expectation
|
||||
h <#= p = timeout 5000000 (h <#:) >>= (`shouldSatisfy` p . correctTransmission . fromJust)
|
||||
|
||||
(<#=?) :: (HasCallStack, Transport c) => c -> (ATransmission 'Agent -> Bool) -> Expectation
|
||||
h <#=? p = timeout 5000000 (h <#:?) >>= (`shouldSatisfy` p . correctTransmission . fromJust)
|
||||
|
||||
-- | test that nothing is delivered to handle `h` during 10ms
|
||||
(#:#) :: Transport c => c -> String -> Expectation
|
||||
h #:# err = tryGet `shouldReturn` ()
|
||||
where
|
||||
tryGet =
|
||||
10000 `timeout` tGetAgent h >>= \case
|
||||
Just _ -> error err
|
||||
_ -> return ()
|
||||
|
||||
type PQMatrix2 c =
|
||||
HasCallStack =>
|
||||
TProxy c ->
|
||||
(HasCallStack => (c -> c -> IO ()) -> Expectation) ->
|
||||
(HasCallStack => (c, InitialKeys) -> (c, PQSupport) -> IO ()) ->
|
||||
Spec
|
||||
|
||||
pqMatrix2 :: PQMatrix2 c
|
||||
pqMatrix2 = pqMatrix2_ True
|
||||
|
||||
pqMatrix2NoInv :: PQMatrix2 c
|
||||
pqMatrix2NoInv = pqMatrix2_ False
|
||||
|
||||
pqMatrix2_ :: Bool -> PQMatrix2 c
|
||||
pqMatrix2_ pqInv _ smpTest test = do
|
||||
it "dh/dh handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQSupportOff)
|
||||
it "dh/pq handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQSupportOn)
|
||||
it "pq/dh handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQSupportOff)
|
||||
it "pq/pq handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQSupportOn)
|
||||
when pqInv $ do
|
||||
it "pq-inv/dh handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOff)
|
||||
it "pq-inv/pq handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOn)
|
||||
|
||||
pqMatrix3 ::
|
||||
HasCallStack =>
|
||||
TProxy c ->
|
||||
(HasCallStack => (c -> c -> c -> IO ()) -> Expectation) ->
|
||||
(HasCallStack => (c, InitialKeys) -> (c, PQSupport) -> (c, PQSupport) -> IO ()) ->
|
||||
Spec
|
||||
pqMatrix3 _ smpTest test = do
|
||||
it "dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOff)
|
||||
it "dh/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOn)
|
||||
it "dh/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOff)
|
||||
it "dh/pq/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOn)
|
||||
it "pq/dh/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOff)
|
||||
it "pq/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOn)
|
||||
it "pq/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOff)
|
||||
it "pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOn)
|
||||
|
||||
testDuplexConnection :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnection _ alice bob = testDuplexConnection' (alice, IKPQOn) (bob, PQSupportOn)
|
||||
|
||||
testDuplexConnection' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO ()
|
||||
testDuplexConnection' (alice, aPQ) (bob, bPQ) = do
|
||||
let pq = pqConnectionMode aPQ bPQ
|
||||
("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK)
|
||||
("", "bob", Right (A.CONF confId pqSup' _ "bob's connInfo")) <- (alice <#:)
|
||||
pqSup' `shouldBe` CR.connPQEncryption aPQ
|
||||
alice #: ("2", "bob", "LET " <> confId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
|
||||
bob <# ("", "alice", A.INFO bPQ "alice's connInfo")
|
||||
bob <# ("", "alice", CON pq)
|
||||
alice <# ("", "bob", CON pq)
|
||||
-- message IDs 1 to 3 get assigned to control messages, so first MSG is assigned ID 4
|
||||
alice #: ("3", "bob", "SEND F :hello") #> ("3", "bob", A.MID 4 pq)
|
||||
alice <# ("", "bob", SENT 4)
|
||||
bob <#= \case ("", "alice", Msg' 4 pq' "hello") -> pq == pq'; _ -> False
|
||||
bob #: ("12", "alice", "ACK 4") #> ("12", "alice", OK)
|
||||
alice #: ("4", "bob", "SEND F :how are you?") #> ("4", "bob", A.MID 5 pq)
|
||||
alice <# ("", "bob", SENT 5)
|
||||
bob <#= \case ("", "alice", Msg' 5 pq' "how are you?") -> pq == pq'; _ -> False
|
||||
bob #: ("13", "alice", "ACK 5") #> ("13", "alice", OK)
|
||||
bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", A.MID 6 pq)
|
||||
bob <# ("", "alice", SENT 6)
|
||||
alice <#= \case ("", "bob", Msg' 6 pq' "hello too") -> pq == pq'; _ -> False
|
||||
alice #: ("3a", "bob", "ACK 6") #> ("3a", "bob", OK)
|
||||
bob #: ("15", "alice", "SEND F 9\nmessage 1") #> ("15", "alice", A.MID 7 pq)
|
||||
bob <# ("", "alice", SENT 7)
|
||||
alice <#= \case ("", "bob", Msg' 7 pq' "message 1") -> pq == pq'; _ -> False
|
||||
alice #: ("4a", "bob", "ACK 7") #> ("4a", "bob", OK)
|
||||
alice #: ("5", "bob", "OFF") #> ("5", "bob", OK)
|
||||
bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", A.MID 8 pq)
|
||||
bob <#= \case ("", "alice", MERR 8 (SMP _ AUTH)) -> True; _ -> False
|
||||
alice #: ("6", "bob", "DEL") #> ("6", "bob", OK)
|
||||
alice #:# "nothing else should be delivered to alice"
|
||||
|
||||
testDuplexConnRandomIds :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnRandomIds _ alice bob = testDuplexConnRandomIds' (alice, IKPQOn) (bob, PQSupportOn)
|
||||
|
||||
testDuplexConnRandomIds' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO ()
|
||||
testDuplexConnRandomIds' (alice, aPQ) (bob, bPQ) = do
|
||||
let pq = pqConnectionMode aPQ bPQ
|
||||
("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo")
|
||||
("", bobConn', Right (A.CONF confId pqSup' _ "bob's connInfo")) <- (alice <#:)
|
||||
pqSup' `shouldBe` CR.connPQEncryption aPQ
|
||||
bobConn' `shouldBe` bobConn
|
||||
alice #: ("2", bobConn, "LET " <> confId <> " 16\nalice's connInfo") =#> \case ("2", c, OK) -> c == bobConn; _ -> False
|
||||
bob <# ("", aliceConn, A.INFO bPQ "alice's connInfo")
|
||||
bob <# ("", aliceConn, CON pq)
|
||||
alice <# ("", bobConn, CON pq)
|
||||
alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, A.MID 4 pq)
|
||||
alice <# ("", bobConn, SENT 4)
|
||||
bob <#= \case ("", c, Msg' 4 pq' "hello") -> c == aliceConn && pq == pq'; _ -> False
|
||||
bob #: ("12", aliceConn, "ACK 4") #> ("12", aliceConn, OK)
|
||||
alice #: ("3", bobConn, "SEND F :how are you?") #> ("3", bobConn, A.MID 5 pq)
|
||||
alice <# ("", bobConn, SENT 5)
|
||||
bob <#= \case ("", c, Msg' 5 pq' "how are you?") -> c == aliceConn && pq == pq'; _ -> False
|
||||
bob #: ("13", aliceConn, "ACK 5") #> ("13", aliceConn, OK)
|
||||
bob #: ("14", aliceConn, "SEND F 9\nhello too") #> ("14", aliceConn, A.MID 6 pq)
|
||||
bob <# ("", aliceConn, SENT 6)
|
||||
alice <#= \case ("", c, Msg' 6 pq' "hello too") -> c == bobConn && pq == pq'; _ -> False
|
||||
alice #: ("3a", bobConn, "ACK 6") #> ("3a", bobConn, OK)
|
||||
bob #: ("15", aliceConn, "SEND F 9\nmessage 1") #> ("15", aliceConn, A.MID 7 pq)
|
||||
bob <# ("", aliceConn, SENT 7)
|
||||
alice <#= \case ("", c, Msg' 7 pq' "message 1") -> c == bobConn && pq == pq'; _ -> False
|
||||
alice #: ("4a", bobConn, "ACK 7") #> ("4a", bobConn, OK)
|
||||
alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK)
|
||||
bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, A.MID 8 pq)
|
||||
bob <#= \case ("", cId, MERR 8 (SMP _ AUTH)) -> cId == aliceConn; _ -> False
|
||||
alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK)
|
||||
alice #:# "nothing else should be delivered to alice"
|
||||
|
||||
testContactConnection :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> (c, PQSupport) -> IO ()
|
||||
testContactConnection (alice, aPQ) (bob, bPQ) (tom, tPQ) = do
|
||||
("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
abPQ = pqConnectionMode aPQ bPQ
|
||||
aPQMode = CR.connPQEncryption aPQ
|
||||
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK)
|
||||
("", "alice_contact", Right (A.REQ aInvId PQSupportOn _ "bob's connInfo")) <- (alice <#:)
|
||||
alice #: ("2", "bob", "ACPT " <> aInvId <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("2", "bob", OK)
|
||||
("", "alice", Right (A.CONF bConfId pqSup'' _ "alice's connInfo")) <- (bob <#:)
|
||||
pqSup'' `shouldBe` bPQ
|
||||
bob #: ("12", "alice", "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", "alice", OK)
|
||||
alice <# ("", "bob", A.INFO (CR.connPQEncryption aPQ) "bob's connInfo 2")
|
||||
alice <# ("", "bob", CON abPQ)
|
||||
bob <# ("", "alice", CON abPQ)
|
||||
alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", A.MID 4 abPQ)
|
||||
alice <# ("", "bob", SENT 4)
|
||||
bob <#= \case ("", "alice", Msg' 4 pq' "hi") -> pq' == abPQ; _ -> False
|
||||
bob #: ("13", "alice", "ACK 4") #> ("13", "alice", OK)
|
||||
|
||||
let atPQ = pqConnectionMode aPQ tPQ
|
||||
tom #: ("21", "alice", "JOIN T " <> cReq' <> enableKEMStr tPQ <> " subscribe 14\ntom's connInfo") #> ("21", "alice", OK)
|
||||
("", "alice_contact", Right (A.REQ aInvId' PQSupportOn _ "tom's connInfo")) <- (alice <#:)
|
||||
alice #: ("4", "tom", "ACPT " <> aInvId' <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("4", "tom", OK)
|
||||
("", "alice", Right (A.CONF tConfId pqSup4 _ "alice's connInfo")) <- (tom <#:)
|
||||
pqSup4 `shouldBe` tPQ
|
||||
tom #: ("22", "alice", "LET " <> tConfId <> " 16\ntom's connInfo 2") #> ("22", "alice", OK)
|
||||
alice <# ("", "tom", A.INFO (CR.connPQEncryption aPQ) "tom's connInfo 2")
|
||||
alice <# ("", "tom", CON atPQ)
|
||||
tom <# ("", "alice", CON atPQ)
|
||||
alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", A.MID 4 atPQ)
|
||||
alice <# ("", "tom", SENT 4)
|
||||
tom <#= \case ("", "alice", Msg' 4 pq' "hi there") -> pq' == atPQ; _ -> False
|
||||
tom #: ("23", "alice", "ACK 4") #> ("23", "alice", OK)
|
||||
|
||||
testContactConnRandomIds :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO ()
|
||||
testContactConnRandomIds (alice, aPQ) (bob, bPQ) = do
|
||||
let pq = pqConnectionMode aPQ bPQ
|
||||
("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo")
|
||||
("", aliceContact', Right (A.REQ aInvId PQSupportOn _ "bob's connInfo")) <- (alice <#:)
|
||||
aliceContact' `shouldBe` aliceContact
|
||||
|
||||
("2", bobConn, Right OK) <- alice #: ("2", "", "ACPT " <> aInvId <> enableKEMStr (CR.connPQEncryption aPQ) <> " 16\nalice's connInfo")
|
||||
("", aliceConn', Right (A.CONF bConfId pqSup'' _ "alice's connInfo")) <- (bob <#:)
|
||||
pqSup'' `shouldBe` bPQ
|
||||
aliceConn' `shouldBe` aliceConn
|
||||
|
||||
bob #: ("12", aliceConn, "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", aliceConn, OK)
|
||||
alice <# ("", bobConn, A.INFO (CR.connPQEncryption aPQ) "bob's connInfo 2")
|
||||
alice <# ("", bobConn, CON pq)
|
||||
bob <# ("", aliceConn, CON pq)
|
||||
|
||||
alice #: ("3", bobConn, "SEND F :hi") #> ("3", bobConn, A.MID 4 pq)
|
||||
alice <# ("", bobConn, SENT 4)
|
||||
bob <#= \case ("", c, Msg' 4 pq' "hi") -> c == aliceConn && pq == pq'; _ -> False
|
||||
bob #: ("13", aliceConn, "ACK 4") #> ("13", aliceConn, OK)
|
||||
|
||||
testRejectContactRequest :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testRejectContactRequest _ alice bob = do
|
||||
("1", "a_contact", Right (INV cReq)) <- alice #: ("1", "a_contact", "NEW T CON subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 10\nbob's info") #> ("11", "alice", OK)
|
||||
("", "a_contact", Right (A.REQ aInvId PQSupportOn _ "bob's info")) <- (alice <#:)
|
||||
-- RJCT must use correct contact connection
|
||||
alice #: ("2a", "bob", "RJCT " <> aInvId) #> ("2a", "bob", ERR $ CONN NOT_FOUND)
|
||||
alice #: ("2b", "a_contact", "RJCT " <> aInvId) #> ("2b", "a_contact", OK)
|
||||
alice #: ("3", "bob", "ACPT " <> aInvId <> " 12\nalice's info") =#> \case ("3", "bob", ERR (A.CMD PROHIBITED _)) -> True; _ -> False
|
||||
bob #:# "nothing should be delivered to bob"
|
||||
|
||||
testSubscription :: Transport c => TProxy c -> c -> c -> c -> IO ()
|
||||
testSubscription _ alice1 alice2 bob = do
|
||||
(alice1, "alice") `connect` (bob, "bob")
|
||||
bob #: ("12", "alice", "SEND F 5\nhello") #> ("12", "alice", MID 4)
|
||||
bob <# ("", "alice", SENT 4)
|
||||
alice1 <#= \case ("", "bob", Msg "hello") -> True; _ -> False
|
||||
alice1 #: ("1", "bob", "ACK 4") #> ("1", "bob", OK)
|
||||
bob #: ("13", "alice", "SEND F 11\nhello again") #> ("13", "alice", MID 5)
|
||||
bob <# ("", "alice", SENT 5)
|
||||
alice1 <#= \case ("", "bob", Msg "hello again") -> True; _ -> False
|
||||
alice1 #: ("2", "bob", "ACK 5") #> ("2", "bob", OK)
|
||||
alice2 #: ("21", "bob", "SUB") #> ("21", "bob", OK)
|
||||
alice1 <# ("", "bob", END)
|
||||
bob #: ("14", "alice", "SEND F 2\nhi") #> ("14", "alice", MID 6)
|
||||
bob <# ("", "alice", SENT 6)
|
||||
alice2 <#= \case ("", "bob", Msg "hi") -> True; _ -> False
|
||||
alice2 #: ("22", "bob", "ACK 6") #> ("22", "bob", OK)
|
||||
alice1 #:# "nothing else should be delivered to alice1"
|
||||
|
||||
testSubscrNotification :: Transport c => TProxy c -> (ThreadId, ThreadId) -> c -> IO ()
|
||||
testSubscrNotification t (server, _) client = do
|
||||
client #: ("1", "conn1", "NEW T INV subscribe") =#> \case ("1", "conn1", INV {}) -> True; _ -> False
|
||||
client #:# "nothing should be delivered to client before the server is killed"
|
||||
killThread server
|
||||
client <#. ("", "", DOWN testSMPServer ["conn1"])
|
||||
withSmpServer (ATransport t) $
|
||||
client <#= \case ("", "conn1", ERR (SMP _ AUTH)) -> True; _ -> False -- this new server does not have the queue
|
||||
|
||||
testMsgDeliveryServerRestart :: forall c. Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO ()
|
||||
testMsgDeliveryServerRestart (alice, aPQ) (bob, bPQ) = do
|
||||
let pq = pqConnectionMode aPQ bPQ
|
||||
withServer $ do
|
||||
connect' (alice, "alice", aPQ) (bob, "bob", bPQ)
|
||||
bob #: ("1", "alice", "SEND F 2\nhi") #> ("1", "alice", A.MID 4 pq)
|
||||
bob <# ("", "alice", SENT 4)
|
||||
alice <#= \case ("", "bob", Msg' _ pq' "hi") -> pq == pq'; _ -> False
|
||||
alice #: ("11", "bob", "ACK 4") #> ("11", "bob", OK)
|
||||
alice #:# "nothing else delivered before the server is killed"
|
||||
|
||||
let server = SMPServer "localhost" testPort2 testKeyHash
|
||||
alice <#. ("", "", DOWN server ["bob"])
|
||||
bob #: ("2", "alice", "SEND F 11\nhello again") #> ("2", "alice", A.MID 5 pq)
|
||||
bob #:# "nothing else delivered before the server is restarted"
|
||||
alice #:# "nothing else delivered before the server is restarted"
|
||||
|
||||
withServer $ do
|
||||
bob <# ("", "alice", SENT 5)
|
||||
alice <#= \case ("", "bob", Msg' _ pq' "hello again") -> pq == pq'; _ -> False
|
||||
alice #: ("12", "bob", "ACK 5") #> ("12", "bob", OK)
|
||||
|
||||
removeFile testStoreLogFile
|
||||
where
|
||||
withServer test' = withSmpServerStoreLogOn (transport @c) testPort2 (const test') `shouldReturn` ()
|
||||
|
||||
testServerConnectionAfterError :: forall c. Transport c => TProxy c -> [c] -> IO ()
|
||||
testServerConnectionAfterError t _ = do
|
||||
withAgent1 $ \bob -> do
|
||||
withAgent2 $ \alice -> do
|
||||
withServer $ do
|
||||
connect (bob, "bob") (alice, "alice")
|
||||
bob <#. ("", "", DOWN server ["alice"])
|
||||
alice <#. ("", "", DOWN server ["bob"])
|
||||
alice #: ("1", "bob", "SEND F 5\nhello") #> ("1", "bob", MID 4)
|
||||
alice #:# "nothing else delivered before the server is restarted"
|
||||
bob #:# "nothing else delivered before the server is restarted"
|
||||
|
||||
withAgent1 $ \bob -> do
|
||||
withAgent2 $ \alice -> do
|
||||
bob #:! ("1", "alice", "SUB") =#> \case ("1", "alice", ERR (BROKER _ e)) -> e == NETWORK || e == TIMEOUT; _ -> False
|
||||
alice #:! ("1", "bob", "SUB") =#> \case ("1", "bob", ERR (BROKER _ e)) -> e == NETWORK || e == TIMEOUT; _ -> False
|
||||
withServer $ do
|
||||
alice <#=? \case ("", "bob", APC SAEConn (SENT 4)) -> True; _ -> False
|
||||
bob <#=? \case ("", "alice", APC _ (Msg "hello")) -> True; _ -> False
|
||||
bob #: ("2", "alice", "ACK 4") #> ("2", "alice", OK)
|
||||
alice #: ("1", "bob", "SEND F 11\nhello again") #> ("1", "bob", MID 5)
|
||||
alice <# ("", "bob", SENT 5)
|
||||
bob <#= \case ("", "alice", Msg "hello again") -> True; _ -> False
|
||||
|
||||
removeFile testStoreLogFile
|
||||
removeFile testDB
|
||||
removeFile testDB2
|
||||
where
|
||||
server = SMPServer "localhost" testPort2 testKeyHash
|
||||
withServer test' = withSmpServerStoreLogOn (ATransport t) testPort2 (const test') `shouldReturn` ()
|
||||
withAgent1 = withAgent agentTestPort testDB 0
|
||||
withAgent2 = withAgent agentTestPort2 testDB2 10
|
||||
withAgent :: String -> FilePath -> Int -> (c -> IO a) -> IO a
|
||||
withAgent agentPort agentDB initClientId = withSmpAgentThreadOn_ (ATransport t) (agentPort, testPort2, agentDB) initClientId (pure ()) . const . testSMPAgentClientOn agentPort
|
||||
|
||||
testMsgDeliveryAgentRestart :: Transport c => TProxy c -> c -> IO ()
|
||||
testMsgDeliveryAgentRestart t bob = do
|
||||
let server = SMPServer "localhost" testPort2 testKeyHash
|
||||
withAgent $ \alice -> do
|
||||
withServer $ do
|
||||
connect (bob, "bob") (alice, "alice")
|
||||
alice #: ("1", "bob", "SEND F 5\nhello") #> ("1", "bob", MID 4)
|
||||
alice <# ("", "bob", SENT 4)
|
||||
bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
|
||||
bob #: ("11", "alice", "ACK 4") #> ("11", "alice", OK)
|
||||
bob #:# "nothing else delivered before the server is down"
|
||||
|
||||
bob <#. ("", "", DOWN server ["alice"])
|
||||
alice #: ("2", "bob", "SEND F 11\nhello again") #> ("2", "bob", MID 5)
|
||||
alice #:# "nothing else delivered before the server is restarted"
|
||||
bob #:# "nothing else delivered before the server is restarted"
|
||||
|
||||
withAgent $ \alice -> do
|
||||
withServer $ do
|
||||
tPutRaw alice ("3", "bob", "SUB")
|
||||
alice <#= \case
|
||||
(corrId, "bob", cmd) ->
|
||||
(corrId == "3" && cmd == OK)
|
||||
|| (corrId == "" && cmd == SENT 5)
|
||||
_ -> False
|
||||
bob <#=? \case ("", "alice", APC _ (Msg "hello again")) -> True; _ -> False
|
||||
bob #: ("12", "alice", "ACK 5") #> ("12", "alice", OK)
|
||||
|
||||
removeFile testStoreLogFile
|
||||
removeFile testDB
|
||||
where
|
||||
withServer test' = withSmpServerStoreLogOn (ATransport t) testPort2 (const test') `shouldReturn` ()
|
||||
withAgent = withSmpAgentThreadOn_ (ATransport t) (agentTestPort, testPort, testDB) 0 (pure ()) . const . testSMPAgentClientOn agentTestPort
|
||||
|
||||
testConcurrentMsgDelivery :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testConcurrentMsgDelivery _ alice bob = do
|
||||
connect (alice, "alice") (bob, "bob")
|
||||
|
||||
("1", "bob2", Right (INV cReq)) <- alice #: ("1", "bob2", "NEW T INV subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
bob #: ("11", "alice2", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice2", OK)
|
||||
("", "bob2", Right (A.CONF _confId PQSupportOff _ "bob's connInfo")) <- (alice <#:)
|
||||
-- below commands would be needed to accept bob's connection, but alice does not
|
||||
-- alice #: ("2", "bob", "LET " <> _confId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
|
||||
-- bob <# ("", "alice", INFO "alice's connInfo")
|
||||
-- bob <# ("", "alice", CON)
|
||||
-- alice <# ("", "bob", CON)
|
||||
|
||||
-- the first connection should not be blocked by the second one
|
||||
sendMessage (alice, "alice") (bob, "bob") "hello"
|
||||
-- alice #: ("2", "bob", "SEND F :hello") #> ("2", "bob", MID 1)
|
||||
-- alice <# ("", "bob", SENT 1)
|
||||
-- bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
|
||||
-- bob #: ("12", "alice", "ACK 1") #> ("12", "alice", OK)
|
||||
bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", MID 5)
|
||||
bob <# ("", "alice", SENT 5)
|
||||
-- if delivery is blocked it won't go further
|
||||
alice <#= \case ("", "bob", Msg "hello too") -> True; _ -> False
|
||||
alice #: ("3", "bob", "ACK 5") #> ("3", "bob", OK)
|
||||
|
||||
testMsgDeliveryQuotaExceeded :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testMsgDeliveryQuotaExceeded _ alice bob = do
|
||||
connect (alice, "alice") (bob, "bob")
|
||||
connect (alice, "alice2") (bob, "bob2")
|
||||
forM_ [1 .. 4 :: Int] $ \i -> do
|
||||
let corrId = bshow i
|
||||
msg = "message " <> bshow i
|
||||
(_, "bob", Right (MID mId)) <- alice #: (corrId, "bob", "SEND F :" <> msg)
|
||||
alice <#= \case ("", "bob", SENT m) -> m == mId; _ -> False
|
||||
(_, "bob", Right (MID _)) <- alice #: ("5", "bob", "SEND F :over quota")
|
||||
alice <#= \case ("", "bob", MWARN _ (SMP _ QUOTA)) -> True; _ -> False
|
||||
|
||||
alice #: ("1", "bob2", "SEND F :hello") #> ("1", "bob2", MID 4)
|
||||
-- if delivery is blocked it won't go further
|
||||
alice <# ("", "bob2", SENT 4)
|
||||
|
||||
testResumeDeliveryQuotaExceeded :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testResumeDeliveryQuotaExceeded _ alice bob = do
|
||||
connect (alice, "alice") (bob, "bob")
|
||||
forM_ [1 .. 4 :: Int] $ \i -> do
|
||||
let corrId = bshow i
|
||||
msg = "message " <> bshow i
|
||||
(_, "bob", Right (MID mId)) <- alice #: (corrId, "bob", "SEND F :" <> msg)
|
||||
alice <#= \case ("", "bob", SENT m) -> m == mId; _ -> False
|
||||
("5", "bob", Right (MID 8)) <- alice #: ("5", "bob", "SEND F :over quota")
|
||||
alice <#= \case ("", "bob", MWARN 8 (SMP _ QUOTA)) -> True; _ -> False
|
||||
alice #:# "the last message not sent yet"
|
||||
bob <#= \case ("", "alice", Msg "message 1") -> True; _ -> False
|
||||
bob #: ("1", "alice", "ACK 4") #> ("1", "alice", OK)
|
||||
alice #:# "the last message not sent"
|
||||
bob <#= \case ("", "alice", Msg "message 2") -> True; _ -> False
|
||||
bob #: ("2", "alice", "ACK 5") #> ("2", "alice", OK)
|
||||
alice #:# "the last message not sent"
|
||||
bob <#= \case ("", "alice", Msg "message 3") -> True; _ -> False
|
||||
bob #: ("3", "alice", "ACK 6") #> ("3", "alice", OK)
|
||||
alice #:# "the last message not sent"
|
||||
bob <#= \case ("", "alice", Msg "message 4") -> True; _ -> False
|
||||
bob #: ("4", "alice", "ACK 7") #> ("4", "alice", OK)
|
||||
inAnyOrder
|
||||
(tGetAgent alice)
|
||||
[ \case ("", c, Right (SENT 8)) -> c == "bob"; _ -> False,
|
||||
\case ("", c, Right QCONT) -> c == "bob"; _ -> False
|
||||
]
|
||||
bob <#= \case ("", "alice", Msg "over quota") -> True; _ -> False
|
||||
-- message 8 is skipped because of alice agent sending "QCONT" message
|
||||
bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK)
|
||||
|
||||
connect :: Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
|
||||
connect (h1, name1) (h2, name2) = connect' (h1, name1, IKPQOn) (h2, name2, PQSupportOn)
|
||||
|
||||
connect' :: forall c. Transport c => (c, ByteString, InitialKeys) -> (c, ByteString, PQSupport) -> IO ()
|
||||
connect' (h1, name1, pqMode1) (h2, name2, pqMode2) = do
|
||||
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV" <> pqConnModeStr pqMode1 <> " subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
pq = pqConnectionMode pqMode1 pqMode2
|
||||
h2 #: ("c2", name1, "JOIN T " <> cReq' <> enableKEMStr pqMode2 <> " subscribe 5\ninfo2") #> ("c2", name1, OK)
|
||||
("", _, Right (A.CONF connId pqSup' _ "info2")) <- (h1 <#:)
|
||||
pqSup' `shouldBe` CR.connPQEncryption pqMode1
|
||||
h1 #: ("c3", name2, "LET " <> connId <> " 5\ninfo1") #> ("c3", name2, OK)
|
||||
h2 <# ("", name1, A.INFO pqMode2 "info1")
|
||||
h2 <# ("", name1, CON pq)
|
||||
h1 <# ("", name2, CON pq)
|
||||
|
||||
pqConnectionMode :: InitialKeys -> PQSupport -> PQEncryption
|
||||
pqConnectionMode pqMode1 pqMode2 = PQEncryption $ supportPQ (CR.connPQEncryption pqMode1) && supportPQ pqMode2
|
||||
|
||||
enableKEMStr :: PQSupport -> ByteString
|
||||
enableKEMStr PQSupportOn = " " <> strEncode PQSupportOn
|
||||
enableKEMStr _ = ""
|
||||
|
||||
pqConnModeStr :: InitialKeys -> ByteString
|
||||
pqConnModeStr (IKNoPQ PQSupportOff) = ""
|
||||
pqConnModeStr pq = " " <> strEncode pq
|
||||
|
||||
sendMessage :: Transport c => (c, ConnId) -> (c, ConnId) -> ByteString -> IO ()
|
||||
sendMessage (h1, name1) (h2, name2) msg = do
|
||||
("m1", name2', Right (MID mId)) <- h1 #: ("m1", name2, "SEND F :" <> msg)
|
||||
name2' `shouldBe` name2
|
||||
h1 <#= \case ("", n, SENT m) -> n == name2 && m == mId; _ -> False
|
||||
("", name1', Right (MSG MsgMeta {recipient = (msgId', _)} _ msg')) <- (h2 <#:)
|
||||
name1' `shouldBe` name1
|
||||
msg' `shouldBe` msg
|
||||
h2 #: ("m2", name1, "ACK " <> bshow msgId') =#> \case ("m2", n, OK) -> n == name1; _ -> False
|
||||
|
||||
-- connect' :: forall c. Transport c => c -> c -> IO (ByteString, ByteString)
|
||||
-- connect' h1 h2 = do
|
||||
-- ("c1", conn2, Right (INV cReq)) <- h1 #: ("c1", "", "NEW T INV subscribe")
|
||||
-- let cReq' = strEncode cReq
|
||||
-- ("c2", conn1, Right OK) <- h2 #: ("c2", "", "JOIN T " <> cReq' <> " subscribe 5\ninfo2")
|
||||
-- ("", _, Right (REQ connId _ "info2")) <- (h1 <#:)
|
||||
-- h1 #: ("c3", conn2, "ACPT " <> connId <> " 5\ninfo1") =#> \case ("c3", c, OK) -> c == conn2; _ -> False
|
||||
-- h2 <# ("", conn1, INFO "info1")
|
||||
-- h2 <# ("", conn1, CON)
|
||||
-- h1 <# ("", conn2, CON)
|
||||
-- pure (conn1, conn2)
|
||||
|
||||
sampleDhKey :: ByteString
|
||||
sampleDhKey = "MCowBQYDK2VuAyEAjiswwI3O_NlS8Fk3HJUW870EY2bAwmttMBsvRB9eV3o="
|
||||
|
||||
syntaxTests :: forall c. Transport c => TProxy c -> Spec
|
||||
syntaxTests t = do
|
||||
it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR CMD SYNTAX parseCommand")
|
||||
describe "NEW" $ do
|
||||
describe "valid" $ do
|
||||
it "with correct parameter" $ ("211", "", "NEW T INV subscribe") >#>= \case ("211", _, "INV" : _) -> True; _ -> False
|
||||
describe "invalid" $ do
|
||||
it "with incorrect parameter" $ ("222", "", "NEW T hi subscribe") >#> ("222", "", "ERR CMD SYNTAX parseCommand")
|
||||
|
||||
describe "JOIN" $ do
|
||||
describe "valid" $ do
|
||||
it "using same server as in invitation" $
|
||||
( "311",
|
||||
"a",
|
||||
"JOIN T https://simpex.chat/invitation#/?smp=smp%3A%2F%2F"
|
||||
<> urlEncode True "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
|
||||
<> "%40localhost%3A5001%2F3456-w%3D%3D%23"
|
||||
<> urlEncode True sampleDhKey
|
||||
<> "&v=2"
|
||||
<> "&e2e=v%3D2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
<> " subscribe "
|
||||
<> "14\nbob's connInfo"
|
||||
)
|
||||
>#> ("311", "a", "ERR SMP smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001 AUTH")
|
||||
describe "invalid" $ do
|
||||
it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR CMD SYNTAX parseCommand")
|
||||
where
|
||||
-- simple test for one command with the expected response
|
||||
(>#>) :: ARawTransmission -> ARawTransmission -> Expectation
|
||||
command >#> response = withFrozenCallStack $ smpAgentTest t command `shouldReturn` response
|
||||
|
||||
-- simple test for one command with a predicate for the expected response
|
||||
(>#>=) :: ARawTransmission -> ((ByteString, ByteString, [ByteString]) -> Bool) -> Expectation
|
||||
command >#>= p = withFrozenCallStack $ smpAgentTest t command >>= (`shouldSatisfy` p . \(cId, connId, cmd) -> (cId, connId, B.words cmd))
|
||||
|
||||
@@ -88,7 +88,7 @@ import Simplex.Messaging.Agent.Store.SQLite (MigrationConfirmation (..), SQLiteS
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction')
|
||||
import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), SMPProxyFallback (..), SMPProxyMode (..), TransportSessionMode (TSMEntity, TSMUser), defaultClientConfig)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn)
|
||||
import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOff, pattern IKPQOn, pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn)
|
||||
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Transport (NTFVersion, authBatchCmdsNTFVersion, pattern VersionNTF)
|
||||
@@ -98,7 +98,7 @@ import Simplex.Messaging.Server.Env.STM (ServerConfig (..))
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Server.QueueStore.QueueInfo
|
||||
import Simplex.Messaging.Transport (ATransport (..), SMPVersion, VersionSMP, authCmdsSMPVersion, basicAuthSMPVersion, batchCmdsSMPVersion, currentServerSMPRelayVersion, supportedSMPHandshakes)
|
||||
import Simplex.Messaging.Util (diffToMicroseconds)
|
||||
import Simplex.Messaging.Util (bshow, diffToMicroseconds)
|
||||
import Simplex.Messaging.Version (VersionRange (..))
|
||||
import qualified Simplex.Messaging.Version as V
|
||||
import Simplex.Messaging.Version.Internal (Version (..))
|
||||
@@ -108,7 +108,7 @@ import UnliftIO
|
||||
import Util
|
||||
import XFTPClient (testXFTPServer)
|
||||
|
||||
type AEntityTransmission e = (ACorrId, ConnId, ACommand 'Agent e)
|
||||
type AEntityTransmission e = (ACorrId, ConnId, AEvent e)
|
||||
|
||||
-- deriving instance Eq (ValidFileDescription p)
|
||||
|
||||
@@ -144,60 +144,62 @@ nGet c = withFrozenCallStack $ get' @'AENone c
|
||||
|
||||
nGetUP :: (MonadIO m, HasCallStack) => AgentClient -> m (AEntityTransmission 'AENone)
|
||||
nGetUP c = withFrozenCallStack $ liftIO $ do
|
||||
timeout 15000000 (pGet_ c True) >>= \case
|
||||
Just (corrId, connId, APC _ cmd@UP {}) -> pure (corrId, connId, cmd)
|
||||
Just (_, _, APC _ cmd) -> error $ "unexpected command " <> show cmd
|
||||
timeout 15000000 (pGet' c True False) >>= \case
|
||||
Just (corrId, connId, AEvt _ cmd@UP {}) -> pure (corrId, connId, cmd)
|
||||
Just (_, _, AEvt _ cmd) -> error $ "unexpected command " <> show cmd
|
||||
Nothing -> error "timed out waiting for UP"
|
||||
|
||||
get' :: forall e m. (MonadIO m, AEntityI e, HasCallStack) => AgentClient -> m (AEntityTransmission e)
|
||||
get' c = withFrozenCallStack $ do
|
||||
(corrId, connId, APC e cmd) <- pGet c
|
||||
(corrId, connId, AEvt e cmd) <- pGet c
|
||||
case testEquality e (sAEntity @e) of
|
||||
Just Refl -> pure (corrId, connId, cmd)
|
||||
_ -> error $ "unexpected command " <> show cmd
|
||||
|
||||
pGet :: forall m. (MonadIO m, HasCallStack) => AgentClient -> m (ATransmission 'Agent)
|
||||
pGet c = withFrozenCallStack $ pGet_ c False
|
||||
pGet :: forall m. MonadIO m => AgentClient -> m ATransmission
|
||||
pGet c = pGet' c True True
|
||||
|
||||
pGet_ :: forall m. (MonadIO m, HasCallStack) => AgentClient -> Bool -> m (ATransmission 'Agent)
|
||||
pGet_ c expectUp = withFrozenCallStack $ do
|
||||
t@(_, _, APC _ cmd) <- atomically (readTBQueue $ subQ c)
|
||||
pGet' :: forall m. MonadIO m => AgentClient -> Bool -> Bool -> m ATransmission
|
||||
pGet' c skipWarn skipUp = do
|
||||
t@(_, _, AEvt _ cmd) <- atomically (readTBQueue $ subQ c)
|
||||
case cmd of
|
||||
CONNECT {} -> pGet_ c expectUp
|
||||
DISCONNECT {} -> pGet_ c expectUp
|
||||
ERR (BROKER _ NETWORK) -> pGet_ c expectUp
|
||||
MWARN {} -> pGet_ c expectUp
|
||||
UP {} | not expectUp -> pGet_ c expectUp
|
||||
CONNECT {} -> pGet' c skipWarn skipUp
|
||||
DISCONNECT {} -> pGet' c skipWarn skipUp
|
||||
ERR (BROKER _ NETWORK) -> pGet' c skipWarn skipUp
|
||||
MWARN {} | skipWarn -> pGet' c skipWarn skipUp
|
||||
RFWARN {} | skipWarn -> pGet' c skipWarn skipUp
|
||||
SFWARN {} | skipWarn -> pGet' c skipWarn skipUp
|
||||
UP {} | skipUp -> pGet' c skipWarn skipUp
|
||||
_ -> pure t
|
||||
|
||||
pattern CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand 'Agent e
|
||||
pattern CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> AEvent e
|
||||
pattern CONF conId srvs connInfo <- A.CONF conId PQSupportOn srvs connInfo
|
||||
|
||||
pattern INFO :: ConnInfo -> ACommand 'Agent 'AEConn
|
||||
pattern INFO :: ConnInfo -> AEvent 'AEConn
|
||||
pattern INFO connInfo = A.INFO PQSupportOn connInfo
|
||||
|
||||
pattern REQ :: InvitationId -> NonEmpty SMPServer -> ConnInfo -> ACommand 'Agent e
|
||||
pattern REQ :: InvitationId -> NonEmpty SMPServer -> ConnInfo -> AEvent e
|
||||
pattern REQ invId srvs connInfo <- A.REQ invId PQSupportOn srvs connInfo
|
||||
|
||||
pattern CON :: ACommand 'Agent 'AEConn
|
||||
pattern CON :: AEvent 'AEConn
|
||||
pattern CON = A.CON PQEncOn
|
||||
|
||||
pattern Msg :: MsgBody -> ACommand 'Agent e
|
||||
pattern Msg :: MsgBody -> AEvent e
|
||||
pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk, pqEncryption = PQEncOn} _ msgBody
|
||||
|
||||
pattern Msg' :: AgentMsgId -> PQEncryption -> MsgBody -> ACommand 'Agent e
|
||||
pattern Msg' :: AgentMsgId -> PQEncryption -> MsgBody -> AEvent e
|
||||
pattern Msg' aMsgId pq msgBody <- MSG MsgMeta {integrity = MsgOk, recipient = (aMsgId, _), pqEncryption = pq} _ msgBody
|
||||
|
||||
pattern MsgErr :: AgentMsgId -> MsgErrorType -> MsgBody -> ACommand 'Agent 'AEConn
|
||||
pattern MsgErr :: AgentMsgId -> MsgErrorType -> MsgBody -> AEvent 'AEConn
|
||||
pattern MsgErr msgId err msgBody <- MSG MsgMeta {recipient = (msgId, _), integrity = MsgError err} _ msgBody
|
||||
|
||||
pattern MsgErr' :: AgentMsgId -> MsgErrorType -> PQEncryption -> MsgBody -> ACommand 'Agent 'AEConn
|
||||
pattern MsgErr' :: AgentMsgId -> MsgErrorType -> PQEncryption -> MsgBody -> AEvent 'AEConn
|
||||
pattern MsgErr' msgId err pq msgBody <- MSG MsgMeta {recipient = (msgId, _), integrity = MsgError err, pqEncryption = pq} _ msgBody
|
||||
|
||||
pattern SENT :: AgentMsgId -> ACommand 'Agent 'AEConn
|
||||
pattern SENT :: AgentMsgId -> AEvent 'AEConn
|
||||
pattern SENT msgId = A.SENT msgId Nothing
|
||||
|
||||
pattern Rcvd :: AgentMsgId -> ACommand 'Agent 'AEConn
|
||||
pattern Rcvd :: AgentMsgId -> AEvent 'AEConn
|
||||
pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}]
|
||||
|
||||
smpCfgVPrev :: ProtocolClientConfig SMPVersion
|
||||
@@ -255,7 +257,7 @@ runRight action =
|
||||
Right x -> pure x
|
||||
Left e -> error $ "Unexpected error: " <> show e
|
||||
|
||||
getInAnyOrder :: HasCallStack => AgentClient -> [ATransmission 'Agent -> Bool] -> Expectation
|
||||
getInAnyOrder :: HasCallStack => AgentClient -> [ATransmission -> Bool] -> Expectation
|
||||
getInAnyOrder c ts = withFrozenCallStack $ inAnyOrder (pGet c) ts
|
||||
|
||||
inAnyOrder :: (Show a, MonadUnliftIO m, HasCallStack) => m a -> [a -> Bool] -> m ()
|
||||
@@ -292,12 +294,20 @@ functionalAPITests t = do
|
||||
withSmpServer t testAgentClient3
|
||||
it "should establish connection without PQ encryption and enable it" $
|
||||
withSmpServer t testEnablePQEncryption
|
||||
describe "Establishing duplex connection, different PQ settings" $ do
|
||||
testPQMatrix2 t $ runAgentClientTestPQ True
|
||||
describe "Establishing duplex connection v2, different Ratchet versions" $
|
||||
testRatchetMatrix2 t runAgentClientTest
|
||||
describe "Establish duplex connection via contact address" $
|
||||
testMatrix2 t runAgentClientContactTest
|
||||
describe "Establish duplex connection via contact address, different PQ settings" $ do
|
||||
testPQMatrix2NoInv t $ runAgentClientContactTestPQ True PQSupportOn
|
||||
describe "Establish duplex connection via contact address v2, different Ratchet versions" $
|
||||
testRatchetMatrix2 t runAgentClientContactTest
|
||||
describe "Establish duplex connection via contact address, different PQ settings" $ do
|
||||
testPQMatrix3 t $ runAgentClientContactTestPQ3 True
|
||||
it "should support rejecting contact request" $
|
||||
withSmpServer t testRejectContactRequest
|
||||
describe "Establishing connection asynchronously" $ do
|
||||
it "should connect with initiating client going offline" $
|
||||
withSmpServer t testAsyncInitiatingOffline
|
||||
@@ -324,6 +334,10 @@ functionalAPITests t = do
|
||||
testDuplicateMessage t
|
||||
it "should report error via msg integrity on skipped messages" $
|
||||
testSkippedMessages t
|
||||
it "should connect to the server when server goes up if it initially was down" $
|
||||
testDeliveryAfterSubscriptionError t
|
||||
it "should deliver messages if one of connections has quota exceeded" $
|
||||
testMsgDeliveryQuotaExceeded t
|
||||
describe "message expiration" $ do
|
||||
it "should expire one message" $ testExpireMessage t
|
||||
it "should expire multiple messages" $ testExpireManyMessages t
|
||||
@@ -485,7 +499,7 @@ canCreateQueue allowNew (srvAuth, srvVersion) (clntAuth, clntVersion) =
|
||||
let v = basicAuthSMPVersion
|
||||
in allowNew && (isNothing srvAuth || (srvVersion >= v && clntVersion >= v && srvAuth == clntAuth))
|
||||
|
||||
testMatrix2 :: ATransport -> (PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec
|
||||
testMatrix2 :: HasCallStack => ATransport -> (PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec
|
||||
testMatrix2 t runTest = do
|
||||
it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentProxyCfg agentProxyCfg (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn True
|
||||
it "v7" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 $ runTest PQSupportOn False
|
||||
@@ -497,7 +511,7 @@ testMatrix2 t runTest = do
|
||||
it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQSupportOff False
|
||||
it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQSupportOff False
|
||||
|
||||
testRatchetMatrix2 :: ATransport -> (PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec
|
||||
testRatchetMatrix2 :: HasCallStack => ATransport -> (PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec
|
||||
testRatchetMatrix2 t runTest = do
|
||||
it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentProxyCfg agentProxyCfg (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn True
|
||||
it "ratchet next" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 $ runTest PQSupportOn False
|
||||
@@ -508,11 +522,50 @@ testRatchetMatrix2 t runTest = do
|
||||
it "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 $ runTest PQSupportOff False
|
||||
it "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 3 $ runTest PQSupportOff False
|
||||
|
||||
testServerMatrix2 :: ATransport -> (InitialAgentServers -> IO ()) -> Spec
|
||||
testServerMatrix2 :: HasCallStack => ATransport -> (InitialAgentServers -> IO ()) -> Spec
|
||||
testServerMatrix2 t runTest = do
|
||||
it "1 server" $ withSmpServer t $ runTest initAgentServers
|
||||
it "2 servers" $ withSmpServer t . withSmpServerOn t testPort2 $ runTest initAgentServers2
|
||||
|
||||
testPQMatrix2 :: HasCallStack => ATransport -> (HasCallStack => (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO ()) -> Spec
|
||||
testPQMatrix2 = pqMatrix2_ True
|
||||
|
||||
testPQMatrix2NoInv :: HasCallStack => ATransport -> (HasCallStack => (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO ()) -> Spec
|
||||
testPQMatrix2NoInv = pqMatrix2_ False
|
||||
|
||||
pqMatrix2_ :: HasCallStack => Bool -> ATransport -> (HasCallStack => (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO ()) -> Spec
|
||||
pqMatrix2_ pqInv t test = do
|
||||
it "dh/dh handshake" $ runTest $ \a b -> test (a, IKPQOff) (b, PQSupportOff)
|
||||
it "dh/pq handshake" $ runTest $ \a b -> test (a, IKPQOff) (b, PQSupportOn)
|
||||
it "pq/dh handshake" $ runTest $ \a b -> test (a, IKPQOn) (b, PQSupportOff)
|
||||
it "pq/pq handshake" $ runTest $ \a b -> test (a, IKPQOn) (b, PQSupportOn)
|
||||
when pqInv $ do
|
||||
it "pq-inv/dh handshake" $ runTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOff)
|
||||
it "pq-inv/pq handshake" $ runTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOn)
|
||||
where
|
||||
runTest = withSmpServerProxy t . runTestCfgServers2 agentProxyCfg agentProxyCfg (initAgentServersProxy SPMAlways SPFProhibit) 3
|
||||
|
||||
testPQMatrix3 ::
|
||||
HasCallStack =>
|
||||
ATransport ->
|
||||
(HasCallStack => (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> (AgentClient, PQSupport) -> AgentMsgId -> IO ()) ->
|
||||
Spec
|
||||
testPQMatrix3 t test = do
|
||||
it "dh" $ runTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOff)
|
||||
it "dh/dh/pq" $ runTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOn)
|
||||
it "dh/pq/dh" $ runTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOff)
|
||||
it "dh/pq/pq" $ runTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOn)
|
||||
it "pq/dh/dh" $ runTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOff)
|
||||
it "pq/dh/pq" $ runTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOn)
|
||||
it "pq/pq/dh" $ runTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOff)
|
||||
it "pq" $ runTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOn)
|
||||
where
|
||||
runTest test' =
|
||||
withSmpServerProxy t $
|
||||
runTestCfgServers2 agentProxyCfg agentProxyCfg servers 3 $ \a b baseMsgId ->
|
||||
withAgent 3 agentProxyCfg servers testDB3 $ \c -> test' a b c baseMsgId
|
||||
servers = initAgentServersProxy SPMAlways SPFProhibit
|
||||
|
||||
runTestCfg2 :: HasCallStack => AgentConfig -> AgentConfig -> AgentMsgId -> (HasCallStack => AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> IO ()
|
||||
runTestCfg2 aCfg bCfg = runTestCfgServers2 aCfg bCfg initAgentServers
|
||||
{-# INLINE runTestCfg2 #-}
|
||||
@@ -522,17 +575,17 @@ runTestCfgServers2 aCfg bCfg servers baseMsgId runTest =
|
||||
withAgentClientsCfgServers2 aCfg bCfg servers $ \a b -> runTest a b baseMsgId
|
||||
{-# INLINE runTestCfgServers2 #-}
|
||||
|
||||
withAgentClientsCfgServers2 :: HasCallStack => AgentConfig -> AgentConfig -> InitialAgentServers -> (HasCallStack => AgentClient -> AgentClient -> IO ()) -> IO ()
|
||||
withAgentClientsCfgServers2 :: HasCallStack => AgentConfig -> AgentConfig -> InitialAgentServers -> (HasCallStack => AgentClient -> AgentClient -> IO a) -> IO a
|
||||
withAgentClientsCfgServers2 aCfg bCfg servers runTest =
|
||||
withAgent 1 aCfg servers testDB $ \a ->
|
||||
withAgent 2 bCfg servers testDB2 $ \b ->
|
||||
runTest a b
|
||||
|
||||
withAgentClientsCfg2 :: HasCallStack => AgentConfig -> AgentConfig -> (HasCallStack => AgentClient -> AgentClient -> IO ()) -> IO ()
|
||||
withAgentClientsCfg2 :: HasCallStack => AgentConfig -> AgentConfig -> (HasCallStack => AgentClient -> AgentClient -> IO a) -> IO a
|
||||
withAgentClientsCfg2 aCfg bCfg = withAgentClientsCfgServers2 aCfg bCfg initAgentServers
|
||||
{-# INLINE withAgentClientsCfg2 #-}
|
||||
|
||||
withAgentClients2 :: HasCallStack => (HasCallStack => AgentClient -> AgentClient -> IO ()) -> IO ()
|
||||
withAgentClients2 :: HasCallStack => (HasCallStack => AgentClient -> AgentClient -> IO a) -> IO a
|
||||
withAgentClients2 = withAgentClientsCfg2 agentCfg agentCfg
|
||||
{-# INLINE withAgentClients2 #-}
|
||||
|
||||
@@ -543,16 +596,20 @@ withAgentClients3 runTest =
|
||||
runTest a b c
|
||||
|
||||
runAgentClientTest :: HasCallStack => PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientTest pqSupport viaProxy alice@AgentClient {} bob baseId =
|
||||
runAgentClientTest pqSupport viaProxy alice bob baseId =
|
||||
runAgentClientTestPQ viaProxy (alice, IKNoPQ pqSupport) (bob, pqSupport) baseId
|
||||
|
||||
runAgentClientTestPQ :: HasCallStack => Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO ()
|
||||
runAgentClientTestPQ viaProxy (alice, aPQ) (bob, bPQ) baseId =
|
||||
runRight_ $ do
|
||||
(bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (IKNoPQ pqSupport) SMSubscribe
|
||||
aliceId <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" pqSupport SMSubscribe
|
||||
(bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing aPQ SMSubscribe
|
||||
aliceId <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" bPQ SMSubscribe
|
||||
("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice
|
||||
liftIO $ pqSup' `shouldBe` pqSupport
|
||||
liftIO $ pqSup' `shouldBe` CR.connPQEncryption aPQ
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
let pqEnc = CR.pqSupportToEnc pqSupport
|
||||
let pqEnc = PQEncryption $ pqConnectionMode aPQ bPQ
|
||||
get alice ##> ("", bobId, A.CON pqEnc)
|
||||
get bob ##> ("", aliceId, A.INFO pqSupport "alice's connInfo")
|
||||
get bob ##> ("", aliceId, A.INFO bPQ "alice's connInfo")
|
||||
get bob ##> ("", aliceId, A.CON pqEnc)
|
||||
-- message IDs 1 to 3 (or 1 to 4 in v1) get assigned to control messages, so first MSG is assigned ID 4
|
||||
let proxySrv = if viaProxy then Just testSMPServer else Nothing
|
||||
@@ -580,6 +637,9 @@ runAgentClientTest pqSupport viaProxy alice@AgentClient {} bob baseId =
|
||||
where
|
||||
msgId = subtract baseId . fst
|
||||
|
||||
pqConnectionMode :: InitialKeys -> PQSupport -> Bool
|
||||
pqConnectionMode pqMode1 pqMode2 = supportPQ (CR.connPQEncryption pqMode1) && supportPQ pqMode2
|
||||
|
||||
testEnablePQEncryption :: HasCallStack => IO ()
|
||||
testEnablePQEncryption =
|
||||
withAgentClients2 $ \ca cb -> runRight_ $ do
|
||||
@@ -672,19 +732,23 @@ testAgentClient3 =
|
||||
|
||||
runAgentClientContactTest :: HasCallStack => PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientContactTest pqSupport viaProxy alice bob baseId =
|
||||
runAgentClientContactTestPQ viaProxy pqSupport (alice, IKNoPQ pqSupport) (bob, pqSupport) baseId
|
||||
|
||||
runAgentClientContactTestPQ :: HasCallStack => Bool -> PQSupport -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO ()
|
||||
runAgentClientContactTestPQ viaProxy reqPQSupport (alice, aPQ) (bob, bPQ) baseId =
|
||||
runRight_ $ do
|
||||
(_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing (IKNoPQ pqSupport) SMSubscribe
|
||||
aliceId <- A.prepareConnectionToJoin bob 1 True qInfo pqSupport
|
||||
aliceId' <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" pqSupport SMSubscribe
|
||||
(_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing aPQ SMSubscribe
|
||||
aliceId <- A.prepareConnectionToJoin bob 1 True qInfo bPQ
|
||||
aliceId' <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" bPQ SMSubscribe
|
||||
liftIO $ aliceId' `shouldBe` aliceId
|
||||
("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice
|
||||
liftIO $ pqSup' `shouldBe` pqSupport
|
||||
bobId <- acceptContact alice True invId "alice's connInfo" PQSupportOn SMSubscribe
|
||||
liftIO $ pqSup' `shouldBe` reqPQSupport
|
||||
bobId <- acceptContact alice True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe
|
||||
("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get bob
|
||||
liftIO $ pqSup'' `shouldBe` pqSupport
|
||||
liftIO $ pqSup'' `shouldBe` bPQ
|
||||
allowConnection bob aliceId confId "bob's connInfo"
|
||||
let pqEnc = CR.pqSupportToEnc pqSupport
|
||||
get alice ##> ("", bobId, A.INFO pqSupport "bob's connInfo")
|
||||
let pqEnc = PQEncryption $ pqConnectionMode aPQ bPQ
|
||||
get alice ##> ("", bobId, A.INFO (CR.connPQEncryption aPQ) "bob's connInfo")
|
||||
get alice ##> ("", bobId, A.CON pqEnc)
|
||||
get bob ##> ("", aliceId, A.CON pqEnc)
|
||||
-- message IDs 1 to 3 (or 1 to 4 in v1) get assigned to control messages, so first MSG is assigned ID 4
|
||||
@@ -713,6 +777,41 @@ runAgentClientContactTest pqSupport viaProxy alice bob baseId =
|
||||
where
|
||||
msgId = subtract baseId . fst
|
||||
|
||||
runAgentClientContactTestPQ3 :: HasCallStack => Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> (AgentClient, PQSupport) -> AgentMsgId -> IO ()
|
||||
runAgentClientContactTestPQ3 viaProxy (alice, aPQ) (bob, bPQ) (tom, tPQ) baseId = runRight_ $ do
|
||||
(_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing aPQ SMSubscribe
|
||||
(bAliceId, bobId, abPQEnc) <- connectViaContact bob bPQ qInfo
|
||||
sentMessages abPQEnc alice bobId bob bAliceId
|
||||
(tAliceId, tomId, atPQEnc) <- connectViaContact tom tPQ qInfo
|
||||
sentMessages atPQEnc alice tomId tom tAliceId
|
||||
where
|
||||
msgId = subtract baseId . fst
|
||||
connectViaContact b pq qInfo = do
|
||||
aId <- A.prepareConnectionToJoin b 1 True qInfo pq
|
||||
aId' <- A.joinConnection b 1 (Just aId) True qInfo "bob's connInfo" pq SMSubscribe
|
||||
liftIO $ aId' `shouldBe` aId
|
||||
("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice
|
||||
liftIO $ pqSup' `shouldBe` PQSupportOn
|
||||
bId <- acceptContact alice True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe
|
||||
("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get b
|
||||
liftIO $ pqSup'' `shouldBe` pq
|
||||
allowConnection b aId confId "bob's connInfo"
|
||||
let pqEnc = PQEncryption $ pqConnectionMode aPQ pq
|
||||
get alice ##> ("", bId, A.INFO (CR.connPQEncryption aPQ) "bob's connInfo")
|
||||
get alice ##> ("", bId, A.CON pqEnc)
|
||||
get b ##> ("", aId, A.CON pqEnc)
|
||||
pure (aId, bId, pqEnc)
|
||||
sentMessages pqEnc a bId b aId = do
|
||||
let proxySrv = if viaProxy then Just testSMPServer else Nothing
|
||||
1 <- msgId <$> A.sendMessage a bId pqEnc SMP.noMsgFlags "hello"
|
||||
get a ##> ("", bId, A.SENT (baseId + 1) proxySrv)
|
||||
get b =##> \case ("", c, Msg' _ pq "hello") -> c == aId && pq == pqEnc; _ -> False
|
||||
ackMessage b aId (baseId + 1) Nothing
|
||||
2 <- msgId <$> A.sendMessage b aId pqEnc SMP.noMsgFlags "hello too"
|
||||
get b ##> ("", aId, A.SENT (baseId + 2) proxySrv)
|
||||
get a =##> \case ("", c, Msg' _ pq "hello too") -> c == bId && pq == pqEnc; _ -> False
|
||||
ackMessage a bId (baseId + 2) Nothing
|
||||
|
||||
noMessages :: HasCallStack => AgentClient -> String -> Expectation
|
||||
noMessages c err = tryGet `shouldReturn` ()
|
||||
where
|
||||
@@ -721,6 +820,18 @@ noMessages c err = tryGet `shouldReturn` ()
|
||||
Just msg -> error $ err <> ": " <> show msg
|
||||
_ -> return ()
|
||||
|
||||
testRejectContactRequest :: HasCallStack => IO ()
|
||||
testRejectContactRequest =
|
||||
withAgentClients2 $ \alice bob -> runRight_ $ do
|
||||
(addrConnId, qInfo) <- A.createConnection alice 1 True SCMContact Nothing IKPQOn SMSubscribe
|
||||
aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn
|
||||
aliceId' <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" PQSupportOn SMSubscribe
|
||||
liftIO $ aliceId' `shouldBe` aliceId
|
||||
("", _, A.REQ invId PQSupportOn _ "bob's connInfo") <- get alice
|
||||
liftIO $ runExceptT (rejectContact alice "abcd" invId) `shouldReturn` Left (CONN NOT_FOUND)
|
||||
rejectContact alice addrConnId invId
|
||||
liftIO $ noMessages bob "nothing delivered to bob"
|
||||
|
||||
testAsyncInitiatingOffline :: HasCallStack => IO ()
|
||||
testAsyncInitiatingOffline =
|
||||
withAgent 2 agentCfg initAgentServers testDB2 $ \bob -> runRight_ $ do
|
||||
@@ -1084,6 +1195,53 @@ testSkippedMessages t = do
|
||||
disposeAgentClient alice2
|
||||
disposeAgentClient bob2
|
||||
|
||||
testDeliveryAfterSubscriptionError :: HasCallStack => ATransport -> IO ()
|
||||
testDeliveryAfterSubscriptionError t = do
|
||||
(aId, bId) <- withAgentClients2 $ \a b -> do
|
||||
(aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ makeConnection a b
|
||||
nGet a =##> \case ("", "", DOWN _ [c]) -> c == bId; _ -> False
|
||||
nGet b =##> \case ("", "", DOWN _ [c]) -> c == aId; _ -> False
|
||||
4 <- runRight $ sendMessage a bId SMP.noMsgFlags "hello"
|
||||
liftIO $ noMessages b "not delivered"
|
||||
pure (aId, bId)
|
||||
|
||||
withAgentClients2 $ \a b -> do
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ subscribeConnection a bId
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ subscribeConnection b aId
|
||||
pure ()
|
||||
withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do
|
||||
get a =##> \case ("", c, SENT 4) -> c == bId; _ -> False
|
||||
get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False
|
||||
ackMessage b aId 4 Nothing
|
||||
|
||||
testMsgDeliveryQuotaExceeded :: HasCallStack => ATransport -> IO ()
|
||||
testMsgDeliveryQuotaExceeded t =
|
||||
withAgentClients2 $ \a b -> withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
(aId', bId') <- makeConnection a b
|
||||
forM_ ([1 .. 4] :: [Int]) $ \i -> do
|
||||
mId <- sendMessage a bId SMP.noMsgFlags $ "message " <> bshow i
|
||||
get a =##> \case ("", c, SENT mId') -> bId == c && mId == mId'; _ -> False
|
||||
8 <- sendMessage a bId SMP.noMsgFlags "over quota"
|
||||
pGet' a False True =##> \case ("", c, AEvt _ (MWARN 8 (SMP _ QUOTA))) -> bId == c; _ -> False
|
||||
4 <- sendMessage a bId' SMP.noMsgFlags "hello"
|
||||
get a =##> \case ("", c, SENT 4) -> bId' == c; _ -> False
|
||||
get b =##> \case ("", c, Msg "message 1") -> aId == c; _ -> False
|
||||
get b =##> \case ("", c, Msg "hello") -> aId' == c; _ -> False
|
||||
ackMessage b aId' 4 Nothing
|
||||
ackMessage b aId 4 Nothing
|
||||
get b =##> \case ("", c, Msg "message 2") -> aId == c; _ -> False
|
||||
ackMessage b aId 5 Nothing
|
||||
get b =##> \case ("", c, Msg "message 3") -> aId == c; _ -> False
|
||||
ackMessage b aId 6 Nothing
|
||||
get b =##> \case ("", c, Msg "message 4") -> aId == c; _ -> False
|
||||
ackMessage b aId 7 Nothing
|
||||
get a =##> \case ("", c, QCONT) -> bId == c; _ -> False
|
||||
get b =##> \case ("", c, Msg "over quota") -> aId == c; _ -> False
|
||||
ackMessage b aId 9 Nothing -- msg 8 was QCONT
|
||||
get a =##> \case ("", c, SENT 8) -> bId == c; _ -> False
|
||||
liftIO $ concurrently_ (noMessages a "no more events") (noMessages b "no more events")
|
||||
|
||||
testExpireMessage :: HasCallStack => ATransport -> IO ()
|
||||
testExpireMessage t =
|
||||
withAgent 1 agentCfg {messageTimeout = 1, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB $ \a ->
|
||||
@@ -1150,8 +1308,8 @@ testExpireMessageQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1} testP
|
||||
get b' =##> \case ("", c, Msg "1") -> c == aId; _ -> False
|
||||
ackMessage b' aId 4 Nothing
|
||||
liftIO . getInAnyOrder a $
|
||||
[ \case ("", c, APC SAEConn (SENT 6)) -> c == bId; _ -> False,
|
||||
\case ("", c, APC SAEConn QCONT) -> c == bId; _ -> False
|
||||
[ \case ("", c, AEvt SAEConn (SENT 6)) -> c == bId; _ -> False,
|
||||
\case ("", c, AEvt SAEConn QCONT) -> c == bId; _ -> False
|
||||
]
|
||||
get b' =##> \case ("", c, MsgErr 6 (MsgSkipped 4 4) "3") -> c == aId; _ -> False
|
||||
ackMessage b' aId 6 Nothing
|
||||
@@ -1187,8 +1345,8 @@ testExpireManyMessagesQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1}
|
||||
get b' =##> \case ("", c, Msg "1") -> c == aId; _ -> False
|
||||
ackMessage b' aId 4 Nothing
|
||||
liftIO . getInAnyOrder a $
|
||||
[ \case ("", c, APC SAEConn (SENT 8)) -> c == bId; _ -> False,
|
||||
\case ("", c, APC SAEConn QCONT) -> c == bId; _ -> False
|
||||
[ \case ("", c, AEvt SAEConn (SENT 8)) -> c == bId; _ -> False,
|
||||
\case ("", c, AEvt SAEConn QCONT) -> c == bId; _ -> False
|
||||
]
|
||||
get b' =##> \case ("", c, MsgErr 6 (MsgSkipped 4 6) "5") -> c == aId; _ -> False
|
||||
ackMessage b' aId 6 Nothing
|
||||
@@ -1261,9 +1419,9 @@ ratchetSyncP cId rss = \case
|
||||
cId' == cId && rss' == rss && ratchetSyncState == rss
|
||||
_ -> False
|
||||
|
||||
ratchetSyncP' :: ConnId -> RatchetSyncState -> ATransmission 'Agent -> Bool
|
||||
ratchetSyncP' :: ConnId -> RatchetSyncState -> ATransmission -> Bool
|
||||
ratchetSyncP' cId rss = \case
|
||||
(_, cId', APC SAEConn (RSYNC rss' _ ConnectionStats {ratchetSyncState})) ->
|
||||
(_, cId', AEvt SAEConn (RSYNC rss' _ ConnectionStats {ratchetSyncState})) ->
|
||||
cId' == cId && rss' == rss && ratchetSyncState == rss
|
||||
_ -> False
|
||||
|
||||
@@ -1432,8 +1590,8 @@ testInactiveNoSubs t = do
|
||||
withSmpServerConfigOn t cfg' testPort $ \_ ->
|
||||
withAgent 1 agentCfg initAgentServers testDB $ \alice -> do
|
||||
runRight_ . void $ createConnection alice 1 True SCMInvitation Nothing SMOnlyCreate -- do not subscribe to pass noSubscriptions check
|
||||
Just (_, _, APC SAENone (CONNECT _ _)) <- timeout 2000000 $ atomically (readTBQueue $ subQ alice)
|
||||
Just (_, _, APC SAENone (DISCONNECT _ _)) <- timeout 5000000 $ atomically (readTBQueue $ subQ alice)
|
||||
Just (_, _, AEvt SAENone (CONNECT _ _)) <- timeout 2000000 $ atomically (readTBQueue $ subQ alice)
|
||||
Just (_, _, AEvt SAENone (DISCONNECT _ _)) <- timeout 5000000 $ atomically (readTBQueue $ subQ alice)
|
||||
pure ()
|
||||
|
||||
testInactiveWithSubs :: ATransport -> IO ()
|
||||
@@ -1510,11 +1668,11 @@ testSuspendingAgentCompleteSending t = withAgentClients2 $ \a b -> do
|
||||
liftIO $ suspendAgent b 5000000
|
||||
|
||||
withSmpServerStoreLogOn t testPort $ \_ -> runRight_ @AgentErrorType $ do
|
||||
pGet b =##> \case ("", c, APC SAEConn (SENT 5)) -> c == aId; _ -> False
|
||||
pGet b =##> \case ("", c, APC SAEConn (SENT 6)) -> c == aId; _ -> False
|
||||
pGet b =##> \case ("", c, AEvt SAEConn (SENT 5)) -> c == aId; _ -> False
|
||||
pGet b =##> \case ("", c, AEvt SAEConn (SENT 6)) -> c == aId; _ -> False
|
||||
("", "", SUSPENDED) <- nGet b
|
||||
|
||||
pGet a =##> \case ("", c, APC _ (Msg "hello too")) -> c == bId; _ -> False
|
||||
pGet a =##> \case ("", c, AEvt _ (Msg "hello too")) -> c == bId; _ -> False
|
||||
ackMessage a bId 5 Nothing
|
||||
get a =##> \case ("", c, Msg "how are you?") -> c == bId; _ -> False
|
||||
ackMessage a bId 6 Nothing
|
||||
@@ -1968,7 +2126,7 @@ testJoinConnectionAsyncReplyError t = do
|
||||
ConnectionStats {rcvQueuesInfo = [RcvQueueInfo {}], sndQueuesInfo = [SndQueueInfo {}]} <- getConnectionServers b aId
|
||||
pure ()
|
||||
withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do
|
||||
pGet a =##> \case ("3", c, APC _ OK) -> c == bId; _ -> False
|
||||
pGet a =##> \case ("3", c, AEvt _ OK) -> c == bId; _ -> False
|
||||
get a ##> ("", bId, CON)
|
||||
get b ##> ("", aId, INFO "alice's connInfo")
|
||||
get b ##> ("", aId, CON)
|
||||
@@ -2113,7 +2271,7 @@ testSwitchAsync servers = do
|
||||
withB = withAgent 2 agentCfg servers testDB2
|
||||
|
||||
withAgent :: HasCallStack => Int -> AgentConfig -> InitialAgentServers -> FilePath -> (HasCallStack => AgentClient -> IO a) -> IO a
|
||||
withAgent clientId cfg' servers dbPath = bracket (getSMPAgentClient' clientId cfg' servers dbPath) disposeAgentClient
|
||||
withAgent clientId cfg' servers dbPath = bracket (getSMPAgentClient' clientId cfg' servers dbPath) (\a -> disposeAgentClient a >> threadDelay 100000)
|
||||
|
||||
sessionSubscribe :: (forall a. (AgentClient -> IO a) -> IO a) -> [ConnId] -> (AgentClient -> ExceptT AgentErrorType IO ()) -> IO ()
|
||||
sessionSubscribe withC connIds a =
|
||||
@@ -2240,20 +2398,20 @@ testAbortSwitchStartedReinitiate servers = do
|
||||
withB :: (AgentClient -> IO a) -> IO a
|
||||
withB = withAgent 2 agentCfg servers testDB2
|
||||
|
||||
switchPhaseRcvP :: ConnId -> SwitchPhase -> [Maybe RcvSwitchStatus] -> ATransmission 'Agent -> Bool
|
||||
switchPhaseRcvP :: ConnId -> SwitchPhase -> [Maybe RcvSwitchStatus] -> ATransmission -> Bool
|
||||
switchPhaseRcvP cId sphase swchStatuses = switchPhaseP cId QDRcv sphase (\stats -> rcvSwchStatuses' stats == swchStatuses)
|
||||
|
||||
switchPhaseSndP :: ConnId -> SwitchPhase -> [Maybe SndSwitchStatus] -> ATransmission 'Agent -> Bool
|
||||
switchPhaseSndP :: ConnId -> SwitchPhase -> [Maybe SndSwitchStatus] -> ATransmission -> Bool
|
||||
switchPhaseSndP cId sphase swchStatuses = switchPhaseP cId QDSnd sphase (\stats -> sndSwchStatuses' stats == swchStatuses)
|
||||
|
||||
switchPhaseP :: ConnId -> QueueDirection -> SwitchPhase -> (ConnectionStats -> Bool) -> ATransmission 'Agent -> Bool
|
||||
switchPhaseP :: ConnId -> QueueDirection -> SwitchPhase -> (ConnectionStats -> Bool) -> ATransmission -> Bool
|
||||
switchPhaseP cId qd sphase statsP = \case
|
||||
(_, cId', APC SAEConn (SWITCH qd' sphase' stats)) -> cId' == cId && qd' == qd && sphase' == sphase && statsP stats
|
||||
(_, cId', AEvt SAEConn (SWITCH qd' sphase' stats)) -> cId' == cId && qd' == qd && sphase' == sphase && statsP stats
|
||||
_ -> False
|
||||
|
||||
errQueueNotFoundP :: ConnId -> ATransmission 'Agent -> Bool
|
||||
errQueueNotFoundP :: ConnId -> ATransmission -> Bool
|
||||
errQueueNotFoundP cId = \case
|
||||
(_, cId', APC SAEConn (ERR AGENT {agentErr = A_QUEUE {queueErr = "QKEY: queue address not found in connection"}})) -> cId' == cId
|
||||
(_, cId', AEvt SAEConn (ERR AGENT {agentErr = A_QUEUE {queueErr = "QKEY: queue address not found in connection"}})) -> cId' == cId
|
||||
_ -> False
|
||||
|
||||
testCannotAbortSwitchSecured :: HasCallStack => InitialAgentServers -> IO ()
|
||||
|
||||
@@ -57,7 +57,7 @@ import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMes
|
||||
import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore')
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers)
|
||||
import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, SENT)
|
||||
import Simplex.Messaging.Agent.Store.SQLite (getSavedNtfToken)
|
||||
import Simplex.Messaging.Agent.Store.SQLite (closeSQLiteStore, getSavedNtfToken, reopenSQLiteStore)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
@@ -162,11 +162,12 @@ testNtfMatrix t runTest = do
|
||||
it "servers: next SMP v7, curr NTF v2; clients: curr/new" $ runNtfTestCfg t cfgV7 ntfServerCfgV2 agentCfg agentCfgV7 runTest
|
||||
|
||||
runNtfTestCfg :: ATransport -> ServerConfig -> NtfServerConfig -> AgentConfig -> AgentConfig -> (APNSMockServer -> AgentClient -> AgentClient -> IO ()) -> IO ()
|
||||
runNtfTestCfg t smpCfg ntfCfg aCfg bCfg runTest =
|
||||
runNtfTestCfg t smpCfg ntfCfg aCfg bCfg runTest = do
|
||||
withSmpServerConfigOn t smpCfg testPort $ \_ ->
|
||||
withAPNSMockServer $ \apns ->
|
||||
withNtfServerCfg ntfCfg {transports = [(ntfTestPort, t)]} $ \_ ->
|
||||
withAgentClientsCfg2 aCfg bCfg $ runTest apns
|
||||
threadDelay 100000
|
||||
|
||||
testNotificationToken :: APNSMockServer -> IO ()
|
||||
testNotificationToken APNSMockServer {apnsQ} = do
|
||||
@@ -346,7 +347,7 @@ testRunNTFServerTests t srv =
|
||||
testProtocolServer a 1 $ ProtoServerWithAuth srv Nothing
|
||||
|
||||
testNotificationSubscriptionExistingConnection :: APNSMockServer -> AgentClient -> AgentClient -> IO ()
|
||||
testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} alice@AgentClient {agentEnv = Env {config = aliceCfg}} bob = do
|
||||
testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} alice@AgentClient {agentEnv = Env {config = aliceCfg, store}} bob = do
|
||||
(bobId, aliceId, nonce, message) <- runRight $ do
|
||||
-- establish connection
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe
|
||||
@@ -377,11 +378,21 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} alice@Agen
|
||||
-- alice client already has subscription for the connection
|
||||
Left (CMD PROHIBITED _) <- runExceptT $ getNotificationMessage alice nonce message
|
||||
|
||||
threadDelay 200000
|
||||
suspendAgent alice 0
|
||||
closeSQLiteStore store
|
||||
threadDelay 200000
|
||||
|
||||
-- aliceNtf client doesn't have subscription and is allowed to get notification message
|
||||
withAgent 3 aliceCfg initAgentServers testDB $ \aliceNtf -> runRight_ $ do
|
||||
(_, [SMPMsgMeta {msgFlags = MsgFlags True}]) <- getNotificationMessage aliceNtf nonce message
|
||||
pure ()
|
||||
|
||||
threadDelay 200000
|
||||
reopenSQLiteStore store
|
||||
foregroundAgent alice
|
||||
threadDelay 200000
|
||||
|
||||
runRight_ $ do
|
||||
get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId (baseId + 1) Nothing
|
||||
|
||||
@@ -663,7 +663,7 @@ testGetPendingServerCommand st = do
|
||||
Right (Just PendingCommand {corrId = corrId'}) <- getPendingServerCommand db (Just smpServer1)
|
||||
corrId' `shouldBe` "4"
|
||||
where
|
||||
command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) (IKNoPQ PQSupportOn) SMSubscribe
|
||||
command = AClientCommand $ NEW True (ACM SCMInvitation) (IKNoPQ PQSupportOn) SMSubscribe
|
||||
corruptCmd :: DB.Connection -> ByteString -> ConnId -> IO ()
|
||||
corruptCmd db corrId connId = DB.execute db "UPDATE commands SET command = cast('bad' as blob) WHERE conn_id = ? AND corr_id = ?" (connId, corrId)
|
||||
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
|
||||
module CoreTests.ProtocolErrorTests where
|
||||
|
||||
import GHC.Generics (Generic)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Simplex.FileTransfer.Transport (XFTPErrorType (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import qualified Simplex.Messaging.Agent.Protocol as Agent
|
||||
import Simplex.Messaging.Client (ProxyClientError (..))
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (CommandError (..), ErrorType (..))
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport (HandshakeError (..), TransportError (..))
|
||||
import Simplex.RemoteControl.Types (RCErrorType (..))
|
||||
import Test.Hspec
|
||||
import Test.Hspec.QuickCheck (modifyMaxSuccess)
|
||||
import Test.QuickCheck
|
||||
|
||||
protocolErrorTests :: Spec
|
||||
protocolErrorTests = modifyMaxSuccess (const 1000) $ do
|
||||
describe "errors parsing / serializing" $ do
|
||||
it "should parse SMP protocol errors" . property . forAll possibleErrorType $ \err ->
|
||||
smpDecode (smpEncode err) == Right err
|
||||
it "should parse SMP agent errors" . property . forAll possibleAgentErrorType $ \err ->
|
||||
strDecode (strEncode err) == Right err
|
||||
where
|
||||
possibleErrorType :: Gen ErrorType
|
||||
possibleErrorType = arbitrary >>= \e -> if skipErrorType e then discard else pure e
|
||||
possibleAgentErrorType :: Gen AgentErrorType
|
||||
possibleAgentErrorType =
|
||||
arbitrary >>= \case
|
||||
BROKER srv _ | hasSpaces srv -> discard
|
||||
SMP srv e | hasSpaces srv || skipErrorType e -> discard
|
||||
NTF srv e | hasSpaces srv || skipErrorType e -> discard
|
||||
XFTP srv _ | hasSpaces srv -> discard
|
||||
Agent.PROXY pxy srv _ | hasSpaces pxy || hasSpaces srv -> discard
|
||||
Agent.PROXY _ _ (ProxyProtocolError e) | skipErrorType e -> discard
|
||||
Agent.PROXY _ _ (ProxyUnexpectedResponse e) | hasUnicode e -> discard
|
||||
Agent.PROXY _ _ (ProxyResponseError e) | skipErrorType e -> discard
|
||||
ok -> pure ok
|
||||
hasSpaces :: String -> Bool
|
||||
hasSpaces = any (== ' ')
|
||||
hasUnicode :: String -> Bool
|
||||
hasUnicode = any (>= '\255')
|
||||
skipErrorType = \case
|
||||
SMP.PROXY (SMP.PROTOCOL e) -> skipErrorType e
|
||||
SMP.PROXY (SMP.BROKER (UNEXPECTED s)) -> hasUnicode s
|
||||
SMP.PROXY (SMP.BROKER (RESPONSE s)) -> hasUnicode s
|
||||
_ -> False
|
||||
|
||||
deriving instance Generic AgentErrorType
|
||||
|
||||
deriving instance Generic CommandErrorType
|
||||
|
||||
deriving instance Generic ConnectionErrorType
|
||||
|
||||
deriving instance Generic ProxyClientError
|
||||
|
||||
deriving instance Generic BrokerErrorType
|
||||
|
||||
deriving instance Generic SMPAgentError
|
||||
|
||||
deriving instance Generic AgentCryptoError
|
||||
|
||||
deriving instance Generic ErrorType
|
||||
|
||||
deriving instance Generic CommandError
|
||||
|
||||
deriving instance Generic SMP.ProxyError
|
||||
|
||||
deriving instance Generic TransportError
|
||||
|
||||
deriving instance Generic HandshakeError
|
||||
|
||||
deriving instance Generic XFTPErrorType
|
||||
|
||||
deriving instance Generic RCErrorType
|
||||
|
||||
instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary ProxyClientError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary AgentCryptoError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary ErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary CommandError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary SMP.ProxyError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary TransportError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary HandshakeError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary XFTPErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary RCErrorType where arbitrary = genericArbitraryU
|
||||
+3
-180
@@ -10,54 +10,20 @@
|
||||
|
||||
module SMPAgentClient where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import qualified Database.SQLite.Simple as SQL
|
||||
import Network.Socket (ServiceName)
|
||||
import NtfClient (ntfTestPort)
|
||||
import SMPClient
|
||||
( proxyVRange,
|
||||
serverBracket,
|
||||
testKeyHash,
|
||||
testPort,
|
||||
testPort2,
|
||||
withSmpServer,
|
||||
withSmpServerOn,
|
||||
withSmpServerThreadOn,
|
||||
)
|
||||
import SMPClient (proxyVRange, testPort)
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Server (runSMPAgentBlocking)
|
||||
import Simplex.Messaging.Agent.Store.SQLite (MigrationConfirmation (..), SQLiteStore (dbNew))
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction')
|
||||
import Simplex.Messaging.Client (ProtocolClientConfig (..), SMPProxyFallback, SMPProxyMode, chooseTransportHost, defaultNetworkConfig, defaultSMPClientConfig)
|
||||
import Simplex.Messaging.Client (ProtocolClientConfig (..), SMPProxyFallback, SMPProxyMode, defaultNetworkConfig, defaultSMPClientConfig)
|
||||
import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig)
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Protocol (NtfServer, ProtoServerWithAuth)
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Test.Hspec
|
||||
import UnliftIO.Concurrent
|
||||
import UnliftIO.Directory
|
||||
import XFTPClient (testXFTPServer)
|
||||
|
||||
agentTestHost :: NonEmpty TransportHost
|
||||
agentTestHost = "localhost"
|
||||
|
||||
agentTestPort :: ServiceName
|
||||
agentTestPort = "5010"
|
||||
|
||||
agentTestPort2 :: ServiceName
|
||||
agentTestPort2 = "5011"
|
||||
|
||||
agentTestPort3 :: ServiceName
|
||||
agentTestPort3 = "5012"
|
||||
|
||||
testDB :: FilePath
|
||||
testDB = "tests/tmp/smp-agent.test.protocol.db"
|
||||
|
||||
@@ -67,114 +33,6 @@ testDB2 = "tests/tmp/smp-agent2.test.protocol.db"
|
||||
testDB3 :: FilePath
|
||||
testDB3 = "tests/tmp/smp-agent3.test.protocol.db"
|
||||
|
||||
smpAgentTest :: forall c. Transport c => TProxy c -> ARawTransmission -> IO ARawTransmission
|
||||
smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> get h
|
||||
where
|
||||
get h = do
|
||||
t@(_, _, cmdStr) <- tGetRaw h
|
||||
case parseAll networkCommandP cmdStr of
|
||||
Right (ACmd SAgent _ CONNECT {}) -> get h
|
||||
Right (ACmd SAgent _ DISCONNECT {}) -> get h
|
||||
_ -> pure t
|
||||
|
||||
runSmpAgentTest :: forall c a. Transport c => (c -> IO a) -> IO a
|
||||
runSmpAgentTest test = withSmpServer t . withSmpAgent t $ testSMPAgentClient test
|
||||
where
|
||||
t = transport @c
|
||||
|
||||
runSmpAgentServerTest :: forall c a. Transport c => ((ThreadId, ThreadId) -> c -> IO a) -> IO a
|
||||
runSmpAgentServerTest test =
|
||||
withSmpServerThreadOn t testPort $
|
||||
\server -> withSmpAgentThreadOn t (agentTestPort, testPort, testDB) $
|
||||
\agent -> testSMPAgentClient $ test (server, agent)
|
||||
where
|
||||
t = transport @c
|
||||
|
||||
smpAgentServerTest :: Transport c => ((ThreadId, ThreadId) -> c -> IO ()) -> Expectation
|
||||
smpAgentServerTest test' = runSmpAgentServerTest test' `shouldReturn` ()
|
||||
|
||||
runSmpAgentTestN :: forall c a. Transport c => [(ServiceName, ServiceName, FilePath)] -> ([c] -> IO a) -> IO a
|
||||
runSmpAgentTestN agents test = withSmpServer t $ run agents []
|
||||
where
|
||||
run :: [(ServiceName, ServiceName, FilePath)] -> [c] -> IO a
|
||||
run [] hs = test hs
|
||||
run (a@(p, _, _) : as) hs = withSmpAgentOn t a $ testSMPAgentClientOn p $ \h -> run as (h : hs)
|
||||
t = transport @c
|
||||
|
||||
runSmpAgentTestN_1 :: forall c a. Transport c => Int -> ([c] -> IO a) -> IO a
|
||||
runSmpAgentTestN_1 nClients test = withSmpServer t . withSmpAgent t $ run nClients []
|
||||
where
|
||||
run :: Int -> [c] -> IO a
|
||||
run 0 hs = test hs
|
||||
run n hs = testSMPAgentClient $ \h -> run (n - 1) (h : hs)
|
||||
t = transport @c
|
||||
|
||||
smpAgentTestN :: Transport c => [(ServiceName, ServiceName, FilePath)] -> ([c] -> IO ()) -> Expectation
|
||||
smpAgentTestN agents test' = runSmpAgentTestN agents test' `shouldReturn` ()
|
||||
|
||||
smpAgentTestN_1 :: Transport c => Int -> ([c] -> IO ()) -> Expectation
|
||||
smpAgentTestN_1 n test' = runSmpAgentTestN_1 n test' `shouldReturn` ()
|
||||
|
||||
smpAgentTest2_2_2 :: forall c. Transport c => (c -> c -> IO ()) -> Expectation
|
||||
smpAgentTest2_2_2 test' =
|
||||
withSmpServerOn (transport @c) testPort2 $
|
||||
smpAgentTest2_2_2_needs_server test'
|
||||
|
||||
smpAgentTest2_2_2_needs_server :: forall c. Transport c => (c -> c -> IO ()) -> Expectation
|
||||
smpAgentTest2_2_2_needs_server test' =
|
||||
smpAgentTestN
|
||||
[ (agentTestPort, testPort, testDB),
|
||||
(agentTestPort2, testPort2, testDB2)
|
||||
]
|
||||
_test
|
||||
where
|
||||
_test [h1, h2] = test' h1 h2
|
||||
_test _ = error "expected 2 handles"
|
||||
|
||||
smpAgentTest2_2_1 :: Transport c => (c -> c -> IO ()) -> Expectation
|
||||
smpAgentTest2_2_1 test' =
|
||||
smpAgentTestN
|
||||
[ (agentTestPort, testPort, testDB),
|
||||
(agentTestPort2, testPort, testDB2)
|
||||
]
|
||||
_test
|
||||
where
|
||||
_test [h1, h2] = test' h1 h2
|
||||
_test _ = error "expected 2 handles"
|
||||
|
||||
smpAgentTest2_1_1 :: Transport c => (c -> c -> IO ()) -> Expectation
|
||||
smpAgentTest2_1_1 test' = smpAgentTestN_1 2 _test
|
||||
where
|
||||
_test [h1, h2] = test' h1 h2
|
||||
_test _ = error "expected 2 handles"
|
||||
|
||||
smpAgentTest3 :: Transport c => (c -> c -> c -> IO ()) -> Expectation
|
||||
smpAgentTest3 test' =
|
||||
smpAgentTestN
|
||||
[ (agentTestPort, testPort, testDB),
|
||||
(agentTestPort2, testPort, testDB2),
|
||||
(agentTestPort3, testPort, testDB3)
|
||||
]
|
||||
_test
|
||||
where
|
||||
_test [h1, h2, h3] = test' h1 h2 h3
|
||||
_test _ = error "expected 3 handles"
|
||||
|
||||
smpAgentTest3_1_1 :: Transport c => (c -> c -> c -> IO ()) -> Expectation
|
||||
smpAgentTest3_1_1 test' = smpAgentTestN_1 3 _test
|
||||
where
|
||||
_test [h1, h2, h3] = test' h1 h2 h3
|
||||
_test _ = error "expected 3 handles"
|
||||
|
||||
smpAgentTest1_1_1 :: forall c. Transport c => (c -> IO ()) -> Expectation
|
||||
smpAgentTest1_1_1 test' =
|
||||
smpAgentTestN
|
||||
[(agentTestPort2, testPort2, testDB2)]
|
||||
_test
|
||||
where
|
||||
_test [h] = test' h
|
||||
_test _ = error "expected 1 handle"
|
||||
|
||||
testSMPServer :: SMPServer
|
||||
testSMPServer = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"
|
||||
|
||||
@@ -206,14 +64,13 @@ initAgentServersProxy smpProxyMode smpProxyFallback =
|
||||
agentCfg :: AgentConfig
|
||||
agentCfg =
|
||||
defaultAgentConfig
|
||||
{ tcpPort = Just agentTestPort,
|
||||
{ tcpPort = Nothing,
|
||||
tbqSize = 4,
|
||||
-- database = testDB,
|
||||
smpCfg = defaultSMPClientConfig {qSize = 1, defaultTransport = (testPort, transport @TLS), networkConfig},
|
||||
ntfCfg = defaultNTFClientConfig {qSize = 1, defaultTransport = (ntfTestPort, transport @TLS), networkConfig},
|
||||
reconnectInterval = fastRetryInterval,
|
||||
persistErrorInterval = 1,
|
||||
xftpNotifyErrsOnRetry = False,
|
||||
ntfWorkerDelay = 100,
|
||||
ntfSMPWorkerDelay = 100,
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
@@ -232,39 +89,5 @@ fastRetryInterval = defaultReconnectInterval {initialInterval = 50_000}
|
||||
fastMessageRetryInterval :: RetryInterval2
|
||||
fastMessageRetryInterval = RetryInterval2 {riFast = fastRetryInterval, riSlow = fastRetryInterval}
|
||||
|
||||
withSmpAgentThreadOn_ :: ATransport -> (ServiceName, ServiceName, FilePath) -> Int -> IO () -> (ThreadId -> IO a) -> IO a
|
||||
withSmpAgentThreadOn_ t (port', smpPort', db') initClientId afterProcess =
|
||||
let cfg' = agentCfg {tcpPort = Just port'}
|
||||
initServers' = initAgentServers {smp = userServers [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
|
||||
in serverBracket
|
||||
( \started -> do
|
||||
Right st <- liftIO $ createAgentStore db' "" False MCError
|
||||
when (dbNew st) . liftIO $ withTransaction' st (`SQL.execute_` "INSERT INTO users (user_id) VALUES (1)")
|
||||
runSMPAgentBlocking t cfg' initServers' st initClientId started
|
||||
)
|
||||
afterProcess
|
||||
|
||||
userServers :: NonEmpty (ProtoServerWithAuth p) -> Map UserId (NonEmpty (ProtoServerWithAuth p))
|
||||
userServers srvs = M.fromList [(1, srvs)]
|
||||
|
||||
withSmpAgentThreadOn :: ATransport -> (ServiceName, ServiceName, FilePath) -> (ThreadId -> IO a) -> IO a
|
||||
withSmpAgentThreadOn t a@(_, _, db') = withSmpAgentThreadOn_ t a 0 $ removeFile db'
|
||||
|
||||
withSmpAgentOn :: ATransport -> (ServiceName, ServiceName, FilePath) -> IO a -> IO a
|
||||
withSmpAgentOn t (port', smpPort', db') = withSmpAgentThreadOn t (port', smpPort', db') . const
|
||||
|
||||
withSmpAgent :: ATransport -> IO a -> IO a
|
||||
withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
|
||||
|
||||
testSMPAgentClientOn :: Transport c => ServiceName -> (c -> IO a) -> IO a
|
||||
testSMPAgentClientOn port' client = do
|
||||
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig agentTestHost
|
||||
runTransportClient defaultTransportClientConfig Nothing useHost port' (Just testKeyHash) $ \h -> do
|
||||
line <- getLn h
|
||||
if line == "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
then client h
|
||||
else do
|
||||
error $ "wrong welcome message: " <> B.unpack line
|
||||
|
||||
testSMPAgentClient :: Transport c => (c -> IO a) -> IO a
|
||||
testSMPAgentClient = testSMPAgentClientOn agentTestPort
|
||||
|
||||
@@ -11,7 +11,6 @@ import CoreTests.BatchingTests
|
||||
import CoreTests.CryptoFileTests
|
||||
import CoreTests.CryptoTests
|
||||
import CoreTests.EncodingTests
|
||||
import CoreTests.ProtocolErrorTests
|
||||
import CoreTests.RetryIntervalTests
|
||||
import CoreTests.TRcvQueuesTests
|
||||
import CoreTests.UtilTests
|
||||
@@ -49,7 +48,6 @@ main = do
|
||||
describe "Core tests" $ do
|
||||
describe "Batching tests" batchingTests
|
||||
describe "Encoding tests" encodingTests
|
||||
describe "Protocol error tests" protocolErrorTests
|
||||
describe "Version range" versionRangeTests
|
||||
describe "Encryption tests" cryptoTests
|
||||
describe "Encrypted files tests" cryptoFileTests
|
||||
|
||||
+5
-3
@@ -20,15 +20,17 @@ import Data.Int (Int64)
|
||||
import Data.List (find, isSuffixOf)
|
||||
import Data.Maybe (fromJust)
|
||||
import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2, testDB3)
|
||||
import SMPClient (xit'')
|
||||
import Simplex.FileTransfer.Client (XFTPClientConfig (..))
|
||||
import Simplex.FileTransfer.Description (FileChunk (..), FileDescription (..), FileDescriptionURI (..), ValidFileDescription, fileDescriptionURI, kb, mb, qrSizeLimit, pattern ValidFileDescription)
|
||||
import Simplex.FileTransfer.Protocol (FileParty (..))
|
||||
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..))
|
||||
import Simplex.FileTransfer.Transport (XFTPErrorType (AUTH))
|
||||
import Simplex.FileTransfer.Types (RcvFileId, SndFileId)
|
||||
import Simplex.Messaging.Agent (AgentClient, testProtocolServer, xftpDeleteRcvFile, xftpDeleteSndFileInternal, xftpDeleteSndFileRemote, xftpReceiveFile, xftpSendDescription, xftpSendFile, xftpStartWorkers)
|
||||
import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..))
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, xftpCfg)
|
||||
import Simplex.Messaging.Agent.Protocol (ACommand (..), AgentErrorType (..), BrokerErrorType (..), RcvFileId, SndFileId, noAuthSrv)
|
||||
import Simplex.Messaging.Agent.Protocol (AEvent (..), AgentErrorType (..), BrokerErrorType (..), noAuthSrv)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs)
|
||||
import qualified Simplex.Messaging.Crypto.File as CF
|
||||
@@ -58,7 +60,7 @@ xftpAgentTests = around_ testBracket . describe "agent XFTP API" $ do
|
||||
it "should resume receiving file after restart" testXFTPAgentReceiveRestore
|
||||
it "should cleanup rcv tmp path after permanent error" testXFTPAgentReceiveCleanup
|
||||
it "should resume sending file after restart" testXFTPAgentSendRestore
|
||||
xit "should cleanup snd prefix path after permanent error" testXFTPAgentSendCleanup
|
||||
xit'' "should cleanup snd prefix path after permanent error" testXFTPAgentSendCleanup
|
||||
it "should delete sent file on server" testXFTPAgentDelete
|
||||
it "should resume deleting file after restart" testXFTPAgentDeleteRestore
|
||||
-- TODO when server is fixed to correctly send AUTH error, this test has to be modified to expect AUTH error
|
||||
@@ -475,7 +477,7 @@ testXFTPAgentSendCleanup = withGlobalLogging logCfgNoLogs $ do
|
||||
-- send file - should fail with AUTH error
|
||||
withAgent 2 agentCfg initAgentServers testDB $ \sndr' -> do
|
||||
runRight_ $ xftpStartWorkers sndr' (Just senderFiles)
|
||||
("", sfId', SFERR (INTERNAL "XFTP {serverAddress = \"xftp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:7000\", xftpErr = AUTH}")) <-
|
||||
("", sfId', SFERR (XFTP "xftp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:7000" AUTH)) <-
|
||||
sfGet sndr'
|
||||
sfId' `shouldBe` sfId
|
||||
|
||||
|
||||
Reference in New Issue
Block a user