Merge remote-tracking branch 'origin/master' into ab/bench-target

This commit is contained in:
Alexander Bondarenko
2024-04-11 19:58:49 +03:00
54 changed files with 1669 additions and 1160 deletions
+9 -8
View File
@@ -179,7 +179,8 @@ runXFTPRcvWorker c srv Worker {doWork} = do
RcvFileChunk {rcvFileId, rcvFileEntityId, fileTmpPath, replicas = []} -> rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath) "chunk has no replicas"
fc@RcvFileChunk {userId, rcvFileId, rcvFileEntityId, digest, fileTmpPath, replicas = replica@RcvFileChunkReplica {rcvChunkReplicaId, server, delay} : _} -> do
let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop ->
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do
lift $ waitForUserNetwork c
downloadFileChunk fc replica
`catchAgentError` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e
where
@@ -389,7 +390,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
let numRecipients' = min numRecipients maxRecipients
-- concurrently?
-- separate worker to create chunks? record retries and delay on snd_file_chunks?
forM_ (filter (not . chunkCreated) chunks) $ createChunk numRecipients'
forM_ (filter (\SndFileChunk {replicas} -> null replicas) chunks) $ createChunk numRecipients'
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSUploading
where
AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients, messageRetryInterval = ri} = cfg
@@ -413,9 +414,6 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
let chunkSpecs = prepareChunkSpecs fsEncPath chunkSizes
chunkDigests <- liftIO $ mapM getChunkDigest chunkSpecs
pure (FileDigest digest, zip chunkSpecs $ coerce chunkDigests)
chunkCreated :: SndFileChunk -> Bool
chunkCreated SndFileChunk {replicas} =
any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSCreated) replicas
createChunk :: Int -> SndFileChunk -> AM ()
createChunk numRecipients' ch = do
atomically $ assertAgentForeground c
@@ -425,7 +423,8 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
where
tryCreate = do
usedSrvs <- newTVarIO ([] :: [XFTPServer])
withRetryInterval (riFast ri) $ \_ loop ->
withRetryInterval (riFast ri) $ \_ loop -> do
lift $ waitForUserNetwork c
createWithNextSrv usedSrvs
`catchAgentError` \e -> retryOnError "XFTP prepare worker" (retryLoop loop) (throwError e) e
where
@@ -457,7 +456,8 @@ runXFTPSndWorker c srv Worker {doWork} = do
SndFileChunk {sndFileId, sndFileEntityId, filePrefixPath, replicas = []} -> sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) "chunk has no replicas"
fc@SndFileChunk {userId, sndFileId, sndFileEntityId, filePrefixPath, digest, replicas = replica@SndFileChunkReplica {sndChunkReplicaId, server, delay} : _} -> do
let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop ->
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do
lift $ waitForUserNetwork c
uploadFileChunk cfg fc replica
`catchAgentError` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e
where
@@ -623,7 +623,8 @@ runXFTPDelWorker c srv Worker {doWork} = do
where
processDeletedReplica replica@DeletedSndChunkReplica {deletedSndChunkReplicaId, userId, server, chunkDigest, delay} = do
let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop ->
withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do
lift $ waitForUserNetwork c
deleteChunkReplica
`catchAgentError` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e
where
+66 -18
View File
@@ -4,11 +4,14 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Simplex.FileTransfer.Client where
import Control.Logger.Simple
import Control.Monad
import Control.Monad.Except
import Crypto.Random (ChaChaDRG)
@@ -20,9 +23,10 @@ import Data.Int (Int64)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Time (UTCTime)
import Data.Word (Word32)
import qualified Data.X509 as X
import qualified Data.X509.Validation as XV
import qualified Network.HTTP.Types as N
import qualified Network.HTTP2.Client as H
import Simplex.FileTransfer.Description (mb)
import Simplex.FileTransfer.Protocol
import Simplex.FileTransfer.Transport
import Simplex.Messaging.Client
@@ -37,6 +41,7 @@ import Simplex.Messaging.Client
import Simplex.Messaging.Client.Agent ()
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import Simplex.Messaging.Encoding (smpDecode, smpEncode)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol
( BasicAuth,
@@ -45,12 +50,13 @@ import Simplex.Messaging.Protocol
RecipientId,
SenderId,
)
import Simplex.Messaging.Transport (THandleParams (..), supportedParameters)
import Simplex.Messaging.Transport.Client (TransportClientConfig, TransportHost)
import Simplex.Messaging.Transport (HandshakeError (VERSION), THandleAuth (..), THandleParams (..), TransportError (..), supportedParameters)
import Simplex.Messaging.Transport.Client (TransportClientConfig, TransportHost, alpn)
import Simplex.Messaging.Transport.HTTP2
import Simplex.Messaging.Transport.HTTP2.Client
import Simplex.Messaging.Transport.HTTP2.File
import Simplex.Messaging.Util (bshow, whenM)
import Simplex.Messaging.Util (bshow, liftEitherWith, liftError', tshow, whenM)
import Simplex.Messaging.Version (compatibleVersion, pattern Compatible)
import UnliftIO
import UnliftIO.Directory
@@ -63,7 +69,7 @@ data XFTPClient = XFTPClient
data XFTPClientConfig = XFTPClientConfig
{ xftpNetworkConfig :: NetworkConfig,
uploadTimeoutPerMb :: Int64
serverVRange :: VersionRangeXFTP
}
data XFTPChunkBody = XFTPChunkBody
@@ -85,12 +91,12 @@ defaultXFTPClientConfig :: XFTPClientConfig
defaultXFTPClientConfig =
XFTPClientConfig
{ xftpNetworkConfig = defaultNetworkConfig,
uploadTimeoutPerMb = 10000000 -- 10 seconds
serverVRange = supportedFileServerVRange
}
getXFTPClient :: TransportSession FileResponse -> XFTPClientConfig -> (XFTPClient -> IO ()) -> IO (Either XFTPClientError XFTPClient)
getXFTPClient transportSession@(_, srv, _) config@XFTPClientConfig {xftpNetworkConfig} disconnected = runExceptT $ do
let tcConfig = transportClientConfig xftpNetworkConfig
getXFTPClient :: TVar ChaChaDRG -> TransportSession FileResponse -> XFTPClientConfig -> (XFTPClient -> IO ()) -> IO (Either XFTPClientError XFTPClient)
getXFTPClient g transportSession@(_, srv, _) config@XFTPClientConfig {xftpNetworkConfig, serverVRange} disconnected = runExceptT $ do
let tcConfig = (transportClientConfig xftpNetworkConfig) {alpn = Just ["xftp/1"]}
http2Config = xftpHTTP2Config tcConfig config
username = proxyUsername transportSession
ProtocolServer _ host port keyHash = srv
@@ -98,13 +104,50 @@ getXFTPClient transportSession@(_, srv, _) config@XFTPClientConfig {xftpNetworkC
clientVar <- newTVarIO Nothing
let usePort = if null port then "443" else port
clientDisconnected = readTVarIO clientVar >>= mapM_ disconnected
http2Client <- withExceptT xftpClientError . ExceptT $ getVerifiedHTTP2Client (Just username) useHost usePort (Just keyHash) Nothing http2Config clientDisconnected
let HTTP2Client {sessionId} = http2Client
thParams = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = currentXFTPVersion, thAuth = Nothing, implySessId = False, batch = True}
c = XFTPClient {http2Client, thParams, transportSession, config}
http2Client <- liftError' xftpClientError $ getVerifiedHTTP2Client (Just username) useHost usePort (Just keyHash) Nothing http2Config clientDisconnected
let HTTP2Client {sessionId, sessionALPN} = http2Client
thParams0 = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = VersionXFTP 1, thAuth = Nothing, implySessId = False, batch = True}
logDebug $ "Client negotiated handshake protocol: " <> tshow sessionALPN
thParams <- case sessionALPN of
Just "xftp/1" -> xftpClientHandshakeV1 g serverVRange keyHash http2Client thParams0
Nothing -> pure thParams0
_ -> throwError $ PCETransportError (TEHandshake VERSION)
let c = XFTPClient {http2Client, thParams, transportSession, config}
atomically $ writeTVar clientVar $ Just c
pure c
xftpClientHandshakeV1 :: TVar ChaChaDRG -> VersionRangeXFTP -> C.KeyHash -> HTTP2Client -> THandleParamsXFTP -> ExceptT XFTPClientError IO THandleParamsXFTP
xftpClientHandshakeV1 g serverVRange keyHash@(C.KeyHash kh) c@HTTP2Client {sessionId, serverKey} thParams0 = do
shs <- getServerHandshake
(v, sk) <- processServerHandshake shs
(k, pk) <- atomically $ C.generateKeyPair g
sendClientHandshake XFTPClientHandshake {xftpVersion = v, keyHash, authPubKey = k}
pure thParams0 {thAuth = Just THandleAuth {peerPubKey = sk, privKey = pk}, thVersion = v}
where
getServerHandshake = do
let helloReq = H.requestNoBody "POST" "/" []
HTTP2Response {respBody = HTTP2Body {bodyHead = shsBody}} <-
liftError' (const $ PCEResponseError HANDSHAKE) $ sendRequest c helloReq Nothing
liftHS . smpDecode =<< liftHS (C.unPad shsBody)
processServerHandshake XFTPServerHandshake {xftpVersionRange, sessionId = serverSessId, authPubKey = serverAuth} = do
unless (sessionId == serverSessId) $ throwError $ PCEResponseError SESSION
case xftpVersionRange `compatibleVersion` serverVRange of
Nothing -> throwError $ PCEResponseError HANDSHAKE
Just (Compatible v) ->
fmap (v,) . liftHS $ do
let (X.CertificateChain cert, exact) = serverAuth
case cert of
[_leaf, ca] | XV.Fingerprint kh == XV.getFingerprint ca X.HashSHA256 -> pure ()
_ -> throwError "bad certificate"
pubKey <- maybe (throwError "bad server key type") (`C.verifyX509` exact) serverKey
C.x509ToPublic (pubKey, []) >>= C.pubKey
sendClientHandshake chs = do
chs' <- liftHS $ C.pad (smpEncode chs) xftpBlockSize
let chsReq = H.requestBuilder "POST" "/" [] $ byteString chs'
HTTP2Response {respBody = HTTP2Body {bodyHead}} <- liftError' (const $ PCEResponseError HANDSHAKE) $ sendRequest c chsReq Nothing
unless (B.null bodyHead) $ throwError $ PCEResponseError HANDSHAKE
liftHS = liftEitherWith (const $ PCEResponseError HANDSHAKE)
closeXFTPClient :: XFTPClient -> IO ()
closeXFTPClient XFTPClient {http2Client} = closeHTTP2Client http2Client
@@ -144,8 +187,8 @@ sendXFTPCommand c@XFTPClient {thParams} pKey fId cmd chunkSpec_ = do
sendXFTPTransmission :: XFTPClient -> ByteString -> Maybe XFTPChunkSpec -> ExceptT XFTPClientError IO (FileResponse, HTTP2Body)
sendXFTPTransmission XFTPClient {config, thParams, http2Client} t chunkSpec_ = do
let req = H.requestStreaming N.methodPost "/" [] streamBody
reqTimeout = (\XFTPChunkSpec {chunkSize} -> chunkTimeout config chunkSize) <$> chunkSpec_
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req reqTimeout
reqTimeout = xftpReqTimeout config $ (\XFTPChunkSpec {chunkSize} -> chunkSize) <$> chunkSpec_
HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req (Just reqTimeout)
when (B.length bodyHead /= xftpBlockSize) $ throwError $ PCEResponseError BLOCK
-- TODO validate that the file ID is the same as in the request?
(_, _, (_, _fId, respOrErr)) <- liftEither . first PCEResponseError $ xftpDecodeTransmission thParams bodyHead
@@ -198,15 +241,20 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec {
let t = chunkTimeout config chunkSize
ExceptT (sequence <$> (t `timeout` download cbState)) >>= maybe (throwError PCEResponseTimeout) pure
where
download cbState = runExceptT $
withExceptT PCEResponseError $
download cbState =
runExceptT . withExceptT PCEResponseError $
receiveEncFile chunkPart cbState chunkSpec `catchError` \e ->
whenM (doesFileExist filePath) (removeFile filePath) >> throwError e
_ -> throwError $ PCEResponseError NO_FILE
(r, _) -> throwError . PCEUnexpectedResponse $ bshow r
xftpReqTimeout :: XFTPClientConfig -> Maybe Word32 -> Int
xftpReqTimeout cfg@XFTPClientConfig {xftpNetworkConfig = NetworkConfig {tcpTimeout}} chunkSize_ =
maybe tcpTimeout (chunkTimeout cfg) chunkSize_
chunkTimeout :: XFTPClientConfig -> Word32 -> Int
chunkTimeout config chunkSize = fromIntegral $ (fromIntegral chunkSize * uploadTimeoutPerMb config) `div` mb 1
chunkTimeout XFTPClientConfig {xftpNetworkConfig = NetworkConfig {tcpTimeout, tcpTimeoutPerKb}} sz =
tcpTimeout + fromIntegral (min ((fromIntegral sz `div` 1024) * tcpTimeoutPerKb) (fromIntegral (maxBound :: Int)))
deleteXFTPChunk :: XFTPClient -> C.APrivateAuthKey -> SenderId -> ExceptT XFTPClientError IO ()
deleteXFTPChunk c spKey sId = sendXFTPCommand c spKey sId FDEL Nothing >>= okResponse
+4 -3
View File
@@ -11,6 +11,7 @@ import Control.Logger.Simple (logInfo)
import Control.Monad
import Control.Monad.Except
import Control.Monad.Trans (lift)
import Crypto.Random (ChaChaDRG)
import Data.Bifunctor (first)
import qualified Data.ByteString.Char8 as B
import Data.Text (Text)
@@ -60,15 +61,15 @@ newXFTPAgent config = do
type ME a = ExceptT XFTPClientAgentError IO a
getXFTPServerClient :: XFTPClientAgent -> XFTPServer -> ME XFTPClient
getXFTPServerClient XFTPClientAgent {xftpClients, config} srv = do
getXFTPServerClient :: TVar ChaChaDRG -> XFTPClientAgent -> XFTPServer -> ME XFTPClient
getXFTPServerClient g XFTPClientAgent {xftpClients, config} srv = do
atomically getClientVar >>= either newXFTPClient waitForXFTPClient
where
connectClient :: ME XFTPClient
connectClient =
ExceptT $
first (XFTPClientAgentError srv)
<$> getXFTPClient (1, srv, Nothing) (xftpConfig config) clientDisconnected
<$> getXFTPClient g (1, srv, Nothing) (xftpConfig config) clientDisconnected
clientDisconnected :: XFTPClient -> IO ()
clientDisconnected _ = do
+13 -12
View File
@@ -333,9 +333,9 @@ cliSendFileOpts SendOptions {filePath, outputDir, numRecipients, xftpServers, re
rKeys <- atomically $ L.fromList <$> replicateM numRecipients (C.generateAuthKeyPair C.SEd25519 g)
digest <- liftIO $ getChunkDigest chunkSpec
let ch = FileInfo {sndKey, size = fromIntegral chunkSize, digest}
c <- withRetry retryCount $ getXFTPServerClient a xftpServer
c <- withRetry retryCount $ getXFTPServerClient g a xftpServer
(sndId, rIds) <- withRetry retryCount $ createXFTPChunk c spKey ch (L.map fst rKeys) auth
withReconnect a xftpServer retryCount $ \c' -> uploadXFTPChunk c' spKey sndId chunkSpec
withReconnect g a xftpServer retryCount $ \c' -> uploadXFTPChunk c' spKey sndId chunkSpec
logInfo $ "uploaded chunk " <> tshow chunkNo
uploaded <- atomically . stateTVar uploadedChunks $ \cs ->
let cs' = fromIntegral chunkSize : cs in (sum cs', cs')
@@ -445,7 +445,7 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath,
when (FileSize encSize /= size) $ throwError $ CLIError "File size mismatch"
liftIO $ printNoNewLine "Decrypting file..."
CryptoFile path _ <- withExceptT cliCryptoError $ decryptChunks encSize chunkPaths key nonce $ fmap CF.plain . getFilePath
forM_ chunks $ acknowledgeFileChunk a
forM_ chunks $ acknowledgeFileChunk g a
whenM (doesPathExist encPath) $ removeDirectoryRecursive encPath
liftIO $ do
printNoNewLine $ "File downloaded: " <> path
@@ -456,7 +456,7 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath,
logInfo $ "downloading chunk " <> tshow chunkNo <> " from " <> showServer server <> "..."
chunkPath <- uniqueCombine encPath $ show chunkNo
let chunkSpec = XFTPRcvChunkSpec chunkPath (unFileSize chunkSize) (unFileDigest digest)
withReconnect a server retryCount $ \c -> downloadXFTPChunk g c replicaKey (unChunkReplicaId replicaId) chunkSpec
withReconnect g a server retryCount $ \c -> downloadXFTPChunk g c replicaKey (unChunkReplicaId replicaId) chunkSpec
logInfo $ "downloaded chunk " <> tshow chunkNo <> " to " <> T.pack chunkPath
downloaded <- atomically . stateTVar downloadedChunks $ \cs ->
let cs' = fromIntegral (unFileSize chunkSize) : cs in (sum cs', cs')
@@ -472,12 +472,12 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath,
ifM (doesDirectoryExist path) (uniqueCombine path name) $
ifM (doesFileExist path) (throwError "File already exists") (pure path)
_ -> (`uniqueCombine` name) . (</> "Downloads") =<< getHomeDirectory
acknowledgeFileChunk :: XFTPClientAgent -> FileChunk -> ExceptT CLIError IO ()
acknowledgeFileChunk a FileChunk {replicas = replica : _} = do
acknowledgeFileChunk :: TVar ChaChaDRG -> XFTPClientAgent -> FileChunk -> ExceptT CLIError IO ()
acknowledgeFileChunk g a FileChunk {replicas = replica : _} = do
let FileChunkReplica {server, replicaId, replicaKey} = replica
c <- withRetry retryCount $ getXFTPServerClient a server
c <- withRetry retryCount $ getXFTPServerClient g a server
withRetry retryCount $ ackXFTPChunk c replicaKey (unChunkReplicaId replicaId)
acknowledgeFileChunk _ _ = throwError $ CLIError "chunk has no replicas"
acknowledgeFileChunk _ _ _ = throwError $ CLIError "chunk has no replicas"
printProgress :: String -> Int64 -> Int64 -> IO ()
printProgress s part total = printNoNewLine $ s <> " " <> show ((part * 100) `div` total) <> "%"
@@ -501,7 +501,8 @@ cliDeleteFile DeleteOptions {fileDescription, retryCount, yes} = do
deleteFileChunk :: XFTPClientAgent -> FileChunk -> ExceptT CLIError IO ()
deleteFileChunk a FileChunk {chunkNo, replicas = replica : _} = do
let FileChunkReplica {server, replicaId, replicaKey} = replica
withReconnect a server retryCount $ \c -> deleteXFTPChunk c replicaKey (unChunkReplicaId replicaId)
g <- liftIO C.newRandom
withReconnect g a server retryCount $ \c -> deleteXFTPChunk c replicaKey (unChunkReplicaId replicaId)
logInfo $ "deleted chunk " <> tshow chunkNo <> " from " <> showServer server
deleteFileChunk _ _ = throwError $ CLIError "chunk has no replicas"
@@ -569,9 +570,9 @@ prepareChunkSpecs filePath chunkSizes = reverse . snd $ foldl' addSpec (0, []) c
getEncPath :: MonadIO m => Maybe FilePath -> String -> m FilePath
getEncPath path name = (`uniqueCombine` (name <> ".encrypted")) =<< maybe (liftIO getCanonicalTemporaryDirectory) pure path
withReconnect :: Show e => XFTPClientAgent -> XFTPServer -> Int -> (XFTPClient -> ExceptT e IO a) -> ExceptT CLIError IO a
withReconnect a srv n run = withRetry n $ do
c <- withRetry n $ getXFTPServerClient a srv
withReconnect :: Show e => TVar ChaChaDRG -> XFTPClientAgent -> XFTPServer -> Int -> (XFTPClient -> ExceptT e IO a) -> ExceptT CLIError IO a
withReconnect g a srv n run = withRetry n $ do
c <- withRetry n $ getXFTPServerClient g a srv
withExceptT (CLIError . show) (run c) `catchError` \e -> do
liftIO $ closeXFTPServerClient a srv
throwError e
+4 -4
View File
@@ -25,7 +25,7 @@ import Data.List.NonEmpty (NonEmpty (..))
import Data.Maybe (isNothing)
import Data.Type.Equality
import Data.Word (Word32)
import Simplex.FileTransfer.Transport (VersionXFTP, XFTPErrorType (..), XFTPVersion, pattern VersionXFTP, xftpClientHandshake)
import Simplex.FileTransfer.Transport (VersionXFTP, XFTPErrorType (..), XFTPVersion, xftpClientHandshakeStub, pattern VersionXFTP)
import Simplex.Messaging.Client (authTransmission)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
@@ -144,7 +144,7 @@ instance FilePartyI p => ProtocolMsgTag (FileCommandTag p) where
instance Protocol XFTPVersion XFTPErrorType FileResponse where
type ProtoCommand FileResponse = FileCmd
type ProtoType FileResponse = 'PXFTP
protocolClientHandshake = xftpClientHandshake
protocolClientHandshake = xftpClientHandshakeStub
protocolPing = FileCmd SFRecipient PING
protocolError = \case
FRErr e -> Just e
@@ -329,9 +329,9 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of
_ -> Nothing
xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString
xftpEncodeAuthTransmission thParams pKey (corrId, fId, msg) = do
xftpEncodeAuthTransmission thParams@THandleParams {thAuth} pKey (corrId, fId, msg) = do
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, fId, msg)
xftpEncodeBatch1 . (,tToSend) =<< authTransmission Nothing (Just pKey) corrId tForAuth
xftpEncodeBatch1 . (,tToSend) =<< authTransmission thAuth (Just pKey) corrId tForAuth
xftpEncodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> Transmission c -> Either TransportError ByteString
xftpEncodeTransmission thParams (corrId, fId, msg) = do
+130 -58
View File
@@ -1,4 +1,3 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
@@ -10,6 +9,7 @@
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module Simplex.FileTransfer.Server where
@@ -18,7 +18,8 @@ import Control.Monad
import Control.Monad.Except
import Control.Monad.Reader
import Data.Bifunctor (first)
import Data.ByteString.Builder (byteString)
import qualified Data.ByteString.Base64.URL as B64
import Data.ByteString.Builder (Builder, byteString)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Int (Int64)
@@ -32,6 +33,7 @@ import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
import Data.Time.Format.ISO8601 (iso8601Show)
import Data.Word (Word32)
import qualified Data.X509 as X
import GHC.IO.Handle (hSetNewlineMode)
import GHC.Stats (getRTSStats)
import qualified Network.HTTP.Types as N
@@ -46,19 +48,22 @@ import Simplex.FileTransfer.Server.StoreLog
import Simplex.FileTransfer.Transport
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (CorrId, RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TransmissionAuth)
import Simplex.Messaging.Protocol (CorrId (..), RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TransmissionAuth)
import Simplex.Messaging.Server (dummyVerifyCmd, verifyCmdAuthorization)
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Server.Stats
import Simplex.Messaging.Transport (THandleParams (..))
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (SessionId, THandleAuth (..), THandleParams (..))
import Simplex.Messaging.Transport.Buffer (trimCR)
import Simplex.Messaging.Transport.HTTP2
import Simplex.Messaging.Transport.HTTP2.File (fileBlockSize)
import Simplex.Messaging.Transport.HTTP2.Server
import Simplex.Messaging.Transport.Server (runTCPServer)
import Simplex.Messaging.Transport.Server (runTCPServer, tlsServerCredentials)
import Simplex.Messaging.Util
import Simplex.Messaging.Version (isCompatible)
import System.Exit (exitFailure)
import System.FilePath ((</>))
import System.IO (hPrint, hPutStrLn, universalNewlineMode)
@@ -70,7 +75,7 @@ import qualified UnliftIO.Exception as E
type M a = ReaderT XFTPEnv IO a
data XFTPTransportRequest = XFTPTransportRequest
{ thParams :: THandleParams XFTPVersion,
{ thParams :: THandleParamsXFTP,
reqBody :: HTTP2Body,
request :: H.Request,
sendResponse :: H.Response -> IO ()
@@ -84,20 +89,75 @@ runXFTPServer cfg = do
runXFTPServerBlocking :: TMVar Bool -> XFTPServerConfig -> IO ()
runXFTPServerBlocking started cfg = newXFTPServerEnv cfg >>= runReaderT (xftpServer cfg started)
data Handshake
= HandshakeSent C.PrivateKeyX25519
| HandshakeAccepted THandleAuth VersionXFTP
xftpServer :: XFTPServerConfig -> TMVar Bool -> M ()
xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpiration} started = do
xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpiration, fileExpiration} started = do
mapM_ (expireServerFiles Nothing) fileExpiration
restoreServerStats
raceAny_ (runServer : expireFilesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg) `finally` stopServer
where
runServer :: M ()
runServer = do
serverParams <- asks tlsServerParams
let (chain, pk) = tlsServerCredentials serverParams
signKey <- liftIO $ case C.x509ToPrivate (pk, []) >>= C.privKey of
Right pk' -> pure pk'
Left e -> putStrLn ("servers has no valid key: " <> show e) >> exitFailure
env <- ask
liftIO $
runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams transportConfig inactiveClientExpiration $ \sessionId r sendResponse -> do
reqBody <- getHTTP2Body r xftpBlockSize
let thParams = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = currentXFTPVersion, thAuth = Nothing, implySessId = False, batch = True}
processRequest XFTPTransportRequest {thParams, request = r, reqBody, sendResponse} `runReaderT` env
sessions <- atomically TM.empty
let cleanup sessionId = atomically $ TM.delete sessionId sessions
liftIO . runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams transportConfig inactiveClientExpiration cleanup $ \sessionId sessionALPN r sendResponse -> do
reqBody <- getHTTP2Body r xftpBlockSize
let thParams0 = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = VersionXFTP 1, thAuth = Nothing, implySessId = False, batch = True}
req0 = XFTPTransportRequest {thParams = thParams0, request = r, reqBody, sendResponse}
flip runReaderT env $ case sessionALPN of
Nothing -> processRequest req0
Just "xftp/1" ->
xftpServerHandshakeV1 chain signKey sessions req0 >>= \case
Nothing -> pure () -- handshake response sent
Just thParams -> processRequest req0 {thParams} -- proceed with new version (XXX: may as well switch the request handler here)
_ -> liftIO . sendResponse $ H.responseNoBody N.ok200 [] -- shouldn't happen: means server picked handshake protocol it doesn't know about
xftpServerHandshakeV1 :: X.CertificateChain -> C.APrivateSignKey -> TMap SessionId Handshake -> XFTPTransportRequest -> M (Maybe (THandleParams XFTPVersion))
xftpServerHandshakeV1 chain serverSignKey sessions XFTPTransportRequest {thParams = thParams@THandleParams {sessionId}, reqBody = HTTP2Body {bodyHead}, sendResponse} = do
s <- atomically $ TM.lookup sessionId sessions
r <- runExceptT $ case s of
Nothing -> processHello
Just (HandshakeSent pk) -> processClientHandshake pk
Just (HandshakeAccepted auth v) -> pure $ Just thParams {thAuth = Just auth, thVersion = v}
either sendError pure r
where
processHello = do
unless (B.null bodyHead) $ throwError HANDSHAKE
(k, pk) <- atomically . C.generateKeyPair =<< asks random
atomically $ TM.insert sessionId (HandshakeSent pk) sessions
let authPubKey = (chain, C.signX509 serverSignKey $ C.publicToX509 k)
let hs = XFTPServerHandshake {xftpVersionRange = supportedFileServerVRange, sessionId, authPubKey}
shs <- encodeXftp hs
liftIO . sendResponse $ H.responseBuilder N.ok200 [] shs
pure Nothing
processClientHandshake privKey = do
unless (B.length bodyHead == xftpBlockSize) $ throwError HANDSHAKE
body <- liftHS $ C.unPad bodyHead
XFTPClientHandshake {xftpVersion, keyHash, authPubKey} <- liftHS $ smpDecode body
kh <- asks serverIdentity
unless (keyHash == kh) $ throwError HANDSHAKE
unless (xftpVersion `isCompatible` supportedFileServerVRange) $ throwError HANDSHAKE
let auth = THandleAuth {peerPubKey = authPubKey, privKey}
atomically $ TM.insert sessionId (HandshakeAccepted auth xftpVersion) sessions
liftIO . sendResponse $ H.responseNoBody N.ok200 []
pure Nothing
sendError :: XFTPErrorType -> M (Maybe (THandleParams XFTPVersion))
sendError err = do
runExceptT (encodeXftp err) >>= \case
Right bs -> liftIO . sendResponse $ H.responseBuilder N.ok200 [] bs
Left _ -> logError $ "Error encoding handshake error: " <> tshow err
pure Nothing
encodeXftp :: Encoding a => a -> ExceptT XFTPErrorType (ReaderT XFTPEnv IO) Builder
encodeXftp a = byteString <$> liftHS (C.pad (smpEncode a) xftpBlockSize)
liftHS = liftEitherWith (const HANDSHAKE)
stopServer :: M ()
stopServer = do
@@ -110,28 +170,10 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
expireFiles :: ExpirationConfig -> M ()
expireFiles expCfg = do
st <- asks store
let interval = checkInterval expCfg * 1000000
forever $ do
liftIO $ threadDelay' interval
old <- liftIO $ expireBeforeEpoch expCfg
sIds <- M.keysSet <$> readTVarIO (files st)
forM_ sIds $ \sId -> do
threadDelay 100000
atomically (expiredFilePath st sId old)
>>= mapM_ (maybeRemove $ delete st sId)
where
maybeRemove del = maybe del (remove del)
remove del filePath =
ifM
(doesFileExist filePath)
((removeFile filePath >> del) `catch` \(e :: SomeException) -> logError $ "failed to remove expired file " <> tshow filePath <> ": " <> tshow e)
del
delete st sId = do
withFileLog (`logDeleteFile` sId)
void $ atomically $ deleteFile st sId
FileServerStats {filesExpired} <- asks serverStats
atomically $ modifyTVar' filesExpired (+ 1)
expireServerFiles (Just 100000) expCfg
serverStatsThread_ :: XFTPServerConfig -> [M ()]
serverStatsThread_ XFTPServerConfig {logStatsInterval = Just interval, logStatsStartTime, serverStatsLogFile} =
@@ -201,7 +243,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
role <- newTVarIO CPRNone
cpLoop h role
where
cpLoop h role = do
cpLoop h role = do
s <- trimCR <$> B.hGetLine h
case strDecode s of
Right CPQuit -> hClose h
@@ -235,12 +277,13 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira
CPQuit -> pure ()
CPSkip -> pure ()
where
withUserRole action = readTVarIO role >>= \case
CPRAdmin -> action
CPRUser -> action
_ -> do
logError "Unauthorized control port command"
hPutStrLn h "AUTH"
withUserRole action =
readTVarIO role >>= \case
CPRAdmin -> action
CPRUser -> action
_ -> do
logError "Unauthorized control port command"
hPutStrLn h "AUTH"
data ServerFile = ServerFile
{ filePath :: FilePath,
@@ -253,10 +296,11 @@ processRequest XFTPTransportRequest {thParams, reqBody = body@HTTP2Body {bodyHea
| B.length bodyHead /= xftpBlockSize = sendXFTPResponse ("", "", FRErr BLOCK) Nothing
| otherwise = do
case xftpDecodeTransmission thParams bodyHead of
Right (sig_, signed, (corrId, fId, cmdOrErr)) -> do
Right (sig_, signed, (corrId, fId, cmdOrErr)) ->
case cmdOrErr of
Right cmd -> do
verifyXFTPTransmission sig_ signed fId cmd >>= \case
let THandleParams {thAuth} = thParams
verifyXFTPTransmission ((,C.cbNonce (bs corrId)) <$> thAuth) sig_ signed fId cmd >>= \case
VRVerified req -> uncurry send =<< processXFTPRequest body req
VRFailed -> send (FRErr AUTH) Nothing
Left e -> send (FRErr e) Nothing
@@ -264,7 +308,6 @@ processRequest XFTPTransportRequest {thParams, reqBody = body@HTTP2Body {bodyHea
send resp = sendXFTPResponse (corrId, fId, resp)
Left e -> sendXFTPResponse ("", "", FRErr e) Nothing
where
sendXFTPResponse :: (CorrId, XFTPFileId, FileResponse) -> Maybe ServerFile -> M ()
sendXFTPResponse (corrId, fId, resp) serverFile_ = do
let t_ = xftpEncodeTransmission thParams (corrId, fId, resp)
liftIO $ sendResponse $ H.responseStreaming N.ok200 [] $ streamBody t_
@@ -283,8 +326,8 @@ processRequest XFTPTransportRequest {thParams, reqBody = body@HTTP2Body {bodyHea
data VerificationResult = VRVerified XFTPRequest | VRFailed
verifyXFTPTransmission :: Maybe TransmissionAuth -> ByteString -> XFTPFileId -> FileCmd -> M VerificationResult
verifyXFTPTransmission tAuth authorized fId cmd =
verifyXFTPTransmission :: Maybe (THandleAuth, C.CbNonce) -> Maybe TransmissionAuth -> ByteString -> XFTPFileId -> FileCmd -> M VerificationResult
verifyXFTPTransmission auth_ tAuth authorized fId cmd =
case cmd of
FileCmd SFSender (FNEW file rcps auth') -> pure $ XFTPReqNew file rcps auth' `verifyWith` sndKey file
FileCmd SFRecipient PING -> pure $ VRVerified XFTPReqPing
@@ -299,7 +342,7 @@ verifyXFTPTransmission tAuth authorized fId cmd =
Right (fr, k) -> XFTPReqCmd fId fr cmd `verifyWith` k
_ -> maybe False (dummyVerifyCmd Nothing authorized) tAuth `seq` VRFailed
-- TODO verify with DH authorization
req `verifyWith` k = if verifyCmdAuthorization Nothing tAuth authorized k then VRVerified req else VRFailed
req `verifyWith` k = if verifyCmdAuthorization auth_ tAuth authorized k then VRVerified req else VRFailed
processXFTPRequest :: HTTP2Body -> XFTPRequest -> M (FileResponse, Maybe ServerFile)
processXFTPRequest HTTP2Body {bodyPart} = \case
@@ -392,7 +435,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
\used -> let used' = used + fromIntegral size in if used' <= quota then (True, used') else (False, used)
receive = do
path <- asks $ filesPath . config
let fPath = path </> B.unpack (U.encode senderId)
let fPath = path </> B.unpack (B64.encode senderId)
receiveChunk (XFTPRcvChunkSpec fPath size digest) >>= \case
Right () -> do
stats <- asks serverStats
@@ -413,18 +456,20 @@ processXFTPRequest HTTP2Body {bodyPart} = \case
sendServerFile :: FileRec -> RcvPublicDhKey -> M (FileResponse, Maybe ServerFile)
sendServerFile FileRec {senderId, filePath, fileInfo = FileInfo {size}} rDhKey = do
readTVarIO filePath >>= \case
Just path -> do
g <- asks random
(sDhKey, spDhKey) <- atomically $ C.generateKeyPair g
let dhSecret = C.dh' rDhKey spDhKey
cbNonce <- atomically $ C.randomCbNonce g
case LC.cbInit dhSecret cbNonce of
Right sbState -> do
stats <- asks serverStats
atomically $ modifyTVar' (fileDownloads stats) (+ 1)
atomically $ updatePeriodStats (filesDownloaded stats) senderId
pure (FRFile sDhKey cbNonce, Just ServerFile {filePath = path, fileSize = size, sbState})
_ -> pure (FRErr INTERNAL, Nothing)
Just path -> ifM (doesFileExist path) sendFile (pure (FRErr AUTH, Nothing))
where
sendFile = do
g <- asks random
(sDhKey, spDhKey) <- atomically $ C.generateKeyPair g
let dhSecret = C.dh' rDhKey spDhKey
cbNonce <- atomically $ C.randomCbNonce g
case LC.cbInit dhSecret cbNonce of
Right sbState -> do
stats <- asks serverStats
atomically $ modifyTVar' (fileDownloads stats) (+ 1)
atomically $ updatePeriodStats (filesDownloaded stats) senderId
pure (FRFile sDhKey cbNonce, Just ServerFile {filePath = path, fileSize = size, sbState})
_ -> pure (FRErr INTERNAL, Nothing)
_ -> pure (FRErr NO_FILE, Nothing)
deleteServerFile :: FileRec -> M FileResponse
@@ -457,6 +502,33 @@ deleteServerFile_ FileRec {senderId, fileInfo, filePath} = do
atomically $ modifyTVar' (filesCount stats) (subtract 1)
atomically $ modifyTVar' (filesSize stats) (subtract $ fromIntegral $ size fileInfo)
expireServerFiles :: Maybe Int -> ExpirationConfig -> M ()
expireServerFiles itemDelay expCfg = do
st <- asks store
usedStart <- readTVarIO $ usedStorage st
old <- liftIO $ expireBeforeEpoch expCfg
files' <- readTVarIO (files st)
logInfo $ "Expiration check: " <> tshow (M.size files') <> " files"
forM_ (M.keys files') $ \sId -> do
mapM_ threadDelay itemDelay
atomically (expiredFilePath st sId old)
>>= mapM_ (maybeRemove $ delete st sId)
usedEnd <- readTVarIO $ usedStorage st
logInfo $ "Used " <> mbs usedStart <> " -> " <> mbs usedEnd <> ", " <> mbs (usedStart - usedEnd) <> " reclaimed."
where
mbs bs = tshow (bs `div` 1048576) <> "mb"
maybeRemove del = maybe del (remove del)
remove del filePath =
ifM
(doesFileExist filePath)
((removeFile filePath >> del) `catch` \(e :: SomeException) -> logError $ "failed to remove expired file " <> tshow filePath <> ": " <> tshow e)
del
delete st sId = do
withFileLog (`logDeleteFile` sId)
void . atomically $ deleteFile st sId -- will not update usedStorage if sId isn't in store
FileServerStats {filesExpired} <- asks serverStats
atomically $ modifyTVar' filesExpired (+ 1)
randomId :: Int -> M ByteString
randomId n = atomically . C.randomBytes n =<< asks random
+17 -2
View File
@@ -2,19 +2,23 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StrictData #-}
module Simplex.FileTransfer.Server.Env where
import Control.Logger.Simple (logInfo)
import Control.Logger.Simple
import Control.Monad
import Control.Monad.IO.Unlift
import Crypto.Random
import Data.Default (def)
import Data.Int (Int64)
import Data.List (find)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import Data.Time.Clock (getCurrentTime)
import Data.Word (Word32)
import Data.X509.Validation (Fingerprint (..))
@@ -27,6 +31,7 @@ import Simplex.FileTransfer.Server.StoreLog
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol (BasicAuth, RcvPublicAuthKey)
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Transport (ALPN)
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
import Simplex.Messaging.Util (tshow)
import System.IO (IOMode (..))
@@ -94,6 +99,9 @@ defaultFileExpiration =
checkInterval = 2 * 3600 -- seconds, 2 hours
}
supportedXFTPhandshakes :: [ALPN]
supportedXFTPhandshakes = ["xftp/1"]
newXFTPServerEnv :: XFTPServerConfig -> IO XFTPEnv
newXFTPServerEnv config@XFTPServerConfig {storeLogFile, fileSizeQuota, caCertificateFile, certificateFile, privateKeyFile} = do
random <- liftIO C.newRandom
@@ -104,7 +112,14 @@ newXFTPServerEnv config@XFTPServerConfig {storeLogFile, fileSizeQuota, caCertifi
forM_ fileSizeQuota $ \quota -> do
logInfo $ "Total / available storage: " <> tshow quota <> " / " <> tshow (quota - used)
when (quota < used) $ logInfo "WARNING: storage quota is less than used storage, no files can be uploaded!"
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
tlsServerParams' <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
let tlsServerParams =
tlsServerParams'
{ T.serverHooks =
def
{ T.onALPNClientSuggest = Just $ pure . fromMaybe "" . find (`elem` supportedXFTPhandshakes)
}
}
Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile
serverStats <- atomically . newFileServerStats =<< liftIO getCurrentTime
pure XFTPEnv {config, store, storeLog, random, tlsServerParams, serverIdentity = C.KeyHash fp, serverStats}
+2 -1
View File
@@ -22,6 +22,7 @@ import Control.Concurrent.STM
import Control.Monad.Except
import qualified Data.Attoparsec.ByteString.Char8 as A
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Composition ((.:), (.:.))
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
@@ -88,7 +89,7 @@ readWriteFileStore f st = do
pure s
readFileStore :: FilePath -> FileStore -> IO ()
readFileStore f st = mapM_ addFileLogRecord . B.lines =<< B.readFile f
readFileStore f st = mapM_ (addFileLogRecord . LB.toStrict) . LB.lines =<< LB.readFile f
where
addFileLogRecord s = case strDecode s of
Left e -> B.putStrLn $ "Log parsing error (" <> B.pack e <> "): " <> B.take 100 s
+58 -6
View File
@@ -9,9 +9,16 @@
module Simplex.FileTransfer.Transport
( supportedFileServerVRange,
xftpClientHandshake, -- stub
XFTPVersion,
xftpClientHandshakeStub,
XFTPClientHandshake (..),
-- xftpClientHandshake,
XFTPServerHandshake (..),
-- xftpServerHandshake,
THandleXFTP,
THandleParamsXFTP,
VersionXFTP,
VersionRangeXFTP,
XFTPVersion,
pattern VersionXFTP,
XFTPErrorType (..),
XFTPRcvChunkSpec (..),
@@ -30,20 +37,21 @@ import Control.Monad.Except
import Control.Monad.IO.Class
import qualified Data.Aeson.TH as J
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.Bifunctor (bimap, first)
import qualified Data.ByteArray as BA
import Data.ByteString.Builder (Builder, byteString)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Word (Word16, Word32)
import qualified Data.X509 as X
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers
import Simplex.Messaging.Protocol (CommandError)
import Simplex.Messaging.Transport (HandshakeError (..), THandle, TransportError (..))
import Simplex.Messaging.Transport (HandshakeError (..), SessionId, THandle (..), THandleParams (..), TransportError (..))
import Simplex.Messaging.Transport.HTTP2.File
import Simplex.Messaging.Util (bshow)
import Simplex.Messaging.Version
@@ -68,6 +76,9 @@ type VersionRangeXFTP = VersionRange XFTPVersion
pattern VersionXFTP :: Word16 -> VersionXFTP
pattern VersionXFTP v = Version v
type THandleXFTP c = THandle XFTPVersion c
type THandleParamsXFTP = THandleParams XFTPVersion
initialXFTPVersion :: VersionXFTP
initialXFTPVersion = VersionXFTP 1
@@ -75,8 +86,45 @@ supportedFileServerVRange :: VersionRangeXFTP
supportedFileServerVRange = mkVersionRange initialXFTPVersion initialXFTPVersion
-- XFTP protocol does not support handshake
xftpClientHandshake :: c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> ExceptT TransportError IO (THandle XFTPVersion c)
xftpClientHandshake _c _ks _keyHash _xftpVRange = throwError $ TEHandshake VERSION
xftpClientHandshakeStub :: c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> ExceptT TransportError IO (THandle XFTPVersion c)
xftpClientHandshakeStub _c _ks _keyHash _xftpVRange = throwError $ TEHandshake VERSION
data XFTPServerHandshake = XFTPServerHandshake
{ xftpVersionRange :: VersionRangeXFTP,
sessionId :: SessionId,
-- | pub key to agree shared secrets for command authorization and entity ID encryption.
authPubKey :: (X.CertificateChain, X.SignedExact X.PubKey)
}
data XFTPClientHandshake = XFTPClientHandshake
{ -- | agreed XFTP server protocol version
xftpVersion :: VersionXFTP,
-- | server identity - CA certificate fingerprint
keyHash :: C.KeyHash,
-- | pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys.
authPubKey :: C.PublicKeyX25519
}
instance Encoding XFTPClientHandshake where
smpEncode XFTPClientHandshake {xftpVersion, keyHash, authPubKey} =
smpEncode (xftpVersion, keyHash, authPubKey)
smpP = do
(xftpVersion, keyHash) <- smpP
authPubKey <- smpP
Tail _compat <- smpP
pure XFTPClientHandshake {xftpVersion, keyHash, authPubKey}
instance Encoding XFTPServerHandshake where
smpEncode XFTPServerHandshake {xftpVersionRange, sessionId, authPubKey} =
smpEncode (xftpVersionRange, sessionId, auth)
where
auth = bimap C.encodeCertChain C.SignedObject authPubKey
smpP = do
(xftpVersionRange, sessionId) <- smpP
cert <- C.certChainP
C.SignedObject key <- smpP
Tail _compat <- smpP
pure XFTPServerHandshake {xftpVersionRange, sessionId, authPubKey = (cert, key)}
sendEncFile :: Handle -> (Builder -> IO ()) -> LC.SbState -> Word32 -> IO ()
sendEncFile h send = go
@@ -139,6 +187,8 @@ data XFTPErrorType
BLOCK
| -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929)
SESSION
| -- | incorrect handshake command
HANDSHAKE
| -- | SMP command is unknown or has invalid syntax
CMD {cmdErr :: CommandError}
| -- | command authorization error - bad signature or non-existing SMP queue
@@ -181,6 +231,7 @@ instance Encoding XFTPErrorType where
smpEncode = \case
BLOCK -> "BLOCK"
SESSION -> "SESSION"
HANDSHAKE -> "HANDSHAKE"
CMD err -> "CMD " <> smpEncode err
AUTH -> "AUTH"
SIZE -> "SIZE"
@@ -199,6 +250,7 @@ instance Encoding XFTPErrorType where
A.takeTill (== ' ') >>= \case
"BLOCK" -> pure BLOCK
"SESSION" -> pure SESSION
"HANDSHAKE" -> pure HANDSHAKE
"CMD" -> CMD <$> _smpP
"AUTH" -> pure AUTH
"SIZE" -> pure SIZE
+29 -11
View File
@@ -82,6 +82,7 @@ module Simplex.Messaging.Agent
setNtfServers,
setNetworkConfig,
getNetworkConfig,
setUserNetworkInfo,
reconnectAllServers,
registerNtfToken,
verifyNtfToken,
@@ -402,17 +403,32 @@ testProtocolServer c userId srv = withAgentEnv' c $ case protocolTypeI @p of
SPXFTP -> runXFTPServerTest c userId srv
SPNTF -> runNTFServerTest c userId srv
-- | set SOCKS5 proxy on/off and optionally set TCP timeout
-- | set SOCKS5 proxy on/off and optionally set TCP timeouts for fast network
setNetworkConfig :: AgentClient -> NetworkConfig -> IO ()
setNetworkConfig c cfg' = do
cfg <- atomically $ do
swapTVar (useNetworkConfig c) cfg'
when (cfg /= cfg') $ reconnectAllServers c
setNetworkConfig c@AgentClient {useNetworkConfig} cfg' = do
changed <- atomically $ do
(_, cfg) <- readTVar useNetworkConfig
if cfg == cfg'
then pure False
else True <$ (writeTVar useNetworkConfig $! (slowNetworkConfig cfg', cfg'))
when changed $ reconnectAllServers c
-- returns fast network config
getNetworkConfig :: AgentClient -> IO NetworkConfig
getNetworkConfig = readTVarIO . useNetworkConfig
getNetworkConfig = fmap snd . readTVarIO . useNetworkConfig
{-# INLINE getNetworkConfig #-}
setUserNetworkInfo :: AgentClient -> UserNetworkInfo -> IO ()
setUserNetworkInfo c@AgentClient {userNetworkState} UserNetworkInfo {networkType = nt'} = withAgentEnv' c $ do
d <- asks $ initialInterval . userNetworkInterval . config
ts <- liftIO getCurrentTime
atomically $ do
ns@UserNetworkState {networkType = nt} <- readTVar userNetworkState
when (nt' /= nt) $
writeTVar userNetworkState $! case nt' of
UNNone -> ns {networkType = nt', offline = Just UNSOffline {offlineDelay = d, offlineFrom = ts}}
_ -> ns {networkType = nt', offline = Nothing}
reconnectAllServers :: AgentClient -> IO ()
reconnectAllServers c = do
reconnectServerClients c smpClients
@@ -1267,6 +1283,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq (Worker {doWork
let mId = unId msgId
ri' = maybe id updateRetryInterval2 msgRetryState ri
withRetryLock2 ri' qLock $ \riState loop -> do
lift $ waitForUserNetwork c
resp <- tryError $ case msgType of
AM_CONN_INFO -> sendConfirmation c sq msgBody
AM_CONN_INFO_REPLY -> sendConfirmation c sq msgBody
@@ -2047,7 +2064,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
pure ack'
where
queueDrained = case conn of
DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq)
DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ A_QCONT (sndAddress rq)
_ -> pure ()
processClientMsg srvTs msgFlags msgBody = do
clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <-
@@ -2096,7 +2113,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
notify $ MSG msgMeta msgFlags body
pure ACKPending
A_RCVD rcpts -> qDuplex conn'' "RCVD" $ messagesRcvd rcpts msgMeta
QCONT addr -> qDuplexAckDel conn'' "QCONT" $ continueSending srvMsgId addr
A_QCONT addr -> qDuplexAckDel conn'' "QCONT" $ continueSending srvMsgId addr
QADD qs -> qDuplexAckDel conn'' "QADD" $ qAddMsg srvMsgId qs
QKEY qs -> qDuplexAckDel conn'' "QKEY" $ qKeyMsg srvMsgId qs
QUSE qs -> qDuplexAckDel conn'' "QUSE" $ qUseMsg srvMsgId qs
@@ -2310,6 +2327,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
atomically $
TM.lookup (qAddress sq) (smpDeliveryWorkers c)
>>= mapM_ (\(_, retryLock) -> tryPutTMVar retryLock ())
notify QCONT
Nothing -> qError "QCONT: queue address not found"
messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> AM ACKd
@@ -2351,10 +2369,10 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v,
case L.nonEmpty keepSqs of
Just sqs' -> do
-- move inside case?
withStore' c $ \db -> mapM_ (deleteConnSndQueue db connId) delSqs
sq_@SndQueue {sndPublicKey, e2ePubKey} <- lift $ newSndQueue userId connId qInfo
let sq'' = (sq_ :: NewSndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
sq2 <- withStore c $ \db -> addConnSndQueue db connId sq''
sq2 <- withStore c $ \db -> do
liftIO $ mapM_ (deleteConnSndQueue db connId) delSqs
addConnSndQueue db connId (sq_ :: NewSndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
case (sndPublicKey, e2ePubKey) of
(Just sndPubKey, Just dhPublicKey) -> do
logServer "<--" c srv rId $ "MSG <QADD>:" <> logSecret srvMsgId <> " " <> logSecret (senderId queueAddress)
+114 -30
View File
@@ -27,6 +27,7 @@ module Simplex.Messaging.Agent.Client
withConnLock,
withConnLocks,
withInvLock,
withLockMap,
closeAgentClient,
closeProtocolServerClients,
reconnectServerClients,
@@ -80,6 +81,7 @@ module Simplex.Messaging.Agent.Client
agentClientStore,
agentDRG,
getAgentSubscriptions,
slowNetworkConfig,
Worker (..),
SessionVar (..),
SubscriptionsInfo (..),
@@ -99,6 +101,11 @@ module Simplex.Messaging.Agent.Client
agentOperations,
agentOperationBracket,
waitUntilActive,
UserNetworkInfo (..),
UserNetworkType (..),
UserNetworkState (..),
UNSOffline (..),
waitForUserNetwork,
throwWhenInactive,
throwWhenNoDelivery,
beginAgentOperation,
@@ -132,7 +139,7 @@ import Control.Applicative ((<|>))
import Control.Concurrent (ThreadId, forkIO, threadDelay)
import Control.Concurrent.Async (Async, uninterruptibleCancel)
import Control.Concurrent.STM (retry, throwSTM)
import Control.Exception (AsyncException (..))
import Control.Exception (AsyncException (..), BlockedIndefinitelyOnSTM (..))
import Control.Logger.Simple
import Control.Monad
import Control.Monad.Except
@@ -142,11 +149,13 @@ import Crypto.Random (ChaChaDRG)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Composition ((.:.))
import Data.Either (lefts, partitionEithers)
import Data.Functor (($>))
import Data.Int (Int64)
import Data.List (deleteFirstsBy, foldl', partition, (\\))
import Data.List.NonEmpty (NonEmpty (..), (<|))
import qualified Data.List.NonEmpty as L
@@ -157,7 +166,7 @@ import Data.Set (Set)
import qualified Data.Set as S
import Data.Text (Text)
import Data.Text.Encoding
import Data.Time (UTCTime, defaultTimeLocale, formatTime, getCurrentTime)
import Data.Time (UTCTime, defaultTimeLocale, diffUTCTime, formatTime, getCurrentTime)
import Data.Time.Clock.System (getSystemTime)
import Data.Word (Word16)
import Network.Socket (HostName)
@@ -165,7 +174,7 @@ import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientCo
import qualified Simplex.FileTransfer.Client as X
import Simplex.FileTransfer.Description (ChunkReplicaId (..), FileDigest (..), kb)
import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse)
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (DIGEST), XFTPVersion)
import Simplex.FileTransfer.Transport (XFTPErrorType (DIGEST), XFTPRcvChunkSpec (..), XFTPVersion)
import Simplex.FileTransfer.Types (DeletedSndChunkReplica (..), NewSndChunkReplica (..), RcvFileChunkReplica (..), SndFileChunk (..), SndFileChunkReplica (..))
import Simplex.FileTransfer.Util (uniqueCombine)
import Simplex.Messaging.Agent.Env.SQLite
@@ -181,7 +190,6 @@ import Simplex.Messaging.Client
import Simplex.Messaging.Client.Agent ()
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.Base64 (encode)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol
@@ -195,6 +203,7 @@ import Simplex.Messaging.Protocol
ErrorType,
MsgFlags (..),
MsgId,
NtfPublicAuthKey,
NtfServer,
NtfServerWithAuth,
ProtoServer,
@@ -206,22 +215,21 @@ import Simplex.Messaging.Protocol
QueueIdsKeys (..),
RcvMessage (..),
RcvNtfPublicDhKey,
NtfPublicAuthKey,
SMPMsgMeta (..),
SProtocolType (..),
SndPublicAuthKey,
SubscriptionMode (..),
UserProtocol,
VersionRangeSMPC,
VersionSMPC,
XFTPServer,
XFTPServerWithAuth,
VersionSMPC,
VersionRangeSMPC,
sameSrvAddr',
)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Transport (SMPVersion)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (SMPVersion)
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util
import Simplex.Messaging.Version
@@ -264,7 +272,8 @@ data AgentClient = AgentClient
ntfClients :: TMap NtfTransportSession NtfClientVar,
xftpServers :: TMap UserId (NonEmpty XFTPServerWithAuth),
xftpClients :: TMap XFTPTransportSession XFTPClientVar,
useNetworkConfig :: TVar NetworkConfig,
useNetworkConfig :: TVar (NetworkConfig, NetworkConfig), -- (slow, fast) networks
userNetworkState :: TVar UserNetworkState,
subscrConns :: TVar (Set ConnId),
activeSubs :: TRcvQueues,
pendingSubs :: TRcvQueues,
@@ -395,6 +404,23 @@ data AgentStatsKey = AgentStatsKey
}
deriving (Eq, Ord, Show)
data UserNetworkInfo = UserNetworkInfo
{ networkType :: UserNetworkType
}
deriving (Show)
data UserNetworkType = UNNone | UNCellular | UNWifi | UNEthernet | UNOther
deriving (Eq, Show)
data UserNetworkState = UserNetworkState
{ networkType :: UserNetworkType,
offline :: Maybe UNSOffline
}
deriving (Show)
data UNSOffline = UNSOffline {offlineDelay :: Int64, offlineFrom :: UTCTime}
deriving (Show)
-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's.
newAgentClient :: Int -> InitialAgentServers -> Env -> STM AgentClient
newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
@@ -410,7 +436,8 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv =
ntfClients <- TM.empty
xftpServers <- newTVar xftp
xftpClients <- TM.empty
useNetworkConfig <- newTVar netCfg
useNetworkConfig <- newTVar (slowNetworkConfig netCfg, netCfg)
userNetworkState <- newTVar $ UserNetworkState UNOther Nothing
subscrConns <- newTVar S.empty
activeSubs <- RQ.empty
pendingSubs <- RQ.empty
@@ -445,6 +472,7 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv =
xftpServers,
xftpClients,
useNetworkConfig,
userNetworkState,
subscrConns,
activeSubs,
pendingSubs,
@@ -469,6 +497,13 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv =
agentEnv
}
slowNetworkConfig :: NetworkConfig -> NetworkConfig
slowNetworkConfig cfg@NetworkConfig {tcpConnectTimeout, tcpTimeout, tcpTimeoutPerKb} =
cfg {tcpConnectTimeout = slow tcpConnectTimeout, tcpTimeout = slow tcpTimeout, tcpTimeoutPerKb = slow tcpTimeoutPerKb}
where
slow :: Integral a => a -> a
slow t = (t * 3) `div` 2
agentClientStore :: AgentClient -> SQLiteStore
agentClientStore AgentClient {agentEnv = Env {store}} = store
{-# INLINE agentClientStore #-}
@@ -531,7 +566,8 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv,
g <- asks random
env <- ask
liftError' (protocolClientError SMP $ B.unpack $ strEncode srv) $
getProtocolClient g tSess cfg (Just msgQ) $ clientDisconnected env v
getProtocolClient g tSess cfg (Just msgQ) $
clientDisconnected env v
clientDisconnected :: Env -> SMPClientVar -> SMPClient -> IO ()
clientDisconnected env v client = do
@@ -580,6 +616,7 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess =
withRetryInterval ri $ \_ loop -> do
pending <- atomically getPending
forM_ (L.nonEmpty pending) $ \qs -> do
waitForUserNetwork c
void . tryAgentError' $ reconnectSMPClient timeoutCounts c tSess qs
loop
getPending = RQ.getSessQueues tSess $ pendingSubs c
@@ -592,7 +629,7 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess =
reconnectSMPClient :: TVar Int -> AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> AM ()
reconnectSMPClient tc c tSess@(_, srv, _) qs = do
NetworkConfig {tcpTimeout} <- readTVarIO $ useNetworkConfig c
NetworkConfig {tcpTimeout} <- atomically $ getNetworkConfig c
-- this allows 3x of timeout per batch of subscription (90 queues per batch empirically)
let t = (length qs `div` 90 + 1) * tcpTimeout * 3
ExceptT (sequence <$> (t `timeout` runExceptT resubscribe)) >>= \case
@@ -634,7 +671,8 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d
cfg <- lift $ getClientConfig c ntfCfg
g <- asks random
liftError' (protocolClientError NTF $ B.unpack $ strEncode srv) $
getProtocolClient g tSess cfg Nothing $ clientDisconnected v
getProtocolClient g tSess cfg Nothing $
clientDisconnected v
clientDisconnected :: NtfClientVar -> NtfClient -> IO ()
clientDisconnected v client = do
@@ -644,7 +682,7 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
getXFTPServerClient :: AgentClient -> XFTPTransportSession -> AM XFTPClient
getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@(userId, srv, _) = do
getXFTPServerClient c@AgentClient {active, xftpClients} tSess@(userId, srv, _) = do
unlessM (readTVarIO active) . throwError $ INACTIVE
atomically (getTSessVar c tSess xftpClients)
>>= either
@@ -654,9 +692,11 @@ getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@
connectClient :: XFTPClientVar -> AM XFTPClient
connectClient v = do
cfg <- asks $ xftpCfg . config
xftpNetworkConfig <- readTVarIO useNetworkConfig
g <- asks random
xftpNetworkConfig <- atomically $ getNetworkConfig c
liftError' (protocolClientError XFTP $ B.unpack $ strEncode srv) $
X.getXFTPClient tSess cfg {xftpNetworkConfig} $ clientDisconnected v
X.getXFTPClient g tSess cfg {xftpNetworkConfig} $
clientDisconnected v
clientDisconnected :: XFTPClientVar -> XFTPClient -> IO ()
clientDisconnected v client = do
@@ -688,7 +728,7 @@ removeTSessVar' v tSess vs =
waitForProtocolClient :: ProtocolTypeI (ProtoType msg) => AgentClient -> TransportSession msg -> ClientVar msg -> AM (Client msg)
waitForProtocolClient c (_, srv, _) v = do
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v)
liftEither $ case client_ of
Just (Right smpClient) -> Right smpClient
@@ -724,11 +764,51 @@ hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerCli
hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost
getClientConfig :: AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> AM' (ProtocolClientConfig v)
getClientConfig AgentClient {useNetworkConfig} cfgSel = do
getClientConfig c cfgSel = do
cfg <- asks $ cfgSel . config
networkConfig <- readTVarIO useNetworkConfig
networkConfig <- atomically $ getNetworkConfig c
pure cfg {networkConfig}
getNetworkConfig :: AgentClient -> STM NetworkConfig
getNetworkConfig c = do
(slowCfg, fastCfg) <- readTVar (useNetworkConfig c)
UserNetworkState {networkType} <- readTVar (userNetworkState c)
pure $ case networkType of
UNCellular -> slowCfg
UNNone -> slowCfg
_ -> fastCfg
waitForUserNetwork :: AgentClient -> AM' ()
waitForUserNetwork AgentClient {userNetworkState} =
(offline <$> readTVarIO userNetworkState) >>= mapM_ waitWhileOffline
where
waitWhileOffline UNSOffline {offlineDelay = d} =
unlessM (liftIO $ waitOnline d False) $ do -- network delay reached, increase delay
ts' <- liftIO getCurrentTime
ni <- asks $ userNetworkInterval . config
atomically $ do
ns@UserNetworkState {offline} <- readTVar userNetworkState
forM_ offline $ \UNSOffline {offlineDelay = d', offlineFrom = ts} ->
-- Using `min` to avoid multiple updates in a short period of time
-- and to reset `offlineDelay` if network went `on` and `off` again.
writeTVar userNetworkState $!
let d'' = nextRetryDelay (diffToMicroseconds $ diffUTCTime ts' ts) (min d d') ni
in ns {offline = Just UNSOffline {offlineDelay = d'', offlineFrom = ts}}
waitOnline :: Int64 -> Bool -> IO Bool
waitOnline t online'
| t <= 0 = pure online'
| otherwise =
registerDelay (fromIntegral maxWait)
>>= atomically . onlineOrDelay
>>= waitOnline (t - maxWait)
where
maxWait = min t $ fromIntegral (maxBound :: Int)
onlineOrDelay delay = do
online <- isNothing . offline <$> readTVar userNetworkState
expired <- readTVar delay
unless (online || expired) retry
pure online
closeAgentClient :: AgentClient -> IO ()
closeAgentClient c = do
atomically $ writeTVar (active c) False
@@ -784,8 +864,8 @@ closeClient c clientSel tSess =
closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO ()
closeClient_ c v = do
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case
NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c
E.handle (\BlockedIndefinitelyOnSTM -> pure ()) $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case
Just (Right client) -> closeProtocolServerClient client `catchAll_` pure ()
_ -> pure ()
@@ -799,7 +879,7 @@ withConnLock c connId name = ExceptT . withConnLock' c connId name . runExceptT
withConnLock' :: AgentClient -> ConnId -> String -> AM' a -> AM' a
withConnLock' _ "" _ = id
withConnLock' AgentClient {connLocks} connId name = withLockMap_ connLocks connId name
withConnLock' AgentClient {connLocks} connId name = withLockMap connLocks connId name
{-# INLINE withConnLock' #-}
withInvLock :: AgentClient -> ByteString -> String -> AM a -> AM a
@@ -807,16 +887,16 @@ withInvLock c key name = ExceptT . withInvLock' c key name . runExceptT
{-# INLINE withInvLock #-}
withInvLock' :: AgentClient -> ByteString -> String -> AM' a -> AM' a
withInvLock' AgentClient {invLocks} = withLockMap_ invLocks
withInvLock' AgentClient {invLocks} = withLockMap invLocks
{-# INLINE withInvLock' #-}
withConnLocks :: AgentClient -> [ConnId] -> String -> AM' a -> AM' a
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null)
{-# INLINE withConnLocks #-}
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
withLockMap_ = withGetLock . getMapLock
{-# INLINE withLockMap_ #-}
withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
withLockMap = withGetLock . getMapLock
{-# INLINE withLockMap #-}
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a
withLocksMap_ = withGetLocks . getMapLock
@@ -945,13 +1025,13 @@ runXFTPServerTest :: AgentClient -> UserId -> XFTPServerWithAuth -> AM' (Maybe P
runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do
cfg <- asks $ xftpCfg . config
g <- asks random
xftpNetworkConfig <- readTVarIO $ useNetworkConfig c
xftpNetworkConfig <- atomically $ getNetworkConfig c
workDir <- getXFTPWorkPath
filePath <- getTempFilePath workDir
rcvPath <- getTempFilePath workDir
liftIO $ do
let tSess = (userId, srv, Nothing)
X.getXFTPClient tSess cfg {xftpNetworkConfig} (\_ -> pure ()) >>= \case
X.getXFTPClient g tSess cfg {xftpNetworkConfig} (\_ -> pure ()) >>= \case
Right xftp -> withTestChunk filePath $ do
(sndKey, spKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
(rcvKey, rpKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
@@ -1035,7 +1115,7 @@ mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q)
{-# INLINE mkSMPTSession #-}
getSessionMode :: AgentClient -> IO TransportSessionMode
getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig
getSessionMode = atomically . fmap sessionMode . getNetworkConfig
{-# INLINE getSessionMode #-}
newRcvQueue :: AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> AM (NewRcvQueue, SMPQueueUri)
@@ -1127,7 +1207,7 @@ sendTSessionBatches statCmd statBatchSize toRQ action c qs =
where
batchQueues :: AM' [(SMPTransportSession, NonEmpty q)]
batchQueues = do
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
mode <- atomically $ sessionMode <$> getNetworkConfig c
pure . M.assocs $ foldl' (batch mode) M.empty qs
where
batch mode m q =
@@ -1770,3 +1850,7 @@ $(J.deriveJSON defaultJSON ''WorkersSummary)
$(J.deriveJSON defaultJSON {J.fieldLabelModifier = takeWhile (/= '_')} ''AgentWorkersDetails)
$(J.deriveJSON defaultJSON ''AgentWorkersSummary)
$(J.deriveJSON (enumJSON $ dropPrefix "UN") ''UserNetworkType)
$(J.deriveJSON defaultJSON ''UserNetworkInfo)
+14 -4
View File
@@ -92,6 +92,7 @@ data AgentConfig = AgentConfig
xftpCfg :: XFTPClientConfig,
reconnectInterval :: RetryInterval,
messageRetryInterval :: RetryInterval2,
userNetworkInterval :: RetryInterval,
messageTimeout :: NominalDiffTime,
connDeleteDeliveryTimeout :: NominalDiffTime,
helloTimeout :: NominalDiffTime,
@@ -126,7 +127,7 @@ defaultReconnectInterval =
RetryInterval
{ initialInterval = 2_000000,
increaseAfter = 10_000000,
maxInterval = 180_000000
maxInterval = 60_000000
}
defaultMessageRetryInterval :: RetryInterval2
@@ -134,18 +135,26 @@ defaultMessageRetryInterval =
RetryInterval2
{ riFast =
RetryInterval
{ initialInterval = 1_000000,
{ initialInterval = 2_000000,
increaseAfter = 10_000000,
maxInterval = 60_000000
},
riSlow =
RetryInterval
{ initialInterval = 180_000000, -- 3 minutes
{ initialInterval = 300_000000, -- 5 minutes
increaseAfter = 60_000000,
maxInterval = 3 * 3600_000000 -- 3 hours
maxInterval = 6 * 3600_000000 -- 6 hours
}
}
defaultUserNetworkInterval :: RetryInterval
defaultUserNetworkInterval =
RetryInterval
{ initialInterval = 1200_000000, -- 20 minutes
increaseAfter = 0,
maxInterval = 7200_000000 -- 2 hours
}
defaultAgentConfig :: AgentConfig
defaultAgentConfig =
AgentConfig
@@ -161,6 +170,7 @@ defaultAgentConfig =
xftpCfg = defaultXFTPClientConfig,
reconnectInterval = defaultReconnectInterval,
messageRetryInterval = defaultMessageRetryInterval,
userNetworkInterval = defaultUserNetworkInterval,
messageTimeout = 2 * nominalDay,
connDeleteDeliveryTimeout = 2 * nominalDay,
helloTimeout = 2 * nominalDay,
@@ -160,7 +160,8 @@ runNtfWorker c srv Worker {doWork} = do
\nextSub@(NtfSubscription {connId}, _, _) -> do
logInfo $ "runNtfWorker, nextSub " <> tshow nextSub
ri <- asks $ reconnectInterval . config
withRetryInterval ri $ \_ loop ->
withRetryInterval ri $ \_ loop -> do
lift $ waitForUserNetwork c
processSub nextSub
`catchAgentError` retryOnError c "NtfWorker" loop (workerInternalError c connId . show)
processSub :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> AM ()
@@ -243,7 +244,8 @@ runNtfSMPWorker c srv Worker {doWork} = do
\nextSub@(NtfSubscription {connId}, _, _) -> do
logInfo $ "runNtfSMPWorker, nextSub " <> tshow nextSub
ri <- asks $ reconnectInterval . config
withRetryInterval ri $ \_ loop ->
withRetryInterval ri $ \_ loop -> do
lift $ waitForUserNetwork c
processSub nextSub
`catchAgentError` retryOnError c "NtfSMPWorker" loop (workerInternalError c connId . show)
processSub :: (NtfSubscription, NtfSubSMPAction, NtfActionTs) -> AM ()
+15 -8
View File
@@ -163,6 +163,7 @@ import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
@@ -201,7 +202,6 @@ import Simplex.Messaging.Crypto.Ratchet
SndE2ERatchetParams
)
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.Base64 (base64P, encode)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers
import Simplex.Messaging.Protocol
@@ -405,6 +405,7 @@ data ACommand (p :: AParty) (e :: AEntity) where
MSGNTF :: SMPMsgMeta -> ACommand Agent AEConn
ACK :: AgentMsgId -> Maybe MsgReceiptInfo -> ACommand Client AEConn
RCVD :: MsgMeta -> NonEmpty MsgReceipt -> ACommand Agent AEConn
QCONT :: ACommand Agent AEConn
SWCH :: ACommand Client AEConn
OFF :: ACommand Client AEConn
DEL :: ACommand Client AEConn
@@ -467,6 +468,7 @@ data ACommandTag (p :: AParty) (e :: AEntity) where
MSGNTF_ :: ACommandTag Agent AEConn
ACK_ :: ACommandTag Client AEConn
RCVD_ :: ACommandTag Agent AEConn
QCONT_ :: ACommandTag Agent AEConn
SWCH_ :: ACommandTag Client AEConn
OFF_ :: ACommandTag Client AEConn
DEL_ :: ACommandTag Client AEConn
@@ -522,6 +524,7 @@ aCommandTag = \case
MSGNTF {} -> MSGNTF_
ACK {} -> ACK_
RCVD {} -> RCVD_
QCONT -> QCONT_
SWCH -> SWCH_
OFF -> OFF_
DEL -> DEL_
@@ -996,7 +999,7 @@ agentMessageType = \case
HELLO -> AM_HELLO_
A_MSG _ -> AM_A_MSG_
A_RCVD {} -> AM_A_RCVD_
QCONT _ -> AM_QCONT_
A_QCONT _ -> AM_QCONT_
QADD _ -> AM_QADD_
QKEY _ -> AM_QKEY_
QUSE _ -> AM_QUSE_
@@ -1020,7 +1023,7 @@ data AMsgType
= HELLO_
| A_MSG_
| A_RCVD_
| QCONT_
| A_QCONT_
| QADD_
| QKEY_
| QUSE_
@@ -1033,7 +1036,7 @@ instance Encoding AMsgType where
HELLO_ -> "H"
A_MSG_ -> "M"
A_RCVD_ -> "V"
QCONT_ -> "QC"
A_QCONT_ -> "QC"
QADD_ -> "QA"
QKEY_ -> "QK"
QUSE_ -> "QU"
@@ -1046,7 +1049,7 @@ instance Encoding AMsgType where
'V' -> pure A_RCVD_
'Q' ->
A.anyChar >>= \case
'C' -> pure QCONT_
'C' -> pure A_QCONT_
'A' -> pure QADD_
'K' -> pure QKEY_
'U' -> pure QUSE_
@@ -1066,7 +1069,7 @@ data AMessage
| -- | agent envelope for delivery receipt
A_RCVD (NonEmpty AMessageReceipt)
| -- | the message instructing the client to continue sending messages (after ERR QUOTA)
QCONT SndQAddr
A_QCONT SndQAddr
| -- add queue to connection (sent by recipient), with optional address of the replaced queue
QADD (NonEmpty (SMPQueueUri, Maybe SndQAddr))
| -- key to secure the added queues and agree e2e encryption key (sent by sender)
@@ -1124,7 +1127,7 @@ instance Encoding AMessage where
HELLO -> smpEncode HELLO_
A_MSG body -> smpEncode (A_MSG_, Tail body)
A_RCVD mrs -> smpEncode (A_RCVD_, mrs)
QCONT addr -> smpEncode (QCONT_, addr)
A_QCONT addr -> smpEncode (A_QCONT_, addr)
QADD qs -> smpEncode (QADD_, qs)
QKEY qs -> smpEncode (QKEY_, qs)
QUSE qs -> smpEncode (QUSE_, qs)
@@ -1136,7 +1139,7 @@ instance Encoding AMessage where
HELLO_ -> pure HELLO
A_MSG_ -> A_MSG . unTail <$> smpP
A_RCVD_ -> A_RCVD <$> smpP
QCONT_ -> QCONT <$> smpP
A_QCONT_ -> A_QCONT <$> smpP
QADD_ -> QADD <$> smpP
QKEY_ -> QKEY <$> smpP
QUSE_ -> QUSE <$> smpP
@@ -1668,6 +1671,7 @@ instance StrEncoding ACmdTag where
"MSGNTF" -> ct MSGNTF_
"ACK" -> t ACK_
"RCVD" -> ct RCVD_
"QCONT" -> ct QCONT_
"SWCH" -> t SWCH_
"OFF" -> t OFF_
"DEL" -> t DEL_
@@ -1725,6 +1729,7 @@ instance (APartyI p, AEntityI e) => StrEncoding (ACommandTag p e) where
MSGNTF_ -> "MSGNTF"
ACK_ -> "ACK"
RCVD_ -> "RCVD"
QCONT_ -> "QCONT"
SWCH_ -> "SWCH"
OFF_ -> "OFF"
DEL_ -> "DEL"
@@ -1794,6 +1799,7 @@ commandP binaryP =
MSG_ -> s (MSG <$> strP <* A.space <*> smpP <* A.space <*> binaryP)
MSGNTF_ -> s (MSGNTF <$> strP)
RCVD_ -> s (RCVD <$> strP <* A.space <*> strP)
QCONT_ -> pure QCONT
DEL_RCVQ_ -> s (DEL_RCVQ <$> strP_ <*> strP_ <*> strP)
DEL_CONN_ -> pure DEL_CONN
DEL_USER_ -> s (DEL_USER <$> strP)
@@ -1857,6 +1863,7 @@ serializeCommand = \case
MSGNTF smpMsgMeta -> s (MSGNTF_, smpMsgMeta)
ACK mId rcptInfo_ -> s (ACK_, mId) <> maybe "" (B.cons ' ' . serializeBinary) rcptInfo_
RCVD msgMeta rcpts -> s (RCVD_, msgMeta, rcpts)
QCONT -> s QCONT_
SWCH -> s SWCH_
OFF -> s OFF_
DEL -> s DEL_
+5 -4
View File
@@ -11,6 +11,7 @@ module Simplex.Messaging.Agent.RetryInterval
withRetryIntervalCount,
withRetryLock2,
updateRetryInterval2,
nextRetryDelay,
)
where
@@ -60,7 +61,7 @@ withRetryIntervalCount ri action = callAction 0 0 $ initialInterval ri
loop = do
liftIO $ threadDelay' delay
let elapsed' = elapsed + delay
callAction (n + 1) elapsed' $ nextDelay elapsed' delay ri
callAction (n + 1) elapsed' $ nextRetryDelay elapsed' delay ri
-- This function allows action to toggle between slow and fast retry intervals.
withRetryLock2 :: forall m. MonadIO m => RetryInterval2 -> TMVar () -> (RI2State -> (RetryIntervalMode -> m ()) -> m ()) -> m ()
@@ -76,7 +77,7 @@ withRetryLock2 RetryInterval2 {riSlow, riFast} lock action =
run (elapsed, delay) ri call = do
wait delay
let elapsed' = elapsed + delay
delay' = nextDelay elapsed' delay ri
delay' = nextRetryDelay elapsed' delay ri
call (elapsed', delay')
wait delay = do
waiting <- newTVarIO True
@@ -87,8 +88,8 @@ withRetryLock2 RetryInterval2 {riSlow, riFast} lock action =
takeTMVar lock
writeTVar waiting False
nextDelay :: Int64 -> Int64 -> RetryInterval -> Int64
nextDelay elapsed delay RetryInterval {increaseAfter, maxInterval} =
nextRetryDelay :: Int64 -> Int64 -> RetryInterval -> Int64
nextRetryDelay elapsed delay RetryInterval {increaseAfter, maxInterval} =
if elapsed < increaseAfter || delay == maxInterval
then delay
else min (delay * 3 `div` 2) maxInterval
+2 -2
View File
@@ -231,6 +231,7 @@ import Data.Bifunctor (first, second)
import Data.ByteArray (ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString.Base64.URL as U
import qualified Data.ByteString.Char8 as B
import Data.Char (toLower)
import Data.Functor (($>))
@@ -270,7 +271,6 @@ import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..))
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys, PQEncryption (..), PQSupport (..))
import qualified Simplex.Messaging.Crypto.Ratchet as CR
import Simplex.Messaging.Encoding
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..))
import Simplex.Messaging.Notifications.Types
@@ -1214,7 +1214,7 @@ setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem =
db
[sql|
UPDATE ratchets
SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?, pq_priv_kem = ?
SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?, pq_priv_kem = ?
WHERE conn_id = ?
|]
(x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, pqPrivKem, connId)
+3 -3
View File
@@ -207,7 +207,7 @@ data NetworkConfig = NetworkConfig
-- | timeout of protocol commands (microseconds)
tcpTimeout :: Int,
-- | additional timeout per kilobyte (1024 bytes) to be sent
tcpTimeoutPerKb :: Int,
tcpTimeoutPerKb :: Int64,
-- | TCP keep-alive options, Nothing to skip enabling keep-alive
tcpKeepAlive :: Maybe KeepAliveOpts,
-- | period for SMP ping commands (microseconds, 0 to disable)
@@ -230,7 +230,7 @@ defaultNetworkConfig =
sessionMode = TSMUser,
tcpConnectTimeout = 20_000_000,
tcpTimeout = 15_000_000,
tcpTimeoutPerKb = 45_000, -- 45ms, should be less than 130ms to avoid Int overflow on 32 bit systems
tcpTimeoutPerKb = 5_000,
tcpKeepAlive = Just defaultKeepAliveOpts,
smpPingInterval = 600_000_000, -- 10min
smpPingCount = 3,
@@ -239,7 +239,7 @@ defaultNetworkConfig =
transportClientConfig :: NetworkConfig -> TransportClientConfig
transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} =
TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing}
TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing, alpn = Nothing}
{-# INLINE transportClientConfig #-}
-- | protocol client configuration.
+2 -2
View File
@@ -211,6 +211,8 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (bimap, first)
import Data.ByteArray (ByteArrayAccess)
import qualified Data.ByteArray as BA
import Data.ByteString.Base64 (decode, encode)
import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.ByteString.Lazy (fromStrict, toStrict)
@@ -228,8 +230,6 @@ import Database.SQLite.Simple.ToField (ToField (..))
import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+))
import Network.Transport.Internal (decodeWord16, encodeWord16)
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.Base64 (decode, encode)
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (blobFieldDecoder, parseAll, parseString)
import Simplex.Messaging.Util ((<$?>))
-29
View File
@@ -1,29 +0,0 @@
{-# LANGUAGE OverloadedStrings #-}
-- | Compatibility wrappers for base64 package, Base64 (padded) variant.
module Simplex.Messaging.Encoding.Base64 where
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Base64.Types (extractBase64)
import Data.Bifunctor (first)
import Data.ByteString.Base64 (decodeBase64Untyped, encodeBase64')
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.Text as T
encode :: ByteString -> ByteString
encode = extractBase64 . encodeBase64'
{-# INLINE encode #-}
decode :: ByteString -> Either String ByteString
decode = first T.unpack . decodeBase64Untyped
{-# INLINE decode #-}
base64P :: A.Parser ByteString
base64P = do
str <- A.takeWhile1 (`B.elem` base64Alphabet)
pad <- A.takeWhile (== '=') -- correct amount of padding can be derived from str length
either (fail . T.unpack) pure $ decodeBase64Untyped (str <> pad)
base64Alphabet :: ByteString
base64Alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
@@ -1,33 +0,0 @@
{-# LANGUAGE OverloadedStrings #-}
-- | Compatibility wrappers for base64 package, Base64URL-padded variant.
module Simplex.Messaging.Encoding.Base64.URL where
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Base64.Types (extractBase64)
import Data.Bifunctor (first)
import Data.ByteString.Base64.URL (decodeBase64Lenient, decodeBase64UnpaddedUntyped, decodeBase64Untyped, encodeBase64')
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.Text as T
encode :: ByteString -> ByteString
encode = extractBase64 . encodeBase64'
{-# INLINE encode #-}
decode :: ByteString -> Either String ByteString
decode = first T.unpack . decodeBase64Untyped
{-# INLINE decode #-}
decodeLenient :: ByteString -> ByteString
decodeLenient = decodeBase64Lenient
{-# INLINE decodeLenient #-}
base64urlP :: A.Parser ByteString
base64urlP = do
str <- A.takeWhile1 (`B.elem` base64AlphabetURL)
_pad <- A.takeWhile (== '=') -- correct amount of padding can be derived from str length
either (fail . T.unpack) pure $ decodeBase64UnpaddedUntyped str
base64AlphabetURL :: ByteString
base64AlphabetURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
+11 -6
View File
@@ -10,6 +10,7 @@ module Simplex.Messaging.Encoding.String
strToJSON,
strToJEncoding,
strParseJSON,
base64urlP,
strEncodeList,
strListP,
)
@@ -22,8 +23,10 @@ import qualified Data.Aeson.Encoding as JE
import qualified Data.Aeson.Types as JT
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (isAlphaNum)
import Data.Int (Int64)
import qualified Data.List.NonEmpty as L
import Data.Set (Set)
@@ -35,7 +38,6 @@ import Data.Time.Clock.System (SystemTime (..))
import Data.Time.Format.ISO8601
import Data.Word (Word16, Word32)
import Simplex.Messaging.Encoding
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Util ((<$?>))
@@ -52,16 +54,19 @@ class StrEncoding a where
strDecode :: ByteString -> Either String a
strDecode = parseAll strP
strP :: Parser a
strP = strDecode <$?> U.base64urlP
strP = strDecode <$?> base64urlP
-- base64url encoding/decoding of ByteStrings - the parser only allows non-empty strings
instance StrEncoding ByteString where
strEncode = U.encode
{-# INLINE strEncode #-}
strDecode = U.decode
{-# INLINE strDecode #-}
strP = U.base64urlP
{-# INLINE strP #-}
strP = base64urlP
base64urlP :: Parser ByteString
base64urlP = do
str <- A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_')
pad <- A.takeWhile (== '=')
either fail pure $ U.decode (str <> pad)
newtype Str = Str {unStr :: ByteString}
deriving (Eq, Show)
@@ -11,12 +11,14 @@
module Simplex.Messaging.Notifications.Protocol where
import Control.Applicative ((<|>))
import Data.Aeson (FromJSON (..), ToJSON (..), (.:), (.=))
import qualified Data.Aeson as J
import qualified Data.Aeson.Encoding as JE
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
import Data.Kind
import Data.Maybe (isNothing)
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
@@ -406,10 +408,12 @@ instance Encoding DeviceToken where
instance StrEncoding DeviceToken where
strEncode (DeviceToken p t) = strEncode p <> " " <> t
strP = DeviceToken <$> strP <* A.space <*> hexStringP
strP = nullToken <|> hexToken
where
nullToken = "apns_null test_ntf_token" $> DeviceToken PPApnsNull "test_ntf_token"
hexToken = DeviceToken <$> strP <* A.space <*> hexStringP
hexStringP =
A.takeWhile (\c -> A.isDigit c || (c >= 'a' && c <= 'f')) >>= \s ->
A.takeWhile (`B.elem` "0123456789abcdef") >>= \s ->
if even (B.length s) then pure s else fail "odd number of hex characters"
instance ToJSON DeviceToken where
+12 -8
View File
@@ -108,7 +108,7 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do
logServerStats :: Int64 -> Int64 -> FilePath -> M ()
logServerStats startAt logInterval statsFilePath = do
initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime
liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath
logInfo $ "server stats log enabled: " <> T.pack statsFilePath
liftIO $ threadDelay' $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0)
NtfServerStats {fromTime, tknCreated, tknVerified, tknDeleted, subCreated, subDeleted, ntfReceived, ntfDelivered, activeTokens, activeSubs} <- asks serverStats
let interval = 1000000 * logInterval
@@ -442,7 +442,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
where
processCommand :: NtfRequest -> M (Transmission NtfResponse)
processCommand = \case
NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn _ _ dhPubKey)) -> do
NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn token _ dhPubKey)) -> do
logDebug "TNEW - new token"
st <- asks store
ks@(srvDhPubKey, srvDhPrivKey) <- atomically . C.generateKeyPair =<< asks random
@@ -453,9 +453,9 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
atomically $ addNtfToken st tknId tkn
atomically $ writeTBQueue pushQ (tkn, PNVerification regCode)
withNtfLog (`logCreateToken` tkn)
incNtfStat tknCreated
incNtfStatT token tknCreated
pure (corrId, "", NRTknId tknId srvDhPubKey)
NtfReqCmd SToken (NtfTkn tkn@NtfTknData {ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey), tknCronInterval}) (corrId, tknId, cmd) -> do
NtfReqCmd SToken (NtfTkn tkn@NtfTknData {token, ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey), tknCronInterval}) (corrId, tknId, cmd) -> do
status <- readTVarIO tknStatus
(corrId,tknId,) <$> case cmd of
TNEW (NewNtfTkn _ _ dhPubKey) -> do
@@ -474,7 +474,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
updateTknStatus tkn NTActive
tIds <- atomically $ removeInactiveTokenRegistrations st tkn
forM_ tIds cancelInvervalNotifications
incNtfStat tknVerified
incNtfStatT token tknVerified
pure NROk
| otherwise -> do
logDebug "TVFY - incorrect code or token status"
@@ -493,8 +493,8 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
addNtfToken st tknId tkn'
writeTBQueue pushQ (tkn', PNVerification regCode)
withNtfLog $ \s -> logUpdateToken s tknId token' regCode
incNtfStat tknDeleted
incNtfStat tknCreated
incNtfStatT token tknDeleted
incNtfStatT token tknCreated
pure NROk
TDEL -> do
logDebug "TDEL"
@@ -504,7 +504,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
atomically $ removeSubscription ca smpServer (SPNotifier, notifierId)
cancelInvervalNotifications tknId
withNtfLog (`logDeleteToken` tknId)
incNtfStat tknDeleted
incNtfStatT token tknDeleted
pure NROk
TCRN 0 -> do
logDebug "TCRN 0"
@@ -583,6 +583,10 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
withNtfLog :: (StoreLog 'WriteMode -> IO a) -> M ()
withNtfLog action = liftIO . mapM_ action =<< asks storeLog
incNtfStatT :: DeviceToken -> (NtfServerStats -> TVar Int) -> M ()
incNtfStatT (DeviceToken PPApnsNull _) _ = pure ()
incNtfStatT _ statSel = incNtfStat statSel
incNtfStat :: (NtfServerStats -> TVar Int) -> M ()
incNtfStat statSel = do
stats <- asks serverStats
@@ -27,9 +27,8 @@ import Data.Aeson (ToJSON, (.=))
import qualified Data.Aeson as J
import qualified Data.Aeson.Encoding as JE
import qualified Data.Aeson.TH as JQ
import Data.Base64.Types (extractBase64)
import Data.Bifunctor (first)
import qualified Data.ByteString.Base64.URL as UP
import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Builder (lazyByteString)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Lazy.Char8 as LB
@@ -47,7 +46,6 @@ import Network.HTTP2.Client (Request)
import qualified Network.HTTP2.Client as H
import Network.Socket (HostName, ServiceName)
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Encoding.Base64.URL as U
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Server.Push.APNS.Internal
@@ -56,7 +54,7 @@ import Simplex.Messaging.Parsers (defaultJSON)
import Simplex.Messaging.Protocol (EncNMsgMeta)
import Simplex.Messaging.Transport.HTTP2 (HTTP2Body (..))
import Simplex.Messaging.Transport.HTTP2.Client
import Simplex.Messaging.Util (safeDecodeUtf8)
import Simplex.Messaging.Util (safeDecodeUtf8, tshow)
import System.Environment (getEnv)
import UnliftIO.STM
@@ -93,8 +91,8 @@ signedJWTToken pk (JWTToken hdr claims) = do
pure $ hc <> "." <> serialize sig
where
jwtEncode :: ToJSON a => a -> ByteString
jwtEncode = extractBase64 . UP.encodeBase64Unpadded' . LB.toStrict . J.encode
serialize sig = extractBase64 . UP.encodeBase64Unpadded' $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
jwtEncode = U.encodeUnpadded . LB.toStrict . J.encode
serialize sig = U.encodeUnpadded $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
readECPrivateKey :: FilePath -> IO EC.PrivateKey
readECPrivateKey f = do
@@ -260,11 +258,11 @@ mkApnsJWTToken appTeamId jwtHeader privateKey = do
connectHTTPS2 :: HostName -> APNSPushClientConfig -> TVar (Maybe HTTP2Client) -> IO (Either HTTP2ClientError HTTP2Client)
connectHTTPS2 apnsHost APNSPushClientConfig {apnsPort, http2cfg, caStoreFile} https2Client = do
caStore_ <- XS.readCertificateStore caStoreFile
when (isNothing caStore_) $ putStrLn $ "Error loading CertificateStore from " <> caStoreFile
when (isNothing caStore_) $ logError $ "Error loading CertificateStore from " <> T.pack caStoreFile
r <- getHTTP2Client apnsHost apnsPort caStore_ http2cfg disconnected
case r of
Right client -> atomically . writeTVar https2Client $ Just client
Left e -> putStrLn $ "Error connecting to APNS: " <> show e
Left e -> logError $ "Error connecting to APNS: " <> tshow e
pure r
where
disconnected = atomically $ writeTVar https2Client Nothing
@@ -24,9 +24,12 @@ module Simplex.Messaging.Notifications.Server.StoreLog
where
import Control.Concurrent.STM
import Control.Logger.Simple
import Control.Monad
import qualified Data.Attoparsec.ByteString.Char8 as A
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import qualified Data.Text as T
import Data.Word (Word16)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
@@ -34,7 +37,7 @@ import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Server.Store
import Simplex.Messaging.Protocol (NtfPrivateAuthKey)
import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.Util (whenM)
import Simplex.Messaging.Util (safeDecodeUtf8, whenM)
import System.Directory (doesFileExist, renameFile)
import System.IO
@@ -189,10 +192,10 @@ readWriteNtfStore f st = do
pure s
readNtfStore :: FilePath -> NtfStore -> IO ()
readNtfStore f st = mapM_ addNtfLogRecord . B.lines =<< B.readFile f
readNtfStore f st = mapM_ (addNtfLogRecord . LB.toStrict) . LB.lines =<< LB.readFile f
where
addNtfLogRecord s = case strDecode s of
Left e -> B.putStrLn $ "Log parsing error (" <> B.pack e <> "): " <> B.take 100 s
Left e -> logError $ "Log parsing error (" <> T.pack e <> "): " <> safeDecodeUtf8 (B.take 100 s)
Right lr -> atomically $ case lr of
CreateToken r@NtfTknRec {ntfTknId} -> do
tkn <- mkTknData r
+17 -1
View File
@@ -10,9 +10,10 @@ import qualified Data.Aeson as J
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (toLower)
import Data.Char (isAlphaNum, toLower)
import Data.String
import Data.Text (Text)
import qualified Data.Text as T
@@ -23,8 +24,23 @@ import Database.SQLite.Simple (ResultError (..), SQLData (..))
import Database.SQLite.Simple.FromField (FieldParser, returnError)
import Database.SQLite.Simple.Internal (Field (..))
import Database.SQLite.Simple.Ok (Ok (Ok))
import Simplex.Messaging.Util ((<$?>))
import Text.Read (readMaybe)
base64P :: Parser ByteString
base64P = decode <$?> paddedBase64 rawBase64P
paddedBase64 :: Parser ByteString -> Parser ByteString
paddedBase64 raw = (<>) <$> raw <*> pad
where
pad = A.takeWhile (== '=')
rawBase64P :: Parser ByteString
rawBase64P = A.takeWhile1 (\c -> isAlphaNum c || c == '+' || c == '/')
-- rawBase64UriP :: Parser ByteString
-- rawBase64UriP = A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_')
tsISO8601P :: Parser UTCTime
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd
+1 -1
View File
@@ -176,6 +176,7 @@ import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import qualified Data.ByteString.Base64 as B64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (isPrint, isSpace)
@@ -193,7 +194,6 @@ import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
import Network.Socket (ServiceName)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import qualified Simplex.Messaging.Encoding.Base64 as B64
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers
import Simplex.Messaging.ServiceScheme
+5 -3
View File
@@ -45,8 +45,10 @@ import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Crypto.Random
import Data.Bifunctor (first)
import Data.ByteString.Base64 (encode)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Either (fromRight, partitionEithers)
import Data.Functor (($>))
import Data.Int (Int64)
@@ -67,7 +69,6 @@ import Network.Socket (ServiceName, Socket, socketToHandle)
import Simplex.Messaging.Agent.Lock
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding (Encoding (smpEncode))
import Simplex.Messaging.Encoding.Base64 (encode)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol
import Simplex.Messaging.Server.Control
@@ -983,7 +984,7 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= \case
ms <- asks msgStore
quota <- asks $ msgQueueQuota . config
old_ <- asks (messageExpiration . config) $>>= (liftIO . fmap Just . expireBeforeEpoch)
runExceptT (liftIO (B.readFile f) >>= foldM (\expired -> restoreMsg expired ms quota old_) 0 . B.lines) >>= \case
runExceptT (liftIO (LB.readFile f) >>= foldM (\expired -> restoreMsg expired ms quota old_) 0 . LB.lines) >>= \case
Left e -> do
logError . T.pack $ "error restoring messages: " <> e
liftIO exitFailure
@@ -992,10 +993,11 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= \case
logInfo "messages restored"
pure expired
where
restoreMsg !expired ms quota old_ s = do
restoreMsg !expired ms quota old_ s' = do
MLRv3 rId msg <- liftEither . first (msgErr "parsing") $ strDecode s
addToMsgQueue rId msg
where
s = LB.toStrict s'
addToMsgQueue rId msg = do
(isExpired, logFull) <- atomically $ do
q <- getMsgQueue ms rId quota
+5 -4
View File
@@ -25,8 +25,8 @@ where
import Control.Applicative (optional, (<|>))
import Control.Monad (foldM, unless, when)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Functor (($>))
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
@@ -148,13 +148,14 @@ writeQueues s = mapM_ $ \q -> when (active q) $ logCreateQueue s q
active QueueRec {status} = status == QueueActive
readQueues :: FilePath -> IO (Map RecipientId QueueRec)
readQueues f = foldM processLine M.empty . B.lines =<< B.readFile f
readQueues f = foldM processLine M.empty . LB.lines =<< LB.readFile f
where
processLine :: Map RecipientId QueueRec -> ByteString -> IO (Map RecipientId QueueRec)
processLine m s = case strDecode $ trimCR s of
processLine :: Map RecipientId QueueRec -> LB.ByteString -> IO (Map RecipientId QueueRec)
processLine m s' = case strDecode $ trimCR s of
Right r -> pure $ procLogRecord r
Left e -> printError e $> m
where
s = LB.toStrict s'
procLogRecord :: StoreLogRecord -> Map RecipientId QueueRec
procLogRecord = \case
CreateQueue q -> M.insert (recipientId q) q m
+6 -1
View File
@@ -54,6 +54,7 @@ module Simplex.Messaging.Transport
-- * TLS Transport
TLS (..),
SessionId,
ALPN,
connectTLS,
closeTLS,
supportedParameters,
@@ -228,10 +229,13 @@ data TLS = TLS
tlsPeer :: TransportPeer,
tlsUniq :: ByteString,
tlsBuffer :: TBuffer,
tlsALPN :: Maybe ALPN,
tlsServerCerts :: X.CertificateChain,
tlsTransportConfig :: TransportConfig
}
type ALPN = ByteString
connectTLS :: T.TLSParams p => Maybe HostName -> TransportConfig -> p -> Socket -> IO T.Context
connectTLS host_ TransportConfig {logTLSErrors} params sock =
E.bracketOnError (T.contextNew sock params) closeTLS $ \ctx ->
@@ -246,7 +250,8 @@ getTLS tlsPeer cfg tlsServerCerts cxt = withTlsUnique tlsPeer cxt newTLS
where
newTLS tlsUniq = do
tlsBuffer <- atomically newTBuffer
pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsServerCerts, tlsPeer, tlsUniq, tlsBuffer}
tlsALPN <- T.getNegotiatedProtocol cxt
pure TLS {tlsContext = cxt, tlsALPN, tlsTransportConfig = cfg, tlsServerCerts, tlsPeer, tlsUniq, tlsBuffer}
withTlsUnique :: TransportPeer -> T.Context -> (ByteString -> IO c) -> IO c
withTlsUnique peer cxt f =
+12 -9
View File
@@ -17,6 +17,7 @@ module Simplex.Messaging.Transport.Client
TransportHost (..),
TransportHosts (..),
TransportHosts_ (..),
validateCertificateChain
)
where
@@ -113,12 +114,13 @@ data TransportClientConfig = TransportClientConfig
{ socksProxy :: Maybe SocksProxy,
tcpKeepAlive :: Maybe KeepAliveOpts,
logTLSErrors :: Bool,
clientCredentials :: Maybe (X.CertificateChain, T.PrivKey)
clientCredentials :: Maybe (X.CertificateChain, T.PrivKey),
alpn :: Maybe [ALPN]
}
deriving (Eq, Show)
defaultTransportClientConfig :: TransportClientConfig
defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True Nothing
defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True Nothing Nothing
clientTransportConfig :: TransportClientConfig -> TransportConfig
clientTransportConfig TransportClientConfig {logTLSErrors} =
@@ -129,10 +131,10 @@ runTransportClient :: Transport c => TransportClientConfig -> Maybe ByteString -
runTransportClient = runTLSTransportClient supportedParameters Nothing
runTLSTransportClient :: Transport c => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> IO a) -> IO a
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials} proxyUsername host port keyHash client = do
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials, alpn} proxyUsername host port keyHash client = do
serverCert <- newEmptyTMVarIO
let hostName = B.unpack $ strEncode host
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials serverCert
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials alpn serverCert
connectTCP = case socksProxy of
Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host
_ -> connectTCPClient hostName
@@ -215,14 +217,15 @@ instance ToJSON SocksProxy where
instance FromJSON SocksProxy where
parseJSON = strParseJSON "SocksProxy"
mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe (X.CertificateChain, T.PrivKey) -> TMVar X.CertificateChain -> T.ClientParams
mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ serverCerts =
mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe (X.CertificateChain, T.PrivKey) -> Maybe [ALPN] -> TMVar X.CertificateChain -> T.ClientParams
mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ alpn_ serverCerts =
(T.defaultParamsClient host p)
{ T.clientShared = def {T.sharedCAStore = fromMaybe (T.sharedCAStore def) caStore_},
T.clientHooks =
def
{ T.onServerCertificate = onServerCert,
T.onCertificateRequest = maybe def (const . pure . Just) clientCreds_
T.onCertificateRequest = maybe def (const . pure . Just) clientCreds_,
T.onSuggestALPN = pure alpn_
},
T.clientSupported = supported
}
@@ -237,7 +240,7 @@ mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ serverCerts =
validateCertificateChain :: C.KeyHash -> HostName -> ByteString -> X.CertificateChain -> IO [XV.FailedReason]
validateCertificateChain _ _ _ (X.CertificateChain []) = pure [XV.EmptyChain]
validateCertificateChain _ _ _ (X.CertificateChain [_]) = pure [XV.EmptyChain]
validateCertificateChain (C.KeyHash kh) host port cc@(X.CertificateChain sc@[_, caCert]) =
validateCertificateChain (C.KeyHash kh) host port cc@(X.CertificateChain [_, caCert]) =
if Fingerprint kh == XV.getFingerprint caCert X.HashSHA256
then x509validate
else pure [XV.UnknownCA]
@@ -247,7 +250,7 @@ validateCertificateChain (C.KeyHash kh) host port cc@(X.CertificateChain sc@[_,
where
hooks = XV.defaultHooks
checks = XV.defaultChecks {XV.checkFQHN = False}
certStore = XS.makeCertificateStore sc
certStore = XS.makeCertificateStore [caCert]
cache = XV.exceptionValidationCache [] -- we manually check fingerprint only of the identity certificate (ca.crt)
serviceID = (host, port)
validateCertificateChain _ _ _ _ = pure [XV.AuthorityTooDeep]
+3 -3
View File
@@ -16,15 +16,15 @@ import qualified Network.HTTP2.Server as HS
import Network.Socket (SockAddr (..))
import qualified Network.TLS as T
import qualified Network.TLS.Extra as TE
import Simplex.Messaging.Transport (SessionId, TLS (tlsUniq), Transport (cGet, cPut))
import Simplex.Messaging.Transport (TLS, Transport (cGet, cPut))
import Simplex.Messaging.Transport.Buffer
import qualified System.TimeManager as TI
defaultHTTP2BufferSize :: BufferSize
defaultHTTP2BufferSize = 32768
withHTTP2 :: BufferSize -> (Config -> SessionId -> IO a) -> TLS -> IO a
withHTTP2 sz run c = E.bracket (allocHTTP2Config c sz) freeSimpleConfig (`run` tlsUniq c)
withHTTP2 :: BufferSize -> (Config -> IO a) -> IO () -> TLS -> IO a
withHTTP2 sz run fin c = E.bracket (allocHTTP2Config c sz) (\cfg -> freeSimpleConfig cfg `E.finally` fin) run
allocHTTP2Config :: TLS -> BufferSize -> IO Config
allocHTTP2Config c sz = do
+32 -15
View File
@@ -23,15 +23,20 @@ import qualified Network.TLS as T
import Numeric.Natural (Natural)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Transport (SessionId, TLS)
import Simplex.Messaging.Transport (ALPN, SessionId, TLS (tlsALPN), getServerCerts, getServerVerifyKey, tlsUniq)
import Simplex.Messaging.Transport.Client (TransportClientConfig (..), TransportHost (..), runTLSTransportClient)
import Simplex.Messaging.Transport.HTTP2
import Simplex.Messaging.Util (eitherToMaybe)
import UnliftIO.STM
import UnliftIO.Timeout
import qualified Data.X509 as X
data HTTP2Client = HTTP2Client
{ action :: Maybe (Async HTTP2Response),
sessionId :: SessionId,
sessionALPN :: Maybe ALPN,
serverKey :: Maybe C.APublicVerifyKey, -- may not always be a key we control (i.e. APNS with apple-mandated key types)
serverCerts :: X.CertificateChain,
sessionTs :: UTCTime,
sendReq :: Request -> (Response -> IO HTTP2Response) -> IO HTTP2Response,
client_ :: HClient
@@ -66,7 +71,7 @@ defaultHTTP2ClientConfig =
HTTP2ClientConfig
{ qSize = 64,
connTimeout = 10000000,
transportConfig = TransportClientConfig Nothing Nothing True Nothing,
transportConfig = TransportClientConfig Nothing Nothing True Nothing Nothing,
bufferSize = defaultHTTP2BufferSize,
bodyHeadSize = 16384,
suportedTLSParams = http2TLSParams
@@ -86,9 +91,10 @@ getVerifiedHTTP2Client proxyUsername host port keyHash caStore config disconnect
attachHTTP2Client :: HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> Int -> TLS -> IO (Either HTTP2ClientError HTTP2Client)
attachHTTP2Client config host port disconnected bufferSize tls = getVerifiedHTTP2ClientWith config host port disconnected setup
where
setup :: (TLS -> H.Client HTTP2Response) -> IO HTTP2Response
setup = runHTTP2ClientWith bufferSize host ($ tls)
getVerifiedHTTP2ClientWith :: HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> ((SessionId -> H.Client HTTP2Response) -> IO HTTP2Response) -> IO (Either HTTP2ClientError HTTP2Client)
getVerifiedHTTP2ClientWith :: HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> ((TLS -> H.Client HTTP2Response) -> IO HTTP2Response) -> IO (Either HTTP2ClientError HTTP2Client)
getVerifiedHTTP2ClientWith config host port disconnected setup =
(atomically mkHTTPS2Client >>= runClient)
`E.catch` \(e :: IOException) -> pure . Left $ HCIOError e
@@ -104,15 +110,25 @@ getVerifiedHTTP2ClientWith config host port disconnected setup =
cVar <- newEmptyTMVarIO
action <- async $ setup (client c cVar) `E.finally` atomically (putTMVar cVar $ Left HCNetworkError)
c_ <- connTimeout config `timeout` atomically (takeTMVar cVar)
pure $ case c_ of
Just (Right c') -> Right c' {action = Just action}
Just (Left e) -> Left e
Nothing -> Left HCNetworkError
case c_ of
Just (Right c') -> pure $ Right c' {action = Just action}
Just (Left e) -> pure $ Left e
Nothing -> cancel action $> Left HCNetworkError
client :: HClient -> TMVar (Either HTTP2ClientError HTTP2Client) -> SessionId -> H.Client HTTP2Response
client c cVar sessionId sendReq = do
client :: HClient -> TMVar (Either HTTP2ClientError HTTP2Client) -> TLS -> H.Client HTTP2Response
client c cVar tls sendReq = do
sessionTs <- getCurrentTime
let c' = HTTP2Client {action = Nothing, client_ = c, sendReq, sessionId, sessionTs}
let c' =
HTTP2Client
{ action = Nothing,
client_ = c,
serverKey = eitherToMaybe $ getServerVerifyKey tls,
serverCerts = getServerCerts tls,
sendReq,
sessionTs,
sessionId = tlsUniq tls,
sessionALPN = tlsALPN tls
}
atomically $ do
writeTVar (connected c) True
putTMVar cVar (Right c')
@@ -154,13 +170,14 @@ sendRequestDirect HTTP2Client {client_ = HClient {config, disconnected}, sendReq
http2RequestTimeout :: HTTP2ClientConfig -> Maybe Int -> Int
http2RequestTimeout HTTP2ClientConfig {connTimeout} = maybe connTimeout (connTimeout +)
runHTTP2Client :: forall a. T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> BufferSize -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (SessionId -> H.Client a) -> IO a
runHTTP2Client :: forall a. T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> BufferSize -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (TLS -> H.Client a) -> IO a
runHTTP2Client tlsParams caStore tcConfig bufferSize proxyUsername host port keyHash = runHTTP2ClientWith bufferSize host setup
where
setup :: (TLS -> IO a) -> IO a
setup = runTLSTransportClient tlsParams caStore tcConfig proxyUsername host port keyHash
runHTTP2ClientWith :: forall a. BufferSize -> TransportHost -> ((TLS -> IO a) -> IO a) -> (SessionId -> H.Client a) -> IO a
runHTTP2ClientWith bufferSize host setup client = setup $ withHTTP2 bufferSize run
runHTTP2ClientWith :: forall a. BufferSize -> TransportHost -> ((TLS -> IO a) -> IO a) -> (TLS -> H.Client a) -> IO a
runHTTP2ClientWith bufferSize host setup client = setup $ \tls -> withHTTP2 bufferSize (run tls) (pure ()) tls
where
run :: H.Config -> SessionId -> IO a
run cfg sessId = H.run (ClientConfig "https" (strEncode host) 20) cfg $ client sessId
run :: TLS -> H.Config -> IO a
run tls cfg = H.run (ClientConfig "https" (strEncode host) 20) cfg $ client tls
+13 -12
View File
@@ -13,14 +13,14 @@ import Network.Socket
import qualified Network.TLS as T
import Numeric.Natural (Natural)
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Transport (SessionId, TLS, closeConnection)
import Simplex.Messaging.Transport (ALPN, SessionId, TLS, closeConnection, tlsALPN, tlsUniq)
import Simplex.Messaging.Transport.HTTP2
import Simplex.Messaging.Transport.Server (TransportServerConfig (..), loadSupportedTLSServerParams, runTransportServer)
import Simplex.Messaging.Util (threadDelay')
import UnliftIO (finally)
import UnliftIO.Concurrent (forkIO, killThread)
type HTTP2ServerFunc = SessionId -> Request -> (Response -> IO ()) -> IO ()
type HTTP2ServerFunc = SessionId -> Maybe ALPN -> Request -> (Response -> IO ()) -> IO ()
data HTTP2ServerConfig = HTTP2ServerConfig
{ qSize :: Natural,
@@ -37,6 +37,7 @@ data HTTP2ServerConfig = HTTP2ServerConfig
data HTTP2Request = HTTP2Request
{ sessionId :: SessionId,
sessionALPN :: Maybe ALPN,
request :: Request,
reqBody :: HTTP2Body,
sendResponse :: Response -> IO ()
@@ -54,32 +55,32 @@ getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, se
started <- newEmptyTMVarIO
reqQ <- newTBQueueIO qSize
action <- async $
runHTTP2Server started http2Port bufferSize tlsServerParams transportConfig Nothing $ \sessionId r sendResponse -> do
runHTTP2Server started http2Port bufferSize tlsServerParams transportConfig Nothing (const $ pure ()) $ \sessionId sessionALPN r sendResponse -> do
reqBody <- getHTTP2Body r bodyHeadSize
atomically $ writeTBQueue reqQ HTTP2Request {sessionId, request = r, reqBody, sendResponse}
atomically $ writeTBQueue reqQ HTTP2Request {sessionId, sessionALPN, request = r, reqBody, sendResponse}
void . atomically $ takeTMVar started
pure HTTP2Server {action, reqQ}
closeHTTP2Server :: HTTP2Server -> IO ()
closeHTTP2Server = uninterruptibleCancel . action
runHTTP2Server :: TMVar Bool -> ServiceName -> BufferSize -> T.ServerParams -> TransportServerConfig -> Maybe ExpirationConfig -> HTTP2ServerFunc -> IO ()
runHTTP2Server started port bufferSize serverParams transportConfig expCfg_ = runHTTP2ServerWith_ expCfg_ bufferSize setup
runHTTP2Server :: TMVar Bool -> ServiceName -> BufferSize -> T.ServerParams -> TransportServerConfig -> Maybe ExpirationConfig -> (SessionId -> IO ()) -> HTTP2ServerFunc -> IO ()
runHTTP2Server started port bufferSize serverParams transportConfig expCfg_ clientFinished = runHTTP2ServerWith_ expCfg_ clientFinished bufferSize setup
where
setup = runTransportServer started port serverParams transportConfig
runHTTP2ServerWith :: BufferSize -> ((TLS -> IO ()) -> a) -> HTTP2ServerFunc -> a
runHTTP2ServerWith = runHTTP2ServerWith_ Nothing
runHTTP2ServerWith = runHTTP2ServerWith_ Nothing (\_sessId -> pure ())
runHTTP2ServerWith_ :: Maybe ExpirationConfig -> BufferSize -> ((TLS -> IO ()) -> a) -> HTTP2ServerFunc -> a
runHTTP2ServerWith_ expCfg_ bufferSize setup http2Server = setup $ \tls -> do
runHTTP2ServerWith_ :: Maybe ExpirationConfig -> (SessionId -> IO ()) -> BufferSize -> ((TLS -> IO ()) -> a) -> HTTP2ServerFunc -> a
runHTTP2ServerWith_ expCfg_ clientFinished bufferSize setup http2Server = setup $ \tls -> do
activeAt <- newTVarIO =<< getSystemTime
tid_ <- mapM (forkIO . expireInactiveClient tls activeAt) expCfg_
withHTTP2 bufferSize (run activeAt) tls `finally` mapM_ killThread tid_
withHTTP2 bufferSize (run tls activeAt) (clientFinished $ tlsUniq tls) tls `finally` mapM_ killThread tid_
where
run activeAt cfg sessId = H.run cfg $ \req _aux sendResp -> do
run tls activeAt cfg = H.run cfg $ \req _aux sendResp -> do
getSystemTime >>= atomically . writeTVar activeAt
http2Server sessId req (`sendResp` [])
http2Server (tlsUniq tls) (tlsALPN tls) req (`sendResp` [])
expireInactiveClient tls activeAt expCfg = loop
where
loop = do