mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-06-07 04:31:46 +00:00
Merge remote-tracking branch 'origin/master' into ab/bench-target
This commit is contained in:
@@ -179,7 +179,8 @@ runXFTPRcvWorker c srv Worker {doWork} = do
|
||||
RcvFileChunk {rcvFileId, rcvFileEntityId, fileTmpPath, replicas = []} -> rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath) "chunk has no replicas"
|
||||
fc@RcvFileChunk {userId, rcvFileId, rcvFileEntityId, digest, fileTmpPath, replicas = replica@RcvFileChunkReplica {rcvChunkReplicaId, server, delay} : _} -> do
|
||||
let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay
|
||||
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop ->
|
||||
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do
|
||||
lift $ waitForUserNetwork c
|
||||
downloadFileChunk fc replica
|
||||
`catchAgentError` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e
|
||||
where
|
||||
@@ -389,7 +390,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
let numRecipients' = min numRecipients maxRecipients
|
||||
-- concurrently?
|
||||
-- separate worker to create chunks? record retries and delay on snd_file_chunks?
|
||||
forM_ (filter (not . chunkCreated) chunks) $ createChunk numRecipients'
|
||||
forM_ (filter (\SndFileChunk {replicas} -> null replicas) chunks) $ createChunk numRecipients'
|
||||
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSUploading
|
||||
where
|
||||
AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients, messageRetryInterval = ri} = cfg
|
||||
@@ -413,9 +414,6 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
let chunkSpecs = prepareChunkSpecs fsEncPath chunkSizes
|
||||
chunkDigests <- liftIO $ mapM getChunkDigest chunkSpecs
|
||||
pure (FileDigest digest, zip chunkSpecs $ coerce chunkDigests)
|
||||
chunkCreated :: SndFileChunk -> Bool
|
||||
chunkCreated SndFileChunk {replicas} =
|
||||
any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSCreated) replicas
|
||||
createChunk :: Int -> SndFileChunk -> AM ()
|
||||
createChunk numRecipients' ch = do
|
||||
atomically $ assertAgentForeground c
|
||||
@@ -425,7 +423,8 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
where
|
||||
tryCreate = do
|
||||
usedSrvs <- newTVarIO ([] :: [XFTPServer])
|
||||
withRetryInterval (riFast ri) $ \_ loop ->
|
||||
withRetryInterval (riFast ri) $ \_ loop -> do
|
||||
lift $ waitForUserNetwork c
|
||||
createWithNextSrv usedSrvs
|
||||
`catchAgentError` \e -> retryOnError "XFTP prepare worker" (retryLoop loop) (throwError e) e
|
||||
where
|
||||
@@ -457,7 +456,8 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
SndFileChunk {sndFileId, sndFileEntityId, filePrefixPath, replicas = []} -> sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) "chunk has no replicas"
|
||||
fc@SndFileChunk {userId, sndFileId, sndFileEntityId, filePrefixPath, digest, replicas = replica@SndFileChunkReplica {sndChunkReplicaId, server, delay} : _} -> do
|
||||
let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay
|
||||
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop ->
|
||||
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do
|
||||
lift $ waitForUserNetwork c
|
||||
uploadFileChunk cfg fc replica
|
||||
`catchAgentError` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e
|
||||
where
|
||||
@@ -623,7 +623,8 @@ runXFTPDelWorker c srv Worker {doWork} = do
|
||||
where
|
||||
processDeletedReplica replica@DeletedSndChunkReplica {deletedSndChunkReplicaId, userId, server, chunkDigest, delay} = do
|
||||
let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay
|
||||
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop ->
|
||||
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do
|
||||
lift $ waitForUserNetwork c
|
||||
deleteChunkReplica
|
||||
`catchAgentError` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e
|
||||
where
|
||||
|
||||
@@ -4,11 +4,14 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module Simplex.FileTransfer.Client where
|
||||
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
@@ -20,9 +23,10 @@ import Data.Int (Int64)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import Data.Time (UTCTime)
|
||||
import Data.Word (Word32)
|
||||
import qualified Data.X509 as X
|
||||
import qualified Data.X509.Validation as XV
|
||||
import qualified Network.HTTP.Types as N
|
||||
import qualified Network.HTTP2.Client as H
|
||||
import Simplex.FileTransfer.Description (mb)
|
||||
import Simplex.FileTransfer.Protocol
|
||||
import Simplex.FileTransfer.Transport
|
||||
import Simplex.Messaging.Client
|
||||
@@ -37,6 +41,7 @@ import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import Simplex.Messaging.Encoding (smpDecode, smpEncode)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol
|
||||
( BasicAuth,
|
||||
@@ -45,12 +50,13 @@ import Simplex.Messaging.Protocol
|
||||
RecipientId,
|
||||
SenderId,
|
||||
)
|
||||
import Simplex.Messaging.Transport (THandleParams (..), supportedParameters)
|
||||
import Simplex.Messaging.Transport.Client (TransportClientConfig, TransportHost)
|
||||
import Simplex.Messaging.Transport (HandshakeError (VERSION), THandleAuth (..), THandleParams (..), TransportError (..), supportedParameters)
|
||||
import Simplex.Messaging.Transport.Client (TransportClientConfig, TransportHost, alpn)
|
||||
import Simplex.Messaging.Transport.HTTP2
|
||||
import Simplex.Messaging.Transport.HTTP2.Client
|
||||
import Simplex.Messaging.Transport.HTTP2.File
|
||||
import Simplex.Messaging.Util (bshow, whenM)
|
||||
import Simplex.Messaging.Util (bshow, liftEitherWith, liftError', tshow, whenM)
|
||||
import Simplex.Messaging.Version (compatibleVersion, pattern Compatible)
|
||||
import UnliftIO
|
||||
import UnliftIO.Directory
|
||||
|
||||
@@ -63,7 +69,7 @@ data XFTPClient = XFTPClient
|
||||
|
||||
data XFTPClientConfig = XFTPClientConfig
|
||||
{ xftpNetworkConfig :: NetworkConfig,
|
||||
uploadTimeoutPerMb :: Int64
|
||||
serverVRange :: VersionRangeXFTP
|
||||
}
|
||||
|
||||
data XFTPChunkBody = XFTPChunkBody
|
||||
@@ -85,12 +91,12 @@ defaultXFTPClientConfig :: XFTPClientConfig
|
||||
defaultXFTPClientConfig =
|
||||
XFTPClientConfig
|
||||
{ xftpNetworkConfig = defaultNetworkConfig,
|
||||
uploadTimeoutPerMb = 10000000 -- 10 seconds
|
||||
serverVRange = supportedFileServerVRange
|
||||
}
|
||||
|
||||
getXFTPClient :: TransportSession FileResponse -> XFTPClientConfig -> (XFTPClient -> IO ()) -> IO (Either XFTPClientError XFTPClient)
|
||||
getXFTPClient transportSession@(_, srv, _) config@XFTPClientConfig {xftpNetworkConfig} disconnected = runExceptT $ do
|
||||
let tcConfig = transportClientConfig xftpNetworkConfig
|
||||
getXFTPClient :: TVar ChaChaDRG -> TransportSession FileResponse -> XFTPClientConfig -> (XFTPClient -> IO ()) -> IO (Either XFTPClientError XFTPClient)
|
||||
getXFTPClient g transportSession@(_, srv, _) config@XFTPClientConfig {xftpNetworkConfig, serverVRange} disconnected = runExceptT $ do
|
||||
let tcConfig = (transportClientConfig xftpNetworkConfig) {alpn = Just ["xftp/1"]}
|
||||
http2Config = xftpHTTP2Config tcConfig config
|
||||
username = proxyUsername transportSession
|
||||
ProtocolServer _ host port keyHash = srv
|
||||
@@ -98,13 +104,50 @@ getXFTPClient transportSession@(_, srv, _) config@XFTPClientConfig {xftpNetworkC
|
||||
clientVar <- newTVarIO Nothing
|
||||
let usePort = if null port then "443" else port
|
||||
clientDisconnected = readTVarIO clientVar >>= mapM_ disconnected
|
||||
http2Client <- withExceptT xftpClientError . ExceptT $ getVerifiedHTTP2Client (Just username) useHost usePort (Just keyHash) Nothing http2Config clientDisconnected
|
||||
let HTTP2Client {sessionId} = http2Client
|
||||
thParams = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = currentXFTPVersion, thAuth = Nothing, implySessId = False, batch = True}
|
||||
c = XFTPClient {http2Client, thParams, transportSession, config}
|
||||
http2Client <- liftError' xftpClientError $ getVerifiedHTTP2Client (Just username) useHost usePort (Just keyHash) Nothing http2Config clientDisconnected
|
||||
let HTTP2Client {sessionId, sessionALPN} = http2Client
|
||||
thParams0 = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = VersionXFTP 1, thAuth = Nothing, implySessId = False, batch = True}
|
||||
logDebug $ "Client negotiated handshake protocol: " <> tshow sessionALPN
|
||||
thParams <- case sessionALPN of
|
||||
Just "xftp/1" -> xftpClientHandshakeV1 g serverVRange keyHash http2Client thParams0
|
||||
Nothing -> pure thParams0
|
||||
_ -> throwError $ PCETransportError (TEHandshake VERSION)
|
||||
let c = XFTPClient {http2Client, thParams, transportSession, config}
|
||||
atomically $ writeTVar clientVar $ Just c
|
||||
pure c
|
||||
|
||||
xftpClientHandshakeV1 :: TVar ChaChaDRG -> VersionRangeXFTP -> C.KeyHash -> HTTP2Client -> THandleParamsXFTP -> ExceptT XFTPClientError IO THandleParamsXFTP
|
||||
xftpClientHandshakeV1 g serverVRange keyHash@(C.KeyHash kh) c@HTTP2Client {sessionId, serverKey} thParams0 = do
|
||||
shs <- getServerHandshake
|
||||
(v, sk) <- processServerHandshake shs
|
||||
(k, pk) <- atomically $ C.generateKeyPair g
|
||||
sendClientHandshake XFTPClientHandshake {xftpVersion = v, keyHash, authPubKey = k}
|
||||
pure thParams0 {thAuth = Just THandleAuth {peerPubKey = sk, privKey = pk}, thVersion = v}
|
||||
where
|
||||
getServerHandshake = do
|
||||
let helloReq = H.requestNoBody "POST" "/" []
|
||||
HTTP2Response {respBody = HTTP2Body {bodyHead = shsBody}} <-
|
||||
liftError' (const $ PCEResponseError HANDSHAKE) $ sendRequest c helloReq Nothing
|
||||
liftHS . smpDecode =<< liftHS (C.unPad shsBody)
|
||||
processServerHandshake XFTPServerHandshake {xftpVersionRange, sessionId = serverSessId, authPubKey = serverAuth} = do
|
||||
unless (sessionId == serverSessId) $ throwError $ PCEResponseError SESSION
|
||||
case xftpVersionRange `compatibleVersion` serverVRange of
|
||||
Nothing -> throwError $ PCEResponseError HANDSHAKE
|
||||
Just (Compatible v) ->
|
||||
fmap (v,) . liftHS $ do
|
||||
let (X.CertificateChain cert, exact) = serverAuth
|
||||
case cert of
|
||||
[_leaf, ca] | XV.Fingerprint kh == XV.getFingerprint ca X.HashSHA256 -> pure ()
|
||||
_ -> throwError "bad certificate"
|
||||
pubKey <- maybe (throwError "bad server key type") (`C.verifyX509` exact) serverKey
|
||||
C.x509ToPublic (pubKey, []) >>= C.pubKey
|
||||
sendClientHandshake chs = do
|
||||
chs' <- liftHS $ C.pad (smpEncode chs) xftpBlockSize
|
||||
let chsReq = H.requestBuilder "POST" "/" [] $ byteString chs'
|
||||
HTTP2Response {respBody = HTTP2Body {bodyHead}} <- liftError' (const $ PCEResponseError HANDSHAKE) $ sendRequest c chsReq Nothing
|
||||
unless (B.null bodyHead) $ throwError $ PCEResponseError HANDSHAKE
|
||||
liftHS = liftEitherWith (const $ PCEResponseError HANDSHAKE)
|
||||
|
||||
closeXFTPClient :: XFTPClient -> IO ()
|
||||
closeXFTPClient XFTPClient {http2Client} = closeHTTP2Client http2Client
|
||||
|
||||
@@ -144,8 +187,8 @@ sendXFTPCommand c@XFTPClient {thParams} pKey fId cmd chunkSpec_ = do
|
||||
sendXFTPTransmission :: XFTPClient -> ByteString -> Maybe XFTPChunkSpec -> ExceptT XFTPClientError IO (FileResponse, HTTP2Body)
|
||||
sendXFTPTransmission XFTPClient {config, thParams, http2Client} t chunkSpec_ = do
|
||||
let req = H.requestStreaming N.methodPost "/" [] streamBody
|
||||
reqTimeout = (\XFTPChunkSpec {chunkSize} -> chunkTimeout config chunkSize) <$> chunkSpec_
|
||||
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req reqTimeout
|
||||
reqTimeout = xftpReqTimeout config $ (\XFTPChunkSpec {chunkSize} -> chunkSize) <$> chunkSpec_
|
||||
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req (Just reqTimeout)
|
||||
when (B.length bodyHead /= xftpBlockSize) $ throwError $ PCEResponseError BLOCK
|
||||
-- TODO validate that the file ID is the same as in the request?
|
||||
(_, _, (_, _fId, respOrErr)) <- liftEither . first PCEResponseError $ xftpDecodeTransmission thParams bodyHead
|
||||
@@ -198,15 +241,20 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec {
|
||||
let t = chunkTimeout config chunkSize
|
||||
ExceptT (sequence <$> (t `timeout` download cbState)) >>= maybe (throwError PCEResponseTimeout) pure
|
||||
where
|
||||
download cbState = runExceptT $
|
||||
withExceptT PCEResponseError $
|
||||
download cbState =
|
||||
runExceptT . withExceptT PCEResponseError $
|
||||
receiveEncFile chunkPart cbState chunkSpec `catchError` \e ->
|
||||
whenM (doesFileExist filePath) (removeFile filePath) >> throwError e
|
||||
_ -> throwError $ PCEResponseError NO_FILE
|
||||
(r, _) -> throwError . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
xftpReqTimeout :: XFTPClientConfig -> Maybe Word32 -> Int
|
||||
xftpReqTimeout cfg@XFTPClientConfig {xftpNetworkConfig = NetworkConfig {tcpTimeout}} chunkSize_ =
|
||||
maybe tcpTimeout (chunkTimeout cfg) chunkSize_
|
||||
|
||||
chunkTimeout :: XFTPClientConfig -> Word32 -> Int
|
||||
chunkTimeout config chunkSize = fromIntegral $ (fromIntegral chunkSize * uploadTimeoutPerMb config) `div` mb 1
|
||||
chunkTimeout XFTPClientConfig {xftpNetworkConfig = NetworkConfig {tcpTimeout, tcpTimeoutPerKb}} sz =
|
||||
tcpTimeout + fromIntegral (min ((fromIntegral sz `div` 1024) * tcpTimeoutPerKb) (fromIntegral (maxBound :: Int)))
|
||||
|
||||
deleteXFTPChunk :: XFTPClient -> C.APrivateAuthKey -> SenderId -> ExceptT XFTPClientError IO ()
|
||||
deleteXFTPChunk c spKey sId = sendXFTPCommand c spKey sId FDEL Nothing >>= okResponse
|
||||
|
||||
@@ -11,6 +11,7 @@ import Control.Logger.Simple (logInfo)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans (lift)
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Text (Text)
|
||||
@@ -60,15 +61,15 @@ newXFTPAgent config = do
|
||||
|
||||
type ME a = ExceptT XFTPClientAgentError IO a
|
||||
|
||||
getXFTPServerClient :: XFTPClientAgent -> XFTPServer -> ME XFTPClient
|
||||
getXFTPServerClient XFTPClientAgent {xftpClients, config} srv = do
|
||||
getXFTPServerClient :: TVar ChaChaDRG -> XFTPClientAgent -> XFTPServer -> ME XFTPClient
|
||||
getXFTPServerClient g XFTPClientAgent {xftpClients, config} srv = do
|
||||
atomically getClientVar >>= either newXFTPClient waitForXFTPClient
|
||||
where
|
||||
connectClient :: ME XFTPClient
|
||||
connectClient =
|
||||
ExceptT $
|
||||
first (XFTPClientAgentError srv)
|
||||
<$> getXFTPClient (1, srv, Nothing) (xftpConfig config) clientDisconnected
|
||||
<$> getXFTPClient g (1, srv, Nothing) (xftpConfig config) clientDisconnected
|
||||
|
||||
clientDisconnected :: XFTPClient -> IO ()
|
||||
clientDisconnected _ = do
|
||||
|
||||
@@ -333,9 +333,9 @@ cliSendFileOpts SendOptions {filePath, outputDir, numRecipients, xftpServers, re
|
||||
rKeys <- atomically $ L.fromList <$> replicateM numRecipients (C.generateAuthKeyPair C.SEd25519 g)
|
||||
digest <- liftIO $ getChunkDigest chunkSpec
|
||||
let ch = FileInfo {sndKey, size = fromIntegral chunkSize, digest}
|
||||
c <- withRetry retryCount $ getXFTPServerClient a xftpServer
|
||||
c <- withRetry retryCount $ getXFTPServerClient g a xftpServer
|
||||
(sndId, rIds) <- withRetry retryCount $ createXFTPChunk c spKey ch (L.map fst rKeys) auth
|
||||
withReconnect a xftpServer retryCount $ \c' -> uploadXFTPChunk c' spKey sndId chunkSpec
|
||||
withReconnect g a xftpServer retryCount $ \c' -> uploadXFTPChunk c' spKey sndId chunkSpec
|
||||
logInfo $ "uploaded chunk " <> tshow chunkNo
|
||||
uploaded <- atomically . stateTVar uploadedChunks $ \cs ->
|
||||
let cs' = fromIntegral chunkSize : cs in (sum cs', cs')
|
||||
@@ -445,7 +445,7 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath,
|
||||
when (FileSize encSize /= size) $ throwError $ CLIError "File size mismatch"
|
||||
liftIO $ printNoNewLine "Decrypting file..."
|
||||
CryptoFile path _ <- withExceptT cliCryptoError $ decryptChunks encSize chunkPaths key nonce $ fmap CF.plain . getFilePath
|
||||
forM_ chunks $ acknowledgeFileChunk a
|
||||
forM_ chunks $ acknowledgeFileChunk g a
|
||||
whenM (doesPathExist encPath) $ removeDirectoryRecursive encPath
|
||||
liftIO $ do
|
||||
printNoNewLine $ "File downloaded: " <> path
|
||||
@@ -456,7 +456,7 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath,
|
||||
logInfo $ "downloading chunk " <> tshow chunkNo <> " from " <> showServer server <> "..."
|
||||
chunkPath <- uniqueCombine encPath $ show chunkNo
|
||||
let chunkSpec = XFTPRcvChunkSpec chunkPath (unFileSize chunkSize) (unFileDigest digest)
|
||||
withReconnect a server retryCount $ \c -> downloadXFTPChunk g c replicaKey (unChunkReplicaId replicaId) chunkSpec
|
||||
withReconnect g a server retryCount $ \c -> downloadXFTPChunk g c replicaKey (unChunkReplicaId replicaId) chunkSpec
|
||||
logInfo $ "downloaded chunk " <> tshow chunkNo <> " to " <> T.pack chunkPath
|
||||
downloaded <- atomically . stateTVar downloadedChunks $ \cs ->
|
||||
let cs' = fromIntegral (unFileSize chunkSize) : cs in (sum cs', cs')
|
||||
@@ -472,12 +472,12 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath,
|
||||
ifM (doesDirectoryExist path) (uniqueCombine path name) $
|
||||
ifM (doesFileExist path) (throwError "File already exists") (pure path)
|
||||
_ -> (`uniqueCombine` name) . (</> "Downloads") =<< getHomeDirectory
|
||||
acknowledgeFileChunk :: XFTPClientAgent -> FileChunk -> ExceptT CLIError IO ()
|
||||
acknowledgeFileChunk a FileChunk {replicas = replica : _} = do
|
||||
acknowledgeFileChunk :: TVar ChaChaDRG -> XFTPClientAgent -> FileChunk -> ExceptT CLIError IO ()
|
||||
acknowledgeFileChunk g a FileChunk {replicas = replica : _} = do
|
||||
let FileChunkReplica {server, replicaId, replicaKey} = replica
|
||||
c <- withRetry retryCount $ getXFTPServerClient a server
|
||||
c <- withRetry retryCount $ getXFTPServerClient g a server
|
||||
withRetry retryCount $ ackXFTPChunk c replicaKey (unChunkReplicaId replicaId)
|
||||
acknowledgeFileChunk _ _ = throwError $ CLIError "chunk has no replicas"
|
||||
acknowledgeFileChunk _ _ _ = throwError $ CLIError "chunk has no replicas"
|
||||
|
||||
printProgress :: String -> Int64 -> Int64 -> IO ()
|
||||
printProgress s part total = printNoNewLine $ s <> " " <> show ((part * 100) `div` total) <> "%"
|
||||
@@ -501,7 +501,8 @@ cliDeleteFile DeleteOptions {fileDescription, retryCount, yes} = do
|
||||
deleteFileChunk :: XFTPClientAgent -> FileChunk -> ExceptT CLIError IO ()
|
||||
deleteFileChunk a FileChunk {chunkNo, replicas = replica : _} = do
|
||||
let FileChunkReplica {server, replicaId, replicaKey} = replica
|
||||
withReconnect a server retryCount $ \c -> deleteXFTPChunk c replicaKey (unChunkReplicaId replicaId)
|
||||
g <- liftIO C.newRandom
|
||||
withReconnect g a server retryCount $ \c -> deleteXFTPChunk c replicaKey (unChunkReplicaId replicaId)
|
||||
logInfo $ "deleted chunk " <> tshow chunkNo <> " from " <> showServer server
|
||||
deleteFileChunk _ _ = throwError $ CLIError "chunk has no replicas"
|
||||
|
||||
@@ -569,9 +570,9 @@ prepareChunkSpecs filePath chunkSizes = reverse . snd $ foldl' addSpec (0, []) c
|
||||
getEncPath :: MonadIO m => Maybe FilePath -> String -> m FilePath
|
||||
getEncPath path name = (`uniqueCombine` (name <> ".encrypted")) =<< maybe (liftIO getCanonicalTemporaryDirectory) pure path
|
||||
|
||||
withReconnect :: Show e => XFTPClientAgent -> XFTPServer -> Int -> (XFTPClient -> ExceptT e IO a) -> ExceptT CLIError IO a
|
||||
withReconnect a srv n run = withRetry n $ do
|
||||
c <- withRetry n $ getXFTPServerClient a srv
|
||||
withReconnect :: Show e => TVar ChaChaDRG -> XFTPClientAgent -> XFTPServer -> Int -> (XFTPClient -> ExceptT e IO a) -> ExceptT CLIError IO a
|
||||
withReconnect g a srv n run = withRetry n $ do
|
||||
c <- withRetry n $ getXFTPServerClient g a srv
|
||||
withExceptT (CLIError . show) (run c) `catchError` \e -> do
|
||||
liftIO $ closeXFTPServerClient a srv
|
||||
throwError e
|
||||
|
||||
@@ -25,7 +25,7 @@ import Data.List.NonEmpty (NonEmpty (..))
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Type.Equality
|
||||
import Data.Word (Word32)
|
||||
import Simplex.FileTransfer.Transport (VersionXFTP, XFTPErrorType (..), XFTPVersion, pattern VersionXFTP, xftpClientHandshake)
|
||||
import Simplex.FileTransfer.Transport (VersionXFTP, XFTPErrorType (..), XFTPVersion, xftpClientHandshakeStub, pattern VersionXFTP)
|
||||
import Simplex.Messaging.Client (authTransmission)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
@@ -144,7 +144,7 @@ instance FilePartyI p => ProtocolMsgTag (FileCommandTag p) where
|
||||
instance Protocol XFTPVersion XFTPErrorType FileResponse where
|
||||
type ProtoCommand FileResponse = FileCmd
|
||||
type ProtoType FileResponse = 'PXFTP
|
||||
protocolClientHandshake = xftpClientHandshake
|
||||
protocolClientHandshake = xftpClientHandshakeStub
|
||||
protocolPing = FileCmd SFRecipient PING
|
||||
protocolError = \case
|
||||
FRErr e -> Just e
|
||||
@@ -329,9 +329,9 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of
|
||||
_ -> Nothing
|
||||
|
||||
xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString
|
||||
xftpEncodeAuthTransmission thParams pKey (corrId, fId, msg) = do
|
||||
xftpEncodeAuthTransmission thParams@THandleParams {thAuth} pKey (corrId, fId, msg) = do
|
||||
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, fId, msg)
|
||||
xftpEncodeBatch1 . (,tToSend) =<< authTransmission Nothing (Just pKey) corrId tForAuth
|
||||
xftpEncodeBatch1 . (,tToSend) =<< authTransmission thAuth (Just pKey) corrId tForAuth
|
||||
|
||||
xftpEncodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> Transmission c -> Either TransportError ByteString
|
||||
xftpEncodeTransmission thParams (corrId, fId, msg) = do
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
@@ -10,6 +9,7 @@
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Simplex.FileTransfer.Server where
|
||||
|
||||
@@ -18,7 +18,8 @@ import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Reader
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Builder (byteString)
|
||||
import qualified Data.ByteString.Base64.URL as B64
|
||||
import Data.ByteString.Builder (Builder, byteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Int (Int64)
|
||||
@@ -32,6 +33,7 @@ import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
|
||||
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
||||
import Data.Time.Format.ISO8601 (iso8601Show)
|
||||
import Data.Word (Word32)
|
||||
import qualified Data.X509 as X
|
||||
import GHC.IO.Handle (hSetNewlineMode)
|
||||
import GHC.Stats (getRTSStats)
|
||||
import qualified Network.HTTP.Types as N
|
||||
@@ -46,19 +48,22 @@ import Simplex.FileTransfer.Server.StoreLog
|
||||
import Simplex.FileTransfer.Transport
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (CorrId, RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TransmissionAuth)
|
||||
import Simplex.Messaging.Protocol (CorrId (..), RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TransmissionAuth)
|
||||
import Simplex.Messaging.Server (dummyVerifyCmd, verifyCmdAuthorization)
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Server.Stats
|
||||
import Simplex.Messaging.Transport (THandleParams (..))
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (SessionId, THandleAuth (..), THandleParams (..))
|
||||
import Simplex.Messaging.Transport.Buffer (trimCR)
|
||||
import Simplex.Messaging.Transport.HTTP2
|
||||
import Simplex.Messaging.Transport.HTTP2.File (fileBlockSize)
|
||||
import Simplex.Messaging.Transport.HTTP2.Server
|
||||
import Simplex.Messaging.Transport.Server (runTCPServer)
|
||||
import Simplex.Messaging.Transport.Server (runTCPServer, tlsServerCredentials)
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version (isCompatible)
|
||||
import System.Exit (exitFailure)
|
||||
import System.FilePath ((</>))
|
||||
import System.IO (hPrint, hPutStrLn, universalNewlineMode)
|
||||
@@ -70,7 +75,7 @@ import qualified UnliftIO.Exception as E
|
||||
type M a = ReaderT XFTPEnv IO a
|
||||
|
||||
data XFTPTransportRequest = XFTPTransportRequest
|
||||
{ thParams :: THandleParams XFTPVersion,
|
||||
{ thParams :: THandleParamsXFTP,
|
||||
reqBody :: HTTP2Body,
|
||||
request :: H.Request,
|
||||
sendResponse :: H.Response -> IO ()
|
||||
@@ -84,20 +89,75 @@ runXFTPServer cfg = do
|
||||
runXFTPServerBlocking :: TMVar Bool -> XFTPServerConfig -> IO ()
|
||||
runXFTPServerBlocking started cfg = newXFTPServerEnv cfg >>= runReaderT (xftpServer cfg started)
|
||||
|
||||
data Handshake
|
||||
= HandshakeSent C.PrivateKeyX25519
|
||||
| HandshakeAccepted THandleAuth VersionXFTP
|
||||
|
||||
xftpServer :: XFTPServerConfig -> TMVar Bool -> M ()
|
||||
xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpiration} started = do
|
||||
xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpiration, fileExpiration} started = do
|
||||
mapM_ (expireServerFiles Nothing) fileExpiration
|
||||
restoreServerStats
|
||||
raceAny_ (runServer : expireFilesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg) `finally` stopServer
|
||||
where
|
||||
runServer :: M ()
|
||||
runServer = do
|
||||
serverParams <- asks tlsServerParams
|
||||
let (chain, pk) = tlsServerCredentials serverParams
|
||||
signKey <- liftIO $ case C.x509ToPrivate (pk, []) >>= C.privKey of
|
||||
Right pk' -> pure pk'
|
||||
Left e -> putStrLn ("servers has no valid key: " <> show e) >> exitFailure
|
||||
env <- ask
|
||||
liftIO $
|
||||
runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams transportConfig inactiveClientExpiration $ \sessionId r sendResponse -> do
|
||||
reqBody <- getHTTP2Body r xftpBlockSize
|
||||
let thParams = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = currentXFTPVersion, thAuth = Nothing, implySessId = False, batch = True}
|
||||
processRequest XFTPTransportRequest {thParams, request = r, reqBody, sendResponse} `runReaderT` env
|
||||
sessions <- atomically TM.empty
|
||||
let cleanup sessionId = atomically $ TM.delete sessionId sessions
|
||||
liftIO . runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams transportConfig inactiveClientExpiration cleanup $ \sessionId sessionALPN r sendResponse -> do
|
||||
reqBody <- getHTTP2Body r xftpBlockSize
|
||||
let thParams0 = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = VersionXFTP 1, thAuth = Nothing, implySessId = False, batch = True}
|
||||
req0 = XFTPTransportRequest {thParams = thParams0, request = r, reqBody, sendResponse}
|
||||
flip runReaderT env $ case sessionALPN of
|
||||
Nothing -> processRequest req0
|
||||
Just "xftp/1" ->
|
||||
xftpServerHandshakeV1 chain signKey sessions req0 >>= \case
|
||||
Nothing -> pure () -- handshake response sent
|
||||
Just thParams -> processRequest req0 {thParams} -- proceed with new version (XXX: may as well switch the request handler here)
|
||||
_ -> liftIO . sendResponse $ H.responseNoBody N.ok200 [] -- shouldn't happen: means server picked handshake protocol it doesn't know about
|
||||
xftpServerHandshakeV1 :: X.CertificateChain -> C.APrivateSignKey -> TMap SessionId Handshake -> XFTPTransportRequest -> M (Maybe (THandleParams XFTPVersion))
|
||||
xftpServerHandshakeV1 chain serverSignKey sessions XFTPTransportRequest {thParams = thParams@THandleParams {sessionId}, reqBody = HTTP2Body {bodyHead}, sendResponse} = do
|
||||
s <- atomically $ TM.lookup sessionId sessions
|
||||
r <- runExceptT $ case s of
|
||||
Nothing -> processHello
|
||||
Just (HandshakeSent pk) -> processClientHandshake pk
|
||||
Just (HandshakeAccepted auth v) -> pure $ Just thParams {thAuth = Just auth, thVersion = v}
|
||||
either sendError pure r
|
||||
where
|
||||
processHello = do
|
||||
unless (B.null bodyHead) $ throwError HANDSHAKE
|
||||
(k, pk) <- atomically . C.generateKeyPair =<< asks random
|
||||
atomically $ TM.insert sessionId (HandshakeSent pk) sessions
|
||||
let authPubKey = (chain, C.signX509 serverSignKey $ C.publicToX509 k)
|
||||
let hs = XFTPServerHandshake {xftpVersionRange = supportedFileServerVRange, sessionId, authPubKey}
|
||||
shs <- encodeXftp hs
|
||||
liftIO . sendResponse $ H.responseBuilder N.ok200 [] shs
|
||||
pure Nothing
|
||||
processClientHandshake privKey = do
|
||||
unless (B.length bodyHead == xftpBlockSize) $ throwError HANDSHAKE
|
||||
body <- liftHS $ C.unPad bodyHead
|
||||
XFTPClientHandshake {xftpVersion, keyHash, authPubKey} <- liftHS $ smpDecode body
|
||||
kh <- asks serverIdentity
|
||||
unless (keyHash == kh) $ throwError HANDSHAKE
|
||||
unless (xftpVersion `isCompatible` supportedFileServerVRange) $ throwError HANDSHAKE
|
||||
let auth = THandleAuth {peerPubKey = authPubKey, privKey}
|
||||
atomically $ TM.insert sessionId (HandshakeAccepted auth xftpVersion) sessions
|
||||
liftIO . sendResponse $ H.responseNoBody N.ok200 []
|
||||
pure Nothing
|
||||
sendError :: XFTPErrorType -> M (Maybe (THandleParams XFTPVersion))
|
||||
sendError err = do
|
||||
runExceptT (encodeXftp err) >>= \case
|
||||
Right bs -> liftIO . sendResponse $ H.responseBuilder N.ok200 [] bs
|
||||
Left _ -> logError $ "Error encoding handshake error: " <> tshow err
|
||||
pure Nothing
|
||||
encodeXftp :: Encoding a => a -> ExceptT XFTPErrorType (ReaderT XFTPEnv IO) Builder
|
||||
encodeXftp a = byteString <$> liftHS (C.pad (smpEncode a) xftpBlockSize)
|
||||
liftHS = liftEitherWith (const HANDSHAKE)
|
||||
|
||||
stopServer :: M ()
|
||||
stopServer = do
|
||||
@@ -110,28 +170,10 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
|
||||
|
||||
expireFiles :: ExpirationConfig -> M ()
|
||||
expireFiles expCfg = do
|
||||
st <- asks store
|
||||
let interval = checkInterval expCfg * 1000000
|
||||
forever $ do
|
||||
liftIO $ threadDelay' interval
|
||||
old <- liftIO $ expireBeforeEpoch expCfg
|
||||
sIds <- M.keysSet <$> readTVarIO (files st)
|
||||
forM_ sIds $ \sId -> do
|
||||
threadDelay 100000
|
||||
atomically (expiredFilePath st sId old)
|
||||
>>= mapM_ (maybeRemove $ delete st sId)
|
||||
where
|
||||
maybeRemove del = maybe del (remove del)
|
||||
remove del filePath =
|
||||
ifM
|
||||
(doesFileExist filePath)
|
||||
((removeFile filePath >> del) `catch` \(e :: SomeException) -> logError $ "failed to remove expired file " <> tshow filePath <> ": " <> tshow e)
|
||||
del
|
||||
delete st sId = do
|
||||
withFileLog (`logDeleteFile` sId)
|
||||
void $ atomically $ deleteFile st sId
|
||||
FileServerStats {filesExpired} <- asks serverStats
|
||||
atomically $ modifyTVar' filesExpired (+ 1)
|
||||
expireServerFiles (Just 100000) expCfg
|
||||
|
||||
serverStatsThread_ :: XFTPServerConfig -> [M ()]
|
||||
serverStatsThread_ XFTPServerConfig {logStatsInterval = Just interval, logStatsStartTime, serverStatsLogFile} =
|
||||
@@ -201,7 +243,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
|
||||
role <- newTVarIO CPRNone
|
||||
cpLoop h role
|
||||
where
|
||||
cpLoop h role = do
|
||||
cpLoop h role = do
|
||||
s <- trimCR <$> B.hGetLine h
|
||||
case strDecode s of
|
||||
Right CPQuit -> hClose h
|
||||
@@ -235,12 +277,13 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
|
||||
CPQuit -> pure ()
|
||||
CPSkip -> pure ()
|
||||
where
|
||||
withUserRole action = readTVarIO role >>= \case
|
||||
CPRAdmin -> action
|
||||
CPRUser -> action
|
||||
_ -> do
|
||||
logError "Unauthorized control port command"
|
||||
hPutStrLn h "AUTH"
|
||||
withUserRole action =
|
||||
readTVarIO role >>= \case
|
||||
CPRAdmin -> action
|
||||
CPRUser -> action
|
||||
_ -> do
|
||||
logError "Unauthorized control port command"
|
||||
hPutStrLn h "AUTH"
|
||||
|
||||
data ServerFile = ServerFile
|
||||
{ filePath :: FilePath,
|
||||
@@ -253,10 +296,11 @@ processRequest XFTPTransportRequest {thParams, reqBody = body@HTTP2Body {bodyHea
|
||||
| B.length bodyHead /= xftpBlockSize = sendXFTPResponse ("", "", FRErr BLOCK) Nothing
|
||||
| otherwise = do
|
||||
case xftpDecodeTransmission thParams bodyHead of
|
||||
Right (sig_, signed, (corrId, fId, cmdOrErr)) -> do
|
||||
Right (sig_, signed, (corrId, fId, cmdOrErr)) ->
|
||||
case cmdOrErr of
|
||||
Right cmd -> do
|
||||
verifyXFTPTransmission sig_ signed fId cmd >>= \case
|
||||
let THandleParams {thAuth} = thParams
|
||||
verifyXFTPTransmission ((,C.cbNonce (bs corrId)) <$> thAuth) sig_ signed fId cmd >>= \case
|
||||
VRVerified req -> uncurry send =<< processXFTPRequest body req
|
||||
VRFailed -> send (FRErr AUTH) Nothing
|
||||
Left e -> send (FRErr e) Nothing
|
||||
@@ -264,7 +308,6 @@ processRequest XFTPTransportRequest {thParams, reqBody = body@HTTP2Body {bodyHea
|
||||
send resp = sendXFTPResponse (corrId, fId, resp)
|
||||
Left e -> sendXFTPResponse ("", "", FRErr e) Nothing
|
||||
where
|
||||
sendXFTPResponse :: (CorrId, XFTPFileId, FileResponse) -> Maybe ServerFile -> M ()
|
||||
sendXFTPResponse (corrId, fId, resp) serverFile_ = do
|
||||
let t_ = xftpEncodeTransmission thParams (corrId, fId, resp)
|
||||
liftIO $ sendResponse $ H.responseStreaming N.ok200 [] $ streamBody t_
|
||||
@@ -283,8 +326,8 @@ processRequest XFTPTransportRequest {thParams, reqBody = body@HTTP2Body {bodyHea
|
||||
|
||||
data VerificationResult = VRVerified XFTPRequest | VRFailed
|
||||
|
||||
verifyXFTPTransmission :: Maybe TransmissionAuth -> ByteString -> XFTPFileId -> FileCmd -> M VerificationResult
|
||||
verifyXFTPTransmission tAuth authorized fId cmd =
|
||||
verifyXFTPTransmission :: Maybe (THandleAuth, C.CbNonce) -> Maybe TransmissionAuth -> ByteString -> XFTPFileId -> FileCmd -> M VerificationResult
|
||||
verifyXFTPTransmission auth_ tAuth authorized fId cmd =
|
||||
case cmd of
|
||||
FileCmd SFSender (FNEW file rcps auth') -> pure $ XFTPReqNew file rcps auth' `verifyWith` sndKey file
|
||||
FileCmd SFRecipient PING -> pure $ VRVerified XFTPReqPing
|
||||
@@ -299,7 +342,7 @@ verifyXFTPTransmission tAuth authorized fId cmd =
|
||||
Right (fr, k) -> XFTPReqCmd fId fr cmd `verifyWith` k
|
||||
_ -> maybe False (dummyVerifyCmd Nothing authorized) tAuth `seq` VRFailed
|
||||
-- TODO verify with DH authorization
|
||||
req `verifyWith` k = if verifyCmdAuthorization Nothing tAuth authorized k then VRVerified req else VRFailed
|
||||
req `verifyWith` k = if verifyCmdAuthorization auth_ tAuth authorized k then VRVerified req else VRFailed
|
||||
|
||||
processXFTPRequest :: HTTP2Body -> XFTPRequest -> M (FileResponse, Maybe ServerFile)
|
||||
processXFTPRequest HTTP2Body {bodyPart} = \case
|
||||
@@ -392,7 +435,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
|
||||
\used -> let used' = used + fromIntegral size in if used' <= quota then (True, used') else (False, used)
|
||||
receive = do
|
||||
path <- asks $ filesPath . config
|
||||
let fPath = path </> B.unpack (U.encode senderId)
|
||||
let fPath = path </> B.unpack (B64.encode senderId)
|
||||
receiveChunk (XFTPRcvChunkSpec fPath size digest) >>= \case
|
||||
Right () -> do
|
||||
stats <- asks serverStats
|
||||
@@ -413,18 +456,20 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
|
||||
sendServerFile :: FileRec -> RcvPublicDhKey -> M (FileResponse, Maybe ServerFile)
|
||||
sendServerFile FileRec {senderId, filePath, fileInfo = FileInfo {size}} rDhKey = do
|
||||
readTVarIO filePath >>= \case
|
||||
Just path -> do
|
||||
g <- asks random
|
||||
(sDhKey, spDhKey) <- atomically $ C.generateKeyPair g
|
||||
let dhSecret = C.dh' rDhKey spDhKey
|
||||
cbNonce <- atomically $ C.randomCbNonce g
|
||||
case LC.cbInit dhSecret cbNonce of
|
||||
Right sbState -> do
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar' (fileDownloads stats) (+ 1)
|
||||
atomically $ updatePeriodStats (filesDownloaded stats) senderId
|
||||
pure (FRFile sDhKey cbNonce, Just ServerFile {filePath = path, fileSize = size, sbState})
|
||||
_ -> pure (FRErr INTERNAL, Nothing)
|
||||
Just path -> ifM (doesFileExist path) sendFile (pure (FRErr AUTH, Nothing))
|
||||
where
|
||||
sendFile = do
|
||||
g <- asks random
|
||||
(sDhKey, spDhKey) <- atomically $ C.generateKeyPair g
|
||||
let dhSecret = C.dh' rDhKey spDhKey
|
||||
cbNonce <- atomically $ C.randomCbNonce g
|
||||
case LC.cbInit dhSecret cbNonce of
|
||||
Right sbState -> do
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar' (fileDownloads stats) (+ 1)
|
||||
atomically $ updatePeriodStats (filesDownloaded stats) senderId
|
||||
pure (FRFile sDhKey cbNonce, Just ServerFile {filePath = path, fileSize = size, sbState})
|
||||
_ -> pure (FRErr INTERNAL, Nothing)
|
||||
_ -> pure (FRErr NO_FILE, Nothing)
|
||||
|
||||
deleteServerFile :: FileRec -> M FileResponse
|
||||
@@ -457,6 +502,33 @@ deleteServerFile_ FileRec {senderId, fileInfo, filePath} = do
|
||||
atomically $ modifyTVar' (filesCount stats) (subtract 1)
|
||||
atomically $ modifyTVar' (filesSize stats) (subtract $ fromIntegral $ size fileInfo)
|
||||
|
||||
expireServerFiles :: Maybe Int -> ExpirationConfig -> M ()
|
||||
expireServerFiles itemDelay expCfg = do
|
||||
st <- asks store
|
||||
usedStart <- readTVarIO $ usedStorage st
|
||||
old <- liftIO $ expireBeforeEpoch expCfg
|
||||
files' <- readTVarIO (files st)
|
||||
logInfo $ "Expiration check: " <> tshow (M.size files') <> " files"
|
||||
forM_ (M.keys files') $ \sId -> do
|
||||
mapM_ threadDelay itemDelay
|
||||
atomically (expiredFilePath st sId old)
|
||||
>>= mapM_ (maybeRemove $ delete st sId)
|
||||
usedEnd <- readTVarIO $ usedStorage st
|
||||
logInfo $ "Used " <> mbs usedStart <> " -> " <> mbs usedEnd <> ", " <> mbs (usedStart - usedEnd) <> " reclaimed."
|
||||
where
|
||||
mbs bs = tshow (bs `div` 1048576) <> "mb"
|
||||
maybeRemove del = maybe del (remove del)
|
||||
remove del filePath =
|
||||
ifM
|
||||
(doesFileExist filePath)
|
||||
((removeFile filePath >> del) `catch` \(e :: SomeException) -> logError $ "failed to remove expired file " <> tshow filePath <> ": " <> tshow e)
|
||||
del
|
||||
delete st sId = do
|
||||
withFileLog (`logDeleteFile` sId)
|
||||
void . atomically $ deleteFile st sId -- will not update usedStorage if sId isn't in store
|
||||
FileServerStats {filesExpired} <- asks serverStats
|
||||
atomically $ modifyTVar' filesExpired (+ 1)
|
||||
|
||||
randomId :: Int -> M ByteString
|
||||
randomId n = atomically . C.randomBytes n =<< asks random
|
||||
|
||||
|
||||
@@ -2,19 +2,23 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE StrictData #-}
|
||||
|
||||
module Simplex.FileTransfer.Server.Env where
|
||||
|
||||
import Control.Logger.Simple (logInfo)
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int64)
|
||||
import Data.List (find)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Time.Clock (getCurrentTime)
|
||||
import Data.Word (Word32)
|
||||
import Data.X509.Validation (Fingerprint (..))
|
||||
@@ -27,6 +31,7 @@ import Simplex.FileTransfer.Server.StoreLog
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (BasicAuth, RcvPublicAuthKey)
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Transport (ALPN)
|
||||
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
|
||||
import Simplex.Messaging.Util (tshow)
|
||||
import System.IO (IOMode (..))
|
||||
@@ -94,6 +99,9 @@ defaultFileExpiration =
|
||||
checkInterval = 2 * 3600 -- seconds, 2 hours
|
||||
}
|
||||
|
||||
supportedXFTPhandshakes :: [ALPN]
|
||||
supportedXFTPhandshakes = ["xftp/1"]
|
||||
|
||||
newXFTPServerEnv :: XFTPServerConfig -> IO XFTPEnv
|
||||
newXFTPServerEnv config@XFTPServerConfig {storeLogFile, fileSizeQuota, caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
random <- liftIO C.newRandom
|
||||
@@ -104,7 +112,14 @@ newXFTPServerEnv config@XFTPServerConfig {storeLogFile, fileSizeQuota, caCertifi
|
||||
forM_ fileSizeQuota $ \quota -> do
|
||||
logInfo $ "Total / available storage: " <> tshow quota <> " / " <> tshow (quota - used)
|
||||
when (quota < used) $ logInfo "WARNING: storage quota is less than used storage, no files can be uploaded!"
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
tlsServerParams' <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
let tlsServerParams =
|
||||
tlsServerParams'
|
||||
{ T.serverHooks =
|
||||
def
|
||||
{ T.onALPNClientSuggest = Just $ pure . fromMaybe "" . find (`elem` supportedXFTPhandshakes)
|
||||
}
|
||||
}
|
||||
Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile
|
||||
serverStats <- atomically . newFileServerStats =<< liftIO getCurrentTime
|
||||
pure XFTPEnv {config, store, storeLog, random, tlsServerParams, serverIdentity = C.KeyHash fp, serverStats}
|
||||
|
||||
@@ -22,6 +22,7 @@ import Control.Concurrent.STM
|
||||
import Control.Monad.Except
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Composition ((.:), (.:.))
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
@@ -88,7 +89,7 @@ readWriteFileStore f st = do
|
||||
pure s
|
||||
|
||||
readFileStore :: FilePath -> FileStore -> IO ()
|
||||
readFileStore f st = mapM_ addFileLogRecord . B.lines =<< B.readFile f
|
||||
readFileStore f st = mapM_ (addFileLogRecord . LB.toStrict) . LB.lines =<< LB.readFile f
|
||||
where
|
||||
addFileLogRecord s = case strDecode s of
|
||||
Left e -> B.putStrLn $ "Log parsing error (" <> B.pack e <> "): " <> B.take 100 s
|
||||
|
||||
@@ -9,9 +9,16 @@
|
||||
|
||||
module Simplex.FileTransfer.Transport
|
||||
( supportedFileServerVRange,
|
||||
xftpClientHandshake, -- stub
|
||||
XFTPVersion,
|
||||
xftpClientHandshakeStub,
|
||||
XFTPClientHandshake (..),
|
||||
-- xftpClientHandshake,
|
||||
XFTPServerHandshake (..),
|
||||
-- xftpServerHandshake,
|
||||
THandleXFTP,
|
||||
THandleParamsXFTP,
|
||||
VersionXFTP,
|
||||
VersionRangeXFTP,
|
||||
XFTPVersion,
|
||||
pattern VersionXFTP,
|
||||
XFTPErrorType (..),
|
||||
XFTPRcvChunkSpec (..),
|
||||
@@ -30,20 +37,21 @@ import Control.Monad.Except
|
||||
import Control.Monad.IO.Class
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString.Builder (Builder, byteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Word (Word16, Word32)
|
||||
import qualified Data.X509 as X
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Protocol (CommandError)
|
||||
import Simplex.Messaging.Transport (HandshakeError (..), THandle, TransportError (..))
|
||||
import Simplex.Messaging.Transport (HandshakeError (..), SessionId, THandle (..), THandleParams (..), TransportError (..))
|
||||
import Simplex.Messaging.Transport.HTTP2.File
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import Simplex.Messaging.Version
|
||||
@@ -68,6 +76,9 @@ type VersionRangeXFTP = VersionRange XFTPVersion
|
||||
pattern VersionXFTP :: Word16 -> VersionXFTP
|
||||
pattern VersionXFTP v = Version v
|
||||
|
||||
type THandleXFTP c = THandle XFTPVersion c
|
||||
type THandleParamsXFTP = THandleParams XFTPVersion
|
||||
|
||||
initialXFTPVersion :: VersionXFTP
|
||||
initialXFTPVersion = VersionXFTP 1
|
||||
|
||||
@@ -75,8 +86,45 @@ supportedFileServerVRange :: VersionRangeXFTP
|
||||
supportedFileServerVRange = mkVersionRange initialXFTPVersion initialXFTPVersion
|
||||
|
||||
-- XFTP protocol does not support handshake
|
||||
xftpClientHandshake :: c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> ExceptT TransportError IO (THandle XFTPVersion c)
|
||||
xftpClientHandshake _c _ks _keyHash _xftpVRange = throwError $ TEHandshake VERSION
|
||||
xftpClientHandshakeStub :: c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> ExceptT TransportError IO (THandle XFTPVersion c)
|
||||
xftpClientHandshakeStub _c _ks _keyHash _xftpVRange = throwError $ TEHandshake VERSION
|
||||
|
||||
data XFTPServerHandshake = XFTPServerHandshake
|
||||
{ xftpVersionRange :: VersionRangeXFTP,
|
||||
sessionId :: SessionId,
|
||||
-- | pub key to agree shared secrets for command authorization and entity ID encryption.
|
||||
authPubKey :: (X.CertificateChain, X.SignedExact X.PubKey)
|
||||
}
|
||||
|
||||
data XFTPClientHandshake = XFTPClientHandshake
|
||||
{ -- | agreed XFTP server protocol version
|
||||
xftpVersion :: VersionXFTP,
|
||||
-- | server identity - CA certificate fingerprint
|
||||
keyHash :: C.KeyHash,
|
||||
-- | pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys.
|
||||
authPubKey :: C.PublicKeyX25519
|
||||
}
|
||||
|
||||
instance Encoding XFTPClientHandshake where
|
||||
smpEncode XFTPClientHandshake {xftpVersion, keyHash, authPubKey} =
|
||||
smpEncode (xftpVersion, keyHash, authPubKey)
|
||||
smpP = do
|
||||
(xftpVersion, keyHash) <- smpP
|
||||
authPubKey <- smpP
|
||||
Tail _compat <- smpP
|
||||
pure XFTPClientHandshake {xftpVersion, keyHash, authPubKey}
|
||||
|
||||
instance Encoding XFTPServerHandshake where
|
||||
smpEncode XFTPServerHandshake {xftpVersionRange, sessionId, authPubKey} =
|
||||
smpEncode (xftpVersionRange, sessionId, auth)
|
||||
where
|
||||
auth = bimap C.encodeCertChain C.SignedObject authPubKey
|
||||
smpP = do
|
||||
(xftpVersionRange, sessionId) <- smpP
|
||||
cert <- C.certChainP
|
||||
C.SignedObject key <- smpP
|
||||
Tail _compat <- smpP
|
||||
pure XFTPServerHandshake {xftpVersionRange, sessionId, authPubKey = (cert, key)}
|
||||
|
||||
sendEncFile :: Handle -> (Builder -> IO ()) -> LC.SbState -> Word32 -> IO ()
|
||||
sendEncFile h send = go
|
||||
@@ -139,6 +187,8 @@ data XFTPErrorType
|
||||
BLOCK
|
||||
| -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929)
|
||||
SESSION
|
||||
| -- | incorrect handshake command
|
||||
HANDSHAKE
|
||||
| -- | SMP command is unknown or has invalid syntax
|
||||
CMD {cmdErr :: CommandError}
|
||||
| -- | command authorization error - bad signature or non-existing SMP queue
|
||||
@@ -181,6 +231,7 @@ instance Encoding XFTPErrorType where
|
||||
smpEncode = \case
|
||||
BLOCK -> "BLOCK"
|
||||
SESSION -> "SESSION"
|
||||
HANDSHAKE -> "HANDSHAKE"
|
||||
CMD err -> "CMD " <> smpEncode err
|
||||
AUTH -> "AUTH"
|
||||
SIZE -> "SIZE"
|
||||
@@ -199,6 +250,7 @@ instance Encoding XFTPErrorType where
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"BLOCK" -> pure BLOCK
|
||||
"SESSION" -> pure SESSION
|
||||
"HANDSHAKE" -> pure HANDSHAKE
|
||||
"CMD" -> CMD <$> _smpP
|
||||
"AUTH" -> pure AUTH
|
||||
"SIZE" -> pure SIZE
|
||||
|
||||
@@ -82,6 +82,7 @@ module Simplex.Messaging.Agent
|
||||
setNtfServers,
|
||||
setNetworkConfig,
|
||||
getNetworkConfig,
|
||||
setUserNetworkInfo,
|
||||
reconnectAllServers,
|
||||
registerNtfToken,
|
||||
verifyNtfToken,
|
||||
@@ -402,17 +403,32 @@ testProtocolServer c userId srv = withAgentEnv' c $ case protocolTypeI @p of
|
||||
SPXFTP -> runXFTPServerTest c userId srv
|
||||
SPNTF -> runNTFServerTest c userId srv
|
||||
|
||||
-- | set SOCKS5 proxy on/off and optionally set TCP timeout
|
||||
-- | set SOCKS5 proxy on/off and optionally set TCP timeouts for fast network
|
||||
setNetworkConfig :: AgentClient -> NetworkConfig -> IO ()
|
||||
setNetworkConfig c cfg' = do
|
||||
cfg <- atomically $ do
|
||||
swapTVar (useNetworkConfig c) cfg'
|
||||
when (cfg /= cfg') $ reconnectAllServers c
|
||||
setNetworkConfig c@AgentClient {useNetworkConfig} cfg' = do
|
||||
changed <- atomically $ do
|
||||
(_, cfg) <- readTVar useNetworkConfig
|
||||
if cfg == cfg'
|
||||
then pure False
|
||||
else True <$ (writeTVar useNetworkConfig $! (slowNetworkConfig cfg', cfg'))
|
||||
when changed $ reconnectAllServers c
|
||||
|
||||
-- returns fast network config
|
||||
getNetworkConfig :: AgentClient -> IO NetworkConfig
|
||||
getNetworkConfig = readTVarIO . useNetworkConfig
|
||||
getNetworkConfig = fmap snd . readTVarIO . useNetworkConfig
|
||||
{-# INLINE getNetworkConfig #-}
|
||||
|
||||
setUserNetworkInfo :: AgentClient -> UserNetworkInfo -> IO ()
|
||||
setUserNetworkInfo c@AgentClient {userNetworkState} UserNetworkInfo {networkType = nt'} = withAgentEnv' c $ do
|
||||
d <- asks $ initialInterval . userNetworkInterval . config
|
||||
ts <- liftIO getCurrentTime
|
||||
atomically $ do
|
||||
ns@UserNetworkState {networkType = nt} <- readTVar userNetworkState
|
||||
when (nt' /= nt) $
|
||||
writeTVar userNetworkState $! case nt' of
|
||||
UNNone -> ns {networkType = nt', offline = Just UNSOffline {offlineDelay = d, offlineFrom = ts}}
|
||||
_ -> ns {networkType = nt', offline = Nothing}
|
||||
|
||||
reconnectAllServers :: AgentClient -> IO ()
|
||||
reconnectAllServers c = do
|
||||
reconnectServerClients c smpClients
|
||||
@@ -1267,6 +1283,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq (Worker {doWork
|
||||
let mId = unId msgId
|
||||
ri' = maybe id updateRetryInterval2 msgRetryState ri
|
||||
withRetryLock2 ri' qLock $ \riState loop -> do
|
||||
lift $ waitForUserNetwork c
|
||||
resp <- tryError $ case msgType of
|
||||
AM_CONN_INFO -> sendConfirmation c sq msgBody
|
||||
AM_CONN_INFO_REPLY -> sendConfirmation c sq msgBody
|
||||
@@ -2047,7 +2064,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
|
||||
pure ack'
|
||||
where
|
||||
queueDrained = case conn of
|
||||
DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq)
|
||||
DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ A_QCONT (sndAddress rq)
|
||||
_ -> pure ()
|
||||
processClientMsg srvTs msgFlags msgBody = do
|
||||
clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <-
|
||||
@@ -2096,7 +2113,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
|
||||
notify $ MSG msgMeta msgFlags body
|
||||
pure ACKPending
|
||||
A_RCVD rcpts -> qDuplex conn'' "RCVD" $ messagesRcvd rcpts msgMeta
|
||||
QCONT addr -> qDuplexAckDel conn'' "QCONT" $ continueSending srvMsgId addr
|
||||
A_QCONT addr -> qDuplexAckDel conn'' "QCONT" $ continueSending srvMsgId addr
|
||||
QADD qs -> qDuplexAckDel conn'' "QADD" $ qAddMsg srvMsgId qs
|
||||
QKEY qs -> qDuplexAckDel conn'' "QKEY" $ qKeyMsg srvMsgId qs
|
||||
QUSE qs -> qDuplexAckDel conn'' "QUSE" $ qUseMsg srvMsgId qs
|
||||
@@ -2310,6 +2327,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
|
||||
atomically $
|
||||
TM.lookup (qAddress sq) (smpDeliveryWorkers c)
|
||||
>>= mapM_ (\(_, retryLock) -> tryPutTMVar retryLock ())
|
||||
notify QCONT
|
||||
Nothing -> qError "QCONT: queue address not found"
|
||||
|
||||
messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> AM ACKd
|
||||
@@ -2351,10 +2369,10 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
|
||||
case L.nonEmpty keepSqs of
|
||||
Just sqs' -> do
|
||||
-- move inside case?
|
||||
withStore' c $ \db -> mapM_ (deleteConnSndQueue db connId) delSqs
|
||||
sq_@SndQueue {sndPublicKey, e2ePubKey} <- lift $ newSndQueue userId connId qInfo
|
||||
let sq'' = (sq_ :: NewSndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
|
||||
sq2 <- withStore c $ \db -> addConnSndQueue db connId sq''
|
||||
sq2 <- withStore c $ \db -> do
|
||||
liftIO $ mapM_ (deleteConnSndQueue db connId) delSqs
|
||||
addConnSndQueue db connId (sq_ :: NewSndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
|
||||
case (sndPublicKey, e2ePubKey) of
|
||||
(Just sndPubKey, Just dhPublicKey) -> do
|
||||
logServer "<--" c srv rId $ "MSG <QADD>:" <> logSecret srvMsgId <> " " <> logSecret (senderId queueAddress)
|
||||
|
||||
@@ -27,6 +27,7 @@ module Simplex.Messaging.Agent.Client
|
||||
withConnLock,
|
||||
withConnLocks,
|
||||
withInvLock,
|
||||
withLockMap,
|
||||
closeAgentClient,
|
||||
closeProtocolServerClients,
|
||||
reconnectServerClients,
|
||||
@@ -80,6 +81,7 @@ module Simplex.Messaging.Agent.Client
|
||||
agentClientStore,
|
||||
agentDRG,
|
||||
getAgentSubscriptions,
|
||||
slowNetworkConfig,
|
||||
Worker (..),
|
||||
SessionVar (..),
|
||||
SubscriptionsInfo (..),
|
||||
@@ -99,6 +101,11 @@ module Simplex.Messaging.Agent.Client
|
||||
agentOperations,
|
||||
agentOperationBracket,
|
||||
waitUntilActive,
|
||||
UserNetworkInfo (..),
|
||||
UserNetworkType (..),
|
||||
UserNetworkState (..),
|
||||
UNSOffline (..),
|
||||
waitForUserNetwork,
|
||||
throwWhenInactive,
|
||||
throwWhenNoDelivery,
|
||||
beginAgentOperation,
|
||||
@@ -132,7 +139,7 @@ import Control.Applicative ((<|>))
|
||||
import Control.Concurrent (ThreadId, forkIO, threadDelay)
|
||||
import Control.Concurrent.Async (Async, uninterruptibleCancel)
|
||||
import Control.Concurrent.STM (retry, throwSTM)
|
||||
import Control.Exception (AsyncException (..))
|
||||
import Control.Exception (AsyncException (..), BlockedIndefinitelyOnSTM (..))
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
@@ -142,11 +149,13 @@ import Crypto.Random (ChaChaDRG)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Bifunctor (bimap, first, second)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Composition ((.:.))
|
||||
import Data.Either (lefts, partitionEithers)
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.List (deleteFirstsBy, foldl', partition, (\\))
|
||||
import Data.List.NonEmpty (NonEmpty (..), (<|))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
@@ -157,7 +166,7 @@ import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Text (Text)
|
||||
import Data.Text.Encoding
|
||||
import Data.Time (UTCTime, defaultTimeLocale, formatTime, getCurrentTime)
|
||||
import Data.Time (UTCTime, defaultTimeLocale, diffUTCTime, formatTime, getCurrentTime)
|
||||
import Data.Time.Clock.System (getSystemTime)
|
||||
import Data.Word (Word16)
|
||||
import Network.Socket (HostName)
|
||||
@@ -165,7 +174,7 @@ import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientCo
|
||||
import qualified Simplex.FileTransfer.Client as X
|
||||
import Simplex.FileTransfer.Description (ChunkReplicaId (..), FileDigest (..), kb)
|
||||
import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse)
|
||||
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (DIGEST), XFTPVersion)
|
||||
import Simplex.FileTransfer.Transport (XFTPErrorType (DIGEST), XFTPRcvChunkSpec (..), XFTPVersion)
|
||||
import Simplex.FileTransfer.Types (DeletedSndChunkReplica (..), NewSndChunkReplica (..), RcvFileChunkReplica (..), SndFileChunk (..), SndFileChunkReplica (..))
|
||||
import Simplex.FileTransfer.Util (uniqueCombine)
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
@@ -181,7 +190,6 @@ import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.Base64 (encode)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Client
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
@@ -195,6 +203,7 @@ import Simplex.Messaging.Protocol
|
||||
ErrorType,
|
||||
MsgFlags (..),
|
||||
MsgId,
|
||||
NtfPublicAuthKey,
|
||||
NtfServer,
|
||||
NtfServerWithAuth,
|
||||
ProtoServer,
|
||||
@@ -206,22 +215,21 @@ import Simplex.Messaging.Protocol
|
||||
QueueIdsKeys (..),
|
||||
RcvMessage (..),
|
||||
RcvNtfPublicDhKey,
|
||||
NtfPublicAuthKey,
|
||||
SMPMsgMeta (..),
|
||||
SProtocolType (..),
|
||||
SndPublicAuthKey,
|
||||
SubscriptionMode (..),
|
||||
UserProtocol,
|
||||
VersionRangeSMPC,
|
||||
VersionSMPC,
|
||||
XFTPServer,
|
||||
XFTPServerWithAuth,
|
||||
VersionSMPC,
|
||||
VersionRangeSMPC,
|
||||
sameSrvAddr',
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport (SMPVersion)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (SMPVersion)
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
@@ -264,7 +272,8 @@ data AgentClient = AgentClient
|
||||
ntfClients :: TMap NtfTransportSession NtfClientVar,
|
||||
xftpServers :: TMap UserId (NonEmpty XFTPServerWithAuth),
|
||||
xftpClients :: TMap XFTPTransportSession XFTPClientVar,
|
||||
useNetworkConfig :: TVar NetworkConfig,
|
||||
useNetworkConfig :: TVar (NetworkConfig, NetworkConfig), -- (slow, fast) networks
|
||||
userNetworkState :: TVar UserNetworkState,
|
||||
subscrConns :: TVar (Set ConnId),
|
||||
activeSubs :: TRcvQueues,
|
||||
pendingSubs :: TRcvQueues,
|
||||
@@ -395,6 +404,23 @@ data AgentStatsKey = AgentStatsKey
|
||||
}
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
data UserNetworkInfo = UserNetworkInfo
|
||||
{ networkType :: UserNetworkType
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data UserNetworkType = UNNone | UNCellular | UNWifi | UNEthernet | UNOther
|
||||
deriving (Eq, Show)
|
||||
|
||||
data UserNetworkState = UserNetworkState
|
||||
{ networkType :: UserNetworkType,
|
||||
offline :: Maybe UNSOffline
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data UNSOffline = UNSOffline {offlineDelay :: Int64, offlineFrom :: UTCTime}
|
||||
deriving (Show)
|
||||
|
||||
-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's.
|
||||
newAgentClient :: Int -> InitialAgentServers -> Env -> STM AgentClient
|
||||
newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
|
||||
@@ -410,7 +436,8 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv =
|
||||
ntfClients <- TM.empty
|
||||
xftpServers <- newTVar xftp
|
||||
xftpClients <- TM.empty
|
||||
useNetworkConfig <- newTVar netCfg
|
||||
useNetworkConfig <- newTVar (slowNetworkConfig netCfg, netCfg)
|
||||
userNetworkState <- newTVar $ UserNetworkState UNOther Nothing
|
||||
subscrConns <- newTVar S.empty
|
||||
activeSubs <- RQ.empty
|
||||
pendingSubs <- RQ.empty
|
||||
@@ -445,6 +472,7 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv =
|
||||
xftpServers,
|
||||
xftpClients,
|
||||
useNetworkConfig,
|
||||
userNetworkState,
|
||||
subscrConns,
|
||||
activeSubs,
|
||||
pendingSubs,
|
||||
@@ -469,6 +497,13 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv =
|
||||
agentEnv
|
||||
}
|
||||
|
||||
slowNetworkConfig :: NetworkConfig -> NetworkConfig
|
||||
slowNetworkConfig cfg@NetworkConfig {tcpConnectTimeout, tcpTimeout, tcpTimeoutPerKb} =
|
||||
cfg {tcpConnectTimeout = slow tcpConnectTimeout, tcpTimeout = slow tcpTimeout, tcpTimeoutPerKb = slow tcpTimeoutPerKb}
|
||||
where
|
||||
slow :: Integral a => a -> a
|
||||
slow t = (t * 3) `div` 2
|
||||
|
||||
agentClientStore :: AgentClient -> SQLiteStore
|
||||
agentClientStore AgentClient {agentEnv = Env {store}} = store
|
||||
{-# INLINE agentClientStore #-}
|
||||
@@ -531,7 +566,8 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv,
|
||||
g <- asks random
|
||||
env <- ask
|
||||
liftError' (protocolClientError SMP $ B.unpack $ strEncode srv) $
|
||||
getProtocolClient g tSess cfg (Just msgQ) $ clientDisconnected env v
|
||||
getProtocolClient g tSess cfg (Just msgQ) $
|
||||
clientDisconnected env v
|
||||
|
||||
clientDisconnected :: Env -> SMPClientVar -> SMPClient -> IO ()
|
||||
clientDisconnected env v client = do
|
||||
@@ -580,6 +616,7 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess =
|
||||
withRetryInterval ri $ \_ loop -> do
|
||||
pending <- atomically getPending
|
||||
forM_ (L.nonEmpty pending) $ \qs -> do
|
||||
waitForUserNetwork c
|
||||
void . tryAgentError' $ reconnectSMPClient timeoutCounts c tSess qs
|
||||
loop
|
||||
getPending = RQ.getSessQueues tSess $ pendingSubs c
|
||||
@@ -592,7 +629,7 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess =
|
||||
|
||||
reconnectSMPClient :: TVar Int -> AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> AM ()
|
||||
reconnectSMPClient tc c tSess@(_, srv, _) qs = do
|
||||
NetworkConfig {tcpTimeout} <- readTVarIO $ useNetworkConfig c
|
||||
NetworkConfig {tcpTimeout} <- atomically $ getNetworkConfig c
|
||||
-- this allows 3x of timeout per batch of subscription (90 queues per batch empirically)
|
||||
let t = (length qs `div` 90 + 1) * tcpTimeout * 3
|
||||
ExceptT (sequence <$> (t `timeout` runExceptT resubscribe)) >>= \case
|
||||
@@ -634,7 +671,8 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d
|
||||
cfg <- lift $ getClientConfig c ntfCfg
|
||||
g <- asks random
|
||||
liftError' (protocolClientError NTF $ B.unpack $ strEncode srv) $
|
||||
getProtocolClient g tSess cfg Nothing $ clientDisconnected v
|
||||
getProtocolClient g tSess cfg Nothing $
|
||||
clientDisconnected v
|
||||
|
||||
clientDisconnected :: NtfClientVar -> NtfClient -> IO ()
|
||||
clientDisconnected v client = do
|
||||
@@ -644,7 +682,7 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
getXFTPServerClient :: AgentClient -> XFTPTransportSession -> AM XFTPClient
|
||||
getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@(userId, srv, _) = do
|
||||
getXFTPServerClient c@AgentClient {active, xftpClients} tSess@(userId, srv, _) = do
|
||||
unlessM (readTVarIO active) . throwError $ INACTIVE
|
||||
atomically (getTSessVar c tSess xftpClients)
|
||||
>>= either
|
||||
@@ -654,9 +692,11 @@ getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@
|
||||
connectClient :: XFTPClientVar -> AM XFTPClient
|
||||
connectClient v = do
|
||||
cfg <- asks $ xftpCfg . config
|
||||
xftpNetworkConfig <- readTVarIO useNetworkConfig
|
||||
g <- asks random
|
||||
xftpNetworkConfig <- atomically $ getNetworkConfig c
|
||||
liftError' (protocolClientError XFTP $ B.unpack $ strEncode srv) $
|
||||
X.getXFTPClient tSess cfg {xftpNetworkConfig} $ clientDisconnected v
|
||||
X.getXFTPClient g tSess cfg {xftpNetworkConfig} $
|
||||
clientDisconnected v
|
||||
|
||||
clientDisconnected :: XFTPClientVar -> XFTPClient -> IO ()
|
||||
clientDisconnected v client = do
|
||||
@@ -688,7 +728,7 @@ removeTSessVar' v tSess vs =
|
||||
|
||||
waitForProtocolClient :: ProtocolTypeI (ProtoType msg) => AgentClient -> TransportSession msg -> ClientVar msg -> AM (Client msg)
|
||||
waitForProtocolClient c (_, srv, _) v = do
|
||||
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
|
||||
NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c
|
||||
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v)
|
||||
liftEither $ case client_ of
|
||||
Just (Right smpClient) -> Right smpClient
|
||||
@@ -724,11 +764,51 @@ hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerCli
|
||||
hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost
|
||||
|
||||
getClientConfig :: AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> AM' (ProtocolClientConfig v)
|
||||
getClientConfig AgentClient {useNetworkConfig} cfgSel = do
|
||||
getClientConfig c cfgSel = do
|
||||
cfg <- asks $ cfgSel . config
|
||||
networkConfig <- readTVarIO useNetworkConfig
|
||||
networkConfig <- atomically $ getNetworkConfig c
|
||||
pure cfg {networkConfig}
|
||||
|
||||
getNetworkConfig :: AgentClient -> STM NetworkConfig
|
||||
getNetworkConfig c = do
|
||||
(slowCfg, fastCfg) <- readTVar (useNetworkConfig c)
|
||||
UserNetworkState {networkType} <- readTVar (userNetworkState c)
|
||||
pure $ case networkType of
|
||||
UNCellular -> slowCfg
|
||||
UNNone -> slowCfg
|
||||
_ -> fastCfg
|
||||
|
||||
waitForUserNetwork :: AgentClient -> AM' ()
|
||||
waitForUserNetwork AgentClient {userNetworkState} =
|
||||
(offline <$> readTVarIO userNetworkState) >>= mapM_ waitWhileOffline
|
||||
where
|
||||
waitWhileOffline UNSOffline {offlineDelay = d} =
|
||||
unlessM (liftIO $ waitOnline d False) $ do -- network delay reached, increase delay
|
||||
ts' <- liftIO getCurrentTime
|
||||
ni <- asks $ userNetworkInterval . config
|
||||
atomically $ do
|
||||
ns@UserNetworkState {offline} <- readTVar userNetworkState
|
||||
forM_ offline $ \UNSOffline {offlineDelay = d', offlineFrom = ts} ->
|
||||
-- Using `min` to avoid multiple updates in a short period of time
|
||||
-- and to reset `offlineDelay` if network went `on` and `off` again.
|
||||
writeTVar userNetworkState $!
|
||||
let d'' = nextRetryDelay (diffToMicroseconds $ diffUTCTime ts' ts) (min d d') ni
|
||||
in ns {offline = Just UNSOffline {offlineDelay = d'', offlineFrom = ts}}
|
||||
waitOnline :: Int64 -> Bool -> IO Bool
|
||||
waitOnline t online'
|
||||
| t <= 0 = pure online'
|
||||
| otherwise =
|
||||
registerDelay (fromIntegral maxWait)
|
||||
>>= atomically . onlineOrDelay
|
||||
>>= waitOnline (t - maxWait)
|
||||
where
|
||||
maxWait = min t $ fromIntegral (maxBound :: Int)
|
||||
onlineOrDelay delay = do
|
||||
online <- isNothing . offline <$> readTVar userNetworkState
|
||||
expired <- readTVar delay
|
||||
unless (online || expired) retry
|
||||
pure online
|
||||
|
||||
closeAgentClient :: AgentClient -> IO ()
|
||||
closeAgentClient c = do
|
||||
atomically $ writeTVar (active c) False
|
||||
@@ -784,8 +864,8 @@ closeClient c clientSel tSess =
|
||||
|
||||
closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO ()
|
||||
closeClient_ c v = do
|
||||
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
|
||||
tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case
|
||||
NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c
|
||||
E.handle (\BlockedIndefinitelyOnSTM -> pure ()) $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case
|
||||
Just (Right client) -> closeProtocolServerClient client `catchAll_` pure ()
|
||||
_ -> pure ()
|
||||
|
||||
@@ -799,7 +879,7 @@ withConnLock c connId name = ExceptT . withConnLock' c connId name . runExceptT
|
||||
|
||||
withConnLock' :: AgentClient -> ConnId -> String -> AM' a -> AM' a
|
||||
withConnLock' _ "" _ = id
|
||||
withConnLock' AgentClient {connLocks} connId name = withLockMap_ connLocks connId name
|
||||
withConnLock' AgentClient {connLocks} connId name = withLockMap connLocks connId name
|
||||
{-# INLINE withConnLock' #-}
|
||||
|
||||
withInvLock :: AgentClient -> ByteString -> String -> AM a -> AM a
|
||||
@@ -807,16 +887,16 @@ withInvLock c key name = ExceptT . withInvLock' c key name . runExceptT
|
||||
{-# INLINE withInvLock #-}
|
||||
|
||||
withInvLock' :: AgentClient -> ByteString -> String -> AM' a -> AM' a
|
||||
withInvLock' AgentClient {invLocks} = withLockMap_ invLocks
|
||||
withInvLock' AgentClient {invLocks} = withLockMap invLocks
|
||||
{-# INLINE withInvLock' #-}
|
||||
|
||||
withConnLocks :: AgentClient -> [ConnId] -> String -> AM' a -> AM' a
|
||||
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null)
|
||||
{-# INLINE withConnLocks #-}
|
||||
|
||||
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
|
||||
withLockMap_ = withGetLock . getMapLock
|
||||
{-# INLINE withLockMap_ #-}
|
||||
withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
|
||||
withLockMap = withGetLock . getMapLock
|
||||
{-# INLINE withLockMap #-}
|
||||
|
||||
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a
|
||||
withLocksMap_ = withGetLocks . getMapLock
|
||||
@@ -945,13 +1025,13 @@ runXFTPServerTest :: AgentClient -> UserId -> XFTPServerWithAuth -> AM' (Maybe P
|
||||
runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
cfg <- asks $ xftpCfg . config
|
||||
g <- asks random
|
||||
xftpNetworkConfig <- readTVarIO $ useNetworkConfig c
|
||||
xftpNetworkConfig <- atomically $ getNetworkConfig c
|
||||
workDir <- getXFTPWorkPath
|
||||
filePath <- getTempFilePath workDir
|
||||
rcvPath <- getTempFilePath workDir
|
||||
liftIO $ do
|
||||
let tSess = (userId, srv, Nothing)
|
||||
X.getXFTPClient tSess cfg {xftpNetworkConfig} (\_ -> pure ()) >>= \case
|
||||
X.getXFTPClient g tSess cfg {xftpNetworkConfig} (\_ -> pure ()) >>= \case
|
||||
Right xftp -> withTestChunk filePath $ do
|
||||
(sndKey, spKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
|
||||
(rcvKey, rpKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
|
||||
@@ -1035,7 +1115,7 @@ mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q)
|
||||
{-# INLINE mkSMPTSession #-}
|
||||
|
||||
getSessionMode :: AgentClient -> IO TransportSessionMode
|
||||
getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig
|
||||
getSessionMode = atomically . fmap sessionMode . getNetworkConfig
|
||||
{-# INLINE getSessionMode #-}
|
||||
|
||||
newRcvQueue :: AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> AM (NewRcvQueue, SMPQueueUri)
|
||||
@@ -1127,7 +1207,7 @@ sendTSessionBatches statCmd statBatchSize toRQ action c qs =
|
||||
where
|
||||
batchQueues :: AM' [(SMPTransportSession, NonEmpty q)]
|
||||
batchQueues = do
|
||||
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
|
||||
mode <- atomically $ sessionMode <$> getNetworkConfig c
|
||||
pure . M.assocs $ foldl' (batch mode) M.empty qs
|
||||
where
|
||||
batch mode m q =
|
||||
@@ -1770,3 +1850,7 @@ $(J.deriveJSON defaultJSON ''WorkersSummary)
|
||||
$(J.deriveJSON defaultJSON {J.fieldLabelModifier = takeWhile (/= '_')} ''AgentWorkersDetails)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''AgentWorkersSummary)
|
||||
|
||||
$(J.deriveJSON (enumJSON $ dropPrefix "UN") ''UserNetworkType)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''UserNetworkInfo)
|
||||
|
||||
@@ -92,6 +92,7 @@ data AgentConfig = AgentConfig
|
||||
xftpCfg :: XFTPClientConfig,
|
||||
reconnectInterval :: RetryInterval,
|
||||
messageRetryInterval :: RetryInterval2,
|
||||
userNetworkInterval :: RetryInterval,
|
||||
messageTimeout :: NominalDiffTime,
|
||||
connDeleteDeliveryTimeout :: NominalDiffTime,
|
||||
helloTimeout :: NominalDiffTime,
|
||||
@@ -126,7 +127,7 @@ defaultReconnectInterval =
|
||||
RetryInterval
|
||||
{ initialInterval = 2_000000,
|
||||
increaseAfter = 10_000000,
|
||||
maxInterval = 180_000000
|
||||
maxInterval = 60_000000
|
||||
}
|
||||
|
||||
defaultMessageRetryInterval :: RetryInterval2
|
||||
@@ -134,18 +135,26 @@ defaultMessageRetryInterval =
|
||||
RetryInterval2
|
||||
{ riFast =
|
||||
RetryInterval
|
||||
{ initialInterval = 1_000000,
|
||||
{ initialInterval = 2_000000,
|
||||
increaseAfter = 10_000000,
|
||||
maxInterval = 60_000000
|
||||
},
|
||||
riSlow =
|
||||
RetryInterval
|
||||
{ initialInterval = 180_000000, -- 3 minutes
|
||||
{ initialInterval = 300_000000, -- 5 minutes
|
||||
increaseAfter = 60_000000,
|
||||
maxInterval = 3 * 3600_000000 -- 3 hours
|
||||
maxInterval = 6 * 3600_000000 -- 6 hours
|
||||
}
|
||||
}
|
||||
|
||||
defaultUserNetworkInterval :: RetryInterval
|
||||
defaultUserNetworkInterval =
|
||||
RetryInterval
|
||||
{ initialInterval = 1200_000000, -- 20 minutes
|
||||
increaseAfter = 0,
|
||||
maxInterval = 7200_000000 -- 2 hours
|
||||
}
|
||||
|
||||
defaultAgentConfig :: AgentConfig
|
||||
defaultAgentConfig =
|
||||
AgentConfig
|
||||
@@ -161,6 +170,7 @@ defaultAgentConfig =
|
||||
xftpCfg = defaultXFTPClientConfig,
|
||||
reconnectInterval = defaultReconnectInterval,
|
||||
messageRetryInterval = defaultMessageRetryInterval,
|
||||
userNetworkInterval = defaultUserNetworkInterval,
|
||||
messageTimeout = 2 * nominalDay,
|
||||
connDeleteDeliveryTimeout = 2 * nominalDay,
|
||||
helloTimeout = 2 * nominalDay,
|
||||
|
||||
@@ -160,7 +160,8 @@ runNtfWorker c srv Worker {doWork} = do
|
||||
\nextSub@(NtfSubscription {connId}, _, _) -> do
|
||||
logInfo $ "runNtfWorker, nextSub " <> tshow nextSub
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \_ loop ->
|
||||
withRetryInterval ri $ \_ loop -> do
|
||||
lift $ waitForUserNetwork c
|
||||
processSub nextSub
|
||||
`catchAgentError` retryOnError c "NtfWorker" loop (workerInternalError c connId . show)
|
||||
processSub :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> AM ()
|
||||
@@ -243,7 +244,8 @@ runNtfSMPWorker c srv Worker {doWork} = do
|
||||
\nextSub@(NtfSubscription {connId}, _, _) -> do
|
||||
logInfo $ "runNtfSMPWorker, nextSub " <> tshow nextSub
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \_ loop ->
|
||||
withRetryInterval ri $ \_ loop -> do
|
||||
lift $ waitForUserNetwork c
|
||||
processSub nextSub
|
||||
`catchAgentError` retryOnError c "NtfSMPWorker" loop (workerInternalError c connId . show)
|
||||
processSub :: (NtfSubscription, NtfSubSMPAction, NtfActionTs) -> AM ()
|
||||
|
||||
@@ -163,6 +163,7 @@ import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
@@ -201,7 +202,6 @@ import Simplex.Messaging.Crypto.Ratchet
|
||||
SndE2ERatchetParams
|
||||
)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.Base64 (base64P, encode)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Protocol
|
||||
@@ -405,6 +405,7 @@ data ACommand (p :: AParty) (e :: AEntity) where
|
||||
MSGNTF :: SMPMsgMeta -> ACommand Agent AEConn
|
||||
ACK :: AgentMsgId -> Maybe MsgReceiptInfo -> ACommand Client AEConn
|
||||
RCVD :: MsgMeta -> NonEmpty MsgReceipt -> ACommand Agent AEConn
|
||||
QCONT :: ACommand Agent AEConn
|
||||
SWCH :: ACommand Client AEConn
|
||||
OFF :: ACommand Client AEConn
|
||||
DEL :: ACommand Client AEConn
|
||||
@@ -467,6 +468,7 @@ data ACommandTag (p :: AParty) (e :: AEntity) where
|
||||
MSGNTF_ :: ACommandTag Agent AEConn
|
||||
ACK_ :: ACommandTag Client AEConn
|
||||
RCVD_ :: ACommandTag Agent AEConn
|
||||
QCONT_ :: ACommandTag Agent AEConn
|
||||
SWCH_ :: ACommandTag Client AEConn
|
||||
OFF_ :: ACommandTag Client AEConn
|
||||
DEL_ :: ACommandTag Client AEConn
|
||||
@@ -522,6 +524,7 @@ aCommandTag = \case
|
||||
MSGNTF {} -> MSGNTF_
|
||||
ACK {} -> ACK_
|
||||
RCVD {} -> RCVD_
|
||||
QCONT -> QCONT_
|
||||
SWCH -> SWCH_
|
||||
OFF -> OFF_
|
||||
DEL -> DEL_
|
||||
@@ -996,7 +999,7 @@ agentMessageType = \case
|
||||
HELLO -> AM_HELLO_
|
||||
A_MSG _ -> AM_A_MSG_
|
||||
A_RCVD {} -> AM_A_RCVD_
|
||||
QCONT _ -> AM_QCONT_
|
||||
A_QCONT _ -> AM_QCONT_
|
||||
QADD _ -> AM_QADD_
|
||||
QKEY _ -> AM_QKEY_
|
||||
QUSE _ -> AM_QUSE_
|
||||
@@ -1020,7 +1023,7 @@ data AMsgType
|
||||
= HELLO_
|
||||
| A_MSG_
|
||||
| A_RCVD_
|
||||
| QCONT_
|
||||
| A_QCONT_
|
||||
| QADD_
|
||||
| QKEY_
|
||||
| QUSE_
|
||||
@@ -1033,7 +1036,7 @@ instance Encoding AMsgType where
|
||||
HELLO_ -> "H"
|
||||
A_MSG_ -> "M"
|
||||
A_RCVD_ -> "V"
|
||||
QCONT_ -> "QC"
|
||||
A_QCONT_ -> "QC"
|
||||
QADD_ -> "QA"
|
||||
QKEY_ -> "QK"
|
||||
QUSE_ -> "QU"
|
||||
@@ -1046,7 +1049,7 @@ instance Encoding AMsgType where
|
||||
'V' -> pure A_RCVD_
|
||||
'Q' ->
|
||||
A.anyChar >>= \case
|
||||
'C' -> pure QCONT_
|
||||
'C' -> pure A_QCONT_
|
||||
'A' -> pure QADD_
|
||||
'K' -> pure QKEY_
|
||||
'U' -> pure QUSE_
|
||||
@@ -1066,7 +1069,7 @@ data AMessage
|
||||
| -- | agent envelope for delivery receipt
|
||||
A_RCVD (NonEmpty AMessageReceipt)
|
||||
| -- | the message instructing the client to continue sending messages (after ERR QUOTA)
|
||||
QCONT SndQAddr
|
||||
A_QCONT SndQAddr
|
||||
| -- add queue to connection (sent by recipient), with optional address of the replaced queue
|
||||
QADD (NonEmpty (SMPQueueUri, Maybe SndQAddr))
|
||||
| -- key to secure the added queues and agree e2e encryption key (sent by sender)
|
||||
@@ -1124,7 +1127,7 @@ instance Encoding AMessage where
|
||||
HELLO -> smpEncode HELLO_
|
||||
A_MSG body -> smpEncode (A_MSG_, Tail body)
|
||||
A_RCVD mrs -> smpEncode (A_RCVD_, mrs)
|
||||
QCONT addr -> smpEncode (QCONT_, addr)
|
||||
A_QCONT addr -> smpEncode (A_QCONT_, addr)
|
||||
QADD qs -> smpEncode (QADD_, qs)
|
||||
QKEY qs -> smpEncode (QKEY_, qs)
|
||||
QUSE qs -> smpEncode (QUSE_, qs)
|
||||
@@ -1136,7 +1139,7 @@ instance Encoding AMessage where
|
||||
HELLO_ -> pure HELLO
|
||||
A_MSG_ -> A_MSG . unTail <$> smpP
|
||||
A_RCVD_ -> A_RCVD <$> smpP
|
||||
QCONT_ -> QCONT <$> smpP
|
||||
A_QCONT_ -> A_QCONT <$> smpP
|
||||
QADD_ -> QADD <$> smpP
|
||||
QKEY_ -> QKEY <$> smpP
|
||||
QUSE_ -> QUSE <$> smpP
|
||||
@@ -1668,6 +1671,7 @@ instance StrEncoding ACmdTag where
|
||||
"MSGNTF" -> ct MSGNTF_
|
||||
"ACK" -> t ACK_
|
||||
"RCVD" -> ct RCVD_
|
||||
"QCONT" -> ct QCONT_
|
||||
"SWCH" -> t SWCH_
|
||||
"OFF" -> t OFF_
|
||||
"DEL" -> t DEL_
|
||||
@@ -1725,6 +1729,7 @@ instance (APartyI p, AEntityI e) => StrEncoding (ACommandTag p e) where
|
||||
MSGNTF_ -> "MSGNTF"
|
||||
ACK_ -> "ACK"
|
||||
RCVD_ -> "RCVD"
|
||||
QCONT_ -> "QCONT"
|
||||
SWCH_ -> "SWCH"
|
||||
OFF_ -> "OFF"
|
||||
DEL_ -> "DEL"
|
||||
@@ -1794,6 +1799,7 @@ commandP binaryP =
|
||||
MSG_ -> s (MSG <$> strP <* A.space <*> smpP <* A.space <*> binaryP)
|
||||
MSGNTF_ -> s (MSGNTF <$> strP)
|
||||
RCVD_ -> s (RCVD <$> strP <* A.space <*> strP)
|
||||
QCONT_ -> pure QCONT
|
||||
DEL_RCVQ_ -> s (DEL_RCVQ <$> strP_ <*> strP_ <*> strP)
|
||||
DEL_CONN_ -> pure DEL_CONN
|
||||
DEL_USER_ -> s (DEL_USER <$> strP)
|
||||
@@ -1857,6 +1863,7 @@ serializeCommand = \case
|
||||
MSGNTF smpMsgMeta -> s (MSGNTF_, smpMsgMeta)
|
||||
ACK mId rcptInfo_ -> s (ACK_, mId) <> maybe "" (B.cons ' ' . serializeBinary) rcptInfo_
|
||||
RCVD msgMeta rcpts -> s (RCVD_, msgMeta, rcpts)
|
||||
QCONT -> s QCONT_
|
||||
SWCH -> s SWCH_
|
||||
OFF -> s OFF_
|
||||
DEL -> s DEL_
|
||||
|
||||
@@ -11,6 +11,7 @@ module Simplex.Messaging.Agent.RetryInterval
|
||||
withRetryIntervalCount,
|
||||
withRetryLock2,
|
||||
updateRetryInterval2,
|
||||
nextRetryDelay,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -60,7 +61,7 @@ withRetryIntervalCount ri action = callAction 0 0 $ initialInterval ri
|
||||
loop = do
|
||||
liftIO $ threadDelay' delay
|
||||
let elapsed' = elapsed + delay
|
||||
callAction (n + 1) elapsed' $ nextDelay elapsed' delay ri
|
||||
callAction (n + 1) elapsed' $ nextRetryDelay elapsed' delay ri
|
||||
|
||||
-- This function allows action to toggle between slow and fast retry intervals.
|
||||
withRetryLock2 :: forall m. MonadIO m => RetryInterval2 -> TMVar () -> (RI2State -> (RetryIntervalMode -> m ()) -> m ()) -> m ()
|
||||
@@ -76,7 +77,7 @@ withRetryLock2 RetryInterval2 {riSlow, riFast} lock action =
|
||||
run (elapsed, delay) ri call = do
|
||||
wait delay
|
||||
let elapsed' = elapsed + delay
|
||||
delay' = nextDelay elapsed' delay ri
|
||||
delay' = nextRetryDelay elapsed' delay ri
|
||||
call (elapsed', delay')
|
||||
wait delay = do
|
||||
waiting <- newTVarIO True
|
||||
@@ -87,8 +88,8 @@ withRetryLock2 RetryInterval2 {riSlow, riFast} lock action =
|
||||
takeTMVar lock
|
||||
writeTVar waiting False
|
||||
|
||||
nextDelay :: Int64 -> Int64 -> RetryInterval -> Int64
|
||||
nextDelay elapsed delay RetryInterval {increaseAfter, maxInterval} =
|
||||
nextRetryDelay :: Int64 -> Int64 -> RetryInterval -> Int64
|
||||
nextRetryDelay elapsed delay RetryInterval {increaseAfter, maxInterval} =
|
||||
if elapsed < increaseAfter || delay == maxInterval
|
||||
then delay
|
||||
else min (delay * 3 `div` 2) maxInterval
|
||||
|
||||
@@ -231,6 +231,7 @@ import Data.Bifunctor (first, second)
|
||||
import Data.ByteArray (ScrubbedBytes)
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (toLower)
|
||||
import Data.Functor (($>))
|
||||
@@ -270,7 +271,6 @@ import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..))
|
||||
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys, PQEncryption (..), PQSupport (..))
|
||||
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..))
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
@@ -1214,7 +1214,7 @@ setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem =
|
||||
db
|
||||
[sql|
|
||||
UPDATE ratchets
|
||||
SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?, pq_priv_kem = ?
|
||||
SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?, pq_priv_kem = ?
|
||||
WHERE conn_id = ?
|
||||
|]
|
||||
(x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, pqPrivKem, connId)
|
||||
|
||||
@@ -207,7 +207,7 @@ data NetworkConfig = NetworkConfig
|
||||
-- | timeout of protocol commands (microseconds)
|
||||
tcpTimeout :: Int,
|
||||
-- | additional timeout per kilobyte (1024 bytes) to be sent
|
||||
tcpTimeoutPerKb :: Int,
|
||||
tcpTimeoutPerKb :: Int64,
|
||||
-- | TCP keep-alive options, Nothing to skip enabling keep-alive
|
||||
tcpKeepAlive :: Maybe KeepAliveOpts,
|
||||
-- | period for SMP ping commands (microseconds, 0 to disable)
|
||||
@@ -230,7 +230,7 @@ defaultNetworkConfig =
|
||||
sessionMode = TSMUser,
|
||||
tcpConnectTimeout = 20_000_000,
|
||||
tcpTimeout = 15_000_000,
|
||||
tcpTimeoutPerKb = 45_000, -- 45ms, should be less than 130ms to avoid Int overflow on 32 bit systems
|
||||
tcpTimeoutPerKb = 5_000,
|
||||
tcpKeepAlive = Just defaultKeepAliveOpts,
|
||||
smpPingInterval = 600_000_000, -- 10min
|
||||
smpPingCount = 3,
|
||||
@@ -239,7 +239,7 @@ defaultNetworkConfig =
|
||||
|
||||
transportClientConfig :: NetworkConfig -> TransportClientConfig
|
||||
transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} =
|
||||
TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing}
|
||||
TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing, alpn = Nothing}
|
||||
{-# INLINE transportClientConfig #-}
|
||||
|
||||
-- | protocol client configuration.
|
||||
|
||||
@@ -211,6 +211,8 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import Data.ByteArray (ByteArrayAccess)
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString.Base64 (decode, encode)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.ByteString.Lazy (fromStrict, toStrict)
|
||||
@@ -228,8 +230,6 @@ import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+))
|
||||
import Network.Transport.Internal (decodeWord16, encodeWord16)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.Base64 (decode, encode)
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (blobFieldDecoder, parseAll, parseString)
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
-- | Compatibility wrappers for base64 package, Base64 (padded) variant.
|
||||
module Simplex.Messaging.Encoding.Base64 where
|
||||
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Base64.Types (extractBase64)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64 (decodeBase64Untyped, encodeBase64')
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Text as T
|
||||
|
||||
encode :: ByteString -> ByteString
|
||||
encode = extractBase64 . encodeBase64'
|
||||
{-# INLINE encode #-}
|
||||
|
||||
decode :: ByteString -> Either String ByteString
|
||||
decode = first T.unpack . decodeBase64Untyped
|
||||
{-# INLINE decode #-}
|
||||
|
||||
base64P :: A.Parser ByteString
|
||||
base64P = do
|
||||
str <- A.takeWhile1 (`B.elem` base64Alphabet)
|
||||
pad <- A.takeWhile (== '=') -- correct amount of padding can be derived from str length
|
||||
either (fail . T.unpack) pure $ decodeBase64Untyped (str <> pad)
|
||||
|
||||
base64Alphabet :: ByteString
|
||||
base64Alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
|
||||
@@ -1,33 +0,0 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
-- | Compatibility wrappers for base64 package, Base64URL-padded variant.
|
||||
module Simplex.Messaging.Encoding.Base64.URL where
|
||||
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Base64.Types (extractBase64)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64.URL (decodeBase64Lenient, decodeBase64UnpaddedUntyped, decodeBase64Untyped, encodeBase64')
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Text as T
|
||||
|
||||
encode :: ByteString -> ByteString
|
||||
encode = extractBase64 . encodeBase64'
|
||||
{-# INLINE encode #-}
|
||||
|
||||
decode :: ByteString -> Either String ByteString
|
||||
decode = first T.unpack . decodeBase64Untyped
|
||||
{-# INLINE decode #-}
|
||||
|
||||
decodeLenient :: ByteString -> ByteString
|
||||
decodeLenient = decodeBase64Lenient
|
||||
{-# INLINE decodeLenient #-}
|
||||
|
||||
base64urlP :: A.Parser ByteString
|
||||
base64urlP = do
|
||||
str <- A.takeWhile1 (`B.elem` base64AlphabetURL)
|
||||
_pad <- A.takeWhile (== '=') -- correct amount of padding can be derived from str length
|
||||
either (fail . T.unpack) pure $ decodeBase64UnpaddedUntyped str
|
||||
|
||||
base64AlphabetURL :: ByteString
|
||||
base64AlphabetURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
|
||||
@@ -10,6 +10,7 @@ module Simplex.Messaging.Encoding.String
|
||||
strToJSON,
|
||||
strToJEncoding,
|
||||
strParseJSON,
|
||||
base64urlP,
|
||||
strEncodeList,
|
||||
strListP,
|
||||
)
|
||||
@@ -22,8 +23,10 @@ import qualified Data.Aeson.Encoding as JE
|
||||
import qualified Data.Aeson.Types as JT
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isAlphaNum)
|
||||
import Data.Int (Int64)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Set (Set)
|
||||
@@ -35,7 +38,6 @@ import Data.Time.Clock.System (SystemTime (..))
|
||||
import Data.Time.Format.ISO8601
|
||||
import Data.Word (Word16, Word32)
|
||||
import Simplex.Messaging.Encoding
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
@@ -52,16 +54,19 @@ class StrEncoding a where
|
||||
strDecode :: ByteString -> Either String a
|
||||
strDecode = parseAll strP
|
||||
strP :: Parser a
|
||||
strP = strDecode <$?> U.base64urlP
|
||||
strP = strDecode <$?> base64urlP
|
||||
|
||||
-- base64url encoding/decoding of ByteStrings - the parser only allows non-empty strings
|
||||
instance StrEncoding ByteString where
|
||||
strEncode = U.encode
|
||||
{-# INLINE strEncode #-}
|
||||
strDecode = U.decode
|
||||
{-# INLINE strDecode #-}
|
||||
strP = U.base64urlP
|
||||
{-# INLINE strP #-}
|
||||
strP = base64urlP
|
||||
|
||||
base64urlP :: Parser ByteString
|
||||
base64urlP = do
|
||||
str <- A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_')
|
||||
pad <- A.takeWhile (== '=')
|
||||
either fail pure $ U.decode (str <> pad)
|
||||
|
||||
newtype Str = Str {unStr :: ByteString}
|
||||
deriving (Eq, Show)
|
||||
|
||||
@@ -11,12 +11,14 @@
|
||||
|
||||
module Simplex.Messaging.Notifications.Protocol where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..), (.:), (.=))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
import Data.Kind
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
|
||||
@@ -406,10 +408,12 @@ instance Encoding DeviceToken where
|
||||
|
||||
instance StrEncoding DeviceToken where
|
||||
strEncode (DeviceToken p t) = strEncode p <> " " <> t
|
||||
strP = DeviceToken <$> strP <* A.space <*> hexStringP
|
||||
strP = nullToken <|> hexToken
|
||||
where
|
||||
nullToken = "apns_null test_ntf_token" $> DeviceToken PPApnsNull "test_ntf_token"
|
||||
hexToken = DeviceToken <$> strP <* A.space <*> hexStringP
|
||||
hexStringP =
|
||||
A.takeWhile (\c -> A.isDigit c || (c >= 'a' && c <= 'f')) >>= \s ->
|
||||
A.takeWhile (`B.elem` "0123456789abcdef") >>= \s ->
|
||||
if even (B.length s) then pure s else fail "odd number of hex characters"
|
||||
|
||||
instance ToJSON DeviceToken where
|
||||
|
||||
@@ -108,7 +108,7 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do
|
||||
logServerStats :: Int64 -> Int64 -> FilePath -> M ()
|
||||
logServerStats startAt logInterval statsFilePath = do
|
||||
initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime
|
||||
liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath
|
||||
logInfo $ "server stats log enabled: " <> T.pack statsFilePath
|
||||
liftIO $ threadDelay' $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0)
|
||||
NtfServerStats {fromTime, tknCreated, tknVerified, tknDeleted, subCreated, subDeleted, ntfReceived, ntfDelivered, activeTokens, activeSubs} <- asks serverStats
|
||||
let interval = 1000000 * logInterval
|
||||
@@ -442,7 +442,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
|
||||
where
|
||||
processCommand :: NtfRequest -> M (Transmission NtfResponse)
|
||||
processCommand = \case
|
||||
NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn _ _ dhPubKey)) -> do
|
||||
NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn token _ dhPubKey)) -> do
|
||||
logDebug "TNEW - new token"
|
||||
st <- asks store
|
||||
ks@(srvDhPubKey, srvDhPrivKey) <- atomically . C.generateKeyPair =<< asks random
|
||||
@@ -453,9 +453,9 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
|
||||
atomically $ addNtfToken st tknId tkn
|
||||
atomically $ writeTBQueue pushQ (tkn, PNVerification regCode)
|
||||
withNtfLog (`logCreateToken` tkn)
|
||||
incNtfStat tknCreated
|
||||
incNtfStatT token tknCreated
|
||||
pure (corrId, "", NRTknId tknId srvDhPubKey)
|
||||
NtfReqCmd SToken (NtfTkn tkn@NtfTknData {ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey), tknCronInterval}) (corrId, tknId, cmd) -> do
|
||||
NtfReqCmd SToken (NtfTkn tkn@NtfTknData {token, ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey), tknCronInterval}) (corrId, tknId, cmd) -> do
|
||||
status <- readTVarIO tknStatus
|
||||
(corrId,tknId,) <$> case cmd of
|
||||
TNEW (NewNtfTkn _ _ dhPubKey) -> do
|
||||
@@ -474,7 +474,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
|
||||
updateTknStatus tkn NTActive
|
||||
tIds <- atomically $ removeInactiveTokenRegistrations st tkn
|
||||
forM_ tIds cancelInvervalNotifications
|
||||
incNtfStat tknVerified
|
||||
incNtfStatT token tknVerified
|
||||
pure NROk
|
||||
| otherwise -> do
|
||||
logDebug "TVFY - incorrect code or token status"
|
||||
@@ -493,8 +493,8 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
|
||||
addNtfToken st tknId tkn'
|
||||
writeTBQueue pushQ (tkn', PNVerification regCode)
|
||||
withNtfLog $ \s -> logUpdateToken s tknId token' regCode
|
||||
incNtfStat tknDeleted
|
||||
incNtfStat tknCreated
|
||||
incNtfStatT token tknDeleted
|
||||
incNtfStatT token tknCreated
|
||||
pure NROk
|
||||
TDEL -> do
|
||||
logDebug "TDEL"
|
||||
@@ -504,7 +504,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
|
||||
atomically $ removeSubscription ca smpServer (SPNotifier, notifierId)
|
||||
cancelInvervalNotifications tknId
|
||||
withNtfLog (`logDeleteToken` tknId)
|
||||
incNtfStat tknDeleted
|
||||
incNtfStatT token tknDeleted
|
||||
pure NROk
|
||||
TCRN 0 -> do
|
||||
logDebug "TCRN 0"
|
||||
@@ -583,6 +583,10 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
|
||||
withNtfLog :: (StoreLog 'WriteMode -> IO a) -> M ()
|
||||
withNtfLog action = liftIO . mapM_ action =<< asks storeLog
|
||||
|
||||
incNtfStatT :: DeviceToken -> (NtfServerStats -> TVar Int) -> M ()
|
||||
incNtfStatT (DeviceToken PPApnsNull _) _ = pure ()
|
||||
incNtfStatT _ statSel = incNtfStat statSel
|
||||
|
||||
incNtfStat :: (NtfServerStats -> TVar Int) -> M ()
|
||||
incNtfStat statSel = do
|
||||
stats <- asks serverStats
|
||||
|
||||
@@ -27,9 +27,8 @@ import Data.Aeson (ToJSON, (.=))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
import qualified Data.Aeson.TH as JQ
|
||||
import Data.Base64.Types (extractBase64)
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64.URL as UP
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Builder (lazyByteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
@@ -47,7 +46,6 @@ import Network.HTTP2.Client (Request)
|
||||
import qualified Network.HTTP2.Client as H
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Encoding.Base64.URL as U
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS.Internal
|
||||
@@ -56,7 +54,7 @@ import Simplex.Messaging.Parsers (defaultJSON)
|
||||
import Simplex.Messaging.Protocol (EncNMsgMeta)
|
||||
import Simplex.Messaging.Transport.HTTP2 (HTTP2Body (..))
|
||||
import Simplex.Messaging.Transport.HTTP2.Client
|
||||
import Simplex.Messaging.Util (safeDecodeUtf8)
|
||||
import Simplex.Messaging.Util (safeDecodeUtf8, tshow)
|
||||
import System.Environment (getEnv)
|
||||
import UnliftIO.STM
|
||||
|
||||
@@ -93,8 +91,8 @@ signedJWTToken pk (JWTToken hdr claims) = do
|
||||
pure $ hc <> "." <> serialize sig
|
||||
where
|
||||
jwtEncode :: ToJSON a => a -> ByteString
|
||||
jwtEncode = extractBase64 . UP.encodeBase64Unpadded' . LB.toStrict . J.encode
|
||||
serialize sig = extractBase64 . UP.encodeBase64Unpadded' $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
|
||||
jwtEncode = U.encodeUnpadded . LB.toStrict . J.encode
|
||||
serialize sig = U.encodeUnpadded $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
|
||||
|
||||
readECPrivateKey :: FilePath -> IO EC.PrivateKey
|
||||
readECPrivateKey f = do
|
||||
@@ -260,11 +258,11 @@ mkApnsJWTToken appTeamId jwtHeader privateKey = do
|
||||
connectHTTPS2 :: HostName -> APNSPushClientConfig -> TVar (Maybe HTTP2Client) -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
connectHTTPS2 apnsHost APNSPushClientConfig {apnsPort, http2cfg, caStoreFile} https2Client = do
|
||||
caStore_ <- XS.readCertificateStore caStoreFile
|
||||
when (isNothing caStore_) $ putStrLn $ "Error loading CertificateStore from " <> caStoreFile
|
||||
when (isNothing caStore_) $ logError $ "Error loading CertificateStore from " <> T.pack caStoreFile
|
||||
r <- getHTTP2Client apnsHost apnsPort caStore_ http2cfg disconnected
|
||||
case r of
|
||||
Right client -> atomically . writeTVar https2Client $ Just client
|
||||
Left e -> putStrLn $ "Error connecting to APNS: " <> show e
|
||||
Left e -> logError $ "Error connecting to APNS: " <> tshow e
|
||||
pure r
|
||||
where
|
||||
disconnected = atomically $ writeTVar https2Client Nothing
|
||||
|
||||
@@ -24,9 +24,12 @@ module Simplex.Messaging.Notifications.Server.StoreLog
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import qualified Data.Text as T
|
||||
import Data.Word (Word16)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
@@ -34,7 +37,7 @@ import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Store
|
||||
import Simplex.Messaging.Protocol (NtfPrivateAuthKey)
|
||||
import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.Util (whenM)
|
||||
import Simplex.Messaging.Util (safeDecodeUtf8, whenM)
|
||||
import System.Directory (doesFileExist, renameFile)
|
||||
import System.IO
|
||||
|
||||
@@ -189,10 +192,10 @@ readWriteNtfStore f st = do
|
||||
pure s
|
||||
|
||||
readNtfStore :: FilePath -> NtfStore -> IO ()
|
||||
readNtfStore f st = mapM_ addNtfLogRecord . B.lines =<< B.readFile f
|
||||
readNtfStore f st = mapM_ (addNtfLogRecord . LB.toStrict) . LB.lines =<< LB.readFile f
|
||||
where
|
||||
addNtfLogRecord s = case strDecode s of
|
||||
Left e -> B.putStrLn $ "Log parsing error (" <> B.pack e <> "): " <> B.take 100 s
|
||||
Left e -> logError $ "Log parsing error (" <> T.pack e <> "): " <> safeDecodeUtf8 (B.take 100 s)
|
||||
Right lr -> atomically $ case lr of
|
||||
CreateToken r@NtfTknRec {ntfTknId} -> do
|
||||
tkn <- mkTknData r
|
||||
|
||||
@@ -10,9 +10,10 @@ import qualified Data.Aeson as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (toLower)
|
||||
import Data.Char (isAlphaNum, toLower)
|
||||
import Data.String
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
@@ -23,8 +24,23 @@ import Database.SQLite.Simple (ResultError (..), SQLData (..))
|
||||
import Database.SQLite.Simple.FromField (FieldParser, returnError)
|
||||
import Database.SQLite.Simple.Internal (Field (..))
|
||||
import Database.SQLite.Simple.Ok (Ok (Ok))
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
import Text.Read (readMaybe)
|
||||
|
||||
base64P :: Parser ByteString
|
||||
base64P = decode <$?> paddedBase64 rawBase64P
|
||||
|
||||
paddedBase64 :: Parser ByteString -> Parser ByteString
|
||||
paddedBase64 raw = (<>) <$> raw <*> pad
|
||||
where
|
||||
pad = A.takeWhile (== '=')
|
||||
|
||||
rawBase64P :: Parser ByteString
|
||||
rawBase64P = A.takeWhile1 (\c -> isAlphaNum c || c == '+' || c == '/')
|
||||
|
||||
-- rawBase64UriP :: Parser ByteString
|
||||
-- rawBase64UriP = A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_')
|
||||
|
||||
tsISO8601P :: Parser UTCTime
|
||||
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd
|
||||
|
||||
|
||||
@@ -176,6 +176,7 @@ import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64 as B64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isPrint, isSpace)
|
||||
@@ -193,7 +194,6 @@ import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
|
||||
import Network.Socket (ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import qualified Simplex.Messaging.Encoding.Base64 as B64
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.ServiceScheme
|
||||
|
||||
@@ -45,8 +45,10 @@ import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64 (encode)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Either (fromRight, partitionEithers)
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
@@ -67,7 +69,6 @@ import Network.Socket (ServiceName, Socket, socketToHandle)
|
||||
import Simplex.Messaging.Agent.Lock
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding (Encoding (smpEncode))
|
||||
import Simplex.Messaging.Encoding.Base64 (encode)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.Control
|
||||
@@ -983,7 +984,7 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= \case
|
||||
ms <- asks msgStore
|
||||
quota <- asks $ msgQueueQuota . config
|
||||
old_ <- asks (messageExpiration . config) $>>= (liftIO . fmap Just . expireBeforeEpoch)
|
||||
runExceptT (liftIO (B.readFile f) >>= foldM (\expired -> restoreMsg expired ms quota old_) 0 . B.lines) >>= \case
|
||||
runExceptT (liftIO (LB.readFile f) >>= foldM (\expired -> restoreMsg expired ms quota old_) 0 . LB.lines) >>= \case
|
||||
Left e -> do
|
||||
logError . T.pack $ "error restoring messages: " <> e
|
||||
liftIO exitFailure
|
||||
@@ -992,10 +993,11 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= \case
|
||||
logInfo "messages restored"
|
||||
pure expired
|
||||
where
|
||||
restoreMsg !expired ms quota old_ s = do
|
||||
restoreMsg !expired ms quota old_ s' = do
|
||||
MLRv3 rId msg <- liftEither . first (msgErr "parsing") $ strDecode s
|
||||
addToMsgQueue rId msg
|
||||
where
|
||||
s = LB.toStrict s'
|
||||
addToMsgQueue rId msg = do
|
||||
(isExpired, logFull) <- atomically $ do
|
||||
q <- getMsgQueue ms rId quota
|
||||
|
||||
@@ -25,8 +25,8 @@ where
|
||||
|
||||
import Control.Applicative (optional, (<|>))
|
||||
import Control.Monad (foldM, unless, when)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Functor (($>))
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
@@ -148,13 +148,14 @@ writeQueues s = mapM_ $ \q -> when (active q) $ logCreateQueue s q
|
||||
active QueueRec {status} = status == QueueActive
|
||||
|
||||
readQueues :: FilePath -> IO (Map RecipientId QueueRec)
|
||||
readQueues f = foldM processLine M.empty . B.lines =<< B.readFile f
|
||||
readQueues f = foldM processLine M.empty . LB.lines =<< LB.readFile f
|
||||
where
|
||||
processLine :: Map RecipientId QueueRec -> ByteString -> IO (Map RecipientId QueueRec)
|
||||
processLine m s = case strDecode $ trimCR s of
|
||||
processLine :: Map RecipientId QueueRec -> LB.ByteString -> IO (Map RecipientId QueueRec)
|
||||
processLine m s' = case strDecode $ trimCR s of
|
||||
Right r -> pure $ procLogRecord r
|
||||
Left e -> printError e $> m
|
||||
where
|
||||
s = LB.toStrict s'
|
||||
procLogRecord :: StoreLogRecord -> Map RecipientId QueueRec
|
||||
procLogRecord = \case
|
||||
CreateQueue q -> M.insert (recipientId q) q m
|
||||
|
||||
@@ -54,6 +54,7 @@ module Simplex.Messaging.Transport
|
||||
-- * TLS Transport
|
||||
TLS (..),
|
||||
SessionId,
|
||||
ALPN,
|
||||
connectTLS,
|
||||
closeTLS,
|
||||
supportedParameters,
|
||||
@@ -228,10 +229,13 @@ data TLS = TLS
|
||||
tlsPeer :: TransportPeer,
|
||||
tlsUniq :: ByteString,
|
||||
tlsBuffer :: TBuffer,
|
||||
tlsALPN :: Maybe ALPN,
|
||||
tlsServerCerts :: X.CertificateChain,
|
||||
tlsTransportConfig :: TransportConfig
|
||||
}
|
||||
|
||||
type ALPN = ByteString
|
||||
|
||||
connectTLS :: T.TLSParams p => Maybe HostName -> TransportConfig -> p -> Socket -> IO T.Context
|
||||
connectTLS host_ TransportConfig {logTLSErrors} params sock =
|
||||
E.bracketOnError (T.contextNew sock params) closeTLS $ \ctx ->
|
||||
@@ -246,7 +250,8 @@ getTLS tlsPeer cfg tlsServerCerts cxt = withTlsUnique tlsPeer cxt newTLS
|
||||
where
|
||||
newTLS tlsUniq = do
|
||||
tlsBuffer <- atomically newTBuffer
|
||||
pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsServerCerts, tlsPeer, tlsUniq, tlsBuffer}
|
||||
tlsALPN <- T.getNegotiatedProtocol cxt
|
||||
pure TLS {tlsContext = cxt, tlsALPN, tlsTransportConfig = cfg, tlsServerCerts, tlsPeer, tlsUniq, tlsBuffer}
|
||||
|
||||
withTlsUnique :: TransportPeer -> T.Context -> (ByteString -> IO c) -> IO c
|
||||
withTlsUnique peer cxt f =
|
||||
|
||||
@@ -17,6 +17,7 @@ module Simplex.Messaging.Transport.Client
|
||||
TransportHost (..),
|
||||
TransportHosts (..),
|
||||
TransportHosts_ (..),
|
||||
validateCertificateChain
|
||||
)
|
||||
where
|
||||
|
||||
@@ -113,12 +114,13 @@ data TransportClientConfig = TransportClientConfig
|
||||
{ socksProxy :: Maybe SocksProxy,
|
||||
tcpKeepAlive :: Maybe KeepAliveOpts,
|
||||
logTLSErrors :: Bool,
|
||||
clientCredentials :: Maybe (X.CertificateChain, T.PrivKey)
|
||||
clientCredentials :: Maybe (X.CertificateChain, T.PrivKey),
|
||||
alpn :: Maybe [ALPN]
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
defaultTransportClientConfig :: TransportClientConfig
|
||||
defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True Nothing
|
||||
defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True Nothing Nothing
|
||||
|
||||
clientTransportConfig :: TransportClientConfig -> TransportConfig
|
||||
clientTransportConfig TransportClientConfig {logTLSErrors} =
|
||||
@@ -129,10 +131,10 @@ runTransportClient :: Transport c => TransportClientConfig -> Maybe ByteString -
|
||||
runTransportClient = runTLSTransportClient supportedParameters Nothing
|
||||
|
||||
runTLSTransportClient :: Transport c => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> IO a) -> IO a
|
||||
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials} proxyUsername host port keyHash client = do
|
||||
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials, alpn} proxyUsername host port keyHash client = do
|
||||
serverCert <- newEmptyTMVarIO
|
||||
let hostName = B.unpack $ strEncode host
|
||||
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials serverCert
|
||||
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials alpn serverCert
|
||||
connectTCP = case socksProxy of
|
||||
Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host
|
||||
_ -> connectTCPClient hostName
|
||||
@@ -215,14 +217,15 @@ instance ToJSON SocksProxy where
|
||||
instance FromJSON SocksProxy where
|
||||
parseJSON = strParseJSON "SocksProxy"
|
||||
|
||||
mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe (X.CertificateChain, T.PrivKey) -> TMVar X.CertificateChain -> T.ClientParams
|
||||
mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ serverCerts =
|
||||
mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe (X.CertificateChain, T.PrivKey) -> Maybe [ALPN] -> TMVar X.CertificateChain -> T.ClientParams
|
||||
mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ alpn_ serverCerts =
|
||||
(T.defaultParamsClient host p)
|
||||
{ T.clientShared = def {T.sharedCAStore = fromMaybe (T.sharedCAStore def) caStore_},
|
||||
T.clientHooks =
|
||||
def
|
||||
{ T.onServerCertificate = onServerCert,
|
||||
T.onCertificateRequest = maybe def (const . pure . Just) clientCreds_
|
||||
T.onCertificateRequest = maybe def (const . pure . Just) clientCreds_,
|
||||
T.onSuggestALPN = pure alpn_
|
||||
},
|
||||
T.clientSupported = supported
|
||||
}
|
||||
@@ -237,7 +240,7 @@ mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ serverCerts =
|
||||
validateCertificateChain :: C.KeyHash -> HostName -> ByteString -> X.CertificateChain -> IO [XV.FailedReason]
|
||||
validateCertificateChain _ _ _ (X.CertificateChain []) = pure [XV.EmptyChain]
|
||||
validateCertificateChain _ _ _ (X.CertificateChain [_]) = pure [XV.EmptyChain]
|
||||
validateCertificateChain (C.KeyHash kh) host port cc@(X.CertificateChain sc@[_, caCert]) =
|
||||
validateCertificateChain (C.KeyHash kh) host port cc@(X.CertificateChain [_, caCert]) =
|
||||
if Fingerprint kh == XV.getFingerprint caCert X.HashSHA256
|
||||
then x509validate
|
||||
else pure [XV.UnknownCA]
|
||||
@@ -247,7 +250,7 @@ validateCertificateChain (C.KeyHash kh) host port cc@(X.CertificateChain sc@[_,
|
||||
where
|
||||
hooks = XV.defaultHooks
|
||||
checks = XV.defaultChecks {XV.checkFQHN = False}
|
||||
certStore = XS.makeCertificateStore sc
|
||||
certStore = XS.makeCertificateStore [caCert]
|
||||
cache = XV.exceptionValidationCache [] -- we manually check fingerprint only of the identity certificate (ca.crt)
|
||||
serviceID = (host, port)
|
||||
validateCertificateChain _ _ _ _ = pure [XV.AuthorityTooDeep]
|
||||
|
||||
@@ -16,15 +16,15 @@ import qualified Network.HTTP2.Server as HS
|
||||
import Network.Socket (SockAddr (..))
|
||||
import qualified Network.TLS as T
|
||||
import qualified Network.TLS.Extra as TE
|
||||
import Simplex.Messaging.Transport (SessionId, TLS (tlsUniq), Transport (cGet, cPut))
|
||||
import Simplex.Messaging.Transport (TLS, Transport (cGet, cPut))
|
||||
import Simplex.Messaging.Transport.Buffer
|
||||
import qualified System.TimeManager as TI
|
||||
|
||||
defaultHTTP2BufferSize :: BufferSize
|
||||
defaultHTTP2BufferSize = 32768
|
||||
|
||||
withHTTP2 :: BufferSize -> (Config -> SessionId -> IO a) -> TLS -> IO a
|
||||
withHTTP2 sz run c = E.bracket (allocHTTP2Config c sz) freeSimpleConfig (`run` tlsUniq c)
|
||||
withHTTP2 :: BufferSize -> (Config -> IO a) -> IO () -> TLS -> IO a
|
||||
withHTTP2 sz run fin c = E.bracket (allocHTTP2Config c sz) (\cfg -> freeSimpleConfig cfg `E.finally` fin) run
|
||||
|
||||
allocHTTP2Config :: TLS -> BufferSize -> IO Config
|
||||
allocHTTP2Config c sz = do
|
||||
|
||||
@@ -23,15 +23,20 @@ import qualified Network.TLS as T
|
||||
import Numeric.Natural (Natural)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Transport (SessionId, TLS)
|
||||
import Simplex.Messaging.Transport (ALPN, SessionId, TLS (tlsALPN), getServerCerts, getServerVerifyKey, tlsUniq)
|
||||
import Simplex.Messaging.Transport.Client (TransportClientConfig (..), TransportHost (..), runTLSTransportClient)
|
||||
import Simplex.Messaging.Transport.HTTP2
|
||||
import Simplex.Messaging.Util (eitherToMaybe)
|
||||
import UnliftIO.STM
|
||||
import UnliftIO.Timeout
|
||||
import qualified Data.X509 as X
|
||||
|
||||
data HTTP2Client = HTTP2Client
|
||||
{ action :: Maybe (Async HTTP2Response),
|
||||
sessionId :: SessionId,
|
||||
sessionALPN :: Maybe ALPN,
|
||||
serverKey :: Maybe C.APublicVerifyKey, -- may not always be a key we control (i.e. APNS with apple-mandated key types)
|
||||
serverCerts :: X.CertificateChain,
|
||||
sessionTs :: UTCTime,
|
||||
sendReq :: Request -> (Response -> IO HTTP2Response) -> IO HTTP2Response,
|
||||
client_ :: HClient
|
||||
@@ -66,7 +71,7 @@ defaultHTTP2ClientConfig =
|
||||
HTTP2ClientConfig
|
||||
{ qSize = 64,
|
||||
connTimeout = 10000000,
|
||||
transportConfig = TransportClientConfig Nothing Nothing True Nothing,
|
||||
transportConfig = TransportClientConfig Nothing Nothing True Nothing Nothing,
|
||||
bufferSize = defaultHTTP2BufferSize,
|
||||
bodyHeadSize = 16384,
|
||||
suportedTLSParams = http2TLSParams
|
||||
@@ -86,9 +91,10 @@ getVerifiedHTTP2Client proxyUsername host port keyHash caStore config disconnect
|
||||
attachHTTP2Client :: HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> Int -> TLS -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
attachHTTP2Client config host port disconnected bufferSize tls = getVerifiedHTTP2ClientWith config host port disconnected setup
|
||||
where
|
||||
setup :: (TLS -> H.Client HTTP2Response) -> IO HTTP2Response
|
||||
setup = runHTTP2ClientWith bufferSize host ($ tls)
|
||||
|
||||
getVerifiedHTTP2ClientWith :: HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> ((SessionId -> H.Client HTTP2Response) -> IO HTTP2Response) -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
getVerifiedHTTP2ClientWith :: HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> ((TLS -> H.Client HTTP2Response) -> IO HTTP2Response) -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
getVerifiedHTTP2ClientWith config host port disconnected setup =
|
||||
(atomically mkHTTPS2Client >>= runClient)
|
||||
`E.catch` \(e :: IOException) -> pure . Left $ HCIOError e
|
||||
@@ -104,15 +110,25 @@ getVerifiedHTTP2ClientWith config host port disconnected setup =
|
||||
cVar <- newEmptyTMVarIO
|
||||
action <- async $ setup (client c cVar) `E.finally` atomically (putTMVar cVar $ Left HCNetworkError)
|
||||
c_ <- connTimeout config `timeout` atomically (takeTMVar cVar)
|
||||
pure $ case c_ of
|
||||
Just (Right c') -> Right c' {action = Just action}
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left HCNetworkError
|
||||
case c_ of
|
||||
Just (Right c') -> pure $ Right c' {action = Just action}
|
||||
Just (Left e) -> pure $ Left e
|
||||
Nothing -> cancel action $> Left HCNetworkError
|
||||
|
||||
client :: HClient -> TMVar (Either HTTP2ClientError HTTP2Client) -> SessionId -> H.Client HTTP2Response
|
||||
client c cVar sessionId sendReq = do
|
||||
client :: HClient -> TMVar (Either HTTP2ClientError HTTP2Client) -> TLS -> H.Client HTTP2Response
|
||||
client c cVar tls sendReq = do
|
||||
sessionTs <- getCurrentTime
|
||||
let c' = HTTP2Client {action = Nothing, client_ = c, sendReq, sessionId, sessionTs}
|
||||
let c' =
|
||||
HTTP2Client
|
||||
{ action = Nothing,
|
||||
client_ = c,
|
||||
serverKey = eitherToMaybe $ getServerVerifyKey tls,
|
||||
serverCerts = getServerCerts tls,
|
||||
sendReq,
|
||||
sessionTs,
|
||||
sessionId = tlsUniq tls,
|
||||
sessionALPN = tlsALPN tls
|
||||
}
|
||||
atomically $ do
|
||||
writeTVar (connected c) True
|
||||
putTMVar cVar (Right c')
|
||||
@@ -154,13 +170,14 @@ sendRequestDirect HTTP2Client {client_ = HClient {config, disconnected}, sendReq
|
||||
http2RequestTimeout :: HTTP2ClientConfig -> Maybe Int -> Int
|
||||
http2RequestTimeout HTTP2ClientConfig {connTimeout} = maybe connTimeout (connTimeout +)
|
||||
|
||||
runHTTP2Client :: forall a. T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> BufferSize -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (SessionId -> H.Client a) -> IO a
|
||||
runHTTP2Client :: forall a. T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> BufferSize -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (TLS -> H.Client a) -> IO a
|
||||
runHTTP2Client tlsParams caStore tcConfig bufferSize proxyUsername host port keyHash = runHTTP2ClientWith bufferSize host setup
|
||||
where
|
||||
setup :: (TLS -> IO a) -> IO a
|
||||
setup = runTLSTransportClient tlsParams caStore tcConfig proxyUsername host port keyHash
|
||||
|
||||
runHTTP2ClientWith :: forall a. BufferSize -> TransportHost -> ((TLS -> IO a) -> IO a) -> (SessionId -> H.Client a) -> IO a
|
||||
runHTTP2ClientWith bufferSize host setup client = setup $ withHTTP2 bufferSize run
|
||||
runHTTP2ClientWith :: forall a. BufferSize -> TransportHost -> ((TLS -> IO a) -> IO a) -> (TLS -> H.Client a) -> IO a
|
||||
runHTTP2ClientWith bufferSize host setup client = setup $ \tls -> withHTTP2 bufferSize (run tls) (pure ()) tls
|
||||
where
|
||||
run :: H.Config -> SessionId -> IO a
|
||||
run cfg sessId = H.run (ClientConfig "https" (strEncode host) 20) cfg $ client sessId
|
||||
run :: TLS -> H.Config -> IO a
|
||||
run tls cfg = H.run (ClientConfig "https" (strEncode host) 20) cfg $ client tls
|
||||
|
||||
@@ -13,14 +13,14 @@ import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
import Numeric.Natural (Natural)
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Transport (SessionId, TLS, closeConnection)
|
||||
import Simplex.Messaging.Transport (ALPN, SessionId, TLS, closeConnection, tlsALPN, tlsUniq)
|
||||
import Simplex.Messaging.Transport.HTTP2
|
||||
import Simplex.Messaging.Transport.Server (TransportServerConfig (..), loadSupportedTLSServerParams, runTransportServer)
|
||||
import Simplex.Messaging.Util (threadDelay')
|
||||
import UnliftIO (finally)
|
||||
import UnliftIO.Concurrent (forkIO, killThread)
|
||||
|
||||
type HTTP2ServerFunc = SessionId -> Request -> (Response -> IO ()) -> IO ()
|
||||
type HTTP2ServerFunc = SessionId -> Maybe ALPN -> Request -> (Response -> IO ()) -> IO ()
|
||||
|
||||
data HTTP2ServerConfig = HTTP2ServerConfig
|
||||
{ qSize :: Natural,
|
||||
@@ -37,6 +37,7 @@ data HTTP2ServerConfig = HTTP2ServerConfig
|
||||
|
||||
data HTTP2Request = HTTP2Request
|
||||
{ sessionId :: SessionId,
|
||||
sessionALPN :: Maybe ALPN,
|
||||
request :: Request,
|
||||
reqBody :: HTTP2Body,
|
||||
sendResponse :: Response -> IO ()
|
||||
@@ -54,32 +55,32 @@ getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, se
|
||||
started <- newEmptyTMVarIO
|
||||
reqQ <- newTBQueueIO qSize
|
||||
action <- async $
|
||||
runHTTP2Server started http2Port bufferSize tlsServerParams transportConfig Nothing $ \sessionId r sendResponse -> do
|
||||
runHTTP2Server started http2Port bufferSize tlsServerParams transportConfig Nothing (const $ pure ()) $ \sessionId sessionALPN r sendResponse -> do
|
||||
reqBody <- getHTTP2Body r bodyHeadSize
|
||||
atomically $ writeTBQueue reqQ HTTP2Request {sessionId, request = r, reqBody, sendResponse}
|
||||
atomically $ writeTBQueue reqQ HTTP2Request {sessionId, sessionALPN, request = r, reqBody, sendResponse}
|
||||
void . atomically $ takeTMVar started
|
||||
pure HTTP2Server {action, reqQ}
|
||||
|
||||
closeHTTP2Server :: HTTP2Server -> IO ()
|
||||
closeHTTP2Server = uninterruptibleCancel . action
|
||||
|
||||
runHTTP2Server :: TMVar Bool -> ServiceName -> BufferSize -> T.ServerParams -> TransportServerConfig -> Maybe ExpirationConfig -> HTTP2ServerFunc -> IO ()
|
||||
runHTTP2Server started port bufferSize serverParams transportConfig expCfg_ = runHTTP2ServerWith_ expCfg_ bufferSize setup
|
||||
runHTTP2Server :: TMVar Bool -> ServiceName -> BufferSize -> T.ServerParams -> TransportServerConfig -> Maybe ExpirationConfig -> (SessionId -> IO ()) -> HTTP2ServerFunc -> IO ()
|
||||
runHTTP2Server started port bufferSize serverParams transportConfig expCfg_ clientFinished = runHTTP2ServerWith_ expCfg_ clientFinished bufferSize setup
|
||||
where
|
||||
setup = runTransportServer started port serverParams transportConfig
|
||||
|
||||
runHTTP2ServerWith :: BufferSize -> ((TLS -> IO ()) -> a) -> HTTP2ServerFunc -> a
|
||||
runHTTP2ServerWith = runHTTP2ServerWith_ Nothing
|
||||
runHTTP2ServerWith = runHTTP2ServerWith_ Nothing (\_sessId -> pure ())
|
||||
|
||||
runHTTP2ServerWith_ :: Maybe ExpirationConfig -> BufferSize -> ((TLS -> IO ()) -> a) -> HTTP2ServerFunc -> a
|
||||
runHTTP2ServerWith_ expCfg_ bufferSize setup http2Server = setup $ \tls -> do
|
||||
runHTTP2ServerWith_ :: Maybe ExpirationConfig -> (SessionId -> IO ()) -> BufferSize -> ((TLS -> IO ()) -> a) -> HTTP2ServerFunc -> a
|
||||
runHTTP2ServerWith_ expCfg_ clientFinished bufferSize setup http2Server = setup $ \tls -> do
|
||||
activeAt <- newTVarIO =<< getSystemTime
|
||||
tid_ <- mapM (forkIO . expireInactiveClient tls activeAt) expCfg_
|
||||
withHTTP2 bufferSize (run activeAt) tls `finally` mapM_ killThread tid_
|
||||
withHTTP2 bufferSize (run tls activeAt) (clientFinished $ tlsUniq tls) tls `finally` mapM_ killThread tid_
|
||||
where
|
||||
run activeAt cfg sessId = H.run cfg $ \req _aux sendResp -> do
|
||||
run tls activeAt cfg = H.run cfg $ \req _aux sendResp -> do
|
||||
getSystemTime >>= atomically . writeTVar activeAt
|
||||
http2Server sessId req (`sendResp` [])
|
||||
http2Server (tlsUniq tls) (tlsALPN tls) req (`sendResp` [])
|
||||
expireInactiveClient tls activeAt expCfg = loop
|
||||
where
|
||||
loop = do
|
||||
|
||||
Reference in New Issue
Block a user