Merge remote-tracking branch 'origin/master' into ep/sntrup761

This commit is contained in:
IC Rainbow
2023-10-28 19:27:24 +03:00
78 changed files with 948 additions and 913 deletions
+1
View File
@@ -0,0 +1 @@
- ignore: {name: "Use underscore"}
+1 -1
View File
@@ -19,7 +19,7 @@ source-repository-package
source-repository-package
type: git
location: https://github.com/kazu-yamamoto/http2.git
tag: b5a1b7200cf5bc7044af34ba325284271f6dff25
tag: 804fa283f067bd3fd89b8c5f8d25b3047813a517
source-repository-package
type: git
+30
View File
@@ -0,0 +1,30 @@
indentation: 2
column-limit: none
function-arrows: trailing
comma-style: trailing
import-export-style: trailing
indent-wheres: true
record-brace-space: true
newlines-between-decls: 1
haddock-style: single-line
haddock-style-module: null
let-style: inline
in-style: right-align
single-constraint-parens: never
unicode: never
respectful: true
fixities:
- infixr 9 .
- infixr 8 .:, .:., .=
- infixr 6 <>
- infixr 5 ++
- infixl 4 <$>, <$, $>, <$$>, <$?>
- infixl 4 <*>, <*, *>, <**>
- infix 4 ==, /=
- infixr 3 &&
- infixl 3 <|>
- infixr 2 ||
- infixl 1 >>, >>=
- infixr 1 =<<, >=>, <=<
- infixr 0 $, $!
reexports: []
+2 -3
View File
@@ -46,8 +46,7 @@ dependencies:
- filepath == 1.4.*
- hourglass == 0.2.*
- http-types == 0.12.*
- http2 == 4.1.*
- generic-random == 1.5.*
- http2 >= 4.1.4 && < 4.2
- ini == 0.4.1
- iproute == 1.7.*
- iso8601-time == 0.1.*
@@ -56,7 +55,6 @@ dependencies:
- network >= 3.1.2.7 && < 3.2
- network-transport == 0.5.6
- optparse-applicative >= 0.15 && < 0.17
- QuickCheck == 2.14.*
- process == 1.6.*
- random >= 1.1 && < 1.3
- simple-logger == 0.1.*
@@ -152,6 +150,7 @@ tests:
dependencies:
- simplexmq
- deepseq == 1.4.*
- generic-random == 1.5.*
- hspec == 2.11.*
- hspec-core == 2.11.*
- HUnit == 1.6.*
+8 -8
View File
@@ -364,9 +364,9 @@ The clients can optionally instruct a dedicated push notification server to subs
[`SEND` command](#send-message) includes the notification flag to instruct SMP server whether to send the notification - this flag is forwarded to the recepient inside encrypted envelope, together with the timestamp and the message body, so even if TLS is compromised this flag cannot be used for traffic correlation.
## SMP Transmission andtransport block structure
## SMP Transmission and transport block structure
Each transport block (SMP transmission) has a fixed size of 16384 bytes for traffic uniformity.
Each transport block has a fixed size of 16384 bytes for traffic uniformity.
From SMP version 4 each block can contain multiple transmissions, version 3 blocks have 1 transmission.
Some parts of SMP transmission are padded to a fixed size; this padding is uniformly added as a word16 encoded in network byte order - see `paddedString` syntax.
@@ -387,17 +387,17 @@ Each transmission/block for SMP v3 between the client and the server must have t
```abnf
paddedTransmission = <padded(transmission), 16384>
transmission = [signature] SP signed
signed = sessionIdentifier SP [corrId] SP [queueId] SP smpCommand
transmission = signature signed
signed = sessionIdentifier corrId queueId smpCommand
; corrId is required in client commands and server responses,
; it is empty in server notifications.
corrId = 1*32(%x21-7F) ; any characters other than control/whitespace
queueId = encoded ; max 32 bytes when decoded (24 bytes is used),
corrId = length *OCTET
queueId = length *OCTET
; empty queue ID is used with "create" command and in some server responses
signature = encoded
signature = length *OCTET
; empty signature can be used with "send" before the queue is secured with secure command
; signature is always empty with "ping" and "serverMsg"
encoded = <base64 encoded binary>
length = 1*1 OCTET
```
`base64` encoding should be used with padding, as defined in section 4 of [RFC 4648][9]
+3 -1
View File
@@ -19,7 +19,9 @@ if [ ! -f "${confd}/file-server.ini" ]; then
# Set quota
case "${QUOTA}" in
'') printf 'Please specify $QUOTA environment variable.\n'; exit 1 ;;
*) set -- "$@" --quota "${QUOTA}" ;;
*GB) QUOTA="$(printf ${QUOTA} | tr '[:upper:]' '[:lower:]')"; set -- "$@" --quota "${QUOTA}" ;;
*gb) set -- "$@" --quota "${QUOTA}" ;;
*) printf 'Wrong format. Format should be: 1gb, 10gb, 100gb.\n'; exit 1 ;;
esac
# Init the certificates and configs
+15 -25
View File
@@ -114,6 +114,7 @@ library
Simplex.Messaging.Notifications.Server.Env
Simplex.Messaging.Notifications.Server.Main
Simplex.Messaging.Notifications.Server.Push.APNS
Simplex.Messaging.Notifications.Server.Push.APNS.Internal
Simplex.Messaging.Notifications.Server.Stats
Simplex.Messaging.Notifications.Server.Store
Simplex.Messaging.Notifications.Server.StoreLog
@@ -140,6 +141,7 @@ library
Simplex.Messaging.Transport.Credentials
Simplex.Messaging.Transport.HTTP2
Simplex.Messaging.Transport.HTTP2.Client
Simplex.Messaging.Transport.HTTP2.File
Simplex.Messaging.Transport.HTTP2.Server
Simplex.Messaging.Transport.KeepAlive
Simplex.Messaging.Transport.Server
@@ -160,8 +162,7 @@ library
extra-libraries:
crypto
build-depends:
QuickCheck ==2.14.*
, aeson ==2.2.*
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
, asn1-encoding ==0.9.*
, asn1-types ==0.3.*
@@ -180,10 +181,9 @@ library
, direct-sqlcipher ==2.3.*
, directory ==1.3.*
, filepath ==1.4.*
, generic-random ==1.5.*
, hourglass ==0.2.*
, http-types ==0.12.*
, http2 ==4.1.*
, http2 >=4.1.4 && <4.2
, ini ==0.4.1
, iproute ==1.7.*
, iso8601-time ==0.1.*
@@ -225,8 +225,7 @@ executable ntf-server
apps/ntf-server
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
build-depends:
QuickCheck ==2.14.*
, aeson ==2.2.*
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
, asn1-encoding ==0.9.*
, asn1-types ==0.3.*
@@ -245,10 +244,9 @@ executable ntf-server
, direct-sqlcipher ==2.3.*
, directory ==1.3.*
, filepath ==1.4.*
, generic-random ==1.5.*
, hourglass ==0.2.*
, http-types ==0.12.*
, http2 ==4.1.*
, http2 >=4.1.4 && <4.2
, ini ==0.4.1
, iproute ==1.7.*
, iso8601-time ==0.1.*
@@ -291,8 +289,7 @@ executable smp-agent
apps/smp-agent
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
build-depends:
QuickCheck ==2.14.*
, aeson ==2.2.*
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
, asn1-encoding ==0.9.*
, asn1-types ==0.3.*
@@ -311,10 +308,9 @@ executable smp-agent
, direct-sqlcipher ==2.3.*
, directory ==1.3.*
, filepath ==1.4.*
, generic-random ==1.5.*
, hourglass ==0.2.*
, http-types ==0.12.*
, http2 ==4.1.*
, http2 >=4.1.4 && <4.2
, ini ==0.4.1
, iproute ==1.7.*
, iso8601-time ==0.1.*
@@ -357,8 +353,7 @@ executable smp-server
apps/smp-server
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
build-depends:
QuickCheck ==2.14.*
, aeson ==2.2.*
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
, asn1-encoding ==0.9.*
, asn1-types ==0.3.*
@@ -377,10 +372,9 @@ executable smp-server
, direct-sqlcipher ==2.3.*
, directory ==1.3.*
, filepath ==1.4.*
, generic-random ==1.5.*
, hourglass ==0.2.*
, http-types ==0.12.*
, http2 ==4.1.*
, http2 >=4.1.4 && <4.2
, ini ==0.4.1
, iproute ==1.7.*
, iso8601-time ==0.1.*
@@ -423,8 +417,7 @@ executable xftp
apps/xftp
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
build-depends:
QuickCheck ==2.14.*
, aeson ==2.2.*
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
, asn1-encoding ==0.9.*
, asn1-types ==0.3.*
@@ -443,10 +436,9 @@ executable xftp
, direct-sqlcipher ==2.3.*
, directory ==1.3.*
, filepath ==1.4.*
, generic-random ==1.5.*
, hourglass ==0.2.*
, http-types ==0.12.*
, http2 ==4.1.*
, http2 >=4.1.4 && <4.2
, ini ==0.4.1
, iproute ==1.7.*
, iso8601-time ==0.1.*
@@ -489,8 +481,7 @@ executable xftp-server
apps/xftp-server
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts
build-depends:
QuickCheck ==2.14.*
, aeson ==2.2.*
aeson ==2.2.*
, ansi-terminal >=0.10 && <0.12
, asn1-encoding ==0.9.*
, asn1-types ==0.3.*
@@ -509,10 +500,9 @@ executable xftp-server
, direct-sqlcipher ==2.3.*
, directory ==1.3.*
, filepath ==1.4.*
, generic-random ==1.5.*
, hourglass ==0.2.*
, http-types ==0.12.*
, http2 ==4.1.*
, http2 >=4.1.4 && <4.2
, ini ==0.4.1
, iproute ==1.7.*
, iso8601-time ==0.1.*
@@ -610,7 +600,7 @@ test-suite simplexmq-test
, hspec ==2.11.*
, hspec-core ==2.11.*
, http-types ==0.12.*
, http2 ==4.1.*
, http2 >=4.1.4 && <4.2
, ini ==0.4.1
, iproute ==1.7.*
, iso8601-time ==0.1.*
+8 -9
View File
@@ -10,7 +10,6 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
module Simplex.FileTransfer.Agent
@@ -322,8 +321,8 @@ runXFTPSndPrepareWorker c doWork = do
if status /= SFSEncrypted -- status is SFSNew or SFSEncrypting
then do
fsEncPath <- toFSFilePath $ sndFileEncPath ppath
when (status == SFSEncrypting) $
whenM (doesFileExist fsEncPath) $ removeFile fsEncPath
when (status == SFSEncrypting) . whenM (doesFileExist fsEncPath) $
removeFile fsEncPath
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSEncrypting
(digest, chunkSpecsDigests) <- encryptFileForUpload sndFile fsEncPath
withStore c $ \db -> do
@@ -441,11 +440,11 @@ runXFTPSndWorker c srv doWork = do
| length rcvIdsKeys > numRecipients = throwError $ INTERNAL "too many recipients"
| length rcvIdsKeys == numRecipients = pure cr
| otherwise = do
maxRecipients <- asks $ xftpMaxRecipientsPerRequest . config
let numRecipients' = min (numRecipients - length rcvIdsKeys) maxRecipients
rcvIdsKeys' <- agentXFTPAddRecipients c userId chunkDigest cr numRecipients'
cr' <- withStore' c $ \db -> addSndChunkReplicaRecipients db cr $ L.toList rcvIdsKeys'
addRecipients ch cr'
maxRecipients <- asks $ xftpMaxRecipientsPerRequest . config
let numRecipients' = min (numRecipients - length rcvIdsKeys) maxRecipients
rcvIdsKeys' <- agentXFTPAddRecipients c userId chunkDigest cr numRecipients'
cr' <- withStore' c $ \db -> addSndChunkReplicaRecipients db cr $ L.toList rcvIdsKeys'
addRecipients ch cr'
sndFileToDescrs :: SndFile -> m (ValidFileDescription 'FSender, [ValidFileDescription 'FRecipient])
sndFileToDescrs SndFile {digest = Nothing} = throwError $ INTERNAL "snd file has no digest"
sndFileToDescrs SndFile {chunks = []} = throwError $ INTERNAL "snd file has no chunks"
@@ -573,7 +572,7 @@ runXFTPDelWorker c srv doWork = do
withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay
atomically $ assertAgentForeground c
loop
retryDone e = delWorkerInternalError c deletedSndChunkReplicaId e
retryDone = delWorkerInternalError c deletedSndChunkReplicaId
deleteChunkReplica :: DeletedSndChunkReplica -> m ()
deleteChunkReplica replica@DeletedSndChunkReplica {userId, deletedSndChunkReplicaId} = do
agentXFTPDeleteChunk c userId replica
+2 -1
View File
@@ -49,6 +49,7 @@ import Simplex.Messaging.Transport (supportedParameters)
import Simplex.Messaging.Transport.Client (TransportClientConfig, TransportHost)
import Simplex.Messaging.Transport.HTTP2
import Simplex.Messaging.Transport.HTTP2.Client
import Simplex.Messaging.Transport.HTTP2.File
import Simplex.Messaging.Util (bshow, liftEitherError, whenM)
import UnliftIO
import UnliftIO.Directory
@@ -153,7 +154,7 @@ sendXFTPCommand XFTPClient {config, http2Client = http2@HTTP2Client {sessionId}}
forM_ chunkSpec_ $ \XFTPChunkSpec {filePath, chunkOffset, chunkSize} ->
withFile filePath ReadMode $ \h -> do
hSeek h AbsoluteSeek $ fromIntegral chunkOffset
sendFile h send $ fromIntegral chunkSize
hSendFile h send $ fromIntegral chunkSize
done
createXFTPChunk ::
+6 -7
View File
@@ -9,7 +9,6 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
module Simplex.FileTransfer.Client.Main
@@ -527,8 +526,8 @@ prepareChunkSizes size' = prepareSizes size'
(smallSize, bigSize)
| size' > size34 chunkSize3 = (chunkSize2, chunkSize3)
| otherwise = (chunkSize1, chunkSize2)
-- | size' > size34 chunkSize2 = (chunkSize1, chunkSize2)
-- | otherwise = (chunkSize0, chunkSize1)
-- | size' > size34 chunkSize2 = (chunkSize1, chunkSize2)
-- | otherwise = (chunkSize0, chunkSize1)
size34 sz = (fromIntegral sz * 3) `div` 4
prepareSizes 0 = []
prepareSizes size
@@ -571,11 +570,11 @@ withRetry retryCount = withRetry' retryCount . withExceptT (CLIError . show)
removeFD :: Bool -> FilePath -> IO ()
removeFD yes fd
| yes = do
removeFile fd
putStrLn $ "\nFile description " <> fd <> " is deleted."
removeFile fd
putStrLn $ "\nFile description " <> fd <> " is deleted."
| otherwise = do
y <- liftIO . getConfirmation $ "\nFile description " <> fd <> " can't be used again. Delete it"
when y $ removeFile fd
y <- liftIO . getConfirmation $ "\nFile description " <> fd <> " can't be used again. Delete it"
when y $ removeFile fd
getConfirmation :: String -> IO Bool
getConfirmation prompt = do
+6 -6
View File
@@ -46,12 +46,12 @@ encryptFile srcFile fileHdr key nonce fileSize' encSize encFile = do
encryptChunks_ get w (!sb, !len)
| len == 0 = pure sb
| otherwise = do
let chSize = min len 65536
ch <- liftIO $ get chSize
when (B.length ch /= fromIntegral chSize) $ throwError $ FTCEFileIOError "encrypting file: unexpected EOF"
let (ch', sb') = LC.sbEncryptChunk sb ch
liftIO $ B.hPut w ch'
encryptChunks_ get w (sb', len - chSize)
let chSize = min len 65536
ch <- liftIO $ get chSize
when (B.length ch /= fromIntegral chSize) $ throwError $ FTCEFileIOError "encrypting file: unexpected EOF"
let (ch', sb') = LC.sbEncryptChunk sb ch
liftIO $ B.hPut w ch'
encryptChunks_ get w (sb', len - chSize)
decryptChunks :: Int64 -> [FilePath] -> C.SbKey -> C.CbNonce -> (String -> ExceptT String IO CryptoFile) -> ExceptT FTCryptoError IO CryptoFile
decryptChunks _ [] _ _ _ = throwError $ FTCEInvalidHeader "empty"
+14 -20
View File
@@ -1,6 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
@@ -9,7 +8,7 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
module Simplex.FileTransfer.Description
@@ -38,7 +37,7 @@ where
import Control.Applicative (optional)
import Control.Monad ((<=<))
import Data.Aeson (FromJSON, ToJSON)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
@@ -54,12 +53,11 @@ import Data.Word (Word32)
import qualified Data.Yaml as Y
import Database.SQLite.Simple.FromField (FromField (..))
import Database.SQLite.Simple.ToField (ToField (..))
import GHC.Generics (Generic)
import Simplex.FileTransfer.Chunks
import Simplex.FileTransfer.Protocol
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Parsers (defaultJSON, parseAll)
import Simplex.Messaging.Protocol (XFTPServer)
import Simplex.Messaging.Util (bshow, groupAllOn, (<$?>))
@@ -150,21 +148,13 @@ data YAMLFileDescription = YAMLFileDescription
chunkSize :: String,
replicas :: [YAMLServerReplicas]
}
deriving (Eq, Show, Generic, FromJSON)
instance ToJSON YAMLFileDescription where
toJSON = J.genericToJSON J.defaultOptions
toEncoding = J.genericToEncoding J.defaultOptions
deriving (Eq, Show)
data YAMLServerReplicas = YAMLServerReplicas
{ server :: XFTPServer,
chunks :: [String]
}
deriving (Eq, Show, Generic, FromJSON)
instance ToJSON YAMLServerReplicas where
toJSON = J.genericToJSON J.defaultOptions
toEncoding = J.genericToEncoding J.defaultOptions
deriving (Eq, Show)
data FileServerReplica = FileServerReplica
{ chunkNo :: Int,
@@ -176,6 +166,13 @@ data FileServerReplica = FileServerReplica
}
deriving (Show)
newtype FileSize a = FileSize {unFileSize :: a}
deriving (Eq, Show)
$(J.deriveJSON defaultJSON ''YAMLServerReplicas)
$(J.deriveJSON defaultJSON ''YAMLFileDescription)
instance FilePartyI p => StrEncoding (ValidFileDescription p) where
strEncode (ValidFD fd) = strEncode fd
strDecode s = strDecode s >>= (\(AVFD fd) -> checkParty fd)
@@ -217,9 +214,6 @@ encodeFileDescription FileDescription {party, size, digest, key, nonce, chunkSiz
replicas = encodeFileReplicas chunkSize chunks
}
newtype FileSize a = FileSize {unFileSize :: a}
deriving (Eq, Show)
instance (Integral a, Show a) => StrEncoding (FileSize a) where
strEncode (FileSize b)
| b' /= 0 = bshow b
@@ -242,9 +236,9 @@ instance (Integral a, Show a) => StrEncoding (FileSize a) where
instance (Integral a, Show a) => IsString (FileSize a) where
fromString = either error id . strDecode . B.pack
instance (FromField a) => FromField (FileSize a) where fromField f = FileSize <$> fromField f
instance FromField a => FromField (FileSize a) where fromField f = FileSize <$> fromField f
instance (ToField a) => ToField (FileSize a) where toField (FileSize s) = toField s
instance ToField a => ToField (FileSize a) where toField (FileSize s) = toField s
groupReplicasByServer :: FileSize Word32 -> [FileChunk] -> [[FileServerReplica]]
groupReplicasByServer defChunkSize =
+8 -24
View File
@@ -1,5 +1,4 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
@@ -7,6 +6,7 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
@@ -14,8 +14,7 @@
module Simplex.FileTransfer.Protocol where
import Control.Applicative ((<|>))
import Data.Aeson (FromJSON, ToJSON)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.ByteString.Char8 (ByteString)
@@ -25,8 +24,6 @@ import Data.List.NonEmpty (NonEmpty (..))
import Data.Maybe (isNothing)
import Data.Type.Equality
import Data.Word (Word32)
import GHC.Generics (Generic)
import Generic.Random (genericArbitraryU)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
@@ -59,7 +56,6 @@ import Simplex.Messaging.Protocol
import Simplex.Messaging.Transport (SessionId, TransportError (..))
import Simplex.Messaging.Util (bshow, (<$?>))
import Simplex.Messaging.Version
import Test.QuickCheck (Arbitrary (..))
currentXFTPVersion :: Version
currentXFTPVersion = 1
@@ -69,14 +65,7 @@ xftpBlockSize = 16384
-- | File protocol clients
data FileParty = FRecipient | FSender
deriving (Eq, Show, Generic)
instance FromJSON FileParty where
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "F"
instance ToJSON FileParty where
toJSON = J.genericToJSON . enumJSON $ dropPrefix "F"
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "F"
deriving (Eq, Show)
data SFileParty :: FileParty -> Type where
SFRecipient :: SFileParty FRecipient
@@ -355,14 +344,7 @@ data XFTPErrorType
INTERNAL
| -- | used internally, never returned by the server (to be removed)
DUPLICATE_ -- not part of SMP protocol, used internally
deriving (Eq, Generic, Read, Show)
instance ToJSON XFTPErrorType where
toJSON = J.genericToJSON $ sumTypeJSON id
toEncoding = J.genericToEncoding $ sumTypeJSON id
instance FromJSON XFTPErrorType where
parseJSON = J.genericParseJSON $ sumTypeJSON id
deriving (Eq, Read, Show)
instance StrEncoding XFTPErrorType where
strEncode = \case
@@ -370,8 +352,6 @@ instance StrEncoding XFTPErrorType where
e -> bshow e
strP = "CMD " *> (CMD <$> parseRead1) <|> parseRead1
instance Arbitrary XFTPErrorType where arbitrary = genericArbitraryU
instance Encoding XFTPErrorType where
smpEncode = \case
BLOCK -> "BLOCK"
@@ -435,3 +415,7 @@ xftpDecodeTransmission sessionId t = do
case tParse True t' of
t'' :| [] -> Right $ tDecodeParseValidate sessionId currentXFTPVersion t''
_ -> Left BLOCK
$(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty)
$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType)
+11 -11
View File
@@ -169,17 +169,17 @@ processRequest :: HTTP2Request -> M ()
processRequest HTTP2Request {sessionId, reqBody = body@HTTP2Body {bodyHead}, sendResponse}
| B.length bodyHead /= xftpBlockSize = sendXFTPResponse ("", "", FRErr BLOCK) Nothing
| otherwise = do
case xftpDecodeTransmission sessionId bodyHead of
Right (sig_, signed, (corrId, fId, cmdOrErr)) -> do
case cmdOrErr of
Right cmd -> do
verifyXFTPTransmission sig_ signed fId cmd >>= \case
VRVerified req -> uncurry send =<< processXFTPRequest body req
VRFailed -> send (FRErr AUTH) Nothing
Left e -> send (FRErr e) Nothing
where
send resp = sendXFTPResponse (corrId, fId, resp)
Left e -> sendXFTPResponse ("", "", FRErr e) Nothing
case xftpDecodeTransmission sessionId bodyHead of
Right (sig_, signed, (corrId, fId, cmdOrErr)) -> do
case cmdOrErr of
Right cmd -> do
verifyXFTPTransmission sig_ signed fId cmd >>= \case
VRVerified req -> uncurry send =<< processXFTPRequest body req
VRFailed -> send (FRErr AUTH) Nothing
Left e -> send (FRErr e) Nothing
where
send resp = sendXFTPResponse (corrId, fId, resp)
Left e -> sendXFTPResponse ("", "", FRErr e) Nothing
where
sendXFTPResponse :: (CorrId, XFTPFileId, FileResponse) -> Maybe ServerFile -> M ()
sendXFTPResponse (corrId, fId, resp) serverFile_ = do
+1 -1
View File
@@ -26,7 +26,7 @@ import Simplex.FileTransfer.Server.StoreLog
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol (BasicAuth, RcvPublicVerifyKey)
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
import Simplex.Messaging.Util (tshow)
import System.IO (IOMode (..))
import UnliftIO.STM
+5 -4
View File
@@ -19,7 +19,7 @@ import Options.Applicative
import Simplex.FileTransfer.Chunks
import Simplex.FileTransfer.Description (FileSize (..))
import Simplex.FileTransfer.Server (runXFTPServer)
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..), defaultFileExpiration, defFileExpirationHours)
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..), defFileExpirationHours, defaultFileExpiration)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern XFTPServer)
@@ -143,9 +143,10 @@ xftpServerCLI cfgPath logPath = do
allowNewFiles = fromMaybe True $ iniOnOff "AUTH" "new_files" ini,
newFileBasicAuth = either error id <$> strDecodeIni "AUTH" "create_password" ini,
fileExpiration =
Just defaultFileExpiration
{ ttl = 3600 * readIniDefault defFileExpirationHours "STORE_LOG" "expire_files_hours" ini
},
Just
defaultFileExpiration
{ ttl = 3600 * readIniDefault defFileExpirationHours "STORE_LOG" "expire_files_hours" ini
},
caCertificateFile = c caCrtFile,
privateKeyFile = c serverKeyFile,
certificateFile = c serverCrtFile,
+5 -33
View File
@@ -1,14 +1,11 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.FileTransfer.Transport
( supportedFileServerVRange,
XFTPRcvChunkSpec (..),
sendFile,
receiveFile,
sendEncFile,
receiveEncFile,
@@ -25,10 +22,10 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Word (Word32)
import GHC.IO.Handle.Internals (ioe_EOF)
import Simplex.FileTransfer.Protocol (XFTPErrorType (..))
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import Simplex.Messaging.Transport.HTTP2.File
import Simplex.Messaging.Version
import System.IO (Handle, IOMode (..), withFile)
@@ -42,18 +39,6 @@ data XFTPRcvChunkSpec = XFTPRcvChunkSpec
supportedFileServerVRange :: VersionRange
supportedFileServerVRange = mkVersionRange 1 1
fileBlockSize :: Int
fileBlockSize = 16 * 1024
sendFile :: Handle -> (Builder -> IO ()) -> Word32 -> IO ()
sendFile h send = go
where
go 0 = pure ()
go sz =
getFileChunk h sz >>= \ch -> do
send $ byteString ch
go $ sz - fromIntegral (B.length ch)
sendEncFile :: Handle -> (Builder -> IO ()) -> LC.SbState -> Word32 -> IO ()
sendEncFile h send = go
where
@@ -66,23 +51,10 @@ sendEncFile h send = go
send (byteString encCh) `E.catch` \(e :: E.SomeException) -> print e >> E.throwIO e
go sbState' $ sz - fromIntegral (B.length ch)
getFileChunk :: Handle -> Word32 -> IO ByteString
getFileChunk h sz =
B.hGet h fileBlockSize >>= \case
"" -> ioe_EOF
ch -> pure $ B.take (fromIntegral sz) ch -- sz >= xftpBlockSize
receiveFile :: (Int -> IO ByteString) -> XFTPRcvChunkSpec -> ExceptT XFTPErrorType IO ()
receiveFile getBody = receiveFile_ receive
where
receive h sz = do
ch <- getBody fileBlockSize
let chSize = fromIntegral $ B.length ch
if
| chSize > sz -> pure $ Left SIZE
| chSize > 0 -> B.hPut h ch >> receive h (sz - chSize)
| sz == 0 -> pure $ Right ()
| otherwise -> pure $ Left SIZE
receive h sz = hReceiveFile getBody h sz >>= \sz' -> pure $ if sz' == 0 then Right () else Left SIZE
receiveEncFile :: (Int -> IO ByteString) -> LC.SbState -> XFTPRcvChunkSpec -> ExceptT XFTPErrorType IO ()
receiveEncFile getBody = receiveFile_ . receive
@@ -91,8 +63,8 @@ receiveEncFile getBody = receiveFile_ . receive
ch <- getBody fileBlockSize
let chSize = fromIntegral $ B.length ch
if
| chSize > sz + authSz -> pure $ Left SIZE
| chSize > 0 -> do
| chSize > sz + authSz -> pure $ Left SIZE
| chSize > 0 -> do
let (ch', rest) = B.splitAt (fromIntegral sz) ch
(decCh, sbState') = LC.sbDecryptChunk sbState ch'
sz' = sz - fromIntegral (B.length ch')
@@ -105,7 +77,7 @@ receiveEncFile getBody = receiveFile_ . receive
tag = LC.sbAuth sbState'
tag'' <- if tagSz == C.authTagSize then pure tag' else (tag' <>) <$> getBody (C.authTagSize - tagSz)
pure $ if BA.constEq tag'' tag then Right () else Left CRYPTO
| otherwise -> pure $ Left SIZE
| otherwise -> pure $ Left SIZE
authSz = fromIntegral C.authTagSize
receiveFile_ :: (Handle -> Word32 -> IO (Either XFTPErrorType ())) -> XFTPRcvChunkSpec -> ExceptT XFTPErrorType IO ()
+133 -130
View File
@@ -14,7 +14,6 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
-- |
@@ -43,6 +42,7 @@ module Simplex.Messaging.Agent
disconnectAgentClient,
resumeAgentClient,
withConnLock,
withInvLock,
createUser,
deleteUser,
createConnectionAsync,
@@ -151,7 +151,7 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg
import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..))
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SubscriptionMode (..), SMPMsgMeta, SProtocolType (..), SndPublicVerifyKey, UserProtocol, XFTPServerWithAuth)
import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicVerifyKey, SubscriptionMode (..), UserProtocol, XFTPServerWithAuth)
import qualified Simplex.Messaging.Protocol as SMP
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util
@@ -462,8 +462,8 @@ deleteUser' c userId delSMPQueues = do
atomically $ TM.delete userId $ smpServers c
where
delUser =
whenM (withStore' c (`deleteUserWithoutConns` userId)) $
atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId)
whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $
writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId)
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> SubscriptionMode -> m ConnId
newConnAsync c userId corrId enableNtfs cMode subMode = do
@@ -481,16 +481,17 @@ newConnNoQueues c userId connId enableNtfs cMode = do
joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId
joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo subMode = do
aVRange <- asks $ smpAgentVRange . config
case crAgentVRange `compatibleVersion` aVRange of
Just (Compatible connAgentVersion) -> do
g <- asks idsDrg
let duplexHS = connAgentVersion /= 1
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo
pure connId
_ -> throwError $ AGENT A_VERSION
withInvLock c (strEncode cReqUri) "joinConnAsync" $ do
aVRange <- asks $ smpAgentVRange . config
case crAgentVRange `compatibleVersion` aVRange of
Just (Compatible connAgentVersion) -> do
g <- asks idsDrg
let duplexHS = connAgentVersion /= 1
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo
pure connId
_ -> throwError $ AGENT A_VERSION
joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo =
throwError $ CMD PROHIBITED
@@ -554,11 +555,11 @@ switchConnectionAsync' c corrId connId =
SomeConn _ (DuplexConnection cData rqs@(rq :| _rqs) sqs)
| isJust (switchingRQ rqs) -> throwError $ CMD PROHIBITED
| otherwise -> do
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn SWCH
let rqs' = updatedQs rq1 rqs
pure . connectionStats $ DuplexConnection cData rqs' sqs
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn SWCH
let rqs' = updatedQs rq1 rqs
pure . connectionStats $ DuplexConnection cData rqs' sqs
_ -> throwError $ CMD PROHIBITED
newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c)
@@ -615,24 +616,25 @@ startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {cr
_ -> throwError $ AGENT A_VERSION
joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m ConnId
joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = do
(aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv
g <- asks idsDrg
connId' <- withStore c $ \db -> runExceptT $ do
connId' <- ExceptT $ createSndConn db g cData q
liftIO $ createRatchet db connId' rc
pure connId'
let sq = (q :: SndQueue) {connId = connId'}
cData' = (cData :: ConnData) {connId = connId'}
duplexHS = connAgentVersion /= 1
tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case
Right _ -> do
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv =
withInvLock c (strEncode inv) "joinConnSrv" $ do
(aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv
g <- asks idsDrg
connId' <- withStore c $ \db -> runExceptT $ do
connId' <- ExceptT $ createSndConn db g cData q
liftIO $ createRatchet db connId' rc
pure connId'
Left e -> do
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
withStore' c (`deleteConn` connId')
throwError e
let sq = (q :: SndQueue) {connId = connId'}
cData' = (cData :: ConnData) {connId = connId'}
duplexHS = connAgentVersion /= 1
tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case
Right _ -> do
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
pure connId'
Left e -> do
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
withStore' c (`deleteConn` connId')
throwError e
joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo subMode srv = do
aVRange <- asks $ smpAgentVRange . config
clientVRange <- asks $ smpClientVRange . config
@@ -813,15 +815,15 @@ getNotificationMessage' c nonce encNtfInfo = do
where
getNtfMessages ntfConnId maxMs nMeta ms
| length ms < maxMs =
getConnectionMessage' c ntfConnId >>= \case
Just m@SMP.SMPMsgMeta {msgId, msgTs, msgFlags} -> case nMeta of
Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'}
| msgId == msgId' || msgTs > msgTs' -> pure $ reverse (m : ms)
| otherwise -> getMsg (m : ms)
_
| SMP.notification msgFlags -> pure $ reverse (m : ms)
| otherwise -> getMsg (m : ms)
_ -> pure $ reverse ms
getConnectionMessage' c ntfConnId >>= \case
Just m@SMP.SMPMsgMeta {msgId, msgTs, msgFlags} -> case nMeta of
Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'}
| msgId == msgId' || msgTs > msgTs' -> pure $ reverse (m : ms)
| otherwise -> getMsg (m : ms)
_
| SMP.notification msgFlags -> pure $ reverse (m : ms)
| otherwise -> getMsg (m : ms)
_ -> pure $ reverse ms
| otherwise = pure $ reverse ms
where
getMsg = getNtfMessages ntfConnId maxMs nMeta
@@ -962,12 +964,12 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
Just (rq'@RcvQueue {primary}, rq'' : rqs')
| primary -> internalErr "ICQDelete: cannot delete primary rcv queue"
| otherwise -> do
checkRQSwchStatus rq' RSReceivedMessage
tryError (deleteQueue c rq') >>= \case
Right () -> finalizeSwitch
Left e
| temporaryOrHostError e -> throwError e
| otherwise -> finalizeSwitch >> throwError e
checkRQSwchStatus rq' RSReceivedMessage
tryError (deleteQueue c rq') >>= \case
Right () -> finalizeSwitch
Left e
| temporaryOrHostError e -> throwError e
| otherwise -> finalizeSwitch >> throwError e
where
finalizeSwitch = do
withStore' c $ \db -> deleteConnRcvQueue db rq'
@@ -1123,7 +1125,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, dupl
-- because the queue must be secured by the time the confirmation or the first HELLO is received
| duplexHandshake == Just True -> connErr
| otherwise ->
ifM (msgExpired helloTimeout) connErr (retrySndMsg RIFast)
ifM (msgExpired helloTimeout) connErr (retrySndMsg RIFast)
where
connErr = case rq_ of
-- party initiating connection
@@ -1143,8 +1145,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, dupl
-- for other operations BROKER HOST is treated as a permanent error (e.g., when connecting to the server),
-- the message sending would be retried
| temporaryOrHostError e -> do
let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndMsg RIFast)
let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndMsg RIFast)
| otherwise -> notifyDel msgId err
where
msgExpired timeoutSel = do
@@ -1286,9 +1288,9 @@ switchConnection' c connId =
SomeConn _ conn@(DuplexConnection cData rqs@(rq :| _rqs) _)
| isJust (switchingRQ rqs) -> throwError $ CMD PROHIBITED
| otherwise -> do
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
rq' <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
switchDuplexConnection c conn rq'
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
rq' <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
switchDuplexConnection c conn rq'
_ -> throwError $ CMD PROHIBITED
switchDuplexConnection :: AgentMonad m => AgentClient -> Connection 'CDuplex -> RcvQueue -> m ConnectionStats
@@ -1314,19 +1316,19 @@ abortConnectionSwitch' c connId =
SomeConn _ (DuplexConnection cData rqs sqs) -> case switchingRQ rqs of
Just rq
| canAbortRcvSwitch rq -> do
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
-- multiple queues to which the connections switches were possible when repeating switch was allowed
let (delRqs, keepRqs) = L.partition (\q -> Just rq.dbQueueId == q.dbReplaceQueueId) rqs
case L.nonEmpty keepRqs of
Just rqs' -> do
rq' <- withStore' c $ \db -> do
mapM_ (setRcvQueueDeleted db) delRqs
setRcvSwitchStatus db rq Nothing
forM_ delRqs $ \RcvQueue {server, rcvId} -> enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICDeleteRcvQueue rcvId
let rqs'' = updatedQs rq' rqs'
conn' = DuplexConnection cData rqs'' sqs
pure $ connectionStats conn'
_ -> throwError $ INTERNAL "won't delete all rcv queues in connection"
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
-- multiple queues to which the connections switches were possible when repeating switch was allowed
let (delRqs, keepRqs) = L.partition (\q -> Just rq.dbQueueId == q.dbReplaceQueueId) rqs
case L.nonEmpty keepRqs of
Just rqs' -> do
rq' <- withStore' c $ \db -> do
mapM_ (setRcvQueueDeleted db) delRqs
setRcvSwitchStatus db rq Nothing
forM_ delRqs $ \RcvQueue {server, rcvId} -> enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICDeleteRcvQueue rcvId
let rqs'' = updatedQs rq' rqs'
conn' = DuplexConnection cData rqs'' sqs
pure $ connectionStats conn'
_ -> throwError $ INTERNAL "won't delete all rcv queues in connection"
| otherwise -> throwError $ CMD PROHIBITED
_ -> throwError $ CMD PROHIBITED
_ -> throwError $ CMD PROHIBITED
@@ -1336,16 +1338,16 @@ synchronizeRatchet' c connId force = withConnLock c connId "synchronizeRatchet"
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection cData rqs sqs)
| ratchetSyncAllowed cData || force -> do
-- check queues are not switching?
AgentConfig {e2eEncryptVRange} <- asks config
(pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange
void $ enqueueRatchetKeyMsgs c cData sqs e2eParams
withStore' c $ \db -> do
setConnRatchetSync db connId RSStarted
setRatchetX3dhKeys db connId pk1 pk2 k1 k2
let cData' = cData {ratchetSyncState = RSStarted} :: ConnData
conn' = DuplexConnection cData' rqs sqs
pure $ connectionStats conn'
-- check queues are not switching?
AgentConfig {e2eEncryptVRange} <- asks config
(pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange
void $ enqueueRatchetKeyMsgs c cData sqs e2eParams
withStore' c $ \db -> do
setConnRatchetSync db connId RSStarted
setRatchetX3dhKeys db connId pk1 pk2 k1 k2
let cData' = cData {ratchetSyncState = RSStarted} :: ConnData
conn' = DuplexConnection cData' rqs sqs
pure $ connectionStats conn'
| otherwise -> throwError $ CMD PROHIBITED
_ -> throwError $ CMD PROHIBITED
@@ -1521,23 +1523,23 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
-- possible improvement: add minimal time before repeat registration
(Just tknId, Nothing)
| savedDeviceToken == suppliedDeviceToken ->
when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered
when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered
| otherwise -> replaceToken tknId
(Just tknId, Just (NTAVerify code))
| savedDeviceToken == suppliedDeviceToken ->
t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code
t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code
| otherwise -> replaceToken tknId
(Just tknId, Just NTACheck)
| savedDeviceToken == suppliedDeviceToken -> do
ns <- asks ntfSupervisor
atomically $ nsUpdateToken ns tkn {ntfMode = suppliedNtfMode}
when (ntfTknStatus == NTActive) $ do
cron <- asks $ ntfCron . config
agentNtfEnableCron c tknId tkn cron
when (suppliedNtfMode == NMInstant) $ initializeNtfSubs c
when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete
-- possible improvement: get updated token status from the server, or maybe TCRON could return the current status
pure ntfTknStatus
ns <- asks ntfSupervisor
atomically $ nsUpdateToken ns tkn {ntfMode = suppliedNtfMode}
when (ntfTknStatus == NTActive) $ do
cron <- asks $ ntfCron . config
agentNtfEnableCron c tknId tkn cron
when (suppliedNtfMode == NMInstant) $ initializeNtfSubs c
when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete
-- possible improvement: get updated token status from the server, or maybe TCRON could return the current status
pure ntfTknStatus
| otherwise -> replaceToken tknId
(Just tknId, Just NTADelete) -> do
agentNtfDeleteToken c tknId tkn
@@ -1647,10 +1649,10 @@ toggleConnectionNtfs' c connId enable = do
toggle cData
| enableNtfs cData == enable = pure ()
| otherwise = do
withStore' c $ \db -> setConnectionNtfs db connId enable
ns <- asks ntfSupervisor
let cmd = if enable then NSCCreate else NSCDelete
atomically $ sendNtfSubCommand ns (connId, cmd)
withStore' c $ \db -> setConnectionNtfs db connId enable
ns <- asks ntfSupervisor
let cmd = if enable then NSCCreate else NSCDelete
atomically $ sendNtfSubCommand ns (connId, cmd)
deleteToken_ :: AgentMonad m => AgentClient -> NtfToken -> m ()
deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do
@@ -1743,11 +1745,12 @@ getAgentMigrations' :: AgentMonad m => AgentClient -> m [UpMigration]
getAgentMigrations' c = map upMigration <$> withStore' c (Migrations.getCurrent . DB.conn)
debugAgentLocks' :: AgentMonad' m => AgentClient -> m AgentLocks
debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs, deleteLock = d} = do
debugAgentLocks' AgentClient {connLocks = cs, invLocks = is, reconnectLocks = rs, deleteLock = d} = do
connLocks <- getLocks cs
invLocks <- getLocks is
srvLocks <- getLocks rs
delLock <- atomically $ tryReadTMVar d
pure AgentLocks {connLocks, srvLocks, delLock}
pure AgentLocks {connLocks, invLocks, srvLocks, delLock}
where
getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls)
@@ -1912,11 +1915,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
resetRatchetSync :: m (Connection c)
resetRatchetSync
| rss `notElem` ([RSOk, RSStarted] :: [RatchetSyncState]) = do
let cData'' = (toConnData conn') {ratchetSyncState = RSOk} :: ConnData
conn'' = updateConnection cData'' conn'
notify . RSYNC RSOk Nothing $ connectionStats conn''
withStore' c $ \db -> setConnRatchetSync db connId RSOk
pure conn''
let cData'' = (toConnData conn') {ratchetSyncState = RSOk} :: ConnData
conn'' = updateConnection cData'' conn'
notify . RSYNC RSOk Nothing $ connectionStats conn''
withStore' c $ \db -> setConnRatchetSync db connId RSOk
pure conn''
| otherwise = pure conn'
Right _ -> prohibited >> ack
Left e@(AGENT A_DUPLICATE) -> do
@@ -1924,11 +1927,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck}
| userAck -> ackDel internalId
| otherwise -> do
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
AgentMessage _ (A_MSG body) -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
_ -> pure ()
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
AgentMessage _ (A_MSG body) -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
_ -> pure ()
_ -> checkDuplicateHash e encryptedMsgHash >> ack
Left (AGENT (A_CRYPTO e)) -> do
exists <- withStore' c $ \db -> checkRcvMsgHashExists db connId encryptedMsgHash
@@ -1976,9 +1979,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
case msgAVRange `compatibleVersion` aVRange of
Just (Compatible av)
| av > connAgentVersion -> do
withStore' c $ \db -> setConnAgentVersion db connId av
let cData'' = cData' {connAgentVersion = av} :: ConnData
pure $ updateConnection cData'' conn'
withStore' c $ \db -> setConnAgentVersion db connId av
let cData'' = cData' {connAgentVersion = av} :: ConnData
pure $ updateConnection cData'' conn'
| otherwise -> pure conn'
Nothing -> pure conn'
ack :: m ()
@@ -1994,9 +1997,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
processEND = \case
Just (Right clnt)
| sessId == sessionId clnt -> do
removeSubscription c connId
notify' END
pure "END"
removeSubscription c connId
notify' END
pure "END"
| otherwise -> ignored
_ -> ignored
ignored = pure "END from disconnected client - ignored"
@@ -2186,12 +2189,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
case findRQ (smpServer, senderId) rqs of
Just rq'@RcvQueue {rcvId, e2ePrivKey = dhPrivKey, smpClientVersion = cVer, status = status'}
| status' == New || status' == Confirmed -> do
checkRQSwchStatus rq RSSendingQADD
logServer "<--" c srv rId $ "MSG <QKEY> " <> logSecret senderId
let dhSecret = C.dh' dhPublicKey dhPrivKey
withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer'
enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey
notify . SWITCH QDRcv SPConfirmed $ connectionStats conn'
checkRQSwchStatus rq RSSendingQADD
logServer "<--" c srv rId $ "MSG <QKEY> " <> logSecret senderId
let dhSecret = C.dh' dhPublicKey dhPrivKey
withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer'
enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey
notify . SWITCH QDRcv SPConfirmed $ connectionStats conn'
| otherwise -> qError "QKEY: queue already secured"
_ -> qError "QKEY: queue address not found in connection"
where
@@ -2227,8 +2230,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
ereadyMsg rcPrev (DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) = do
let CR.Ratchet {rcSnd} = rcPrev
-- if ratchet was initialized as receiving, it means EREADY wasn't sent on key negotiation
when (isNothing rcSnd) $
void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId
when (isNothing rcSnd) . void $
enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} (EREADY lastExternalSndId)
smpInvitation :: Connection c -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
smpInvitation conn' connReq@(CRInvitationUri crData _) cInfo = do
@@ -2267,9 +2270,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
getSendRatchetKeys
| rss == RSStarted = withStore c (`getRatchetX3dhKeys'` connId)
| otherwise = do
(pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ version e2eOtherPartyParams
void $ enqueueRatchetKeyMsgs c cData' sqs e2eParams
pure (pk1, pk2, k1, k2)
(pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ version e2eOtherPartyParams
void $ enqueueRatchetKeyMsgs c cData' sqs e2eParams
pure (pk1, pk2, k1, k2)
notifyAgreed :: m ()
notifyAgreed = do
let cData'' = cData' {ratchetSyncState = RSAgreed} :: ConnData
@@ -2285,11 +2288,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
initRatchet :: VersionRange -> (C.PrivateKeyX448, C.PrivateKeyX448, C.PublicKeyX448, C.PublicKeyX448) -> m ()
initRatchet e2eEncryptVRange (pk1, pk2, k1, k2)
| rkHash k1 k2 <= rkHashRcv = do
recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 $ CR.x3dhRcv pk1 pk2 e2eOtherPartyParams
recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 $ CR.x3dhRcv pk1 pk2 e2eOtherPartyParams
| otherwise = do
(_, rcDHRs) <- liftIO C.generateKeyPair'
recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs $ CR.x3dhSnd pk1 pk2 e2eOtherPartyParams
void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId
(_, rcDHRs) <- liftIO C.generateKeyPair'
recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs $ CR.x3dhSnd pk1 pk2 e2eOtherPartyParams
void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId
checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity
checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash
@@ -2347,8 +2350,8 @@ mkAgentConfirmation :: AgentMonad m => Compatible Version -> AgentClient -> Conn
mkAgentConfirmation (Compatible agentVersion) c cData sq srv connInfo subMode
| agentVersion == 1 = pure $ AgentConnInfo connInfo
| otherwise = do
qInfo <- createReplyQueue c cData sq subMode srv
pure $ AgentConnInfoReply (qInfo :| []) connInfo
qInfo <- createReplyQueue c cData sq subMode srv
pure $ AgentConnInfoReply (qInfo :| []) connInfo
enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
enqueueConfirmation c cData sq connInfo e2eEncryption_ = do
+38 -36
View File
@@ -2,7 +2,6 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
@@ -15,6 +14,7 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
@@ -25,6 +25,7 @@ module Simplex.Messaging.Agent.Client
ProtocolTestStep (..),
newAgentClient,
withConnLock,
withInvLock,
closeAgentClient,
closeProtocolServerClients,
closeXFTPServerClient,
@@ -117,8 +118,7 @@ import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Crypto.Random (getRandomBytes)
import Data.Aeson (FromJSON, ToJSON)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
@@ -137,7 +137,6 @@ import Data.Text (Text)
import Data.Text.Encoding
import Data.Time (UTCTime, defaultTimeLocale, formatTime, getCurrentTime)
import Data.Word (Word16)
import GHC.Generics (Generic)
import Network.Socket (HostName)
import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError)
import qualified Simplex.FileTransfer.Client as X
@@ -164,7 +163,7 @@ import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (dropPrefix, enumJSON, parse)
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, parse)
import Simplex.Messaging.Protocol
( AProtocolType (..),
BrokerMsg,
@@ -249,6 +248,8 @@ data AgentClient = AgentClient
getMsgLocks :: TMap (SMPServer, SMP.RecipientId) (TMVar ()),
-- locks to prevent concurrent operations with connection
connLocks :: TMap ConnId Lock,
-- locks to prevent concurrent operations with connection request invitations
invLocks :: TMap ByteString Lock,
-- lock to prevent concurrency between periodic and async connection deletions
deleteLock :: Lock,
-- locks to prevent concurrent reconnections to SMP servers
@@ -279,10 +280,13 @@ data AgentOpState = AgentOpState {opSuspended :: Bool, opsInProgress :: Int}
data AgentState = ASForeground | ASSuspending | ASSuspended
deriving (Eq, Show)
data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String, delLock :: Maybe String}
deriving (Show, Generic, FromJSON)
instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions
data AgentLocks = AgentLocks
{ connLocks :: Map String String,
invLocks :: Map String String,
srvLocks :: Map String String,
delLock :: Maybe String
}
deriving (Show)
data AgentStatsKey = AgentStatsKey
{ userId :: UserId,
@@ -325,6 +329,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
agentState <- newTVar ASForeground
getMsgLocks <- TM.empty
connLocks <- TM.empty
invLocks <- TM.empty
deleteLock <- createLock
reconnectLocks <- TM.empty
reconnections <- newTAsyncs
@@ -362,6 +367,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
agentState,
getMsgLocks,
connLocks,
invLocks,
deleteLock,
reconnectLocks,
reconnections,
@@ -645,6 +651,9 @@ withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a
withConnLock _ "" _ = id
withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId name
withInvLock :: MonadUnliftIO m => AgentClient -> ByteString -> String -> m a -> m a
withInvLock AgentClient {invLocks} = withLockMap_ invLocks
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pure
where
@@ -725,23 +734,14 @@ data ProtocolTestStep
| TSDownloadFile
| TSCompareFile
| TSDeleteFile
deriving (Eq, Show, Generic)
instance ToJSON ProtocolTestStep where
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TS"
toJSON = J.genericToJSON . enumJSON $ dropPrefix "TS"
instance FromJSON ProtocolTestStep where
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TS"
deriving (Eq, Show)
data ProtocolTestFailure = ProtocolTestFailure
{ testStep :: ProtocolTestStep,
testError :: AgentErrorType
}
deriving (Eq, Show, Generic, FromJSON)
deriving (Eq, Show)
instance ToJSON ProtocolTestFailure where toEncoding = J.genericToEncoding J.defaultOptions
runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe ProtocolTestFailure)
runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
cfg <- getClientConfig c smpCfg
@@ -901,8 +901,8 @@ subscribeQueues c qs = do
subscribeQueues_ u smp qs' = do
rs <- sendBatch subscribeSMPQueues smp qs'
mapM_ (uncurry $ processSubResult c) rs
when (any temporaryClientError . lefts . map snd $ L.toList rs) $
unliftIO u $ reconnectServer c $ transportSession' smp
when (any temporaryClientError . lefts . map snd $ L.toList rs) . unliftIO u $
reconnectServer c (transportSession' smp)
pure rs
type BatchResponses e r = (NonEmpty (RcvQueue, Either e r))
@@ -989,7 +989,7 @@ sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer,
mkInvitation = do
let agentEnvelope = AgentInvitation {agentVersion, connReq, connInfo}
agentCbEncryptOnce v dhPublicKey . smpEncode $
SMP.ClientMessage SMP.PHEmpty $ smpEncode agentEnvelope
SMP.ClientMessage SMP.PHEmpty (smpEncode agentEnvelope)
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta)
getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
@@ -1324,7 +1324,7 @@ userServers c = case protocolTypeI @p of
SPSMP -> smpServers c
SPXFTP -> xftpServers c
pickServer :: forall p m. (AgentMonad' m) => NonEmpty (ProtoServerWithAuth p) -> m (ProtoServerWithAuth p)
pickServer :: forall p m. AgentMonad' m => NonEmpty (ProtoServerWithAuth p) -> m (ProtoServerWithAuth p)
pickServer = \case
srv :| [] -> pure srv
servers -> do
@@ -1343,7 +1343,7 @@ withUserServers c userId action =
Just srvs -> action srvs
_ -> throwError $ INTERNAL "unknown userId - no user servers"
withNextSrv :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> ((ProtoServerWithAuth p) -> m a) -> m a
withNextSrv :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> m a) -> m a
withNextSrv c userId usedSrvs initUsed action = do
used <- readTVarIO usedSrvs
srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId used
@@ -1355,22 +1355,14 @@ withNextSrv c userId usedSrvs initUsed action = do
action srvAuth
data SubInfo = SubInfo {userId :: UserId, server :: Text, rcvId :: Text, subError :: Maybe String}
deriving (Show, Generic)
instance ToJSON SubInfo where toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
instance FromJSON SubInfo where parseJSON = J.genericParseJSON J.defaultOptions {J.omitNothingFields = True}
deriving (Show)
data SubscriptionsInfo = SubscriptionsInfo
{ activeSubscriptions :: [SubInfo],
pendingSubscriptions :: [SubInfo],
removedSubscriptions :: [SubInfo]
}
deriving (Show, Generic)
instance ToJSON SubscriptionsInfo where toEncoding = J.genericToEncoding J.defaultOptions
instance FromJSON SubscriptionsInfo where parseJSON = J.genericParseJSON J.defaultOptions
deriving (Show)
getAgentSubscriptions :: MonadIO m => AgentClient -> m SubscriptionsInfo
getAgentSubscriptions c = do
@@ -1382,6 +1374,16 @@ getAgentSubscriptions c = do
getSubs sel = map (`subInfo` Nothing) . M.keys <$> readTVarIO (getRcvQueues $ sel c)
getRemovedSubs = map (uncurry subInfo . second Just) . M.assocs <$> readTVarIO (removedSubs c)
subInfo :: (UserId, SMPServer, SMP.RecipientId) -> Maybe SMPClientError -> SubInfo
subInfo (uId, srv, rId) err = SubInfo {userId = uId, server = enc srv, rcvId = enc rId, subError = show <$> err}
subInfo (uId, srv, rId) err = SubInfo {userId = uId, server = enc srv, rcvId = enc rId, subError = show <$> err}
enc :: StrEncoding a => a -> Text
enc = decodeLatin1 . strEncode
$(J.deriveJSON defaultJSON ''AgentLocks)
$(J.deriveJSON (enumJSON $ dropPrefix "TS") ''ProtocolTestStep)
$(J.deriveJSON defaultJSON ''ProtocolTestFailure)
$(J.deriveJSON defaultJSON ''SubInfo)
$(J.deriveJSON defaultJSON ''SubscriptionsInfo)
-2
View File
@@ -1,5 +1,3 @@
{-# LANGUAGE NamedFieldPuns #-}
module Simplex.Messaging.Agent.Lock where
import Control.Monad (void)
+12 -13
View File
@@ -5,7 +5,6 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
module Simplex.Messaging.Agent.NtfSubSupervisor
@@ -100,14 +99,14 @@ processNtfSub c (connId, cmd) = do
Just (action, _)
-- subscription was marked for deletion / is being deleted
| isDeleteNtfSubAction action -> do
if ntfSubStatus == NASNew || ntfSubStatus == NASOff || ntfSubStatus == NASDeleted
then resetSubscription
else withNtfServer c $ \ntfServer -> do
withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate)
addNtfNTFWorker ntfServer
if ntfSubStatus == NASNew || ntfSubStatus == NASOff || ntfSubStatus == NASDeleted
then resetSubscription
else withNtfServer c $ \ntfServer -> do
withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate)
addNtfNTFWorker ntfServer
| otherwise -> case action of
NtfSubNTFAction _ -> addNtfNTFWorker subNtfServer
NtfSubSMPAction _ -> addNtfSMPWorker smpServer
NtfSubNTFAction _ -> addNtfNTFWorker subNtfServer
NtfSubSMPAction _ -> addNtfSMPWorker smpServer
rotate :: m ()
rotate = do
withStore' c $ \db -> supervisorUpdateNtfSub db sub (NtfSubNTFAction NSARotate)
@@ -291,11 +290,11 @@ rescheduleAction :: AgentMonad' m => TMVar () -> UTCTime -> UTCTime -> m Bool
rescheduleAction doWork ts actionTs
| actionTs <= ts = pure False
| otherwise = do
void . atomically $ tryTakeTMVar doWork
void . forkIO $ do
liftIO $ threadDelay' $ diffToMicroseconds $ diffUTCTime actionTs ts
void . atomically $ tryPutTMVar doWork ()
pure True
void . atomically $ tryTakeTMVar doWork
void . forkIO $ do
liftIO $ threadDelay' $ diffToMicroseconds $ diffUTCTime actionTs ts
void . atomically $ tryPutTMVar doWork ()
pure True
retryOnError :: AgentMonad' m => AgentClient -> Text -> m () -> (AgentErrorType -> m ()) -> AgentErrorType -> m ()
retryOnError c name loop done e = do
+38 -97
View File
@@ -1,6 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
@@ -13,6 +12,7 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
@@ -151,7 +151,7 @@ import Control.Monad (unless)
import Control.Monad.Except (runExceptT, throwError)
import Control.Monad.IO.Class
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Base64
@@ -176,8 +176,6 @@ import Data.Typeable ()
import Data.Word (Word32)
import Database.SQLite.Simple.FromField
import Database.SQLite.Simple.ToField
import GHC.Generics (Generic)
import Generic.Random (genericArbitraryU)
import Simplex.FileTransfer.Description
import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType)
import Simplex.Messaging.Agent.QueryString
@@ -213,7 +211,6 @@ import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTra
import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts_ (..))
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import Test.QuickCheck (Arbitrary (..))
import Text.Read
import UnliftIO.Exception (Exception)
@@ -614,15 +611,11 @@ data RcvQueueInfo = RcvQueueInfo
rcvSwitchStatus :: Maybe RcvSwitchStatus,
canAbortSwitch :: Bool
}
deriving (Eq, Show, Generic)
instance ToJSON RcvQueueInfo where toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
instance FromJSON RcvQueueInfo where parseJSON = J.genericParseJSON J.defaultOptions {J.omitNothingFields = True}
deriving (Eq, Show)
instance StrEncoding RcvQueueInfo where
strEncode RcvQueueInfo {rcvServer, rcvSwitchStatus, canAbortSwitch} =
"srv=" <> strEncode rcvServer
("srv=" <> strEncode rcvServer)
<> maybe "" (\switch -> ";switch=" <> strEncode switch) rcvSwitchStatus
<> (";can_abort_switch=" <> strEncode canAbortSwitch)
strP = do
@@ -635,11 +628,7 @@ data SndQueueInfo = SndQueueInfo
{ sndServer :: SMPServer,
sndSwitchStatus :: Maybe SndSwitchStatus
}
deriving (Eq, Show, Generic)
instance ToJSON SndQueueInfo where toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
instance FromJSON SndQueueInfo where parseJSON = J.genericParseJSON J.defaultOptions {J.omitNothingFields = True}
deriving (Eq, Show)
instance StrEncoding SndQueueInfo where
strEncode SndQueueInfo {sndServer, sndSwitchStatus} =
@@ -656,13 +645,11 @@ data ConnectionStats = ConnectionStats
ratchetSyncState :: RatchetSyncState,
ratchetSyncSupported :: Bool
}
deriving (Eq, Show, Generic, FromJSON)
instance ToJSON ConnectionStats where toEncoding = J.genericToEncoding J.defaultOptions
deriving (Eq, Show)
instance StrEncoding ConnectionStats where
strEncode ConnectionStats {connAgentVersion, rcvQueuesInfo, sndQueuesInfo, ratchetSyncState, ratchetSyncSupported} =
"agent_version=" <> strEncode connAgentVersion
("agent_version=" <> strEncode connAgentVersion)
<> (" rcv=" <> strEncodeList rcvQueuesInfo)
<> (" snd=" <> strEncodeList sndQueuesInfo)
<> (" sync=" <> strEncode ratchetSyncState)
@@ -1048,7 +1035,7 @@ instance StrEncoding MsgReceiptStatus where
MROk -> "ok"
MRBadMsgHash -> "badMsgHash"
strP =
A.takeWhile1 (/= ' ') >>= \ case
A.takeWhile1 (/= ' ') >>= \case
"ok" -> pure MROk
"badMsgHash" -> pure MRBadMsgHash
_ -> fail "bad MsgReceiptStatus"
@@ -1391,7 +1378,7 @@ type AgentMsgId = Int64
-- | Result of received message integrity validation.
data MsgIntegrity = MsgOk | MsgError {errorInfo :: MsgErrorType}
deriving (Eq, Show, Generic)
deriving (Eq, Show)
instance StrEncoding MsgIntegrity where
strP = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> strP)
@@ -1399,20 +1386,13 @@ instance StrEncoding MsgIntegrity where
MsgOk -> "OK"
MsgError e -> "ERR " <> strEncode e
instance ToJSON MsgIntegrity where
toJSON = J.genericToJSON $ sumTypeJSON fstToLower
toEncoding = J.genericToEncoding $ sumTypeJSON fstToLower
instance FromJSON MsgIntegrity where
parseJSON = J.genericParseJSON $ sumTypeJSON fstToLower
-- | Error of message integrity validation.
data MsgErrorType
= MsgSkipped {fromMsgId :: AgentMsgId, toMsgId :: AgentMsgId}
| MsgBadId {msgId :: AgentMsgId}
| MsgBadHash
| MsgDuplicate
deriving (Eq, Show, Generic)
deriving (Eq, Show)
instance StrEncoding MsgErrorType where
strP =
@@ -1427,13 +1407,6 @@ instance StrEncoding MsgErrorType where
MsgBadHash -> "HASH"
MsgDuplicate -> "DUPLICATE"
instance ToJSON MsgErrorType where
toJSON = J.genericToJSON $ sumTypeJSON fstToLower
toEncoding = J.genericToEncoding $ sumTypeJSON fstToLower
instance FromJSON MsgErrorType where
parseJSON = J.genericParseJSON $ sumTypeJSON fstToLower
-- | Error type used in errors sent to agent clients.
data AgentErrorType
= -- | command or response error
@@ -1454,14 +1427,7 @@ data AgentErrorType
INTERNAL {internalErr :: String}
| -- | agent inactive
INACTIVE
deriving (Eq, Generic, Show, Exception)
instance ToJSON AgentErrorType where
toJSON = J.genericToJSON $ sumTypeJSON id
toEncoding = J.genericToEncoding $ sumTypeJSON id
instance FromJSON AgentErrorType where
parseJSON = J.genericParseJSON $ sumTypeJSON id
deriving (Eq, Show, Exception)
-- | SMP agent protocol command or response error.
data CommandErrorType
@@ -1475,14 +1441,7 @@ data CommandErrorType
SIZE
| -- | message does not fit in SMP block
LARGE
deriving (Eq, Generic, Read, Show, Exception)
instance ToJSON CommandErrorType where
toJSON = J.genericToJSON $ sumTypeJSON id
toEncoding = J.genericToEncoding $ sumTypeJSON id
instance FromJSON CommandErrorType where
parseJSON = J.genericParseJSON $ sumTypeJSON id
deriving (Eq, Read, Show, Exception)
-- | Connection error.
data ConnectionErrorType
@@ -1496,14 +1455,7 @@ data ConnectionErrorType
NOT_ACCEPTED
| -- | connection not available on reply confirmation/HELLO after timeout
NOT_AVAILABLE
deriving (Eq, Generic, Read, Show, Exception)
instance ToJSON ConnectionErrorType where
toJSON = J.genericToJSON $ sumTypeJSON id
toEncoding = J.genericToEncoding $ sumTypeJSON id
instance FromJSON ConnectionErrorType where
parseJSON = J.genericParseJSON $ sumTypeJSON id
deriving (Eq, Read, Show, Exception)
-- | SMP server errors.
data BrokerErrorType
@@ -1519,14 +1471,7 @@ data BrokerErrorType
TRANSPORT {transportErr :: TransportError}
| -- | command response timeout
TIMEOUT
deriving (Eq, Generic, Read, Show, Exception)
instance ToJSON BrokerErrorType where
toJSON = J.genericToJSON $ sumTypeJSON id
toEncoding = J.genericToEncoding $ sumTypeJSON id
instance FromJSON BrokerErrorType where
parseJSON = J.genericParseJSON $ sumTypeJSON id
deriving (Eq, Read, Show, Exception)
-- | Errors of another SMP agent.
data SMPAgentError
@@ -1543,7 +1488,7 @@ data SMPAgentError
A_DUPLICATE
| -- | error in the message to add/delete/etc queue in connection
A_QUEUE {queueErr :: String}
deriving (Eq, Generic, Read, Show, Exception)
deriving (Eq, Read, Show, Exception)
data AgentCryptoError
= -- | AES decryption error
@@ -1556,14 +1501,7 @@ data AgentCryptoError
RATCHET_EARLIER Word32
| -- | too many skipped messages
RATCHET_SKIPPED Word32
deriving (Eq, Generic, Read, Show, Exception)
instance ToJSON AgentCryptoError where
toJSON = J.genericToJSON $ sumTypeJSON id
toEncoding = J.genericToEncoding $ sumTypeJSON id
instance FromJSON AgentCryptoError where
parseJSON = J.genericParseJSON $ sumTypeJSON id
deriving (Eq, Read, Show, Exception)
instance StrEncoding AgentCryptoError where
strP =
@@ -1579,13 +1517,6 @@ instance StrEncoding AgentCryptoError where
RATCHET_EARLIER n -> "RATCHET_EARLIER " <> strEncode n
RATCHET_SKIPPED n -> "RATCHET_SKIPPED " <> strEncode n
instance ToJSON SMPAgentError where
toJSON = J.genericToJSON $ sumTypeJSON id
toEncoding = J.genericToEncoding $ sumTypeJSON id
instance FromJSON SMPAgentError where
parseJSON = J.genericParseJSON $ sumTypeJSON id
instance StrEncoding AgentErrorType where
strP =
"CMD " *> (CMD <$> parseRead1)
@@ -1620,18 +1551,6 @@ instance StrEncoding AgentErrorType where
where
text = encodeUtf8 . T.pack
instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU
instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU
instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU
instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU
instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU
instance Arbitrary AgentCryptoError where arbitrary = genericArbitraryU
cryptoErrToSyncState :: AgentCryptoError -> RatchetSyncState
cryptoErrToSyncState = \case
DECRYPT_AES -> RSAllowed
@@ -1957,3 +1876,25 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
unless (B.null s) $ throwError $ CMD SIZE
pure body
Nothing -> return . Left $ CMD SYNTAX
$(J.deriveJSON defaultJSON ''RcvQueueInfo)
$(J.deriveJSON defaultJSON ''SndQueueInfo)
$(J.deriveJSON defaultJSON ''ConnectionStats)
$(J.deriveJSON (sumTypeJSON fstToLower) ''MsgErrorType)
$(J.deriveJSON (sumTypeJSON fstToLower) ''MsgIntegrity)
$(J.deriveJSON (sumTypeJSON id) ''CommandErrorType)
$(J.deriveJSON (sumTypeJSON id) ''ConnectionErrorType)
$(J.deriveJSON (sumTypeJSON id) ''BrokerErrorType)
$(J.deriveJSON (sumTypeJSON id) ''AgentCryptoError)
$(J.deriveJSON (sumTypeJSON id) ''SMPAgentError)
$(J.deriveJSON (sumTypeJSON id) ''AgentErrorType)
@@ -1,6 +1,5 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Agent.RetryInterval
+1 -1
View File
@@ -23,7 +23,7 @@ import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore)
import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), simplexMQVersion)
import Simplex.Messaging.Transport.Server (loadTLSServerParams, runTransportServer, defaultTransportServerConfig)
import Simplex.Messaging.Transport.Server (defaultTransportServerConfig, loadTLSServerParams, runTransportServer)
import Simplex.Messaging.Util (bshow)
import UnliftIO.Async (race_)
import qualified UnliftIO.Exception as E
+2
View File
@@ -576,6 +576,8 @@ data StoreError
SEServerNotFound
| -- | Connection already used.
SEConnDuplicate
| -- | Confirmed snd queue already exists.
SESndQueueExists
| -- | Wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa.
-- 'upgradeRcvConnToDuplex' and 'upgradeSndConnToDuplex' do not allow duplex connections - they would also return this error.
SEBadConnType ConnType
+57 -49
View File
@@ -1,7 +1,6 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
@@ -17,12 +16,12 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Simplex.Messaging.Agent.Store.SQLite
( SQLiteStore (..),
@@ -220,8 +219,7 @@ import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Class
import Crypto.Random (ChaChaDRG, randomBytesGenerate)
import Data.Aeson (FromJSON, ToJSON)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (second)
import Data.ByteString (ByteString)
@@ -247,7 +245,6 @@ import Database.SQLite.Simple.FromField
import Database.SQLite.Simple.QQ (sql)
import Database.SQLite.Simple.ToField (ToField (..))
import qualified Database.SQLite3 as SQLite3
import GHC.Generics (Generic)
import Network.Socket (ServiceName)
import Simplex.FileTransfer.Client (XFTPChunkSpec (..))
import Simplex.FileTransfer.Description
@@ -267,7 +264,7 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..))
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (blobFieldParser, dropPrefix, fromTextField_, sumTypeJSON)
import Simplex.Messaging.Parsers (blobFieldParser, defaultJSON, dropPrefix, fromTextField_, sumTypeJSON)
import Simplex.Messaging.Protocol
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Transport.Client (TransportHost)
@@ -277,7 +274,7 @@ import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
import System.Exit (exitFailure)
import System.FilePath (takeDirectory)
import System.IO (hFlush, stdout)
import UnliftIO.Exception (onException, bracketOnError)
import UnliftIO.Exception (bracketOnError, onException)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@@ -287,7 +284,7 @@ data MigrationError
= MEUpgrade {upMigrations :: [UpMigration]}
| MEDowngrade {downMigrations :: [String]}
| MigrationError {mtrError :: MTRError}
deriving (Eq, Show, Generic)
deriving (Eq, Show)
migrationErrorDescription :: MigrationError -> String
migrationErrorDescription = \case
@@ -297,18 +294,12 @@ migrationErrorDescription = \case
"Database version is newer than the app.\nConfirm to back up and downgrade using these migrations: " <> intercalate ", " dms
MigrationError err -> mtrErrorDescription err
instance ToJSON MigrationError where
toJSON = J.genericToJSON . sumTypeJSON $ dropPrefix "ME"
toEncoding = J.genericToEncoding . sumTypeJSON $ dropPrefix "ME"
data UpMigration = UpMigration {upName :: String, withDown :: Bool}
deriving (Eq, Show, Generic, FromJSON)
deriving (Eq, Show)
upMigration :: Migration -> UpMigration
upMigration Migration {name, down} = UpMigration name $ isJust down
instance ToJSON UpMigration where toEncoding = J.genericToEncoding J.defaultOptions
data MigrationConfirmation = MCYesUp | MCYesUpDown | MCConsole | MCError
deriving (Eq, Show)
@@ -347,10 +338,10 @@ migrateSchema st migrations confirmMigrations = do
Right ms@(MTRUp ums)
| dbNew st -> Migrations.run st ms $> Right ()
| otherwise -> case confirmMigrations of
MCYesUp -> run ms
MCYesUpDown -> run ms
MCConsole -> confirm err >> run ms
MCError -> pure $ Left err
MCYesUp -> run ms
MCYesUpDown -> run ms
MCConsole -> confirm err >> run ms
MCError -> pure $ Left err
where
err = MEUpgrade $ map upMigration ums -- "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map name ums)
Right ms@(MTRDown dms) -> case confirmMigrations of
@@ -561,10 +552,23 @@ createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, dupl
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
createConn_ gVar cData $ \connId -> do
serverKeyHash_ <- createServer_ db server
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
void $ insertSndQueue_ db connId q serverKeyHash_
-- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_
ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $
createConn_ gVar cData $ \connId -> do
serverKeyHash_ <- createServer_ db server
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
void $ insertSndQueue_ db connId q serverKeyHash_
checkConfirmedSndQueueExists_ :: DB.Connection -> SndQueue -> IO Bool
checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do
fromMaybe False
<$> maybeFirstRow
fromOnly
( DB.query
db
"SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1"
(host server, port server, sndId, New)
)
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
@@ -980,7 +984,7 @@ updatePendingMsgRIState db connId msgId RI2State {slowInterval, fastInterval} =
getPendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO [InternalId]
getPendingMsgs db connId SndQueue {dbQueueId} =
map fromOnly
<$> DB.query db "SELECT internal_id FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
<$> DB.query db "SELECT internal_id FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ? ORDER BY internal_id ASC" (connId, dbQueueId)
deletePendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO ()
deletePendingMsgs db connId SndQueue {dbQueueId} =
@@ -1079,12 +1083,12 @@ countPendingSndDeliveries_ db connId msgId = do
deleteRcvMsgHashesExpired :: DB.Connection -> NominalDiffTime -> IO ()
deleteRcvMsgHashesExpired db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
DB.execute db "DELETE FROM encrypted_rcv_message_hashes WHERE created_at < ?" (Only cutoffTs)
deleteSndMsgsExpired :: DB.Connection -> NominalDiffTime -> IO ()
deleteSndMsgsExpired db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
DB.execute
db
"DELETE FROM messages WHERE internal_ts < ? AND internal_snd_id IS NOT NULL"
@@ -1163,7 +1167,7 @@ getSkippedMsgKeys :: DB.Connection -> ConnId -> IO SkippedMsgKeys
getSkippedMsgKeys db connId =
skipped <$> DB.query db "SELECT header_key, msg_n, msg_key FROM skipped_messages WHERE conn_id = ?" (Only connId)
where
skipped ms = foldl' addSkippedKey M.empty ms
skipped = foldl' addSkippedKey M.empty
addSkippedKey smks (hk, msgN, mk) = M.alter (Just . addMsgKey) hk smks
where
addMsgKey = maybe (M.singleton msgN mk) (M.insert msgN mk)
@@ -1734,15 +1738,15 @@ getAnyConn deleted' dbConn connId =
Just (cData@ConnData {deleted}, cMode)
| deleted /= deleted' -> pure $ Left SEConnNotFound
| otherwise -> do
rQ <- getRcvQueuesByConnId_ dbConn connId
sQ <- getSndQueuesByConnId_ dbConn connId
pure $ case (rQ, sQ, cMode) of
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
_ -> Left SEConnNotFound
rQ <- getRcvQueuesByConnId_ dbConn connId
sQ <- getSndQueuesByConnId_ dbConn connId
pure $ case (rQ, sQ, cMode) of
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
_ -> Left SEConnNotFound
getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
getConns = getAnyConns_ False
@@ -1804,7 +1808,7 @@ checkRatchetKeyHashExists db connId hash = do
deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO ()
deleteRatchetKeyHashesExpired db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
DB.execute db "DELETE FROM processed_ratchet_key_hashes WHERE created_at < ?" (Only cutoffTs)
-- | returns all connection queues, the first queue is the primary one
@@ -2253,7 +2257,7 @@ deleteRcvFile' db rcvFileId =
getNextRcvChunkToDownload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Maybe RcvFileChunk)
getNextRcvChunkToDownload db server@ProtocolServer {host, port, keyHash} ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
maybeFirstRow toChunk $
DB.query
db
@@ -2290,7 +2294,7 @@ getNextRcvChunkToDownload db server@ProtocolServer {host, port, keyHash} ttl = d
getNextRcvFileToDecrypt :: DB.Connection -> NominalDiffTime -> IO (Maybe RcvFile)
getNextRcvFileToDecrypt db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
fileId_ :: Maybe DBRcvFileId <-
maybeFirstRow fromOnly $
DB.query
@@ -2308,7 +2312,7 @@ getNextRcvFileToDecrypt db ttl = do
getPendingRcvFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer]
getPendingRcvFilesServers db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
map toXFTPServer
<$> DB.query
db
@@ -2350,7 +2354,7 @@ getCleanupRcvFilesDeleted db =
getRcvFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBRcvFileId, RcvFileId, FilePath)]
getRcvFilesExpired db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
DB.query
db
[sql|
@@ -2458,7 +2462,7 @@ getChunkReplicaRecipients_ db replicaId =
getNextSndFileToPrepare :: DB.Connection -> NominalDiffTime -> IO (Maybe SndFile)
getNextSndFileToPrepare db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
fileId_ :: Maybe DBSndFileId <-
maybeFirstRow fromOnly $
DB.query
@@ -2539,7 +2543,7 @@ createSndFileReplica db SndFileChunk {sndChunkId} NewSndChunkReplica {server, re
getNextSndChunkToUpload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Maybe SndFileChunk)
getNextSndChunkToUpload db server@ProtocolServer {host, port, keyHash} ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
chunk_ <-
maybeFirstRow toChunk $
DB.query
@@ -2608,7 +2612,7 @@ updateSndChunkReplicaStatus db replicaId status = do
getPendingSndFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer]
getPendingSndFilesServers db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
map toXFTPServer
<$> DB.query
db
@@ -2647,7 +2651,7 @@ getCleanupSndFilesDeleted db =
getSndFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBSndFileId, SndFileId, Maybe FilePath)]
getSndFilesExpired db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
DB.query
db
[sql|
@@ -2687,7 +2691,7 @@ getDeletedSndChunkReplica db deletedSndChunkReplicaId =
getNextDeletedSndChunkReplica :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Maybe DeletedSndChunkReplica)
getNextDeletedSndChunkReplica db ProtocolServer {host, port, keyHash} ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
replicaId_ :: Maybe Int64 <-
maybeFirstRow fromOnly $
DB.query
@@ -2716,7 +2720,7 @@ deleteDeletedSndChunkReplica db deletedSndChunkReplicaId =
getPendingDelFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer]
getPendingDelFilesServers db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
map toXFTPServer
<$> DB.query
db
@@ -2731,5 +2735,9 @@ getPendingDelFilesServers db ttl = do
deleteDeletedSndChunkReplicasExpired :: DB.Connection -> NominalDiffTime -> IO ()
deleteDeletedSndChunkReplicasExpired db ttl = do
cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime
cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime
DB.execute db "DELETE FROM deleted_snd_chunk_replicas WHERE created_at < ?" (Only cutoffTs)
$(J.deriveJSON defaultJSON ''UpMigration)
$(J.deriveToJSON (sumTypeJSON $ dropPrefix "ME") ''MigrationError)
@@ -1,7 +1,7 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TemplateHaskell #-}
module Simplex.Messaging.Agent.Store.SQLite.DB
( Connection (..),
@@ -20,13 +20,12 @@ where
import Control.Concurrent.STM
import Control.Monad (when)
import Data.Aeson (FromJSON, ToJSON)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Data.Int (Int64)
import Data.Time (diffUTCTime, getCurrentTime)
import Database.SQLite.Simple (FromRow, NamedParam, Query, ToRow)
import qualified Database.SQLite.Simple as SQL
import GHC.Generics (Generic)
import Simplex.Messaging.Parsers (defaultJSON)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (diffToMilliseconds)
@@ -41,9 +40,7 @@ data SlowQueryStats = SlowQueryStats
timeMax :: Int64,
timeAvg :: Int64
}
deriving (Show, Generic, FromJSON)
instance ToJSON SlowQueryStats where toEncoding = J.genericToEncoding J.defaultOptions
deriving (Show)
timeIt :: TMap Query SlowQueryStats -> Query -> IO a -> IO a
timeIt slow sql a = do
@@ -100,3 +97,5 @@ query_ Connection {conn, slow} sql = timeIt slow sql $ SQL.query_ conn sql
queryNamed :: FromRow r => Connection -> Query -> [NamedParam] -> IO [r]
queryNamed Connection {conn, slow} sql = timeIt slow sql . SQL.queryNamed conn sql
{-# INLINE queryNamed #-}
$(J.deriveJSON defaultJSON ''SlowQueryStats)
@@ -1,4 +1,3 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedLists #-}
@@ -27,8 +26,7 @@ module Simplex.Messaging.Agent.Store.SQLite.Migrations
where
import Control.Monad (forM_, when)
import Data.Aeson (ToJSON)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Data.List (intercalate, sortOn)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.Map as M
@@ -40,7 +38,6 @@ import Database.SQLite.Simple (Connection, Only (..), Query (..))
import qualified Database.SQLite.Simple as DB
import Database.SQLite.Simple.QQ (sql)
import qualified Database.SQLite3 as SQLite3
import GHC.Generics (Generic)
import Simplex.Messaging.Agent.Protocol (extraSMPServerHosts)
import Simplex.Messaging.Agent.Store.SQLite.Common
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220101_initial
@@ -169,11 +166,7 @@ data MigrationsToRun = MTRUp [Migration] | MTRDown [DownMigration] | MTRNone
data MTRError
= MTRENoDown {dbMigrations :: [String]}
| MTREDifferent {appMigration :: String, dbMigration :: String}
deriving (Eq, Show, Generic)
instance ToJSON MTRError where
toJSON = J.genericToJSON . sumTypeJSON $ dropPrefix "MTRE"
toEncoding = J.genericToEncoding . sumTypeJSON $ dropPrefix "MTRE"
deriving (Eq, Show)
mtrErrorDescription :: MTRError -> String
mtrErrorDescription = \case
@@ -192,3 +185,5 @@ migrationsToRun [] dbMs
migrationsToRun (a : as) (d : ds)
| name a == name d = migrationsToRun as ds
| otherwise = Left $ MTREDifferent (name a) (name d)
$(J.deriveJSON (sumTypeJSON $ dropPrefix "MTRE") ''MTRError)
@@ -1,4 +1,3 @@
{-# LANGUAGE LambdaCase #-}
module Simplex.Messaging.Agent.TRcvQueues
( TRcvQueues (getRcvQueues),
empty,
+28 -42
View File
@@ -1,6 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
@@ -10,6 +9,7 @@
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
@@ -86,8 +86,7 @@ import Control.Exception
import Control.Monad
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
@@ -97,13 +96,12 @@ import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Maybe (fromMaybe)
import Data.Time.Clock (UTCTime, getCurrentTime)
import GHC.Generics (Generic)
import Network.Socket (ServiceName)
import Numeric.Natural
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (dropPrefix, enumJSON)
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON)
import Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
@@ -190,14 +188,7 @@ data HostMode
HMOnion
| -- | prefer (or require) public hosts
HMPublic
deriving (Eq, Show, Generic)
instance FromJSON HostMode where
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "HM"
instance ToJSON HostMode where
toJSON = J.genericToJSON . enumJSON $ dropPrefix "HM"
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "HM"
deriving (Eq, Show)
-- | network configuration for the client
data NetworkConfig = NetworkConfig
@@ -223,21 +214,10 @@ data NetworkConfig = NetworkConfig
smpPingCount :: Int,
logTLSErrors :: Bool
}
deriving (Eq, Show, Generic, FromJSON)
instance ToJSON NetworkConfig where
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
deriving (Eq, Show)
data TransportSessionMode = TSMUser | TSMEntity
deriving (Eq, Show, Generic)
instance ToJSON TransportSessionMode where
toJSON = J.genericToJSON . enumJSON $ dropPrefix "TSM"
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TSM"
instance FromJSON TransportSessionMode where
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TSM"
deriving (Eq, Show)
defaultNetworkConfig :: NetworkConfig
defaultNetworkConfig =
@@ -429,11 +409,11 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize,
where
response entityId
| entityId == entId =
case respOrErr of
Left e -> Left $ PCEResponseError e
Right r -> case protocolError r of
Just e -> Left $ PCEProtocolError e
_ -> Right r
case respOrErr of
Left e -> Left $ PCEResponseError e
Right r -> case protocolError r of
Just e -> Left $ PCEProtocolError e
_ -> Right r
| otherwise = Left . PCEUnexpectedResponse $ bshow respOrErr
sendMsg :: Either err msg -> IO ()
sendMsg = \case
@@ -661,13 +641,13 @@ sendProtocolCommands c@ProtocolClient {batch, blockSize} cs = do
validate rs
| diff == 0 = pure $ L.fromList rs
| diff > 0 = do
putStrLn "send error: fewer responses than expected"
pure $ L.fromList $ rs <> replicate diff (Response "" $ Left $ PCETransportError TEBadBlock)
putStrLn "send error: fewer responses than expected"
pure $ L.fromList $ rs <> replicate diff (Response "" $ Left $ PCETransportError TEBadBlock)
| otherwise = do
putStrLn "send error: more responses than expected"
pure $ L.fromList $ take (L.length cs) rs
where
diff = L.length cs - length rs
putStrLn "send error: more responses than expected"
pure $ L.fromList $ take (L.length cs) rs
where
diff = L.length cs - length rs
streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO ()
streamProtocolCommands c@ProtocolClient {batch, blockSize} cs cb = do
@@ -688,8 +668,8 @@ sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do
(: []) <$> getResponse c r
data ClientBatch err msg
-- ByteString in CBTransmissions does not include count byte, it is added by tEncodeBatch
= CBTransmissions ByteString Int [Request err msg]
= -- ByteString in CBTransmissions does not include count byte, it is added by tEncodeBatch
CBTransmissions ByteString Int [Request err msg]
| CBTransmission ByteString (Request err msg)
| CBLargeTransmission (Request err msg)
@@ -713,9 +693,9 @@ batchClientTransmissions batch blkSize
encodeBatch :: ByteString -> Int -> [Request err msg] -> NonEmpty (PCTransmission err msg) -> (ClientBatch err msg, Maybe (NonEmpty (PCTransmission err msg)))
encodeBatch s n rs ts@((t, r) :| ts_)
| B.length s' <= blkSize - 3 && n < 255 =
case L.nonEmpty ts_ of
Just ts' -> encodeBatch s' n' rs' ts'
Nothing -> (CBTransmissions s' n' (reverse rs'), Nothing)
case L.nonEmpty ts_ of
Just ts' -> encodeBatch s' n' rs' ts'
Nothing -> (CBTransmissions s' n' (reverse rs'), Nothing)
| n == 0 = (CBLargeTransmission r, L.nonEmpty ts_)
| otherwise = (CBTransmissions s n (reverse rs), Just ts)
where
@@ -765,3 +745,9 @@ mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCo
r <- Request entId <$> newEmptyTMVar
TM.insert corrId r sentCommands
pure r
$(J.deriveJSON (enumJSON $ dropPrefix "HM") ''HostMode)
$(J.deriveJSON (enumJSON $ dropPrefix "TSM") ''TransportSessionMode)
$(J.deriveJSON defaultJSON ''NetworkConfig)
+14 -13
View File
@@ -18,7 +18,7 @@ import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Trans.Except
import Data.Bifunctor (first, bimap)
import Data.Bifunctor (bimap, first)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (partitionEithers)
@@ -36,11 +36,11 @@ import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, SMPServer, NtfPrivateSignKey, NotifierId, RcvPrivateSignKey, RecipientId)
import Simplex.Messaging.Protocol (BrokerMsg, NotifierId, NtfPrivateSignKey, ProtocolServer (..), QueueId, RcvPrivateSignKey, RecipientId, SMPServer)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport
import Simplex.Messaging.Util (catchAll_, ($>>=), toChunks)
import Simplex.Messaging.Util (catchAll_, toChunks, ($>>=))
import System.Timeout (timeout)
import UnliftIO (async)
import UnliftIO.Exception (Exception)
@@ -238,7 +238,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
(tempErrs, finalErrs) = partition (temporaryClientError . snd) errs
mapM_ (atomically . addSubscription ca srv) oks
mapM_ (liftIO . notify . CAResubscribed srv) $ L.nonEmpty $ map fst oks
mapM_ (atomically . removePendingSubscription ca srv . fst) finalErrs
mapM_ (atomically . removePendingSubscription ca srv . fst) finalErrs
mapM_ (liftIO . notify . CASubError srv) $ L.nonEmpty finalErrs
mapM_ (throwE . snd) $ listToMaybe tempErrs
@@ -281,7 +281,7 @@ subscribeQueue ca srv sub = do
handleErr e = do
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
removePendingSubscription ca srv $ fst sub
removePendingSubscription ca srv (fst sub)
throwE e
subscribeQueuesSMP :: SMPClientAgent -> SMPServer -> NonEmpty (RecipientId, RcvPrivateSignKey) -> IO (NonEmpty (RecipientId, Either SMPClientError ()))
@@ -300,14 +300,15 @@ subscribeQueues_ party ca srv subs = do
smpSubscribeQueues :: SMPSubParty -> SMPClientAgent -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateSignKey) -> IO (NonEmpty (QueueId, Either SMPClientError ()))
smpSubscribeQueues party ca smp srv subs = do
rs <- L.zip subs <$> subscribe smp (L.map swap subs)
atomically $ forM rs $ \(sub, r) -> (fst sub,) <$> case r of
Right () -> do
addSubscription ca srv $ first (party,) sub
pure $ Right ()
Left e -> do
when (e /= PCENetworkError && e /= PCEResponseTimeout) $
removePendingSubscription ca srv $ (party,) $ fst sub
pure $ Left e
atomically $ forM rs $ \(sub, r) ->
(fst sub,) <$> case r of
Right () -> do
addSubscription ca srv $ first (party,) sub
pure $ Right ()
Left e -> do
when (e /= PCENetworkError && e /= PCEResponseTimeout) $
removePendingSubscription ca srv (party, fst sub)
pure $ Left e
where
subscribe = case party of
SPRecipient -> subscribeSMPQueues
+2 -2
View File
@@ -680,9 +680,9 @@ instance CryptoSignature ASignature where
signatureBytes (ASignature _ sig) = signatureBytes sig
decodeSignature s
| B.length s == Ed25519.signatureSize =
ASignature SEd25519 . SignatureEd25519 <$> ed Ed25519.signature s
ASignature SEd25519 . SignatureEd25519 <$> ed Ed25519.signature s
| B.length s == Ed448.signatureSize =
ASignature SEd448 . SignatureEd448 <$> ed Ed448.signature s
ASignature SEd448 . SignatureEd448 <$> ed Ed448.signature s
| otherwise = Left "bad signature size"
where
ed alg = first show . CE.eitherCryptoError . alg
+9 -13
View File
@@ -1,7 +1,6 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module Simplex.Messaging.Crypto.File
( CryptoFile (..),
@@ -24,19 +23,18 @@ where
import Control.Exception
import Control.Monad
import Control.Monad.Except
import Data.Aeson (FromJSON, ToJSON)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import qualified Data.ByteArray as BA
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as LB
import Data.List.NonEmpty (NonEmpty (..))
import Data.Maybe (isJust)
import GHC.Generics (Generic)
import Simplex.Messaging.Client.Agent ()
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Lazy (LazyByteString)
import qualified Simplex.Messaging.Crypto.Lazy as LC
import Simplex.Messaging.Parsers (defaultJSON)
import Simplex.Messaging.Util (liftEitherWith)
import System.Directory (getFileSize)
import UnliftIO (Handle, IOMode (..), liftIO)
@@ -45,16 +43,10 @@ import UnliftIO.STM
-- Possibly encrypted local file
data CryptoFile = CryptoFile {filePath :: FilePath, cryptoArgs :: Maybe CryptoFileArgs}
deriving (Eq, Show, Generic, FromJSON)
instance ToJSON CryptoFile where
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
deriving (Eq, Show)
data CryptoFileArgs = CFArgs {fileKey :: C.SbKey, fileNonce :: C.CbNonce}
deriving (Eq, Show, Generic, FromJSON)
instance ToJSON CryptoFileArgs where toEncoding = J.genericToEncoding J.defaultOptions
deriving (Eq, Show)
data CryptoFileHandle = CFHandle Handle (Maybe (TVar LC.SbState))
@@ -124,3 +116,7 @@ getFileContentsSize :: CryptoFile -> IO Integer
getFileContentsSize (CryptoFile path cfArgs) = do
size <- getFileSize path
pure $ if isJust cfArgs then size - fromIntegral C.authTagSize else size
$(J.deriveJSON defaultJSON ''CryptoFileArgs)
$(J.deriveJSON defaultJSON ''CryptoFile)
+8 -9
View File
@@ -1,7 +1,6 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
@@ -90,10 +89,10 @@ unPad = fmap snd . splitLen
splitLen :: LazyByteString -> Either CryptoError (Int64, LazyByteString)
splitLen padded
| LB.length lenStr == 8 = case smpDecode $ LB.toStrict lenStr of
Right len
| len < 0 -> Left CryptoInvalidMsgError
| otherwise -> Right (len, LB.take len rest)
Left _ -> Left CryptoInvalidMsgError
Right len
| len < 0 -> Left CryptoInvalidMsgError
| otherwise -> Right (len, LB.take len rest)
Left _ -> Left CryptoInvalidMsgError
| otherwise = Left CryptoInvalidMsgError
where
(lenStr, rest) = LB.splitAt 8 padded
@@ -112,10 +111,10 @@ sbDecrypt :: SbKey -> CbNonce -> LazyByteString -> Either CryptoError LazyByteSt
sbDecrypt (SbKey key) (CbNonce nonce) packet
| LB.length tag' < 16 = Left CBDecryptError
| otherwise = case secretBox sbDecryptChunk key nonce c of
Right (tag :| cs)
| BA.constEq (LB.toStrict tag') tag -> unPad $ LB.fromChunks cs
| otherwise -> Left CBDecryptError
Left e -> Left e
Right (tag :| cs)
| BA.constEq (LB.toStrict tag') tag -> unPad $ LB.fromChunks cs
| otherwise -> Left CBDecryptError
Left e -> Left e
where
(tag', c) = LB.splitAt 16 packet
+34 -27
View File
@@ -1,12 +1,12 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
@@ -20,8 +20,9 @@ import Control.Monad.Trans.Except
import Crypto.Cipher.AES (AES256)
import Crypto.Hash (SHA512)
import qualified Crypto.KDF.HKDF as H
import Data.Aeson (FromJSON, ToJSON)
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as JQ
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as LB
@@ -33,12 +34,11 @@ import Data.Typeable (Typeable)
import Data.Word (Word32)
import Database.SQLite.Simple.FromField (FromField (..))
import Database.SQLite.Simple.ToField (ToField (..))
import GHC.Generics
import Simplex.Messaging.Agent.QueryString
import Simplex.Messaging.Crypto
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (blobFieldDecoder, parseE, parseE')
import Simplex.Messaging.Parsers (blobFieldDecoder, defaultJSON, parseE, parseE')
import Simplex.Messaging.Version
currentE2EEncryptVersion :: Version
@@ -112,11 +112,11 @@ x3dh v (sk1, rk1) dh1 dh2 dh3 =
(hk, nhk, sk)
-- for backwards compatibility with clients using agent version before 3.4.0
| v == 1 =
let (hk', rest) = B.splitAt 32 dhs
in uncurry (hk',,) $ B.splitAt 32 rest
let (hk', rest) = B.splitAt 32 dhs
in uncurry (hk',,) $ B.splitAt 32 rest
| otherwise =
let salt = B.replicate 64 '\0'
in hkdf3 salt dhs "SimpleXX3DH"
let salt = B.replicate 64 '\0'
in hkdf3 salt dhs "SimpleXX3DH"
type RatchetX448 = Ratchet 'X448
@@ -135,29 +135,20 @@ data Ratchet a = Ratchet
rcNHKs :: HeaderKey,
rcNHKr :: HeaderKey
}
deriving (Eq, Show, Generic, FromJSON)
instance AlgorithmI a => ToJSON (Ratchet a) where
toEncoding = J.genericToEncoding J.defaultOptions
deriving (Eq, Show)
data SndRatchet a = SndRatchet
{ rcDHRr :: PublicKey a,
rcCKs :: RatchetKey,
rcHKs :: HeaderKey
}
deriving (Eq, Show, Generic, FromJSON)
instance AlgorithmI a => ToJSON (SndRatchet a) where
toEncoding = J.genericToEncoding J.defaultOptions
deriving (Eq, Show)
data RcvRatchet = RcvRatchet
{ rcCKr :: RatchetKey,
rcHKr :: HeaderKey
}
deriving (Eq, Show, Generic, FromJSON)
instance ToJSON RcvRatchet where
toEncoding = J.genericToEncoding J.defaultOptions
deriving (Eq, Show)
type SkippedMsgKeys = Map HeaderKey SkippedHdrMsgKeys
@@ -204,10 +195,6 @@ instance ToJSON RatchetKey where
instance FromJSON RatchetKey where
parseJSON = fmap RatchetKey . strParseJSON "Key"
instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode
instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder J.eitherDecodeStrict'
instance ToField MessageKey where toField = toField . smpEncode
instance FromField MessageKey where fromField = blobFieldDecoder smpDecode
@@ -428,9 +415,9 @@ rcDecrypt rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do
| rcNr + maxSkip < untilN = Left $ CERatchetTooManySkipped (untilN + 1 - rcNr)
| rcNr == untilN = Right (r, M.empty)
| otherwise =
let (rcCKr', rcNr', mks) = advanceRcvRatchet (untilN - rcNr) rcCKr rcNr M.empty
r' = r {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr'}
in Right (r', M.singleton rcHKr mks)
let (rcCKr', rcNr', mks) = advanceRcvRatchet (untilN - rcNr) rcCKr rcNr M.empty
r' = r {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr'}
in Right (r', M.singleton rcHKr mks)
advanceRcvRatchet :: Word32 -> RatchetKey -> Word32 -> SkippedHdrMsgKeys -> (RatchetKey, Word32, SkippedHdrMsgKeys)
advanceRcvRatchet 0 ck msgNs mks = (ck, msgNs, mks)
advanceRcvRatchet n ck msgNs mks =
@@ -493,3 +480,23 @@ hkdf3 salt ikm info = (s1, s2, s3)
out = H.expand prk info 96
(s1, rest) = B.splitAt 32 out
(s2, s3) = B.splitAt 32 rest
$(JQ.deriveJSON defaultJSON ''RcvRatchet)
instance AlgorithmI a => ToJSON (SndRatchet a) where
toEncoding = $(JQ.mkToEncoding defaultJSON ''SndRatchet)
toJSON = $(JQ.mkToJSON defaultJSON ''SndRatchet)
instance AlgorithmI a => FromJSON (SndRatchet a) where
parseJSON = $(JQ.mkParseJSON defaultJSON ''SndRatchet)
instance AlgorithmI a => ToJSON (Ratchet a) where
toEncoding = $(JQ.mkToEncoding defaultJSON ''Ratchet)
toJSON = $(JQ.mkToJSON defaultJSON ''Ratchet)
instance AlgorithmI a => FromJSON (Ratchet a) where
parseJSON = $(JQ.mkParseJSON defaultJSON ''Ratchet)
instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode
instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder J.eitherDecodeStrict'
+1 -1
View File
@@ -109,7 +109,7 @@ lenP = fromIntegral . c2w <$> A.anyChar
{-# INLINE lenP #-}
instance Encoding a => Encoding (Maybe a) where
smpEncode s = maybe "0" (("1" <>) . smpEncode) s
smpEncode = maybe "0" (("1" <>) . smpEncode)
{-# INLINE smpEncode #-}
smpP =
smpP >>= \case
@@ -518,7 +518,7 @@ instance ToJSON NtfTknStatus where
instance FromJSON NtfTknStatus where
parseJSON = J.withText "NtfTknStatus" $ either fail pure . smpDecode . encodeUtf8
checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e)
checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of
Just Refl -> Right c
+36 -35
View File
@@ -16,13 +16,12 @@ import Control.Logger.Simple
import Control.Monad
import Control.Monad.Except
import Control.Monad.Reader
import Data.Bifunctor (second)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
import Data.Int (Int64)
import Data.List (intercalate, sort)
import Data.List.NonEmpty (NonEmpty(..))
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
@@ -205,10 +204,10 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
Just err -> (subs, oks, err : errs) -- permanent error, log and don't retry subscription
Nothing -> (sub : subs, oks, errs) -- temporary error, retry subscription
-- | Subscribe to queues. The list of results can have a different order.
-- \| Subscribe to queues. The list of results can have a different order.
subscribeQueues :: SMPServer -> NonEmpty NtfSubData -> IO (NonEmpty (NtfSubData, Either SMPClientError ()))
subscribeQueues srv subs =
L.map (second snd) . L.zip subs <$> subscribeQueuesNtfs ca srv (L.map sub subs)
L.zipWith (\s r -> (s, snd r)) subs <$> subscribeQueuesNtfs ca srv (L.map sub subs)
where
sub NtfSubData {smpQueue = SMPQueueNtf {notifierId}, notifierKey} = (notifierId, notifierKey)
@@ -248,10 +247,11 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
forM errs (\((_, ntfId), err) -> handleSubError (SMPQueueNtf srv ntfId) err)
>>= logSubErrors srv . catMaybes . L.toList
logSubStatus srv event n = when (n > 0) $
logInfo $ "SMP server " <> event <> " " <> showServer' srv <> " (" <> tshow n <> " subscriptions)"
logSubStatus srv event n =
when (n > 0) . logInfo $
"SMP server " <> event <> " " <> showServer' srv <> " (" <> tshow n <> " subscriptions)"
logSubErrors :: SMPServer -> [NtfSubStatus] -> M ()
logSubErrors :: SMPServer -> [NtfSubStatus] -> M ()
logSubErrors srv errs = forM_ (L.group $ sort errs) $ \errs' -> do
logError $ "SMP subscription errors on server " <> showServer' srv <> ": " <> tshow (L.head errs') <> " (" <> tshow (length errs') <> " errors)"
@@ -289,14 +289,14 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do
case ntf of
PNVerification _
| status /= NTInvalid && status /= NTExpired ->
deliverNotification pp tkn ntf >>= \case
Right _ -> do
status_ <- atomically $ stateTVar tknStatus $ \case
NTActive -> (Nothing, NTActive)
NTConfirmed -> (Nothing, NTConfirmed)
_ -> (Just NTConfirmed, NTConfirmed)
forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status'
_ -> pure ()
deliverNotification pp tkn ntf >>= \case
Right _ -> do
status_ <- atomically $ stateTVar tknStatus $ \case
NTActive -> (Nothing, NTActive)
NTConfirmed -> (Nothing, NTConfirmed)
_ -> (Just NTConfirmed, NTConfirmed)
forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status'
_ -> pure ()
| otherwise -> logError "bad notification token status"
PNCheckMessages -> checkActiveTkn status $ do
void $ deliverNotification pp tkn ntf
@@ -345,7 +345,8 @@ runNtfClientTransport th@THandle {sessionId} = do
raceAny_ ([liftIO $ send th c, client c s ps, receive th c] <> disconnectThread_ c expCfg)
`finally` liftIO (clientDisconnected c)
where
disconnectThread_ c expCfg = maybe [] ((: []) . liftIO . disconnectTransport th c activeAt) expCfg
disconnectThread_ c (Just expCfg) = [liftIO $ disconnectTransport th c activeAt expCfg]
disconnectThread_ _ _ = []
clientDisconnected :: NtfServerClient -> IO ()
clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False
@@ -463,16 +464,16 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
else pure $ NRErr AUTH
TVFY code -- this allows repeated verification for cases when client connection dropped before server response
| (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do
logDebug "TVFY - token verified"
st <- asks store
updateTknStatus tkn NTActive
tIds <- atomically $ removeInactiveTokenRegistrations st tkn
forM_ tIds cancelInvervalNotifications
incNtfStat tknVerified
pure NROk
logDebug "TVFY - token verified"
st <- asks store
updateTknStatus tkn NTActive
tIds <- atomically $ removeInactiveTokenRegistrations st tkn
forM_ tIds cancelInvervalNotifications
incNtfStat tknVerified
pure NROk
| otherwise -> do
logDebug "TVFY - incorrect code or token status"
pure $ NRErr AUTH
logDebug "TVFY - incorrect code or token status"
pure $ NRErr AUTH
TCHK -> do
logDebug "TCHK"
pure $ NRTkn status
@@ -509,16 +510,16 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
TCRN int
| int < 20 -> pure $ NRErr QUOTA
| otherwise -> do
logDebug "TCRN"
atomically $ writeTVar tknCronInterval int
atomically (TM.lookup tknId intervalNotifiers) >>= \case
Nothing -> runIntervalNotifier int
Just IntervalNotifier {interval, action} ->
unless (interval == int) $ do
uninterruptibleCancel action
runIntervalNotifier int
withNtfLog $ \s -> logTokenCron s tknId int
pure NROk
logDebug "TCRN"
atomically $ writeTVar tknCronInterval int
atomically (TM.lookup tknId intervalNotifiers) >>= \case
Nothing -> runIntervalNotifier int
Just IntervalNotifier {interval, action} ->
unless (interval == int) $ do
uninterruptibleCancel action
runIntervalNotifier int
withNtfLog $ \s -> logTokenCron s tknId int
pure NROk
where
runIntervalNotifier interval = do
action <- async . intervalNotifier $ fromIntegral interval * 1000000 * 60
@@ -34,7 +34,7 @@ import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ATransport)
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
import System.IO (IOMode (..))
import System.Mem.Weak (Weak)
import UnliftIO.STM
@@ -7,7 +7,6 @@
module Simplex.Messaging.Notifications.Server.Main where
import Data.Either (fromRight)
import Data.Functor (($>))
import Data.Ini (lookupValue, readIniFile)
import Data.Maybe (fromMaybe)
@@ -92,7 +91,7 @@ ntfServerCLI cfgPath logPath =
hSetBuffering stdout LineBuffering
hSetBuffering stderr LineBuffering
fp <- checkSavedFingerprint cfgPath defaultX509Config
let host = fromRight "<hostnames>" $ T.unpack <$> lookupValue "TRANSPORT" "host" ini
let host = either (const "<hostnames>") T.unpack $ lookupValue "TRANSPORT" "host" ini
port = T.unpack $ strictIni "TRANSPORT" "port" ini
cfg@NtfServerConfig {transports, storeLogFile} = serverConfig
srv = ProtoServerWithAuth (NtfServer [THDomainName host] (if port == "443" then "" else port) (C.KeyHash fp)) Nothing
@@ -1,9 +1,9 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use newtype instead of data" #-}
@@ -23,15 +23,15 @@ import qualified Crypto.Store.PKCS8 as PK
import Data.ASN1.BinaryEncoding (DER (..))
import Data.ASN1.Encoding
import Data.ASN1.Types
import Data.Aeson (FromJSON, ToJSON, (.=))
import Data.Aeson (ToJSON, (.=))
import qualified Data.Aeson as J
import qualified Data.Aeson.Encoding as JE
import qualified Data.Aeson.TH as JQ
import Data.Bifunctor (first)
import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Builder (lazyByteString)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Lazy.Char8 as LB
import qualified Data.CaseInsensitive as CI
import Data.Int (Int64)
import Data.Map.Strict (Map)
import Data.Maybe (isNothing)
@@ -40,8 +40,7 @@ import qualified Data.Text as T
import Data.Time.Clock.System
import qualified Data.X509 as X
import qualified Data.X509.CertificateStore as XS
import GHC.Generics (Generic)
import Network.HTTP.Types (HeaderName, Status)
import Network.HTTP.Types (Status)
import qualified Network.HTTP.Types as N
import Network.HTTP2.Client (Request)
import qualified Network.HTTP2.Client as H
@@ -49,7 +48,9 @@ import Network.Socket (HostName, ServiceName)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Server.Push.APNS.Internal
import Simplex.Messaging.Notifications.Server.Store (NtfTknData (..))
import Simplex.Messaging.Parsers (defaultJSON)
import Simplex.Messaging.Protocol (EncNMsgMeta)
import Simplex.Messaging.Transport.HTTP2 (HTTP2Body (..))
import Simplex.Messaging.Transport.HTTP2.Client
@@ -61,17 +62,13 @@ data JWTHeader = JWTHeader
{ alg :: Text, -- key algorithm, ES256 for APNS
kid :: Text -- key ID
}
deriving (Show, Generic)
instance ToJSON JWTHeader where toEncoding = J.genericToEncoding J.defaultOptions
deriving (Show)
data JWTClaims = JWTClaims
{ iss :: Text, -- issuer, team ID for APNS
iat :: Int64 -- issue time, seconds from epoch
}
deriving (Show, Generic)
instance ToJSON JWTClaims where toEncoding = J.genericToEncoding J.defaultOptions
deriving (Show)
data JWTToken = JWTToken JWTHeader JWTClaims
deriving (Show)
@@ -83,6 +80,10 @@ mkJWTToken hdr iss = do
type SignedJWTToken = ByteString
$(JQ.deriveToJSON defaultJSON ''JWTHeader)
$(JQ.deriveToJSON defaultJSON ''JWTClaims)
signedJWTToken :: EC.PrivateKey -> JWTToken -> IO SignedJWTToken
signedJWTToken pk (JWTToken hdr claims) = do
let hc = jwtEncode hdr <> "." <> jwtEncode claims
@@ -122,24 +123,13 @@ instance StrEncoding PNMessageData where
pure PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta}
data APNSNotification = APNSNotification {aps :: APNSNotificationBody, notificationData :: Maybe J.Value}
deriving (Show, Generic)
instance ToJSON APNSNotification where
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
deriving (Show)
data APNSNotificationBody
= APNSBackground {contentAvailable :: Int}
| APNSMutableContent {mutableContent :: Int, alert :: APNSAlertBody, category :: Maybe Text}
| APNSAlert {alert :: APNSAlertBody, badge :: Maybe Int, sound :: Maybe Text, category :: Maybe Text}
deriving (Show, Generic)
apnsJSONOptions :: J.Options
apnsJSONOptions = J.defaultOptions {J.omitNothingFields = True, J.sumEncoding = J.UntaggedValue, J.fieldLabelModifier = J.camelTo2 '-'}
instance ToJSON APNSNotificationBody where
toJSON = J.genericToJSON apnsJSONOptions
toEncoding = J.genericToEncoding apnsJSONOptions
deriving (Show)
type APNSNotificationData = Map Text Text
@@ -305,6 +295,10 @@ apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case
-- apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing}
$(JQ.deriveToJSON apnsJSONOptions ''APNSNotificationBody)
$(JQ.deriveToJSON defaultJSON ''APNSNotification)
apnsRequest :: APNSPushClient -> ByteString -> APNSNotification -> IO Request
apnsRequest c tkn ntf@APNSNotification {aps} = do
signedJWT <- getApnsJWTToken c
@@ -337,7 +331,8 @@ type PushProviderClient = NtfTknData -> PushNotification -> ExceptT PushProvider
-- this is not a newtype on purpose to have a correct JSON encoding as a record
data APNSErrorResponse = APNSErrorResponse {reason :: Text}
deriving (Generic, FromJSON)
$(JQ.deriveFromJSON defaultJSON ''APNSErrorResponse)
apnsPushProviderClient :: APNSPushClient -> PushProviderClient
apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {token = DeviceToken _ tknStr} pn = do
@@ -356,15 +351,15 @@ apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {toke
result status reason'
| status == Just N.ok200 = pure ()
| status == Just N.badRequest400 =
case reason' of
"BadDeviceToken" -> throwError PPTokenInvalid
"DeviceTokenNotForTopic" -> throwError PPTokenInvalid
"TopicDisallowed" -> throwError PPPermanentError
_ -> err status reason'
case reason' of
"BadDeviceToken" -> throwError PPTokenInvalid
"DeviceTokenNotForTopic" -> throwError PPTokenInvalid
"TopicDisallowed" -> throwError PPPermanentError
_ -> err status reason'
| status == Just N.forbidden403 = case reason' of
"ExpiredProviderToken" -> throwError PPPermanentError -- there should be no point retrying it as the token was refreshed
"InvalidProviderToken" -> throwError PPPermanentError
_ -> err status reason'
"ExpiredProviderToken" -> throwError PPPermanentError -- there should be no point retrying it as the token was refreshed
"InvalidProviderToken" -> throwError PPPermanentError
_ -> err status reason'
| status == Just N.gone410 = throwError PPTokenInvalid
| status == Just N.serviceUnavailable503 = liftIO (disconnectApnsHTTP2Client c) >> throwError PPRetryLater
-- Just tooManyRequests429 -> TooManyRequests - too many requests for the same token
@@ -372,12 +367,3 @@ apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {toke
err :: Maybe Status -> Text -> ExceptT PushProviderError IO ()
err s r = throwError $ PPResponseError s r
liftHTTPS2 a = ExceptT $ first PPConnection <$> a
hApnsTopic :: HeaderName
hApnsTopic = CI.mk "apns-topic"
hApnsPushType :: HeaderName
hApnsPushType = CI.mk "apns-push-type"
hApnsPriority :: HeaderName
hApnsPriority = CI.mk "apns-priority"
@@ -0,0 +1,20 @@
{-# LANGUAGE OverloadedStrings #-}
module Simplex.Messaging.Notifications.Server.Push.APNS.Internal where
import qualified Data.Aeson as J
import qualified Data.CaseInsensitive as CI
import Network.HTTP.Types (HeaderName)
import Simplex.Messaging.Parsers (defaultJSON)
hApnsTopic :: HeaderName
hApnsTopic = CI.mk "apns-topic"
hApnsPushType :: HeaderName
hApnsPushType = CI.mk "apns-push-type"
hApnsPriority :: HeaderName
hApnsPriority = CI.mk "apns-priority"
apnsJSONOptions :: J.Options
apnsJSONOptions = defaultJSON {J.sumEncoding = J.UntaggedValue, J.fieldLabelModifier = J.camelTo2 '-'}
@@ -4,7 +4,6 @@
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Notifications.Server.Store where
@@ -5,7 +5,6 @@
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StrictData #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
module Simplex.Messaging.Notifications.Server.StoreLog
@@ -50,9 +50,9 @@ ntfServerHandshake c kh ntfVRange = do
getHandshake th >>= \case
NtfClientHandshake {ntfVersion, keyHash}
| keyHash /= kh ->
throwError $ TEHandshake IDENTITY
| ntfVersion `isCompatible` ntfVRange -> do
pure (th :: THandle c) {thVersion = ntfVersion}
throwError $ TEHandshake IDENTITY
| ntfVersion `isCompatible` ntfVRange ->
pure (th :: THandle c) {thVersion = ntfVersion}
| otherwise -> throwError $ TEHandshake VERSION
-- | Notifcations server client transport handshake.
+4 -1
View File
@@ -85,7 +85,7 @@ blobFieldDecoder dec = \case
Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e)
f -> returnError ConversionFailed f "expecting SQLBlob column type"
fromTextField_ :: (Typeable a) => (Text -> Maybe a) -> Field -> Ok a
fromTextField_ :: Typeable a => (Text -> Maybe a) -> Field -> Ok a
fromTextField_ fromText = \case
f@(Field (SQLText t) _) ->
case fromText t of
@@ -151,3 +151,6 @@ singleFieldJSON_ objectTag tagModifier =
J.nullaryToObject = True,
J.omitNothingFields = True
}
defaultJSON :: J.Options
defaultJSON = J.defaultOptions {J.omitNothingFields = True}
+33 -48
View File
@@ -1,7 +1,6 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
@@ -17,6 +16,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
@@ -135,7 +135,7 @@ module Simplex.Messaging.Protocol
noAuthSrv,
-- * TCP transport functions
TransportBatch(..),
TransportBatch (..),
tPut,
tPutLog,
tGet,
@@ -155,7 +155,7 @@ import Control.Applicative (optional, (<|>))
import Control.Concurrent (threadDelay)
import Control.Monad.Except
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Char8 (ByteString)
@@ -170,9 +170,7 @@ import Data.Maybe (isJust, isNothing)
import Data.String
import Data.Time.Clock.System (SystemTime (..))
import Data.Type.Equality
import GHC.Generics (Generic)
import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
import Generic.Random (genericArbitraryU)
import Network.Socket (HostName, ServiceName)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
@@ -182,7 +180,6 @@ import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts (..))
import Simplex.Messaging.Util (bshow, eitherToMaybe, (<$?>))
import Simplex.Messaging.Version
import Test.QuickCheck (Arbitrary (..))
currentSMPClientVersion :: Version
currentSMPClientVersion = 2
@@ -309,17 +306,19 @@ instance StrEncoding SubscriptionMode where
SMSubscribe -> "subscribe"
SMOnlyCreate -> "only-create"
strP =
(A.string "subscribe" $> SMSubscribe) <|> (A.string "only-create" $> SMOnlyCreate)
<?> "SubscriptionMode"
(A.string "subscribe" $> SMSubscribe)
<|> (A.string "only-create" $> SMOnlyCreate)
<?> "SubscriptionMode"
instance Encoding SubscriptionMode where
smpEncode = \case
SMSubscribe -> "S"
SMOnlyCreate -> "C"
smpP = A.anyChar >>= \case
'S' -> pure SMSubscribe
'C' -> pure SMOnlyCreate
_ -> fail "bad SubscriptionMode"
smpP =
A.anyChar >>= \case
'S' -> pure SMSubscribe
'C' -> pure SMOnlyCreate
_ -> fail "bad SubscriptionMode"
data BrokerMsg where
-- SMP broker messages (responses, client messages, notifications)
@@ -472,9 +471,7 @@ instance Encoding NMsgMeta where
-- it must be data for correct JSON encoding
data MsgFlags = MsgFlags {notification :: Bool}
deriving (Eq, Show, Generic, FromJSON)
instance ToJSON MsgFlags where toEncoding = J.genericToEncoding J.defaultOptions
deriving (Eq, Show)
-- this encoding should not become bigger than 7 bytes (currently it is 1 byte)
instance Encoding MsgFlags where
@@ -997,14 +994,7 @@ data ErrorType
INTERNAL
| -- | used internally, never returned by the server (to be removed)
DUPLICATE_ -- not part of SMP protocol, used internally
deriving (Eq, Generic, Read, Show)
instance ToJSON ErrorType where
toJSON = J.genericToJSON $ sumTypeJSON id
toEncoding = J.genericToEncoding $ sumTypeJSON id
instance FromJSON ErrorType where
parseJSON = J.genericParseJSON $ sumTypeJSON id
deriving (Eq, Read, Show)
instance StrEncoding ErrorType where
strEncode = \case
@@ -1026,18 +1016,7 @@ data CommandError
HAS_AUTH
| -- | transmission has no required entity ID (e.g. SMP queue)
NO_ENTITY
deriving (Eq, Generic, Read, Show)
instance ToJSON CommandError where
toJSON = J.genericToJSON $ sumTypeJSON id
toEncoding = J.genericToEncoding $ sumTypeJSON id
instance FromJSON CommandError where
parseJSON = J.genericParseJSON $ sumTypeJSON id
instance Arbitrary ErrorType where arbitrary = genericArbitraryU
instance Arbitrary CommandError where arbitrary = genericArbitraryU
deriving (Eq, Read, Show)
-- | SMP transmission parser.
transmissionP :: Parser RawTransmission
@@ -1306,7 +1285,7 @@ tPutLog th s = do
pure r
-- ByteString does not include length byte, it is added by tEncodeBatch
data TransportBatch = TBTransmissions Int ByteString | TBTransmission ByteString | TBLargeTransmission
data TransportBatch = TBTransmissions Int ByteString | TBTransmission ByteString | TBLargeTransmission
-- | encodes and batches transmissions into blocks,
batchTransmissions :: Bool -> Int -> NonEmpty SentRawTransmission -> [TransportBatch]
@@ -1319,22 +1298,22 @@ batchTransmissions batch bSize
let (n, s, ts_) = encodeBatch 0 "" ts
r = if n == 0 then TBLargeTransmission else TBTransmissions n s
rs' = r : rs
in case ts_ of
Just ts' -> mkBatch rs' ts'
_ -> rs'
in case ts_ of
Just ts' -> mkBatch rs' ts'
_ -> rs'
mkBatch1 :: ByteString -> TransportBatch
mkBatch1 s = if B.length s > bSize - 2 then TBLargeTransmission else TBTransmission s
encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString))
encodeBatch n s ts@(t :| ts_)
| n == 255 = (n, s, Just ts)
| otherwise =
let s' = s <> smpEncode (Large t)
n' = n + 1
in if B.length s' > bSize - 3 -- one byte is reserved for the number of messages in the batch
then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts
else case L.nonEmpty ts_ of
Just ts' -> encodeBatch n' s' ts'
_ -> (n', s', Nothing)
let s' = s <> smpEncode (Large t)
n' = n + 1
in if B.length s' > bSize - 3 -- one byte is reserved for the number of messages in the batch
then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts
else case L.nonEmpty ts_ of
Just ts' -> encodeBatch n' s' ts'
_ -> (n', s', Nothing)
tEncode :: SentRawTransmission -> ByteString
tEncode (sig, t) = smpEncode (C.signatureBytes sig) <> t
@@ -1373,8 +1352,8 @@ tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => SessionId ->
tDecodeParseValidate sessionId v = \case
Right RawTransmission {signature, signed, sessId, corrId, entityId, command}
| sessId == sessionId ->
let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature
in either (const $ tError corrId) (tParseValidate signed) decodedTransmission
let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature
in either (const $ tError corrId) (tParseValidate signed) decodedTransmission
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession))
Left _ -> tError ""
where
@@ -1385,3 +1364,9 @@ tDecodeParseValidate sessionId v = \case
tParseValidate signed t@(sig, corrId, entityId, command) =
let cmd = parseProtocol @err @cmd v command >>= checkCredentials t
in (sig, signed, (CorrId corrId, entityId, cmd))
$(J.deriveJSON defaultJSON ''MsgFlags)
$(J.deriveJSON (sumTypeJSON id) ''CommandError)
$(J.deriveJSON (sumTypeJSON id) ''ErrorType)
+30 -30
View File
@@ -1,6 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
@@ -8,6 +7,7 @@
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -116,9 +116,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
restoreServerMessages
restoreServerStats
raceAny_
( serverThread s "server subscribedQ" subscribedQ subscribers subscriptions cancelSub :
serverThread s "server ntfSubscribedQ" ntfSubscribedQ Env.notifiers ntfSubscriptions (\_ -> pure ()) :
map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg
( serverThread s "server subscribedQ" subscribedQ subscribers subscriptions cancelSub
: serverThread s "server ntfSubscribedQ" ntfSubscribedQ Env.notifiers ntfSubscriptions (\_ -> pure ())
: map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg
)
`finally` withLock (savingLock s) "final" (saveServer False)
where
@@ -148,7 +148,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
updateSubscribers :: STM (Maybe (QueueId, Client))
updateSubscribers = do
(qId, clnt) <- readTQueue $ subQ s
let clientToBeNotified = \c' ->
let clientToBeNotified c' =
if sameClientSession clnt c'
then pure Nothing
else do
@@ -277,9 +277,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
hPutStrLn h $ "Clients: " <> show (length clients)
forM_ (M.toList clients) $ \(cid, Client {sessionId, connected, activeAt, subscriptions}) -> do
hPutStrLn h . B.unpack $ "Client " <> encode cid <> " $" <> encode sessionId
readTVarIO connected >>= hPutStrLn h . (" connected: " <>) . show
readTVarIO activeAt >>= hPutStrLn h . (" activeAt: " <>) . B.unpack . strEncode
readTVarIO subscriptions >>= hPutStrLn h . (" subscriptions: " <>) . show . M.size
readTVarIO connected >>= hPutStrLn h . (" connected: " <>) . show
readTVarIO activeAt >>= hPutStrLn h . (" activeAt: " <>) . B.unpack . strEncode
readTVarIO subscriptions >>= hPutStrLn h . (" subscriptions: " <>) . show . M.size
CPStats -> do
ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, msgSentNtf, msgRecvNtf, qCount, msgCount} <- unliftIO u $ asks serverStats
putStat "fromTime" fromTime
@@ -666,27 +666,27 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
sendMessage qr msgFlags msgBody
| B.length msgBody > maxMessageLength = pure $ err LARGE_MSG
| otherwise = case status qr of
QueueOff -> return $ err AUTH
QueueActive ->
case C.maxLenBS msgBody of
Left _ -> pure $ err LARGE_MSG
Right body -> do
msg_ <- time "SEND" $ do
q <- getStoreMsgQueue "SEND" $ recipientId qr
expireMessages q
atomically . writeMsg q =<< mkMessage body
case msg_ of
Nothing -> pure $ err QUOTA
Just msg -> time "SEND ok" $ do
stats <- asks serverStats
when (notification msgFlags) $ do
atomically . trySendNotification msg =<< asks idsDrg
atomically $ modifyTVar' (msgSentNtf stats) (+ 1)
atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr)
atomically $ modifyTVar' (msgSent stats) (+ 1)
atomically $ modifyTVar' (msgCount stats) (subtract 1)
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
pure ok
QueueOff -> return $ err AUTH
QueueActive ->
case C.maxLenBS msgBody of
Left _ -> pure $ err LARGE_MSG
Right body -> do
msg_ <- time "SEND" $ do
q <- getStoreMsgQueue "SEND" $ recipientId qr
expireMessages q
atomically . writeMsg q =<< mkMessage body
case msg_ of
Nothing -> pure $ err QUOTA
Just msg -> time "SEND ok" $ do
stats <- asks serverStats
when (notification msgFlags) $ do
atomically . trySendNotification msg =<< asks idsDrg
atomically $ modifyTVar' (msgSentNtf stats) (+ 1)
atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr)
atomically $ modifyTVar' (msgSent stats) (+ 1)
atomically $ modifyTVar' (msgCount stats) (subtract 1)
atomically $ updatePeriodStats (activeQueues stats) (recipientId qr)
pure ok
where
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
mkMessage body = do
@@ -767,7 +767,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
msgTs' = msg.msgTs
setDelivered :: Sub -> Message -> STM Bool
setDelivered s msg = tryPutTMVar (delivered s) $ msg.msgId
setDelivered s msg = tryPutTMVar (delivered s) msg.msgId
getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue
getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do
+1 -1
View File
@@ -31,7 +31,7 @@ import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ATransport)
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
import Simplex.Messaging.Version
import System.IO (IOMode (..))
import System.Mem.Weak (Weak)
+17 -17
View File
@@ -11,7 +11,6 @@ module Simplex.Messaging.Server.Main where
import Control.Monad (void)
import Crypto.Random (getRandomBytes)
import qualified Data.ByteString.Char8 as B
import Data.Either (fromRight)
import Data.Functor (($>))
import Data.Ini (lookupValue, readIniFile)
import Data.Maybe (fromMaybe)
@@ -24,7 +23,7 @@ import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (BasicAuth (..), ProtoServerWithAuth (ProtoServerWithAuth), pattern SMPServer)
import Simplex.Messaging.Server (runSMPServer)
import Simplex.Messaging.Server.CLI
import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defaultInactiveClientExpiration, defaultMessageExpiration, defMsgExpirationDays)
import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defMsgExpirationDays, defaultInactiveClientExpiration, defaultMessageExpiration)
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Transport (simplexMQVersion, supportedSMPServerVRange)
import Simplex.Messaging.Transport.Client (TransportHost (..))
@@ -60,15 +59,15 @@ smpServerCLI cfgPath logPath =
initializeServer opts
| scripted opts = initialize opts
| otherwise = do
putStrLn "Use `smp-server init -h` for available options."
void $ withPrompt "SMP server will be initialized (press Enter)" getLine
enableStoreLog <- onOffPrompt "Enable store log to restore queues and messages on server restart" True
logStats <- onOffPrompt "Enable logging daily statistics" False
putStrLn "Require a password to create new messaging queues?"
password <- withPrompt "'r' for random (default), 'n' - no password, or enter password: " serverPassword
let host = fromMaybe (ip opts) (fqdn opts)
host' <- withPrompt ("Enter server FQDN or IP address for certificate (" <> host <> "): ") getLine
initialize opts {enableStoreLog, logStats, fqdn = if null host' then fqdn opts else Just host', password}
putStrLn "Use `smp-server init -h` for available options."
void $ withPrompt "SMP server will be initialized (press Enter)" getLine
enableStoreLog <- onOffPrompt "Enable store log to restore queues and messages on server restart" True
logStats <- onOffPrompt "Enable logging daily statistics" False
putStrLn "Require a password to create new messaging queues?"
password <- withPrompt "'r' for random (default), 'n' - no password, or enter password: " serverPassword
let host = fromMaybe (ip opts) (fqdn opts)
host' <- withPrompt ("Enter server FQDN or IP address for certificate (" <> host <> "): ") getLine
initialize opts {enableStoreLog, logStats, fqdn = if null host' then fqdn opts else Just host', password}
where
serverPassword =
getLine >>= \case
@@ -121,8 +120,8 @@ smpServerCLI cfgPath logPath =
\# The password will not be shared with the connecting contacts, you must share it only\n\
\# with the users who you want to allow creating messaging queues on your server.\n"
<> ( case basicAuth of
Just auth -> "create_password: " <> T.unpack (safeDecodeUtf8 $ strEncode auth)
_ -> "# create_password: password to create new queues (any printable ASCII characters without whitespace, '@', ':' and '/')"
Just auth -> "create_password: " <> T.unpack (safeDecodeUtf8 $ strEncode auth)
_ -> "# create_password: password to create new queues (any printable ASCII characters without whitespace, '@', ':' and '/')"
)
<> "\n\n\
\[TRANSPORT]\n\
@@ -141,7 +140,7 @@ smpServerCLI cfgPath logPath =
hSetBuffering stdout LineBuffering
hSetBuffering stderr LineBuffering
fp <- checkSavedFingerprint cfgPath defaultX509Config
let host = fromRight "<hostnames>" $ T.unpack <$> lookupValue "TRANSPORT" "host" ini
let host = either (const "<hostnames>") T.unpack $ lookupValue "TRANSPORT" "host" ini
port = T.unpack $ strictIni "TRANSPORT" "port" ini
cfg@ServerConfig {transports, storeLogFile, newQueueBasicAuth, messageExpiration, inactiveClientExpiration} = serverConfig
srv = ProtoServerWithAuth (SMPServer [THDomainName host] (if port == "5223" then "" else port) (C.KeyHash fp)) newQueueBasicAuth
@@ -186,9 +185,10 @@ smpServerCLI cfgPath logPath =
allowNewQueues = fromMaybe True $ iniOnOff "AUTH" "new_queues" ini,
newQueueBasicAuth = either error id <$> strDecodeIni "AUTH" "create_password" ini,
messageExpiration =
Just defaultMessageExpiration
{ ttl = 86400 * readIniDefault defMsgExpirationDays "STORE_LOG" "expire_messages_days" ini
},
Just
defaultMessageExpiration
{ ttl = 86400 * readIniDefault defMsgExpirationDays "STORE_LOG" "expire_messages_days" ini
},
inactiveClientExpiration =
settingIsOn "INACTIVE_CLIENTS" "disconnect" ini
$> ExpirationConfig
+1 -1
View File
@@ -64,7 +64,7 @@ flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue)
snapshotMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue . msgQueue)
snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue . msgQueue)
where
snapshotTQueue q = do
msgs <- flushTQueue q
-1
View File
@@ -4,7 +4,6 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Simplex.Messaging.Server.StoreLog
+10 -28
View File
@@ -1,7 +1,6 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
@@ -10,6 +9,7 @@
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
-- |
@@ -63,8 +63,7 @@ where
import Control.Applicative ((<|>))
import Control.Monad.Except
import Control.Monad.Trans.Except (throwE)
import Data.Aeson (FromJSON, ToJSON)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser)
import Data.Bifunctor (first)
import Data.Bitraversable (bimapM)
@@ -74,9 +73,7 @@ import qualified Data.ByteString.Lazy as BL
import Data.Default (def)
import Data.Functor (($>))
import Data.Version (showVersion)
import GHC.Generics (Generic)
import GHC.IO.Handle.Internals (ioe_EOF)
import Generic.Random (genericArbitraryU)
import Network.Socket
import qualified Network.TLS as T
import qualified Network.TLS.Extra as TE
@@ -87,7 +84,6 @@ import Simplex.Messaging.Parsers (dropPrefix, parse, parseRead1, sumTypeJSON)
import Simplex.Messaging.Transport.Buffer
import Simplex.Messaging.Util (bshow, catchAll, catchAll_)
import Simplex.Messaging.Version
import Test.QuickCheck (Arbitrary (..))
import UnliftIO.Exception (Exception)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@@ -284,14 +280,7 @@ data TransportError
TEBadSession
| -- | transport handshake error
TEHandshake {handshakeErr :: HandshakeError}
deriving (Eq, Generic, Read, Show, Exception)
instance ToJSON TransportError where
toJSON = J.genericToJSON . sumTypeJSON $ dropPrefix "TE"
toEncoding = J.genericToEncoding . sumTypeJSON $ dropPrefix "TE"
instance FromJSON TransportError where
parseJSON = J.genericParseJSON . sumTypeJSON $ dropPrefix "TE"
deriving (Eq, Read, Show, Exception)
-- | Transport handshake error.
data HandshakeError
@@ -301,18 +290,7 @@ data HandshakeError
VERSION
| -- | incorrect server identity
IDENTITY
deriving (Eq, Generic, Read, Show, Exception)
instance ToJSON HandshakeError where
toJSON = J.genericToJSON $ sumTypeJSON id
toEncoding = J.genericToEncoding $ sumTypeJSON id
instance FromJSON HandshakeError where
parseJSON = J.genericParseJSON $ sumTypeJSON id
instance Arbitrary TransportError where arbitrary = genericArbitraryU
instance Arbitrary HandshakeError where arbitrary = genericArbitraryU
deriving (Eq, Read, Show, Exception)
-- | SMP encrypted transport error parser.
transportErrorP :: Parser TransportError
@@ -354,9 +332,9 @@ smpServerHandshake c kh smpVRange = do
getHandshake th >>= \case
ClientHandshake {smpVersion, keyHash}
| keyHash /= kh ->
throwE $ TEHandshake IDENTITY
throwE $ TEHandshake IDENTITY
| smpVersion `isCompatible` smpVRange -> do
pure $ smpThHandle th smpVersion
pure $ smpThHandle th smpVersion
| otherwise -> throwE $ TEHandshake VERSION
-- | Client SMP transport handshake.
@@ -385,3 +363,7 @@ getHandshake th = ExceptT $ (parse smpP (TEHandshake PARSE) =<<) <$> tGetBlock t
smpTHandle :: Transport c => c -> THandle c
smpTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, batch = False}
$(J.deriveJSON (sumTypeJSON id) ''HandshakeError)
$(J.deriveJSON (sumTypeJSON $ dropPrefix "TE") ''TransportError)
+4 -4
View File
@@ -8,8 +8,8 @@ import Control.Concurrent.STM
import qualified Control.Exception as E
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import GHC.IO.Exception (IOErrorType (..), IOException (..), ioException)
import System.Timeout (timeout)
import GHC.IO.Exception (ioException, IOException (..), IOErrorType (..))
data TBuffer = TBuffer
{ buffer :: TVar ByteString,
@@ -41,9 +41,9 @@ getBuffered tb@TBuffer {buffer} n t_ getChunk = withBufferLock tb $ do
readChunks firstChunk b
| B.length b >= n = pure b
| otherwise =
get >>= \case
"" -> pure b
s -> readChunks False $ b <> s
get >>= \case
"" -> pure b
s -> readChunks False $ b <> s
where
get
| firstChunk = getChunk
@@ -1,7 +1,5 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Simplex.Messaging.Transport.Credentials
( tlsCredentials,
@@ -22,10 +22,9 @@ import qualified Network.TLS as T
import Numeric.Natural (Natural)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Transport (SessionId)
import Simplex.Messaging.Transport (SessionId, TLS)
import Simplex.Messaging.Transport.Client (TransportClientConfig (..), TransportHost (..), runTLSTransportClient)
import Simplex.Messaging.Transport.HTTP2
import Simplex.Messaging.Transport (TLS)
import UnliftIO.STM
import UnliftIO.Timeout
@@ -0,0 +1,42 @@
{-# LANGUAGE MultiWayIf #-}
module Simplex.Messaging.Transport.HTTP2.File where
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Builder (Builder, byteString)
import Data.Int (Int64)
import Data.Word (Word32)
import GHC.IO.Handle.Internals (ioe_EOF)
import System.IO (Handle)
fileBlockSize :: Int
fileBlockSize = 16384
hReceiveFile :: (Int -> IO ByteString) -> Handle -> Word32 -> IO Int64
hReceiveFile _ _ 0 = pure 0
hReceiveFile getBody h size = get $ fromIntegral size
where
get sz = do
ch <- getBody fileBlockSize
let chSize = fromIntegral $ B.length ch
if
| chSize > sz -> pure (chSize - sz)
| chSize > 0 -> B.hPut h ch >> get (sz - chSize)
| otherwise -> pure (-fromIntegral sz)
hSendFile :: Handle -> (Builder -> IO ()) -> Word32 -> IO ()
hSendFile h send = go
where
go 0 = pure ()
go sz =
getFileChunk h sz >>= \ch -> do
send $ byteString ch
go $ sz - fromIntegral (B.length ch)
getFileChunk :: Handle -> Word32 -> IO ByteString
getFileChunk h sz = do
ch <- B.hGet h fileBlockSize
if B.null ch
then ioe_EOF
else pure $ B.take (fromIntegral sz) ch -- sz >= xftpBlockSize
@@ -1,5 +1,4 @@
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
module Simplex.Messaging.Transport.HTTP2.Server where
+6 -7
View File
@@ -1,25 +1,22 @@
{-# LANGUAGE CApiFFI #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TemplateHaskell #-}
module Simplex.Messaging.Transport.KeepAlive where
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import Foreign.C (CInt (..))
import GHC.Generics (Generic)
import Network.Socket
import Simplex.Messaging.Parsers (defaultJSON)
data KeepAliveOpts = KeepAliveOpts
{ keepIdle :: Int,
keepIntvl :: Int,
keepCnt :: Int
}
deriving (Eq, Show, Generic, FromJSON)
instance ToJSON KeepAliveOpts where toEncoding = J.genericToEncoding J.defaultOptions
deriving (Eq, Show)
defaultKeepAliveOpts :: KeepAliveOpts
defaultKeepAliveOpts =
@@ -68,3 +65,5 @@ setSocketKeepAlive sock KeepAliveOpts {keepCnt, keepIdle, keepIntvl} = do
setSocketOption sock (SockOpt _SOL_TCP _TCP_KEEPIDLE) keepIdle
setSocketOption sock (SockOpt _SOL_TCP _TCP_KEEPINTVL) keepIntvl
setSocketOption sock (SockOpt _SOL_TCP _TCP_KEEPCNT) keepCnt
$(J.deriveJSON defaultJSON ''KeepAliveOpts)
+24 -12
View File
@@ -5,10 +5,13 @@
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Transport.Server
( runTransportServer,
runTCPServer,
TransportServerConfig (..),
( TransportServerConfig (..),
defaultTransportServerConfig,
runTransportServer,
runTransportServerSocket,
runTCPServer,
runTCPServerSocket,
startTCPServer,
loadSupportedTLSServerParams,
loadTLSServerParams,
loadFingerprint,
@@ -46,10 +49,11 @@ data TransportServerConfig = TransportServerConfig
deriving (Eq, Show)
defaultTransportServerConfig :: TransportServerConfig
defaultTransportServerConfig = TransportServerConfig
{ logTLSErrors = True,
transportTimeout = 40000000
}
defaultTransportServerConfig =
TransportServerConfig
{ logTLSErrors = True,
transportTimeout = 40000000
}
serverTransportConfig :: TransportServerConfig -> TransportConfig
serverTransportConfig TransportServerConfig {logTLSErrors} =
@@ -60,11 +64,15 @@ serverTransportConfig TransportServerConfig {logTLSErrors} =
--
-- All accepted connections are passed to the passed function.
runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m ()
runTransportServer started port serverParams cfg server = do
runTransportServer started port = runTransportServerSocket started (startTCPServer started port) (transportName (TProxy :: TProxy c))
-- | Run a transport server with provided connection setup and handler.
runTransportServerSocket :: (MonadUnliftIO m, T.TLSParams p, Transport a) => TMVar Bool -> IO Socket -> String -> p -> TransportServerConfig -> (a -> m ()) -> m ()
runTransportServerSocket started getSocket threadLabel serverParams cfg server = do
u <- askUnliftIO
let tCfg = serverTransportConfig cfg
labelMyThread $ "transport server for " <> transportName (TProxy :: TProxy c)
liftIO . runTCPServer started port $ \conn ->
labelMyThread $ "transport server for " <> threadLabel
liftIO . runTCPServerSocket started getSocket $ \conn ->
E.bracket
(connectTLS Nothing tCfg serverParams conn >>= getServerConnection tCfg)
closeConnection
@@ -72,11 +80,15 @@ runTransportServer started port serverParams cfg server = do
-- | Run TCP server without TLS
runTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO ()
runTCPServer started port server = do
runTCPServer started port = runTCPServerSocket started $ startTCPServer started port
-- | Wrap socket provider in a TCP server bracket.
runTCPServerSocket :: TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO ()
runTCPServerSocket started getSocket server = do
clients <- atomically TM.empty
clientId <- newTVarIO 0
E.bracket
(startTCPServer started port)
getSocket
(closeServer started clients)
$ \sock -> forever . E.bracketOnError (accept sock) (close . fst) $ \(conn, _peer) -> do
-- catchAll_ is needed here in case the connection was closed earlier
@@ -15,9 +15,9 @@ import qualified Network.WebSockets.Stream as S
import Simplex.Messaging.Transport
( TProxy,
Transport (..),
TransportConfig (..),
TransportError (..),
TransportPeer (..),
TransportConfig (..),
closeTLS,
smpBlockSize,
withTlsUnique,
+8 -8
View File
@@ -62,7 +62,7 @@ liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a)
liftEitherError f a = liftIOEither (first f <$> a)
{-# INLINE liftEitherError #-}
liftEitherWith :: (MonadError e' m) => (e -> e') -> Either e a -> m a
liftEitherWith :: MonadError e' m => (e -> e') -> Either e a -> m a
liftEitherWith f = liftEither . first f
{-# INLINE liftEitherWith #-}
@@ -102,7 +102,7 @@ catchAllErrors err action handle = tryAllErrors err action >>= either handle pur
{-# INLINE catchAllErrors #-}
catchThrow :: (MonadUnliftIO m, MonadError e m) => m a -> (E.SomeException -> e) -> m a
catchThrow action err = catchAllErrors err action throwError
catchThrow action err = catchAllErrors err action throwError
{-# INLINE catchThrow #-}
allFinally :: (MonadUnliftIO m, MonadError e m) => (E.SomeException -> e) -> m a -> m b -> m a
@@ -115,12 +115,12 @@ eitherToMaybe = either (const Nothing) Just
groupOn :: Eq k => (a -> k) -> [a] -> [[a]]
groupOn = groupBy . eqOn
-- it is equivalent to groupBy ((==) `on` f),
-- but it redefines `on` to avoid duplicate computation for most values.
-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn
-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f`
where
eqOn f = \x -> let fx = f x in \y -> fx == f y
-- it is equivalent to groupBy ((==) `on` f),
-- but it redefines `on` to avoid duplicate computation for most values.
-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn
-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f`
eqOn f x = let fx = f x in \y -> fx == f y
groupAllOn :: Ord k => (a -> k) -> [a] -> [[a]]
groupAllOn f = groupOn f . sortOn f
@@ -129,7 +129,7 @@ toChunks :: Int -> [a] -> [NonEmpty a]
toChunks _ [] = []
toChunks n xs =
let (ys, xs') = splitAt n xs
in maybe id (:) (L.nonEmpty ys) (toChunks n xs')
in maybe id (:) (L.nonEmpty ys) (toChunks n xs')
safeDecodeUtf8 :: ByteString -> Text
safeDecodeUtf8 = decodeUtf8With onError
+1 -1
View File
@@ -49,7 +49,7 @@ extra-deps:
- github: simplex-chat/aeson
commit: aab7b5a14d6c5ea64c64dcaee418de1bb00dcc2b
- github: kazu-yamamoto/http2
commit: b5a1b7200cf5bc7044af34ba325284271f6dff25
commit: 804fa283f067bd3fd89b8c5f8d25b3047813a517
# - ../direct-sqlcipher
- github: simplex-chat/direct-sqlcipher
commit: f814ee68b16a9447fbb467ccc8f29bdd3546bfd9
+17 -17
View File
@@ -45,42 +45,42 @@ agentTests (ATransport t) = do
describe "Migration tests" migrationTests
describe "SMP agent protocol syntax" $ syntaxTests t
describe "Establishing duplex connection" $ do
it "should connect via one server and one agent" $
it "should connect via one server and one agent" $ do
smpAgentTest2_1_1 $ testDuplexConnection t
it "should connect via one server and one agent (random IDs)" $
it "should connect via one server and one agent (random IDs)" $ do
smpAgentTest2_1_1 $ testDuplexConnRandomIds t
it "should connect via one server and 2 agents" $
it "should connect via one server and 2 agents" $ do
smpAgentTest2_2_1 $ testDuplexConnection t
it "should connect via one server and 2 agents (random IDs)" $
it "should connect via one server and 2 agents (random IDs)" $ do
smpAgentTest2_2_1 $ testDuplexConnRandomIds t
it "should connect via 2 servers and 2 agents" $
it "should connect via 2 servers and 2 agents" $ do
smpAgentTest2_2_2 $ testDuplexConnection t
it "should connect via 2 servers and 2 agents (random IDs)" $
it "should connect via 2 servers and 2 agents (random IDs)" $ do
smpAgentTest2_2_2 $ testDuplexConnRandomIds t
describe "Establishing connections via `contact connection`" $ do
it "should connect via contact connection with one server and 3 agents" $
it "should connect via contact connection with one server and 3 agents" $ do
smpAgentTest3 $ testContactConnection t
it "should connect via contact connection with one server and 2 agents (random IDs)" $
it "should connect via contact connection with one server and 2 agents (random IDs)" $ do
smpAgentTest2_2_1 $ testContactConnRandomIds t
it "should support rejecting contact request" $
it "should support rejecting contact request" $ do
smpAgentTest2_2_1 $ testRejectContactRequest t
describe "Connection subscriptions" $ do
it "should connect via one server and one agent" $
it "should connect via one server and one agent" $ do
smpAgentTest3_1_1 $ testSubscription t
it "should send notifications to client when server disconnects" $
it "should send notifications to client when server disconnects" $ do
smpAgentServerTest $ testSubscrNotification t
describe "Message delivery and server reconnection" $ do
it "should deliver messages after losing server connection and re-connecting" $
it "should deliver messages after losing server connection and re-connecting" $ do
smpAgentTest2_2_2_needs_server $ testMsgDeliveryServerRestart t
it "should connect to the server when server goes up if it initially was down" $
it "should connect to the server when server goes up if it initially was down" $ do
smpAgentTestN [] $ testServerConnectionAfterError t
it "should deliver pending messages after agent restarting" $
it "should deliver pending messages after agent restarting" $ do
smpAgentTest1_1_1 $ testMsgDeliveryAgentRestart t
it "should concurrently deliver messages to connections without blocking" $
it "should concurrently deliver messages to connections without blocking" $ do
smpAgentTest2_2_1 $ testConcurrentMsgDelivery t
it "should deliver messages if one of connections has quota exceeded" $
it "should deliver messages if one of connections has quota exceeded" $ do
smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t
it "should resume delivering messages after exceeding quota once all messages are received" $
it "should resume delivering messages after exceeding quota once all messages are received" $ do
smpAgentTest2_2_1 $ testResumeDeliveryQuotaExceeded t
type AEntityTransmission p e = (ACorrId, ConnId, ACommand p e)
+12 -13
View File
@@ -2,7 +2,6 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
module AgentTests.ConnectionRequestTests where
@@ -113,24 +112,24 @@ connectionRequestTests =
it "should serialize connection requests" $ do
strEncode connectionRequest
`shouldBe` "simplex:/invitation#/?v=1&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
<> urlEncode True testDhKeyStrUri
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
<> urlEncode True testDhKeyStrUri
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
strEncode connectionRequestCurrentRange
`shouldBe` "simplex:/invitation#/?v=1-4&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
<> urlEncode True testDhKeyStrUri
<> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
<> urlEncode True testDhKeyStrUri
<> "&e2e=v%3D1-2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
<> urlEncode True testDhKeyStrUri
<> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
<> urlEncode True testDhKeyStrUri
<> "&e2e=v%3D1-2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
strEncode connectionRequestClientDataEmpty
`shouldBe` "simplex:/invitation#/?v=1&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
<> urlEncode True testDhKeyStrUri
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
<> "&data=%7B%7D"
<> urlEncode True testDhKeyStrUri
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
<> "&data=%7B%7D"
strEncode connectionRequestClientData
`shouldBe` "simplex:/invitation#/?v=1&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
<> urlEncode True testDhKeyStrUri
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
<> "&data=%7B%22type%22%3A%22group_link%22%2C%20%22group_link_id%22%3A%22abc%22%7D"
<> urlEncode True testDhKeyStrUri
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
<> "&data=%7B%22type%22%3A%22group_link%22%2C%20%22group_link_id%22%3A%22abc%22%7D"
it "should parse connection requests" $ do
strDecode
( "https://simplex.chat/invitation#/?smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23"
+54 -1
View File
@@ -43,6 +43,7 @@ import Data.Int (Int64)
import qualified Data.Map as M
import Data.Maybe (isNothing)
import qualified Data.Set as S
import Data.Time.Clock (diffUTCTime, getCurrentTime)
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
import Data.Type.Equality
import SMPAgentClient
@@ -310,6 +311,7 @@ functionalAPITests t = do
describe "Delivery receipts" $ do
it "should send and receive delivery receipt" $ withSmpServer t testDeliveryReceipts
it "should send delivery receipt only in connection v3+" $ testDeliveryReceiptsVersion t
it "send delivery receipts concurrently with messages" $ testDeliveryReceiptsConcurrent t
testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int
testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do
@@ -1159,7 +1161,7 @@ testBatchedSubscriptions nCreate nDel t = do
a <- getSMPAgentClient' agentCfg initAgentServers2 testDB
b <- getSMPAgentClient' agentCfg initAgentServers2 testDB2
conns <- runServers $ do
conns <- forM [1 .. nCreate :: Int] . const $ makeConnection a b
conns <- replicateM (nCreate :: Int) $ makeConnection a b
forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId
let (aIds', bIds') = unzip $ take nDel conns
delete a bIds'
@@ -1928,6 +1930,57 @@ testDeliveryReceiptsVersion t = do
disconnectAgentClient a'
disconnectAgentClient b'
testDeliveryReceiptsConcurrent :: HasCallStack => ATransport -> IO ()
testDeliveryReceiptsConcurrent t =
withSmpServerConfigOn t cfg {msgQueueQuota = 128} testPort $ \_ -> do
withAgentClients2 $ \a b -> do
(aId, bId) <- runRight $ makeConnection a b
t1 <- liftIO getCurrentTime
concurrently_ (runClient "a" a bId) (runClient "b" b aId)
t2 <- liftIO getCurrentTime
diffUTCTime t2 t1 `shouldSatisfy` (< 15)
liftIO $ noMessages a "nothing else should be delivered to alice"
liftIO $ noMessages b "nothing else should be delivered to bob"
where
runClient :: String -> AgentClient -> ConnId -> IO ()
runClient _cName client connId = do
concurrently_ send receive
where
numMsgs = 100
send = runRight_ $
replicateM_ numMsgs $ do
-- liftIO $ print $ cName <> ": sendMessage"
void $ sendMessage client connId SMP.noMsgFlags "hello"
receive =
runRight_ $
-- for each sent message: 1 SENT, 1 RCVD, 1 OK for acknowledging RCVD
-- for each received message: 1 MSG, 1 OK for acknowledging MSG
receiveLoop (numMsgs * 5)
receiveLoop :: Int -> ExceptT AgentErrorType IO ()
receiveLoop 0 = pure ()
receiveLoop n = do
r <- getWithTimeout
case r of
(_, _, SENT _) -> do
-- liftIO $ print $ cName <> ": SENT"
pure ()
(_, _, MSG MsgMeta {recipient = (msgId, _), integrity = MsgOk} _ _) -> do
-- liftIO $ print $ cName <> ": MSG " <> show msgId
ackMessageAsync client (B.pack . show $ n) connId msgId (Just "")
(_, _, RCVD MsgMeta {recipient = (msgId, _), integrity = MsgOk} _) -> do
-- liftIO $ print $ cName <> ": RCVD " <> show msgId
ackMessageAsync client (B.pack . show $ n) connId msgId Nothing
(_, _, OK) -> do
-- liftIO $ print $ cName <> ": OK"
pure ()
r' -> error $ "unexpected event: " <> show r'
receiveLoop (n - 1)
getWithTimeout :: ExceptT AgentErrorType IO (AEntityTransmission 'AEConn)
getWithTimeout = do
1000000 `timeout` get client >>= \case
Just r -> pure r
_ -> error "timeout"
testTwoUsers :: HasCallStack => IO ()
testTwoUsers = withAgentClients2 $ \a b -> do
let nc = netCfg initAgentServers
+1 -1
View File
@@ -504,7 +504,7 @@ testNotificationsSMPRestartBatch n t APNSMockServer {apnsQ} = do
a <- getSMPAgentClient' agentCfg initAgentServers2 testDB
b <- getSMPAgentClient' agentCfg initAgentServers2 testDB2
conns <- runServers $ do
conns <- forM [1 .. n :: Int] . const $ makeConnection a b
conns <- replicateM (n :: Int) $ makeConnection a b
_ <- registerTestToken a "abcd" NMInstant apnsQ
liftIO $ threadDelay 1500000
forM_ conns $ \(aliceId, bobId) -> do
-1
View File
@@ -7,7 +7,6 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
module AgentTests.SQLiteTests (storeTests) where
+2 -2
View File
@@ -4,7 +4,7 @@ module CoreTests.BatchingTests (batchingTests) where
import Control.Concurrent.STM
import Control.Monad
import Crypto.Random (MonadRandom(..))
import Crypto.Random (MonadRandom (..))
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.List.NonEmpty as L
@@ -12,7 +12,7 @@ import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol
import Simplex.Messaging.Transport
import Simplex.Messaging.Version (VersionRange(..))
import Simplex.Messaging.Version (VersionRange (..))
import Test.Hspec
batchingTests :: Spec
+1 -1
View File
@@ -195,7 +195,7 @@ testAESGCM = it "should encrypt / decrypt string with a random symmetric key" $
cipher `shouldNotBe` plain
s `shouldBe` plain
testEncoding :: (C.AlgorithmI a) => C.SAlgorithm a -> Spec
testEncoding :: C.AlgorithmI a => C.SAlgorithm a -> Spec
testEncoding alg = it "should encode / decode key" . ioProperty $ do
(k, pk) <- C.generateKeyPair alg
pure $ \(_ :: Int) ->
+57 -6
View File
@@ -1,14 +1,22 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module CoreTests.ProtocolErrorTests where
import qualified Data.ByteString.Char8 as B
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Simplex.Messaging.Agent.Protocol (AgentErrorType (..), BrokerErrorType (..))
import GHC.Generics (Generic)
import Generic.Random (genericArbitraryU)
import Simplex.FileTransfer.Protocol (XFTPErrorType (..))
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Protocol (CommandError (..), ErrorType (..))
import Simplex.Messaging.Transport (HandshakeError (..), TransportError (..))
import Test.Hspec
import Test.Hspec.QuickCheck (modifyMaxSuccess)
import Test.QuickCheck
@@ -16,15 +24,58 @@ import Test.QuickCheck
protocolErrorTests :: Spec
protocolErrorTests = modifyMaxSuccess (const 1000) $ do
describe "errors parsing / serializing" $ do
it "should parse SMP protocol errors" . property $ \(err :: AgentErrorType) ->
errHasSpaces err
|| parseAll strP (strEncode err) == Right err
it "should parse SMP protocol errors" . property $ \(err :: ErrorType) ->
smpDecode (smpEncode err) == Right err
it "should parse SMP agent errors" . property $ \(err :: AgentErrorType) ->
errHasSpaces err
|| parseAll strP (strEncode err) == Right err
|| strDecode (strEncode err) == Right err
where
errHasSpaces = \case
BROKER srv (RESPONSE e) -> hasSpaces srv || hasSpaces e
BROKER srv _ -> hasSpaces srv
_ -> False
hasSpaces s = ' ' `B.elem` encodeUtf8 (T.pack s)
deriving instance Generic AgentErrorType
deriving instance Generic CommandErrorType
deriving instance Generic ConnectionErrorType
deriving instance Generic BrokerErrorType
deriving instance Generic SMPAgentError
deriving instance Generic AgentCryptoError
deriving instance Generic ErrorType
deriving instance Generic CommandError
deriving instance Generic TransportError
deriving instance Generic HandshakeError
deriving instance Generic XFTPErrorType
instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU
instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU
instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU
instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU
instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU
instance Arbitrary AgentCryptoError where arbitrary = genericArbitraryU
instance Arbitrary ErrorType where arbitrary = genericArbitraryU
instance Arbitrary CommandError where arbitrary = genericArbitraryU
instance Arbitrary TransportError where arbitrary = genericArbitraryU
instance Arbitrary HandshakeError where arbitrary = genericArbitraryU
instance Arbitrary XFTPErrorType where arbitrary = genericArbitraryU
-1
View File
@@ -1,4 +1,3 @@
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module CoreTests.RetryIntervalTests where
+1 -1
View File
@@ -6,8 +6,8 @@ import Control.Exception (Exception, SomeException, throwIO)
import Control.Monad.Except
import Control.Monad.IO.Class
import Data.IORef
import Simplex.Messaging.Util
import Simplex.Messaging.Client.Agent ()
import Simplex.Messaging.Util
import Test.Hspec
import qualified UnliftIO.Exception as UE
+4 -3
View File
@@ -39,13 +39,14 @@ versionRangeTests = modifyMaxSuccess (const 1000) $ do
isCompatible (1 :: Version) (vr 2 2) `shouldBe` False
it "compatibleVersion should pass isCompatible check" . property $
\((min1, max1) :: (V, V)) ((min2, max2) :: (V, V)) ->
min1 > max1 || min2 > max2 -- one of ranges is invalid, skip testing it
min1 > max1
|| min2 > max2 -- one of ranges is invalid, skip testing it
|| let w = fromIntegral . fromEnum
vr1 = mkVersionRange (w min1) (w max1) :: VersionRange
vr2 = mkVersionRange (w min2) (w max2) :: VersionRange
in case compatibleVersion vr1 vr2 of
Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2
_ -> True
Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2
_ -> True
where
vr = mkVersionRange
compatible :: (VersionRange, VersionRange) -> Maybe Version -> Expectation
+7
View File
@@ -38,6 +38,7 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Server (runNtfServerBlocking)
import Simplex.Messaging.Notifications.Server.Env
import Simplex.Messaging.Notifications.Server.Push.APNS
import Simplex.Messaging.Notifications.Server.Push.APNS.Internal
import Simplex.Messaging.Notifications.Transport
import Simplex.Messaging.Protocol
import Simplex.Messaging.Transport
@@ -186,10 +187,16 @@ instance FromJSON APNSAlertBody where
parseJSON (J.String v) = pure $ APNSAlertText v
parseJSON invalid = JT.prependFailure "parsing Coord failed, " (JT.typeMismatch "Object" invalid)
deriving instance Generic APNSNotificationBody
instance FromJSON APNSNotificationBody where parseJSON = J.genericParseJSON apnsJSONOptions {J.rejectUnknownFields = True}
deriving instance Generic APNSNotification
deriving instance FromJSON APNSNotification
deriving instance Generic APNSErrorResponse
deriving instance ToJSON APNSErrorResponse
getAPNSMockServer :: HTTP2ServerConfig -> IO APNSMockServer
-1
View File
@@ -139,7 +139,6 @@ testNotificationSubscription (ATransport t) =
mTs `shouldBe` msgTs
(msgBody, "hello") #== "delivered from queue"
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, ACK mId1)
pure ()
-- replace token
let tkn' = DeviceToken PPApnsTest "efgh"
RespNtf "7" tId' NROk <- signSendRecvNtf nh tknKey ("7", tId, TRPL tkn')
-1
View File
@@ -832,7 +832,6 @@ testRestoreExpireMessages at@(ATransport t) =
msgs'' <- B.readFile testStoreMsgsFile
length (B.lines msgs'') `shouldBe` 2
B.lines msgs'' `shouldBe` drop 2 (B.lines msgs)
where
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
runTest _ test' server = do
+6 -6
View File
@@ -292,8 +292,8 @@ testXFTPAgentSendRestore = withGlobalLogging logCfgNoLogs $ do
-- receive file
rcp <- getSMPAgentClient' agentCfg initAgentServers testDB2
runRight_ $
void $ testReceive rcp rfd1 filePath
runRight_ . void $
testReceive rcp rfd1 filePath
testXFTPAgentSendCleanup :: HasCallStack => IO ()
testXFTPAgentSendCleanup = withGlobalLogging logCfgNoLogs $ do
@@ -342,8 +342,8 @@ testXFTPAgentDelete = withGlobalLogging logCfgNoLogs $
-- receive file
rcp1 <- getSMPAgentClient' agentCfg initAgentServers testDB2
runRight_ $
void $ testReceive rcp1 rfd1 filePath
runRight_ . void $
testReceive rcp1 rfd1 filePath
length <$> listDirectory xftpServerFiles `shouldReturn` 6
@@ -377,8 +377,8 @@ testXFTPAgentDeleteRestore = withGlobalLogging logCfgNoLogs $ do
-- receive file
rcp1 <- getSMPAgentClient' agentCfg initAgentServers testDB2
runRight_ $
void $ testReceive rcp1 rfd1 filePath
runRight_ . void $
testReceive rcp1 rfd1 filePath
disconnectAgentClient rcp1
disconnectAgentClient sndr
pure (sfId, sndDescr, rfd2)