mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-04-27 10:45:14 +00:00
Merge remote-tracking branch 'origin/master' into ep/sntrup761
This commit is contained in:
@@ -0,0 +1 @@
|
||||
- ignore: {name: "Use underscore"}
|
||||
+1
-1
@@ -19,7 +19,7 @@ source-repository-package
|
||||
source-repository-package
|
||||
type: git
|
||||
location: https://github.com/kazu-yamamoto/http2.git
|
||||
tag: b5a1b7200cf5bc7044af34ba325284271f6dff25
|
||||
tag: 804fa283f067bd3fd89b8c5f8d25b3047813a517
|
||||
|
||||
source-repository-package
|
||||
type: git
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
indentation: 2
|
||||
column-limit: none
|
||||
function-arrows: trailing
|
||||
comma-style: trailing
|
||||
import-export-style: trailing
|
||||
indent-wheres: true
|
||||
record-brace-space: true
|
||||
newlines-between-decls: 1
|
||||
haddock-style: single-line
|
||||
haddock-style-module: null
|
||||
let-style: inline
|
||||
in-style: right-align
|
||||
single-constraint-parens: never
|
||||
unicode: never
|
||||
respectful: true
|
||||
fixities:
|
||||
- infixr 9 .
|
||||
- infixr 8 .:, .:., .=
|
||||
- infixr 6 <>
|
||||
- infixr 5 ++
|
||||
- infixl 4 <$>, <$, $>, <$$>, <$?>
|
||||
- infixl 4 <*>, <*, *>, <**>
|
||||
- infix 4 ==, /=
|
||||
- infixr 3 &&
|
||||
- infixl 3 <|>
|
||||
- infixr 2 ||
|
||||
- infixl 1 >>, >>=
|
||||
- infixr 1 =<<, >=>, <=<
|
||||
- infixr 0 $, $!
|
||||
reexports: []
|
||||
+2
-3
@@ -46,8 +46,7 @@ dependencies:
|
||||
- filepath == 1.4.*
|
||||
- hourglass == 0.2.*
|
||||
- http-types == 0.12.*
|
||||
- http2 == 4.1.*
|
||||
- generic-random == 1.5.*
|
||||
- http2 >= 4.1.4 && < 4.2
|
||||
- ini == 0.4.1
|
||||
- iproute == 1.7.*
|
||||
- iso8601-time == 0.1.*
|
||||
@@ -56,7 +55,6 @@ dependencies:
|
||||
- network >= 3.1.2.7 && < 3.2
|
||||
- network-transport == 0.5.6
|
||||
- optparse-applicative >= 0.15 && < 0.17
|
||||
- QuickCheck == 2.14.*
|
||||
- process == 1.6.*
|
||||
- random >= 1.1 && < 1.3
|
||||
- simple-logger == 0.1.*
|
||||
@@ -152,6 +150,7 @@ tests:
|
||||
dependencies:
|
||||
- simplexmq
|
||||
- deepseq == 1.4.*
|
||||
- generic-random == 1.5.*
|
||||
- hspec == 2.11.*
|
||||
- hspec-core == 2.11.*
|
||||
- HUnit == 1.6.*
|
||||
|
||||
@@ -364,9 +364,9 @@ The clients can optionally instruct a dedicated push notification server to subs
|
||||
|
||||
[`SEND` command](#send-message) includes the notification flag to instruct SMP server whether to send the notification - this flag is forwarded to the recepient inside encrypted envelope, together with the timestamp and the message body, so even if TLS is compromised this flag cannot be used for traffic correlation.
|
||||
|
||||
## SMP Transmission andtransport block structure
|
||||
## SMP Transmission and transport block structure
|
||||
|
||||
Each transport block (SMP transmission) has a fixed size of 16384 bytes for traffic uniformity.
|
||||
Each transport block has a fixed size of 16384 bytes for traffic uniformity.
|
||||
|
||||
From SMP version 4 each block can contain multiple transmissions, version 3 blocks have 1 transmission.
|
||||
Some parts of SMP transmission are padded to a fixed size; this padding is uniformly added as a word16 encoded in network byte order - see `paddedString` syntax.
|
||||
@@ -387,17 +387,17 @@ Each transmission/block for SMP v3 between the client and the server must have t
|
||||
|
||||
```abnf
|
||||
paddedTransmission = <padded(transmission), 16384>
|
||||
transmission = [signature] SP signed
|
||||
signed = sessionIdentifier SP [corrId] SP [queueId] SP smpCommand
|
||||
transmission = signature signed
|
||||
signed = sessionIdentifier corrId queueId smpCommand
|
||||
; corrId is required in client commands and server responses,
|
||||
; it is empty in server notifications.
|
||||
corrId = 1*32(%x21-7F) ; any characters other than control/whitespace
|
||||
queueId = encoded ; max 32 bytes when decoded (24 bytes is used),
|
||||
corrId = length *OCTET
|
||||
queueId = length *OCTET
|
||||
; empty queue ID is used with "create" command and in some server responses
|
||||
signature = encoded
|
||||
signature = length *OCTET
|
||||
; empty signature can be used with "send" before the queue is secured with secure command
|
||||
; signature is always empty with "ping" and "serverMsg"
|
||||
encoded = <base64 encoded binary>
|
||||
length = 1*1 OCTET
|
||||
```
|
||||
|
||||
`base64` encoding should be used with padding, as defined in section 4 of [RFC 4648][9]
|
||||
|
||||
@@ -19,7 +19,9 @@ if [ ! -f "${confd}/file-server.ini" ]; then
|
||||
# Set quota
|
||||
case "${QUOTA}" in
|
||||
'') printf 'Please specify $QUOTA environment variable.\n'; exit 1 ;;
|
||||
*) set -- "$@" --quota "${QUOTA}" ;;
|
||||
*GB) QUOTA="$(printf ${QUOTA} | tr '[:upper:]' '[:lower:]')"; set -- "$@" --quota "${QUOTA}" ;;
|
||||
*gb) set -- "$@" --quota "${QUOTA}" ;;
|
||||
*) printf 'Wrong format. Format should be: 1gb, 10gb, 100gb.\n'; exit 1 ;;
|
||||
esac
|
||||
|
||||
# Init the certificates and configs
|
||||
|
||||
+15
-25
@@ -114,6 +114,7 @@ library
|
||||
Simplex.Messaging.Notifications.Server.Env
|
||||
Simplex.Messaging.Notifications.Server.Main
|
||||
Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
Simplex.Messaging.Notifications.Server.Push.APNS.Internal
|
||||
Simplex.Messaging.Notifications.Server.Stats
|
||||
Simplex.Messaging.Notifications.Server.Store
|
||||
Simplex.Messaging.Notifications.Server.StoreLog
|
||||
@@ -140,6 +141,7 @@ library
|
||||
Simplex.Messaging.Transport.Credentials
|
||||
Simplex.Messaging.Transport.HTTP2
|
||||
Simplex.Messaging.Transport.HTTP2.Client
|
||||
Simplex.Messaging.Transport.HTTP2.File
|
||||
Simplex.Messaging.Transport.HTTP2.Server
|
||||
Simplex.Messaging.Transport.KeepAlive
|
||||
Simplex.Messaging.Transport.Server
|
||||
@@ -160,8 +162,7 @@ library
|
||||
extra-libraries:
|
||||
crypto
|
||||
build-depends:
|
||||
QuickCheck ==2.14.*
|
||||
, aeson ==2.2.*
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
, asn1-encoding ==0.9.*
|
||||
, asn1-types ==0.3.*
|
||||
@@ -180,10 +181,9 @@ library
|
||||
, direct-sqlcipher ==2.3.*
|
||||
, directory ==1.3.*
|
||||
, filepath ==1.4.*
|
||||
, generic-random ==1.5.*
|
||||
, hourglass ==0.2.*
|
||||
, http-types ==0.12.*
|
||||
, http2 ==4.1.*
|
||||
, http2 >=4.1.4 && <4.2
|
||||
, ini ==0.4.1
|
||||
, iproute ==1.7.*
|
||||
, iso8601-time ==0.1.*
|
||||
@@ -225,8 +225,7 @@ executable ntf-server
|
||||
apps/ntf-server
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
|
||||
build-depends:
|
||||
QuickCheck ==2.14.*
|
||||
, aeson ==2.2.*
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
, asn1-encoding ==0.9.*
|
||||
, asn1-types ==0.3.*
|
||||
@@ -245,10 +244,9 @@ executable ntf-server
|
||||
, direct-sqlcipher ==2.3.*
|
||||
, directory ==1.3.*
|
||||
, filepath ==1.4.*
|
||||
, generic-random ==1.5.*
|
||||
, hourglass ==0.2.*
|
||||
, http-types ==0.12.*
|
||||
, http2 ==4.1.*
|
||||
, http2 >=4.1.4 && <4.2
|
||||
, ini ==0.4.1
|
||||
, iproute ==1.7.*
|
||||
, iso8601-time ==0.1.*
|
||||
@@ -291,8 +289,7 @@ executable smp-agent
|
||||
apps/smp-agent
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
|
||||
build-depends:
|
||||
QuickCheck ==2.14.*
|
||||
, aeson ==2.2.*
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
, asn1-encoding ==0.9.*
|
||||
, asn1-types ==0.3.*
|
||||
@@ -311,10 +308,9 @@ executable smp-agent
|
||||
, direct-sqlcipher ==2.3.*
|
||||
, directory ==1.3.*
|
||||
, filepath ==1.4.*
|
||||
, generic-random ==1.5.*
|
||||
, hourglass ==0.2.*
|
||||
, http-types ==0.12.*
|
||||
, http2 ==4.1.*
|
||||
, http2 >=4.1.4 && <4.2
|
||||
, ini ==0.4.1
|
||||
, iproute ==1.7.*
|
||||
, iso8601-time ==0.1.*
|
||||
@@ -357,8 +353,7 @@ executable smp-server
|
||||
apps/smp-server
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
|
||||
build-depends:
|
||||
QuickCheck ==2.14.*
|
||||
, aeson ==2.2.*
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
, asn1-encoding ==0.9.*
|
||||
, asn1-types ==0.3.*
|
||||
@@ -377,10 +372,9 @@ executable smp-server
|
||||
, direct-sqlcipher ==2.3.*
|
||||
, directory ==1.3.*
|
||||
, filepath ==1.4.*
|
||||
, generic-random ==1.5.*
|
||||
, hourglass ==0.2.*
|
||||
, http-types ==0.12.*
|
||||
, http2 ==4.1.*
|
||||
, http2 >=4.1.4 && <4.2
|
||||
, ini ==0.4.1
|
||||
, iproute ==1.7.*
|
||||
, iso8601-time ==0.1.*
|
||||
@@ -423,8 +417,7 @@ executable xftp
|
||||
apps/xftp
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
|
||||
build-depends:
|
||||
QuickCheck ==2.14.*
|
||||
, aeson ==2.2.*
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
, asn1-encoding ==0.9.*
|
||||
, asn1-types ==0.3.*
|
||||
@@ -443,10 +436,9 @@ executable xftp
|
||||
, direct-sqlcipher ==2.3.*
|
||||
, directory ==1.3.*
|
||||
, filepath ==1.4.*
|
||||
, generic-random ==1.5.*
|
||||
, hourglass ==0.2.*
|
||||
, http-types ==0.12.*
|
||||
, http2 ==4.1.*
|
||||
, http2 >=4.1.4 && <4.2
|
||||
, ini ==0.4.1
|
||||
, iproute ==1.7.*
|
||||
, iso8601-time ==0.1.*
|
||||
@@ -489,8 +481,7 @@ executable xftp-server
|
||||
apps/xftp-server
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
|
||||
build-depends:
|
||||
QuickCheck ==2.14.*
|
||||
, aeson ==2.2.*
|
||||
aeson ==2.2.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
, asn1-encoding ==0.9.*
|
||||
, asn1-types ==0.3.*
|
||||
@@ -509,10 +500,9 @@ executable xftp-server
|
||||
, direct-sqlcipher ==2.3.*
|
||||
, directory ==1.3.*
|
||||
, filepath ==1.4.*
|
||||
, generic-random ==1.5.*
|
||||
, hourglass ==0.2.*
|
||||
, http-types ==0.12.*
|
||||
, http2 ==4.1.*
|
||||
, http2 >=4.1.4 && <4.2
|
||||
, ini ==0.4.1
|
||||
, iproute ==1.7.*
|
||||
, iso8601-time ==0.1.*
|
||||
@@ -610,7 +600,7 @@ test-suite simplexmq-test
|
||||
, hspec ==2.11.*
|
||||
, hspec-core ==2.11.*
|
||||
, http-types ==0.12.*
|
||||
, http2 ==4.1.*
|
||||
, http2 >=4.1.4 && <4.2
|
||||
, ini ==0.4.1
|
||||
, iproute ==1.7.*
|
||||
, iso8601-time ==0.1.*
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
module Simplex.FileTransfer.Agent
|
||||
@@ -322,8 +321,8 @@ runXFTPSndPrepareWorker c doWork = do
|
||||
if status /= SFSEncrypted -- status is SFSNew or SFSEncrypting
|
||||
then do
|
||||
fsEncPath <- toFSFilePath $ sndFileEncPath ppath
|
||||
when (status == SFSEncrypting) $
|
||||
whenM (doesFileExist fsEncPath) $ removeFile fsEncPath
|
||||
when (status == SFSEncrypting) . whenM (doesFileExist fsEncPath) $
|
||||
removeFile fsEncPath
|
||||
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSEncrypting
|
||||
(digest, chunkSpecsDigests) <- encryptFileForUpload sndFile fsEncPath
|
||||
withStore c $ \db -> do
|
||||
@@ -441,11 +440,11 @@ runXFTPSndWorker c srv doWork = do
|
||||
| length rcvIdsKeys > numRecipients = throwError $ INTERNAL "too many recipients"
|
||||
| length rcvIdsKeys == numRecipients = pure cr
|
||||
| otherwise = do
|
||||
maxRecipients <- asks $ xftpMaxRecipientsPerRequest . config
|
||||
let numRecipients' = min (numRecipients - length rcvIdsKeys) maxRecipients
|
||||
rcvIdsKeys' <- agentXFTPAddRecipients c userId chunkDigest cr numRecipients'
|
||||
cr' <- withStore' c $ \db -> addSndChunkReplicaRecipients db cr $ L.toList rcvIdsKeys'
|
||||
addRecipients ch cr'
|
||||
maxRecipients <- asks $ xftpMaxRecipientsPerRequest . config
|
||||
let numRecipients' = min (numRecipients - length rcvIdsKeys) maxRecipients
|
||||
rcvIdsKeys' <- agentXFTPAddRecipients c userId chunkDigest cr numRecipients'
|
||||
cr' <- withStore' c $ \db -> addSndChunkReplicaRecipients db cr $ L.toList rcvIdsKeys'
|
||||
addRecipients ch cr'
|
||||
sndFileToDescrs :: SndFile -> m (ValidFileDescription 'FSender, [ValidFileDescription 'FRecipient])
|
||||
sndFileToDescrs SndFile {digest = Nothing} = throwError $ INTERNAL "snd file has no digest"
|
||||
sndFileToDescrs SndFile {chunks = []} = throwError $ INTERNAL "snd file has no chunks"
|
||||
@@ -573,7 +572,7 @@ runXFTPDelWorker c srv doWork = do
|
||||
withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay
|
||||
atomically $ assertAgentForeground c
|
||||
loop
|
||||
retryDone e = delWorkerInternalError c deletedSndChunkReplicaId e
|
||||
retryDone = delWorkerInternalError c deletedSndChunkReplicaId
|
||||
deleteChunkReplica :: DeletedSndChunkReplica -> m ()
|
||||
deleteChunkReplica replica@DeletedSndChunkReplica {userId, deletedSndChunkReplicaId} = do
|
||||
agentXFTPDeleteChunk c userId replica
|
||||
|
||||
@@ -49,6 +49,7 @@ import Simplex.Messaging.Transport (supportedParameters)
|
||||
import Simplex.Messaging.Transport.Client (TransportClientConfig, TransportHost)
|
||||
import Simplex.Messaging.Transport.HTTP2
|
||||
import Simplex.Messaging.Transport.HTTP2.Client
|
||||
import Simplex.Messaging.Transport.HTTP2.File
|
||||
import Simplex.Messaging.Util (bshow, liftEitherError, whenM)
|
||||
import UnliftIO
|
||||
import UnliftIO.Directory
|
||||
@@ -153,7 +154,7 @@ sendXFTPCommand XFTPClient {config, http2Client = http2@HTTP2Client {sessionId}}
|
||||
forM_ chunkSpec_ $ \XFTPChunkSpec {filePath, chunkOffset, chunkSize} ->
|
||||
withFile filePath ReadMode $ \h -> do
|
||||
hSeek h AbsoluteSeek $ fromIntegral chunkOffset
|
||||
sendFile h send $ fromIntegral chunkSize
|
||||
hSendFile h send $ fromIntegral chunkSize
|
||||
done
|
||||
|
||||
createXFTPChunk ::
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
module Simplex.FileTransfer.Client.Main
|
||||
@@ -527,8 +526,8 @@ prepareChunkSizes size' = prepareSizes size'
|
||||
(smallSize, bigSize)
|
||||
| size' > size34 chunkSize3 = (chunkSize2, chunkSize3)
|
||||
| otherwise = (chunkSize1, chunkSize2)
|
||||
-- | size' > size34 chunkSize2 = (chunkSize1, chunkSize2)
|
||||
-- | otherwise = (chunkSize0, chunkSize1)
|
||||
-- | size' > size34 chunkSize2 = (chunkSize1, chunkSize2)
|
||||
-- | otherwise = (chunkSize0, chunkSize1)
|
||||
size34 sz = (fromIntegral sz * 3) `div` 4
|
||||
prepareSizes 0 = []
|
||||
prepareSizes size
|
||||
@@ -571,11 +570,11 @@ withRetry retryCount = withRetry' retryCount . withExceptT (CLIError . show)
|
||||
removeFD :: Bool -> FilePath -> IO ()
|
||||
removeFD yes fd
|
||||
| yes = do
|
||||
removeFile fd
|
||||
putStrLn $ "\nFile description " <> fd <> " is deleted."
|
||||
removeFile fd
|
||||
putStrLn $ "\nFile description " <> fd <> " is deleted."
|
||||
| otherwise = do
|
||||
y <- liftIO . getConfirmation $ "\nFile description " <> fd <> " can't be used again. Delete it"
|
||||
when y $ removeFile fd
|
||||
y <- liftIO . getConfirmation $ "\nFile description " <> fd <> " can't be used again. Delete it"
|
||||
when y $ removeFile fd
|
||||
|
||||
getConfirmation :: String -> IO Bool
|
||||
getConfirmation prompt = do
|
||||
|
||||
@@ -46,12 +46,12 @@ encryptFile srcFile fileHdr key nonce fileSize' encSize encFile = do
|
||||
encryptChunks_ get w (!sb, !len)
|
||||
| len == 0 = pure sb
|
||||
| otherwise = do
|
||||
let chSize = min len 65536
|
||||
ch <- liftIO $ get chSize
|
||||
when (B.length ch /= fromIntegral chSize) $ throwError $ FTCEFileIOError "encrypting file: unexpected EOF"
|
||||
let (ch', sb') = LC.sbEncryptChunk sb ch
|
||||
liftIO $ B.hPut w ch'
|
||||
encryptChunks_ get w (sb', len - chSize)
|
||||
let chSize = min len 65536
|
||||
ch <- liftIO $ get chSize
|
||||
when (B.length ch /= fromIntegral chSize) $ throwError $ FTCEFileIOError "encrypting file: unexpected EOF"
|
||||
let (ch', sb') = LC.sbEncryptChunk sb ch
|
||||
liftIO $ B.hPut w ch'
|
||||
encryptChunks_ get w (sb', len - chSize)
|
||||
|
||||
decryptChunks :: Int64 -> [FilePath] -> C.SbKey -> C.CbNonce -> (String -> ExceptT String IO CryptoFile) -> ExceptT FTCryptoError IO CryptoFile
|
||||
decryptChunks _ [] _ _ _ = throwError $ FTCEInvalidHeader "empty"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
@@ -9,7 +8,7 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
module Simplex.FileTransfer.Description
|
||||
@@ -38,7 +37,7 @@ where
|
||||
import Control.Applicative (optional)
|
||||
import Control.Monad ((<=<))
|
||||
import Data.Aeson (FromJSON, ToJSON)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
@@ -54,12 +53,11 @@ import Data.Word (Word32)
|
||||
import qualified Data.Yaml as Y
|
||||
import Database.SQLite.Simple.FromField (FromField (..))
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import GHC.Generics (Generic)
|
||||
import Simplex.FileTransfer.Chunks
|
||||
import Simplex.FileTransfer.Protocol
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Parsers (defaultJSON, parseAll)
|
||||
import Simplex.Messaging.Protocol (XFTPServer)
|
||||
import Simplex.Messaging.Util (bshow, groupAllOn, (<$?>))
|
||||
|
||||
@@ -150,21 +148,13 @@ data YAMLFileDescription = YAMLFileDescription
|
||||
chunkSize :: String,
|
||||
replicas :: [YAMLServerReplicas]
|
||||
}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON YAMLFileDescription where
|
||||
toJSON = J.genericToJSON J.defaultOptions
|
||||
toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Eq, Show)
|
||||
|
||||
data YAMLServerReplicas = YAMLServerReplicas
|
||||
{ server :: XFTPServer,
|
||||
chunks :: [String]
|
||||
}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON YAMLServerReplicas where
|
||||
toJSON = J.genericToJSON J.defaultOptions
|
||||
toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Eq, Show)
|
||||
|
||||
data FileServerReplica = FileServerReplica
|
||||
{ chunkNo :: Int,
|
||||
@@ -176,6 +166,13 @@ data FileServerReplica = FileServerReplica
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
newtype FileSize a = FileSize {unFileSize :: a}
|
||||
deriving (Eq, Show)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''YAMLServerReplicas)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''YAMLFileDescription)
|
||||
|
||||
instance FilePartyI p => StrEncoding (ValidFileDescription p) where
|
||||
strEncode (ValidFD fd) = strEncode fd
|
||||
strDecode s = strDecode s >>= (\(AVFD fd) -> checkParty fd)
|
||||
@@ -217,9 +214,6 @@ encodeFileDescription FileDescription {party, size, digest, key, nonce, chunkSiz
|
||||
replicas = encodeFileReplicas chunkSize chunks
|
||||
}
|
||||
|
||||
newtype FileSize a = FileSize {unFileSize :: a}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance (Integral a, Show a) => StrEncoding (FileSize a) where
|
||||
strEncode (FileSize b)
|
||||
| b' /= 0 = bshow b
|
||||
@@ -242,9 +236,9 @@ instance (Integral a, Show a) => StrEncoding (FileSize a) where
|
||||
instance (Integral a, Show a) => IsString (FileSize a) where
|
||||
fromString = either error id . strDecode . B.pack
|
||||
|
||||
instance (FromField a) => FromField (FileSize a) where fromField f = FileSize <$> fromField f
|
||||
instance FromField a => FromField (FileSize a) where fromField f = FileSize <$> fromField f
|
||||
|
||||
instance (ToField a) => ToField (FileSize a) where toField (FileSize s) = toField s
|
||||
instance ToField a => ToField (FileSize a) where toField (FileSize s) = toField s
|
||||
|
||||
groupReplicasByServer :: FileSize Word32 -> [FileChunk] -> [[FileServerReplica]]
|
||||
groupReplicasByServer defChunkSize =
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
@@ -7,6 +6,7 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
@@ -14,8 +14,7 @@
|
||||
module Simplex.FileTransfer.Protocol where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import Data.Aeson (FromJSON, ToJSON)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
@@ -25,8 +24,6 @@ import Data.List.NonEmpty (NonEmpty (..))
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Type.Equality
|
||||
import Data.Word (Word32)
|
||||
import GHC.Generics (Generic)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
@@ -59,7 +56,6 @@ import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport (SessionId, TransportError (..))
|
||||
import Simplex.Messaging.Util (bshow, (<$?>))
|
||||
import Simplex.Messaging.Version
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
|
||||
currentXFTPVersion :: Version
|
||||
currentXFTPVersion = 1
|
||||
@@ -69,14 +65,7 @@ xftpBlockSize = 16384
|
||||
|
||||
-- | File protocol clients
|
||||
data FileParty = FRecipient | FSender
|
||||
deriving (Eq, Show, Generic)
|
||||
|
||||
instance FromJSON FileParty where
|
||||
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "F"
|
||||
|
||||
instance ToJSON FileParty where
|
||||
toJSON = J.genericToJSON . enumJSON $ dropPrefix "F"
|
||||
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "F"
|
||||
deriving (Eq, Show)
|
||||
|
||||
data SFileParty :: FileParty -> Type where
|
||||
SFRecipient :: SFileParty FRecipient
|
||||
@@ -355,14 +344,7 @@ data XFTPErrorType
|
||||
INTERNAL
|
||||
| -- | used internally, never returned by the server (to be removed)
|
||||
DUPLICATE_ -- not part of SMP protocol, used internally
|
||||
deriving (Eq, Generic, Read, Show)
|
||||
|
||||
instance ToJSON XFTPErrorType where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON id
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON id
|
||||
|
||||
instance FromJSON XFTPErrorType where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON id
|
||||
deriving (Eq, Read, Show)
|
||||
|
||||
instance StrEncoding XFTPErrorType where
|
||||
strEncode = \case
|
||||
@@ -370,8 +352,6 @@ instance StrEncoding XFTPErrorType where
|
||||
e -> bshow e
|
||||
strP = "CMD " *> (CMD <$> parseRead1) <|> parseRead1
|
||||
|
||||
instance Arbitrary XFTPErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Encoding XFTPErrorType where
|
||||
smpEncode = \case
|
||||
BLOCK -> "BLOCK"
|
||||
@@ -435,3 +415,7 @@ xftpDecodeTransmission sessionId t = do
|
||||
case tParse True t' of
|
||||
t'' :| [] -> Right $ tDecodeParseValidate sessionId currentXFTPVersion t''
|
||||
_ -> Left BLOCK
|
||||
|
||||
$(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType)
|
||||
|
||||
@@ -169,17 +169,17 @@ processRequest :: HTTP2Request -> M ()
|
||||
processRequest HTTP2Request {sessionId, reqBody = body@HTTP2Body {bodyHead}, sendResponse}
|
||||
| B.length bodyHead /= xftpBlockSize = sendXFTPResponse ("", "", FRErr BLOCK) Nothing
|
||||
| otherwise = do
|
||||
case xftpDecodeTransmission sessionId bodyHead of
|
||||
Right (sig_, signed, (corrId, fId, cmdOrErr)) -> do
|
||||
case cmdOrErr of
|
||||
Right cmd -> do
|
||||
verifyXFTPTransmission sig_ signed fId cmd >>= \case
|
||||
VRVerified req -> uncurry send =<< processXFTPRequest body req
|
||||
VRFailed -> send (FRErr AUTH) Nothing
|
||||
Left e -> send (FRErr e) Nothing
|
||||
where
|
||||
send resp = sendXFTPResponse (corrId, fId, resp)
|
||||
Left e -> sendXFTPResponse ("", "", FRErr e) Nothing
|
||||
case xftpDecodeTransmission sessionId bodyHead of
|
||||
Right (sig_, signed, (corrId, fId, cmdOrErr)) -> do
|
||||
case cmdOrErr of
|
||||
Right cmd -> do
|
||||
verifyXFTPTransmission sig_ signed fId cmd >>= \case
|
||||
VRVerified req -> uncurry send =<< processXFTPRequest body req
|
||||
VRFailed -> send (FRErr AUTH) Nothing
|
||||
Left e -> send (FRErr e) Nothing
|
||||
where
|
||||
send resp = sendXFTPResponse (corrId, fId, resp)
|
||||
Left e -> sendXFTPResponse ("", "", FRErr e) Nothing
|
||||
where
|
||||
sendXFTPResponse :: (CorrId, XFTPFileId, FileResponse) -> Maybe ServerFile -> M ()
|
||||
sendXFTPResponse (corrId, fId, resp) serverFile_ = do
|
||||
|
||||
@@ -26,7 +26,7 @@ import Simplex.FileTransfer.Server.StoreLog
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (BasicAuth, RcvPublicVerifyKey)
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
|
||||
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
|
||||
import Simplex.Messaging.Util (tshow)
|
||||
import System.IO (IOMode (..))
|
||||
import UnliftIO.STM
|
||||
|
||||
@@ -19,7 +19,7 @@ import Options.Applicative
|
||||
import Simplex.FileTransfer.Chunks
|
||||
import Simplex.FileTransfer.Description (FileSize (..))
|
||||
import Simplex.FileTransfer.Server (runXFTPServer)
|
||||
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..), defaultFileExpiration, defFileExpirationHours)
|
||||
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..), defFileExpirationHours, defaultFileExpiration)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern XFTPServer)
|
||||
@@ -143,9 +143,10 @@ xftpServerCLI cfgPath logPath = do
|
||||
allowNewFiles = fromMaybe True $ iniOnOff "AUTH" "new_files" ini,
|
||||
newFileBasicAuth = either error id <$> strDecodeIni "AUTH" "create_password" ini,
|
||||
fileExpiration =
|
||||
Just defaultFileExpiration
|
||||
{ ttl = 3600 * readIniDefault defFileExpirationHours "STORE_LOG" "expire_files_hours" ini
|
||||
},
|
||||
Just
|
||||
defaultFileExpiration
|
||||
{ ttl = 3600 * readIniDefault defFileExpirationHours "STORE_LOG" "expire_files_hours" ini
|
||||
},
|
||||
caCertificateFile = c caCrtFile,
|
||||
privateKeyFile = c serverKeyFile,
|
||||
certificateFile = c serverCrtFile,
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiWayIf #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.FileTransfer.Transport
|
||||
( supportedFileServerVRange,
|
||||
XFTPRcvChunkSpec (..),
|
||||
sendFile,
|
||||
receiveFile,
|
||||
sendEncFile,
|
||||
receiveEncFile,
|
||||
@@ -25,10 +22,10 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Word (Word32)
|
||||
import GHC.IO.Handle.Internals (ioe_EOF)
|
||||
import Simplex.FileTransfer.Protocol (XFTPErrorType (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import Simplex.Messaging.Transport.HTTP2.File
|
||||
import Simplex.Messaging.Version
|
||||
import System.IO (Handle, IOMode (..), withFile)
|
||||
|
||||
@@ -42,18 +39,6 @@ data XFTPRcvChunkSpec = XFTPRcvChunkSpec
|
||||
supportedFileServerVRange :: VersionRange
|
||||
supportedFileServerVRange = mkVersionRange 1 1
|
||||
|
||||
fileBlockSize :: Int
|
||||
fileBlockSize = 16 * 1024
|
||||
|
||||
sendFile :: Handle -> (Builder -> IO ()) -> Word32 -> IO ()
|
||||
sendFile h send = go
|
||||
where
|
||||
go 0 = pure ()
|
||||
go sz =
|
||||
getFileChunk h sz >>= \ch -> do
|
||||
send $ byteString ch
|
||||
go $ sz - fromIntegral (B.length ch)
|
||||
|
||||
sendEncFile :: Handle -> (Builder -> IO ()) -> LC.SbState -> Word32 -> IO ()
|
||||
sendEncFile h send = go
|
||||
where
|
||||
@@ -66,23 +51,10 @@ sendEncFile h send = go
|
||||
send (byteString encCh) `E.catch` \(e :: E.SomeException) -> print e >> E.throwIO e
|
||||
go sbState' $ sz - fromIntegral (B.length ch)
|
||||
|
||||
getFileChunk :: Handle -> Word32 -> IO ByteString
|
||||
getFileChunk h sz =
|
||||
B.hGet h fileBlockSize >>= \case
|
||||
"" -> ioe_EOF
|
||||
ch -> pure $ B.take (fromIntegral sz) ch -- sz >= xftpBlockSize
|
||||
|
||||
receiveFile :: (Int -> IO ByteString) -> XFTPRcvChunkSpec -> ExceptT XFTPErrorType IO ()
|
||||
receiveFile getBody = receiveFile_ receive
|
||||
where
|
||||
receive h sz = do
|
||||
ch <- getBody fileBlockSize
|
||||
let chSize = fromIntegral $ B.length ch
|
||||
if
|
||||
| chSize > sz -> pure $ Left SIZE
|
||||
| chSize > 0 -> B.hPut h ch >> receive h (sz - chSize)
|
||||
| sz == 0 -> pure $ Right ()
|
||||
| otherwise -> pure $ Left SIZE
|
||||
receive h sz = hReceiveFile getBody h sz >>= \sz' -> pure $ if sz' == 0 then Right () else Left SIZE
|
||||
|
||||
receiveEncFile :: (Int -> IO ByteString) -> LC.SbState -> XFTPRcvChunkSpec -> ExceptT XFTPErrorType IO ()
|
||||
receiveEncFile getBody = receiveFile_ . receive
|
||||
@@ -91,8 +63,8 @@ receiveEncFile getBody = receiveFile_ . receive
|
||||
ch <- getBody fileBlockSize
|
||||
let chSize = fromIntegral $ B.length ch
|
||||
if
|
||||
| chSize > sz + authSz -> pure $ Left SIZE
|
||||
| chSize > 0 -> do
|
||||
| chSize > sz + authSz -> pure $ Left SIZE
|
||||
| chSize > 0 -> do
|
||||
let (ch', rest) = B.splitAt (fromIntegral sz) ch
|
||||
(decCh, sbState') = LC.sbDecryptChunk sbState ch'
|
||||
sz' = sz - fromIntegral (B.length ch')
|
||||
@@ -105,7 +77,7 @@ receiveEncFile getBody = receiveFile_ . receive
|
||||
tag = LC.sbAuth sbState'
|
||||
tag'' <- if tagSz == C.authTagSize then pure tag' else (tag' <>) <$> getBody (C.authTagSize - tagSz)
|
||||
pure $ if BA.constEq tag'' tag then Right () else Left CRYPTO
|
||||
| otherwise -> pure $ Left SIZE
|
||||
| otherwise -> pure $ Left SIZE
|
||||
authSz = fromIntegral C.authTagSize
|
||||
|
||||
receiveFile_ :: (Handle -> Word32 -> IO (Either XFTPErrorType ())) -> XFTPRcvChunkSpec -> ExceptT XFTPErrorType IO ()
|
||||
|
||||
+133
-130
@@ -14,7 +14,6 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
-- |
|
||||
@@ -43,6 +42,7 @@ module Simplex.Messaging.Agent
|
||||
disconnectAgentClient,
|
||||
resumeAgentClient,
|
||||
withConnLock,
|
||||
withInvLock,
|
||||
createUser,
|
||||
deleteUser,
|
||||
createConnectionAsync,
|
||||
@@ -151,7 +151,7 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..))
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Parsers (parse)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SubscriptionMode (..), SMPMsgMeta, SProtocolType (..), SndPublicVerifyKey, UserProtocol, XFTPServerWithAuth)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicVerifyKey, SubscriptionMode (..), UserProtocol, XFTPServerWithAuth)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util
|
||||
@@ -462,8 +462,8 @@ deleteUser' c userId delSMPQueues = do
|
||||
atomically $ TM.delete userId $ smpServers c
|
||||
where
|
||||
delUser =
|
||||
whenM (withStore' c (`deleteUserWithoutConns` userId)) $
|
||||
atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId)
|
||||
whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $
|
||||
writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId)
|
||||
|
||||
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> SubscriptionMode -> m ConnId
|
||||
newConnAsync c userId corrId enableNtfs cMode subMode = do
|
||||
@@ -481,16 +481,17 @@ newConnNoQueues c userId connId enableNtfs cMode = do
|
||||
|
||||
joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId
|
||||
joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo subMode = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
case crAgentVRange `compatibleVersion` aVRange of
|
||||
Just (Compatible connAgentVersion) -> do
|
||||
g <- asks idsDrg
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
|
||||
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo
|
||||
pure connId
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
withInvLock c (strEncode cReqUri) "joinConnAsync" $ do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
case crAgentVRange `compatibleVersion` aVRange of
|
||||
Just (Compatible connAgentVersion) -> do
|
||||
g <- asks idsDrg
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
|
||||
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo
|
||||
pure connId
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo =
|
||||
throwError $ CMD PROHIBITED
|
||||
|
||||
@@ -554,11 +555,11 @@ switchConnectionAsync' c corrId connId =
|
||||
SomeConn _ (DuplexConnection cData rqs@(rq :| _rqs) sqs)
|
||||
| isJust (switchingRQ rqs) -> throwError $ CMD PROHIBITED
|
||||
| otherwise -> do
|
||||
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
|
||||
rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn SWCH
|
||||
let rqs' = updatedQs rq1 rqs
|
||||
pure . connectionStats $ DuplexConnection cData rqs' sqs
|
||||
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
|
||||
rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn SWCH
|
||||
let rqs' = updatedQs rq1 rqs
|
||||
pure . connectionStats $ DuplexConnection cData rqs' sqs
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c)
|
||||
@@ -615,24 +616,25 @@ startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {cr
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
|
||||
joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m ConnId
|
||||
joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = do
|
||||
(aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv
|
||||
g <- asks idsDrg
|
||||
connId' <- withStore c $ \db -> runExceptT $ do
|
||||
connId' <- ExceptT $ createSndConn db g cData q
|
||||
liftIO $ createRatchet db connId' rc
|
||||
pure connId'
|
||||
let sq = (q :: SndQueue) {connId = connId'}
|
||||
cData' = (cData :: ConnData) {connId = connId'}
|
||||
duplexHS = connAgentVersion /= 1
|
||||
tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case
|
||||
Right _ -> do
|
||||
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
|
||||
joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv =
|
||||
withInvLock c (strEncode inv) "joinConnSrv" $ do
|
||||
(aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv
|
||||
g <- asks idsDrg
|
||||
connId' <- withStore c $ \db -> runExceptT $ do
|
||||
connId' <- ExceptT $ createSndConn db g cData q
|
||||
liftIO $ createRatchet db connId' rc
|
||||
pure connId'
|
||||
Left e -> do
|
||||
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
|
||||
withStore' c (`deleteConn` connId')
|
||||
throwError e
|
||||
let sq = (q :: SndQueue) {connId = connId'}
|
||||
cData' = (cData :: ConnData) {connId = connId'}
|
||||
duplexHS = connAgentVersion /= 1
|
||||
tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case
|
||||
Right _ -> do
|
||||
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
|
||||
pure connId'
|
||||
Left e -> do
|
||||
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
|
||||
withStore' c (`deleteConn` connId')
|
||||
throwError e
|
||||
joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo subMode srv = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
@@ -813,15 +815,15 @@ getNotificationMessage' c nonce encNtfInfo = do
|
||||
where
|
||||
getNtfMessages ntfConnId maxMs nMeta ms
|
||||
| length ms < maxMs =
|
||||
getConnectionMessage' c ntfConnId >>= \case
|
||||
Just m@SMP.SMPMsgMeta {msgId, msgTs, msgFlags} -> case nMeta of
|
||||
Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'}
|
||||
| msgId == msgId' || msgTs > msgTs' -> pure $ reverse (m : ms)
|
||||
| otherwise -> getMsg (m : ms)
|
||||
_
|
||||
| SMP.notification msgFlags -> pure $ reverse (m : ms)
|
||||
| otherwise -> getMsg (m : ms)
|
||||
_ -> pure $ reverse ms
|
||||
getConnectionMessage' c ntfConnId >>= \case
|
||||
Just m@SMP.SMPMsgMeta {msgId, msgTs, msgFlags} -> case nMeta of
|
||||
Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'}
|
||||
| msgId == msgId' || msgTs > msgTs' -> pure $ reverse (m : ms)
|
||||
| otherwise -> getMsg (m : ms)
|
||||
_
|
||||
| SMP.notification msgFlags -> pure $ reverse (m : ms)
|
||||
| otherwise -> getMsg (m : ms)
|
||||
_ -> pure $ reverse ms
|
||||
| otherwise = pure $ reverse ms
|
||||
where
|
||||
getMsg = getNtfMessages ntfConnId maxMs nMeta
|
||||
@@ -962,12 +964,12 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
|
||||
Just (rq'@RcvQueue {primary}, rq'' : rqs')
|
||||
| primary -> internalErr "ICQDelete: cannot delete primary rcv queue"
|
||||
| otherwise -> do
|
||||
checkRQSwchStatus rq' RSReceivedMessage
|
||||
tryError (deleteQueue c rq') >>= \case
|
||||
Right () -> finalizeSwitch
|
||||
Left e
|
||||
| temporaryOrHostError e -> throwError e
|
||||
| otherwise -> finalizeSwitch >> throwError e
|
||||
checkRQSwchStatus rq' RSReceivedMessage
|
||||
tryError (deleteQueue c rq') >>= \case
|
||||
Right () -> finalizeSwitch
|
||||
Left e
|
||||
| temporaryOrHostError e -> throwError e
|
||||
| otherwise -> finalizeSwitch >> throwError e
|
||||
where
|
||||
finalizeSwitch = do
|
||||
withStore' c $ \db -> deleteConnRcvQueue db rq'
|
||||
@@ -1123,7 +1125,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, dupl
|
||||
-- because the queue must be secured by the time the confirmation or the first HELLO is received
|
||||
| duplexHandshake == Just True -> connErr
|
||||
| otherwise ->
|
||||
ifM (msgExpired helloTimeout) connErr (retrySndMsg RIFast)
|
||||
ifM (msgExpired helloTimeout) connErr (retrySndMsg RIFast)
|
||||
where
|
||||
connErr = case rq_ of
|
||||
-- party initiating connection
|
||||
@@ -1143,8 +1145,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, dupl
|
||||
-- for other operations BROKER HOST is treated as a permanent error (e.g., when connecting to the server),
|
||||
-- the message sending would be retried
|
||||
| temporaryOrHostError e -> do
|
||||
let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout
|
||||
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndMsg RIFast)
|
||||
let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout
|
||||
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndMsg RIFast)
|
||||
| otherwise -> notifyDel msgId err
|
||||
where
|
||||
msgExpired timeoutSel = do
|
||||
@@ -1286,9 +1288,9 @@ switchConnection' c connId =
|
||||
SomeConn _ conn@(DuplexConnection cData rqs@(rq :| _rqs) _)
|
||||
| isJust (switchingRQ rqs) -> throwError $ CMD PROHIBITED
|
||||
| otherwise -> do
|
||||
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
|
||||
rq' <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
|
||||
switchDuplexConnection c conn rq'
|
||||
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
|
||||
rq' <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
|
||||
switchDuplexConnection c conn rq'
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
switchDuplexConnection :: AgentMonad m => AgentClient -> Connection 'CDuplex -> RcvQueue -> m ConnectionStats
|
||||
@@ -1314,19 +1316,19 @@ abortConnectionSwitch' c connId =
|
||||
SomeConn _ (DuplexConnection cData rqs sqs) -> case switchingRQ rqs of
|
||||
Just rq
|
||||
| canAbortRcvSwitch rq -> do
|
||||
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
|
||||
-- multiple queues to which the connections switches were possible when repeating switch was allowed
|
||||
let (delRqs, keepRqs) = L.partition (\q -> Just rq.dbQueueId == q.dbReplaceQueueId) rqs
|
||||
case L.nonEmpty keepRqs of
|
||||
Just rqs' -> do
|
||||
rq' <- withStore' c $ \db -> do
|
||||
mapM_ (setRcvQueueDeleted db) delRqs
|
||||
setRcvSwitchStatus db rq Nothing
|
||||
forM_ delRqs $ \RcvQueue {server, rcvId} -> enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICDeleteRcvQueue rcvId
|
||||
let rqs'' = updatedQs rq' rqs'
|
||||
conn' = DuplexConnection cData rqs'' sqs
|
||||
pure $ connectionStats conn'
|
||||
_ -> throwError $ INTERNAL "won't delete all rcv queues in connection"
|
||||
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
|
||||
-- multiple queues to which the connections switches were possible when repeating switch was allowed
|
||||
let (delRqs, keepRqs) = L.partition (\q -> Just rq.dbQueueId == q.dbReplaceQueueId) rqs
|
||||
case L.nonEmpty keepRqs of
|
||||
Just rqs' -> do
|
||||
rq' <- withStore' c $ \db -> do
|
||||
mapM_ (setRcvQueueDeleted db) delRqs
|
||||
setRcvSwitchStatus db rq Nothing
|
||||
forM_ delRqs $ \RcvQueue {server, rcvId} -> enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICDeleteRcvQueue rcvId
|
||||
let rqs'' = updatedQs rq' rqs'
|
||||
conn' = DuplexConnection cData rqs'' sqs
|
||||
pure $ connectionStats conn'
|
||||
_ -> throwError $ INTERNAL "won't delete all rcv queues in connection"
|
||||
| otherwise -> throwError $ CMD PROHIBITED
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
@@ -1336,16 +1338,16 @@ synchronizeRatchet' c connId force = withConnLock c connId "synchronizeRatchet"
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection cData rqs sqs)
|
||||
| ratchetSyncAllowed cData || force -> do
|
||||
-- check queues are not switching?
|
||||
AgentConfig {e2eEncryptVRange} <- asks config
|
||||
(pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange
|
||||
void $ enqueueRatchetKeyMsgs c cData sqs e2eParams
|
||||
withStore' c $ \db -> do
|
||||
setConnRatchetSync db connId RSStarted
|
||||
setRatchetX3dhKeys db connId pk1 pk2 k1 k2
|
||||
let cData' = cData {ratchetSyncState = RSStarted} :: ConnData
|
||||
conn' = DuplexConnection cData' rqs sqs
|
||||
pure $ connectionStats conn'
|
||||
-- check queues are not switching?
|
||||
AgentConfig {e2eEncryptVRange} <- asks config
|
||||
(pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange
|
||||
void $ enqueueRatchetKeyMsgs c cData sqs e2eParams
|
||||
withStore' c $ \db -> do
|
||||
setConnRatchetSync db connId RSStarted
|
||||
setRatchetX3dhKeys db connId pk1 pk2 k1 k2
|
||||
let cData' = cData {ratchetSyncState = RSStarted} :: ConnData
|
||||
conn' = DuplexConnection cData' rqs sqs
|
||||
pure $ connectionStats conn'
|
||||
| otherwise -> throwError $ CMD PROHIBITED
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
@@ -1521,23 +1523,23 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
|
||||
-- possible improvement: add minimal time before repeat registration
|
||||
(Just tknId, Nothing)
|
||||
| savedDeviceToken == suppliedDeviceToken ->
|
||||
when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered
|
||||
when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered
|
||||
| otherwise -> replaceToken tknId
|
||||
(Just tknId, Just (NTAVerify code))
|
||||
| savedDeviceToken == suppliedDeviceToken ->
|
||||
t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code
|
||||
t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code
|
||||
| otherwise -> replaceToken tknId
|
||||
(Just tknId, Just NTACheck)
|
||||
| savedDeviceToken == suppliedDeviceToken -> do
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ nsUpdateToken ns tkn {ntfMode = suppliedNtfMode}
|
||||
when (ntfTknStatus == NTActive) $ do
|
||||
cron <- asks $ ntfCron . config
|
||||
agentNtfEnableCron c tknId tkn cron
|
||||
when (suppliedNtfMode == NMInstant) $ initializeNtfSubs c
|
||||
when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete
|
||||
-- possible improvement: get updated token status from the server, or maybe TCRON could return the current status
|
||||
pure ntfTknStatus
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ nsUpdateToken ns tkn {ntfMode = suppliedNtfMode}
|
||||
when (ntfTknStatus == NTActive) $ do
|
||||
cron <- asks $ ntfCron . config
|
||||
agentNtfEnableCron c tknId tkn cron
|
||||
when (suppliedNtfMode == NMInstant) $ initializeNtfSubs c
|
||||
when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete
|
||||
-- possible improvement: get updated token status from the server, or maybe TCRON could return the current status
|
||||
pure ntfTknStatus
|
||||
| otherwise -> replaceToken tknId
|
||||
(Just tknId, Just NTADelete) -> do
|
||||
agentNtfDeleteToken c tknId tkn
|
||||
@@ -1647,10 +1649,10 @@ toggleConnectionNtfs' c connId enable = do
|
||||
toggle cData
|
||||
| enableNtfs cData == enable = pure ()
|
||||
| otherwise = do
|
||||
withStore' c $ \db -> setConnectionNtfs db connId enable
|
||||
ns <- asks ntfSupervisor
|
||||
let cmd = if enable then NSCCreate else NSCDelete
|
||||
atomically $ sendNtfSubCommand ns (connId, cmd)
|
||||
withStore' c $ \db -> setConnectionNtfs db connId enable
|
||||
ns <- asks ntfSupervisor
|
||||
let cmd = if enable then NSCCreate else NSCDelete
|
||||
atomically $ sendNtfSubCommand ns (connId, cmd)
|
||||
|
||||
deleteToken_ :: AgentMonad m => AgentClient -> NtfToken -> m ()
|
||||
deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do
|
||||
@@ -1743,11 +1745,12 @@ getAgentMigrations' :: AgentMonad m => AgentClient -> m [UpMigration]
|
||||
getAgentMigrations' c = map upMigration <$> withStore' c (Migrations.getCurrent . DB.conn)
|
||||
|
||||
debugAgentLocks' :: AgentMonad' m => AgentClient -> m AgentLocks
|
||||
debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs, deleteLock = d} = do
|
||||
debugAgentLocks' AgentClient {connLocks = cs, invLocks = is, reconnectLocks = rs, deleteLock = d} = do
|
||||
connLocks <- getLocks cs
|
||||
invLocks <- getLocks is
|
||||
srvLocks <- getLocks rs
|
||||
delLock <- atomically $ tryReadTMVar d
|
||||
pure AgentLocks {connLocks, srvLocks, delLock}
|
||||
pure AgentLocks {connLocks, invLocks, srvLocks, delLock}
|
||||
where
|
||||
getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls)
|
||||
|
||||
@@ -1912,11 +1915,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
|
||||
resetRatchetSync :: m (Connection c)
|
||||
resetRatchetSync
|
||||
| rss `notElem` ([RSOk, RSStarted] :: [RatchetSyncState]) = do
|
||||
let cData'' = (toConnData conn') {ratchetSyncState = RSOk} :: ConnData
|
||||
conn'' = updateConnection cData'' conn'
|
||||
notify . RSYNC RSOk Nothing $ connectionStats conn''
|
||||
withStore' c $ \db -> setConnRatchetSync db connId RSOk
|
||||
pure conn''
|
||||
let cData'' = (toConnData conn') {ratchetSyncState = RSOk} :: ConnData
|
||||
conn'' = updateConnection cData'' conn'
|
||||
notify . RSYNC RSOk Nothing $ connectionStats conn''
|
||||
withStore' c $ \db -> setConnRatchetSync db connId RSOk
|
||||
pure conn''
|
||||
| otherwise = pure conn'
|
||||
Right _ -> prohibited >> ack
|
||||
Left e@(AGENT A_DUPLICATE) -> do
|
||||
@@ -1924,11 +1927,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
|
||||
Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck}
|
||||
| userAck -> ackDel internalId
|
||||
| otherwise -> do
|
||||
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
|
||||
AgentMessage _ (A_MSG body) -> do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
notify $ MSG msgMeta msgFlags body
|
||||
_ -> pure ()
|
||||
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
|
||||
AgentMessage _ (A_MSG body) -> do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
notify $ MSG msgMeta msgFlags body
|
||||
_ -> pure ()
|
||||
_ -> checkDuplicateHash e encryptedMsgHash >> ack
|
||||
Left (AGENT (A_CRYPTO e)) -> do
|
||||
exists <- withStore' c $ \db -> checkRcvMsgHashExists db connId encryptedMsgHash
|
||||
@@ -1976,9 +1979,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
|
||||
case msgAVRange `compatibleVersion` aVRange of
|
||||
Just (Compatible av)
|
||||
| av > connAgentVersion -> do
|
||||
withStore' c $ \db -> setConnAgentVersion db connId av
|
||||
let cData'' = cData' {connAgentVersion = av} :: ConnData
|
||||
pure $ updateConnection cData'' conn'
|
||||
withStore' c $ \db -> setConnAgentVersion db connId av
|
||||
let cData'' = cData' {connAgentVersion = av} :: ConnData
|
||||
pure $ updateConnection cData'' conn'
|
||||
| otherwise -> pure conn'
|
||||
Nothing -> pure conn'
|
||||
ack :: m ()
|
||||
@@ -1994,9 +1997,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
|
||||
processEND = \case
|
||||
Just (Right clnt)
|
||||
| sessId == sessionId clnt -> do
|
||||
removeSubscription c connId
|
||||
notify' END
|
||||
pure "END"
|
||||
removeSubscription c connId
|
||||
notify' END
|
||||
pure "END"
|
||||
| otherwise -> ignored
|
||||
_ -> ignored
|
||||
ignored = pure "END from disconnected client - ignored"
|
||||
@@ -2186,12 +2189,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
|
||||
case findRQ (smpServer, senderId) rqs of
|
||||
Just rq'@RcvQueue {rcvId, e2ePrivKey = dhPrivKey, smpClientVersion = cVer, status = status'}
|
||||
| status' == New || status' == Confirmed -> do
|
||||
checkRQSwchStatus rq RSSendingQADD
|
||||
logServer "<--" c srv rId $ "MSG <QKEY> " <> logSecret senderId
|
||||
let dhSecret = C.dh' dhPublicKey dhPrivKey
|
||||
withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer'
|
||||
enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey
|
||||
notify . SWITCH QDRcv SPConfirmed $ connectionStats conn'
|
||||
checkRQSwchStatus rq RSSendingQADD
|
||||
logServer "<--" c srv rId $ "MSG <QKEY> " <> logSecret senderId
|
||||
let dhSecret = C.dh' dhPublicKey dhPrivKey
|
||||
withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer'
|
||||
enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey
|
||||
notify . SWITCH QDRcv SPConfirmed $ connectionStats conn'
|
||||
| otherwise -> qError "QKEY: queue already secured"
|
||||
_ -> qError "QKEY: queue address not found in connection"
|
||||
where
|
||||
@@ -2227,8 +2230,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
|
||||
ereadyMsg rcPrev (DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) = do
|
||||
let CR.Ratchet {rcSnd} = rcPrev
|
||||
-- if ratchet was initialized as receiving, it means EREADY wasn't sent on key negotiation
|
||||
when (isNothing rcSnd) $
|
||||
void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId
|
||||
when (isNothing rcSnd) . void $
|
||||
enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} (EREADY lastExternalSndId)
|
||||
|
||||
smpInvitation :: Connection c -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
smpInvitation conn' connReq@(CRInvitationUri crData _) cInfo = do
|
||||
@@ -2267,9 +2270,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
|
||||
getSendRatchetKeys
|
||||
| rss == RSStarted = withStore c (`getRatchetX3dhKeys'` connId)
|
||||
| otherwise = do
|
||||
(pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ version e2eOtherPartyParams
|
||||
void $ enqueueRatchetKeyMsgs c cData' sqs e2eParams
|
||||
pure (pk1, pk2, k1, k2)
|
||||
(pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ version e2eOtherPartyParams
|
||||
void $ enqueueRatchetKeyMsgs c cData' sqs e2eParams
|
||||
pure (pk1, pk2, k1, k2)
|
||||
notifyAgreed :: m ()
|
||||
notifyAgreed = do
|
||||
let cData'' = cData' {ratchetSyncState = RSAgreed} :: ConnData
|
||||
@@ -2285,11 +2288,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
|
||||
initRatchet :: VersionRange -> (C.PrivateKeyX448, C.PrivateKeyX448, C.PublicKeyX448, C.PublicKeyX448) -> m ()
|
||||
initRatchet e2eEncryptVRange (pk1, pk2, k1, k2)
|
||||
| rkHash k1 k2 <= rkHashRcv = do
|
||||
recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 $ CR.x3dhRcv pk1 pk2 e2eOtherPartyParams
|
||||
recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 $ CR.x3dhRcv pk1 pk2 e2eOtherPartyParams
|
||||
| otherwise = do
|
||||
(_, rcDHRs) <- liftIO C.generateKeyPair'
|
||||
recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs $ CR.x3dhSnd pk1 pk2 e2eOtherPartyParams
|
||||
void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId
|
||||
(_, rcDHRs) <- liftIO C.generateKeyPair'
|
||||
recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs $ CR.x3dhSnd pk1 pk2 e2eOtherPartyParams
|
||||
void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId
|
||||
|
||||
checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity
|
||||
checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash
|
||||
@@ -2347,8 +2350,8 @@ mkAgentConfirmation :: AgentMonad m => Compatible Version -> AgentClient -> Conn
|
||||
mkAgentConfirmation (Compatible agentVersion) c cData sq srv connInfo subMode
|
||||
| agentVersion == 1 = pure $ AgentConnInfo connInfo
|
||||
| otherwise = do
|
||||
qInfo <- createReplyQueue c cData sq subMode srv
|
||||
pure $ AgentConnInfoReply (qInfo :| []) connInfo
|
||||
qInfo <- createReplyQueue c cData sq subMode srv
|
||||
pure $ AgentConnInfoReply (qInfo :| []) connInfo
|
||||
|
||||
enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
|
||||
enqueueConfirmation c cData sq connInfo e2eEncryption_ = do
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
@@ -15,6 +14,7 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilyDependencies #-}
|
||||
@@ -25,6 +25,7 @@ module Simplex.Messaging.Agent.Client
|
||||
ProtocolTestStep (..),
|
||||
newAgentClient,
|
||||
withConnLock,
|
||||
withInvLock,
|
||||
closeAgentClient,
|
||||
closeProtocolServerClients,
|
||||
closeXFTPServerClient,
|
||||
@@ -117,8 +118,7 @@ import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random (getRandomBytes)
|
||||
import Data.Aeson (FromJSON, ToJSON)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Bifunctor (bimap, first, second)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
@@ -137,7 +137,6 @@ import Data.Text (Text)
|
||||
import Data.Text.Encoding
|
||||
import Data.Time (UTCTime, defaultTimeLocale, formatTime, getCurrentTime)
|
||||
import Data.Word (Word16)
|
||||
import GHC.Generics (Generic)
|
||||
import Network.Socket (HostName)
|
||||
import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError)
|
||||
import qualified Simplex.FileTransfer.Client as X
|
||||
@@ -164,7 +163,7 @@ import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Client
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Parsers (dropPrefix, enumJSON, parse)
|
||||
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, parse)
|
||||
import Simplex.Messaging.Protocol
|
||||
( AProtocolType (..),
|
||||
BrokerMsg,
|
||||
@@ -249,6 +248,8 @@ data AgentClient = AgentClient
|
||||
getMsgLocks :: TMap (SMPServer, SMP.RecipientId) (TMVar ()),
|
||||
-- locks to prevent concurrent operations with connection
|
||||
connLocks :: TMap ConnId Lock,
|
||||
-- locks to prevent concurrent operations with connection request invitations
|
||||
invLocks :: TMap ByteString Lock,
|
||||
-- lock to prevent concurrency between periodic and async connection deletions
|
||||
deleteLock :: Lock,
|
||||
-- locks to prevent concurrent reconnections to SMP servers
|
||||
@@ -279,10 +280,13 @@ data AgentOpState = AgentOpState {opSuspended :: Bool, opsInProgress :: Int}
|
||||
data AgentState = ASForeground | ASSuspending | ASSuspended
|
||||
deriving (Eq, Show)
|
||||
|
||||
data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String, delLock :: Maybe String}
|
||||
deriving (Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
data AgentLocks = AgentLocks
|
||||
{ connLocks :: Map String String,
|
||||
invLocks :: Map String String,
|
||||
srvLocks :: Map String String,
|
||||
delLock :: Maybe String
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data AgentStatsKey = AgentStatsKey
|
||||
{ userId :: UserId,
|
||||
@@ -325,6 +329,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
|
||||
agentState <- newTVar ASForeground
|
||||
getMsgLocks <- TM.empty
|
||||
connLocks <- TM.empty
|
||||
invLocks <- TM.empty
|
||||
deleteLock <- createLock
|
||||
reconnectLocks <- TM.empty
|
||||
reconnections <- newTAsyncs
|
||||
@@ -362,6 +367,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
|
||||
agentState,
|
||||
getMsgLocks,
|
||||
connLocks,
|
||||
invLocks,
|
||||
deleteLock,
|
||||
reconnectLocks,
|
||||
reconnections,
|
||||
@@ -645,6 +651,9 @@ withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a
|
||||
withConnLock _ "" _ = id
|
||||
withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId name
|
||||
|
||||
withInvLock :: MonadUnliftIO m => AgentClient -> ByteString -> String -> m a -> m a
|
||||
withInvLock AgentClient {invLocks} = withLockMap_ invLocks
|
||||
|
||||
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
|
||||
withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pure
|
||||
where
|
||||
@@ -725,23 +734,14 @@ data ProtocolTestStep
|
||||
| TSDownloadFile
|
||||
| TSCompareFile
|
||||
| TSDeleteFile
|
||||
deriving (Eq, Show, Generic)
|
||||
|
||||
instance ToJSON ProtocolTestStep where
|
||||
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TS"
|
||||
toJSON = J.genericToJSON . enumJSON $ dropPrefix "TS"
|
||||
|
||||
instance FromJSON ProtocolTestStep where
|
||||
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TS"
|
||||
deriving (Eq, Show)
|
||||
|
||||
data ProtocolTestFailure = ProtocolTestFailure
|
||||
{ testStep :: ProtocolTestStep,
|
||||
testError :: AgentErrorType
|
||||
}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance ToJSON ProtocolTestFailure where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
|
||||
runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe ProtocolTestFailure)
|
||||
runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
cfg <- getClientConfig c smpCfg
|
||||
@@ -901,8 +901,8 @@ subscribeQueues c qs = do
|
||||
subscribeQueues_ u smp qs' = do
|
||||
rs <- sendBatch subscribeSMPQueues smp qs'
|
||||
mapM_ (uncurry $ processSubResult c) rs
|
||||
when (any temporaryClientError . lefts . map snd $ L.toList rs) $
|
||||
unliftIO u $ reconnectServer c $ transportSession' smp
|
||||
when (any temporaryClientError . lefts . map snd $ L.toList rs) . unliftIO u $
|
||||
reconnectServer c (transportSession' smp)
|
||||
pure rs
|
||||
|
||||
type BatchResponses e r = (NonEmpty (RcvQueue, Either e r))
|
||||
@@ -989,7 +989,7 @@ sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer,
|
||||
mkInvitation = do
|
||||
let agentEnvelope = AgentInvitation {agentVersion, connReq, connInfo}
|
||||
agentCbEncryptOnce v dhPublicKey . smpEncode $
|
||||
SMP.ClientMessage SMP.PHEmpty $ smpEncode agentEnvelope
|
||||
SMP.ClientMessage SMP.PHEmpty (smpEncode agentEnvelope)
|
||||
|
||||
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta)
|
||||
getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
|
||||
@@ -1324,7 +1324,7 @@ userServers c = case protocolTypeI @p of
|
||||
SPSMP -> smpServers c
|
||||
SPXFTP -> xftpServers c
|
||||
|
||||
pickServer :: forall p m. (AgentMonad' m) => NonEmpty (ProtoServerWithAuth p) -> m (ProtoServerWithAuth p)
|
||||
pickServer :: forall p m. AgentMonad' m => NonEmpty (ProtoServerWithAuth p) -> m (ProtoServerWithAuth p)
|
||||
pickServer = \case
|
||||
srv :| [] -> pure srv
|
||||
servers -> do
|
||||
@@ -1343,7 +1343,7 @@ withUserServers c userId action =
|
||||
Just srvs -> action srvs
|
||||
_ -> throwError $ INTERNAL "unknown userId - no user servers"
|
||||
|
||||
withNextSrv :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> ((ProtoServerWithAuth p) -> m a) -> m a
|
||||
withNextSrv :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> m a) -> m a
|
||||
withNextSrv c userId usedSrvs initUsed action = do
|
||||
used <- readTVarIO usedSrvs
|
||||
srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId used
|
||||
@@ -1355,22 +1355,14 @@ withNextSrv c userId usedSrvs initUsed action = do
|
||||
action srvAuth
|
||||
|
||||
data SubInfo = SubInfo {userId :: UserId, server :: Text, rcvId :: Text, subError :: Maybe String}
|
||||
deriving (Show, Generic)
|
||||
|
||||
instance ToJSON SubInfo where toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
|
||||
|
||||
instance FromJSON SubInfo where parseJSON = J.genericParseJSON J.defaultOptions {J.omitNothingFields = True}
|
||||
deriving (Show)
|
||||
|
||||
data SubscriptionsInfo = SubscriptionsInfo
|
||||
{ activeSubscriptions :: [SubInfo],
|
||||
pendingSubscriptions :: [SubInfo],
|
||||
removedSubscriptions :: [SubInfo]
|
||||
}
|
||||
deriving (Show, Generic)
|
||||
|
||||
instance ToJSON SubscriptionsInfo where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
|
||||
instance FromJSON SubscriptionsInfo where parseJSON = J.genericParseJSON J.defaultOptions
|
||||
deriving (Show)
|
||||
|
||||
getAgentSubscriptions :: MonadIO m => AgentClient -> m SubscriptionsInfo
|
||||
getAgentSubscriptions c = do
|
||||
@@ -1382,6 +1374,16 @@ getAgentSubscriptions c = do
|
||||
getSubs sel = map (`subInfo` Nothing) . M.keys <$> readTVarIO (getRcvQueues $ sel c)
|
||||
getRemovedSubs = map (uncurry subInfo . second Just) . M.assocs <$> readTVarIO (removedSubs c)
|
||||
subInfo :: (UserId, SMPServer, SMP.RecipientId) -> Maybe SMPClientError -> SubInfo
|
||||
subInfo (uId, srv, rId) err = SubInfo {userId = uId, server = enc srv, rcvId = enc rId, subError = show <$> err}
|
||||
subInfo (uId, srv, rId) err = SubInfo {userId = uId, server = enc srv, rcvId = enc rId, subError = show <$> err}
|
||||
enc :: StrEncoding a => a -> Text
|
||||
enc = decodeLatin1 . strEncode
|
||||
|
||||
$(J.deriveJSON defaultJSON ''AgentLocks)
|
||||
|
||||
$(J.deriveJSON (enumJSON $ dropPrefix "TS") ''ProtocolTestStep)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''ProtocolTestFailure)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''SubInfo)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''SubscriptionsInfo)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Lock where
|
||||
|
||||
import Control.Monad (void)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
module Simplex.Messaging.Agent.NtfSubSupervisor
|
||||
@@ -100,14 +99,14 @@ processNtfSub c (connId, cmd) = do
|
||||
Just (action, _)
|
||||
-- subscription was marked for deletion / is being deleted
|
||||
| isDeleteNtfSubAction action -> do
|
||||
if ntfSubStatus == NASNew || ntfSubStatus == NASOff || ntfSubStatus == NASDeleted
|
||||
then resetSubscription
|
||||
else withNtfServer c $ \ntfServer -> do
|
||||
withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate)
|
||||
addNtfNTFWorker ntfServer
|
||||
if ntfSubStatus == NASNew || ntfSubStatus == NASOff || ntfSubStatus == NASDeleted
|
||||
then resetSubscription
|
||||
else withNtfServer c $ \ntfServer -> do
|
||||
withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate)
|
||||
addNtfNTFWorker ntfServer
|
||||
| otherwise -> case action of
|
||||
NtfSubNTFAction _ -> addNtfNTFWorker subNtfServer
|
||||
NtfSubSMPAction _ -> addNtfSMPWorker smpServer
|
||||
NtfSubNTFAction _ -> addNtfNTFWorker subNtfServer
|
||||
NtfSubSMPAction _ -> addNtfSMPWorker smpServer
|
||||
rotate :: m ()
|
||||
rotate = do
|
||||
withStore' c $ \db -> supervisorUpdateNtfSub db sub (NtfSubNTFAction NSARotate)
|
||||
@@ -291,11 +290,11 @@ rescheduleAction :: AgentMonad' m => TMVar () -> UTCTime -> UTCTime -> m Bool
|
||||
rescheduleAction doWork ts actionTs
|
||||
| actionTs <= ts = pure False
|
||||
| otherwise = do
|
||||
void . atomically $ tryTakeTMVar doWork
|
||||
void . forkIO $ do
|
||||
liftIO $ threadDelay' $ diffToMicroseconds $ diffUTCTime actionTs ts
|
||||
void . atomically $ tryPutTMVar doWork ()
|
||||
pure True
|
||||
void . atomically $ tryTakeTMVar doWork
|
||||
void . forkIO $ do
|
||||
liftIO $ threadDelay' $ diffToMicroseconds $ diffUTCTime actionTs ts
|
||||
void . atomically $ tryPutTMVar doWork ()
|
||||
pure True
|
||||
|
||||
retryOnError :: AgentMonad' m => AgentClient -> Text -> m () -> (AgentErrorType -> m ()) -> AgentErrorType -> m ()
|
||||
retryOnError c name loop done e = do
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
@@ -13,6 +12,7 @@
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
@@ -151,7 +151,7 @@ import Control.Monad (unless)
|
||||
import Control.Monad.Except (runExceptT, throwError)
|
||||
import Control.Monad.IO.Class
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Base64
|
||||
@@ -176,8 +176,6 @@ import Data.Typeable ()
|
||||
import Data.Word (Word32)
|
||||
import Database.SQLite.Simple.FromField
|
||||
import Database.SQLite.Simple.ToField
|
||||
import GHC.Generics (Generic)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Simplex.FileTransfer.Description
|
||||
import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType)
|
||||
import Simplex.Messaging.Agent.QueryString
|
||||
@@ -213,7 +211,6 @@ import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTra
|
||||
import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts_ (..))
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
import Text.Read
|
||||
import UnliftIO.Exception (Exception)
|
||||
|
||||
@@ -614,15 +611,11 @@ data RcvQueueInfo = RcvQueueInfo
|
||||
rcvSwitchStatus :: Maybe RcvSwitchStatus,
|
||||
canAbortSwitch :: Bool
|
||||
}
|
||||
deriving (Eq, Show, Generic)
|
||||
|
||||
instance ToJSON RcvQueueInfo where toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
|
||||
|
||||
instance FromJSON RcvQueueInfo where parseJSON = J.genericParseJSON J.defaultOptions {J.omitNothingFields = True}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding RcvQueueInfo where
|
||||
strEncode RcvQueueInfo {rcvServer, rcvSwitchStatus, canAbortSwitch} =
|
||||
"srv=" <> strEncode rcvServer
|
||||
("srv=" <> strEncode rcvServer)
|
||||
<> maybe "" (\switch -> ";switch=" <> strEncode switch) rcvSwitchStatus
|
||||
<> (";can_abort_switch=" <> strEncode canAbortSwitch)
|
||||
strP = do
|
||||
@@ -635,11 +628,7 @@ data SndQueueInfo = SndQueueInfo
|
||||
{ sndServer :: SMPServer,
|
||||
sndSwitchStatus :: Maybe SndSwitchStatus
|
||||
}
|
||||
deriving (Eq, Show, Generic)
|
||||
|
||||
instance ToJSON SndQueueInfo where toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
|
||||
|
||||
instance FromJSON SndQueueInfo where parseJSON = J.genericParseJSON J.defaultOptions {J.omitNothingFields = True}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding SndQueueInfo where
|
||||
strEncode SndQueueInfo {sndServer, sndSwitchStatus} =
|
||||
@@ -656,13 +645,11 @@ data ConnectionStats = ConnectionStats
|
||||
ratchetSyncState :: RatchetSyncState,
|
||||
ratchetSyncSupported :: Bool
|
||||
}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON ConnectionStats where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding ConnectionStats where
|
||||
strEncode ConnectionStats {connAgentVersion, rcvQueuesInfo, sndQueuesInfo, ratchetSyncState, ratchetSyncSupported} =
|
||||
"agent_version=" <> strEncode connAgentVersion
|
||||
("agent_version=" <> strEncode connAgentVersion)
|
||||
<> (" rcv=" <> strEncodeList rcvQueuesInfo)
|
||||
<> (" snd=" <> strEncodeList sndQueuesInfo)
|
||||
<> (" sync=" <> strEncode ratchetSyncState)
|
||||
@@ -1048,7 +1035,7 @@ instance StrEncoding MsgReceiptStatus where
|
||||
MROk -> "ok"
|
||||
MRBadMsgHash -> "badMsgHash"
|
||||
strP =
|
||||
A.takeWhile1 (/= ' ') >>= \ case
|
||||
A.takeWhile1 (/= ' ') >>= \case
|
||||
"ok" -> pure MROk
|
||||
"badMsgHash" -> pure MRBadMsgHash
|
||||
_ -> fail "bad MsgReceiptStatus"
|
||||
@@ -1391,7 +1378,7 @@ type AgentMsgId = Int64
|
||||
|
||||
-- | Result of received message integrity validation.
|
||||
data MsgIntegrity = MsgOk | MsgError {errorInfo :: MsgErrorType}
|
||||
deriving (Eq, Show, Generic)
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding MsgIntegrity where
|
||||
strP = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> strP)
|
||||
@@ -1399,20 +1386,13 @@ instance StrEncoding MsgIntegrity where
|
||||
MsgOk -> "OK"
|
||||
MsgError e -> "ERR " <> strEncode e
|
||||
|
||||
instance ToJSON MsgIntegrity where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON fstToLower
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON fstToLower
|
||||
|
||||
instance FromJSON MsgIntegrity where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON fstToLower
|
||||
|
||||
-- | Error of message integrity validation.
|
||||
data MsgErrorType
|
||||
= MsgSkipped {fromMsgId :: AgentMsgId, toMsgId :: AgentMsgId}
|
||||
| MsgBadId {msgId :: AgentMsgId}
|
||||
| MsgBadHash
|
||||
| MsgDuplicate
|
||||
deriving (Eq, Show, Generic)
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding MsgErrorType where
|
||||
strP =
|
||||
@@ -1427,13 +1407,6 @@ instance StrEncoding MsgErrorType where
|
||||
MsgBadHash -> "HASH"
|
||||
MsgDuplicate -> "DUPLICATE"
|
||||
|
||||
instance ToJSON MsgErrorType where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON fstToLower
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON fstToLower
|
||||
|
||||
instance FromJSON MsgErrorType where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON fstToLower
|
||||
|
||||
-- | Error type used in errors sent to agent clients.
|
||||
data AgentErrorType
|
||||
= -- | command or response error
|
||||
@@ -1454,14 +1427,7 @@ data AgentErrorType
|
||||
INTERNAL {internalErr :: String}
|
||||
| -- | agent inactive
|
||||
INACTIVE
|
||||
deriving (Eq, Generic, Show, Exception)
|
||||
|
||||
instance ToJSON AgentErrorType where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON id
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON id
|
||||
|
||||
instance FromJSON AgentErrorType where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON id
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
-- | SMP agent protocol command or response error.
|
||||
data CommandErrorType
|
||||
@@ -1475,14 +1441,7 @@ data CommandErrorType
|
||||
SIZE
|
||||
| -- | message does not fit in SMP block
|
||||
LARGE
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
instance ToJSON CommandErrorType where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON id
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON id
|
||||
|
||||
instance FromJSON CommandErrorType where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON id
|
||||
deriving (Eq, Read, Show, Exception)
|
||||
|
||||
-- | Connection error.
|
||||
data ConnectionErrorType
|
||||
@@ -1496,14 +1455,7 @@ data ConnectionErrorType
|
||||
NOT_ACCEPTED
|
||||
| -- | connection not available on reply confirmation/HELLO after timeout
|
||||
NOT_AVAILABLE
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
instance ToJSON ConnectionErrorType where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON id
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON id
|
||||
|
||||
instance FromJSON ConnectionErrorType where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON id
|
||||
deriving (Eq, Read, Show, Exception)
|
||||
|
||||
-- | SMP server errors.
|
||||
data BrokerErrorType
|
||||
@@ -1519,14 +1471,7 @@ data BrokerErrorType
|
||||
TRANSPORT {transportErr :: TransportError}
|
||||
| -- | command response timeout
|
||||
TIMEOUT
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
instance ToJSON BrokerErrorType where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON id
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON id
|
||||
|
||||
instance FromJSON BrokerErrorType where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON id
|
||||
deriving (Eq, Read, Show, Exception)
|
||||
|
||||
-- | Errors of another SMP agent.
|
||||
data SMPAgentError
|
||||
@@ -1543,7 +1488,7 @@ data SMPAgentError
|
||||
A_DUPLICATE
|
||||
| -- | error in the message to add/delete/etc queue in connection
|
||||
A_QUEUE {queueErr :: String}
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
deriving (Eq, Read, Show, Exception)
|
||||
|
||||
data AgentCryptoError
|
||||
= -- | AES decryption error
|
||||
@@ -1556,14 +1501,7 @@ data AgentCryptoError
|
||||
RATCHET_EARLIER Word32
|
||||
| -- | too many skipped messages
|
||||
RATCHET_SKIPPED Word32
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
instance ToJSON AgentCryptoError where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON id
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON id
|
||||
|
||||
instance FromJSON AgentCryptoError where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON id
|
||||
deriving (Eq, Read, Show, Exception)
|
||||
|
||||
instance StrEncoding AgentCryptoError where
|
||||
strP =
|
||||
@@ -1579,13 +1517,6 @@ instance StrEncoding AgentCryptoError where
|
||||
RATCHET_EARLIER n -> "RATCHET_EARLIER " <> strEncode n
|
||||
RATCHET_SKIPPED n -> "RATCHET_SKIPPED " <> strEncode n
|
||||
|
||||
instance ToJSON SMPAgentError where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON id
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON id
|
||||
|
||||
instance FromJSON SMPAgentError where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON id
|
||||
|
||||
instance StrEncoding AgentErrorType where
|
||||
strP =
|
||||
"CMD " *> (CMD <$> parseRead1)
|
||||
@@ -1620,18 +1551,6 @@ instance StrEncoding AgentErrorType where
|
||||
where
|
||||
text = encodeUtf8 . T.pack
|
||||
|
||||
instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary AgentCryptoError where arbitrary = genericArbitraryU
|
||||
|
||||
cryptoErrToSyncState :: AgentCryptoError -> RatchetSyncState
|
||||
cryptoErrToSyncState = \case
|
||||
DECRYPT_AES -> RSAllowed
|
||||
@@ -1957,3 +1876,25 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
unless (B.null s) $ throwError $ CMD SIZE
|
||||
pure body
|
||||
Nothing -> return . Left $ CMD SYNTAX
|
||||
|
||||
$(J.deriveJSON defaultJSON ''RcvQueueInfo)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''SndQueueInfo)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''ConnectionStats)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON fstToLower) ''MsgErrorType)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON fstToLower) ''MsgIntegrity)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''CommandErrorType)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''ConnectionErrorType)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''BrokerErrorType)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''AgentCryptoError)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''SMPAgentError)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''AgentErrorType)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Agent.RetryInterval
|
||||
|
||||
@@ -23,7 +23,7 @@ import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore)
|
||||
import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), simplexMQVersion)
|
||||
import Simplex.Messaging.Transport.Server (loadTLSServerParams, runTransportServer, defaultTransportServerConfig)
|
||||
import Simplex.Messaging.Transport.Server (defaultTransportServerConfig, loadTLSServerParams, runTransportServer)
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import UnliftIO.Async (race_)
|
||||
import qualified UnliftIO.Exception as E
|
||||
|
||||
@@ -576,6 +576,8 @@ data StoreError
|
||||
SEServerNotFound
|
||||
| -- | Connection already used.
|
||||
SEConnDuplicate
|
||||
| -- | Confirmed snd queue already exists.
|
||||
SESndQueueExists
|
||||
| -- | Wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa.
|
||||
-- 'upgradeRcvConnToDuplex' and 'upgradeSndConnToDuplex' do not allow duplex connections - they would also return this error.
|
||||
SEBadConnType ConnType
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
@@ -17,12 +16,12 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite
|
||||
( SQLiteStore (..),
|
||||
@@ -220,8 +219,7 @@ import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Class
|
||||
import Crypto.Random (ChaChaDRG, randomBytesGenerate)
|
||||
import Data.Aeson (FromJSON, ToJSON)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (second)
|
||||
import Data.ByteString (ByteString)
|
||||
@@ -247,7 +245,6 @@ import Database.SQLite.Simple.FromField
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import qualified Database.SQLite3 as SQLite3
|
||||
import GHC.Generics (Generic)
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.FileTransfer.Client (XFTPChunkSpec (..))
|
||||
import Simplex.FileTransfer.Description
|
||||
@@ -267,7 +264,7 @@ import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..))
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Parsers (blobFieldParser, dropPrefix, fromTextField_, sumTypeJSON)
|
||||
import Simplex.Messaging.Parsers (blobFieldParser, defaultJSON, dropPrefix, fromTextField_, sumTypeJSON)
|
||||
import Simplex.Messaging.Protocol
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
@@ -277,7 +274,7 @@ import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
|
||||
import System.Exit (exitFailure)
|
||||
import System.FilePath (takeDirectory)
|
||||
import System.IO (hFlush, stdout)
|
||||
import UnliftIO.Exception (onException, bracketOnError)
|
||||
import UnliftIO.Exception (bracketOnError, onException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
@@ -287,7 +284,7 @@ data MigrationError
|
||||
= MEUpgrade {upMigrations :: [UpMigration]}
|
||||
| MEDowngrade {downMigrations :: [String]}
|
||||
| MigrationError {mtrError :: MTRError}
|
||||
deriving (Eq, Show, Generic)
|
||||
deriving (Eq, Show)
|
||||
|
||||
migrationErrorDescription :: MigrationError -> String
|
||||
migrationErrorDescription = \case
|
||||
@@ -297,18 +294,12 @@ migrationErrorDescription = \case
|
||||
"Database version is newer than the app.\nConfirm to back up and downgrade using these migrations: " <> intercalate ", " dms
|
||||
MigrationError err -> mtrErrorDescription err
|
||||
|
||||
instance ToJSON MigrationError where
|
||||
toJSON = J.genericToJSON . sumTypeJSON $ dropPrefix "ME"
|
||||
toEncoding = J.genericToEncoding . sumTypeJSON $ dropPrefix "ME"
|
||||
|
||||
data UpMigration = UpMigration {upName :: String, withDown :: Bool}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
deriving (Eq, Show)
|
||||
|
||||
upMigration :: Migration -> UpMigration
|
||||
upMigration Migration {name, down} = UpMigration name $ isJust down
|
||||
|
||||
instance ToJSON UpMigration where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
|
||||
data MigrationConfirmation = MCYesUp | MCYesUpDown | MCConsole | MCError
|
||||
deriving (Eq, Show)
|
||||
|
||||
@@ -347,10 +338,10 @@ migrateSchema st migrations confirmMigrations = do
|
||||
Right ms@(MTRUp ums)
|
||||
| dbNew st -> Migrations.run st ms $> Right ()
|
||||
| otherwise -> case confirmMigrations of
|
||||
MCYesUp -> run ms
|
||||
MCYesUpDown -> run ms
|
||||
MCConsole -> confirm err >> run ms
|
||||
MCError -> pure $ Left err
|
||||
MCYesUp -> run ms
|
||||
MCYesUpDown -> run ms
|
||||
MCConsole -> confirm err >> run ms
|
||||
MCError -> pure $ Left err
|
||||
where
|
||||
err = MEUpgrade $ map upMigration ums -- "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map name ums)
|
||||
Right ms@(MTRDown dms) -> case confirmMigrations of
|
||||
@@ -561,10 +552,23 @@ createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, dupl
|
||||
|
||||
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
|
||||
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
serverKeyHash_ <- createServer_ db server
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
void $ insertSndQueue_ db connId q serverKeyHash_
|
||||
-- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_
|
||||
ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
serverKeyHash_ <- createServer_ db server
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
void $ insertSndQueue_ db connId q serverKeyHash_
|
||||
|
||||
checkConfirmedSndQueueExists_ :: DB.Connection -> SndQueue -> IO Bool
|
||||
checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do
|
||||
fromMaybe False
|
||||
<$> maybeFirstRow
|
||||
fromOnly
|
||||
( DB.query
|
||||
db
|
||||
"SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1"
|
||||
(host server, port server, sndId, New)
|
||||
)
|
||||
|
||||
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
|
||||
getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
|
||||
@@ -980,7 +984,7 @@ updatePendingMsgRIState db connId msgId RI2State {slowInterval, fastInterval} =
|
||||
getPendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO [InternalId]
|
||||
getPendingMsgs db connId SndQueue {dbQueueId} =
|
||||
map fromOnly
|
||||
<$> DB.query db "SELECT internal_id FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
|
||||
<$> DB.query db "SELECT internal_id FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ? ORDER BY internal_id ASC" (connId, dbQueueId)
|
||||
|
||||
deletePendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
deletePendingMsgs db connId SndQueue {dbQueueId} =
|
||||
@@ -1079,12 +1083,12 @@ countPendingSndDeliveries_ db connId msgId = do
|
||||
|
||||
deleteRcvMsgHashesExpired :: DB.Connection -> NominalDiffTime -> IO ()
|
||||
deleteRcvMsgHashesExpired db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
DB.execute db "DELETE FROM encrypted_rcv_message_hashes WHERE created_at < ?" (Only cutoffTs)
|
||||
|
||||
deleteSndMsgsExpired :: DB.Connection -> NominalDiffTime -> IO ()
|
||||
deleteSndMsgsExpired db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
DB.execute
|
||||
db
|
||||
"DELETE FROM messages WHERE internal_ts < ? AND internal_snd_id IS NOT NULL"
|
||||
@@ -1163,7 +1167,7 @@ getSkippedMsgKeys :: DB.Connection -> ConnId -> IO SkippedMsgKeys
|
||||
getSkippedMsgKeys db connId =
|
||||
skipped <$> DB.query db "SELECT header_key, msg_n, msg_key FROM skipped_messages WHERE conn_id = ?" (Only connId)
|
||||
where
|
||||
skipped ms = foldl' addSkippedKey M.empty ms
|
||||
skipped = foldl' addSkippedKey M.empty
|
||||
addSkippedKey smks (hk, msgN, mk) = M.alter (Just . addMsgKey) hk smks
|
||||
where
|
||||
addMsgKey = maybe (M.singleton msgN mk) (M.insert msgN mk)
|
||||
@@ -1734,15 +1738,15 @@ getAnyConn deleted' dbConn connId =
|
||||
Just (cData@ConnData {deleted}, cMode)
|
||||
| deleted /= deleted' -> pure $ Left SEConnNotFound
|
||||
| otherwise -> do
|
||||
rQ <- getRcvQueuesByConnId_ dbConn connId
|
||||
sQ <- getSndQueuesByConnId_ dbConn connId
|
||||
pure $ case (rQ, sQ, cMode) of
|
||||
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
|
||||
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
|
||||
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
|
||||
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
|
||||
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
|
||||
_ -> Left SEConnNotFound
|
||||
rQ <- getRcvQueuesByConnId_ dbConn connId
|
||||
sQ <- getSndQueuesByConnId_ dbConn connId
|
||||
pure $ case (rQ, sQ, cMode) of
|
||||
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
|
||||
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
|
||||
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
|
||||
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
|
||||
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
|
||||
_ -> Left SEConnNotFound
|
||||
|
||||
getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
|
||||
getConns = getAnyConns_ False
|
||||
@@ -1804,7 +1808,7 @@ checkRatchetKeyHashExists db connId hash = do
|
||||
|
||||
deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO ()
|
||||
deleteRatchetKeyHashesExpired db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
DB.execute db "DELETE FROM processed_ratchet_key_hashes WHERE created_at < ?" (Only cutoffTs)
|
||||
|
||||
-- | returns all connection queues, the first queue is the primary one
|
||||
@@ -2253,7 +2257,7 @@ deleteRcvFile' db rcvFileId =
|
||||
|
||||
getNextRcvChunkToDownload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Maybe RcvFileChunk)
|
||||
getNextRcvChunkToDownload db server@ProtocolServer {host, port, keyHash} ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
maybeFirstRow toChunk $
|
||||
DB.query
|
||||
db
|
||||
@@ -2290,7 +2294,7 @@ getNextRcvChunkToDownload db server@ProtocolServer {host, port, keyHash} ttl = d
|
||||
|
||||
getNextRcvFileToDecrypt :: DB.Connection -> NominalDiffTime -> IO (Maybe RcvFile)
|
||||
getNextRcvFileToDecrypt db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
fileId_ :: Maybe DBRcvFileId <-
|
||||
maybeFirstRow fromOnly $
|
||||
DB.query
|
||||
@@ -2308,7 +2312,7 @@ getNextRcvFileToDecrypt db ttl = do
|
||||
|
||||
getPendingRcvFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer]
|
||||
getPendingRcvFilesServers db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
map toXFTPServer
|
||||
<$> DB.query
|
||||
db
|
||||
@@ -2350,7 +2354,7 @@ getCleanupRcvFilesDeleted db =
|
||||
|
||||
getRcvFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBRcvFileId, RcvFileId, FilePath)]
|
||||
getRcvFilesExpired db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
@@ -2458,7 +2462,7 @@ getChunkReplicaRecipients_ db replicaId =
|
||||
|
||||
getNextSndFileToPrepare :: DB.Connection -> NominalDiffTime -> IO (Maybe SndFile)
|
||||
getNextSndFileToPrepare db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
fileId_ :: Maybe DBSndFileId <-
|
||||
maybeFirstRow fromOnly $
|
||||
DB.query
|
||||
@@ -2539,7 +2543,7 @@ createSndFileReplica db SndFileChunk {sndChunkId} NewSndChunkReplica {server, re
|
||||
|
||||
getNextSndChunkToUpload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Maybe SndFileChunk)
|
||||
getNextSndChunkToUpload db server@ProtocolServer {host, port, keyHash} ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
chunk_ <-
|
||||
maybeFirstRow toChunk $
|
||||
DB.query
|
||||
@@ -2608,7 +2612,7 @@ updateSndChunkReplicaStatus db replicaId status = do
|
||||
|
||||
getPendingSndFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer]
|
||||
getPendingSndFilesServers db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
map toXFTPServer
|
||||
<$> DB.query
|
||||
db
|
||||
@@ -2647,7 +2651,7 @@ getCleanupSndFilesDeleted db =
|
||||
|
||||
getSndFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBSndFileId, SndFileId, Maybe FilePath)]
|
||||
getSndFilesExpired db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
@@ -2687,7 +2691,7 @@ getDeletedSndChunkReplica db deletedSndChunkReplicaId =
|
||||
|
||||
getNextDeletedSndChunkReplica :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Maybe DeletedSndChunkReplica)
|
||||
getNextDeletedSndChunkReplica db ProtocolServer {host, port, keyHash} ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
replicaId_ :: Maybe Int64 <-
|
||||
maybeFirstRow fromOnly $
|
||||
DB.query
|
||||
@@ -2716,7 +2720,7 @@ deleteDeletedSndChunkReplica db deletedSndChunkReplicaId =
|
||||
|
||||
getPendingDelFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer]
|
||||
getPendingDelFilesServers db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
map toXFTPServer
|
||||
<$> DB.query
|
||||
db
|
||||
@@ -2731,5 +2735,9 @@ getPendingDelFilesServers db ttl = do
|
||||
|
||||
deleteDeletedSndChunkReplicasExpired :: DB.Connection -> NominalDiffTime -> IO ()
|
||||
deleteDeletedSndChunkReplicasExpired db ttl = do
|
||||
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
|
||||
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
|
||||
DB.execute db "DELETE FROM deleted_snd_chunk_replicas WHERE created_at < ?" (Only cutoffTs)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''UpMigration)
|
||||
|
||||
$(J.deriveToJSON (sumTypeJSON $ dropPrefix "ME") ''MigrationError)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE StrictData #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.DB
|
||||
( Connection (..),
|
||||
@@ -20,13 +20,12 @@ where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad (when)
|
||||
import Data.Aeson (FromJSON, ToJSON)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Int (Int64)
|
||||
import Data.Time (diffUTCTime, getCurrentTime)
|
||||
import Database.SQLite.Simple (FromRow, NamedParam, Query, ToRow)
|
||||
import qualified Database.SQLite.Simple as SQL
|
||||
import GHC.Generics (Generic)
|
||||
import Simplex.Messaging.Parsers (defaultJSON)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (diffToMilliseconds)
|
||||
@@ -41,9 +40,7 @@ data SlowQueryStats = SlowQueryStats
|
||||
timeMax :: Int64,
|
||||
timeAvg :: Int64
|
||||
}
|
||||
deriving (Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON SlowQueryStats where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Show)
|
||||
|
||||
timeIt :: TMap Query SlowQueryStats -> Query -> IO a -> IO a
|
||||
timeIt slow sql a = do
|
||||
@@ -100,3 +97,5 @@ query_ Connection {conn, slow} sql = timeIt slow sql $ SQL.query_ conn sql
|
||||
queryNamed :: FromRow r => Connection -> Query -> [NamedParam] -> IO [r]
|
||||
queryNamed Connection {conn, slow} sql = timeIt slow sql . SQL.queryNamed conn sql
|
||||
{-# INLINE queryNamed #-}
|
||||
|
||||
$(J.deriveJSON defaultJSON ''SlowQueryStats)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
@@ -27,8 +26,7 @@ module Simplex.Messaging.Agent.Store.SQLite.Migrations
|
||||
where
|
||||
|
||||
import Control.Monad (forM_, when)
|
||||
import Data.Aeson (ToJSON)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.List (intercalate, sortOn)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.Map as M
|
||||
@@ -40,7 +38,6 @@ import Database.SQLite.Simple (Connection, Only (..), Query (..))
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
import qualified Database.SQLite3 as SQLite3
|
||||
import GHC.Generics (Generic)
|
||||
import Simplex.Messaging.Agent.Protocol (extraSMPServerHosts)
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Common
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220101_initial
|
||||
@@ -169,11 +166,7 @@ data MigrationsToRun = MTRUp [Migration] | MTRDown [DownMigration] | MTRNone
|
||||
data MTRError
|
||||
= MTRENoDown {dbMigrations :: [String]}
|
||||
| MTREDifferent {appMigration :: String, dbMigration :: String}
|
||||
deriving (Eq, Show, Generic)
|
||||
|
||||
instance ToJSON MTRError where
|
||||
toJSON = J.genericToJSON . sumTypeJSON $ dropPrefix "MTRE"
|
||||
toEncoding = J.genericToEncoding . sumTypeJSON $ dropPrefix "MTRE"
|
||||
deriving (Eq, Show)
|
||||
|
||||
mtrErrorDescription :: MTRError -> String
|
||||
mtrErrorDescription = \case
|
||||
@@ -192,3 +185,5 @@ migrationsToRun [] dbMs
|
||||
migrationsToRun (a : as) (d : ds)
|
||||
| name a == name d = migrationsToRun as ds
|
||||
| otherwise = Left $ MTREDifferent (name a) (name d)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON $ dropPrefix "MTRE") ''MTRError)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
module Simplex.Messaging.Agent.TRcvQueues
|
||||
( TRcvQueues (getRcvQueues),
|
||||
empty,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
@@ -10,6 +9,7 @@
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
@@ -86,8 +86,7 @@ import Control.Exception
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
@@ -97,13 +96,12 @@ import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Time.Clock (UTCTime, getCurrentTime)
|
||||
import GHC.Generics (Generic)
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (dropPrefix, enumJSON)
|
||||
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON)
|
||||
import Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
@@ -190,14 +188,7 @@ data HostMode
|
||||
HMOnion
|
||||
| -- | prefer (or require) public hosts
|
||||
HMPublic
|
||||
deriving (Eq, Show, Generic)
|
||||
|
||||
instance FromJSON HostMode where
|
||||
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "HM"
|
||||
|
||||
instance ToJSON HostMode where
|
||||
toJSON = J.genericToJSON . enumJSON $ dropPrefix "HM"
|
||||
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "HM"
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- | network configuration for the client
|
||||
data NetworkConfig = NetworkConfig
|
||||
@@ -223,21 +214,10 @@ data NetworkConfig = NetworkConfig
|
||||
smpPingCount :: Int,
|
||||
logTLSErrors :: Bool
|
||||
}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON NetworkConfig where
|
||||
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
|
||||
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
|
||||
deriving (Eq, Show)
|
||||
|
||||
data TransportSessionMode = TSMUser | TSMEntity
|
||||
deriving (Eq, Show, Generic)
|
||||
|
||||
instance ToJSON TransportSessionMode where
|
||||
toJSON = J.genericToJSON . enumJSON $ dropPrefix "TSM"
|
||||
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TSM"
|
||||
|
||||
instance FromJSON TransportSessionMode where
|
||||
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TSM"
|
||||
deriving (Eq, Show)
|
||||
|
||||
defaultNetworkConfig :: NetworkConfig
|
||||
defaultNetworkConfig =
|
||||
@@ -429,11 +409,11 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
|
||||
where
|
||||
response entityId
|
||||
| entityId == entId =
|
||||
case respOrErr of
|
||||
Left e -> Left $ PCEResponseError e
|
||||
Right r -> case protocolError r of
|
||||
Just e -> Left $ PCEProtocolError e
|
||||
_ -> Right r
|
||||
case respOrErr of
|
||||
Left e -> Left $ PCEResponseError e
|
||||
Right r -> case protocolError r of
|
||||
Just e -> Left $ PCEProtocolError e
|
||||
_ -> Right r
|
||||
| otherwise = Left . PCEUnexpectedResponse $ bshow respOrErr
|
||||
sendMsg :: Either err msg -> IO ()
|
||||
sendMsg = \case
|
||||
@@ -661,13 +641,13 @@ sendProtocolCommands c@ProtocolClient {batch, blockSize} cs = do
|
||||
validate rs
|
||||
| diff == 0 = pure $ L.fromList rs
|
||||
| diff > 0 = do
|
||||
putStrLn "send error: fewer responses than expected"
|
||||
pure $ L.fromList $ rs <> replicate diff (Response "" $ Left $ PCETransportError TEBadBlock)
|
||||
putStrLn "send error: fewer responses than expected"
|
||||
pure $ L.fromList $ rs <> replicate diff (Response "" $ Left $ PCETransportError TEBadBlock)
|
||||
| otherwise = do
|
||||
putStrLn "send error: more responses than expected"
|
||||
pure $ L.fromList $ take (L.length cs) rs
|
||||
where
|
||||
diff = L.length cs - length rs
|
||||
putStrLn "send error: more responses than expected"
|
||||
pure $ L.fromList $ take (L.length cs) rs
|
||||
where
|
||||
diff = L.length cs - length rs
|
||||
|
||||
streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO ()
|
||||
streamProtocolCommands c@ProtocolClient {batch, blockSize} cs cb = do
|
||||
@@ -688,8 +668,8 @@ sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do
|
||||
(: []) <$> getResponse c r
|
||||
|
||||
data ClientBatch err msg
|
||||
-- ByteString in CBTransmissions does not include count byte, it is added by tEncodeBatch
|
||||
= CBTransmissions ByteString Int [Request err msg]
|
||||
= -- ByteString in CBTransmissions does not include count byte, it is added by tEncodeBatch
|
||||
CBTransmissions ByteString Int [Request err msg]
|
||||
| CBTransmission ByteString (Request err msg)
|
||||
| CBLargeTransmission (Request err msg)
|
||||
|
||||
@@ -713,9 +693,9 @@ batchClientTransmissions batch blkSize
|
||||
encodeBatch :: ByteString -> Int -> [Request err msg] -> NonEmpty (PCTransmission err msg) -> (ClientBatch err msg, Maybe (NonEmpty (PCTransmission err msg)))
|
||||
encodeBatch s n rs ts@((t, r) :| ts_)
|
||||
| B.length s' <= blkSize - 3 && n < 255 =
|
||||
case L.nonEmpty ts_ of
|
||||
Just ts' -> encodeBatch s' n' rs' ts'
|
||||
Nothing -> (CBTransmissions s' n' (reverse rs'), Nothing)
|
||||
case L.nonEmpty ts_ of
|
||||
Just ts' -> encodeBatch s' n' rs' ts'
|
||||
Nothing -> (CBTransmissions s' n' (reverse rs'), Nothing)
|
||||
| n == 0 = (CBLargeTransmission r, L.nonEmpty ts_)
|
||||
| otherwise = (CBTransmissions s n (reverse rs), Just ts)
|
||||
where
|
||||
@@ -765,3 +745,9 @@ mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCo
|
||||
r <- Request entId <$> newEmptyTMVar
|
||||
TM.insert corrId r sentCommands
|
||||
pure r
|
||||
|
||||
$(J.deriveJSON (enumJSON $ dropPrefix "HM") ''HostMode)
|
||||
|
||||
$(J.deriveJSON (enumJSON $ dropPrefix "TSM") ''TransportSessionMode)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''NetworkConfig)
|
||||
|
||||
@@ -18,7 +18,7 @@ import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Bifunctor (first, bimap)
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (partitionEithers)
|
||||
@@ -36,11 +36,11 @@ import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, SMPServer, NtfPrivateSignKey, NotifierId, RcvPrivateSignKey, RecipientId)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, NotifierId, NtfPrivateSignKey, ProtocolServer (..), QueueId, RcvPrivateSignKey, RecipientId, SMPServer)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Util (catchAll_, ($>>=), toChunks)
|
||||
import Simplex.Messaging.Util (catchAll_, toChunks, ($>>=))
|
||||
import System.Timeout (timeout)
|
||||
import UnliftIO (async)
|
||||
import UnliftIO.Exception (Exception)
|
||||
@@ -238,7 +238,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
(tempErrs, finalErrs) = partition (temporaryClientError . snd) errs
|
||||
mapM_ (atomically . addSubscription ca srv) oks
|
||||
mapM_ (liftIO . notify . CAResubscribed srv) $ L.nonEmpty $ map fst oks
|
||||
mapM_ (atomically . removePendingSubscription ca srv . fst) finalErrs
|
||||
mapM_ (atomically . removePendingSubscription ca srv . fst) finalErrs
|
||||
mapM_ (liftIO . notify . CASubError srv) $ L.nonEmpty finalErrs
|
||||
mapM_ (throwE . snd) $ listToMaybe tempErrs
|
||||
|
||||
@@ -281,7 +281,7 @@ subscribeQueue ca srv sub = do
|
||||
|
||||
handleErr e = do
|
||||
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription ca srv $ fst sub
|
||||
removePendingSubscription ca srv (fst sub)
|
||||
throwE e
|
||||
|
||||
subscribeQueuesSMP :: SMPClientAgent -> SMPServer -> NonEmpty (RecipientId, RcvPrivateSignKey) -> IO (NonEmpty (RecipientId, Either SMPClientError ()))
|
||||
@@ -300,14 +300,15 @@ subscribeQueues_ party ca srv subs = do
|
||||
smpSubscribeQueues :: SMPSubParty -> SMPClientAgent -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateSignKey) -> IO (NonEmpty (QueueId, Either SMPClientError ()))
|
||||
smpSubscribeQueues party ca smp srv subs = do
|
||||
rs <- L.zip subs <$> subscribe smp (L.map swap subs)
|
||||
atomically $ forM rs $ \(sub, r) -> (fst sub,) <$> case r of
|
||||
Right () -> do
|
||||
addSubscription ca srv $ first (party,) sub
|
||||
pure $ Right ()
|
||||
Left e -> do
|
||||
when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription ca srv $ (party,) $ fst sub
|
||||
pure $ Left e
|
||||
atomically $ forM rs $ \(sub, r) ->
|
||||
(fst sub,) <$> case r of
|
||||
Right () -> do
|
||||
addSubscription ca srv $ first (party,) sub
|
||||
pure $ Right ()
|
||||
Left e -> do
|
||||
when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription ca srv (party, fst sub)
|
||||
pure $ Left e
|
||||
where
|
||||
subscribe = case party of
|
||||
SPRecipient -> subscribeSMPQueues
|
||||
|
||||
@@ -680,9 +680,9 @@ instance CryptoSignature ASignature where
|
||||
signatureBytes (ASignature _ sig) = signatureBytes sig
|
||||
decodeSignature s
|
||||
| B.length s == Ed25519.signatureSize =
|
||||
ASignature SEd25519 . SignatureEd25519 <$> ed Ed25519.signature s
|
||||
ASignature SEd25519 . SignatureEd25519 <$> ed Ed25519.signature s
|
||||
| B.length s == Ed448.signatureSize =
|
||||
ASignature SEd448 . SignatureEd448 <$> ed Ed448.signature s
|
||||
ASignature SEd448 . SignatureEd448 <$> ed Ed448.signature s
|
||||
| otherwise = Left "bad signature size"
|
||||
where
|
||||
ed alg = first show . CE.eitherCryptoError . alg
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
|
||||
module Simplex.Messaging.Crypto.File
|
||||
( CryptoFile (..),
|
||||
@@ -24,19 +23,18 @@ where
|
||||
import Control.Exception
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Data.Aeson (FromJSON, ToJSON)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy as LB
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import Data.Maybe (isJust)
|
||||
import GHC.Generics (Generic)
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Lazy (LazyByteString)
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import Simplex.Messaging.Parsers (defaultJSON)
|
||||
import Simplex.Messaging.Util (liftEitherWith)
|
||||
import System.Directory (getFileSize)
|
||||
import UnliftIO (Handle, IOMode (..), liftIO)
|
||||
@@ -45,16 +43,10 @@ import UnliftIO.STM
|
||||
|
||||
-- Possibly encrypted local file
|
||||
data CryptoFile = CryptoFile {filePath :: FilePath, cryptoArgs :: Maybe CryptoFileArgs}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON CryptoFile where
|
||||
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
|
||||
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
|
||||
deriving (Eq, Show)
|
||||
|
||||
data CryptoFileArgs = CFArgs {fileKey :: C.SbKey, fileNonce :: C.CbNonce}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON CryptoFileArgs where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Eq, Show)
|
||||
|
||||
data CryptoFileHandle = CFHandle Handle (Maybe (TVar LC.SbState))
|
||||
|
||||
@@ -124,3 +116,7 @@ getFileContentsSize :: CryptoFile -> IO Integer
|
||||
getFileContentsSize (CryptoFile path cfArgs) = do
|
||||
size <- getFileSize path
|
||||
pure $ if isJust cfArgs then size - fromIntegral C.authTagSize else size
|
||||
|
||||
$(J.deriveJSON defaultJSON ''CryptoFileArgs)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''CryptoFile)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
@@ -90,10 +89,10 @@ unPad = fmap snd . splitLen
|
||||
splitLen :: LazyByteString -> Either CryptoError (Int64, LazyByteString)
|
||||
splitLen padded
|
||||
| LB.length lenStr == 8 = case smpDecode $ LB.toStrict lenStr of
|
||||
Right len
|
||||
| len < 0 -> Left CryptoInvalidMsgError
|
||||
| otherwise -> Right (len, LB.take len rest)
|
||||
Left _ -> Left CryptoInvalidMsgError
|
||||
Right len
|
||||
| len < 0 -> Left CryptoInvalidMsgError
|
||||
| otherwise -> Right (len, LB.take len rest)
|
||||
Left _ -> Left CryptoInvalidMsgError
|
||||
| otherwise = Left CryptoInvalidMsgError
|
||||
where
|
||||
(lenStr, rest) = LB.splitAt 8 padded
|
||||
@@ -112,10 +111,10 @@ sbDecrypt :: SbKey -> CbNonce -> LazyByteString -> Either CryptoError LazyByteSt
|
||||
sbDecrypt (SbKey key) (CbNonce nonce) packet
|
||||
| LB.length tag' < 16 = Left CBDecryptError
|
||||
| otherwise = case secretBox sbDecryptChunk key nonce c of
|
||||
Right (tag :| cs)
|
||||
| BA.constEq (LB.toStrict tag') tag -> unPad $ LB.fromChunks cs
|
||||
| otherwise -> Left CBDecryptError
|
||||
Left e -> Left e
|
||||
Right (tag :| cs)
|
||||
| BA.constEq (LB.toStrict tag') tag -> unPad $ LB.fromChunks cs
|
||||
| otherwise -> Left CBDecryptError
|
||||
Left e -> Left e
|
||||
where
|
||||
(tag', c) = LB.splitAt 16 packet
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
@@ -20,8 +20,9 @@ import Control.Monad.Trans.Except
|
||||
import Crypto.Cipher.AES (AES256)
|
||||
import Crypto.Hash (SHA512)
|
||||
import qualified Crypto.KDF.HKDF as H
|
||||
import Data.Aeson (FromJSON, ToJSON)
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as JQ
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy as LB
|
||||
@@ -33,12 +34,11 @@ import Data.Typeable (Typeable)
|
||||
import Data.Word (Word32)
|
||||
import Database.SQLite.Simple.FromField (FromField (..))
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import GHC.Generics
|
||||
import Simplex.Messaging.Agent.QueryString
|
||||
import Simplex.Messaging.Crypto
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (blobFieldDecoder, parseE, parseE')
|
||||
import Simplex.Messaging.Parsers (blobFieldDecoder, defaultJSON, parseE, parseE')
|
||||
import Simplex.Messaging.Version
|
||||
|
||||
currentE2EEncryptVersion :: Version
|
||||
@@ -112,11 +112,11 @@ x3dh v (sk1, rk1) dh1 dh2 dh3 =
|
||||
(hk, nhk, sk)
|
||||
-- for backwards compatibility with clients using agent version before 3.4.0
|
||||
| v == 1 =
|
||||
let (hk', rest) = B.splitAt 32 dhs
|
||||
in uncurry (hk',,) $ B.splitAt 32 rest
|
||||
let (hk', rest) = B.splitAt 32 dhs
|
||||
in uncurry (hk',,) $ B.splitAt 32 rest
|
||||
| otherwise =
|
||||
let salt = B.replicate 64 '\0'
|
||||
in hkdf3 salt dhs "SimpleXX3DH"
|
||||
let salt = B.replicate 64 '\0'
|
||||
in hkdf3 salt dhs "SimpleXX3DH"
|
||||
|
||||
type RatchetX448 = Ratchet 'X448
|
||||
|
||||
@@ -135,29 +135,20 @@ data Ratchet a = Ratchet
|
||||
rcNHKs :: HeaderKey,
|
||||
rcNHKr :: HeaderKey
|
||||
}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance AlgorithmI a => ToJSON (Ratchet a) where
|
||||
toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Eq, Show)
|
||||
|
||||
data SndRatchet a = SndRatchet
|
||||
{ rcDHRr :: PublicKey a,
|
||||
rcCKs :: RatchetKey,
|
||||
rcHKs :: HeaderKey
|
||||
}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance AlgorithmI a => ToJSON (SndRatchet a) where
|
||||
toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Eq, Show)
|
||||
|
||||
data RcvRatchet = RcvRatchet
|
||||
{ rcCKr :: RatchetKey,
|
||||
rcHKr :: HeaderKey
|
||||
}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON RcvRatchet where
|
||||
toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Eq, Show)
|
||||
|
||||
type SkippedMsgKeys = Map HeaderKey SkippedHdrMsgKeys
|
||||
|
||||
@@ -204,10 +195,6 @@ instance ToJSON RatchetKey where
|
||||
instance FromJSON RatchetKey where
|
||||
parseJSON = fmap RatchetKey . strParseJSON "Key"
|
||||
|
||||
instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode
|
||||
|
||||
instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder J.eitherDecodeStrict'
|
||||
|
||||
instance ToField MessageKey where toField = toField . smpEncode
|
||||
|
||||
instance FromField MessageKey where fromField = blobFieldDecoder smpDecode
|
||||
@@ -428,9 +415,9 @@ rcDecrypt rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do
|
||||
| rcNr + maxSkip < untilN = Left $ CERatchetTooManySkipped (untilN + 1 - rcNr)
|
||||
| rcNr == untilN = Right (r, M.empty)
|
||||
| otherwise =
|
||||
let (rcCKr', rcNr', mks) = advanceRcvRatchet (untilN - rcNr) rcCKr rcNr M.empty
|
||||
r' = r {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr'}
|
||||
in Right (r', M.singleton rcHKr mks)
|
||||
let (rcCKr', rcNr', mks) = advanceRcvRatchet (untilN - rcNr) rcCKr rcNr M.empty
|
||||
r' = r {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr'}
|
||||
in Right (r', M.singleton rcHKr mks)
|
||||
advanceRcvRatchet :: Word32 -> RatchetKey -> Word32 -> SkippedHdrMsgKeys -> (RatchetKey, Word32, SkippedHdrMsgKeys)
|
||||
advanceRcvRatchet 0 ck msgNs mks = (ck, msgNs, mks)
|
||||
advanceRcvRatchet n ck msgNs mks =
|
||||
@@ -493,3 +480,23 @@ hkdf3 salt ikm info = (s1, s2, s3)
|
||||
out = H.expand prk info 96
|
||||
(s1, rest) = B.splitAt 32 out
|
||||
(s2, s3) = B.splitAt 32 rest
|
||||
|
||||
$(JQ.deriveJSON defaultJSON ''RcvRatchet)
|
||||
|
||||
instance AlgorithmI a => ToJSON (SndRatchet a) where
|
||||
toEncoding = $(JQ.mkToEncoding defaultJSON ''SndRatchet)
|
||||
toJSON = $(JQ.mkToJSON defaultJSON ''SndRatchet)
|
||||
|
||||
instance AlgorithmI a => FromJSON (SndRatchet a) where
|
||||
parseJSON = $(JQ.mkParseJSON defaultJSON ''SndRatchet)
|
||||
|
||||
instance AlgorithmI a => ToJSON (Ratchet a) where
|
||||
toEncoding = $(JQ.mkToEncoding defaultJSON ''Ratchet)
|
||||
toJSON = $(JQ.mkToJSON defaultJSON ''Ratchet)
|
||||
|
||||
instance AlgorithmI a => FromJSON (Ratchet a) where
|
||||
parseJSON = $(JQ.mkParseJSON defaultJSON ''Ratchet)
|
||||
|
||||
instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode
|
||||
|
||||
instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder J.eitherDecodeStrict'
|
||||
|
||||
@@ -109,7 +109,7 @@ lenP = fromIntegral . c2w <$> A.anyChar
|
||||
{-# INLINE lenP #-}
|
||||
|
||||
instance Encoding a => Encoding (Maybe a) where
|
||||
smpEncode s = maybe "0" (("1" <>) . smpEncode) s
|
||||
smpEncode = maybe "0" (("1" <>) . smpEncode)
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP =
|
||||
smpP >>= \case
|
||||
|
||||
@@ -518,7 +518,7 @@ instance ToJSON NtfTknStatus where
|
||||
|
||||
instance FromJSON NtfTknStatus where
|
||||
parseJSON = J.withText "NtfTknStatus" $ either fail pure . smpDecode . encodeUtf8
|
||||
|
||||
|
||||
checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e)
|
||||
checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of
|
||||
Just Refl -> Right c
|
||||
|
||||
@@ -16,13 +16,12 @@ import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Reader
|
||||
import Data.Bifunctor (second)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.List (intercalate, sort)
|
||||
import Data.List.NonEmpty (NonEmpty(..))
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
@@ -205,10 +204,10 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
|
||||
Just err -> (subs, oks, err : errs) -- permanent error, log and don't retry subscription
|
||||
Nothing -> (sub : subs, oks, errs) -- temporary error, retry subscription
|
||||
|
||||
-- | Subscribe to queues. The list of results can have a different order.
|
||||
-- \| Subscribe to queues. The list of results can have a different order.
|
||||
subscribeQueues :: SMPServer -> NonEmpty NtfSubData -> IO (NonEmpty (NtfSubData, Either SMPClientError ()))
|
||||
subscribeQueues srv subs =
|
||||
L.map (second snd) . L.zip subs <$> subscribeQueuesNtfs ca srv (L.map sub subs)
|
||||
L.zipWith (\s r -> (s, snd r)) subs <$> subscribeQueuesNtfs ca srv (L.map sub subs)
|
||||
where
|
||||
sub NtfSubData {smpQueue = SMPQueueNtf {notifierId}, notifierKey} = (notifierId, notifierKey)
|
||||
|
||||
@@ -248,10 +247,11 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
|
||||
forM errs (\((_, ntfId), err) -> handleSubError (SMPQueueNtf srv ntfId) err)
|
||||
>>= logSubErrors srv . catMaybes . L.toList
|
||||
|
||||
logSubStatus srv event n = when (n > 0) $
|
||||
logInfo $ "SMP server " <> event <> " " <> showServer' srv <> " (" <> tshow n <> " subscriptions)"
|
||||
logSubStatus srv event n =
|
||||
when (n > 0) . logInfo $
|
||||
"SMP server " <> event <> " " <> showServer' srv <> " (" <> tshow n <> " subscriptions)"
|
||||
|
||||
logSubErrors :: SMPServer -> [NtfSubStatus] -> M ()
|
||||
logSubErrors :: SMPServer -> [NtfSubStatus] -> M ()
|
||||
logSubErrors srv errs = forM_ (L.group $ sort errs) $ \errs' -> do
|
||||
logError $ "SMP subscription errors on server " <> showServer' srv <> ": " <> tshow (L.head errs') <> " (" <> tshow (length errs') <> " errors)"
|
||||
|
||||
@@ -289,14 +289,14 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do
|
||||
case ntf of
|
||||
PNVerification _
|
||||
| status /= NTInvalid && status /= NTExpired ->
|
||||
deliverNotification pp tkn ntf >>= \case
|
||||
Right _ -> do
|
||||
status_ <- atomically $ stateTVar tknStatus $ \case
|
||||
NTActive -> (Nothing, NTActive)
|
||||
NTConfirmed -> (Nothing, NTConfirmed)
|
||||
_ -> (Just NTConfirmed, NTConfirmed)
|
||||
forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status'
|
||||
_ -> pure ()
|
||||
deliverNotification pp tkn ntf >>= \case
|
||||
Right _ -> do
|
||||
status_ <- atomically $ stateTVar tknStatus $ \case
|
||||
NTActive -> (Nothing, NTActive)
|
||||
NTConfirmed -> (Nothing, NTConfirmed)
|
||||
_ -> (Just NTConfirmed, NTConfirmed)
|
||||
forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status'
|
||||
_ -> pure ()
|
||||
| otherwise -> logError "bad notification token status"
|
||||
PNCheckMessages -> checkActiveTkn status $ do
|
||||
void $ deliverNotification pp tkn ntf
|
||||
@@ -345,7 +345,8 @@ runNtfClientTransport th@THandle {sessionId} = do
|
||||
raceAny_ ([liftIO $ send th c, client c s ps, receive th c] <> disconnectThread_ c expCfg)
|
||||
`finally` liftIO (clientDisconnected c)
|
||||
where
|
||||
disconnectThread_ c expCfg = maybe [] ((: []) . liftIO . disconnectTransport th c activeAt) expCfg
|
||||
disconnectThread_ c (Just expCfg) = [liftIO $ disconnectTransport th c activeAt expCfg]
|
||||
disconnectThread_ _ _ = []
|
||||
|
||||
clientDisconnected :: NtfServerClient -> IO ()
|
||||
clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False
|
||||
@@ -463,16 +464,16 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
|
||||
else pure $ NRErr AUTH
|
||||
TVFY code -- this allows repeated verification for cases when client connection dropped before server response
|
||||
| (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do
|
||||
logDebug "TVFY - token verified"
|
||||
st <- asks store
|
||||
updateTknStatus tkn NTActive
|
||||
tIds <- atomically $ removeInactiveTokenRegistrations st tkn
|
||||
forM_ tIds cancelInvervalNotifications
|
||||
incNtfStat tknVerified
|
||||
pure NROk
|
||||
logDebug "TVFY - token verified"
|
||||
st <- asks store
|
||||
updateTknStatus tkn NTActive
|
||||
tIds <- atomically $ removeInactiveTokenRegistrations st tkn
|
||||
forM_ tIds cancelInvervalNotifications
|
||||
incNtfStat tknVerified
|
||||
pure NROk
|
||||
| otherwise -> do
|
||||
logDebug "TVFY - incorrect code or token status"
|
||||
pure $ NRErr AUTH
|
||||
logDebug "TVFY - incorrect code or token status"
|
||||
pure $ NRErr AUTH
|
||||
TCHK -> do
|
||||
logDebug "TCHK"
|
||||
pure $ NRTkn status
|
||||
@@ -509,16 +510,16 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
|
||||
TCRN int
|
||||
| int < 20 -> pure $ NRErr QUOTA
|
||||
| otherwise -> do
|
||||
logDebug "TCRN"
|
||||
atomically $ writeTVar tknCronInterval int
|
||||
atomically (TM.lookup tknId intervalNotifiers) >>= \case
|
||||
Nothing -> runIntervalNotifier int
|
||||
Just IntervalNotifier {interval, action} ->
|
||||
unless (interval == int) $ do
|
||||
uninterruptibleCancel action
|
||||
runIntervalNotifier int
|
||||
withNtfLog $ \s -> logTokenCron s tknId int
|
||||
pure NROk
|
||||
logDebug "TCRN"
|
||||
atomically $ writeTVar tknCronInterval int
|
||||
atomically (TM.lookup tknId intervalNotifiers) >>= \case
|
||||
Nothing -> runIntervalNotifier int
|
||||
Just IntervalNotifier {interval, action} ->
|
||||
unless (interval == int) $ do
|
||||
uninterruptibleCancel action
|
||||
runIntervalNotifier int
|
||||
withNtfLog $ \s -> logTokenCron s tknId int
|
||||
pure NROk
|
||||
where
|
||||
runIntervalNotifier interval = do
|
||||
action <- async . intervalNotifier $ fromIntegral interval * 1000000 * 60
|
||||
|
||||
@@ -34,7 +34,7 @@ import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
|
||||
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
|
||||
import System.IO (IOMode (..))
|
||||
import System.Mem.Weak (Weak)
|
||||
import UnliftIO.STM
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.Main where
|
||||
|
||||
import Data.Either (fromRight)
|
||||
import Data.Functor (($>))
|
||||
import Data.Ini (lookupValue, readIniFile)
|
||||
import Data.Maybe (fromMaybe)
|
||||
@@ -92,7 +91,7 @@ ntfServerCLI cfgPath logPath =
|
||||
hSetBuffering stdout LineBuffering
|
||||
hSetBuffering stderr LineBuffering
|
||||
fp <- checkSavedFingerprint cfgPath defaultX509Config
|
||||
let host = fromRight "<hostnames>" $ T.unpack <$> lookupValue "TRANSPORT" "host" ini
|
||||
let host = either (const "<hostnames>") T.unpack $ lookupValue "TRANSPORT" "host" ini
|
||||
port = T.unpack $ strictIni "TRANSPORT" "port" ini
|
||||
cfg@NtfServerConfig {transports, storeLogFile} = serverConfig
|
||||
srv = ProtoServerWithAuth (NtfServer [THDomainName host] (if port == "443" then "" else port) (C.KeyHash fp)) Nothing
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
|
||||
{-# HLINT ignore "Use newtype instead of data" #-}
|
||||
@@ -23,15 +23,15 @@ import qualified Crypto.Store.PKCS8 as PK
|
||||
import Data.ASN1.BinaryEncoding (DER (..))
|
||||
import Data.ASN1.Encoding
|
||||
import Data.ASN1.Types
|
||||
import Data.Aeson (FromJSON, ToJSON, (.=))
|
||||
import Data.Aeson (ToJSON, (.=))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
import qualified Data.Aeson.TH as JQ
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Builder (lazyByteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import qualified Data.CaseInsensitive as CI
|
||||
import Data.Int (Int64)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Maybe (isNothing)
|
||||
@@ -40,8 +40,7 @@ import qualified Data.Text as T
|
||||
import Data.Time.Clock.System
|
||||
import qualified Data.X509 as X
|
||||
import qualified Data.X509.CertificateStore as XS
|
||||
import GHC.Generics (Generic)
|
||||
import Network.HTTP.Types (HeaderName, Status)
|
||||
import Network.HTTP.Types (Status)
|
||||
import qualified Network.HTTP.Types as N
|
||||
import Network.HTTP2.Client (Request)
|
||||
import qualified Network.HTTP2.Client as H
|
||||
@@ -49,7 +48,9 @@ import Network.Socket (HostName, ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS.Internal
|
||||
import Simplex.Messaging.Notifications.Server.Store (NtfTknData (..))
|
||||
import Simplex.Messaging.Parsers (defaultJSON)
|
||||
import Simplex.Messaging.Protocol (EncNMsgMeta)
|
||||
import Simplex.Messaging.Transport.HTTP2 (HTTP2Body (..))
|
||||
import Simplex.Messaging.Transport.HTTP2.Client
|
||||
@@ -61,17 +62,13 @@ data JWTHeader = JWTHeader
|
||||
{ alg :: Text, -- key algorithm, ES256 for APNS
|
||||
kid :: Text -- key ID
|
||||
}
|
||||
deriving (Show, Generic)
|
||||
|
||||
instance ToJSON JWTHeader where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Show)
|
||||
|
||||
data JWTClaims = JWTClaims
|
||||
{ iss :: Text, -- issuer, team ID for APNS
|
||||
iat :: Int64 -- issue time, seconds from epoch
|
||||
}
|
||||
deriving (Show, Generic)
|
||||
|
||||
instance ToJSON JWTClaims where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Show)
|
||||
|
||||
data JWTToken = JWTToken JWTHeader JWTClaims
|
||||
deriving (Show)
|
||||
@@ -83,6 +80,10 @@ mkJWTToken hdr iss = do
|
||||
|
||||
type SignedJWTToken = ByteString
|
||||
|
||||
$(JQ.deriveToJSON defaultJSON ''JWTHeader)
|
||||
|
||||
$(JQ.deriveToJSON defaultJSON ''JWTClaims)
|
||||
|
||||
signedJWTToken :: EC.PrivateKey -> JWTToken -> IO SignedJWTToken
|
||||
signedJWTToken pk (JWTToken hdr claims) = do
|
||||
let hc = jwtEncode hdr <> "." <> jwtEncode claims
|
||||
@@ -122,24 +123,13 @@ instance StrEncoding PNMessageData where
|
||||
pure PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta}
|
||||
|
||||
data APNSNotification = APNSNotification {aps :: APNSNotificationBody, notificationData :: Maybe J.Value}
|
||||
deriving (Show, Generic)
|
||||
|
||||
instance ToJSON APNSNotification where
|
||||
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
|
||||
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
|
||||
deriving (Show)
|
||||
|
||||
data APNSNotificationBody
|
||||
= APNSBackground {contentAvailable :: Int}
|
||||
| APNSMutableContent {mutableContent :: Int, alert :: APNSAlertBody, category :: Maybe Text}
|
||||
| APNSAlert {alert :: APNSAlertBody, badge :: Maybe Int, sound :: Maybe Text, category :: Maybe Text}
|
||||
deriving (Show, Generic)
|
||||
|
||||
apnsJSONOptions :: J.Options
|
||||
apnsJSONOptions = J.defaultOptions {J.omitNothingFields = True, J.sumEncoding = J.UntaggedValue, J.fieldLabelModifier = J.camelTo2 '-'}
|
||||
|
||||
instance ToJSON APNSNotificationBody where
|
||||
toJSON = J.genericToJSON apnsJSONOptions
|
||||
toEncoding = J.genericToEncoding apnsJSONOptions
|
||||
deriving (Show)
|
||||
|
||||
type APNSNotificationData = Map Text Text
|
||||
|
||||
@@ -305,6 +295,10 @@ apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case
|
||||
|
||||
-- apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing}
|
||||
|
||||
$(JQ.deriveToJSON apnsJSONOptions ''APNSNotificationBody)
|
||||
|
||||
$(JQ.deriveToJSON defaultJSON ''APNSNotification)
|
||||
|
||||
apnsRequest :: APNSPushClient -> ByteString -> APNSNotification -> IO Request
|
||||
apnsRequest c tkn ntf@APNSNotification {aps} = do
|
||||
signedJWT <- getApnsJWTToken c
|
||||
@@ -337,7 +331,8 @@ type PushProviderClient = NtfTknData -> PushNotification -> ExceptT PushProvider
|
||||
|
||||
-- this is not a newtype on purpose to have a correct JSON encoding as a record
|
||||
data APNSErrorResponse = APNSErrorResponse {reason :: Text}
|
||||
deriving (Generic, FromJSON)
|
||||
|
||||
$(JQ.deriveFromJSON defaultJSON ''APNSErrorResponse)
|
||||
|
||||
apnsPushProviderClient :: APNSPushClient -> PushProviderClient
|
||||
apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {token = DeviceToken _ tknStr} pn = do
|
||||
@@ -356,15 +351,15 @@ apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {toke
|
||||
result status reason'
|
||||
| status == Just N.ok200 = pure ()
|
||||
| status == Just N.badRequest400 =
|
||||
case reason' of
|
||||
"BadDeviceToken" -> throwError PPTokenInvalid
|
||||
"DeviceTokenNotForTopic" -> throwError PPTokenInvalid
|
||||
"TopicDisallowed" -> throwError PPPermanentError
|
||||
_ -> err status reason'
|
||||
case reason' of
|
||||
"BadDeviceToken" -> throwError PPTokenInvalid
|
||||
"DeviceTokenNotForTopic" -> throwError PPTokenInvalid
|
||||
"TopicDisallowed" -> throwError PPPermanentError
|
||||
_ -> err status reason'
|
||||
| status == Just N.forbidden403 = case reason' of
|
||||
"ExpiredProviderToken" -> throwError PPPermanentError -- there should be no point retrying it as the token was refreshed
|
||||
"InvalidProviderToken" -> throwError PPPermanentError
|
||||
_ -> err status reason'
|
||||
"ExpiredProviderToken" -> throwError PPPermanentError -- there should be no point retrying it as the token was refreshed
|
||||
"InvalidProviderToken" -> throwError PPPermanentError
|
||||
_ -> err status reason'
|
||||
| status == Just N.gone410 = throwError PPTokenInvalid
|
||||
| status == Just N.serviceUnavailable503 = liftIO (disconnectApnsHTTP2Client c) >> throwError PPRetryLater
|
||||
-- Just tooManyRequests429 -> TooManyRequests - too many requests for the same token
|
||||
@@ -372,12 +367,3 @@ apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {toke
|
||||
err :: Maybe Status -> Text -> ExceptT PushProviderError IO ()
|
||||
err s r = throwError $ PPResponseError s r
|
||||
liftHTTPS2 a = ExceptT $ first PPConnection <$> a
|
||||
|
||||
hApnsTopic :: HeaderName
|
||||
hApnsTopic = CI.mk "apns-topic"
|
||||
|
||||
hApnsPushType :: HeaderName
|
||||
hApnsPushType = CI.mk "apns-push-type"
|
||||
|
||||
hApnsPriority :: HeaderName
|
||||
hApnsPriority = CI.mk "apns-priority"
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.Push.APNS.Internal where
|
||||
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.CaseInsensitive as CI
|
||||
import Network.HTTP.Types (HeaderName)
|
||||
import Simplex.Messaging.Parsers (defaultJSON)
|
||||
|
||||
hApnsTopic :: HeaderName
|
||||
hApnsTopic = CI.mk "apns-topic"
|
||||
|
||||
hApnsPushType :: HeaderName
|
||||
hApnsPushType = CI.mk "apns-push-type"
|
||||
|
||||
hApnsPriority :: HeaderName
|
||||
hApnsPriority = CI.mk "apns-priority"
|
||||
|
||||
apnsJSONOptions :: J.Options
|
||||
apnsJSONOptions = defaultJSON {J.sumEncoding = J.UntaggedValue, J.fieldLabelModifier = J.camelTo2 '-'}
|
||||
@@ -4,7 +4,6 @@
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.Store where
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE StrictData #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.StoreLog
|
||||
|
||||
@@ -50,9 +50,9 @@ ntfServerHandshake c kh ntfVRange = do
|
||||
getHandshake th >>= \case
|
||||
NtfClientHandshake {ntfVersion, keyHash}
|
||||
| keyHash /= kh ->
|
||||
throwError $ TEHandshake IDENTITY
|
||||
| ntfVersion `isCompatible` ntfVRange -> do
|
||||
pure (th :: THandle c) {thVersion = ntfVersion}
|
||||
throwError $ TEHandshake IDENTITY
|
||||
| ntfVersion `isCompatible` ntfVRange ->
|
||||
pure (th :: THandle c) {thVersion = ntfVersion}
|
||||
| otherwise -> throwError $ TEHandshake VERSION
|
||||
|
||||
-- | Notifcations server client transport handshake.
|
||||
|
||||
@@ -85,7 +85,7 @@ blobFieldDecoder dec = \case
|
||||
Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e)
|
||||
f -> returnError ConversionFailed f "expecting SQLBlob column type"
|
||||
|
||||
fromTextField_ :: (Typeable a) => (Text -> Maybe a) -> Field -> Ok a
|
||||
fromTextField_ :: Typeable a => (Text -> Maybe a) -> Field -> Ok a
|
||||
fromTextField_ fromText = \case
|
||||
f@(Field (SQLText t) _) ->
|
||||
case fromText t of
|
||||
@@ -151,3 +151,6 @@ singleFieldJSON_ objectTag tagModifier =
|
||||
J.nullaryToObject = True,
|
||||
J.omitNothingFields = True
|
||||
}
|
||||
|
||||
defaultJSON :: J.Options
|
||||
defaultJSON = J.defaultOptions {J.omitNothingFields = True}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
@@ -17,6 +16,7 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE StrictData #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilyDependencies #-}
|
||||
@@ -135,7 +135,7 @@ module Simplex.Messaging.Protocol
|
||||
noAuthSrv,
|
||||
|
||||
-- * TCP transport functions
|
||||
TransportBatch(..),
|
||||
TransportBatch (..),
|
||||
tPut,
|
||||
tPutLog,
|
||||
tGet,
|
||||
@@ -155,7 +155,7 @@ import Control.Applicative (optional, (<|>))
|
||||
import Control.Concurrent (threadDelay)
|
||||
import Control.Monad.Except
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
@@ -170,9 +170,7 @@ import Data.Maybe (isJust, isNothing)
|
||||
import Data.String
|
||||
import Data.Time.Clock.System (SystemTime (..))
|
||||
import Data.Type.Equality
|
||||
import GHC.Generics (Generic)
|
||||
import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
@@ -182,7 +180,6 @@ import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts (..))
|
||||
import Simplex.Messaging.Util (bshow, eitherToMaybe, (<$?>))
|
||||
import Simplex.Messaging.Version
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
|
||||
currentSMPClientVersion :: Version
|
||||
currentSMPClientVersion = 2
|
||||
@@ -309,17 +306,19 @@ instance StrEncoding SubscriptionMode where
|
||||
SMSubscribe -> "subscribe"
|
||||
SMOnlyCreate -> "only-create"
|
||||
strP =
|
||||
(A.string "subscribe" $> SMSubscribe) <|> (A.string "only-create" $> SMOnlyCreate)
|
||||
<?> "SubscriptionMode"
|
||||
(A.string "subscribe" $> SMSubscribe)
|
||||
<|> (A.string "only-create" $> SMOnlyCreate)
|
||||
<?> "SubscriptionMode"
|
||||
|
||||
instance Encoding SubscriptionMode where
|
||||
smpEncode = \case
|
||||
SMSubscribe -> "S"
|
||||
SMOnlyCreate -> "C"
|
||||
smpP = A.anyChar >>= \case
|
||||
'S' -> pure SMSubscribe
|
||||
'C' -> pure SMOnlyCreate
|
||||
_ -> fail "bad SubscriptionMode"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'S' -> pure SMSubscribe
|
||||
'C' -> pure SMOnlyCreate
|
||||
_ -> fail "bad SubscriptionMode"
|
||||
|
||||
data BrokerMsg where
|
||||
-- SMP broker messages (responses, client messages, notifications)
|
||||
@@ -472,9 +471,7 @@ instance Encoding NMsgMeta where
|
||||
|
||||
-- it must be data for correct JSON encoding
|
||||
data MsgFlags = MsgFlags {notification :: Bool}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON MsgFlags where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- this encoding should not become bigger than 7 bytes (currently it is 1 byte)
|
||||
instance Encoding MsgFlags where
|
||||
@@ -997,14 +994,7 @@ data ErrorType
|
||||
INTERNAL
|
||||
| -- | used internally, never returned by the server (to be removed)
|
||||
DUPLICATE_ -- not part of SMP protocol, used internally
|
||||
deriving (Eq, Generic, Read, Show)
|
||||
|
||||
instance ToJSON ErrorType where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON id
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON id
|
||||
|
||||
instance FromJSON ErrorType where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON id
|
||||
deriving (Eq, Read, Show)
|
||||
|
||||
instance StrEncoding ErrorType where
|
||||
strEncode = \case
|
||||
@@ -1026,18 +1016,7 @@ data CommandError
|
||||
HAS_AUTH
|
||||
| -- | transmission has no required entity ID (e.g. SMP queue)
|
||||
NO_ENTITY
|
||||
deriving (Eq, Generic, Read, Show)
|
||||
|
||||
instance ToJSON CommandError where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON id
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON id
|
||||
|
||||
instance FromJSON CommandError where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON id
|
||||
|
||||
instance Arbitrary ErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary CommandError where arbitrary = genericArbitraryU
|
||||
deriving (Eq, Read, Show)
|
||||
|
||||
-- | SMP transmission parser.
|
||||
transmissionP :: Parser RawTransmission
|
||||
@@ -1306,7 +1285,7 @@ tPutLog th s = do
|
||||
pure r
|
||||
|
||||
-- ByteString does not include length byte, it is added by tEncodeBatch
|
||||
data TransportBatch = TBTransmissions Int ByteString | TBTransmission ByteString | TBLargeTransmission
|
||||
data TransportBatch = TBTransmissions Int ByteString | TBTransmission ByteString | TBLargeTransmission
|
||||
|
||||
-- | encodes and batches transmissions into blocks,
|
||||
batchTransmissions :: Bool -> Int -> NonEmpty SentRawTransmission -> [TransportBatch]
|
||||
@@ -1319,22 +1298,22 @@ batchTransmissions batch bSize
|
||||
let (n, s, ts_) = encodeBatch 0 "" ts
|
||||
r = if n == 0 then TBLargeTransmission else TBTransmissions n s
|
||||
rs' = r : rs
|
||||
in case ts_ of
|
||||
Just ts' -> mkBatch rs' ts'
|
||||
_ -> rs'
|
||||
in case ts_ of
|
||||
Just ts' -> mkBatch rs' ts'
|
||||
_ -> rs'
|
||||
mkBatch1 :: ByteString -> TransportBatch
|
||||
mkBatch1 s = if B.length s > bSize - 2 then TBLargeTransmission else TBTransmission s
|
||||
encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString))
|
||||
encodeBatch n s ts@(t :| ts_)
|
||||
| n == 255 = (n, s, Just ts)
|
||||
| otherwise =
|
||||
let s' = s <> smpEncode (Large t)
|
||||
n' = n + 1
|
||||
in if B.length s' > bSize - 3 -- one byte is reserved for the number of messages in the batch
|
||||
then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts
|
||||
else case L.nonEmpty ts_ of
|
||||
Just ts' -> encodeBatch n' s' ts'
|
||||
_ -> (n', s', Nothing)
|
||||
let s' = s <> smpEncode (Large t)
|
||||
n' = n + 1
|
||||
in if B.length s' > bSize - 3 -- one byte is reserved for the number of messages in the batch
|
||||
then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts
|
||||
else case L.nonEmpty ts_ of
|
||||
Just ts' -> encodeBatch n' s' ts'
|
||||
_ -> (n', s', Nothing)
|
||||
|
||||
tEncode :: SentRawTransmission -> ByteString
|
||||
tEncode (sig, t) = smpEncode (C.signatureBytes sig) <> t
|
||||
@@ -1373,8 +1352,8 @@ tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => SessionId ->
|
||||
tDecodeParseValidate sessionId v = \case
|
||||
Right RawTransmission {signature, signed, sessId, corrId, entityId, command}
|
||||
| sessId == sessionId ->
|
||||
let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature
|
||||
in either (const $ tError corrId) (tParseValidate signed) decodedTransmission
|
||||
let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature
|
||||
in either (const $ tError corrId) (tParseValidate signed) decodedTransmission
|
||||
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession))
|
||||
Left _ -> tError ""
|
||||
where
|
||||
@@ -1385,3 +1364,9 @@ tDecodeParseValidate sessionId v = \case
|
||||
tParseValidate signed t@(sig, corrId, entityId, command) =
|
||||
let cmd = parseProtocol @err @cmd v command >>= checkCredentials t
|
||||
in (sig, signed, (CorrId corrId, entityId, cmd))
|
||||
|
||||
$(J.deriveJSON defaultJSON ''MsgFlags)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''CommandError)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''ErrorType)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
@@ -8,6 +7,7 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
@@ -116,9 +116,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
||||
restoreServerMessages
|
||||
restoreServerStats
|
||||
raceAny_
|
||||
( serverThread s "server subscribedQ" subscribedQ subscribers subscriptions cancelSub :
|
||||
serverThread s "server ntfSubscribedQ" ntfSubscribedQ Env.notifiers ntfSubscriptions (\_ -> pure ()) :
|
||||
map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg
|
||||
( serverThread s "server subscribedQ" subscribedQ subscribers subscriptions cancelSub
|
||||
: serverThread s "server ntfSubscribedQ" ntfSubscribedQ Env.notifiers ntfSubscriptions (\_ -> pure ())
|
||||
: map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg
|
||||
)
|
||||
`finally` withLock (savingLock s) "final" (saveServer False)
|
||||
where
|
||||
@@ -148,7 +148,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
||||
updateSubscribers :: STM (Maybe (QueueId, Client))
|
||||
updateSubscribers = do
|
||||
(qId, clnt) <- readTQueue $ subQ s
|
||||
let clientToBeNotified = \c' ->
|
||||
let clientToBeNotified c' =
|
||||
if sameClientSession clnt c'
|
||||
then pure Nothing
|
||||
else do
|
||||
@@ -277,9 +277,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
||||
hPutStrLn h $ "Clients: " <> show (length clients)
|
||||
forM_ (M.toList clients) $ \(cid, Client {sessionId, connected, activeAt, subscriptions}) -> do
|
||||
hPutStrLn h . B.unpack $ "Client " <> encode cid <> " $" <> encode sessionId
|
||||
readTVarIO connected >>= hPutStrLn h . (" connected: " <>) . show
|
||||
readTVarIO activeAt >>= hPutStrLn h . (" activeAt: " <>) . B.unpack . strEncode
|
||||
readTVarIO subscriptions >>= hPutStrLn h . (" subscriptions: " <>) . show . M.size
|
||||
readTVarIO connected >>= hPutStrLn h . (" connected: " <>) . show
|
||||
readTVarIO activeAt >>= hPutStrLn h . (" activeAt: " <>) . B.unpack . strEncode
|
||||
readTVarIO subscriptions >>= hPutStrLn h . (" subscriptions: " <>) . show . M.size
|
||||
CPStats -> do
|
||||
ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, msgSentNtf, msgRecvNtf, qCount, msgCount} <- unliftIO u $ asks serverStats
|
||||
putStat "fromTime" fromTime
|
||||
@@ -666,27 +666,27 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
sendMessage qr msgFlags msgBody
|
||||
| B.length msgBody > maxMessageLength = pure $ err LARGE_MSG
|
||||
| otherwise = case status qr of
|
||||
QueueOff -> return $ err AUTH
|
||||
QueueActive ->
|
||||
case C.maxLenBS msgBody of
|
||||
Left _ -> pure $ err LARGE_MSG
|
||||
Right body -> do
|
||||
msg_ <- time "SEND" $ do
|
||||
q <- getStoreMsgQueue "SEND" $ recipientId qr
|
||||
expireMessages q
|
||||
atomically . writeMsg q =<< mkMessage body
|
||||
case msg_ of
|
||||
Nothing -> pure $ err QUOTA
|
||||
Just msg -> time "SEND ok" $ do
|
||||
stats <- asks serverStats
|
||||
when (notification msgFlags) $ do
|
||||
atomically . trySendNotification msg =<< asks idsDrg
|
||||
atomically $ modifyTVar' (msgSentNtf stats) (+ 1)
|
||||
atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr)
|
||||
atomically $ modifyTVar' (msgSent stats) (+ 1)
|
||||
atomically $ modifyTVar' (msgCount stats) (subtract 1)
|
||||
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
|
||||
pure ok
|
||||
QueueOff -> return $ err AUTH
|
||||
QueueActive ->
|
||||
case C.maxLenBS msgBody of
|
||||
Left _ -> pure $ err LARGE_MSG
|
||||
Right body -> do
|
||||
msg_ <- time "SEND" $ do
|
||||
q <- getStoreMsgQueue "SEND" $ recipientId qr
|
||||
expireMessages q
|
||||
atomically . writeMsg q =<< mkMessage body
|
||||
case msg_ of
|
||||
Nothing -> pure $ err QUOTA
|
||||
Just msg -> time "SEND ok" $ do
|
||||
stats <- asks serverStats
|
||||
when (notification msgFlags) $ do
|
||||
atomically . trySendNotification msg =<< asks idsDrg
|
||||
atomically $ modifyTVar' (msgSentNtf stats) (+ 1)
|
||||
atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr)
|
||||
atomically $ modifyTVar' (msgSent stats) (+ 1)
|
||||
atomically $ modifyTVar' (msgCount stats) (subtract 1)
|
||||
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
|
||||
pure ok
|
||||
where
|
||||
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
|
||||
mkMessage body = do
|
||||
@@ -767,7 +767,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
||||
msgTs' = msg.msgTs
|
||||
|
||||
setDelivered :: Sub -> Message -> STM Bool
|
||||
setDelivered s msg = tryPutTMVar (delivered s) $ msg.msgId
|
||||
setDelivered s msg = tryPutTMVar (delivered s) msg.msgId
|
||||
|
||||
getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue
|
||||
getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do
|
||||
|
||||
@@ -31,7 +31,7 @@ import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
|
||||
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
|
||||
import Simplex.Messaging.Version
|
||||
import System.IO (IOMode (..))
|
||||
import System.Mem.Weak (Weak)
|
||||
|
||||
@@ -11,7 +11,6 @@ module Simplex.Messaging.Server.Main where
|
||||
import Control.Monad (void)
|
||||
import Crypto.Random (getRandomBytes)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (fromRight)
|
||||
import Data.Functor (($>))
|
||||
import Data.Ini (lookupValue, readIniFile)
|
||||
import Data.Maybe (fromMaybe)
|
||||
@@ -24,7 +23,7 @@ import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (BasicAuth (..), ProtoServerWithAuth (ProtoServerWithAuth), pattern SMPServer)
|
||||
import Simplex.Messaging.Server (runSMPServer)
|
||||
import Simplex.Messaging.Server.CLI
|
||||
import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defaultInactiveClientExpiration, defaultMessageExpiration, defMsgExpirationDays)
|
||||
import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defMsgExpirationDays, defaultInactiveClientExpiration, defaultMessageExpiration)
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Transport (simplexMQVersion, supportedSMPServerVRange)
|
||||
import Simplex.Messaging.Transport.Client (TransportHost (..))
|
||||
@@ -60,15 +59,15 @@ smpServerCLI cfgPath logPath =
|
||||
initializeServer opts
|
||||
| scripted opts = initialize opts
|
||||
| otherwise = do
|
||||
putStrLn "Use `smp-server init -h` for available options."
|
||||
void $ withPrompt "SMP server will be initialized (press Enter)" getLine
|
||||
enableStoreLog <- onOffPrompt "Enable store log to restore queues and messages on server restart" True
|
||||
logStats <- onOffPrompt "Enable logging daily statistics" False
|
||||
putStrLn "Require a password to create new messaging queues?"
|
||||
password <- withPrompt "'r' for random (default), 'n' - no password, or enter password: " serverPassword
|
||||
let host = fromMaybe (ip opts) (fqdn opts)
|
||||
host' <- withPrompt ("Enter server FQDN or IP address for certificate (" <> host <> "): ") getLine
|
||||
initialize opts {enableStoreLog, logStats, fqdn = if null host' then fqdn opts else Just host', password}
|
||||
putStrLn "Use `smp-server init -h` for available options."
|
||||
void $ withPrompt "SMP server will be initialized (press Enter)" getLine
|
||||
enableStoreLog <- onOffPrompt "Enable store log to restore queues and messages on server restart" True
|
||||
logStats <- onOffPrompt "Enable logging daily statistics" False
|
||||
putStrLn "Require a password to create new messaging queues?"
|
||||
password <- withPrompt "'r' for random (default), 'n' - no password, or enter password: " serverPassword
|
||||
let host = fromMaybe (ip opts) (fqdn opts)
|
||||
host' <- withPrompt ("Enter server FQDN or IP address for certificate (" <> host <> "): ") getLine
|
||||
initialize opts {enableStoreLog, logStats, fqdn = if null host' then fqdn opts else Just host', password}
|
||||
where
|
||||
serverPassword =
|
||||
getLine >>= \case
|
||||
@@ -121,8 +120,8 @@ smpServerCLI cfgPath logPath =
|
||||
\# The password will not be shared with the connecting contacts, you must share it only\n\
|
||||
\# with the users who you want to allow creating messaging queues on your server.\n"
|
||||
<> ( case basicAuth of
|
||||
Just auth -> "create_password: " <> T.unpack (safeDecodeUtf8 $ strEncode auth)
|
||||
_ -> "# create_password: password to create new queues (any printable ASCII characters without whitespace, '@', ':' and '/')"
|
||||
Just auth -> "create_password: " <> T.unpack (safeDecodeUtf8 $ strEncode auth)
|
||||
_ -> "# create_password: password to create new queues (any printable ASCII characters without whitespace, '@', ':' and '/')"
|
||||
)
|
||||
<> "\n\n\
|
||||
\[TRANSPORT]\n\
|
||||
@@ -141,7 +140,7 @@ smpServerCLI cfgPath logPath =
|
||||
hSetBuffering stdout LineBuffering
|
||||
hSetBuffering stderr LineBuffering
|
||||
fp <- checkSavedFingerprint cfgPath defaultX509Config
|
||||
let host = fromRight "<hostnames>" $ T.unpack <$> lookupValue "TRANSPORT" "host" ini
|
||||
let host = either (const "<hostnames>") T.unpack $ lookupValue "TRANSPORT" "host" ini
|
||||
port = T.unpack $ strictIni "TRANSPORT" "port" ini
|
||||
cfg@ServerConfig {transports, storeLogFile, newQueueBasicAuth, messageExpiration, inactiveClientExpiration} = serverConfig
|
||||
srv = ProtoServerWithAuth (SMPServer [THDomainName host] (if port == "5223" then "" else port) (C.KeyHash fp)) newQueueBasicAuth
|
||||
@@ -186,9 +185,10 @@ smpServerCLI cfgPath logPath =
|
||||
allowNewQueues = fromMaybe True $ iniOnOff "AUTH" "new_queues" ini,
|
||||
newQueueBasicAuth = either error id <$> strDecodeIni "AUTH" "create_password" ini,
|
||||
messageExpiration =
|
||||
Just defaultMessageExpiration
|
||||
{ ttl = 86400 * readIniDefault defMsgExpirationDays "STORE_LOG" "expire_messages_days" ini
|
||||
},
|
||||
Just
|
||||
defaultMessageExpiration
|
||||
{ ttl = 86400 * readIniDefault defMsgExpirationDays "STORE_LOG" "expire_messages_days" ini
|
||||
},
|
||||
inactiveClientExpiration =
|
||||
settingIsOn "INACTIVE_CLIENTS" "disconnect" ini
|
||||
$> ExpirationConfig
|
||||
|
||||
@@ -64,7 +64,7 @@ flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
|
||||
flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue)
|
||||
|
||||
snapshotMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
|
||||
snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue . msgQueue)
|
||||
snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue . msgQueue)
|
||||
where
|
||||
snapshotTQueue q = do
|
||||
msgs <- flushTQueue q
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
module Simplex.Messaging.Server.StoreLog
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
@@ -10,6 +9,7 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
-- |
|
||||
@@ -63,8 +63,7 @@ where
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans.Except (throwE)
|
||||
import Data.Aeson (FromJSON, ToJSON)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Bitraversable (bimapM)
|
||||
@@ -74,9 +73,7 @@ import qualified Data.ByteString.Lazy as BL
|
||||
import Data.Default (def)
|
||||
import Data.Functor (($>))
|
||||
import Data.Version (showVersion)
|
||||
import GHC.Generics (Generic)
|
||||
import GHC.IO.Handle.Internals (ioe_EOF)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
import qualified Network.TLS.Extra as TE
|
||||
@@ -87,7 +84,6 @@ import Simplex.Messaging.Parsers (dropPrefix, parse, parseRead1, sumTypeJSON)
|
||||
import Simplex.Messaging.Transport.Buffer
|
||||
import Simplex.Messaging.Util (bshow, catchAll, catchAll_)
|
||||
import Simplex.Messaging.Version
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
import UnliftIO.Exception (Exception)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
@@ -284,14 +280,7 @@ data TransportError
|
||||
TEBadSession
|
||||
| -- | transport handshake error
|
||||
TEHandshake {handshakeErr :: HandshakeError}
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
instance ToJSON TransportError where
|
||||
toJSON = J.genericToJSON . sumTypeJSON $ dropPrefix "TE"
|
||||
toEncoding = J.genericToEncoding . sumTypeJSON $ dropPrefix "TE"
|
||||
|
||||
instance FromJSON TransportError where
|
||||
parseJSON = J.genericParseJSON . sumTypeJSON $ dropPrefix "TE"
|
||||
deriving (Eq, Read, Show, Exception)
|
||||
|
||||
-- | Transport handshake error.
|
||||
data HandshakeError
|
||||
@@ -301,18 +290,7 @@ data HandshakeError
|
||||
VERSION
|
||||
| -- | incorrect server identity
|
||||
IDENTITY
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
instance ToJSON HandshakeError where
|
||||
toJSON = J.genericToJSON $ sumTypeJSON id
|
||||
toEncoding = J.genericToEncoding $ sumTypeJSON id
|
||||
|
||||
instance FromJSON HandshakeError where
|
||||
parseJSON = J.genericParseJSON $ sumTypeJSON id
|
||||
|
||||
instance Arbitrary TransportError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary HandshakeError where arbitrary = genericArbitraryU
|
||||
deriving (Eq, Read, Show, Exception)
|
||||
|
||||
-- | SMP encrypted transport error parser.
|
||||
transportErrorP :: Parser TransportError
|
||||
@@ -354,9 +332,9 @@ smpServerHandshake c kh smpVRange = do
|
||||
getHandshake th >>= \case
|
||||
ClientHandshake {smpVersion, keyHash}
|
||||
| keyHash /= kh ->
|
||||
throwE $ TEHandshake IDENTITY
|
||||
throwE $ TEHandshake IDENTITY
|
||||
| smpVersion `isCompatible` smpVRange -> do
|
||||
pure $ smpThHandle th smpVersion
|
||||
pure $ smpThHandle th smpVersion
|
||||
| otherwise -> throwE $ TEHandshake VERSION
|
||||
|
||||
-- | Client SMP transport handshake.
|
||||
@@ -385,3 +363,7 @@ getHandshake th = ExceptT $ (parse smpP (TEHandshake PARSE) =<<) <$> tGetBlock t
|
||||
|
||||
smpTHandle :: Transport c => c -> THandle c
|
||||
smpTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, batch = False}
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''HandshakeError)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON $ dropPrefix "TE") ''TransportError)
|
||||
|
||||
@@ -8,8 +8,8 @@ import Control.Concurrent.STM
|
||||
import qualified Control.Exception as E
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import GHC.IO.Exception (IOErrorType (..), IOException (..), ioException)
|
||||
import System.Timeout (timeout)
|
||||
import GHC.IO.Exception (ioException, IOException (..), IOErrorType (..))
|
||||
|
||||
data TBuffer = TBuffer
|
||||
{ buffer :: TVar ByteString,
|
||||
@@ -41,9 +41,9 @@ getBuffered tb@TBuffer {buffer} n t_ getChunk = withBufferLock tb $ do
|
||||
readChunks firstChunk b
|
||||
| B.length b >= n = pure b
|
||||
| otherwise =
|
||||
get >>= \case
|
||||
"" -> pure b
|
||||
s -> readChunks False $ b <> s
|
||||
get >>= \case
|
||||
"" -> pure b
|
||||
s -> readChunks False $ b <> s
|
||||
where
|
||||
get
|
||||
| firstChunk = getChunk
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Transport.Credentials
|
||||
( tlsCredentials,
|
||||
|
||||
@@ -22,10 +22,9 @@ import qualified Network.TLS as T
|
||||
import Numeric.Natural (Natural)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Transport (SessionId)
|
||||
import Simplex.Messaging.Transport (SessionId, TLS)
|
||||
import Simplex.Messaging.Transport.Client (TransportClientConfig (..), TransportHost (..), runTLSTransportClient)
|
||||
import Simplex.Messaging.Transport.HTTP2
|
||||
import Simplex.Messaging.Transport (TLS)
|
||||
import UnliftIO.STM
|
||||
import UnliftIO.Timeout
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{-# LANGUAGE MultiWayIf #-}
|
||||
|
||||
module Simplex.Messaging.Transport.HTTP2.File where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString as B
|
||||
import Data.ByteString.Builder (Builder, byteString)
|
||||
import Data.Int (Int64)
|
||||
import Data.Word (Word32)
|
||||
import GHC.IO.Handle.Internals (ioe_EOF)
|
||||
import System.IO (Handle)
|
||||
|
||||
fileBlockSize :: Int
|
||||
fileBlockSize = 16384
|
||||
|
||||
hReceiveFile :: (Int -> IO ByteString) -> Handle -> Word32 -> IO Int64
|
||||
hReceiveFile _ _ 0 = pure 0
|
||||
hReceiveFile getBody h size = get $ fromIntegral size
|
||||
where
|
||||
get sz = do
|
||||
ch <- getBody fileBlockSize
|
||||
let chSize = fromIntegral $ B.length ch
|
||||
if
|
||||
| chSize > sz -> pure (chSize - sz)
|
||||
| chSize > 0 -> B.hPut h ch >> get (sz - chSize)
|
||||
| otherwise -> pure (-fromIntegral sz)
|
||||
|
||||
hSendFile :: Handle -> (Builder -> IO ()) -> Word32 -> IO ()
|
||||
hSendFile h send = go
|
||||
where
|
||||
go 0 = pure ()
|
||||
go sz =
|
||||
getFileChunk h sz >>= \ch -> do
|
||||
send $ byteString ch
|
||||
go $ sz - fromIntegral (B.length ch)
|
||||
|
||||
getFileChunk :: Handle -> Word32 -> IO ByteString
|
||||
getFileChunk h sz = do
|
||||
ch <- B.hGet h fileBlockSize
|
||||
if B.null ch
|
||||
then ioe_EOF
|
||||
else pure $ B.take (fromIntegral sz) ch -- sz >= xftpBlockSize
|
||||
@@ -1,5 +1,4 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Transport.HTTP2.Server where
|
||||
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
{-# LANGUAGE CApiFFI #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
|
||||
module Simplex.Messaging.Transport.KeepAlive where
|
||||
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Foreign.C (CInt (..))
|
||||
import GHC.Generics (Generic)
|
||||
import Network.Socket
|
||||
import Simplex.Messaging.Parsers (defaultJSON)
|
||||
|
||||
data KeepAliveOpts = KeepAliveOpts
|
||||
{ keepIdle :: Int,
|
||||
keepIntvl :: Int,
|
||||
keepCnt :: Int
|
||||
}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON KeepAliveOpts where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
deriving (Eq, Show)
|
||||
|
||||
defaultKeepAliveOpts :: KeepAliveOpts
|
||||
defaultKeepAliveOpts =
|
||||
@@ -68,3 +65,5 @@ setSocketKeepAlive sock KeepAliveOpts {keepCnt, keepIdle, keepIntvl} = do
|
||||
setSocketOption sock (SockOpt _SOL_TCP _TCP_KEEPIDLE) keepIdle
|
||||
setSocketOption sock (SockOpt _SOL_TCP _TCP_KEEPINTVL) keepIntvl
|
||||
setSocketOption sock (SockOpt _SOL_TCP _TCP_KEEPCNT) keepCnt
|
||||
|
||||
$(J.deriveJSON defaultJSON ''KeepAliveOpts)
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Transport.Server
|
||||
( runTransportServer,
|
||||
runTCPServer,
|
||||
TransportServerConfig (..),
|
||||
( TransportServerConfig (..),
|
||||
defaultTransportServerConfig,
|
||||
runTransportServer,
|
||||
runTransportServerSocket,
|
||||
runTCPServer,
|
||||
runTCPServerSocket,
|
||||
startTCPServer,
|
||||
loadSupportedTLSServerParams,
|
||||
loadTLSServerParams,
|
||||
loadFingerprint,
|
||||
@@ -46,10 +49,11 @@ data TransportServerConfig = TransportServerConfig
|
||||
deriving (Eq, Show)
|
||||
|
||||
defaultTransportServerConfig :: TransportServerConfig
|
||||
defaultTransportServerConfig = TransportServerConfig
|
||||
{ logTLSErrors = True,
|
||||
transportTimeout = 40000000
|
||||
}
|
||||
defaultTransportServerConfig =
|
||||
TransportServerConfig
|
||||
{ logTLSErrors = True,
|
||||
transportTimeout = 40000000
|
||||
}
|
||||
|
||||
serverTransportConfig :: TransportServerConfig -> TransportConfig
|
||||
serverTransportConfig TransportServerConfig {logTLSErrors} =
|
||||
@@ -60,11 +64,15 @@ serverTransportConfig TransportServerConfig {logTLSErrors} =
|
||||
--
|
||||
-- All accepted connections are passed to the passed function.
|
||||
runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m ()
|
||||
runTransportServer started port serverParams cfg server = do
|
||||
runTransportServer started port = runTransportServerSocket started (startTCPServer started port) (transportName (TProxy :: TProxy c))
|
||||
|
||||
-- | Run a transport server with provided connection setup and handler.
|
||||
runTransportServerSocket :: (MonadUnliftIO m, T.TLSParams p, Transport a) => TMVar Bool -> IO Socket -> String -> p -> TransportServerConfig -> (a -> m ()) -> m ()
|
||||
runTransportServerSocket started getSocket threadLabel serverParams cfg server = do
|
||||
u <- askUnliftIO
|
||||
let tCfg = serverTransportConfig cfg
|
||||
labelMyThread $ "transport server for " <> transportName (TProxy :: TProxy c)
|
||||
liftIO . runTCPServer started port $ \conn ->
|
||||
labelMyThread $ "transport server for " <> threadLabel
|
||||
liftIO . runTCPServerSocket started getSocket $ \conn ->
|
||||
E.bracket
|
||||
(connectTLS Nothing tCfg serverParams conn >>= getServerConnection tCfg)
|
||||
closeConnection
|
||||
@@ -72,11 +80,15 @@ runTransportServer started port serverParams cfg server = do
|
||||
|
||||
-- | Run TCP server without TLS
|
||||
runTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO ()
|
||||
runTCPServer started port server = do
|
||||
runTCPServer started port = runTCPServerSocket started $ startTCPServer started port
|
||||
|
||||
-- | Wrap socket provider in a TCP server bracket.
|
||||
runTCPServerSocket :: TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO ()
|
||||
runTCPServerSocket started getSocket server = do
|
||||
clients <- atomically TM.empty
|
||||
clientId <- newTVarIO 0
|
||||
E.bracket
|
||||
(startTCPServer started port)
|
||||
getSocket
|
||||
(closeServer started clients)
|
||||
$ \sock -> forever . E.bracketOnError (accept sock) (close . fst) $ \(conn, _peer) -> do
|
||||
-- catchAll_ is needed here in case the connection was closed earlier
|
||||
|
||||
@@ -15,9 +15,9 @@ import qualified Network.WebSockets.Stream as S
|
||||
import Simplex.Messaging.Transport
|
||||
( TProxy,
|
||||
Transport (..),
|
||||
TransportConfig (..),
|
||||
TransportError (..),
|
||||
TransportPeer (..),
|
||||
TransportConfig (..),
|
||||
closeTLS,
|
||||
smpBlockSize,
|
||||
withTlsUnique,
|
||||
|
||||
@@ -62,7 +62,7 @@ liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a)
|
||||
liftEitherError f a = liftIOEither (first f <$> a)
|
||||
{-# INLINE liftEitherError #-}
|
||||
|
||||
liftEitherWith :: (MonadError e' m) => (e -> e') -> Either e a -> m a
|
||||
liftEitherWith :: MonadError e' m => (e -> e') -> Either e a -> m a
|
||||
liftEitherWith f = liftEither . first f
|
||||
{-# INLINE liftEitherWith #-}
|
||||
|
||||
@@ -102,7 +102,7 @@ catchAllErrors err action handle = tryAllErrors err action >>= either handle pur
|
||||
{-# INLINE catchAllErrors #-}
|
||||
|
||||
catchThrow :: (MonadUnliftIO m, MonadError e m) => m a -> (E.SomeException -> e) -> m a
|
||||
catchThrow action err = catchAllErrors err action throwError
|
||||
catchThrow action err = catchAllErrors err action throwError
|
||||
{-# INLINE catchThrow #-}
|
||||
|
||||
allFinally :: (MonadUnliftIO m, MonadError e m) => (E.SomeException -> e) -> m a -> m b -> m a
|
||||
@@ -115,12 +115,12 @@ eitherToMaybe = either (const Nothing) Just
|
||||
|
||||
groupOn :: Eq k => (a -> k) -> [a] -> [[a]]
|
||||
groupOn = groupBy . eqOn
|
||||
-- it is equivalent to groupBy ((==) `on` f),
|
||||
-- but it redefines `on` to avoid duplicate computation for most values.
|
||||
-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn
|
||||
-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f`
|
||||
where
|
||||
eqOn f = \x -> let fx = f x in \y -> fx == f y
|
||||
-- it is equivalent to groupBy ((==) `on` f),
|
||||
-- but it redefines `on` to avoid duplicate computation for most values.
|
||||
-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn
|
||||
-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f`
|
||||
eqOn f x = let fx = f x in \y -> fx == f y
|
||||
|
||||
groupAllOn :: Ord k => (a -> k) -> [a] -> [[a]]
|
||||
groupAllOn f = groupOn f . sortOn f
|
||||
@@ -129,7 +129,7 @@ toChunks :: Int -> [a] -> [NonEmpty a]
|
||||
toChunks _ [] = []
|
||||
toChunks n xs =
|
||||
let (ys, xs') = splitAt n xs
|
||||
in maybe id (:) (L.nonEmpty ys) (toChunks n xs')
|
||||
in maybe id (:) (L.nonEmpty ys) (toChunks n xs')
|
||||
|
||||
safeDecodeUtf8 :: ByteString -> Text
|
||||
safeDecodeUtf8 = decodeUtf8With onError
|
||||
|
||||
+1
-1
@@ -49,7 +49,7 @@ extra-deps:
|
||||
- github: simplex-chat/aeson
|
||||
commit: aab7b5a14d6c5ea64c64dcaee418de1bb00dcc2b
|
||||
- github: kazu-yamamoto/http2
|
||||
commit: b5a1b7200cf5bc7044af34ba325284271f6dff25
|
||||
commit: 804fa283f067bd3fd89b8c5f8d25b3047813a517
|
||||
# - ../direct-sqlcipher
|
||||
- github: simplex-chat/direct-sqlcipher
|
||||
commit: f814ee68b16a9447fbb467ccc8f29bdd3546bfd9
|
||||
|
||||
+17
-17
@@ -45,42 +45,42 @@ agentTests (ATransport t) = do
|
||||
describe "Migration tests" migrationTests
|
||||
describe "SMP agent protocol syntax" $ syntaxTests t
|
||||
describe "Establishing duplex connection" $ do
|
||||
it "should connect via one server and one agent" $
|
||||
it "should connect via one server and one agent" $ do
|
||||
smpAgentTest2_1_1 $ testDuplexConnection t
|
||||
it "should connect via one server and one agent (random IDs)" $
|
||||
it "should connect via one server and one agent (random IDs)" $ do
|
||||
smpAgentTest2_1_1 $ testDuplexConnRandomIds t
|
||||
it "should connect via one server and 2 agents" $
|
||||
it "should connect via one server and 2 agents" $ do
|
||||
smpAgentTest2_2_1 $ testDuplexConnection t
|
||||
it "should connect via one server and 2 agents (random IDs)" $
|
||||
it "should connect via one server and 2 agents (random IDs)" $ do
|
||||
smpAgentTest2_2_1 $ testDuplexConnRandomIds t
|
||||
it "should connect via 2 servers and 2 agents" $
|
||||
it "should connect via 2 servers and 2 agents" $ do
|
||||
smpAgentTest2_2_2 $ testDuplexConnection t
|
||||
it "should connect via 2 servers and 2 agents (random IDs)" $
|
||||
it "should connect via 2 servers and 2 agents (random IDs)" $ do
|
||||
smpAgentTest2_2_2 $ testDuplexConnRandomIds t
|
||||
describe "Establishing connections via `contact connection`" $ do
|
||||
it "should connect via contact connection with one server and 3 agents" $
|
||||
it "should connect via contact connection with one server and 3 agents" $ do
|
||||
smpAgentTest3 $ testContactConnection t
|
||||
it "should connect via contact connection with one server and 2 agents (random IDs)" $
|
||||
it "should connect via contact connection with one server and 2 agents (random IDs)" $ do
|
||||
smpAgentTest2_2_1 $ testContactConnRandomIds t
|
||||
it "should support rejecting contact request" $
|
||||
it "should support rejecting contact request" $ do
|
||||
smpAgentTest2_2_1 $ testRejectContactRequest t
|
||||
describe "Connection subscriptions" $ do
|
||||
it "should connect via one server and one agent" $
|
||||
it "should connect via one server and one agent" $ do
|
||||
smpAgentTest3_1_1 $ testSubscription t
|
||||
it "should send notifications to client when server disconnects" $
|
||||
it "should send notifications to client when server disconnects" $ do
|
||||
smpAgentServerTest $ testSubscrNotification t
|
||||
describe "Message delivery and server reconnection" $ do
|
||||
it "should deliver messages after losing server connection and re-connecting" $
|
||||
it "should deliver messages after losing server connection and re-connecting" $ do
|
||||
smpAgentTest2_2_2_needs_server $ testMsgDeliveryServerRestart t
|
||||
it "should connect to the server when server goes up if it initially was down" $
|
||||
it "should connect to the server when server goes up if it initially was down" $ do
|
||||
smpAgentTestN [] $ testServerConnectionAfterError t
|
||||
it "should deliver pending messages after agent restarting" $
|
||||
it "should deliver pending messages after agent restarting" $ do
|
||||
smpAgentTest1_1_1 $ testMsgDeliveryAgentRestart t
|
||||
it "should concurrently deliver messages to connections without blocking" $
|
||||
it "should concurrently deliver messages to connections without blocking" $ do
|
||||
smpAgentTest2_2_1 $ testConcurrentMsgDelivery t
|
||||
it "should deliver messages if one of connections has quota exceeded" $
|
||||
it "should deliver messages if one of connections has quota exceeded" $ do
|
||||
smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t
|
||||
it "should resume delivering messages after exceeding quota once all messages are received" $
|
||||
it "should resume delivering messages after exceeding quota once all messages are received" $ do
|
||||
smpAgentTest2_2_1 $ testResumeDeliveryQuotaExceeded t
|
||||
|
||||
type AEntityTransmission p e = (ACorrId, ConnId, ACommand p e)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
module AgentTests.ConnectionRequestTests where
|
||||
@@ -113,24 +112,24 @@ connectionRequestTests =
|
||||
it "should serialize connection requests" $ do
|
||||
strEncode connectionRequest
|
||||
`shouldBe` "simplex:/invitation#/?v=1&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
strEncode connectionRequestCurrentRange
|
||||
`shouldBe` "simplex:/invitation#/?v=1-4&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "&e2e=v%3D1-2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "&e2e=v%3D1-2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
strEncode connectionRequestClientDataEmpty
|
||||
`shouldBe` "simplex:/invitation#/?v=1&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
<> "&data=%7B%7D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
<> "&data=%7B%7D"
|
||||
strEncode connectionRequestClientData
|
||||
`shouldBe` "simplex:/invitation#/?v=1&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
<> "&data=%7B%22type%22%3A%22group_link%22%2C%20%22group_link_id%22%3A%22abc%22%7D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
<> "&data=%7B%22type%22%3A%22group_link%22%2C%20%22group_link_id%22%3A%22abc%22%7D"
|
||||
it "should parse connection requests" $ do
|
||||
strDecode
|
||||
( "https://simplex.chat/invitation#/?smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23"
|
||||
|
||||
@@ -43,6 +43,7 @@ import Data.Int (Int64)
|
||||
import qualified Data.Map as M
|
||||
import Data.Maybe (isNothing)
|
||||
import qualified Data.Set as S
|
||||
import Data.Time.Clock (diffUTCTime, getCurrentTime)
|
||||
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
||||
import Data.Type.Equality
|
||||
import SMPAgentClient
|
||||
@@ -310,6 +311,7 @@ functionalAPITests t = do
|
||||
describe "Delivery receipts" $ do
|
||||
it "should send and receive delivery receipt" $ withSmpServer t testDeliveryReceipts
|
||||
it "should send delivery receipt only in connection v3+" $ testDeliveryReceiptsVersion t
|
||||
it "send delivery receipts concurrently with messages" $ testDeliveryReceiptsConcurrent t
|
||||
|
||||
testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int
|
||||
testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do
|
||||
@@ -1159,7 +1161,7 @@ testBatchedSubscriptions nCreate nDel t = do
|
||||
a <- getSMPAgentClient' agentCfg initAgentServers2 testDB
|
||||
b <- getSMPAgentClient' agentCfg initAgentServers2 testDB2
|
||||
conns <- runServers $ do
|
||||
conns <- forM [1 .. nCreate :: Int] . const $ makeConnection a b
|
||||
conns <- replicateM (nCreate :: Int) $ makeConnection a b
|
||||
forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId
|
||||
let (aIds', bIds') = unzip $ take nDel conns
|
||||
delete a bIds'
|
||||
@@ -1928,6 +1930,57 @@ testDeliveryReceiptsVersion t = do
|
||||
disconnectAgentClient a'
|
||||
disconnectAgentClient b'
|
||||
|
||||
testDeliveryReceiptsConcurrent :: HasCallStack => ATransport -> IO ()
|
||||
testDeliveryReceiptsConcurrent t =
|
||||
withSmpServerConfigOn t cfg {msgQueueQuota = 128} testPort $ \_ -> do
|
||||
withAgentClients2 $ \a b -> do
|
||||
(aId, bId) <- runRight $ makeConnection a b
|
||||
t1 <- liftIO getCurrentTime
|
||||
concurrently_ (runClient "a" a bId) (runClient "b" b aId)
|
||||
t2 <- liftIO getCurrentTime
|
||||
diffUTCTime t2 t1 `shouldSatisfy` (< 15)
|
||||
liftIO $ noMessages a "nothing else should be delivered to alice"
|
||||
liftIO $ noMessages b "nothing else should be delivered to bob"
|
||||
where
|
||||
runClient :: String -> AgentClient -> ConnId -> IO ()
|
||||
runClient _cName client connId = do
|
||||
concurrently_ send receive
|
||||
where
|
||||
numMsgs = 100
|
||||
send = runRight_ $
|
||||
replicateM_ numMsgs $ do
|
||||
-- liftIO $ print $ cName <> ": sendMessage"
|
||||
void $ sendMessage client connId SMP.noMsgFlags "hello"
|
||||
receive =
|
||||
runRight_ $
|
||||
-- for each sent message: 1 SENT, 1 RCVD, 1 OK for acknowledging RCVD
|
||||
-- for each received message: 1 MSG, 1 OK for acknowledging MSG
|
||||
receiveLoop (numMsgs * 5)
|
||||
receiveLoop :: Int -> ExceptT AgentErrorType IO ()
|
||||
receiveLoop 0 = pure ()
|
||||
receiveLoop n = do
|
||||
r <- getWithTimeout
|
||||
case r of
|
||||
(_, _, SENT _) -> do
|
||||
-- liftIO $ print $ cName <> ": SENT"
|
||||
pure ()
|
||||
(_, _, MSG MsgMeta {recipient = (msgId, _), integrity = MsgOk} _ _) -> do
|
||||
-- liftIO $ print $ cName <> ": MSG " <> show msgId
|
||||
ackMessageAsync client (B.pack . show $ n) connId msgId (Just "")
|
||||
(_, _, RCVD MsgMeta {recipient = (msgId, _), integrity = MsgOk} _) -> do
|
||||
-- liftIO $ print $ cName <> ": RCVD " <> show msgId
|
||||
ackMessageAsync client (B.pack . show $ n) connId msgId Nothing
|
||||
(_, _, OK) -> do
|
||||
-- liftIO $ print $ cName <> ": OK"
|
||||
pure ()
|
||||
r' -> error $ "unexpected event: " <> show r'
|
||||
receiveLoop (n - 1)
|
||||
getWithTimeout :: ExceptT AgentErrorType IO (AEntityTransmission 'AEConn)
|
||||
getWithTimeout = do
|
||||
1000000 `timeout` get client >>= \case
|
||||
Just r -> pure r
|
||||
_ -> error "timeout"
|
||||
|
||||
testTwoUsers :: HasCallStack => IO ()
|
||||
testTwoUsers = withAgentClients2 $ \a b -> do
|
||||
let nc = netCfg initAgentServers
|
||||
|
||||
@@ -504,7 +504,7 @@ testNotificationsSMPRestartBatch n t APNSMockServer {apnsQ} = do
|
||||
a <- getSMPAgentClient' agentCfg initAgentServers2 testDB
|
||||
b <- getSMPAgentClient' agentCfg initAgentServers2 testDB2
|
||||
conns <- runServers $ do
|
||||
conns <- forM [1 .. n :: Int] . const $ makeConnection a b
|
||||
conns <- replicateM (n :: Int) $ makeConnection a b
|
||||
_ <- registerTestToken a "abcd" NMInstant apnsQ
|
||||
liftIO $ threadDelay 1500000
|
||||
forM_ conns $ \(aliceId, bobId) -> do
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
module AgentTests.SQLiteTests (storeTests) where
|
||||
|
||||
@@ -4,7 +4,7 @@ module CoreTests.BatchingTests (batchingTests) where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad
|
||||
import Crypto.Random (MonadRandom(..))
|
||||
import Crypto.Random (MonadRandom (..))
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.List.NonEmpty as L
|
||||
@@ -12,7 +12,7 @@ import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Version (VersionRange(..))
|
||||
import Simplex.Messaging.Version (VersionRange (..))
|
||||
import Test.Hspec
|
||||
|
||||
batchingTests :: Spec
|
||||
|
||||
@@ -195,7 +195,7 @@ testAESGCM = it "should encrypt / decrypt string with a random symmetric key" $
|
||||
cipher `shouldNotBe` plain
|
||||
s `shouldBe` plain
|
||||
|
||||
testEncoding :: (C.AlgorithmI a) => C.SAlgorithm a -> Spec
|
||||
testEncoding :: C.AlgorithmI a => C.SAlgorithm a -> Spec
|
||||
testEncoding alg = it "should encode / decode key" . ioProperty $ do
|
||||
(k, pk) <- C.generateKeyPair alg
|
||||
pure $ \(_ :: Int) ->
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
|
||||
module CoreTests.ProtocolErrorTests where
|
||||
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Simplex.Messaging.Agent.Protocol (AgentErrorType (..), BrokerErrorType (..))
|
||||
import GHC.Generics (Generic)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Simplex.FileTransfer.Protocol (XFTPErrorType (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Protocol (CommandError (..), ErrorType (..))
|
||||
import Simplex.Messaging.Transport (HandshakeError (..), TransportError (..))
|
||||
import Test.Hspec
|
||||
import Test.Hspec.QuickCheck (modifyMaxSuccess)
|
||||
import Test.QuickCheck
|
||||
@@ -16,15 +24,58 @@ import Test.QuickCheck
|
||||
protocolErrorTests :: Spec
|
||||
protocolErrorTests = modifyMaxSuccess (const 1000) $ do
|
||||
describe "errors parsing / serializing" $ do
|
||||
it "should parse SMP protocol errors" . property $ \(err :: AgentErrorType) ->
|
||||
errHasSpaces err
|
||||
|| parseAll strP (strEncode err) == Right err
|
||||
it "should parse SMP protocol errors" . property $ \(err :: ErrorType) ->
|
||||
smpDecode (smpEncode err) == Right err
|
||||
it "should parse SMP agent errors" . property $ \(err :: AgentErrorType) ->
|
||||
errHasSpaces err
|
||||
|| parseAll strP (strEncode err) == Right err
|
||||
|| strDecode (strEncode err) == Right err
|
||||
where
|
||||
errHasSpaces = \case
|
||||
BROKER srv (RESPONSE e) -> hasSpaces srv || hasSpaces e
|
||||
BROKER srv _ -> hasSpaces srv
|
||||
_ -> False
|
||||
hasSpaces s = ' ' `B.elem` encodeUtf8 (T.pack s)
|
||||
|
||||
deriving instance Generic AgentErrorType
|
||||
|
||||
deriving instance Generic CommandErrorType
|
||||
|
||||
deriving instance Generic ConnectionErrorType
|
||||
|
||||
deriving instance Generic BrokerErrorType
|
||||
|
||||
deriving instance Generic SMPAgentError
|
||||
|
||||
deriving instance Generic AgentCryptoError
|
||||
|
||||
deriving instance Generic ErrorType
|
||||
|
||||
deriving instance Generic CommandError
|
||||
|
||||
deriving instance Generic TransportError
|
||||
|
||||
deriving instance Generic HandshakeError
|
||||
|
||||
deriving instance Generic XFTPErrorType
|
||||
|
||||
instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary AgentCryptoError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary ErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary CommandError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary TransportError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary HandshakeError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary XFTPErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module CoreTests.RetryIntervalTests where
|
||||
|
||||
@@ -6,8 +6,8 @@ import Control.Exception (Exception, SomeException, throwIO)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Class
|
||||
import Data.IORef
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import Simplex.Messaging.Util
|
||||
import Test.Hspec
|
||||
import qualified UnliftIO.Exception as UE
|
||||
|
||||
|
||||
@@ -39,13 +39,14 @@ versionRangeTests = modifyMaxSuccess (const 1000) $ do
|
||||
isCompatible (1 :: Version) (vr 2 2) `shouldBe` False
|
||||
it "compatibleVersion should pass isCompatible check" . property $
|
||||
\((min1, max1) :: (V, V)) ((min2, max2) :: (V, V)) ->
|
||||
min1 > max1 || min2 > max2 -- one of ranges is invalid, skip testing it
|
||||
min1 > max1
|
||||
|| min2 > max2 -- one of ranges is invalid, skip testing it
|
||||
|| let w = fromIntegral . fromEnum
|
||||
vr1 = mkVersionRange (w min1) (w max1) :: VersionRange
|
||||
vr2 = mkVersionRange (w min2) (w max2) :: VersionRange
|
||||
in case compatibleVersion vr1 vr2 of
|
||||
Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2
|
||||
_ -> True
|
||||
Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2
|
||||
_ -> True
|
||||
where
|
||||
vr = mkVersionRange
|
||||
compatible :: (VersionRange, VersionRange) -> Maybe Version -> Expectation
|
||||
|
||||
@@ -38,6 +38,7 @@ import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Server (runNtfServerBlocking)
|
||||
import Simplex.Messaging.Notifications.Server.Env
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS.Internal
|
||||
import Simplex.Messaging.Notifications.Transport
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport
|
||||
@@ -186,10 +187,16 @@ instance FromJSON APNSAlertBody where
|
||||
parseJSON (J.String v) = pure $ APNSAlertText v
|
||||
parseJSON invalid = JT.prependFailure "parsing Coord failed, " (JT.typeMismatch "Object" invalid)
|
||||
|
||||
deriving instance Generic APNSNotificationBody
|
||||
|
||||
instance FromJSON APNSNotificationBody where parseJSON = J.genericParseJSON apnsJSONOptions {J.rejectUnknownFields = True}
|
||||
|
||||
deriving instance Generic APNSNotification
|
||||
|
||||
deriving instance FromJSON APNSNotification
|
||||
|
||||
deriving instance Generic APNSErrorResponse
|
||||
|
||||
deriving instance ToJSON APNSErrorResponse
|
||||
|
||||
getAPNSMockServer :: HTTP2ServerConfig -> IO APNSMockServer
|
||||
|
||||
@@ -139,7 +139,6 @@ testNotificationSubscription (ATransport t) =
|
||||
mTs `shouldBe` msgTs
|
||||
(msgBody, "hello") #== "delivered from queue"
|
||||
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, ACK mId1)
|
||||
pure ()
|
||||
-- replace token
|
||||
let tkn' = DeviceToken PPApnsTest "efgh"
|
||||
RespNtf "7" tId' NROk <- signSendRecvNtf nh tknKey ("7", tId, TRPL tkn')
|
||||
|
||||
@@ -832,7 +832,6 @@ testRestoreExpireMessages at@(ATransport t) =
|
||||
msgs'' <- B.readFile testStoreMsgsFile
|
||||
length (B.lines msgs'') `shouldBe` 2
|
||||
B.lines msgs'' `shouldBe` drop 2 (B.lines msgs)
|
||||
|
||||
where
|
||||
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
|
||||
runTest _ test' server = do
|
||||
|
||||
+6
-6
@@ -292,8 +292,8 @@ testXFTPAgentSendRestore = withGlobalLogging logCfgNoLogs $ do
|
||||
|
||||
-- receive file
|
||||
rcp <- getSMPAgentClient' agentCfg initAgentServers testDB2
|
||||
runRight_ $
|
||||
void $ testReceive rcp rfd1 filePath
|
||||
runRight_ . void $
|
||||
testReceive rcp rfd1 filePath
|
||||
|
||||
testXFTPAgentSendCleanup :: HasCallStack => IO ()
|
||||
testXFTPAgentSendCleanup = withGlobalLogging logCfgNoLogs $ do
|
||||
@@ -342,8 +342,8 @@ testXFTPAgentDelete = withGlobalLogging logCfgNoLogs $
|
||||
|
||||
-- receive file
|
||||
rcp1 <- getSMPAgentClient' agentCfg initAgentServers testDB2
|
||||
runRight_ $
|
||||
void $ testReceive rcp1 rfd1 filePath
|
||||
runRight_ . void $
|
||||
testReceive rcp1 rfd1 filePath
|
||||
|
||||
length <$> listDirectory xftpServerFiles `shouldReturn` 6
|
||||
|
||||
@@ -377,8 +377,8 @@ testXFTPAgentDeleteRestore = withGlobalLogging logCfgNoLogs $ do
|
||||
|
||||
-- receive file
|
||||
rcp1 <- getSMPAgentClient' agentCfg initAgentServers testDB2
|
||||
runRight_ $
|
||||
void $ testReceive rcp1 rfd1 filePath
|
||||
runRight_ . void $
|
||||
testReceive rcp1 rfd1 filePath
|
||||
disconnectAgentClient rcp1
|
||||
disconnectAgentClient sndr
|
||||
pure (sfId, sndDescr, rfd2)
|
||||
|
||||
Reference in New Issue
Block a user