diff --git a/.hlint.yaml b/.hlint.yaml new file mode 100644 index 000000000..edf6d20ff --- /dev/null +++ b/.hlint.yaml @@ -0,0 +1 @@ +- ignore: {name: "Use underscore"} diff --git a/cabal.project b/cabal.project index 301ca7d56..08e1865b3 100644 --- a/cabal.project +++ b/cabal.project @@ -19,7 +19,7 @@ source-repository-package source-repository-package type: git location: https://github.com/kazu-yamamoto/http2.git - tag: b5a1b7200cf5bc7044af34ba325284271f6dff25 + tag: 804fa283f067bd3fd89b8c5f8d25b3047813a517 source-repository-package type: git diff --git a/fourmolu.yaml b/fourmolu.yaml new file mode 100644 index 000000000..907a25e7d --- /dev/null +++ b/fourmolu.yaml @@ -0,0 +1,30 @@ +indentation: 2 +column-limit: none +function-arrows: trailing +comma-style: trailing +import-export-style: trailing +indent-wheres: true +record-brace-space: true +newlines-between-decls: 1 +haddock-style: single-line +haddock-style-module: null +let-style: inline +in-style: right-align +single-constraint-parens: never +unicode: never +respectful: true +fixities: + - infixr 9 . + - infixr 8 .:, .:., .= + - infixr 6 <> + - infixr 5 ++ + - infixl 4 <$>, <$, $>, <$$>, <$?> + - infixl 4 <*>, <*, *>, <**> + - infix 4 ==, /= + - infixr 3 && + - infixl 3 <|> + - infixr 2 || + - infixl 1 >>, >>= + - infixr 1 =<<, >=>, <=< + - infixr 0 $, $! +reexports: [] diff --git a/package.yaml b/package.yaml index c2de94cbb..62a176b64 100644 --- a/package.yaml +++ b/package.yaml @@ -46,8 +46,7 @@ dependencies: - filepath == 1.4.* - hourglass == 0.2.* - http-types == 0.12.* - - http2 == 4.1.* - - generic-random == 1.5.* + - http2 >= 4.1.4 && < 4.2 - ini == 0.4.1 - iproute == 1.7.* - iso8601-time == 0.1.* @@ -56,7 +55,6 @@ dependencies: - network >= 3.1.2.7 && < 3.2 - network-transport == 0.5.6 - optparse-applicative >= 0.15 && < 0.17 - - QuickCheck == 2.14.* - process == 1.6.* - random >= 1.1 && < 1.3 - simple-logger == 0.1.* @@ -152,6 +150,7 @@ tests: dependencies: - simplexmq - deepseq == 1.4.* + - generic-random == 1.5.* - hspec == 2.11.* - hspec-core == 2.11.* - HUnit == 1.6.* diff --git a/protocol/simplex-messaging.md b/protocol/simplex-messaging.md index 77a0c0621..48843cab3 100644 --- a/protocol/simplex-messaging.md +++ b/protocol/simplex-messaging.md @@ -364,9 +364,9 @@ The clients can optionally instruct a dedicated push notification server to subs [`SEND` command](#send-message) includes the notification flag to instruct SMP server whether to send the notification - this flag is forwarded to the recepient inside encrypted envelope, together with the timestamp and the message body, so even if TLS is compromised this flag cannot be used for traffic correlation. -## SMP Transmission andtransport block structure +## SMP Transmission and transport block structure -Each transport block (SMP transmission) has a fixed size of 16384 bytes for traffic uniformity. +Each transport block has a fixed size of 16384 bytes for traffic uniformity. From SMP version 4 each block can contain multiple transmissions, version 3 blocks have 1 transmission. Some parts of SMP transmission are padded to a fixed size; this padding is uniformly added as a word16 encoded in network byte order - see `paddedString` syntax. @@ -387,17 +387,17 @@ Each transmission/block for SMP v3 between the client and the server must have t ```abnf paddedTransmission = -transmission = [signature] SP signed -signed = sessionIdentifier SP [corrId] SP [queueId] SP smpCommand +transmission = signature signed +signed = sessionIdentifier corrId queueId smpCommand ; corrId is required in client commands and server responses, ; it is empty in server notifications. -corrId = 1*32(%x21-7F) ; any characters other than control/whitespace -queueId = encoded ; max 32 bytes when decoded (24 bytes is used), +corrId = length *OCTET +queueId = length *OCTET ; empty queue ID is used with "create" command and in some server responses -signature = encoded +signature = length *OCTET ; empty signature can be used with "send" before the queue is secured with secure command ; signature is always empty with "ping" and "serverMsg" -encoded = +length = 1*1 OCTET ``` `base64` encoding should be used with padding, as defined in section 4 of [RFC 4648][9] diff --git a/scripts/docker/entrypoint-xftp-server b/scripts/docker/entrypoint-xftp-server index 55757401e..9e5bf5ac1 100755 --- a/scripts/docker/entrypoint-xftp-server +++ b/scripts/docker/entrypoint-xftp-server @@ -19,7 +19,9 @@ if [ ! -f "${confd}/file-server.ini" ]; then # Set quota case "${QUOTA}" in '') printf 'Please specify $QUOTA environment variable.\n'; exit 1 ;; - *) set -- "$@" --quota "${QUOTA}" ;; + *GB) QUOTA="$(printf ${QUOTA} | tr '[:upper:]' '[:lower:]')"; set -- "$@" --quota "${QUOTA}" ;; + *gb) set -- "$@" --quota "${QUOTA}" ;; + *) printf 'Wrong format. Format should be: 1gb, 10gb, 100gb.\n'; exit 1 ;; esac # Init the certificates and configs diff --git a/simplexmq.cabal b/simplexmq.cabal index 2b98e10c2..4aca340d6 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -114,6 +114,7 @@ library Simplex.Messaging.Notifications.Server.Env Simplex.Messaging.Notifications.Server.Main Simplex.Messaging.Notifications.Server.Push.APNS + Simplex.Messaging.Notifications.Server.Push.APNS.Internal Simplex.Messaging.Notifications.Server.Stats Simplex.Messaging.Notifications.Server.Store Simplex.Messaging.Notifications.Server.StoreLog @@ -140,6 +141,7 @@ library Simplex.Messaging.Transport.Credentials Simplex.Messaging.Transport.HTTP2 Simplex.Messaging.Transport.HTTP2.Client + Simplex.Messaging.Transport.HTTP2.File Simplex.Messaging.Transport.HTTP2.Server Simplex.Messaging.Transport.KeepAlive Simplex.Messaging.Transport.Server @@ -160,8 +162,7 @@ library extra-libraries: crypto build-depends: - QuickCheck ==2.14.* - , aeson ==2.2.* + aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 , asn1-encoding ==0.9.* , asn1-types ==0.3.* @@ -180,10 +181,9 @@ library , direct-sqlcipher ==2.3.* , directory ==1.3.* , filepath ==1.4.* - , generic-random ==1.5.* , hourglass ==0.2.* , http-types ==0.12.* - , http2 ==4.1.* + , http2 >=4.1.4 && <4.2 , ini ==0.4.1 , iproute ==1.7.* , iso8601-time ==0.1.* @@ -225,8 +225,7 @@ executable ntf-server apps/ntf-server ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts build-depends: - QuickCheck ==2.14.* - , aeson ==2.2.* + aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 , asn1-encoding ==0.9.* , asn1-types ==0.3.* @@ -245,10 +244,9 @@ executable ntf-server , direct-sqlcipher ==2.3.* , directory ==1.3.* , filepath ==1.4.* - , generic-random ==1.5.* , hourglass ==0.2.* , http-types ==0.12.* - , http2 ==4.1.* + , http2 >=4.1.4 && <4.2 , ini ==0.4.1 , iproute ==1.7.* , iso8601-time ==0.1.* @@ -291,8 +289,7 @@ executable smp-agent apps/smp-agent ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts build-depends: - QuickCheck ==2.14.* - , aeson ==2.2.* + aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 , asn1-encoding ==0.9.* , asn1-types ==0.3.* @@ -311,10 +308,9 @@ executable smp-agent , direct-sqlcipher ==2.3.* , directory ==1.3.* , filepath ==1.4.* - , generic-random ==1.5.* , hourglass ==0.2.* , http-types ==0.12.* - , http2 ==4.1.* + , http2 >=4.1.4 && <4.2 , ini ==0.4.1 , iproute ==1.7.* , iso8601-time ==0.1.* @@ -357,8 +353,7 @@ executable smp-server apps/smp-server ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts build-depends: - QuickCheck ==2.14.* - , aeson ==2.2.* + aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 , asn1-encoding ==0.9.* , asn1-types ==0.3.* @@ -377,10 +372,9 @@ executable smp-server , direct-sqlcipher ==2.3.* , directory ==1.3.* , filepath ==1.4.* - , generic-random ==1.5.* , hourglass ==0.2.* , http-types ==0.12.* - , http2 ==4.1.* + , http2 >=4.1.4 && <4.2 , ini ==0.4.1 , iproute ==1.7.* , iso8601-time ==0.1.* @@ -423,8 +417,7 @@ executable xftp apps/xftp ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts build-depends: - QuickCheck ==2.14.* - , aeson ==2.2.* + aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 , asn1-encoding ==0.9.* , asn1-types ==0.3.* @@ -443,10 +436,9 @@ executable xftp , direct-sqlcipher ==2.3.* , directory ==1.3.* , filepath ==1.4.* - , generic-random ==1.5.* , hourglass ==0.2.* , http-types ==0.12.* - , http2 ==4.1.* + , http2 >=4.1.4 && <4.2 , ini ==0.4.1 , iproute ==1.7.* , iso8601-time ==0.1.* @@ -489,8 +481,7 @@ executable xftp-server apps/xftp-server ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts build-depends: - QuickCheck ==2.14.* - , aeson ==2.2.* + aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 , asn1-encoding ==0.9.* , asn1-types ==0.3.* @@ -509,10 +500,9 @@ executable xftp-server , direct-sqlcipher ==2.3.* , directory ==1.3.* , filepath ==1.4.* - , generic-random ==1.5.* , hourglass ==0.2.* , http-types ==0.12.* - , http2 ==4.1.* + , http2 >=4.1.4 && <4.2 , ini ==0.4.1 , iproute ==1.7.* , iso8601-time ==0.1.* @@ -610,7 +600,7 @@ test-suite simplexmq-test , hspec ==2.11.* , hspec-core ==2.11.* , http-types ==0.12.* - , http2 ==4.1.* + , http2 >=4.1.4 && <4.2 , ini ==0.4.1 , iproute ==1.7.* , iso8601-time ==0.1.* diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index 106c61011..bda8e1e9e 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -10,7 +10,6 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} - {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module Simplex.FileTransfer.Agent @@ -322,8 +321,8 @@ runXFTPSndPrepareWorker c doWork = do if status /= SFSEncrypted -- status is SFSNew or SFSEncrypting then do fsEncPath <- toFSFilePath $ sndFileEncPath ppath - when (status == SFSEncrypting) $ - whenM (doesFileExist fsEncPath) $ removeFile fsEncPath + when (status == SFSEncrypting) . whenM (doesFileExist fsEncPath) $ + removeFile fsEncPath withStore' c $ \db -> updateSndFileStatus db sndFileId SFSEncrypting (digest, chunkSpecsDigests) <- encryptFileForUpload sndFile fsEncPath withStore c $ \db -> do @@ -441,11 +440,11 @@ runXFTPSndWorker c srv doWork = do | length rcvIdsKeys > numRecipients = throwError $ INTERNAL "too many recipients" | length rcvIdsKeys == numRecipients = pure cr | otherwise = do - maxRecipients <- asks $ xftpMaxRecipientsPerRequest . config - let numRecipients' = min (numRecipients - length rcvIdsKeys) maxRecipients - rcvIdsKeys' <- agentXFTPAddRecipients c userId chunkDigest cr numRecipients' - cr' <- withStore' c $ \db -> addSndChunkReplicaRecipients db cr $ L.toList rcvIdsKeys' - addRecipients ch cr' + maxRecipients <- asks $ xftpMaxRecipientsPerRequest . config + let numRecipients' = min (numRecipients - length rcvIdsKeys) maxRecipients + rcvIdsKeys' <- agentXFTPAddRecipients c userId chunkDigest cr numRecipients' + cr' <- withStore' c $ \db -> addSndChunkReplicaRecipients db cr $ L.toList rcvIdsKeys' + addRecipients ch cr' sndFileToDescrs :: SndFile -> m (ValidFileDescription 'FSender, [ValidFileDescription 'FRecipient]) sndFileToDescrs SndFile {digest = Nothing} = throwError $ INTERNAL "snd file has no digest" sndFileToDescrs SndFile {chunks = []} = throwError $ INTERNAL "snd file has no chunks" @@ -573,7 +572,7 @@ runXFTPDelWorker c srv doWork = do withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay atomically $ assertAgentForeground c loop - retryDone e = delWorkerInternalError c deletedSndChunkReplicaId e + retryDone = delWorkerInternalError c deletedSndChunkReplicaId deleteChunkReplica :: DeletedSndChunkReplica -> m () deleteChunkReplica replica@DeletedSndChunkReplica {userId, deletedSndChunkReplicaId} = do agentXFTPDeleteChunk c userId replica diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index fcf0debde..04e6ff429 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -49,6 +49,7 @@ import Simplex.Messaging.Transport (supportedParameters) import Simplex.Messaging.Transport.Client (TransportClientConfig, TransportHost) import Simplex.Messaging.Transport.HTTP2 import Simplex.Messaging.Transport.HTTP2.Client +import Simplex.Messaging.Transport.HTTP2.File import Simplex.Messaging.Util (bshow, liftEitherError, whenM) import UnliftIO import UnliftIO.Directory @@ -153,7 +154,7 @@ sendXFTPCommand XFTPClient {config, http2Client = http2@HTTP2Client {sessionId}} forM_ chunkSpec_ $ \XFTPChunkSpec {filePath, chunkOffset, chunkSize} -> withFile filePath ReadMode $ \h -> do hSeek h AbsoluteSeek $ fromIntegral chunkOffset - sendFile h send $ fromIntegral chunkSize + hSendFile h send $ fromIntegral chunkSize done createXFTPChunk :: diff --git a/src/Simplex/FileTransfer/Client/Main.hs b/src/Simplex/FileTransfer/Client/Main.hs index 4c7e5439d..08e03a556 100644 --- a/src/Simplex/FileTransfer/Client/Main.hs +++ b/src/Simplex/FileTransfer/Client/Main.hs @@ -9,7 +9,6 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} - {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module Simplex.FileTransfer.Client.Main @@ -527,8 +526,8 @@ prepareChunkSizes size' = prepareSizes size' (smallSize, bigSize) | size' > size34 chunkSize3 = (chunkSize2, chunkSize3) | otherwise = (chunkSize1, chunkSize2) - -- | size' > size34 chunkSize2 = (chunkSize1, chunkSize2) - -- | otherwise = (chunkSize0, chunkSize1) + -- | size' > size34 chunkSize2 = (chunkSize1, chunkSize2) + -- | otherwise = (chunkSize0, chunkSize1) size34 sz = (fromIntegral sz * 3) `div` 4 prepareSizes 0 = [] prepareSizes size @@ -571,11 +570,11 @@ withRetry retryCount = withRetry' retryCount . withExceptT (CLIError . show) removeFD :: Bool -> FilePath -> IO () removeFD yes fd | yes = do - removeFile fd - putStrLn $ "\nFile description " <> fd <> " is deleted." + removeFile fd + putStrLn $ "\nFile description " <> fd <> " is deleted." | otherwise = do - y <- liftIO . getConfirmation $ "\nFile description " <> fd <> " can't be used again. Delete it" - when y $ removeFile fd + y <- liftIO . getConfirmation $ "\nFile description " <> fd <> " can't be used again. Delete it" + when y $ removeFile fd getConfirmation :: String -> IO Bool getConfirmation prompt = do diff --git a/src/Simplex/FileTransfer/Crypto.hs b/src/Simplex/FileTransfer/Crypto.hs index c0c2c49c3..03dc83a00 100644 --- a/src/Simplex/FileTransfer/Crypto.hs +++ b/src/Simplex/FileTransfer/Crypto.hs @@ -46,12 +46,12 @@ encryptFile srcFile fileHdr key nonce fileSize' encSize encFile = do encryptChunks_ get w (!sb, !len) | len == 0 = pure sb | otherwise = do - let chSize = min len 65536 - ch <- liftIO $ get chSize - when (B.length ch /= fromIntegral chSize) $ throwError $ FTCEFileIOError "encrypting file: unexpected EOF" - let (ch', sb') = LC.sbEncryptChunk sb ch - liftIO $ B.hPut w ch' - encryptChunks_ get w (sb', len - chSize) + let chSize = min len 65536 + ch <- liftIO $ get chSize + when (B.length ch /= fromIntegral chSize) $ throwError $ FTCEFileIOError "encrypting file: unexpected EOF" + let (ch', sb') = LC.sbEncryptChunk sb ch + liftIO $ B.hPut w ch' + encryptChunks_ get w (sb', len - chSize) decryptChunks :: Int64 -> [FilePath] -> C.SbKey -> C.CbNonce -> (String -> ExceptT String IO CryptoFile) -> ExceptT FTCryptoError IO CryptoFile decryptChunks _ [] _ _ _ = throwError $ FTCEInvalidHeader "empty" diff --git a/src/Simplex/FileTransfer/Description.hs b/src/Simplex/FileTransfer/Description.hs index dda4bec7f..64a1d8a32 100644 --- a/src/Simplex/FileTransfer/Description.hs +++ b/src/Simplex/FileTransfer/Description.hs @@ -1,6 +1,5 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} @@ -9,7 +8,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} - +{-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module Simplex.FileTransfer.Description @@ -38,7 +37,7 @@ where import Control.Applicative (optional) import Control.Monad ((<=<)) import Data.Aeson (FromJSON, ToJSON) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) @@ -54,12 +53,11 @@ import Data.Word (Word32) import qualified Data.Yaml as Y import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) -import GHC.Generics (Generic) import Simplex.FileTransfer.Chunks import Simplex.FileTransfer.Protocol import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Parsers (parseAll) +import Simplex.Messaging.Parsers (defaultJSON, parseAll) import Simplex.Messaging.Protocol (XFTPServer) import Simplex.Messaging.Util (bshow, groupAllOn, (<$?>)) @@ -150,21 +148,13 @@ data YAMLFileDescription = YAMLFileDescription chunkSize :: String, replicas :: [YAMLServerReplicas] } - deriving (Eq, Show, Generic, FromJSON) - -instance ToJSON YAMLFileDescription where - toJSON = J.genericToJSON J.defaultOptions - toEncoding = J.genericToEncoding J.defaultOptions + deriving (Eq, Show) data YAMLServerReplicas = YAMLServerReplicas { server :: XFTPServer, chunks :: [String] } - deriving (Eq, Show, Generic, FromJSON) - -instance ToJSON YAMLServerReplicas where - toJSON = J.genericToJSON J.defaultOptions - toEncoding = J.genericToEncoding J.defaultOptions + deriving (Eq, Show) data FileServerReplica = FileServerReplica { chunkNo :: Int, @@ -176,6 +166,13 @@ data FileServerReplica = FileServerReplica } deriving (Show) +newtype FileSize a = FileSize {unFileSize :: a} + deriving (Eq, Show) + +$(J.deriveJSON defaultJSON ''YAMLServerReplicas) + +$(J.deriveJSON defaultJSON ''YAMLFileDescription) + instance FilePartyI p => StrEncoding (ValidFileDescription p) where strEncode (ValidFD fd) = strEncode fd strDecode s = strDecode s >>= (\(AVFD fd) -> checkParty fd) @@ -217,9 +214,6 @@ encodeFileDescription FileDescription {party, size, digest, key, nonce, chunkSiz replicas = encodeFileReplicas chunkSize chunks } -newtype FileSize a = FileSize {unFileSize :: a} - deriving (Eq, Show) - instance (Integral a, Show a) => StrEncoding (FileSize a) where strEncode (FileSize b) | b' /= 0 = bshow b @@ -242,9 +236,9 @@ instance (Integral a, Show a) => StrEncoding (FileSize a) where instance (Integral a, Show a) => IsString (FileSize a) where fromString = either error id . strDecode . B.pack -instance (FromField a) => FromField (FileSize a) where fromField f = FileSize <$> fromField f +instance FromField a => FromField (FileSize a) where fromField f = FileSize <$> fromField f -instance (ToField a) => ToField (FileSize a) where toField (FileSize s) = toField s +instance ToField a => ToField (FileSize a) where toField (FileSize s) = toField s groupReplicasByServer :: FileSize Word32 -> [FileChunk] -> [[FileServerReplica]] groupReplicasByServer defChunkSize = diff --git a/src/Simplex/FileTransfer/Protocol.hs b/src/Simplex/FileTransfer/Protocol.hs index 7c8ee4cbf..9d87a8f52 100644 --- a/src/Simplex/FileTransfer/Protocol.hs +++ b/src/Simplex/FileTransfer/Protocol.hs @@ -1,5 +1,4 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -7,6 +6,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} @@ -14,8 +14,7 @@ module Simplex.FileTransfer.Protocol where import Control.Applicative ((<|>)) -import Data.Aeson (FromJSON, ToJSON) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) import Data.ByteString.Char8 (ByteString) @@ -25,8 +24,6 @@ import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe (isNothing) import Data.Type.Equality import Data.Word (Word32) -import GHC.Generics (Generic) -import Generic.Random (genericArbitraryU) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -59,7 +56,6 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Transport (SessionId, TransportError (..)) import Simplex.Messaging.Util (bshow, (<$?>)) import Simplex.Messaging.Version -import Test.QuickCheck (Arbitrary (..)) currentXFTPVersion :: Version currentXFTPVersion = 1 @@ -69,14 +65,7 @@ xftpBlockSize = 16384 -- | File protocol clients data FileParty = FRecipient | FSender - deriving (Eq, Show, Generic) - -instance FromJSON FileParty where - parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "F" - -instance ToJSON FileParty where - toJSON = J.genericToJSON . enumJSON $ dropPrefix "F" - toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "F" + deriving (Eq, Show) data SFileParty :: FileParty -> Type where SFRecipient :: SFileParty FRecipient @@ -355,14 +344,7 @@ data XFTPErrorType INTERNAL | -- | used internally, never returned by the server (to be removed) DUPLICATE_ -- not part of SMP protocol, used internally - deriving (Eq, Generic, Read, Show) - -instance ToJSON XFTPErrorType where - toJSON = J.genericToJSON $ sumTypeJSON id - toEncoding = J.genericToEncoding $ sumTypeJSON id - -instance FromJSON XFTPErrorType where - parseJSON = J.genericParseJSON $ sumTypeJSON id + deriving (Eq, Read, Show) instance StrEncoding XFTPErrorType where strEncode = \case @@ -370,8 +352,6 @@ instance StrEncoding XFTPErrorType where e -> bshow e strP = "CMD " *> (CMD <$> parseRead1) <|> parseRead1 -instance Arbitrary XFTPErrorType where arbitrary = genericArbitraryU - instance Encoding XFTPErrorType where smpEncode = \case BLOCK -> "BLOCK" @@ -435,3 +415,7 @@ xftpDecodeTransmission sessionId t = do case tParse True t' of t'' :| [] -> Right $ tDecodeParseValidate sessionId currentXFTPVersion t'' _ -> Left BLOCK + +$(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty) + +$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType) diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index d068731f6..4113b316c 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -169,17 +169,17 @@ processRequest :: HTTP2Request -> M () processRequest HTTP2Request {sessionId, reqBody = body@HTTP2Body {bodyHead}, sendResponse} | B.length bodyHead /= xftpBlockSize = sendXFTPResponse ("", "", FRErr BLOCK) Nothing | otherwise = do - case xftpDecodeTransmission sessionId bodyHead of - Right (sig_, signed, (corrId, fId, cmdOrErr)) -> do - case cmdOrErr of - Right cmd -> do - verifyXFTPTransmission sig_ signed fId cmd >>= \case - VRVerified req -> uncurry send =<< processXFTPRequest body req - VRFailed -> send (FRErr AUTH) Nothing - Left e -> send (FRErr e) Nothing - where - send resp = sendXFTPResponse (corrId, fId, resp) - Left e -> sendXFTPResponse ("", "", FRErr e) Nothing + case xftpDecodeTransmission sessionId bodyHead of + Right (sig_, signed, (corrId, fId, cmdOrErr)) -> do + case cmdOrErr of + Right cmd -> do + verifyXFTPTransmission sig_ signed fId cmd >>= \case + VRVerified req -> uncurry send =<< processXFTPRequest body req + VRFailed -> send (FRErr AUTH) Nothing + Left e -> send (FRErr e) Nothing + where + send resp = sendXFTPResponse (corrId, fId, resp) + Left e -> sendXFTPResponse ("", "", FRErr e) Nothing where sendXFTPResponse :: (CorrId, XFTPFileId, FileResponse) -> Maybe ServerFile -> M () sendXFTPResponse (corrId, fId, resp) serverFile_ = do diff --git a/src/Simplex/FileTransfer/Server/Env.hs b/src/Simplex/FileTransfer/Server/Env.hs index 584594e96..8c82b4a84 100644 --- a/src/Simplex/FileTransfer/Server/Env.hs +++ b/src/Simplex/FileTransfer/Server/Env.hs @@ -26,7 +26,7 @@ import Simplex.FileTransfer.Server.StoreLog import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (BasicAuth, RcvPublicVerifyKey) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig) +import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams) import Simplex.Messaging.Util (tshow) import System.IO (IOMode (..)) import UnliftIO.STM diff --git a/src/Simplex/FileTransfer/Server/Main.hs b/src/Simplex/FileTransfer/Server/Main.hs index 3f082e23c..abe127899 100644 --- a/src/Simplex/FileTransfer/Server/Main.hs +++ b/src/Simplex/FileTransfer/Server/Main.hs @@ -19,7 +19,7 @@ import Options.Applicative import Simplex.FileTransfer.Chunks import Simplex.FileTransfer.Description (FileSize (..)) import Simplex.FileTransfer.Server (runXFTPServer) -import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..), defaultFileExpiration, defFileExpirationHours) +import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..), defFileExpirationHours, defaultFileExpiration) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern XFTPServer) @@ -143,9 +143,10 @@ xftpServerCLI cfgPath logPath = do allowNewFiles = fromMaybe True $ iniOnOff "AUTH" "new_files" ini, newFileBasicAuth = either error id <$> strDecodeIni "AUTH" "create_password" ini, fileExpiration = - Just defaultFileExpiration - { ttl = 3600 * readIniDefault defFileExpirationHours "STORE_LOG" "expire_files_hours" ini - }, + Just + defaultFileExpiration + { ttl = 3600 * readIniDefault defFileExpirationHours "STORE_LOG" "expire_files_hours" ini + }, caCertificateFile = c caCrtFile, privateKeyFile = c serverKeyFile, certificateFile = c serverCrtFile, diff --git a/src/Simplex/FileTransfer/Transport.hs b/src/Simplex/FileTransfer/Transport.hs index 219fd4718..1309b9c31 100644 --- a/src/Simplex/FileTransfer/Transport.hs +++ b/src/Simplex/FileTransfer/Transport.hs @@ -1,14 +1,11 @@ {-# LANGUAGE DuplicateRecordFields #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} module Simplex.FileTransfer.Transport ( supportedFileServerVRange, XFTPRcvChunkSpec (..), - sendFile, receiveFile, sendEncFile, receiveEncFile, @@ -25,10 +22,10 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB import Data.Word (Word32) -import GHC.IO.Handle.Internals (ioe_EOF) import Simplex.FileTransfer.Protocol (XFTPErrorType (..)) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC +import Simplex.Messaging.Transport.HTTP2.File import Simplex.Messaging.Version import System.IO (Handle, IOMode (..), withFile) @@ -42,18 +39,6 @@ data XFTPRcvChunkSpec = XFTPRcvChunkSpec supportedFileServerVRange :: VersionRange supportedFileServerVRange = mkVersionRange 1 1 -fileBlockSize :: Int -fileBlockSize = 16 * 1024 - -sendFile :: Handle -> (Builder -> IO ()) -> Word32 -> IO () -sendFile h send = go - where - go 0 = pure () - go sz = - getFileChunk h sz >>= \ch -> do - send $ byteString ch - go $ sz - fromIntegral (B.length ch) - sendEncFile :: Handle -> (Builder -> IO ()) -> LC.SbState -> Word32 -> IO () sendEncFile h send = go where @@ -66,23 +51,10 @@ sendEncFile h send = go send (byteString encCh) `E.catch` \(e :: E.SomeException) -> print e >> E.throwIO e go sbState' $ sz - fromIntegral (B.length ch) -getFileChunk :: Handle -> Word32 -> IO ByteString -getFileChunk h sz = - B.hGet h fileBlockSize >>= \case - "" -> ioe_EOF - ch -> pure $ B.take (fromIntegral sz) ch -- sz >= xftpBlockSize - receiveFile :: (Int -> IO ByteString) -> XFTPRcvChunkSpec -> ExceptT XFTPErrorType IO () receiveFile getBody = receiveFile_ receive where - receive h sz = do - ch <- getBody fileBlockSize - let chSize = fromIntegral $ B.length ch - if - | chSize > sz -> pure $ Left SIZE - | chSize > 0 -> B.hPut h ch >> receive h (sz - chSize) - | sz == 0 -> pure $ Right () - | otherwise -> pure $ Left SIZE + receive h sz = hReceiveFile getBody h sz >>= \sz' -> pure $ if sz' == 0 then Right () else Left SIZE receiveEncFile :: (Int -> IO ByteString) -> LC.SbState -> XFTPRcvChunkSpec -> ExceptT XFTPErrorType IO () receiveEncFile getBody = receiveFile_ . receive @@ -91,8 +63,8 @@ receiveEncFile getBody = receiveFile_ . receive ch <- getBody fileBlockSize let chSize = fromIntegral $ B.length ch if - | chSize > sz + authSz -> pure $ Left SIZE - | chSize > 0 -> do + | chSize > sz + authSz -> pure $ Left SIZE + | chSize > 0 -> do let (ch', rest) = B.splitAt (fromIntegral sz) ch (decCh, sbState') = LC.sbDecryptChunk sbState ch' sz' = sz - fromIntegral (B.length ch') @@ -105,7 +77,7 @@ receiveEncFile getBody = receiveFile_ . receive tag = LC.sbAuth sbState' tag'' <- if tagSz == C.authTagSize then pure tag' else (tag' <>) <$> getBody (C.authTagSize - tagSz) pure $ if BA.constEq tag'' tag then Right () else Left CRYPTO - | otherwise -> pure $ Left SIZE + | otherwise -> pure $ Left SIZE authSz = fromIntegral C.authTagSize receiveFile_ :: (Handle -> Word32 -> IO (Either XFTPErrorType ())) -> XFTPRcvChunkSpec -> ExceptT XFTPErrorType IO () diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 4fdd33c69..136bae557 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -14,7 +14,6 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} - {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} -- | @@ -43,6 +42,7 @@ module Simplex.Messaging.Agent disconnectAgentClient, resumeAgentClient, withConnLock, + withInvLock, createUser, deleteUser, createConnectionAsync, @@ -151,7 +151,7 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..)) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SubscriptionMode (..), SMPMsgMeta, SProtocolType (..), SndPublicVerifyKey, UserProtocol, XFTPServerWithAuth) +import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicVerifyKey, SubscriptionMode (..), UserProtocol, XFTPServerWithAuth) import qualified Simplex.Messaging.Protocol as SMP import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util @@ -462,8 +462,8 @@ deleteUser' c userId delSMPQueues = do atomically $ TM.delete userId $ smpServers c where delUser = - whenM (withStore' c (`deleteUserWithoutConns` userId)) $ - atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId) + whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $ + writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId) newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> SubscriptionMode -> m ConnId newConnAsync c userId corrId enableNtfs cMode subMode = do @@ -481,16 +481,17 @@ newConnNoQueues c userId connId enableNtfs cMode = do joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo subMode = do - aVRange <- asks $ smpAgentVRange . config - case crAgentVRange `compatibleVersion` aVRange of - Just (Compatible connAgentVersion) -> do - g <- asks idsDrg - let duplexHS = connAgentVersion /= 1 - cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} - connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation - enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo - pure connId - _ -> throwError $ AGENT A_VERSION + withInvLock c (strEncode cReqUri) "joinConnAsync" $ do + aVRange <- asks $ smpAgentVRange . config + case crAgentVRange `compatibleVersion` aVRange of + Just (Compatible connAgentVersion) -> do + g <- asks idsDrg + let duplexHS = connAgentVersion /= 1 + cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} + connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation + enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo + pure connId + _ -> throwError $ AGENT A_VERSION joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo = throwError $ CMD PROHIBITED @@ -554,11 +555,11 @@ switchConnectionAsync' c corrId connId = SomeConn _ (DuplexConnection cData rqs@(rq :| _rqs) sqs) | isJust (switchingRQ rqs) -> throwError $ CMD PROHIBITED | otherwise -> do - when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED - rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted - enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn SWCH - let rqs' = updatedQs rq1 rqs - pure . connectionStats $ DuplexConnection cData rqs' sqs + when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED + rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted + enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn SWCH + let rqs' = updatedQs rq1 rqs + pure . connectionStats $ DuplexConnection cData rqs' sqs _ -> throwError $ CMD PROHIBITED newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) @@ -615,24 +616,25 @@ startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {cr _ -> throwError $ AGENT A_VERSION joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m ConnId -joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = do - (aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv - g <- asks idsDrg - connId' <- withStore c $ \db -> runExceptT $ do - connId' <- ExceptT $ createSndConn db g cData q - liftIO $ createRatchet db connId' rc - pure connId' - let sq = (q :: SndQueue) {connId = connId'} - cData' = (cData :: ConnData) {connId = connId'} - duplexHS = connAgentVersion /= 1 - tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case - Right _ -> do - unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO +joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = + withInvLock c (strEncode inv) "joinConnSrv" $ do + (aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv + g <- asks idsDrg + connId' <- withStore c $ \db -> runExceptT $ do + connId' <- ExceptT $ createSndConn db g cData q + liftIO $ createRatchet db connId' rc pure connId' - Left e -> do - -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md - withStore' c (`deleteConn` connId') - throwError e + let sq = (q :: SndQueue) {connId = connId'} + cData' = (cData :: ConnData) {connId = connId'} + duplexHS = connAgentVersion /= 1 + tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case + Right _ -> do + unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO + pure connId' + Left e -> do + -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md + withStore' c (`deleteConn` connId') + throwError e joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo subMode srv = do aVRange <- asks $ smpAgentVRange . config clientVRange <- asks $ smpClientVRange . config @@ -813,15 +815,15 @@ getNotificationMessage' c nonce encNtfInfo = do where getNtfMessages ntfConnId maxMs nMeta ms | length ms < maxMs = - getConnectionMessage' c ntfConnId >>= \case - Just m@SMP.SMPMsgMeta {msgId, msgTs, msgFlags} -> case nMeta of - Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'} - | msgId == msgId' || msgTs > msgTs' -> pure $ reverse (m : ms) - | otherwise -> getMsg (m : ms) - _ - | SMP.notification msgFlags -> pure $ reverse (m : ms) - | otherwise -> getMsg (m : ms) - _ -> pure $ reverse ms + getConnectionMessage' c ntfConnId >>= \case + Just m@SMP.SMPMsgMeta {msgId, msgTs, msgFlags} -> case nMeta of + Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'} + | msgId == msgId' || msgTs > msgTs' -> pure $ reverse (m : ms) + | otherwise -> getMsg (m : ms) + _ + | SMP.notification msgFlags -> pure $ reverse (m : ms) + | otherwise -> getMsg (m : ms) + _ -> pure $ reverse ms | otherwise = pure $ reverse ms where getMsg = getNtfMessages ntfConnId maxMs nMeta @@ -962,12 +964,12 @@ runCommandProcessing c@AgentClient {subQ} server_ = do Just (rq'@RcvQueue {primary}, rq'' : rqs') | primary -> internalErr "ICQDelete: cannot delete primary rcv queue" | otherwise -> do - checkRQSwchStatus rq' RSReceivedMessage - tryError (deleteQueue c rq') >>= \case - Right () -> finalizeSwitch - Left e - | temporaryOrHostError e -> throwError e - | otherwise -> finalizeSwitch >> throwError e + checkRQSwchStatus rq' RSReceivedMessage + tryError (deleteQueue c rq') >>= \case + Right () -> finalizeSwitch + Left e + | temporaryOrHostError e -> throwError e + | otherwise -> finalizeSwitch >> throwError e where finalizeSwitch = do withStore' c $ \db -> deleteConnRcvQueue db rq' @@ -1123,7 +1125,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, dupl -- because the queue must be secured by the time the confirmation or the first HELLO is received | duplexHandshake == Just True -> connErr | otherwise -> - ifM (msgExpired helloTimeout) connErr (retrySndMsg RIFast) + ifM (msgExpired helloTimeout) connErr (retrySndMsg RIFast) where connErr = case rq_ of -- party initiating connection @@ -1143,8 +1145,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, dupl -- for other operations BROKER HOST is treated as a permanent error (e.g., when connecting to the server), -- the message sending would be retried | temporaryOrHostError e -> do - let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout - ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndMsg RIFast) + let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout + ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndMsg RIFast) | otherwise -> notifyDel msgId err where msgExpired timeoutSel = do @@ -1286,9 +1288,9 @@ switchConnection' c connId = SomeConn _ conn@(DuplexConnection cData rqs@(rq :| _rqs) _) | isJust (switchingRQ rqs) -> throwError $ CMD PROHIBITED | otherwise -> do - when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED - rq' <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted - switchDuplexConnection c conn rq' + when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED + rq' <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted + switchDuplexConnection c conn rq' _ -> throwError $ CMD PROHIBITED switchDuplexConnection :: AgentMonad m => AgentClient -> Connection 'CDuplex -> RcvQueue -> m ConnectionStats @@ -1314,19 +1316,19 @@ abortConnectionSwitch' c connId = SomeConn _ (DuplexConnection cData rqs sqs) -> case switchingRQ rqs of Just rq | canAbortRcvSwitch rq -> do - when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED - -- multiple queues to which the connections switches were possible when repeating switch was allowed - let (delRqs, keepRqs) = L.partition (\q -> Just rq.dbQueueId == q.dbReplaceQueueId) rqs - case L.nonEmpty keepRqs of - Just rqs' -> do - rq' <- withStore' c $ \db -> do - mapM_ (setRcvQueueDeleted db) delRqs - setRcvSwitchStatus db rq Nothing - forM_ delRqs $ \RcvQueue {server, rcvId} -> enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICDeleteRcvQueue rcvId - let rqs'' = updatedQs rq' rqs' - conn' = DuplexConnection cData rqs'' sqs - pure $ connectionStats conn' - _ -> throwError $ INTERNAL "won't delete all rcv queues in connection" + when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED + -- multiple queues to which the connections switches were possible when repeating switch was allowed + let (delRqs, keepRqs) = L.partition (\q -> Just rq.dbQueueId == q.dbReplaceQueueId) rqs + case L.nonEmpty keepRqs of + Just rqs' -> do + rq' <- withStore' c $ \db -> do + mapM_ (setRcvQueueDeleted db) delRqs + setRcvSwitchStatus db rq Nothing + forM_ delRqs $ \RcvQueue {server, rcvId} -> enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICDeleteRcvQueue rcvId + let rqs'' = updatedQs rq' rqs' + conn' = DuplexConnection cData rqs'' sqs + pure $ connectionStats conn' + _ -> throwError $ INTERNAL "won't delete all rcv queues in connection" | otherwise -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED @@ -1336,16 +1338,16 @@ synchronizeRatchet' c connId force = withConnLock c connId "synchronizeRatchet" withStore c (`getConn` connId) >>= \case SomeConn _ (DuplexConnection cData rqs sqs) | ratchetSyncAllowed cData || force -> do - -- check queues are not switching? - AgentConfig {e2eEncryptVRange} <- asks config - (pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange - void $ enqueueRatchetKeyMsgs c cData sqs e2eParams - withStore' c $ \db -> do - setConnRatchetSync db connId RSStarted - setRatchetX3dhKeys db connId pk1 pk2 k1 k2 - let cData' = cData {ratchetSyncState = RSStarted} :: ConnData - conn' = DuplexConnection cData' rqs sqs - pure $ connectionStats conn' + -- check queues are not switching? + AgentConfig {e2eEncryptVRange} <- asks config + (pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange + void $ enqueueRatchetKeyMsgs c cData sqs e2eParams + withStore' c $ \db -> do + setConnRatchetSync db connId RSStarted + setRatchetX3dhKeys db connId pk1 pk2 k1 k2 + let cData' = cData {ratchetSyncState = RSStarted} :: ConnData + conn' = DuplexConnection cData' rqs sqs + pure $ connectionStats conn' | otherwise -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED @@ -1521,23 +1523,23 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = -- possible improvement: add minimal time before repeat registration (Just tknId, Nothing) | savedDeviceToken == suppliedDeviceToken -> - when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered + when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered | otherwise -> replaceToken tknId (Just tknId, Just (NTAVerify code)) | savedDeviceToken == suppliedDeviceToken -> - t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code + t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code | otherwise -> replaceToken tknId (Just tknId, Just NTACheck) | savedDeviceToken == suppliedDeviceToken -> do - ns <- asks ntfSupervisor - atomically $ nsUpdateToken ns tkn {ntfMode = suppliedNtfMode} - when (ntfTknStatus == NTActive) $ do - cron <- asks $ ntfCron . config - agentNtfEnableCron c tknId tkn cron - when (suppliedNtfMode == NMInstant) $ initializeNtfSubs c - when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete - -- possible improvement: get updated token status from the server, or maybe TCRON could return the current status - pure ntfTknStatus + ns <- asks ntfSupervisor + atomically $ nsUpdateToken ns tkn {ntfMode = suppliedNtfMode} + when (ntfTknStatus == NTActive) $ do + cron <- asks $ ntfCron . config + agentNtfEnableCron c tknId tkn cron + when (suppliedNtfMode == NMInstant) $ initializeNtfSubs c + when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete + -- possible improvement: get updated token status from the server, or maybe TCRON could return the current status + pure ntfTknStatus | otherwise -> replaceToken tknId (Just tknId, Just NTADelete) -> do agentNtfDeleteToken c tknId tkn @@ -1647,10 +1649,10 @@ toggleConnectionNtfs' c connId enable = do toggle cData | enableNtfs cData == enable = pure () | otherwise = do - withStore' c $ \db -> setConnectionNtfs db connId enable - ns <- asks ntfSupervisor - let cmd = if enable then NSCCreate else NSCDelete - atomically $ sendNtfSubCommand ns (connId, cmd) + withStore' c $ \db -> setConnectionNtfs db connId enable + ns <- asks ntfSupervisor + let cmd = if enable then NSCCreate else NSCDelete + atomically $ sendNtfSubCommand ns (connId, cmd) deleteToken_ :: AgentMonad m => AgentClient -> NtfToken -> m () deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do @@ -1743,11 +1745,12 @@ getAgentMigrations' :: AgentMonad m => AgentClient -> m [UpMigration] getAgentMigrations' c = map upMigration <$> withStore' c (Migrations.getCurrent . DB.conn) debugAgentLocks' :: AgentMonad' m => AgentClient -> m AgentLocks -debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs, deleteLock = d} = do +debugAgentLocks' AgentClient {connLocks = cs, invLocks = is, reconnectLocks = rs, deleteLock = d} = do connLocks <- getLocks cs + invLocks <- getLocks is srvLocks <- getLocks rs delLock <- atomically $ tryReadTMVar d - pure AgentLocks {connLocks, srvLocks, delLock} + pure AgentLocks {connLocks, invLocks, srvLocks, delLock} where getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls) @@ -1912,11 +1915,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s resetRatchetSync :: m (Connection c) resetRatchetSync | rss `notElem` ([RSOk, RSStarted] :: [RatchetSyncState]) = do - let cData'' = (toConnData conn') {ratchetSyncState = RSOk} :: ConnData - conn'' = updateConnection cData'' conn' - notify . RSYNC RSOk Nothing $ connectionStats conn'' - withStore' c $ \db -> setConnRatchetSync db connId RSOk - pure conn'' + let cData'' = (toConnData conn') {ratchetSyncState = RSOk} :: ConnData + conn'' = updateConnection cData'' conn' + notify . RSYNC RSOk Nothing $ connectionStats conn'' + withStore' c $ \db -> setConnRatchetSync db connId RSOk + pure conn'' | otherwise = pure conn' Right _ -> prohibited >> ack Left e@(AGENT A_DUPLICATE) -> do @@ -1924,11 +1927,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck} | userAck -> ackDel internalId | otherwise -> do - liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case - AgentMessage _ (A_MSG body) -> do - logServer "<--" c srv rId "MSG " - notify $ MSG msgMeta msgFlags body - _ -> pure () + liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case + AgentMessage _ (A_MSG body) -> do + logServer "<--" c srv rId "MSG " + notify $ MSG msgMeta msgFlags body + _ -> pure () _ -> checkDuplicateHash e encryptedMsgHash >> ack Left (AGENT (A_CRYPTO e)) -> do exists <- withStore' c $ \db -> checkRcvMsgHashExists db connId encryptedMsgHash @@ -1976,9 +1979,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s case msgAVRange `compatibleVersion` aVRange of Just (Compatible av) | av > connAgentVersion -> do - withStore' c $ \db -> setConnAgentVersion db connId av - let cData'' = cData' {connAgentVersion = av} :: ConnData - pure $ updateConnection cData'' conn' + withStore' c $ \db -> setConnAgentVersion db connId av + let cData'' = cData' {connAgentVersion = av} :: ConnData + pure $ updateConnection cData'' conn' | otherwise -> pure conn' Nothing -> pure conn' ack :: m () @@ -1994,9 +1997,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s processEND = \case Just (Right clnt) | sessId == sessionId clnt -> do - removeSubscription c connId - notify' END - pure "END" + removeSubscription c connId + notify' END + pure "END" | otherwise -> ignored _ -> ignored ignored = pure "END from disconnected client - ignored" @@ -2186,12 +2189,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s case findRQ (smpServer, senderId) rqs of Just rq'@RcvQueue {rcvId, e2ePrivKey = dhPrivKey, smpClientVersion = cVer, status = status'} | status' == New || status' == Confirmed -> do - checkRQSwchStatus rq RSSendingQADD - logServer "<--" c srv rId $ "MSG " <> logSecret senderId - let dhSecret = C.dh' dhPublicKey dhPrivKey - withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer' - enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey - notify . SWITCH QDRcv SPConfirmed $ connectionStats conn' + checkRQSwchStatus rq RSSendingQADD + logServer "<--" c srv rId $ "MSG " <> logSecret senderId + let dhSecret = C.dh' dhPublicKey dhPrivKey + withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer' + enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey + notify . SWITCH QDRcv SPConfirmed $ connectionStats conn' | otherwise -> qError "QKEY: queue already secured" _ -> qError "QKEY: queue address not found in connection" where @@ -2227,8 +2230,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s ereadyMsg rcPrev (DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) = do let CR.Ratchet {rcSnd} = rcPrev -- if ratchet was initialized as receiving, it means EREADY wasn't sent on key negotiation - when (isNothing rcSnd) $ - void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId + when (isNothing rcSnd) . void $ + enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} (EREADY lastExternalSndId) smpInvitation :: Connection c -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () smpInvitation conn' connReq@(CRInvitationUri crData _) cInfo = do @@ -2267,9 +2270,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s getSendRatchetKeys | rss == RSStarted = withStore c (`getRatchetX3dhKeys'` connId) | otherwise = do - (pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ version e2eOtherPartyParams - void $ enqueueRatchetKeyMsgs c cData' sqs e2eParams - pure (pk1, pk2, k1, k2) + (pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ version e2eOtherPartyParams + void $ enqueueRatchetKeyMsgs c cData' sqs e2eParams + pure (pk1, pk2, k1, k2) notifyAgreed :: m () notifyAgreed = do let cData'' = cData' {ratchetSyncState = RSAgreed} :: ConnData @@ -2285,11 +2288,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s initRatchet :: VersionRange -> (C.PrivateKeyX448, C.PrivateKeyX448, C.PublicKeyX448, C.PublicKeyX448) -> m () initRatchet e2eEncryptVRange (pk1, pk2, k1, k2) | rkHash k1 k2 <= rkHashRcv = do - recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 $ CR.x3dhRcv pk1 pk2 e2eOtherPartyParams + recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 $ CR.x3dhRcv pk1 pk2 e2eOtherPartyParams | otherwise = do - (_, rcDHRs) <- liftIO C.generateKeyPair' - recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs $ CR.x3dhSnd pk1 pk2 e2eOtherPartyParams - void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId + (_, rcDHRs) <- liftIO C.generateKeyPair' + recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs $ CR.x3dhSnd pk1 pk2 e2eOtherPartyParams + void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash @@ -2347,8 +2350,8 @@ mkAgentConfirmation :: AgentMonad m => Compatible Version -> AgentClient -> Conn mkAgentConfirmation (Compatible agentVersion) c cData sq srv connInfo subMode | agentVersion == 1 = pure $ AgentConnInfo connInfo | otherwise = do - qInfo <- createReplyQueue c cData sq subMode srv - pure $ AgentConnInfoReply (qInfo :| []) connInfo + qInfo <- createReplyQueue c cData sq subMode srv + pure $ AgentConnInfoReply (qInfo :| []) connInfo enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () enqueueConfirmation c cData sq connInfo e2eEncryption_ = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 586749606..23ba9f8be 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -2,7 +2,6 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -15,6 +14,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilyDependencies #-} @@ -25,6 +25,7 @@ module Simplex.Messaging.Agent.Client ProtocolTestStep (..), newAgentClient, withConnLock, + withInvLock, closeAgentClient, closeProtocolServerClients, closeXFTPServerClient, @@ -117,8 +118,7 @@ import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader import Crypto.Random (getRandomBytes) -import Data.Aeson (FromJSON, ToJSON) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import Data.Bifunctor (bimap, first, second) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) @@ -137,7 +137,6 @@ import Data.Text (Text) import Data.Text.Encoding import Data.Time (UTCTime, defaultTimeLocale, formatTime, getCurrentTime) import Data.Word (Word16) -import GHC.Generics (Generic) import Network.Socket (HostName) import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError) import qualified Simplex.FileTransfer.Client as X @@ -164,7 +163,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Client import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Types -import Simplex.Messaging.Parsers (dropPrefix, enumJSON, parse) +import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, parse) import Simplex.Messaging.Protocol ( AProtocolType (..), BrokerMsg, @@ -249,6 +248,8 @@ data AgentClient = AgentClient getMsgLocks :: TMap (SMPServer, SMP.RecipientId) (TMVar ()), -- locks to prevent concurrent operations with connection connLocks :: TMap ConnId Lock, + -- locks to prevent concurrent operations with connection request invitations + invLocks :: TMap ByteString Lock, -- lock to prevent concurrency between periodic and async connection deletions deleteLock :: Lock, -- locks to prevent concurrent reconnections to SMP servers @@ -279,10 +280,13 @@ data AgentOpState = AgentOpState {opSuspended :: Bool, opsInProgress :: Int} data AgentState = ASForeground | ASSuspending | ASSuspended deriving (Eq, Show) -data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String, delLock :: Maybe String} - deriving (Show, Generic, FromJSON) - -instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions +data AgentLocks = AgentLocks + { connLocks :: Map String String, + invLocks :: Map String String, + srvLocks :: Map String String, + delLock :: Maybe String + } + deriving (Show) data AgentStatsKey = AgentStatsKey { userId :: UserId, @@ -325,6 +329,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do agentState <- newTVar ASForeground getMsgLocks <- TM.empty connLocks <- TM.empty + invLocks <- TM.empty deleteLock <- createLock reconnectLocks <- TM.empty reconnections <- newTAsyncs @@ -362,6 +367,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do agentState, getMsgLocks, connLocks, + invLocks, deleteLock, reconnectLocks, reconnections, @@ -645,6 +651,9 @@ withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a withConnLock _ "" _ = id withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId name +withInvLock :: MonadUnliftIO m => AgentClient -> ByteString -> String -> m a -> m a +withInvLock AgentClient {invLocks} = withLockMap_ invLocks + withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pure where @@ -725,23 +734,14 @@ data ProtocolTestStep | TSDownloadFile | TSCompareFile | TSDeleteFile - deriving (Eq, Show, Generic) - -instance ToJSON ProtocolTestStep where - toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TS" - toJSON = J.genericToJSON . enumJSON $ dropPrefix "TS" - -instance FromJSON ProtocolTestStep where - parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TS" + deriving (Eq, Show) data ProtocolTestFailure = ProtocolTestFailure { testStep :: ProtocolTestStep, testError :: AgentErrorType } - deriving (Eq, Show, Generic, FromJSON) + deriving (Eq, Show) -instance ToJSON ProtocolTestFailure where toEncoding = J.genericToEncoding J.defaultOptions - runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe ProtocolTestFailure) runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do cfg <- getClientConfig c smpCfg @@ -901,8 +901,8 @@ subscribeQueues c qs = do subscribeQueues_ u smp qs' = do rs <- sendBatch subscribeSMPQueues smp qs' mapM_ (uncurry $ processSubResult c) rs - when (any temporaryClientError . lefts . map snd $ L.toList rs) $ - unliftIO u $ reconnectServer c $ transportSession' smp + when (any temporaryClientError . lefts . map snd $ L.toList rs) . unliftIO u $ + reconnectServer c (transportSession' smp) pure rs type BatchResponses e r = (NonEmpty (RcvQueue, Either e r)) @@ -989,7 +989,7 @@ sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, mkInvitation = do let agentEnvelope = AgentInvitation {agentVersion, connReq, connInfo} agentCbEncryptOnce v dhPublicKey . smpEncode $ - SMP.ClientMessage SMP.PHEmpty $ smpEncode agentEnvelope + SMP.ClientMessage SMP.PHEmpty (smpEncode agentEnvelope) getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta) getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do @@ -1324,7 +1324,7 @@ userServers c = case protocolTypeI @p of SPSMP -> smpServers c SPXFTP -> xftpServers c -pickServer :: forall p m. (AgentMonad' m) => NonEmpty (ProtoServerWithAuth p) -> m (ProtoServerWithAuth p) +pickServer :: forall p m. AgentMonad' m => NonEmpty (ProtoServerWithAuth p) -> m (ProtoServerWithAuth p) pickServer = \case srv :| [] -> pure srv servers -> do @@ -1343,7 +1343,7 @@ withUserServers c userId action = Just srvs -> action srvs _ -> throwError $ INTERNAL "unknown userId - no user servers" -withNextSrv :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> ((ProtoServerWithAuth p) -> m a) -> m a +withNextSrv :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> m a) -> m a withNextSrv c userId usedSrvs initUsed action = do used <- readTVarIO usedSrvs srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId used @@ -1355,22 +1355,14 @@ withNextSrv c userId usedSrvs initUsed action = do action srvAuth data SubInfo = SubInfo {userId :: UserId, server :: Text, rcvId :: Text, subError :: Maybe String} - deriving (Show, Generic) - -instance ToJSON SubInfo where toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True} - -instance FromJSON SubInfo where parseJSON = J.genericParseJSON J.defaultOptions {J.omitNothingFields = True} + deriving (Show) data SubscriptionsInfo = SubscriptionsInfo { activeSubscriptions :: [SubInfo], pendingSubscriptions :: [SubInfo], removedSubscriptions :: [SubInfo] } - deriving (Show, Generic) - -instance ToJSON SubscriptionsInfo where toEncoding = J.genericToEncoding J.defaultOptions - -instance FromJSON SubscriptionsInfo where parseJSON = J.genericParseJSON J.defaultOptions + deriving (Show) getAgentSubscriptions :: MonadIO m => AgentClient -> m SubscriptionsInfo getAgentSubscriptions c = do @@ -1382,6 +1374,16 @@ getAgentSubscriptions c = do getSubs sel = map (`subInfo` Nothing) . M.keys <$> readTVarIO (getRcvQueues $ sel c) getRemovedSubs = map (uncurry subInfo . second Just) . M.assocs <$> readTVarIO (removedSubs c) subInfo :: (UserId, SMPServer, SMP.RecipientId) -> Maybe SMPClientError -> SubInfo - subInfo (uId, srv, rId) err = SubInfo {userId = uId, server = enc srv, rcvId = enc rId, subError = show <$> err} + subInfo (uId, srv, rId) err = SubInfo {userId = uId, server = enc srv, rcvId = enc rId, subError = show <$> err} enc :: StrEncoding a => a -> Text enc = decodeLatin1 . strEncode + +$(J.deriveJSON defaultJSON ''AgentLocks) + +$(J.deriveJSON (enumJSON $ dropPrefix "TS") ''ProtocolTestStep) + +$(J.deriveJSON defaultJSON ''ProtocolTestFailure) + +$(J.deriveJSON defaultJSON ''SubInfo) + +$(J.deriveJSON defaultJSON ''SubscriptionsInfo) diff --git a/src/Simplex/Messaging/Agent/Lock.hs b/src/Simplex/Messaging/Agent/Lock.hs index e0dd22713..10062495d 100644 --- a/src/Simplex/Messaging/Agent/Lock.hs +++ b/src/Simplex/Messaging/Agent/Lock.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE NamedFieldPuns #-} - module Simplex.Messaging.Agent.Lock where import Control.Monad (void) diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index 1e7454723..8434ddbdf 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -5,7 +5,6 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} - {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module Simplex.Messaging.Agent.NtfSubSupervisor @@ -100,14 +99,14 @@ processNtfSub c (connId, cmd) = do Just (action, _) -- subscription was marked for deletion / is being deleted | isDeleteNtfSubAction action -> do - if ntfSubStatus == NASNew || ntfSubStatus == NASOff || ntfSubStatus == NASDeleted - then resetSubscription - else withNtfServer c $ \ntfServer -> do - withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate) - addNtfNTFWorker ntfServer + if ntfSubStatus == NASNew || ntfSubStatus == NASOff || ntfSubStatus == NASDeleted + then resetSubscription + else withNtfServer c $ \ntfServer -> do + withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate) + addNtfNTFWorker ntfServer | otherwise -> case action of - NtfSubNTFAction _ -> addNtfNTFWorker subNtfServer - NtfSubSMPAction _ -> addNtfSMPWorker smpServer + NtfSubNTFAction _ -> addNtfNTFWorker subNtfServer + NtfSubSMPAction _ -> addNtfSMPWorker smpServer rotate :: m () rotate = do withStore' c $ \db -> supervisorUpdateNtfSub db sub (NtfSubNTFAction NSARotate) @@ -291,11 +290,11 @@ rescheduleAction :: AgentMonad' m => TMVar () -> UTCTime -> UTCTime -> m Bool rescheduleAction doWork ts actionTs | actionTs <= ts = pure False | otherwise = do - void . atomically $ tryTakeTMVar doWork - void . forkIO $ do - liftIO $ threadDelay' $ diffToMicroseconds $ diffUTCTime actionTs ts - void . atomically $ tryPutTMVar doWork () - pure True + void . atomically $ tryTakeTMVar doWork + void . forkIO $ do + liftIO $ threadDelay' $ diffToMicroseconds $ diffUTCTime actionTs ts + void . atomically $ tryPutTMVar doWork () + pure True retryOnError :: AgentMonad' m => AgentClient -> Text -> m () -> (AgentErrorType -> m ()) -> AgentErrorType -> m () retryOnError c name loop done e = do diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index b9f0a5b1c..481034e01 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -1,6 +1,5 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -13,6 +12,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} @@ -151,7 +151,7 @@ import Control.Monad (unless) import Control.Monad.Except (runExceptT, throwError) import Control.Monad.IO.Class import Data.Aeson (FromJSON (..), ToJSON (..)) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Base64 @@ -176,8 +176,6 @@ import Data.Typeable () import Data.Word (Word32) import Database.SQLite.Simple.FromField import Database.SQLite.Simple.ToField -import GHC.Generics (Generic) -import Generic.Random (genericArbitraryU) import Simplex.FileTransfer.Description import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType) import Simplex.Messaging.Agent.QueryString @@ -213,7 +211,6 @@ import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTra import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts_ (..)) import Simplex.Messaging.Util import Simplex.Messaging.Version -import Test.QuickCheck (Arbitrary (..)) import Text.Read import UnliftIO.Exception (Exception) @@ -614,15 +611,11 @@ data RcvQueueInfo = RcvQueueInfo rcvSwitchStatus :: Maybe RcvSwitchStatus, canAbortSwitch :: Bool } - deriving (Eq, Show, Generic) - -instance ToJSON RcvQueueInfo where toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True} - -instance FromJSON RcvQueueInfo where parseJSON = J.genericParseJSON J.defaultOptions {J.omitNothingFields = True} + deriving (Eq, Show) instance StrEncoding RcvQueueInfo where strEncode RcvQueueInfo {rcvServer, rcvSwitchStatus, canAbortSwitch} = - "srv=" <> strEncode rcvServer + ("srv=" <> strEncode rcvServer) <> maybe "" (\switch -> ";switch=" <> strEncode switch) rcvSwitchStatus <> (";can_abort_switch=" <> strEncode canAbortSwitch) strP = do @@ -635,11 +628,7 @@ data SndQueueInfo = SndQueueInfo { sndServer :: SMPServer, sndSwitchStatus :: Maybe SndSwitchStatus } - deriving (Eq, Show, Generic) - -instance ToJSON SndQueueInfo where toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True} - -instance FromJSON SndQueueInfo where parseJSON = J.genericParseJSON J.defaultOptions {J.omitNothingFields = True} + deriving (Eq, Show) instance StrEncoding SndQueueInfo where strEncode SndQueueInfo {sndServer, sndSwitchStatus} = @@ -656,13 +645,11 @@ data ConnectionStats = ConnectionStats ratchetSyncState :: RatchetSyncState, ratchetSyncSupported :: Bool } - deriving (Eq, Show, Generic, FromJSON) - -instance ToJSON ConnectionStats where toEncoding = J.genericToEncoding J.defaultOptions + deriving (Eq, Show) instance StrEncoding ConnectionStats where strEncode ConnectionStats {connAgentVersion, rcvQueuesInfo, sndQueuesInfo, ratchetSyncState, ratchetSyncSupported} = - "agent_version=" <> strEncode connAgentVersion + ("agent_version=" <> strEncode connAgentVersion) <> (" rcv=" <> strEncodeList rcvQueuesInfo) <> (" snd=" <> strEncodeList sndQueuesInfo) <> (" sync=" <> strEncode ratchetSyncState) @@ -1048,7 +1035,7 @@ instance StrEncoding MsgReceiptStatus where MROk -> "ok" MRBadMsgHash -> "badMsgHash" strP = - A.takeWhile1 (/= ' ') >>= \ case + A.takeWhile1 (/= ' ') >>= \case "ok" -> pure MROk "badMsgHash" -> pure MRBadMsgHash _ -> fail "bad MsgReceiptStatus" @@ -1391,7 +1378,7 @@ type AgentMsgId = Int64 -- | Result of received message integrity validation. data MsgIntegrity = MsgOk | MsgError {errorInfo :: MsgErrorType} - deriving (Eq, Show, Generic) + deriving (Eq, Show) instance StrEncoding MsgIntegrity where strP = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> strP) @@ -1399,20 +1386,13 @@ instance StrEncoding MsgIntegrity where MsgOk -> "OK" MsgError e -> "ERR " <> strEncode e -instance ToJSON MsgIntegrity where - toJSON = J.genericToJSON $ sumTypeJSON fstToLower - toEncoding = J.genericToEncoding $ sumTypeJSON fstToLower - -instance FromJSON MsgIntegrity where - parseJSON = J.genericParseJSON $ sumTypeJSON fstToLower - -- | Error of message integrity validation. data MsgErrorType = MsgSkipped {fromMsgId :: AgentMsgId, toMsgId :: AgentMsgId} | MsgBadId {msgId :: AgentMsgId} | MsgBadHash | MsgDuplicate - deriving (Eq, Show, Generic) + deriving (Eq, Show) instance StrEncoding MsgErrorType where strP = @@ -1427,13 +1407,6 @@ instance StrEncoding MsgErrorType where MsgBadHash -> "HASH" MsgDuplicate -> "DUPLICATE" -instance ToJSON MsgErrorType where - toJSON = J.genericToJSON $ sumTypeJSON fstToLower - toEncoding = J.genericToEncoding $ sumTypeJSON fstToLower - -instance FromJSON MsgErrorType where - parseJSON = J.genericParseJSON $ sumTypeJSON fstToLower - -- | Error type used in errors sent to agent clients. data AgentErrorType = -- | command or response error @@ -1454,14 +1427,7 @@ data AgentErrorType INTERNAL {internalErr :: String} | -- | agent inactive INACTIVE - deriving (Eq, Generic, Show, Exception) - -instance ToJSON AgentErrorType where - toJSON = J.genericToJSON $ sumTypeJSON id - toEncoding = J.genericToEncoding $ sumTypeJSON id - -instance FromJSON AgentErrorType where - parseJSON = J.genericParseJSON $ sumTypeJSON id + deriving (Eq, Show, Exception) -- | SMP agent protocol command or response error. data CommandErrorType @@ -1475,14 +1441,7 @@ data CommandErrorType SIZE | -- | message does not fit in SMP block LARGE - deriving (Eq, Generic, Read, Show, Exception) - -instance ToJSON CommandErrorType where - toJSON = J.genericToJSON $ sumTypeJSON id - toEncoding = J.genericToEncoding $ sumTypeJSON id - -instance FromJSON CommandErrorType where - parseJSON = J.genericParseJSON $ sumTypeJSON id + deriving (Eq, Read, Show, Exception) -- | Connection error. data ConnectionErrorType @@ -1496,14 +1455,7 @@ data ConnectionErrorType NOT_ACCEPTED | -- | connection not available on reply confirmation/HELLO after timeout NOT_AVAILABLE - deriving (Eq, Generic, Read, Show, Exception) - -instance ToJSON ConnectionErrorType where - toJSON = J.genericToJSON $ sumTypeJSON id - toEncoding = J.genericToEncoding $ sumTypeJSON id - -instance FromJSON ConnectionErrorType where - parseJSON = J.genericParseJSON $ sumTypeJSON id + deriving (Eq, Read, Show, Exception) -- | SMP server errors. data BrokerErrorType @@ -1519,14 +1471,7 @@ data BrokerErrorType TRANSPORT {transportErr :: TransportError} | -- | command response timeout TIMEOUT - deriving (Eq, Generic, Read, Show, Exception) - -instance ToJSON BrokerErrorType where - toJSON = J.genericToJSON $ sumTypeJSON id - toEncoding = J.genericToEncoding $ sumTypeJSON id - -instance FromJSON BrokerErrorType where - parseJSON = J.genericParseJSON $ sumTypeJSON id + deriving (Eq, Read, Show, Exception) -- | Errors of another SMP agent. data SMPAgentError @@ -1543,7 +1488,7 @@ data SMPAgentError A_DUPLICATE | -- | error in the message to add/delete/etc queue in connection A_QUEUE {queueErr :: String} - deriving (Eq, Generic, Read, Show, Exception) + deriving (Eq, Read, Show, Exception) data AgentCryptoError = -- | AES decryption error @@ -1556,14 +1501,7 @@ data AgentCryptoError RATCHET_EARLIER Word32 | -- | too many skipped messages RATCHET_SKIPPED Word32 - deriving (Eq, Generic, Read, Show, Exception) - -instance ToJSON AgentCryptoError where - toJSON = J.genericToJSON $ sumTypeJSON id - toEncoding = J.genericToEncoding $ sumTypeJSON id - -instance FromJSON AgentCryptoError where - parseJSON = J.genericParseJSON $ sumTypeJSON id + deriving (Eq, Read, Show, Exception) instance StrEncoding AgentCryptoError where strP = @@ -1579,13 +1517,6 @@ instance StrEncoding AgentCryptoError where RATCHET_EARLIER n -> "RATCHET_EARLIER " <> strEncode n RATCHET_SKIPPED n -> "RATCHET_SKIPPED " <> strEncode n -instance ToJSON SMPAgentError where - toJSON = J.genericToJSON $ sumTypeJSON id - toEncoding = J.genericToEncoding $ sumTypeJSON id - -instance FromJSON SMPAgentError where - parseJSON = J.genericParseJSON $ sumTypeJSON id - instance StrEncoding AgentErrorType where strP = "CMD " *> (CMD <$> parseRead1) @@ -1620,18 +1551,6 @@ instance StrEncoding AgentErrorType where where text = encodeUtf8 . T.pack -instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU - -instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU - -instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU - -instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU - -instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU - -instance Arbitrary AgentCryptoError where arbitrary = genericArbitraryU - cryptoErrToSyncState :: AgentCryptoError -> RatchetSyncState cryptoErrToSyncState = \case DECRYPT_AES -> RSAllowed @@ -1957,3 +1876,25 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody unless (B.null s) $ throwError $ CMD SIZE pure body Nothing -> return . Left $ CMD SYNTAX + +$(J.deriveJSON defaultJSON ''RcvQueueInfo) + +$(J.deriveJSON defaultJSON ''SndQueueInfo) + +$(J.deriveJSON defaultJSON ''ConnectionStats) + +$(J.deriveJSON (sumTypeJSON fstToLower) ''MsgErrorType) + +$(J.deriveJSON (sumTypeJSON fstToLower) ''MsgIntegrity) + +$(J.deriveJSON (sumTypeJSON id) ''CommandErrorType) + +$(J.deriveJSON (sumTypeJSON id) ''ConnectionErrorType) + +$(J.deriveJSON (sumTypeJSON id) ''BrokerErrorType) + +$(J.deriveJSON (sumTypeJSON id) ''AgentCryptoError) + +$(J.deriveJSON (sumTypeJSON id) ''SMPAgentError) + +$(J.deriveJSON (sumTypeJSON id) ''AgentErrorType) diff --git a/src/Simplex/Messaging/Agent/RetryInterval.hs b/src/Simplex/Messaging/Agent/RetryInterval.hs index 97d537a5a..3538d0aab 100644 --- a/src/Simplex/Messaging/Agent/RetryInterval.hs +++ b/src/Simplex/Messaging/Agent/RetryInterval.hs @@ -1,6 +1,5 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Agent.RetryInterval diff --git a/src/Simplex/Messaging/Agent/Server.hs b/src/Simplex/Messaging/Agent/Server.hs index 32a085511..ec66a5aa7 100644 --- a/src/Simplex/Messaging/Agent/Server.hs +++ b/src/Simplex/Messaging/Agent/Server.hs @@ -23,7 +23,7 @@ import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore) import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), simplexMQVersion) -import Simplex.Messaging.Transport.Server (loadTLSServerParams, runTransportServer, defaultTransportServerConfig) +import Simplex.Messaging.Transport.Server (defaultTransportServerConfig, loadTLSServerParams, runTransportServer) import Simplex.Messaging.Util (bshow) import UnliftIO.Async (race_) import qualified UnliftIO.Exception as E diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index dda6c7c65..27b193693 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -576,6 +576,8 @@ data StoreError SEServerNotFound | -- | Connection already used. SEConnDuplicate + | -- | Confirmed snd queue already exists. + SESndQueueExists | -- | Wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa. -- 'upgradeRcvConnToDuplex' and 'upgradeSndConnToDuplex' do not allow duplex connections - they would also return this error. SEBadConnType ConnType diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index e14e68e75..a08b758bd 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -1,7 +1,6 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -17,12 +16,12 @@ {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} - -{-# OPTIONS_GHC -fno-warn-orphans #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} module Simplex.Messaging.Agent.Store.SQLite ( SQLiteStore (..), @@ -220,8 +219,7 @@ import Control.Monad import Control.Monad.Except import Control.Monad.IO.Class import Crypto.Random (ChaChaDRG, randomBytesGenerate) -import Data.Aeson (FromJSON, ToJSON) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (second) import Data.ByteString (ByteString) @@ -247,7 +245,6 @@ import Database.SQLite.Simple.FromField import Database.SQLite.Simple.QQ (sql) import Database.SQLite.Simple.ToField (ToField (..)) import qualified Database.SQLite3 as SQLite3 -import GHC.Generics (Generic) import Network.Socket (ServiceName) import Simplex.FileTransfer.Client (XFTPChunkSpec (..)) import Simplex.FileTransfer.Description @@ -267,7 +264,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..)) import Simplex.Messaging.Notifications.Types -import Simplex.Messaging.Parsers (blobFieldParser, dropPrefix, fromTextField_, sumTypeJSON) +import Simplex.Messaging.Parsers (blobFieldParser, defaultJSON, dropPrefix, fromTextField_, sumTypeJSON) import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport.Client (TransportHost) @@ -277,7 +274,7 @@ import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) import System.Exit (exitFailure) import System.FilePath (takeDirectory) import System.IO (hFlush, stdout) -import UnliftIO.Exception (onException, bracketOnError) +import UnliftIO.Exception (bracketOnError, onException) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -287,7 +284,7 @@ data MigrationError = MEUpgrade {upMigrations :: [UpMigration]} | MEDowngrade {downMigrations :: [String]} | MigrationError {mtrError :: MTRError} - deriving (Eq, Show, Generic) + deriving (Eq, Show) migrationErrorDescription :: MigrationError -> String migrationErrorDescription = \case @@ -297,18 +294,12 @@ migrationErrorDescription = \case "Database version is newer than the app.\nConfirm to back up and downgrade using these migrations: " <> intercalate ", " dms MigrationError err -> mtrErrorDescription err -instance ToJSON MigrationError where - toJSON = J.genericToJSON . sumTypeJSON $ dropPrefix "ME" - toEncoding = J.genericToEncoding . sumTypeJSON $ dropPrefix "ME" - data UpMigration = UpMigration {upName :: String, withDown :: Bool} - deriving (Eq, Show, Generic, FromJSON) + deriving (Eq, Show) upMigration :: Migration -> UpMigration upMigration Migration {name, down} = UpMigration name $ isJust down -instance ToJSON UpMigration where toEncoding = J.genericToEncoding J.defaultOptions - data MigrationConfirmation = MCYesUp | MCYesUpDown | MCConsole | MCError deriving (Eq, Show) @@ -347,10 +338,10 @@ migrateSchema st migrations confirmMigrations = do Right ms@(MTRUp ums) | dbNew st -> Migrations.run st ms $> Right () | otherwise -> case confirmMigrations of - MCYesUp -> run ms - MCYesUpDown -> run ms - MCConsole -> confirm err >> run ms - MCError -> pure $ Left err + MCYesUp -> run ms + MCYesUpDown -> run ms + MCConsole -> confirm err >> run ms + MCError -> pure $ Left err where err = MEUpgrade $ map upMigration ums -- "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map name ums) Right ms@(MTRDown dms) -> case confirmMigrations of @@ -561,10 +552,23 @@ createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, dupl createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId) createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} = - createConn_ gVar cData $ \connId -> do - serverKeyHash_ <- createServer_ db server - DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake) - void $ insertSndQueue_ db connId q serverKeyHash_ + -- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_ + ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $ + createConn_ gVar cData $ \connId -> do + serverKeyHash_ <- createServer_ db server + DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake) + void $ insertSndQueue_ db connId q serverKeyHash_ + +checkConfirmedSndQueueExists_ :: DB.Connection -> SndQueue -> IO Bool +checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do + fromMaybe False + <$> maybeFirstRow + fromOnly + ( DB.query + db + "SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1" + (host server, port server, sndId, New) + ) getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn)) getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do @@ -980,7 +984,7 @@ updatePendingMsgRIState db connId msgId RI2State {slowInterval, fastInterval} = getPendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO [InternalId] getPendingMsgs db connId SndQueue {dbQueueId} = map fromOnly - <$> DB.query db "SELECT internal_id FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) + <$> DB.query db "SELECT internal_id FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ? ORDER BY internal_id ASC" (connId, dbQueueId) deletePendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO () deletePendingMsgs db connId SndQueue {dbQueueId} = @@ -1079,12 +1083,12 @@ countPendingSndDeliveries_ db connId msgId = do deleteRcvMsgHashesExpired :: DB.Connection -> NominalDiffTime -> IO () deleteRcvMsgHashesExpired db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime DB.execute db "DELETE FROM encrypted_rcv_message_hashes WHERE created_at < ?" (Only cutoffTs) deleteSndMsgsExpired :: DB.Connection -> NominalDiffTime -> IO () deleteSndMsgsExpired db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime DB.execute db "DELETE FROM messages WHERE internal_ts < ? AND internal_snd_id IS NOT NULL" @@ -1163,7 +1167,7 @@ getSkippedMsgKeys :: DB.Connection -> ConnId -> IO SkippedMsgKeys getSkippedMsgKeys db connId = skipped <$> DB.query db "SELECT header_key, msg_n, msg_key FROM skipped_messages WHERE conn_id = ?" (Only connId) where - skipped ms = foldl' addSkippedKey M.empty ms + skipped = foldl' addSkippedKey M.empty addSkippedKey smks (hk, msgN, mk) = M.alter (Just . addMsgKey) hk smks where addMsgKey = maybe (M.singleton msgN mk) (M.insert msgN mk) @@ -1734,15 +1738,15 @@ getAnyConn deleted' dbConn connId = Just (cData@ConnData {deleted}, cMode) | deleted /= deleted' -> pure $ Left SEConnNotFound | otherwise -> do - rQ <- getRcvQueuesByConnId_ dbConn connId - sQ <- getSndQueuesByConnId_ dbConn connId - pure $ case (rQ, sQ, cMode) of - (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) - (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) - (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) - (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) - (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) - _ -> Left SEConnNotFound + rQ <- getRcvQueuesByConnId_ dbConn connId + sQ <- getSndQueuesByConnId_ dbConn connId + pure $ case (rQ, sQ, cMode) of + (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) + (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) + (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) + (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) + (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) + _ -> Left SEConnNotFound getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getConns = getAnyConns_ False @@ -1804,7 +1808,7 @@ checkRatchetKeyHashExists db connId hash = do deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO () deleteRatchetKeyHashesExpired db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime DB.execute db "DELETE FROM processed_ratchet_key_hashes WHERE created_at < ?" (Only cutoffTs) -- | returns all connection queues, the first queue is the primary one @@ -2253,7 +2257,7 @@ deleteRcvFile' db rcvFileId = getNextRcvChunkToDownload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Maybe RcvFileChunk) getNextRcvChunkToDownload db server@ProtocolServer {host, port, keyHash} ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime maybeFirstRow toChunk $ DB.query db @@ -2290,7 +2294,7 @@ getNextRcvChunkToDownload db server@ProtocolServer {host, port, keyHash} ttl = d getNextRcvFileToDecrypt :: DB.Connection -> NominalDiffTime -> IO (Maybe RcvFile) getNextRcvFileToDecrypt db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime fileId_ :: Maybe DBRcvFileId <- maybeFirstRow fromOnly $ DB.query @@ -2308,7 +2312,7 @@ getNextRcvFileToDecrypt db ttl = do getPendingRcvFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer] getPendingRcvFilesServers db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime map toXFTPServer <$> DB.query db @@ -2350,7 +2354,7 @@ getCleanupRcvFilesDeleted db = getRcvFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBRcvFileId, RcvFileId, FilePath)] getRcvFilesExpired db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime DB.query db [sql| @@ -2458,7 +2462,7 @@ getChunkReplicaRecipients_ db replicaId = getNextSndFileToPrepare :: DB.Connection -> NominalDiffTime -> IO (Maybe SndFile) getNextSndFileToPrepare db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime fileId_ :: Maybe DBSndFileId <- maybeFirstRow fromOnly $ DB.query @@ -2539,7 +2543,7 @@ createSndFileReplica db SndFileChunk {sndChunkId} NewSndChunkReplica {server, re getNextSndChunkToUpload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Maybe SndFileChunk) getNextSndChunkToUpload db server@ProtocolServer {host, port, keyHash} ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime chunk_ <- maybeFirstRow toChunk $ DB.query @@ -2608,7 +2612,7 @@ updateSndChunkReplicaStatus db replicaId status = do getPendingSndFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer] getPendingSndFilesServers db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime map toXFTPServer <$> DB.query db @@ -2647,7 +2651,7 @@ getCleanupSndFilesDeleted db = getSndFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBSndFileId, SndFileId, Maybe FilePath)] getSndFilesExpired db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime DB.query db [sql| @@ -2687,7 +2691,7 @@ getDeletedSndChunkReplica db deletedSndChunkReplicaId = getNextDeletedSndChunkReplica :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Maybe DeletedSndChunkReplica) getNextDeletedSndChunkReplica db ProtocolServer {host, port, keyHash} ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime replicaId_ :: Maybe Int64 <- maybeFirstRow fromOnly $ DB.query @@ -2716,7 +2720,7 @@ deleteDeletedSndChunkReplica db deletedSndChunkReplicaId = getPendingDelFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer] getPendingDelFilesServers db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime map toXFTPServer <$> DB.query db @@ -2731,5 +2735,9 @@ getPendingDelFilesServers db ttl = do deleteDeletedSndChunkReplicasExpired :: DB.Connection -> NominalDiffTime -> IO () deleteDeletedSndChunkReplicasExpired db ttl = do - cutoffTs <- addUTCTime (- ttl) <$> getCurrentTime + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime DB.execute db "DELETE FROM deleted_snd_chunk_replicas WHERE created_at < ?" (Only cutoffTs) + +$(J.deriveJSON defaultJSON ''UpMigration) + +$(J.deriveToJSON (sumTypeJSON $ dropPrefix "ME") ''MigrationError) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs index 789f7214e..bfba973c1 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs @@ -1,7 +1,7 @@ {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} module Simplex.Messaging.Agent.Store.SQLite.DB ( Connection (..), @@ -20,13 +20,12 @@ where import Control.Concurrent.STM import Control.Monad (when) -import Data.Aeson (FromJSON, ToJSON) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import Data.Int (Int64) import Data.Time (diffUTCTime, getCurrentTime) import Database.SQLite.Simple (FromRow, NamedParam, Query, ToRow) import qualified Database.SQLite.Simple as SQL -import GHC.Generics (Generic) +import Simplex.Messaging.Parsers (defaultJSON) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (diffToMilliseconds) @@ -41,9 +40,7 @@ data SlowQueryStats = SlowQueryStats timeMax :: Int64, timeAvg :: Int64 } - deriving (Show, Generic, FromJSON) - -instance ToJSON SlowQueryStats where toEncoding = J.genericToEncoding J.defaultOptions + deriving (Show) timeIt :: TMap Query SlowQueryStats -> Query -> IO a -> IO a timeIt slow sql a = do @@ -100,3 +97,5 @@ query_ Connection {conn, slow} sql = timeIt slow sql $ SQL.query_ conn sql queryNamed :: FromRow r => Connection -> Query -> [NamedParam] -> IO [r] queryNamed Connection {conn, slow} sql = timeIt slow sql . SQL.queryNamed conn sql {-# INLINE queryNamed #-} + +$(J.deriveJSON defaultJSON ''SlowQueryStats) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index 6d46b7cc0..8ce7d6514 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLists #-} @@ -27,8 +26,7 @@ module Simplex.Messaging.Agent.Store.SQLite.Migrations where import Control.Monad (forM_, when) -import Data.Aeson (ToJSON) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import Data.List (intercalate, sortOn) import Data.List.NonEmpty (NonEmpty) import qualified Data.Map as M @@ -40,7 +38,6 @@ import Database.SQLite.Simple (Connection, Only (..), Query (..)) import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.QQ (sql) import qualified Database.SQLite3 as SQLite3 -import GHC.Generics (Generic) import Simplex.Messaging.Agent.Protocol (extraSMPServerHosts) import Simplex.Messaging.Agent.Store.SQLite.Common import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220101_initial @@ -169,11 +166,7 @@ data MigrationsToRun = MTRUp [Migration] | MTRDown [DownMigration] | MTRNone data MTRError = MTRENoDown {dbMigrations :: [String]} | MTREDifferent {appMigration :: String, dbMigration :: String} - deriving (Eq, Show, Generic) - -instance ToJSON MTRError where - toJSON = J.genericToJSON . sumTypeJSON $ dropPrefix "MTRE" - toEncoding = J.genericToEncoding . sumTypeJSON $ dropPrefix "MTRE" + deriving (Eq, Show) mtrErrorDescription :: MTRError -> String mtrErrorDescription = \case @@ -192,3 +185,5 @@ migrationsToRun [] dbMs migrationsToRun (a : as) (d : ds) | name a == name d = migrationsToRun as ds | otherwise = Left $ MTREDifferent (name a) (name d) + +$(J.deriveJSON (sumTypeJSON $ dropPrefix "MTRE") ''MTRError) diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs index ffdaf3631..5f1ddf104 100644 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE LambdaCase #-} module Simplex.Messaging.Agent.TRcvQueues ( TRcvQueues (getRcvQueues), empty, diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 5152c0212..11d75b7ee 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -1,6 +1,5 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} @@ -10,6 +9,7 @@ {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} @@ -86,8 +86,7 @@ import Control.Exception import Control.Monad import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Except -import Data.Aeson (FromJSON (..), ToJSON (..)) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) @@ -97,13 +96,12 @@ import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Maybe (fromMaybe) import Data.Time.Clock (UTCTime, getCurrentTime) -import GHC.Generics (Generic) import Network.Socket (ServiceName) import Numeric.Natural import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Parsers (dropPrefix, enumJSON) +import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON) import Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -190,14 +188,7 @@ data HostMode HMOnion | -- | prefer (or require) public hosts HMPublic - deriving (Eq, Show, Generic) - -instance FromJSON HostMode where - parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "HM" - -instance ToJSON HostMode where - toJSON = J.genericToJSON . enumJSON $ dropPrefix "HM" - toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "HM" + deriving (Eq, Show) -- | network configuration for the client data NetworkConfig = NetworkConfig @@ -223,21 +214,10 @@ data NetworkConfig = NetworkConfig smpPingCount :: Int, logTLSErrors :: Bool } - deriving (Eq, Show, Generic, FromJSON) - -instance ToJSON NetworkConfig where - toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True} - toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True} + deriving (Eq, Show) data TransportSessionMode = TSMUser | TSMEntity - deriving (Eq, Show, Generic) - -instance ToJSON TransportSessionMode where - toJSON = J.genericToJSON . enumJSON $ dropPrefix "TSM" - toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TSM" - -instance FromJSON TransportSessionMode where - parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TSM" + deriving (Eq, Show) defaultNetworkConfig :: NetworkConfig defaultNetworkConfig = @@ -429,11 +409,11 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, where response entityId | entityId == entId = - case respOrErr of - Left e -> Left $ PCEResponseError e - Right r -> case protocolError r of - Just e -> Left $ PCEProtocolError e - _ -> Right r + case respOrErr of + Left e -> Left $ PCEResponseError e + Right r -> case protocolError r of + Just e -> Left $ PCEProtocolError e + _ -> Right r | otherwise = Left . PCEUnexpectedResponse $ bshow respOrErr sendMsg :: Either err msg -> IO () sendMsg = \case @@ -661,13 +641,13 @@ sendProtocolCommands c@ProtocolClient {batch, blockSize} cs = do validate rs | diff == 0 = pure $ L.fromList rs | diff > 0 = do - putStrLn "send error: fewer responses than expected" - pure $ L.fromList $ rs <> replicate diff (Response "" $ Left $ PCETransportError TEBadBlock) + putStrLn "send error: fewer responses than expected" + pure $ L.fromList $ rs <> replicate diff (Response "" $ Left $ PCETransportError TEBadBlock) | otherwise = do - putStrLn "send error: more responses than expected" - pure $ L.fromList $ take (L.length cs) rs - where - diff = L.length cs - length rs + putStrLn "send error: more responses than expected" + pure $ L.fromList $ take (L.length cs) rs + where + diff = L.length cs - length rs streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () streamProtocolCommands c@ProtocolClient {batch, blockSize} cs cb = do @@ -688,8 +668,8 @@ sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do (: []) <$> getResponse c r data ClientBatch err msg - -- ByteString in CBTransmissions does not include count byte, it is added by tEncodeBatch - = CBTransmissions ByteString Int [Request err msg] + = -- ByteString in CBTransmissions does not include count byte, it is added by tEncodeBatch + CBTransmissions ByteString Int [Request err msg] | CBTransmission ByteString (Request err msg) | CBLargeTransmission (Request err msg) @@ -713,9 +693,9 @@ batchClientTransmissions batch blkSize encodeBatch :: ByteString -> Int -> [Request err msg] -> NonEmpty (PCTransmission err msg) -> (ClientBatch err msg, Maybe (NonEmpty (PCTransmission err msg))) encodeBatch s n rs ts@((t, r) :| ts_) | B.length s' <= blkSize - 3 && n < 255 = - case L.nonEmpty ts_ of - Just ts' -> encodeBatch s' n' rs' ts' - Nothing -> (CBTransmissions s' n' (reverse rs'), Nothing) + case L.nonEmpty ts_ of + Just ts' -> encodeBatch s' n' rs' ts' + Nothing -> (CBTransmissions s' n' (reverse rs'), Nothing) | n == 0 = (CBLargeTransmission r, L.nonEmpty ts_) | otherwise = (CBTransmissions s n (reverse rs), Just ts) where @@ -765,3 +745,9 @@ mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCo r <- Request entId <$> newEmptyTMVar TM.insert corrId r sentCommands pure r + +$(J.deriveJSON (enumJSON $ dropPrefix "HM") ''HostMode) + +$(J.deriveJSON (enumJSON $ dropPrefix "TSM") ''TransportSessionMode) + +$(J.deriveJSON defaultJSON ''NetworkConfig) diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 489223270..4d0d81bbc 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -18,7 +18,7 @@ import Control.Monad import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Trans.Except -import Data.Bifunctor (first, bimap) +import Data.Bifunctor (bimap, first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Either (partitionEithers) @@ -36,11 +36,11 @@ import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, SMPServer, NtfPrivateSignKey, NotifierId, RcvPrivateSignKey, RecipientId) +import Simplex.Messaging.Protocol (BrokerMsg, NotifierId, NtfPrivateSignKey, ProtocolServer (..), QueueId, RcvPrivateSignKey, RecipientId, SMPServer) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport -import Simplex.Messaging.Util (catchAll_, ($>>=), toChunks) +import Simplex.Messaging.Util (catchAll_, toChunks, ($>>=)) import System.Timeout (timeout) import UnliftIO (async) import UnliftIO.Exception (Exception) @@ -238,7 +238,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = (tempErrs, finalErrs) = partition (temporaryClientError . snd) errs mapM_ (atomically . addSubscription ca srv) oks mapM_ (liftIO . notify . CAResubscribed srv) $ L.nonEmpty $ map fst oks - mapM_ (atomically . removePendingSubscription ca srv . fst) finalErrs + mapM_ (atomically . removePendingSubscription ca srv . fst) finalErrs mapM_ (liftIO . notify . CASubError srv) $ L.nonEmpty finalErrs mapM_ (throwE . snd) $ listToMaybe tempErrs @@ -281,7 +281,7 @@ subscribeQueue ca srv sub = do handleErr e = do atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $ - removePendingSubscription ca srv $ fst sub + removePendingSubscription ca srv (fst sub) throwE e subscribeQueuesSMP :: SMPClientAgent -> SMPServer -> NonEmpty (RecipientId, RcvPrivateSignKey) -> IO (NonEmpty (RecipientId, Either SMPClientError ())) @@ -300,14 +300,15 @@ subscribeQueues_ party ca srv subs = do smpSubscribeQueues :: SMPSubParty -> SMPClientAgent -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateSignKey) -> IO (NonEmpty (QueueId, Either SMPClientError ())) smpSubscribeQueues party ca smp srv subs = do rs <- L.zip subs <$> subscribe smp (L.map swap subs) - atomically $ forM rs $ \(sub, r) -> (fst sub,) <$> case r of - Right () -> do - addSubscription ca srv $ first (party,) sub - pure $ Right () - Left e -> do - when (e /= PCENetworkError && e /= PCEResponseTimeout) $ - removePendingSubscription ca srv $ (party,) $ fst sub - pure $ Left e + atomically $ forM rs $ \(sub, r) -> + (fst sub,) <$> case r of + Right () -> do + addSubscription ca srv $ first (party,) sub + pure $ Right () + Left e -> do + when (e /= PCENetworkError && e /= PCEResponseTimeout) $ + removePendingSubscription ca srv (party, fst sub) + pure $ Left e where subscribe = case party of SPRecipient -> subscribeSMPQueues diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 503138132..cfc8156cf 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -680,9 +680,9 @@ instance CryptoSignature ASignature where signatureBytes (ASignature _ sig) = signatureBytes sig decodeSignature s | B.length s == Ed25519.signatureSize = - ASignature SEd25519 . SignatureEd25519 <$> ed Ed25519.signature s + ASignature SEd25519 . SignatureEd25519 <$> ed Ed25519.signature s | B.length s == Ed448.signatureSize = - ASignature SEd448 . SignatureEd448 <$> ed Ed448.signature s + ASignature SEd448 . SignatureEd448 <$> ed Ed448.signature s | otherwise = Left "bad signature size" where ed alg = first show . CE.eitherCryptoError . alg diff --git a/src/Simplex/Messaging/Crypto/File.hs b/src/Simplex/Messaging/Crypto/File.hs index 8de0bbb61..9afc5d583 100644 --- a/src/Simplex/Messaging/Crypto/File.hs +++ b/src/Simplex/Messaging/Crypto/File.hs @@ -1,7 +1,6 @@ {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} module Simplex.Messaging.Crypto.File ( CryptoFile (..), @@ -24,19 +23,18 @@ where import Control.Exception import Control.Monad import Control.Monad.Except -import Data.Aeson (FromJSON, ToJSON) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import qualified Data.ByteArray as BA import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy as LB import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe (isJust) -import GHC.Generics (Generic) import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Lazy (LazyByteString) import qualified Simplex.Messaging.Crypto.Lazy as LC +import Simplex.Messaging.Parsers (defaultJSON) import Simplex.Messaging.Util (liftEitherWith) import System.Directory (getFileSize) import UnliftIO (Handle, IOMode (..), liftIO) @@ -45,16 +43,10 @@ import UnliftIO.STM -- Possibly encrypted local file data CryptoFile = CryptoFile {filePath :: FilePath, cryptoArgs :: Maybe CryptoFileArgs} - deriving (Eq, Show, Generic, FromJSON) - -instance ToJSON CryptoFile where - toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True} - toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True} + deriving (Eq, Show) data CryptoFileArgs = CFArgs {fileKey :: C.SbKey, fileNonce :: C.CbNonce} - deriving (Eq, Show, Generic, FromJSON) - -instance ToJSON CryptoFileArgs where toEncoding = J.genericToEncoding J.defaultOptions + deriving (Eq, Show) data CryptoFileHandle = CFHandle Handle (Maybe (TVar LC.SbState)) @@ -124,3 +116,7 @@ getFileContentsSize :: CryptoFile -> IO Integer getFileContentsSize (CryptoFile path cfArgs) = do size <- getFileSize path pure $ if isJust cfArgs then size - fromIntegral C.authTagSize else size + +$(J.deriveJSON defaultJSON ''CryptoFileArgs) + +$(J.deriveJSON defaultJSON ''CryptoFile) diff --git a/src/Simplex/Messaging/Crypto/Lazy.hs b/src/Simplex/Messaging/Crypto/Lazy.hs index 6fb37adf7..e0117108c 100644 --- a/src/Simplex/Messaging/Crypto/Lazy.hs +++ b/src/Simplex/Messaging/Crypto/Lazy.hs @@ -1,7 +1,6 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} @@ -90,10 +89,10 @@ unPad = fmap snd . splitLen splitLen :: LazyByteString -> Either CryptoError (Int64, LazyByteString) splitLen padded | LB.length lenStr == 8 = case smpDecode $ LB.toStrict lenStr of - Right len - | len < 0 -> Left CryptoInvalidMsgError - | otherwise -> Right (len, LB.take len rest) - Left _ -> Left CryptoInvalidMsgError + Right len + | len < 0 -> Left CryptoInvalidMsgError + | otherwise -> Right (len, LB.take len rest) + Left _ -> Left CryptoInvalidMsgError | otherwise = Left CryptoInvalidMsgError where (lenStr, rest) = LB.splitAt 8 padded @@ -112,10 +111,10 @@ sbDecrypt :: SbKey -> CbNonce -> LazyByteString -> Either CryptoError LazyByteSt sbDecrypt (SbKey key) (CbNonce nonce) packet | LB.length tag' < 16 = Left CBDecryptError | otherwise = case secretBox sbDecryptChunk key nonce c of - Right (tag :| cs) - | BA.constEq (LB.toStrict tag') tag -> unPad $ LB.fromChunks cs - | otherwise -> Left CBDecryptError - Left e -> Left e + Right (tag :| cs) + | BA.constEq (LB.toStrict tag') tag -> unPad $ LB.fromChunks cs + | otherwise -> Left CBDecryptError + Left e -> Left e where (tag', c) = LB.splitAt 16 packet diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index c04b418c1..b84975a3d 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -1,12 +1,12 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} @@ -20,8 +20,9 @@ import Control.Monad.Trans.Except import Crypto.Cipher.AES (AES256) import Crypto.Hash (SHA512) import qualified Crypto.KDF.HKDF as H -import Data.Aeson (FromJSON, ToJSON) +import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson as J +import qualified Data.Aeson.TH as JQ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy as LB @@ -33,12 +34,11 @@ import Data.Typeable (Typeable) import Data.Word (Word32) import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) -import GHC.Generics import Simplex.Messaging.Agent.QueryString import Simplex.Messaging.Crypto import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Parsers (blobFieldDecoder, parseE, parseE') +import Simplex.Messaging.Parsers (blobFieldDecoder, defaultJSON, parseE, parseE') import Simplex.Messaging.Version currentE2EEncryptVersion :: Version @@ -112,11 +112,11 @@ x3dh v (sk1, rk1) dh1 dh2 dh3 = (hk, nhk, sk) -- for backwards compatibility with clients using agent version before 3.4.0 | v == 1 = - let (hk', rest) = B.splitAt 32 dhs - in uncurry (hk',,) $ B.splitAt 32 rest + let (hk', rest) = B.splitAt 32 dhs + in uncurry (hk',,) $ B.splitAt 32 rest | otherwise = - let salt = B.replicate 64 '\0' - in hkdf3 salt dhs "SimpleXX3DH" + let salt = B.replicate 64 '\0' + in hkdf3 salt dhs "SimpleXX3DH" type RatchetX448 = Ratchet 'X448 @@ -135,29 +135,20 @@ data Ratchet a = Ratchet rcNHKs :: HeaderKey, rcNHKr :: HeaderKey } - deriving (Eq, Show, Generic, FromJSON) - -instance AlgorithmI a => ToJSON (Ratchet a) where - toEncoding = J.genericToEncoding J.defaultOptions + deriving (Eq, Show) data SndRatchet a = SndRatchet { rcDHRr :: PublicKey a, rcCKs :: RatchetKey, rcHKs :: HeaderKey } - deriving (Eq, Show, Generic, FromJSON) - -instance AlgorithmI a => ToJSON (SndRatchet a) where - toEncoding = J.genericToEncoding J.defaultOptions + deriving (Eq, Show) data RcvRatchet = RcvRatchet { rcCKr :: RatchetKey, rcHKr :: HeaderKey } - deriving (Eq, Show, Generic, FromJSON) - -instance ToJSON RcvRatchet where - toEncoding = J.genericToEncoding J.defaultOptions + deriving (Eq, Show) type SkippedMsgKeys = Map HeaderKey SkippedHdrMsgKeys @@ -204,10 +195,6 @@ instance ToJSON RatchetKey where instance FromJSON RatchetKey where parseJSON = fmap RatchetKey . strParseJSON "Key" -instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode - -instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder J.eitherDecodeStrict' - instance ToField MessageKey where toField = toField . smpEncode instance FromField MessageKey where fromField = blobFieldDecoder smpDecode @@ -428,9 +415,9 @@ rcDecrypt rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do | rcNr + maxSkip < untilN = Left $ CERatchetTooManySkipped (untilN + 1 - rcNr) | rcNr == untilN = Right (r, M.empty) | otherwise = - let (rcCKr', rcNr', mks) = advanceRcvRatchet (untilN - rcNr) rcCKr rcNr M.empty - r' = r {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr'} - in Right (r', M.singleton rcHKr mks) + let (rcCKr', rcNr', mks) = advanceRcvRatchet (untilN - rcNr) rcCKr rcNr M.empty + r' = r {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr'} + in Right (r', M.singleton rcHKr mks) advanceRcvRatchet :: Word32 -> RatchetKey -> Word32 -> SkippedHdrMsgKeys -> (RatchetKey, Word32, SkippedHdrMsgKeys) advanceRcvRatchet 0 ck msgNs mks = (ck, msgNs, mks) advanceRcvRatchet n ck msgNs mks = @@ -493,3 +480,23 @@ hkdf3 salt ikm info = (s1, s2, s3) out = H.expand prk info 96 (s1, rest) = B.splitAt 32 out (s2, s3) = B.splitAt 32 rest + +$(JQ.deriveJSON defaultJSON ''RcvRatchet) + +instance AlgorithmI a => ToJSON (SndRatchet a) where + toEncoding = $(JQ.mkToEncoding defaultJSON ''SndRatchet) + toJSON = $(JQ.mkToJSON defaultJSON ''SndRatchet) + +instance AlgorithmI a => FromJSON (SndRatchet a) where + parseJSON = $(JQ.mkParseJSON defaultJSON ''SndRatchet) + +instance AlgorithmI a => ToJSON (Ratchet a) where + toEncoding = $(JQ.mkToEncoding defaultJSON ''Ratchet) + toJSON = $(JQ.mkToJSON defaultJSON ''Ratchet) + +instance AlgorithmI a => FromJSON (Ratchet a) where + parseJSON = $(JQ.mkParseJSON defaultJSON ''Ratchet) + +instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode + +instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder J.eitherDecodeStrict' diff --git a/src/Simplex/Messaging/Encoding.hs b/src/Simplex/Messaging/Encoding.hs index 5d5dec32a..f2b0609bd 100644 --- a/src/Simplex/Messaging/Encoding.hs +++ b/src/Simplex/Messaging/Encoding.hs @@ -109,7 +109,7 @@ lenP = fromIntegral . c2w <$> A.anyChar {-# INLINE lenP #-} instance Encoding a => Encoding (Maybe a) where - smpEncode s = maybe "0" (("1" <>) . smpEncode) s + smpEncode = maybe "0" (("1" <>) . smpEncode) {-# INLINE smpEncode #-} smpP = smpP >>= \case diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 3594a17c2..77b9c10bf 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -518,7 +518,7 @@ instance ToJSON NtfTknStatus where instance FromJSON NtfTknStatus where parseJSON = J.withText "NtfTknStatus" $ either fail pure . smpDecode . encodeUtf8 - + checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e) checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of Just Refl -> Right c diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 5b591fce4..6d1b55d6f 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -16,13 +16,12 @@ import Control.Logger.Simple import Control.Monad import Control.Monad.Except import Control.Monad.Reader -import Data.Bifunctor (second) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Data.Int (Int64) import Data.List (intercalate, sort) -import Data.List.NonEmpty (NonEmpty(..)) +import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M @@ -205,10 +204,10 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge Just err -> (subs, oks, err : errs) -- permanent error, log and don't retry subscription Nothing -> (sub : subs, oks, errs) -- temporary error, retry subscription - -- | Subscribe to queues. The list of results can have a different order. + -- \| Subscribe to queues. The list of results can have a different order. subscribeQueues :: SMPServer -> NonEmpty NtfSubData -> IO (NonEmpty (NtfSubData, Either SMPClientError ())) subscribeQueues srv subs = - L.map (second snd) . L.zip subs <$> subscribeQueuesNtfs ca srv (L.map sub subs) + L.zipWith (\s r -> (s, snd r)) subs <$> subscribeQueuesNtfs ca srv (L.map sub subs) where sub NtfSubData {smpQueue = SMPQueueNtf {notifierId}, notifierKey} = (notifierId, notifierKey) @@ -248,10 +247,11 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge forM errs (\((_, ntfId), err) -> handleSubError (SMPQueueNtf srv ntfId) err) >>= logSubErrors srv . catMaybes . L.toList - logSubStatus srv event n = when (n > 0) $ - logInfo $ "SMP server " <> event <> " " <> showServer' srv <> " (" <> tshow n <> " subscriptions)" + logSubStatus srv event n = + when (n > 0) . logInfo $ + "SMP server " <> event <> " " <> showServer' srv <> " (" <> tshow n <> " subscriptions)" - logSubErrors :: SMPServer -> [NtfSubStatus] -> M () + logSubErrors :: SMPServer -> [NtfSubStatus] -> M () logSubErrors srv errs = forM_ (L.group $ sort errs) $ \errs' -> do logError $ "SMP subscription errors on server " <> showServer' srv <> ": " <> tshow (L.head errs') <> " (" <> tshow (length errs') <> " errors)" @@ -289,14 +289,14 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do case ntf of PNVerification _ | status /= NTInvalid && status /= NTExpired -> - deliverNotification pp tkn ntf >>= \case - Right _ -> do - status_ <- atomically $ stateTVar tknStatus $ \case - NTActive -> (Nothing, NTActive) - NTConfirmed -> (Nothing, NTConfirmed) - _ -> (Just NTConfirmed, NTConfirmed) - forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status' - _ -> pure () + deliverNotification pp tkn ntf >>= \case + Right _ -> do + status_ <- atomically $ stateTVar tknStatus $ \case + NTActive -> (Nothing, NTActive) + NTConfirmed -> (Nothing, NTConfirmed) + _ -> (Just NTConfirmed, NTConfirmed) + forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status' + _ -> pure () | otherwise -> logError "bad notification token status" PNCheckMessages -> checkActiveTkn status $ do void $ deliverNotification pp tkn ntf @@ -345,7 +345,8 @@ runNtfClientTransport th@THandle {sessionId} = do raceAny_ ([liftIO $ send th c, client c s ps, receive th c] <> disconnectThread_ c expCfg) `finally` liftIO (clientDisconnected c) where - disconnectThread_ c expCfg = maybe [] ((: []) . liftIO . disconnectTransport th c activeAt) expCfg + disconnectThread_ c (Just expCfg) = [liftIO $ disconnectTransport th c activeAt expCfg] + disconnectThread_ _ _ = [] clientDisconnected :: NtfServerClient -> IO () clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False @@ -463,16 +464,16 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu else pure $ NRErr AUTH TVFY code -- this allows repeated verification for cases when client connection dropped before server response | (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do - logDebug "TVFY - token verified" - st <- asks store - updateTknStatus tkn NTActive - tIds <- atomically $ removeInactiveTokenRegistrations st tkn - forM_ tIds cancelInvervalNotifications - incNtfStat tknVerified - pure NROk + logDebug "TVFY - token verified" + st <- asks store + updateTknStatus tkn NTActive + tIds <- atomically $ removeInactiveTokenRegistrations st tkn + forM_ tIds cancelInvervalNotifications + incNtfStat tknVerified + pure NROk | otherwise -> do - logDebug "TVFY - incorrect code or token status" - pure $ NRErr AUTH + logDebug "TVFY - incorrect code or token status" + pure $ NRErr AUTH TCHK -> do logDebug "TCHK" pure $ NRTkn status @@ -509,16 +510,16 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu TCRN int | int < 20 -> pure $ NRErr QUOTA | otherwise -> do - logDebug "TCRN" - atomically $ writeTVar tknCronInterval int - atomically (TM.lookup tknId intervalNotifiers) >>= \case - Nothing -> runIntervalNotifier int - Just IntervalNotifier {interval, action} -> - unless (interval == int) $ do - uninterruptibleCancel action - runIntervalNotifier int - withNtfLog $ \s -> logTokenCron s tknId int - pure NROk + logDebug "TCRN" + atomically $ writeTVar tknCronInterval int + atomically (TM.lookup tknId intervalNotifiers) >>= \case + Nothing -> runIntervalNotifier int + Just IntervalNotifier {interval, action} -> + unless (interval == int) $ do + uninterruptibleCancel action + runIntervalNotifier int + withNtfLog $ \s -> logTokenCron s tknId int + pure NROk where runIntervalNotifier interval = do action <- async . intervalNotifier $ fromIntegral interval * 1000000 * 60 diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index d6989c4d4..032da3b89 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -34,7 +34,7 @@ import Simplex.Messaging.Server.Expiration import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport) -import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig) +import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams) import System.IO (IOMode (..)) import System.Mem.Weak (Weak) import UnliftIO.STM diff --git a/src/Simplex/Messaging/Notifications/Server/Main.hs b/src/Simplex/Messaging/Notifications/Server/Main.hs index 216890be2..4e0106aab 100644 --- a/src/Simplex/Messaging/Notifications/Server/Main.hs +++ b/src/Simplex/Messaging/Notifications/Server/Main.hs @@ -7,7 +7,6 @@ module Simplex.Messaging.Notifications.Server.Main where -import Data.Either (fromRight) import Data.Functor (($>)) import Data.Ini (lookupValue, readIniFile) import Data.Maybe (fromMaybe) @@ -92,7 +91,7 @@ ntfServerCLI cfgPath logPath = hSetBuffering stdout LineBuffering hSetBuffering stderr LineBuffering fp <- checkSavedFingerprint cfgPath defaultX509Config - let host = fromRight "" $ T.unpack <$> lookupValue "TRANSPORT" "host" ini + let host = either (const "") T.unpack $ lookupValue "TRANSPORT" "host" ini port = T.unpack $ strictIni "TRANSPORT" "port" ini cfg@NtfServerConfig {transports, storeLogFile} = serverConfig srv = ProtoServerWithAuth (NtfServer [THDomainName host] (if port == "443" then "" else port) (C.KeyHash fp)) Nothing diff --git a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs index c436406f5..fd07e0a02 100644 --- a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs +++ b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs @@ -1,9 +1,9 @@ {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# HLINT ignore "Use newtype instead of data" #-} @@ -23,15 +23,15 @@ import qualified Crypto.Store.PKCS8 as PK import Data.ASN1.BinaryEncoding (DER (..)) import Data.ASN1.Encoding import Data.ASN1.Types -import Data.Aeson (FromJSON, ToJSON, (.=)) +import Data.Aeson (ToJSON, (.=)) import qualified Data.Aeson as J import qualified Data.Aeson.Encoding as JE +import qualified Data.Aeson.TH as JQ import Data.Bifunctor (first) import qualified Data.ByteString.Base64.URL as U import Data.ByteString.Builder (lazyByteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Lazy.Char8 as LB -import qualified Data.CaseInsensitive as CI import Data.Int (Int64) import Data.Map.Strict (Map) import Data.Maybe (isNothing) @@ -40,8 +40,7 @@ import qualified Data.Text as T import Data.Time.Clock.System import qualified Data.X509 as X import qualified Data.X509.CertificateStore as XS -import GHC.Generics (Generic) -import Network.HTTP.Types (HeaderName, Status) +import Network.HTTP.Types (Status) import qualified Network.HTTP.Types as N import Network.HTTP2.Client (Request) import qualified Network.HTTP2.Client as H @@ -49,7 +48,9 @@ import Network.Socket (HostName, ServiceName) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Server.Push.APNS.Internal import Simplex.Messaging.Notifications.Server.Store (NtfTknData (..)) +import Simplex.Messaging.Parsers (defaultJSON) import Simplex.Messaging.Protocol (EncNMsgMeta) import Simplex.Messaging.Transport.HTTP2 (HTTP2Body (..)) import Simplex.Messaging.Transport.HTTP2.Client @@ -61,17 +62,13 @@ data JWTHeader = JWTHeader { alg :: Text, -- key algorithm, ES256 for APNS kid :: Text -- key ID } - deriving (Show, Generic) - -instance ToJSON JWTHeader where toEncoding = J.genericToEncoding J.defaultOptions + deriving (Show) data JWTClaims = JWTClaims { iss :: Text, -- issuer, team ID for APNS iat :: Int64 -- issue time, seconds from epoch } - deriving (Show, Generic) - -instance ToJSON JWTClaims where toEncoding = J.genericToEncoding J.defaultOptions + deriving (Show) data JWTToken = JWTToken JWTHeader JWTClaims deriving (Show) @@ -83,6 +80,10 @@ mkJWTToken hdr iss = do type SignedJWTToken = ByteString +$(JQ.deriveToJSON defaultJSON ''JWTHeader) + +$(JQ.deriveToJSON defaultJSON ''JWTClaims) + signedJWTToken :: EC.PrivateKey -> JWTToken -> IO SignedJWTToken signedJWTToken pk (JWTToken hdr claims) = do let hc = jwtEncode hdr <> "." <> jwtEncode claims @@ -122,24 +123,13 @@ instance StrEncoding PNMessageData where pure PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} data APNSNotification = APNSNotification {aps :: APNSNotificationBody, notificationData :: Maybe J.Value} - deriving (Show, Generic) - -instance ToJSON APNSNotification where - toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True} - toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True} + deriving (Show) data APNSNotificationBody = APNSBackground {contentAvailable :: Int} | APNSMutableContent {mutableContent :: Int, alert :: APNSAlertBody, category :: Maybe Text} | APNSAlert {alert :: APNSAlertBody, badge :: Maybe Int, sound :: Maybe Text, category :: Maybe Text} - deriving (Show, Generic) - -apnsJSONOptions :: J.Options -apnsJSONOptions = J.defaultOptions {J.omitNothingFields = True, J.sumEncoding = J.UntaggedValue, J.fieldLabelModifier = J.camelTo2 '-'} - -instance ToJSON APNSNotificationBody where - toJSON = J.genericToJSON apnsJSONOptions - toEncoding = J.genericToEncoding apnsJSONOptions + deriving (Show) type APNSNotificationData = Map Text Text @@ -305,6 +295,10 @@ apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case -- apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing} +$(JQ.deriveToJSON apnsJSONOptions ''APNSNotificationBody) + +$(JQ.deriveToJSON defaultJSON ''APNSNotification) + apnsRequest :: APNSPushClient -> ByteString -> APNSNotification -> IO Request apnsRequest c tkn ntf@APNSNotification {aps} = do signedJWT <- getApnsJWTToken c @@ -337,7 +331,8 @@ type PushProviderClient = NtfTknData -> PushNotification -> ExceptT PushProvider -- this is not a newtype on purpose to have a correct JSON encoding as a record data APNSErrorResponse = APNSErrorResponse {reason :: Text} - deriving (Generic, FromJSON) + +$(JQ.deriveFromJSON defaultJSON ''APNSErrorResponse) apnsPushProviderClient :: APNSPushClient -> PushProviderClient apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {token = DeviceToken _ tknStr} pn = do @@ -356,15 +351,15 @@ apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {toke result status reason' | status == Just N.ok200 = pure () | status == Just N.badRequest400 = - case reason' of - "BadDeviceToken" -> throwError PPTokenInvalid - "DeviceTokenNotForTopic" -> throwError PPTokenInvalid - "TopicDisallowed" -> throwError PPPermanentError - _ -> err status reason' + case reason' of + "BadDeviceToken" -> throwError PPTokenInvalid + "DeviceTokenNotForTopic" -> throwError PPTokenInvalid + "TopicDisallowed" -> throwError PPPermanentError + _ -> err status reason' | status == Just N.forbidden403 = case reason' of - "ExpiredProviderToken" -> throwError PPPermanentError -- there should be no point retrying it as the token was refreshed - "InvalidProviderToken" -> throwError PPPermanentError - _ -> err status reason' + "ExpiredProviderToken" -> throwError PPPermanentError -- there should be no point retrying it as the token was refreshed + "InvalidProviderToken" -> throwError PPPermanentError + _ -> err status reason' | status == Just N.gone410 = throwError PPTokenInvalid | status == Just N.serviceUnavailable503 = liftIO (disconnectApnsHTTP2Client c) >> throwError PPRetryLater -- Just tooManyRequests429 -> TooManyRequests - too many requests for the same token @@ -372,12 +367,3 @@ apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {toke err :: Maybe Status -> Text -> ExceptT PushProviderError IO () err s r = throwError $ PPResponseError s r liftHTTPS2 a = ExceptT $ first PPConnection <$> a - -hApnsTopic :: HeaderName -hApnsTopic = CI.mk "apns-topic" - -hApnsPushType :: HeaderName -hApnsPushType = CI.mk "apns-push-type" - -hApnsPriority :: HeaderName -hApnsPriority = CI.mk "apns-priority" diff --git a/src/Simplex/Messaging/Notifications/Server/Push/APNS/Internal.hs b/src/Simplex/Messaging/Notifications/Server/Push/APNS/Internal.hs new file mode 100644 index 000000000..8e79d40e3 --- /dev/null +++ b/src/Simplex/Messaging/Notifications/Server/Push/APNS/Internal.hs @@ -0,0 +1,20 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Simplex.Messaging.Notifications.Server.Push.APNS.Internal where + +import qualified Data.Aeson as J +import qualified Data.CaseInsensitive as CI +import Network.HTTP.Types (HeaderName) +import Simplex.Messaging.Parsers (defaultJSON) + +hApnsTopic :: HeaderName +hApnsTopic = CI.mk "apns-topic" + +hApnsPushType :: HeaderName +hApnsPushType = CI.mk "apns-push-type" + +hApnsPriority :: HeaderName +hApnsPriority = CI.mk "apns-priority" + +apnsJSONOptions :: J.Options +apnsJSONOptions = defaultJSON {J.sumEncoding = J.UntaggedValue, J.fieldLabelModifier = J.camelTo2 '-'} diff --git a/src/Simplex/Messaging/Notifications/Server/Store.hs b/src/Simplex/Messaging/Notifications/Server/Store.hs index 7be6b3d54..b7750ae2c 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store.hs @@ -4,7 +4,6 @@ {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Notifications.Server.Store where diff --git a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs b/src/Simplex/Messaging/Notifications/Server/StoreLog.hs index 441f60ec2..3ed28eb52 100644 --- a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs +++ b/src/Simplex/Messaging/Notifications/Server/StoreLog.hs @@ -5,7 +5,6 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE StrictData #-} - {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module Simplex.Messaging.Notifications.Server.StoreLog diff --git a/src/Simplex/Messaging/Notifications/Transport.hs b/src/Simplex/Messaging/Notifications/Transport.hs index 33abe56a0..54bd354bb 100644 --- a/src/Simplex/Messaging/Notifications/Transport.hs +++ b/src/Simplex/Messaging/Notifications/Transport.hs @@ -50,9 +50,9 @@ ntfServerHandshake c kh ntfVRange = do getHandshake th >>= \case NtfClientHandshake {ntfVersion, keyHash} | keyHash /= kh -> - throwError $ TEHandshake IDENTITY - | ntfVersion `isCompatible` ntfVRange -> do - pure (th :: THandle c) {thVersion = ntfVersion} + throwError $ TEHandshake IDENTITY + | ntfVersion `isCompatible` ntfVRange -> + pure (th :: THandle c) {thVersion = ntfVersion} | otherwise -> throwError $ TEHandshake VERSION -- | Notifcations server client transport handshake. diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 3363cbcc9..39cb0383c 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -85,7 +85,7 @@ blobFieldDecoder dec = \case Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e) f -> returnError ConversionFailed f "expecting SQLBlob column type" -fromTextField_ :: (Typeable a) => (Text -> Maybe a) -> Field -> Ok a +fromTextField_ :: Typeable a => (Text -> Maybe a) -> Field -> Ok a fromTextField_ fromText = \case f@(Field (SQLText t) _) -> case fromText t of @@ -151,3 +151,6 @@ singleFieldJSON_ objectTag tagModifier = J.nullaryToObject = True, J.omitNothingFields = True } + +defaultJSON :: J.Options +defaultJSON = J.defaultOptions {J.omitNothingFields = True} diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index a3c014d61..3d5d44c7d 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1,7 +1,6 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -17,6 +16,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilyDependencies #-} @@ -135,7 +135,7 @@ module Simplex.Messaging.Protocol noAuthSrv, -- * TCP transport functions - TransportBatch(..), + TransportBatch (..), tPut, tPutLog, tGet, @@ -155,7 +155,7 @@ import Control.Applicative (optional, (<|>)) import Control.Concurrent (threadDelay) import Control.Monad.Except import Data.Aeson (FromJSON (..), ToJSON (..)) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser, ()) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) @@ -170,9 +170,7 @@ import Data.Maybe (isJust, isNothing) import Data.String import Data.Time.Clock.System (SystemTime (..)) import Data.Type.Equality -import GHC.Generics (Generic) import GHC.TypeLits (ErrorMessage (..), TypeError, type (+)) -import Generic.Random (genericArbitraryU) import Network.Socket (HostName, ServiceName) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding @@ -182,7 +180,6 @@ import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts (..)) import Simplex.Messaging.Util (bshow, eitherToMaybe, (<$?>)) import Simplex.Messaging.Version -import Test.QuickCheck (Arbitrary (..)) currentSMPClientVersion :: Version currentSMPClientVersion = 2 @@ -309,17 +306,19 @@ instance StrEncoding SubscriptionMode where SMSubscribe -> "subscribe" SMOnlyCreate -> "only-create" strP = - (A.string "subscribe" $> SMSubscribe) <|> (A.string "only-create" $> SMOnlyCreate) - "SubscriptionMode" + (A.string "subscribe" $> SMSubscribe) + <|> (A.string "only-create" $> SMOnlyCreate) + "SubscriptionMode" instance Encoding SubscriptionMode where smpEncode = \case SMSubscribe -> "S" SMOnlyCreate -> "C" - smpP = A.anyChar >>= \case - 'S' -> pure SMSubscribe - 'C' -> pure SMOnlyCreate - _ -> fail "bad SubscriptionMode" + smpP = + A.anyChar >>= \case + 'S' -> pure SMSubscribe + 'C' -> pure SMOnlyCreate + _ -> fail "bad SubscriptionMode" data BrokerMsg where -- SMP broker messages (responses, client messages, notifications) @@ -472,9 +471,7 @@ instance Encoding NMsgMeta where -- it must be data for correct JSON encoding data MsgFlags = MsgFlags {notification :: Bool} - deriving (Eq, Show, Generic, FromJSON) - -instance ToJSON MsgFlags where toEncoding = J.genericToEncoding J.defaultOptions + deriving (Eq, Show) -- this encoding should not become bigger than 7 bytes (currently it is 1 byte) instance Encoding MsgFlags where @@ -997,14 +994,7 @@ data ErrorType INTERNAL | -- | used internally, never returned by the server (to be removed) DUPLICATE_ -- not part of SMP protocol, used internally - deriving (Eq, Generic, Read, Show) - -instance ToJSON ErrorType where - toJSON = J.genericToJSON $ sumTypeJSON id - toEncoding = J.genericToEncoding $ sumTypeJSON id - -instance FromJSON ErrorType where - parseJSON = J.genericParseJSON $ sumTypeJSON id + deriving (Eq, Read, Show) instance StrEncoding ErrorType where strEncode = \case @@ -1026,18 +1016,7 @@ data CommandError HAS_AUTH | -- | transmission has no required entity ID (e.g. SMP queue) NO_ENTITY - deriving (Eq, Generic, Read, Show) - -instance ToJSON CommandError where - toJSON = J.genericToJSON $ sumTypeJSON id - toEncoding = J.genericToEncoding $ sumTypeJSON id - -instance FromJSON CommandError where - parseJSON = J.genericParseJSON $ sumTypeJSON id - -instance Arbitrary ErrorType where arbitrary = genericArbitraryU - -instance Arbitrary CommandError where arbitrary = genericArbitraryU + deriving (Eq, Read, Show) -- | SMP transmission parser. transmissionP :: Parser RawTransmission @@ -1306,7 +1285,7 @@ tPutLog th s = do pure r -- ByteString does not include length byte, it is added by tEncodeBatch -data TransportBatch = TBTransmissions Int ByteString | TBTransmission ByteString | TBLargeTransmission +data TransportBatch = TBTransmissions Int ByteString | TBTransmission ByteString | TBLargeTransmission -- | encodes and batches transmissions into blocks, batchTransmissions :: Bool -> Int -> NonEmpty SentRawTransmission -> [TransportBatch] @@ -1319,22 +1298,22 @@ batchTransmissions batch bSize let (n, s, ts_) = encodeBatch 0 "" ts r = if n == 0 then TBLargeTransmission else TBTransmissions n s rs' = r : rs - in case ts_ of - Just ts' -> mkBatch rs' ts' - _ -> rs' + in case ts_ of + Just ts' -> mkBatch rs' ts' + _ -> rs' mkBatch1 :: ByteString -> TransportBatch mkBatch1 s = if B.length s > bSize - 2 then TBLargeTransmission else TBTransmission s encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString)) encodeBatch n s ts@(t :| ts_) | n == 255 = (n, s, Just ts) | otherwise = - let s' = s <> smpEncode (Large t) - n' = n + 1 - in if B.length s' > bSize - 3 -- one byte is reserved for the number of messages in the batch - then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts - else case L.nonEmpty ts_ of - Just ts' -> encodeBatch n' s' ts' - _ -> (n', s', Nothing) + let s' = s <> smpEncode (Large t) + n' = n + 1 + in if B.length s' > bSize - 3 -- one byte is reserved for the number of messages in the batch + then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts + else case L.nonEmpty ts_ of + Just ts' -> encodeBatch n' s' ts' + _ -> (n', s', Nothing) tEncode :: SentRawTransmission -> ByteString tEncode (sig, t) = smpEncode (C.signatureBytes sig) <> t @@ -1373,8 +1352,8 @@ tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => SessionId -> tDecodeParseValidate sessionId v = \case Right RawTransmission {signature, signed, sessId, corrId, entityId, command} | sessId == sessionId -> - let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature - in either (const $ tError corrId) (tParseValidate signed) decodedTransmission + let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature + in either (const $ tError corrId) (tParseValidate signed) decodedTransmission | otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession)) Left _ -> tError "" where @@ -1385,3 +1364,9 @@ tDecodeParseValidate sessionId v = \case tParseValidate signed t@(sig, corrId, entityId, command) = let cmd = parseProtocol @err @cmd v command >>= checkCredentials t in (sig, signed, (CorrId corrId, entityId, cmd)) + +$(J.deriveJSON defaultJSON ''MsgFlags) + +$(J.deriveJSON (sumTypeJSON id) ''CommandError) + +$(J.deriveJSON (sumTypeJSON id) ''ErrorType) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 955e86731..326ad0d8e 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -1,6 +1,5 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} -{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} @@ -8,6 +7,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -116,9 +116,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do restoreServerMessages restoreServerStats raceAny_ - ( serverThread s "server subscribedQ" subscribedQ subscribers subscriptions cancelSub : - serverThread s "server ntfSubscribedQ" ntfSubscribedQ Env.notifiers ntfSubscriptions (\_ -> pure ()) : - map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg + ( serverThread s "server subscribedQ" subscribedQ subscribers subscriptions cancelSub + : serverThread s "server ntfSubscribedQ" ntfSubscribedQ Env.notifiers ntfSubscriptions (\_ -> pure ()) + : map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg ) `finally` withLock (savingLock s) "final" (saveServer False) where @@ -148,7 +148,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do updateSubscribers :: STM (Maybe (QueueId, Client)) updateSubscribers = do (qId, clnt) <- readTQueue $ subQ s - let clientToBeNotified = \c' -> + let clientToBeNotified c' = if sameClientSession clnt c' then pure Nothing else do @@ -277,9 +277,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do hPutStrLn h $ "Clients: " <> show (length clients) forM_ (M.toList clients) $ \(cid, Client {sessionId, connected, activeAt, subscriptions}) -> do hPutStrLn h . B.unpack $ "Client " <> encode cid <> " $" <> encode sessionId - readTVarIO connected >>= hPutStrLn h . (" connected: " <>) . show - readTVarIO activeAt >>= hPutStrLn h . (" activeAt: " <>) . B.unpack . strEncode - readTVarIO subscriptions >>= hPutStrLn h . (" subscriptions: " <>) . show . M.size + readTVarIO connected >>= hPutStrLn h . (" connected: " <>) . show + readTVarIO activeAt >>= hPutStrLn h . (" activeAt: " <>) . B.unpack . strEncode + readTVarIO subscriptions >>= hPutStrLn h . (" subscriptions: " <>) . show . M.size CPStats -> do ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, msgSentNtf, msgRecvNtf, qCount, msgCount} <- unliftIO u $ asks serverStats putStat "fromTime" fromTime @@ -666,27 +666,27 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv sendMessage qr msgFlags msgBody | B.length msgBody > maxMessageLength = pure $ err LARGE_MSG | otherwise = case status qr of - QueueOff -> return $ err AUTH - QueueActive -> - case C.maxLenBS msgBody of - Left _ -> pure $ err LARGE_MSG - Right body -> do - msg_ <- time "SEND" $ do - q <- getStoreMsgQueue "SEND" $ recipientId qr - expireMessages q - atomically . writeMsg q =<< mkMessage body - case msg_ of - Nothing -> pure $ err QUOTA - Just msg -> time "SEND ok" $ do - stats <- asks serverStats - when (notification msgFlags) $ do - atomically . trySendNotification msg =<< asks idsDrg - atomically $ modifyTVar' (msgSentNtf stats) (+ 1) - atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr) - atomically $ modifyTVar' (msgSent stats) (+ 1) - atomically $ modifyTVar' (msgCount stats) (subtract 1) - atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) - pure ok + QueueOff -> return $ err AUTH + QueueActive -> + case C.maxLenBS msgBody of + Left _ -> pure $ err LARGE_MSG + Right body -> do + msg_ <- time "SEND" $ do + q <- getStoreMsgQueue "SEND" $ recipientId qr + expireMessages q + atomically . writeMsg q =<< mkMessage body + case msg_ of + Nothing -> pure $ err QUOTA + Just msg -> time "SEND ok" $ do + stats <- asks serverStats + when (notification msgFlags) $ do + atomically . trySendNotification msg =<< asks idsDrg + atomically $ modifyTVar' (msgSentNtf stats) (+ 1) + atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr) + atomically $ modifyTVar' (msgSent stats) (+ 1) + atomically $ modifyTVar' (msgCount stats) (subtract 1) + atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) + pure ok where mkMessage :: C.MaxLenBS MaxMessageLen -> m Message mkMessage body = do @@ -767,7 +767,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv msgTs' = msg.msgTs setDelivered :: Sub -> Message -> STM Bool - setDelivered s msg = tryPutTMVar (delivered s) $ msg.msgId + setDelivered s msg = tryPutTMVar (delivered s) msg.msgId getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 386355e80..518667b26 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -31,7 +31,7 @@ import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport) -import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig) +import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams) import Simplex.Messaging.Version import System.IO (IOMode (..)) import System.Mem.Weak (Weak) diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 749a8d6ba..324ed49c8 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -11,7 +11,6 @@ module Simplex.Messaging.Server.Main where import Control.Monad (void) import Crypto.Random (getRandomBytes) import qualified Data.ByteString.Char8 as B -import Data.Either (fromRight) import Data.Functor (($>)) import Data.Ini (lookupValue, readIniFile) import Data.Maybe (fromMaybe) @@ -24,7 +23,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (BasicAuth (..), ProtoServerWithAuth (ProtoServerWithAuth), pattern SMPServer) import Simplex.Messaging.Server (runSMPServer) import Simplex.Messaging.Server.CLI -import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defaultInactiveClientExpiration, defaultMessageExpiration, defMsgExpirationDays) +import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defMsgExpirationDays, defaultInactiveClientExpiration, defaultMessageExpiration) import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Transport (simplexMQVersion, supportedSMPServerVRange) import Simplex.Messaging.Transport.Client (TransportHost (..)) @@ -60,15 +59,15 @@ smpServerCLI cfgPath logPath = initializeServer opts | scripted opts = initialize opts | otherwise = do - putStrLn "Use `smp-server init -h` for available options." - void $ withPrompt "SMP server will be initialized (press Enter)" getLine - enableStoreLog <- onOffPrompt "Enable store log to restore queues and messages on server restart" True - logStats <- onOffPrompt "Enable logging daily statistics" False - putStrLn "Require a password to create new messaging queues?" - password <- withPrompt "'r' for random (default), 'n' - no password, or enter password: " serverPassword - let host = fromMaybe (ip opts) (fqdn opts) - host' <- withPrompt ("Enter server FQDN or IP address for certificate (" <> host <> "): ") getLine - initialize opts {enableStoreLog, logStats, fqdn = if null host' then fqdn opts else Just host', password} + putStrLn "Use `smp-server init -h` for available options." + void $ withPrompt "SMP server will be initialized (press Enter)" getLine + enableStoreLog <- onOffPrompt "Enable store log to restore queues and messages on server restart" True + logStats <- onOffPrompt "Enable logging daily statistics" False + putStrLn "Require a password to create new messaging queues?" + password <- withPrompt "'r' for random (default), 'n' - no password, or enter password: " serverPassword + let host = fromMaybe (ip opts) (fqdn opts) + host' <- withPrompt ("Enter server FQDN or IP address for certificate (" <> host <> "): ") getLine + initialize opts {enableStoreLog, logStats, fqdn = if null host' then fqdn opts else Just host', password} where serverPassword = getLine >>= \case @@ -121,8 +120,8 @@ smpServerCLI cfgPath logPath = \# The password will not be shared with the connecting contacts, you must share it only\n\ \# with the users who you want to allow creating messaging queues on your server.\n" <> ( case basicAuth of - Just auth -> "create_password: " <> T.unpack (safeDecodeUtf8 $ strEncode auth) - _ -> "# create_password: password to create new queues (any printable ASCII characters without whitespace, '@', ':' and '/')" + Just auth -> "create_password: " <> T.unpack (safeDecodeUtf8 $ strEncode auth) + _ -> "# create_password: password to create new queues (any printable ASCII characters without whitespace, '@', ':' and '/')" ) <> "\n\n\ \[TRANSPORT]\n\ @@ -141,7 +140,7 @@ smpServerCLI cfgPath logPath = hSetBuffering stdout LineBuffering hSetBuffering stderr LineBuffering fp <- checkSavedFingerprint cfgPath defaultX509Config - let host = fromRight "" $ T.unpack <$> lookupValue "TRANSPORT" "host" ini + let host = either (const "") T.unpack $ lookupValue "TRANSPORT" "host" ini port = T.unpack $ strictIni "TRANSPORT" "port" ini cfg@ServerConfig {transports, storeLogFile, newQueueBasicAuth, messageExpiration, inactiveClientExpiration} = serverConfig srv = ProtoServerWithAuth (SMPServer [THDomainName host] (if port == "5223" then "" else port) (C.KeyHash fp)) newQueueBasicAuth @@ -186,9 +185,10 @@ smpServerCLI cfgPath logPath = allowNewQueues = fromMaybe True $ iniOnOff "AUTH" "new_queues" ini, newQueueBasicAuth = either error id <$> strDecodeIni "AUTH" "create_password" ini, messageExpiration = - Just defaultMessageExpiration - { ttl = 86400 * readIniDefault defMsgExpirationDays "STORE_LOG" "expire_messages_days" ini - }, + Just + defaultMessageExpiration + { ttl = 86400 * readIniDefault defMsgExpirationDays "STORE_LOG" "expire_messages_days" ini + }, inactiveClientExpiration = settingIsOn "INACTIVE_CLIENTS" "disconnect" ini $> ExpirationConfig diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 95e425d8e..74f204103 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -64,7 +64,7 @@ flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue) snapshotMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] -snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue . msgQueue) +snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue . msgQueue) where snapshotTQueue q = do msgs <- flushTQueue q diff --git a/src/Simplex/Messaging/Server/StoreLog.hs b/src/Simplex/Messaging/Server/StoreLog.hs index fceae16f4..80fb178d4 100644 --- a/src/Simplex/Messaging/Server/StoreLog.hs +++ b/src/Simplex/Messaging/Server/StoreLog.hs @@ -4,7 +4,6 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE TupleSections #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Simplex.Messaging.Server.StoreLog diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index de49da35a..6509c1f6f 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -1,7 +1,6 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} @@ -10,6 +9,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} -- | @@ -63,8 +63,7 @@ where import Control.Applicative ((<|>)) import Control.Monad.Except import Control.Monad.Trans.Except (throwE) -import Data.Aeson (FromJSON, ToJSON) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser) import Data.Bifunctor (first) import Data.Bitraversable (bimapM) @@ -74,9 +73,7 @@ import qualified Data.ByteString.Lazy as BL import Data.Default (def) import Data.Functor (($>)) import Data.Version (showVersion) -import GHC.Generics (Generic) import GHC.IO.Handle.Internals (ioe_EOF) -import Generic.Random (genericArbitraryU) import Network.Socket import qualified Network.TLS as T import qualified Network.TLS.Extra as TE @@ -87,7 +84,6 @@ import Simplex.Messaging.Parsers (dropPrefix, parse, parseRead1, sumTypeJSON) import Simplex.Messaging.Transport.Buffer import Simplex.Messaging.Util (bshow, catchAll, catchAll_) import Simplex.Messaging.Version -import Test.QuickCheck (Arbitrary (..)) import UnliftIO.Exception (Exception) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -284,14 +280,7 @@ data TransportError TEBadSession | -- | transport handshake error TEHandshake {handshakeErr :: HandshakeError} - deriving (Eq, Generic, Read, Show, Exception) - -instance ToJSON TransportError where - toJSON = J.genericToJSON . sumTypeJSON $ dropPrefix "TE" - toEncoding = J.genericToEncoding . sumTypeJSON $ dropPrefix "TE" - -instance FromJSON TransportError where - parseJSON = J.genericParseJSON . sumTypeJSON $ dropPrefix "TE" + deriving (Eq, Read, Show, Exception) -- | Transport handshake error. data HandshakeError @@ -301,18 +290,7 @@ data HandshakeError VERSION | -- | incorrect server identity IDENTITY - deriving (Eq, Generic, Read, Show, Exception) - -instance ToJSON HandshakeError where - toJSON = J.genericToJSON $ sumTypeJSON id - toEncoding = J.genericToEncoding $ sumTypeJSON id - -instance FromJSON HandshakeError where - parseJSON = J.genericParseJSON $ sumTypeJSON id - -instance Arbitrary TransportError where arbitrary = genericArbitraryU - -instance Arbitrary HandshakeError where arbitrary = genericArbitraryU + deriving (Eq, Read, Show, Exception) -- | SMP encrypted transport error parser. transportErrorP :: Parser TransportError @@ -354,9 +332,9 @@ smpServerHandshake c kh smpVRange = do getHandshake th >>= \case ClientHandshake {smpVersion, keyHash} | keyHash /= kh -> - throwE $ TEHandshake IDENTITY + throwE $ TEHandshake IDENTITY | smpVersion `isCompatible` smpVRange -> do - pure $ smpThHandle th smpVersion + pure $ smpThHandle th smpVersion | otherwise -> throwE $ TEHandshake VERSION -- | Client SMP transport handshake. @@ -385,3 +363,7 @@ getHandshake th = ExceptT $ (parse smpP (TEHandshake PARSE) =<<) <$> tGetBlock t smpTHandle :: Transport c => c -> THandle c smpTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, batch = False} + +$(J.deriveJSON (sumTypeJSON id) ''HandshakeError) + +$(J.deriveJSON (sumTypeJSON $ dropPrefix "TE") ''TransportError) diff --git a/src/Simplex/Messaging/Transport/Buffer.hs b/src/Simplex/Messaging/Transport/Buffer.hs index 141690386..251471679 100644 --- a/src/Simplex/Messaging/Transport/Buffer.hs +++ b/src/Simplex/Messaging/Transport/Buffer.hs @@ -8,8 +8,8 @@ import Control.Concurrent.STM import qualified Control.Exception as E import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import GHC.IO.Exception (IOErrorType (..), IOException (..), ioException) import System.Timeout (timeout) -import GHC.IO.Exception (ioException, IOException (..), IOErrorType (..)) data TBuffer = TBuffer { buffer :: TVar ByteString, @@ -41,9 +41,9 @@ getBuffered tb@TBuffer {buffer} n t_ getChunk = withBufferLock tb $ do readChunks firstChunk b | B.length b >= n = pure b | otherwise = - get >>= \case - "" -> pure b - s -> readChunks False $ b <> s + get >>= \case + "" -> pure b + s -> readChunks False $ b <> s where get | firstChunk = getChunk diff --git a/src/Simplex/Messaging/Transport/Credentials.hs b/src/Simplex/Messaging/Transport/Credentials.hs index a44dd9ead..db03b5c3a 100644 --- a/src/Simplex/Messaging/Transport/Credentials.hs +++ b/src/Simplex/Messaging/Transport/Credentials.hs @@ -1,7 +1,5 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} module Simplex.Messaging.Transport.Credentials ( tlsCredentials, diff --git a/src/Simplex/Messaging/Transport/HTTP2/Client.hs b/src/Simplex/Messaging/Transport/HTTP2/Client.hs index 73fa13786..449a9bc59 100644 --- a/src/Simplex/Messaging/Transport/HTTP2/Client.hs +++ b/src/Simplex/Messaging/Transport/HTTP2/Client.hs @@ -22,10 +22,9 @@ import qualified Network.TLS as T import Numeric.Natural (Natural) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Transport (SessionId) +import Simplex.Messaging.Transport (SessionId, TLS) import Simplex.Messaging.Transport.Client (TransportClientConfig (..), TransportHost (..), runTLSTransportClient) import Simplex.Messaging.Transport.HTTP2 -import Simplex.Messaging.Transport (TLS) import UnliftIO.STM import UnliftIO.Timeout diff --git a/src/Simplex/Messaging/Transport/HTTP2/File.hs b/src/Simplex/Messaging/Transport/HTTP2/File.hs new file mode 100644 index 000000000..10238f161 --- /dev/null +++ b/src/Simplex/Messaging/Transport/HTTP2/File.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE MultiWayIf #-} + +module Simplex.Messaging.Transport.HTTP2.File where + +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import Data.ByteString.Builder (Builder, byteString) +import Data.Int (Int64) +import Data.Word (Word32) +import GHC.IO.Handle.Internals (ioe_EOF) +import System.IO (Handle) + +fileBlockSize :: Int +fileBlockSize = 16384 + +hReceiveFile :: (Int -> IO ByteString) -> Handle -> Word32 -> IO Int64 +hReceiveFile _ _ 0 = pure 0 +hReceiveFile getBody h size = get $ fromIntegral size + where + get sz = do + ch <- getBody fileBlockSize + let chSize = fromIntegral $ B.length ch + if + | chSize > sz -> pure (chSize - sz) + | chSize > 0 -> B.hPut h ch >> get (sz - chSize) + | otherwise -> pure (-fromIntegral sz) + +hSendFile :: Handle -> (Builder -> IO ()) -> Word32 -> IO () +hSendFile h send = go + where + go 0 = pure () + go sz = + getFileChunk h sz >>= \ch -> do + send $ byteString ch + go $ sz - fromIntegral (B.length ch) + +getFileChunk :: Handle -> Word32 -> IO ByteString +getFileChunk h sz = do + ch <- B.hGet h fileBlockSize + if B.null ch + then ioe_EOF + else pure $ B.take (fromIntegral sz) ch -- sz >= xftpBlockSize diff --git a/src/Simplex/Messaging/Transport/HTTP2/Server.hs b/src/Simplex/Messaging/Transport/HTTP2/Server.hs index ad4849c9d..139205235 100644 --- a/src/Simplex/Messaging/Transport/HTTP2/Server.hs +++ b/src/Simplex/Messaging/Transport/HTTP2/Server.hs @@ -1,5 +1,4 @@ {-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE OverloadedStrings #-} module Simplex.Messaging.Transport.HTTP2.Server where diff --git a/src/Simplex/Messaging/Transport/KeepAlive.hs b/src/Simplex/Messaging/Transport/KeepAlive.hs index 35ef21fb5..52d5e7aaf 100644 --- a/src/Simplex/Messaging/Transport/KeepAlive.hs +++ b/src/Simplex/Messaging/Transport/KeepAlive.hs @@ -1,25 +1,22 @@ {-# LANGUAGE CApiFFI #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE TemplateHaskell #-} module Simplex.Messaging.Transport.KeepAlive where -import Data.Aeson (FromJSON (..), ToJSON (..)) -import qualified Data.Aeson as J +import qualified Data.Aeson.TH as J import Foreign.C (CInt (..)) -import GHC.Generics (Generic) import Network.Socket +import Simplex.Messaging.Parsers (defaultJSON) data KeepAliveOpts = KeepAliveOpts { keepIdle :: Int, keepIntvl :: Int, keepCnt :: Int } - deriving (Eq, Show, Generic, FromJSON) - -instance ToJSON KeepAliveOpts where toEncoding = J.genericToEncoding J.defaultOptions + deriving (Eq, Show) defaultKeepAliveOpts :: KeepAliveOpts defaultKeepAliveOpts = @@ -68,3 +65,5 @@ setSocketKeepAlive sock KeepAliveOpts {keepCnt, keepIdle, keepIntvl} = do setSocketOption sock (SockOpt _SOL_TCP _TCP_KEEPIDLE) keepIdle setSocketOption sock (SockOpt _SOL_TCP _TCP_KEEPINTVL) keepIntvl setSocketOption sock (SockOpt _SOL_TCP _TCP_KEEPCNT) keepCnt + +$(J.deriveJSON defaultJSON ''KeepAliveOpts) diff --git a/src/Simplex/Messaging/Transport/Server.hs b/src/Simplex/Messaging/Transport/Server.hs index 8876135b1..806123e9f 100644 --- a/src/Simplex/Messaging/Transport/Server.hs +++ b/src/Simplex/Messaging/Transport/Server.hs @@ -5,10 +5,13 @@ {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Transport.Server - ( runTransportServer, - runTCPServer, - TransportServerConfig (..), + ( TransportServerConfig (..), defaultTransportServerConfig, + runTransportServer, + runTransportServerSocket, + runTCPServer, + runTCPServerSocket, + startTCPServer, loadSupportedTLSServerParams, loadTLSServerParams, loadFingerprint, @@ -46,10 +49,11 @@ data TransportServerConfig = TransportServerConfig deriving (Eq, Show) defaultTransportServerConfig :: TransportServerConfig -defaultTransportServerConfig = TransportServerConfig - { logTLSErrors = True, - transportTimeout = 40000000 - } +defaultTransportServerConfig = + TransportServerConfig + { logTLSErrors = True, + transportTimeout = 40000000 + } serverTransportConfig :: TransportServerConfig -> TransportConfig serverTransportConfig TransportServerConfig {logTLSErrors} = @@ -60,11 +64,15 @@ serverTransportConfig TransportServerConfig {logTLSErrors} = -- -- All accepted connections are passed to the passed function. runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m () -runTransportServer started port serverParams cfg server = do +runTransportServer started port = runTransportServerSocket started (startTCPServer started port) (transportName (TProxy :: TProxy c)) + +-- | Run a transport server with provided connection setup and handler. +runTransportServerSocket :: (MonadUnliftIO m, T.TLSParams p, Transport a) => TMVar Bool -> IO Socket -> String -> p -> TransportServerConfig -> (a -> m ()) -> m () +runTransportServerSocket started getSocket threadLabel serverParams cfg server = do u <- askUnliftIO let tCfg = serverTransportConfig cfg - labelMyThread $ "transport server for " <> transportName (TProxy :: TProxy c) - liftIO . runTCPServer started port $ \conn -> + labelMyThread $ "transport server for " <> threadLabel + liftIO . runTCPServerSocket started getSocket $ \conn -> E.bracket (connectTLS Nothing tCfg serverParams conn >>= getServerConnection tCfg) closeConnection @@ -72,11 +80,15 @@ runTransportServer started port serverParams cfg server = do -- | Run TCP server without TLS runTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO () -runTCPServer started port server = do +runTCPServer started port = runTCPServerSocket started $ startTCPServer started port + +-- | Wrap socket provider in a TCP server bracket. +runTCPServerSocket :: TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO () +runTCPServerSocket started getSocket server = do clients <- atomically TM.empty clientId <- newTVarIO 0 E.bracket - (startTCPServer started port) + getSocket (closeServer started clients) $ \sock -> forever . E.bracketOnError (accept sock) (close . fst) $ \(conn, _peer) -> do -- catchAll_ is needed here in case the connection was closed earlier diff --git a/src/Simplex/Messaging/Transport/WebSockets.hs b/src/Simplex/Messaging/Transport/WebSockets.hs index a0633e09e..ae78da1fe 100644 --- a/src/Simplex/Messaging/Transport/WebSockets.hs +++ b/src/Simplex/Messaging/Transport/WebSockets.hs @@ -15,9 +15,9 @@ import qualified Network.WebSockets.Stream as S import Simplex.Messaging.Transport ( TProxy, Transport (..), + TransportConfig (..), TransportError (..), TransportPeer (..), - TransportConfig (..), closeTLS, smpBlockSize, withTlsUnique, diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index f235a3341..2dca0956a 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -62,7 +62,7 @@ liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a) liftEitherError f a = liftIOEither (first f <$> a) {-# INLINE liftEitherError #-} -liftEitherWith :: (MonadError e' m) => (e -> e') -> Either e a -> m a +liftEitherWith :: MonadError e' m => (e -> e') -> Either e a -> m a liftEitherWith f = liftEither . first f {-# INLINE liftEitherWith #-} @@ -102,7 +102,7 @@ catchAllErrors err action handle = tryAllErrors err action >>= either handle pur {-# INLINE catchAllErrors #-} catchThrow :: (MonadUnliftIO m, MonadError e m) => m a -> (E.SomeException -> e) -> m a -catchThrow action err = catchAllErrors err action throwError +catchThrow action err = catchAllErrors err action throwError {-# INLINE catchThrow #-} allFinally :: (MonadUnliftIO m, MonadError e m) => (E.SomeException -> e) -> m a -> m b -> m a @@ -115,12 +115,12 @@ eitherToMaybe = either (const Nothing) Just groupOn :: Eq k => (a -> k) -> [a] -> [[a]] groupOn = groupBy . eqOn - -- it is equivalent to groupBy ((==) `on` f), - -- but it redefines `on` to avoid duplicate computation for most values. - -- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn - -- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f` where - eqOn f = \x -> let fx = f x in \y -> fx == f y + -- it is equivalent to groupBy ((==) `on` f), + -- but it redefines `on` to avoid duplicate computation for most values. + -- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn + -- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f` + eqOn f x = let fx = f x in \y -> fx == f y groupAllOn :: Ord k => (a -> k) -> [a] -> [[a]] groupAllOn f = groupOn f . sortOn f @@ -129,7 +129,7 @@ toChunks :: Int -> [a] -> [NonEmpty a] toChunks _ [] = [] toChunks n xs = let (ys, xs') = splitAt n xs - in maybe id (:) (L.nonEmpty ys) (toChunks n xs') + in maybe id (:) (L.nonEmpty ys) (toChunks n xs') safeDecodeUtf8 :: ByteString -> Text safeDecodeUtf8 = decodeUtf8With onError diff --git a/stack.yaml b/stack.yaml index 58f50b42f..4ba98eedb 100644 --- a/stack.yaml +++ b/stack.yaml @@ -49,7 +49,7 @@ extra-deps: - github: simplex-chat/aeson commit: aab7b5a14d6c5ea64c64dcaee418de1bb00dcc2b - github: kazu-yamamoto/http2 - commit: b5a1b7200cf5bc7044af34ba325284271f6dff25 + commit: 804fa283f067bd3fd89b8c5f8d25b3047813a517 # - ../direct-sqlcipher - github: simplex-chat/direct-sqlcipher commit: f814ee68b16a9447fbb467ccc8f29bdd3546bfd9 diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index fbd2a54ed..a00a6985b 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -45,42 +45,42 @@ agentTests (ATransport t) = do describe "Migration tests" migrationTests describe "SMP agent protocol syntax" $ syntaxTests t describe "Establishing duplex connection" $ do - it "should connect via one server and one agent" $ + it "should connect via one server and one agent" $ do smpAgentTest2_1_1 $ testDuplexConnection t - it "should connect via one server and one agent (random IDs)" $ + it "should connect via one server and one agent (random IDs)" $ do smpAgentTest2_1_1 $ testDuplexConnRandomIds t - it "should connect via one server and 2 agents" $ + it "should connect via one server and 2 agents" $ do smpAgentTest2_2_1 $ testDuplexConnection t - it "should connect via one server and 2 agents (random IDs)" $ + it "should connect via one server and 2 agents (random IDs)" $ do smpAgentTest2_2_1 $ testDuplexConnRandomIds t - it "should connect via 2 servers and 2 agents" $ + it "should connect via 2 servers and 2 agents" $ do smpAgentTest2_2_2 $ testDuplexConnection t - it "should connect via 2 servers and 2 agents (random IDs)" $ + it "should connect via 2 servers and 2 agents (random IDs)" $ do smpAgentTest2_2_2 $ testDuplexConnRandomIds t describe "Establishing connections via `contact connection`" $ do - it "should connect via contact connection with one server and 3 agents" $ + it "should connect via contact connection with one server and 3 agents" $ do smpAgentTest3 $ testContactConnection t - it "should connect via contact connection with one server and 2 agents (random IDs)" $ + it "should connect via contact connection with one server and 2 agents (random IDs)" $ do smpAgentTest2_2_1 $ testContactConnRandomIds t - it "should support rejecting contact request" $ + it "should support rejecting contact request" $ do smpAgentTest2_2_1 $ testRejectContactRequest t describe "Connection subscriptions" $ do - it "should connect via one server and one agent" $ + it "should connect via one server and one agent" $ do smpAgentTest3_1_1 $ testSubscription t - it "should send notifications to client when server disconnects" $ + it "should send notifications to client when server disconnects" $ do smpAgentServerTest $ testSubscrNotification t describe "Message delivery and server reconnection" $ do - it "should deliver messages after losing server connection and re-connecting" $ + it "should deliver messages after losing server connection and re-connecting" $ do smpAgentTest2_2_2_needs_server $ testMsgDeliveryServerRestart t - it "should connect to the server when server goes up if it initially was down" $ + it "should connect to the server when server goes up if it initially was down" $ do smpAgentTestN [] $ testServerConnectionAfterError t - it "should deliver pending messages after agent restarting" $ + it "should deliver pending messages after agent restarting" $ do smpAgentTest1_1_1 $ testMsgDeliveryAgentRestart t - it "should concurrently deliver messages to connections without blocking" $ + it "should concurrently deliver messages to connections without blocking" $ do smpAgentTest2_2_1 $ testConcurrentMsgDelivery t - it "should deliver messages if one of connections has quota exceeded" $ + it "should deliver messages if one of connections has quota exceeded" $ do smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t - it "should resume delivering messages after exceeding quota once all messages are received" $ + it "should resume delivering messages after exceeding quota once all messages are received" $ do smpAgentTest2_2_1 $ testResumeDeliveryQuotaExceeded t type AEntityTransmission p e = (ACorrId, ConnId, ACommand p e) diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index 06b06adde..9548443a7 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -2,7 +2,6 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} - {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module AgentTests.ConnectionRequestTests where @@ -113,24 +112,24 @@ connectionRequestTests = it "should serialize connection requests" $ do strEncode connectionRequest `shouldBe` "simplex:/invitation#/?v=1&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" - <> urlEncode True testDhKeyStrUri - <> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" + <> urlEncode True testDhKeyStrUri + <> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" strEncode connectionRequestCurrentRange `shouldBe` "simplex:/invitation#/?v=1-4&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" - <> urlEncode True testDhKeyStrUri - <> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" - <> urlEncode True testDhKeyStrUri - <> "&e2e=v%3D1-2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" + <> urlEncode True testDhKeyStrUri + <> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" + <> urlEncode True testDhKeyStrUri + <> "&e2e=v%3D1-2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" strEncode connectionRequestClientDataEmpty `shouldBe` "simplex:/invitation#/?v=1&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" - <> urlEncode True testDhKeyStrUri - <> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" - <> "&data=%7B%7D" + <> urlEncode True testDhKeyStrUri + <> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" + <> "&data=%7B%7D" strEncode connectionRequestClientData `shouldBe` "simplex:/invitation#/?v=1&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" - <> urlEncode True testDhKeyStrUri - <> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" - <> "&data=%7B%22type%22%3A%22group_link%22%2C%20%22group_link_id%22%3A%22abc%22%7D" + <> urlEncode True testDhKeyStrUri + <> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" + <> "&data=%7B%22type%22%3A%22group_link%22%2C%20%22group_link_id%22%3A%22abc%22%7D" it "should parse connection requests" $ do strDecode ( "https://simplex.chat/invitation#/?smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23" diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 522949ebe..8f19c776e 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -43,6 +43,7 @@ import Data.Int (Int64) import qualified Data.Map as M import Data.Maybe (isNothing) import qualified Data.Set as S +import Data.Time.Clock (diffUTCTime, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Data.Type.Equality import SMPAgentClient @@ -310,6 +311,7 @@ functionalAPITests t = do describe "Delivery receipts" $ do it "should send and receive delivery receipt" $ withSmpServer t testDeliveryReceipts it "should send delivery receipt only in connection v3+" $ testDeliveryReceiptsVersion t + it "send delivery receipts concurrently with messages" $ testDeliveryReceiptsConcurrent t testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do @@ -1159,7 +1161,7 @@ testBatchedSubscriptions nCreate nDel t = do a <- getSMPAgentClient' agentCfg initAgentServers2 testDB b <- getSMPAgentClient' agentCfg initAgentServers2 testDB2 conns <- runServers $ do - conns <- forM [1 .. nCreate :: Int] . const $ makeConnection a b + conns <- replicateM (nCreate :: Int) $ makeConnection a b forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId let (aIds', bIds') = unzip $ take nDel conns delete a bIds' @@ -1928,6 +1930,57 @@ testDeliveryReceiptsVersion t = do disconnectAgentClient a' disconnectAgentClient b' +testDeliveryReceiptsConcurrent :: HasCallStack => ATransport -> IO () +testDeliveryReceiptsConcurrent t = + withSmpServerConfigOn t cfg {msgQueueQuota = 128} testPort $ \_ -> do + withAgentClients2 $ \a b -> do + (aId, bId) <- runRight $ makeConnection a b + t1 <- liftIO getCurrentTime + concurrently_ (runClient "a" a bId) (runClient "b" b aId) + t2 <- liftIO getCurrentTime + diffUTCTime t2 t1 `shouldSatisfy` (< 15) + liftIO $ noMessages a "nothing else should be delivered to alice" + liftIO $ noMessages b "nothing else should be delivered to bob" + where + runClient :: String -> AgentClient -> ConnId -> IO () + runClient _cName client connId = do + concurrently_ send receive + where + numMsgs = 100 + send = runRight_ $ + replicateM_ numMsgs $ do + -- liftIO $ print $ cName <> ": sendMessage" + void $ sendMessage client connId SMP.noMsgFlags "hello" + receive = + runRight_ $ + -- for each sent message: 1 SENT, 1 RCVD, 1 OK for acknowledging RCVD + -- for each received message: 1 MSG, 1 OK for acknowledging MSG + receiveLoop (numMsgs * 5) + receiveLoop :: Int -> ExceptT AgentErrorType IO () + receiveLoop 0 = pure () + receiveLoop n = do + r <- getWithTimeout + case r of + (_, _, SENT _) -> do + -- liftIO $ print $ cName <> ": SENT" + pure () + (_, _, MSG MsgMeta {recipient = (msgId, _), integrity = MsgOk} _ _) -> do + -- liftIO $ print $ cName <> ": MSG " <> show msgId + ackMessageAsync client (B.pack . show $ n) connId msgId (Just "") + (_, _, RCVD MsgMeta {recipient = (msgId, _), integrity = MsgOk} _) -> do + -- liftIO $ print $ cName <> ": RCVD " <> show msgId + ackMessageAsync client (B.pack . show $ n) connId msgId Nothing + (_, _, OK) -> do + -- liftIO $ print $ cName <> ": OK" + pure () + r' -> error $ "unexpected event: " <> show r' + receiveLoop (n - 1) + getWithTimeout :: ExceptT AgentErrorType IO (AEntityTransmission 'AEConn) + getWithTimeout = do + 1000000 `timeout` get client >>= \case + Just r -> pure r + _ -> error "timeout" + testTwoUsers :: HasCallStack => IO () testTwoUsers = withAgentClients2 $ \a b -> do let nc = netCfg initAgentServers diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 0b1755d02..da9f4c322 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -504,7 +504,7 @@ testNotificationsSMPRestartBatch n t APNSMockServer {apnsQ} = do a <- getSMPAgentClient' agentCfg initAgentServers2 testDB b <- getSMPAgentClient' agentCfg initAgentServers2 testDB2 conns <- runServers $ do - conns <- forM [1 .. n :: Int] . const $ makeConnection a b + conns <- replicateM (n :: Int) $ makeConnection a b _ <- registerTestToken a "abcd" NMInstant apnsQ liftIO $ threadDelay 1500000 forM_ conns $ \(aliceId, bobId) -> do diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index a2c8e3929..cf6e8373b 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -7,7 +7,6 @@ {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} - {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module AgentTests.SQLiteTests (storeTests) where diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index ab9763ff6..260411e6e 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -4,7 +4,7 @@ module CoreTests.BatchingTests (batchingTests) where import Control.Concurrent.STM import Control.Monad -import Crypto.Random (MonadRandom(..)) +import Crypto.Random (MonadRandom (..)) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.List.NonEmpty as L @@ -12,7 +12,7 @@ import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Transport -import Simplex.Messaging.Version (VersionRange(..)) +import Simplex.Messaging.Version (VersionRange (..)) import Test.Hspec batchingTests :: Spec diff --git a/tests/CoreTests/CryptoTests.hs b/tests/CoreTests/CryptoTests.hs index 0f47d395f..3ad26a886 100644 --- a/tests/CoreTests/CryptoTests.hs +++ b/tests/CoreTests/CryptoTests.hs @@ -195,7 +195,7 @@ testAESGCM = it "should encrypt / decrypt string with a random symmetric key" $ cipher `shouldNotBe` plain s `shouldBe` plain -testEncoding :: (C.AlgorithmI a) => C.SAlgorithm a -> Spec +testEncoding :: C.AlgorithmI a => C.SAlgorithm a -> Spec testEncoding alg = it "should encode / decode key" . ioProperty $ do (k, pk) <- C.generateKeyPair alg pure $ \(_ :: Int) -> diff --git a/tests/CoreTests/ProtocolErrorTests.hs b/tests/CoreTests/ProtocolErrorTests.hs index 39a00eb88..cc6da7b6c 100644 --- a/tests/CoreTests/ProtocolErrorTests.hs +++ b/tests/CoreTests/ProtocolErrorTests.hs @@ -1,14 +1,22 @@ +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# OPTIONS_GHC -Wno-orphans #-} module CoreTests.ProtocolErrorTests where import qualified Data.ByteString.Char8 as B import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) -import Simplex.Messaging.Agent.Protocol (AgentErrorType (..), BrokerErrorType (..)) +import GHC.Generics (Generic) +import Generic.Random (genericArbitraryU) +import Simplex.FileTransfer.Protocol (XFTPErrorType (..)) +import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Parsers (parseAll) +import Simplex.Messaging.Protocol (CommandError (..), ErrorType (..)) +import Simplex.Messaging.Transport (HandshakeError (..), TransportError (..)) import Test.Hspec import Test.Hspec.QuickCheck (modifyMaxSuccess) import Test.QuickCheck @@ -16,15 +24,58 @@ import Test.QuickCheck protocolErrorTests :: Spec protocolErrorTests = modifyMaxSuccess (const 1000) $ do describe "errors parsing / serializing" $ do - it "should parse SMP protocol errors" . property $ \(err :: AgentErrorType) -> - errHasSpaces err - || parseAll strP (strEncode err) == Right err + it "should parse SMP protocol errors" . property $ \(err :: ErrorType) -> + smpDecode (smpEncode err) == Right err it "should parse SMP agent errors" . property $ \(err :: AgentErrorType) -> errHasSpaces err - || parseAll strP (strEncode err) == Right err + || strDecode (strEncode err) == Right err where errHasSpaces = \case BROKER srv (RESPONSE e) -> hasSpaces srv || hasSpaces e BROKER srv _ -> hasSpaces srv _ -> False hasSpaces s = ' ' `B.elem` encodeUtf8 (T.pack s) + +deriving instance Generic AgentErrorType + +deriving instance Generic CommandErrorType + +deriving instance Generic ConnectionErrorType + +deriving instance Generic BrokerErrorType + +deriving instance Generic SMPAgentError + +deriving instance Generic AgentCryptoError + +deriving instance Generic ErrorType + +deriving instance Generic CommandError + +deriving instance Generic TransportError + +deriving instance Generic HandshakeError + +deriving instance Generic XFTPErrorType + +instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU + +instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU + +instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU + +instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU + +instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU + +instance Arbitrary AgentCryptoError where arbitrary = genericArbitraryU + +instance Arbitrary ErrorType where arbitrary = genericArbitraryU + +instance Arbitrary CommandError where arbitrary = genericArbitraryU + +instance Arbitrary TransportError where arbitrary = genericArbitraryU + +instance Arbitrary HandshakeError where arbitrary = genericArbitraryU + +instance Arbitrary XFTPErrorType where arbitrary = genericArbitraryU diff --git a/tests/CoreTests/RetryIntervalTests.hs b/tests/CoreTests/RetryIntervalTests.hs index d49bd3d14..7097df989 100644 --- a/tests/CoreTests/RetryIntervalTests.hs +++ b/tests/CoreTests/RetryIntervalTests.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} module CoreTests.RetryIntervalTests where diff --git a/tests/CoreTests/UtilTests.hs b/tests/CoreTests/UtilTests.hs index 1dd205b83..9e413e838 100644 --- a/tests/CoreTests/UtilTests.hs +++ b/tests/CoreTests/UtilTests.hs @@ -6,8 +6,8 @@ import Control.Exception (Exception, SomeException, throwIO) import Control.Monad.Except import Control.Monad.IO.Class import Data.IORef -import Simplex.Messaging.Util import Simplex.Messaging.Client.Agent () +import Simplex.Messaging.Util import Test.Hspec import qualified UnliftIO.Exception as UE diff --git a/tests/CoreTests/VersionRangeTests.hs b/tests/CoreTests/VersionRangeTests.hs index 4a623cd87..be02e38b7 100644 --- a/tests/CoreTests/VersionRangeTests.hs +++ b/tests/CoreTests/VersionRangeTests.hs @@ -39,13 +39,14 @@ versionRangeTests = modifyMaxSuccess (const 1000) $ do isCompatible (1 :: Version) (vr 2 2) `shouldBe` False it "compatibleVersion should pass isCompatible check" . property $ \((min1, max1) :: (V, V)) ((min2, max2) :: (V, V)) -> - min1 > max1 || min2 > max2 -- one of ranges is invalid, skip testing it + min1 > max1 + || min2 > max2 -- one of ranges is invalid, skip testing it || let w = fromIntegral . fromEnum vr1 = mkVersionRange (w min1) (w max1) :: VersionRange vr2 = mkVersionRange (w min2) (w max2) :: VersionRange in case compatibleVersion vr1 vr2 of - Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2 - _ -> True + Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2 + _ -> True where vr = mkVersionRange compatible :: (VersionRange, VersionRange) -> Maybe Version -> Expectation diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index d4f6a856d..15a42fa8c 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -38,6 +38,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Server (runNtfServerBlocking) import Simplex.Messaging.Notifications.Server.Env import Simplex.Messaging.Notifications.Server.Push.APNS +import Simplex.Messaging.Notifications.Server.Push.APNS.Internal import Simplex.Messaging.Notifications.Transport import Simplex.Messaging.Protocol import Simplex.Messaging.Transport @@ -186,10 +187,16 @@ instance FromJSON APNSAlertBody where parseJSON (J.String v) = pure $ APNSAlertText v parseJSON invalid = JT.prependFailure "parsing Coord failed, " (JT.typeMismatch "Object" invalid) +deriving instance Generic APNSNotificationBody + instance FromJSON APNSNotificationBody where parseJSON = J.genericParseJSON apnsJSONOptions {J.rejectUnknownFields = True} +deriving instance Generic APNSNotification + deriving instance FromJSON APNSNotification +deriving instance Generic APNSErrorResponse + deriving instance ToJSON APNSErrorResponse getAPNSMockServer :: HTTP2ServerConfig -> IO APNSMockServer diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index d06ededa9..77a4b1945 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -139,7 +139,6 @@ testNotificationSubscription (ATransport t) = mTs `shouldBe` msgTs (msgBody, "hello") #== "delivered from queue" Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, ACK mId1) - pure () -- replace token let tkn' = DeviceToken PPApnsTest "efgh" RespNtf "7" tId' NROk <- signSendRecvNtf nh tknKey ("7", tId, TRPL tkn') diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index a45491e97..cc20d3958 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -832,7 +832,6 @@ testRestoreExpireMessages at@(ATransport t) = msgs'' <- B.readFile testStoreMsgsFile length (B.lines msgs'') `shouldBe` 2 B.lines msgs'' `shouldBe` drop 2 (B.lines msgs) - where runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index 75d68c4da..465c9c2b6 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -292,8 +292,8 @@ testXFTPAgentSendRestore = withGlobalLogging logCfgNoLogs $ do -- receive file rcp <- getSMPAgentClient' agentCfg initAgentServers testDB2 - runRight_ $ - void $ testReceive rcp rfd1 filePath + runRight_ . void $ + testReceive rcp rfd1 filePath testXFTPAgentSendCleanup :: HasCallStack => IO () testXFTPAgentSendCleanup = withGlobalLogging logCfgNoLogs $ do @@ -342,8 +342,8 @@ testXFTPAgentDelete = withGlobalLogging logCfgNoLogs $ -- receive file rcp1 <- getSMPAgentClient' agentCfg initAgentServers testDB2 - runRight_ $ - void $ testReceive rcp1 rfd1 filePath + runRight_ . void $ + testReceive rcp1 rfd1 filePath length <$> listDirectory xftpServerFiles `shouldReturn` 6 @@ -377,8 +377,8 @@ testXFTPAgentDeleteRestore = withGlobalLogging logCfgNoLogs $ do -- receive file rcp1 <- getSMPAgentClient' agentCfg initAgentServers testDB2 - runRight_ $ - void $ testReceive rcp1 rfd1 filePath + runRight_ . void $ + testReceive rcp1 rfd1 filePath disconnectAgentClient rcp1 disconnectAgentClient sndr pure (sfId, sndDescr, rfd2)