From ee90ea6a69fe8283d37d9821cd83798fd0a76260 Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Thu, 28 Mar 2024 01:35:09 +0200 Subject: [PATCH 1/5] replace base64-bytestring with base64 (#1065) * replace base64-bytestring with base64 * minify * use bytestring-0.10 compatible fork PR pending... * bump base64 fork with text compat * move compat details to modules * switch repo * add back module * cleanup * minify * clean imports * rename --------- Co-authored-by: Evgeny Poberezkin --- cabal.project | 6 ++++ package.yaml | 2 +- simplexmq.cabal | 16 +++++---- src/Simplex/FileTransfer/Server.hs | 4 +-- src/Simplex/Messaging/Agent/Client.hs | 2 +- src/Simplex/Messaging/Agent/Protocol.hs | 2 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 2 +- src/Simplex/Messaging/Crypto.hs | 4 +-- src/Simplex/Messaging/Encoding/Base64.hs | 29 ++++++++++++++++ src/Simplex/Messaging/Encoding/Base64/URL.hs | 33 +++++++++++++++++++ src/Simplex/Messaging/Encoding/String.hs | 17 ++++------ .../Notifications/Server/Push/APNS.hs | 8 +++-- src/Simplex/Messaging/Parsers.hs | 18 +--------- src/Simplex/Messaging/Protocol.hs | 2 +- src/Simplex/Messaging/Server.hs | 2 +- tests/AgentTests/NotificationTests.hs | 4 +-- tests/NtfServerTests.hs | 2 +- tests/ServerTests.hs | 2 +- tests/XFTPServerTests.hs | 4 +-- 19 files changed, 105 insertions(+), 54 deletions(-) create mode 100644 src/Simplex/Messaging/Encoding/Base64.hs create mode 100644 src/Simplex/Messaging/Encoding/Base64/URL.hs diff --git a/cabal.project b/cabal.project index 43afe30ea..27811ed5e 100644 --- a/cabal.project +++ b/cabal.project @@ -14,6 +14,12 @@ source-repository-package location: https://github.com/simplex-chat/aeson.git tag: aab7b5a14d6c5ea64c64dcaee418de1bb00dcc2b +-- old bs/text compat for 8.10 +source-repository-package + type: git + location: https://github.com/simplex-chat/base64.git + tag: 2d77b6dbcaffc00570a70be8694049f3710e7c94 + source-repository-package type: git location: https://github.com/simplex-chat/hs-socks.git diff --git a/package.yaml b/package.yaml index 17b071148..92a79b283 100644 --- a/package.yaml +++ b/package.yaml @@ -31,7 +31,7 @@ dependencies: - async == 2.2.* - attoparsec == 0.14.* - base >= 4.14 && < 5 - - base64-bytestring >= 1.0 && < 1.3 + - base64 == 1.0.* - case-insensitive == 1.2.* - composition == 1.0.* - constraints >= 0.12 && < 0.14 diff --git a/simplexmq.cabal b/simplexmq.cabal index 37f1d7ed4..aeb324ca9 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -119,6 +119,8 @@ library Simplex.Messaging.Crypto.SNTRUP761.Bindings.FFI Simplex.Messaging.Crypto.SNTRUP761.Bindings.RNG Simplex.Messaging.Encoding + Simplex.Messaging.Encoding.Base64 + Simplex.Messaging.Encoding.Base64.URL Simplex.Messaging.Encoding.String Simplex.Messaging.Notifications.Client Simplex.Messaging.Notifications.Protocol @@ -187,7 +189,7 @@ library , async ==2.2.* , attoparsec ==0.14.* , base >=4.14 && <5 - , base64-bytestring >=1.0 && <1.3 + , base64 ==1.0.* , case-insensitive ==1.2.* , composition ==1.0.* , constraints >=0.12 && <0.14 @@ -259,7 +261,7 @@ executable ntf-server , async ==2.2.* , attoparsec ==0.14.* , base >=4.14 && <5 - , base64-bytestring >=1.0 && <1.3 + , base64 ==1.0.* , case-insensitive ==1.2.* , composition ==1.0.* , constraints >=0.12 && <0.14 @@ -332,7 +334,7 @@ executable smp-agent , async ==2.2.* , attoparsec ==0.14.* , base >=4.14 && <5 - , base64-bytestring >=1.0 && <1.3 + , base64 ==1.0.* , case-insensitive ==1.2.* , composition ==1.0.* , constraints >=0.12 && <0.14 @@ -405,7 +407,7 @@ executable smp-server , async ==2.2.* , attoparsec ==0.14.* , base >=4.14 && <5 - , base64-bytestring >=1.0 && <1.3 + , base64 ==1.0.* , case-insensitive ==1.2.* , composition ==1.0.* , constraints >=0.12 && <0.14 @@ -478,7 +480,7 @@ executable xftp , async ==2.2.* , attoparsec ==0.14.* , base >=4.14 && <5 - , base64-bytestring >=1.0 && <1.3 + , base64 ==1.0.* , case-insensitive ==1.2.* , composition ==1.0.* , constraints >=0.12 && <0.14 @@ -551,7 +553,7 @@ executable xftp-server , async ==2.2.* , attoparsec ==0.14.* , base >=4.14 && <5 - , base64-bytestring >=1.0 && <1.3 + , base64 ==1.0.* , case-insensitive ==1.2.* , composition ==1.0.* , constraints >=0.12 && <0.14 @@ -658,7 +660,7 @@ test-suite simplexmq-test , async ==2.2.* , attoparsec ==0.14.* , base >=4.14 && <5 - , base64-bytestring >=1.0 && <1.3 + , base64 ==1.0.* , case-insensitive ==1.2.* , composition ==1.0.* , constraints >=0.12 && <0.14 diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index 5b526e692..100301cac 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -17,7 +17,6 @@ import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Data.Bifunctor (first) -import qualified Data.ByteString.Base64.URL as B64 import Data.ByteString.Builder (byteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -46,6 +45,7 @@ import Simplex.FileTransfer.Server.StoreLog import Simplex.FileTransfer.Transport import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC +import qualified Simplex.Messaging.Encoding.Base64.URL as U import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (CorrId, RcvPublicDhKey, RcvPublicAuthKey, RecipientId, TransmissionAuth) import Simplex.Messaging.Server (dummyVerifyCmd, verifyCmdAuthorization) @@ -382,7 +382,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case \used -> let used' = used + fromIntegral size in if used' <= quota then (True, used') else (False, used) receive = do path <- asks $ filesPath . config - let fPath = path B.unpack (B64.encode senderId) + let fPath = path B.unpack (U.encode senderId) receiveChunk (XFTPRcvChunkSpec fPath size digest) >>= \case Right () -> do stats <- asks serverStats diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 834026e6a..3f979125f 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -142,7 +142,6 @@ import Crypto.Random (ChaChaDRG) import qualified Data.Aeson as J import qualified Data.Aeson.TH as J import Data.Bifunctor (bimap, first, second) -import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:.)) @@ -182,6 +181,7 @@ import Simplex.Messaging.Client import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.Base64 (encode) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Client import Simplex.Messaging.Notifications.Protocol diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 2c06e0279..df9907fe0 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -163,7 +163,6 @@ import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A -import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) @@ -202,6 +201,7 @@ import Simplex.Messaging.Crypto.Ratchet SndE2ERatchetParams ) import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.Base64 (base64P, encode) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 2f6707c5a..4f5c1573b 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -231,7 +231,6 @@ import Data.Bifunctor (first, second) import Data.ByteArray (ScrubbedBytes) import qualified Data.ByteArray as BA import Data.ByteString (ByteString) -import qualified Data.ByteString.Base64.URL as U import qualified Data.ByteString.Char8 as B import Data.Char (toLower) import Data.Functor (($>)) @@ -271,6 +270,7 @@ import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..)) import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys, PQEncryption (..), PQSupport (..)) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding +import qualified Simplex.Messaging.Encoding.Base64.URL as U import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..)) import Simplex.Messaging.Notifications.Types diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 28183a1fc..84d1882fa 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -211,8 +211,6 @@ import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (bimap, first) import Data.ByteArray (ByteArrayAccess) import qualified Data.ByteArray as BA -import Data.ByteString.Base64 (decode, encode) -import qualified Data.ByteString.Base64.URL as U import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.ByteString.Lazy (fromStrict, toStrict) @@ -230,6 +228,8 @@ import Database.SQLite.Simple.ToField (ToField (..)) import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+)) import Network.Transport.Internal (decodeWord16, encodeWord16) import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.Base64 (decode, encode) +import qualified Simplex.Messaging.Encoding.Base64.URL as U import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldDecoder, parseAll, parseString) import Simplex.Messaging.Util ((<$?>)) diff --git a/src/Simplex/Messaging/Encoding/Base64.hs b/src/Simplex/Messaging/Encoding/Base64.hs new file mode 100644 index 000000000..951250abc --- /dev/null +++ b/src/Simplex/Messaging/Encoding/Base64.hs @@ -0,0 +1,29 @@ +{-# LANGUAGE OverloadedStrings #-} + +-- | Compatibility wrappers for base64 package, Base64 (padded) variant. +module Simplex.Messaging.Encoding.Base64 where + +import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.Base64.Types (extractBase64) +import Data.Bifunctor (first) +import Data.ByteString.Base64 (decodeBase64Untyped, encodeBase64') +import Data.ByteString.Char8 (ByteString) +import qualified Data.ByteString.Char8 as B +import qualified Data.Text as T + +encode :: ByteString -> ByteString +encode = extractBase64 . encodeBase64' +{-# INLINE encode #-} + +decode :: ByteString -> Either String ByteString +decode = first T.unpack . decodeBase64Untyped +{-# INLINE decode #-} + +base64P :: A.Parser ByteString +base64P = do + str <- A.takeWhile1 (`B.elem` base64Alphabet) + pad <- A.takeWhile (== '=') -- correct amount of padding can be derived from str length + either (fail . T.unpack) pure $ decodeBase64Untyped (str <> pad) + +base64Alphabet :: ByteString +base64Alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" diff --git a/src/Simplex/Messaging/Encoding/Base64/URL.hs b/src/Simplex/Messaging/Encoding/Base64/URL.hs new file mode 100644 index 000000000..247002376 --- /dev/null +++ b/src/Simplex/Messaging/Encoding/Base64/URL.hs @@ -0,0 +1,33 @@ +{-# LANGUAGE OverloadedStrings #-} + +-- | Compatibility wrappers for base64 package, Base64URL-padded variant. +module Simplex.Messaging.Encoding.Base64.URL where + +import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.Base64.Types (extractBase64) +import Data.Bifunctor (first) +import Data.ByteString.Base64.URL (decodeBase64Lenient, decodeBase64UnpaddedUntyped, decodeBase64Untyped, encodeBase64') +import Data.ByteString.Char8 (ByteString) +import qualified Data.ByteString.Char8 as B +import qualified Data.Text as T + +encode :: ByteString -> ByteString +encode = extractBase64 . encodeBase64' +{-# INLINE encode #-} + +decode :: ByteString -> Either String ByteString +decode = first T.unpack . decodeBase64Untyped +{-# INLINE decode #-} + +decodeLenient :: ByteString -> ByteString +decodeLenient = decodeBase64Lenient +{-# INLINE decodeLenient #-} + +base64urlP :: A.Parser ByteString +base64urlP = do + str <- A.takeWhile1 (`B.elem` base64AlphabetURL) + _pad <- A.takeWhile (== '=') -- correct amount of padding can be derived from str length + either (fail . T.unpack) pure $ decodeBase64UnpaddedUntyped str + +base64AlphabetURL :: ByteString +base64AlphabetURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" diff --git a/src/Simplex/Messaging/Encoding/String.hs b/src/Simplex/Messaging/Encoding/String.hs index fcefdc73d..46dc659a9 100644 --- a/src/Simplex/Messaging/Encoding/String.hs +++ b/src/Simplex/Messaging/Encoding/String.hs @@ -10,7 +10,6 @@ module Simplex.Messaging.Encoding.String strToJSON, strToJEncoding, strParseJSON, - base64urlP, strEncodeList, strListP, ) @@ -23,10 +22,8 @@ import qualified Data.Aeson.Encoding as JE import qualified Data.Aeson.Types as JT import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A -import qualified Data.ByteString.Base64.URL as U import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Char (isAlphaNum) import Data.Int (Int64) import qualified Data.List.NonEmpty as L import Data.Set (Set) @@ -38,6 +35,7 @@ import Data.Time.Clock.System (SystemTime (..)) import Data.Time.Format.ISO8601 import Data.Word (Word16, Word32) import Simplex.Messaging.Encoding +import qualified Simplex.Messaging.Encoding.Base64.URL as U import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util ((<$?>)) @@ -54,19 +52,16 @@ class StrEncoding a where strDecode :: ByteString -> Either String a strDecode = parseAll strP strP :: Parser a - strP = strDecode <$?> base64urlP + strP = strDecode <$?> U.base64urlP -- base64url encoding/decoding of ByteStrings - the parser only allows non-empty strings instance StrEncoding ByteString where strEncode = U.encode + {-# INLINE strEncode #-} strDecode = U.decode - strP = base64urlP - -base64urlP :: Parser ByteString -base64urlP = do - str <- A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_') - pad <- A.takeWhile (== '=') - either fail pure $ U.decode (str <> pad) + {-# INLINE strDecode #-} + strP = U.base64urlP + {-# INLINE strP #-} newtype Str = Str {unStr :: ByteString} deriving (Eq, Show) diff --git a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs index 9c3de04df..45f6bf637 100644 --- a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs +++ b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs @@ -27,8 +27,9 @@ import Data.Aeson (ToJSON, (.=)) import qualified Data.Aeson as J import qualified Data.Aeson.Encoding as JE import qualified Data.Aeson.TH as JQ +import Data.Base64.Types (extractBase64) import Data.Bifunctor (first) -import qualified Data.ByteString.Base64.URL as U +import qualified Data.ByteString.Base64.URL as UP import Data.ByteString.Builder (lazyByteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Lazy.Char8 as LB @@ -46,6 +47,7 @@ import Network.HTTP2.Client (Request) import qualified Network.HTTP2.Client as H import Network.Socket (HostName, ServiceName) import qualified Simplex.Messaging.Crypto as C +import qualified Simplex.Messaging.Encoding.Base64.URL as U import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Push.APNS.Internal @@ -91,8 +93,8 @@ signedJWTToken pk (JWTToken hdr claims) = do pure $ hc <> "." <> serialize sig where jwtEncode :: ToJSON a => a -> ByteString - jwtEncode = U.encodeUnpadded . LB.toStrict . J.encode - serialize sig = U.encodeUnpadded $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence] + jwtEncode = extractBase64 . UP.encodeBase64Unpadded' . LB.toStrict . J.encode + serialize sig = extractBase64 . UP.encodeBase64Unpadded' $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence] readECPrivateKey :: FilePath -> IO EC.PrivateKey readECPrivateKey f = do diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 39cb0383c..17486ab9c 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -10,10 +10,9 @@ import qualified Data.Aeson as J import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) -import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Char (isAlphaNum, toLower) +import Data.Char (toLower) import Data.String import Data.Text (Text) import qualified Data.Text as T @@ -24,23 +23,8 @@ import Database.SQLite.Simple (ResultError (..), SQLData (..)) import Database.SQLite.Simple.FromField (FieldParser, returnError) import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) -import Simplex.Messaging.Util ((<$?>)) import Text.Read (readMaybe) -base64P :: Parser ByteString -base64P = decode <$?> paddedBase64 rawBase64P - -paddedBase64 :: Parser ByteString -> Parser ByteString -paddedBase64 raw = (<>) <$> raw <*> pad - where - pad = A.takeWhile (== '=') - -rawBase64P :: Parser ByteString -rawBase64P = A.takeWhile1 (\c -> isAlphaNum c || c == '+' || c == '/') - --- rawBase64UriP :: Parser ByteString --- rawBase64UriP = A.takeWhile1 (\c -> isAlphaNum c || c == '-' || c == '_') - tsISO8601P :: Parser UTCTime tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 2c593fc6f..a867a915e 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -174,7 +174,6 @@ import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser, ()) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) -import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Char (isPrint, isSpace) @@ -192,6 +191,7 @@ import GHC.TypeLits (ErrorMessage (..), TypeError, type (+)) import Network.Socket (ServiceName) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding +import qualified Simplex.Messaging.Encoding.Base64 as B64 import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers import Simplex.Messaging.ServiceScheme diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 181af8fac..a7e25a82b 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -45,7 +45,6 @@ import Control.Monad.IO.Unlift import Control.Monad.Reader import Crypto.Random import Data.Bifunctor (first) -import Data.ByteString.Base64 (encode) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Either (fromRight, partitionEithers) @@ -68,6 +67,7 @@ import Network.Socket (ServiceName, Socket, socketToHandle) import Simplex.Messaging.Agent.Lock import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding (Encoding (smpEncode)) +import Simplex.Messaging.Encoding.Base64 (encode) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol import Simplex.Messaging.Server.Control diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 18fccbf52..c884cbd93 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -42,7 +42,6 @@ import Control.Monad.Trans.Except import qualified Data.Aeson as J import qualified Data.Aeson.Types as JT import Data.Bifunctor (bimap, first) -import qualified Data.ByteString.Base64.URL as U import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Text.Encoding (encodeUtf8) @@ -51,6 +50,7 @@ import SMPAgentClient (agentCfg, initAgentServers, initAgentServers2, testDB, te import SMPClient (cfg, cfgV7, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn) import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore') +import qualified Simplex.Messaging.Encoding.Base64.URL as U import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers) import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO) import Simplex.Messaging.Agent.Store.SQLite (getSavedNtfToken) @@ -345,7 +345,7 @@ testRunNTFServerTests :: ATransport -> NtfServer -> IO (Maybe ProtocolTestFailur testRunNTFServerTests t srv = withNtfServerThreadOn t ntfTestPort $ \ntf -> do a <- liftIO $ getSMPAgentClient' 1 agentCfg initAgentServers testDB - r <- runRight $ testProtocolServer a 1 $ ProtoServerWithAuth srv Nothing + r <- runRight $ testProtocolServer a 1 $ ProtoServerWithAuth srv Nothing killThread ntf pure r diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index e7e2018c2..027675aeb 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -15,7 +15,6 @@ import Control.Concurrent (threadDelay) import qualified Data.Aeson as J import qualified Data.Aeson.Types as JT import Data.Bifunctor (first) -import qualified Data.ByteString.Base64.URL as U import Data.ByteString.Char8 (ByteString) import Data.Text.Encoding (encodeUtf8) import NtfClient @@ -35,6 +34,7 @@ import ServerTests import qualified Simplex.Messaging.Agent.Protocol as AP import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding +import qualified Simplex.Messaging.Encoding.Base64.URL as U import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Push.APNS diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 03935fed5..008b4da88 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -22,7 +22,6 @@ import Control.Exception (SomeException, try) import Control.Monad import Control.Monad.IO.Class import Data.Bifunctor (first) -import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.Set as S @@ -31,6 +30,7 @@ import GHC.Stack (withFrozenCallStack) import SMPClient import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.Base64 (encode) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Protocol diff --git a/tests/XFTPServerTests.hs b/tests/XFTPServerTests.hs index 451406275..4650edf57 100644 --- a/tests/XFTPServerTests.hs +++ b/tests/XFTPServerTests.hs @@ -12,7 +12,6 @@ import Control.Exception (SomeException) import Control.Monad import Control.Monad.Except import Control.Monad.IO.Unlift -import qualified Data.ByteString.Base64.URL as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB @@ -26,6 +25,7 @@ import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (..) import Simplex.Messaging.Client (ProtocolClientError (..)) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC +import qualified Simplex.Messaging.Encoding.Base64.URL as U import Simplex.Messaging.Protocol (BasicAuth, SenderId) import Simplex.Messaging.Server.Expiration (ExpirationConfig (..)) import Simplex.Messaging.Util (liftIOEither) @@ -75,7 +75,7 @@ createTestChunk fp = do pure bytes readChunk :: SenderId -> IO ByteString -readChunk sId = B.readFile (xftpServerFiles B.unpack (B64.encode sId)) +readChunk sId = B.readFile (xftpServerFiles B.unpack (U.encode sId)) testFileChunkDelivery :: Expectation testFileChunkDelivery = xftpTest $ \c -> runRight_ $ runTestFileChunkDelivery c c From bbc9eccf4d73b451278a73f06869c2d43c8709ef Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Thu, 28 Mar 2024 20:12:48 +0200 Subject: [PATCH 2/5] xftp: prevent overwriting completed upload (#1063) * xftp: prevent overwriting completed upload * add size check for skipCommitted * fix import * fail on incorrect size --------- Co-authored-by: Evgeny Poberezkin --- src/Simplex/FileTransfer/Server.hs | 32 ++++++++++++++++++++---------- tests/XFTPServerTests.hs | 20 +++++++++++++++++++ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index 100301cac..4874aea1a 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -4,6 +4,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedLists #-} @@ -25,7 +26,7 @@ import Data.List (intercalate) import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (fromMaybe) +import Data.Maybe (fromMaybe, isJust) import qualified Data.Text as T import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) @@ -47,13 +48,14 @@ import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC import qualified Simplex.Messaging.Encoding.Base64.URL as U import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (CorrId, RcvPublicDhKey, RcvPublicAuthKey, RecipientId, TransmissionAuth) +import Simplex.Messaging.Protocol (CorrId, RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TransmissionAuth) import Simplex.Messaging.Server (dummyVerifyCmd, verifyCmdAuthorization) import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Server.Stats import Simplex.Messaging.Transport (THandleParams (..)) import Simplex.Messaging.Transport.Buffer (trimCR) import Simplex.Messaging.Transport.HTTP2 +import Simplex.Messaging.Transport.HTTP2.File (fileBlockSize) import Simplex.Messaging.Transport.HTTP2.Server import Simplex.Messaging.Transport.Server (runTCPServer) import Simplex.Messaging.Util @@ -67,13 +69,12 @@ import qualified UnliftIO.Exception as E type M a = ReaderT XFTPEnv IO a -data XFTPTransportRequest = - XFTPTransportRequest - { thParams :: THandleParams XFTPVersion, - reqBody :: HTTP2Body, - request :: H.Request, - sendResponse :: H.Response -> IO () - } +data XFTPTransportRequest = XFTPTransportRequest + { thParams :: THandleParams XFTPVersion, + reqBody :: HTTP2Body, + request :: H.Request, + sendResponse :: H.Response -> IO () + } runXFTPServer :: XFTPServerConfig -> IO () runXFTPServer cfg = do @@ -373,8 +374,19 @@ processXFTPRequest HTTP2Body {bodyPart} = \case receiveServerFile FileRec {senderId, fileInfo = FileInfo {size, digest}, filePath} = case bodyPart of Nothing -> pure $ FRErr SIZE -- TODO validate body size from request before downloading, once it's populated - Just getBody -> ifM reserve receive (pure $ FRErr QUOTA) -- TODO: handle duplicate uploads + Just getBody -> skipCommitted $ ifM reserve receive (pure $ FRErr QUOTA) where + -- having a filePath means the file is already uploaded and committed, must not change anything + skipCommitted = ifM (isJust <$> readTVarIO filePath) (liftIO $ drain $ fromIntegral size) + where + -- can't send FROk without reading the request body or a client will block on sending it + -- can't send any old error as the client would fail or restart indefinitely + drain s = do + bs <- B.length <$> getBody fileBlockSize + if + | bs == s -> pure FROk + | bs == 0 || bs > s -> pure $ FRErr SIZE + | otherwise -> drain (s - bs) reserve = do us <- asks $ usedStorage . store quota <- asks $ fromMaybe maxBound . fileSizeQuota . config diff --git a/tests/XFTPServerTests.hs b/tests/XFTPServerTests.hs index 4650edf57..71700280a 100644 --- a/tests/XFTPServerTests.hs +++ b/tests/XFTPServerTests.hs @@ -60,6 +60,7 @@ xftpServerTests = it "prohibited when FNEW disabled" $ testFileBasicAuth False (Just "pwd") (Just "pwd") False it "allowed with correct basic auth" $ testFileBasicAuth True (Just "pwd") (Just "pwd") True it "allowed with auth on server without auth" $ testFileBasicAuth True Nothing (Just "any") True + it "should not change content for uploaded and committed files" testFileSkipCommitted chSize :: Integral a => a chSize = kb 128 @@ -372,3 +373,22 @@ testFileBasicAuth allowNewFiles newFileBasicAuth clntAuth success = else do void (createXFTPChunk c spKey file [rcvKey] clntAuth) `catchError` (liftIO . (`shouldBe` PCEProtocolError AUTH)) + +testFileSkipCommitted :: IO () +testFileSkipCommitted = + withXFTPServerCfg testXFTPServerConfig $ + \_ -> testXFTPClient $ \c -> do + g <- C.newRandom + (sndKey, spKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (rcvKey, rpKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + bytes <- createTestChunk testChunkPath + digest <- LC.sha256Hash <$> LB.readFile testChunkPath + let file = FileInfo {sndKey, size = chSize, digest} + chunkSpec = XFTPChunkSpec {filePath = testChunkPath, chunkOffset = 0, chunkSize = chSize} + runRight_ $ do + (sId, [rId]) <- createXFTPChunk c spKey file [rcvKey] Nothing + uploadXFTPChunk c spKey sId chunkSpec + void . liftIO $ createTestChunk testChunkPath -- trash chunk contents + uploadXFTPChunk c spKey sId chunkSpec -- upload again to get FROk without getting stuck + downloadXFTPChunk g c rpKey rId $ XFTPRcvChunkSpec "tests/tmp/received_chunk" chSize digest + liftIO $ B.readFile "tests/tmp/received_chunk" `shouldReturn` bytes -- new chunk content got ignored From 44410535fdd3e2345629cbe7bf94f0caf331cb0d Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Thu, 28 Mar 2024 18:16:36 +0000 Subject: [PATCH 3/5] do not pass key to control port of xftp server (#1074) --- src/Simplex/FileTransfer/Server.hs | 8 +++----- src/Simplex/FileTransfer/Server/Control.hs | 6 +++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index 4874aea1a..2d957a7b2 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -223,15 +223,13 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira | Just auth == user = CPRUser | otherwise = CPRNone CPStatsRTS -> E.tryAny getRTSStats >>= either (hPrint h) (hPrint h) - CPDelete fileId fKey -> withUserRole $ unliftIO u $ do + CPDelete fileId -> withUserRole $ unliftIO u $ do fs <- asks store r <- runExceptT $ do let asSender = ExceptT . atomically $ getFile fs SFSender fileId let asRecipient = ExceptT . atomically $ getFile fs SFRecipient fileId - (fr, fKey') <- asSender `catchError` const asRecipient - if fKey == fKey' - then ExceptT $ deleteServerFile_ fr - else throwError AUTH + (fr, _) <- asSender `catchError` const asRecipient + ExceptT $ deleteServerFile_ fr liftIO . hPutStrLn h $ either (\e -> "error: " <> show e) (\() -> "ok") r CPHelp -> hPutStrLn h "commands: stats-rts, delete, help, quit" CPQuit -> pure () diff --git a/src/Simplex/FileTransfer/Server/Control.hs b/src/Simplex/FileTransfer/Server/Control.hs index d8d0c425f..54d349c3b 100644 --- a/src/Simplex/FileTransfer/Server/Control.hs +++ b/src/Simplex/FileTransfer/Server/Control.hs @@ -14,7 +14,7 @@ data CPClientRole = CPRNone | CPRUser | CPRAdmin data ControlProtocol = CPAuth BasicAuth | CPStatsRTS - | CPDelete ByteString C.APublicAuthKey + | CPDelete ByteString | CPHelp | CPQuit | CPSkip @@ -23,7 +23,7 @@ instance StrEncoding ControlProtocol where strEncode = \case CPAuth tok -> "auth " <> strEncode tok CPStatsRTS -> "stats-rts" - CPDelete fId fKey -> strEncode (Str "delete", fId, fKey) + CPDelete fId -> strEncode (Str "delete", fId) CPHelp -> "help" CPQuit -> "quit" CPSkip -> "" @@ -31,7 +31,7 @@ instance StrEncoding ControlProtocol where A.takeTill (== ' ') >>= \case "auth" -> CPAuth <$> _strP "stats-rts" -> pure CPStatsRTS - "delete" -> CPDelete <$> _strP <*> _strP + "delete" -> CPDelete <$> _strP "help" -> pure CPHelp "quit" -> pure CPQuit "" -> pure CPSkip From 6c48092f5ead4adbeb38cda3a9d1fdad1e230e5c Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Thu, 28 Mar 2024 18:23:19 +0000 Subject: [PATCH 4/5] 5.6.1.0 --- CHANGELOG.md | 10 ++++++++++ package.yaml | 2 +- simplexmq.cabal | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 003857b65..ed15e7083 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,13 @@ +# 5.6.1 + +Version 5.6.1.0. + +- Much faster iOS notification server start time (fewer skipped notifications). +- Fix SMP server stored message stats. +- Prevent overwriting uploaded XFTP files with subsequent upload attempts. +- Faster base64 encoding/parsing. +- Control port audit log and authentication. + # 5.6.0 Version 5.6.0.4. diff --git a/package.yaml b/package.yaml index 92a79b283..89572ce94 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: simplexmq -version: 5.6.0.4 +version: 5.6.1.0 synopsis: SimpleXMQ message broker description: | This package includes <./docs/Simplex-Messaging-Server.html server>, diff --git a/simplexmq.cabal b/simplexmq.cabal index aeb324ca9..1b35839b7 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -5,7 +5,7 @@ cabal-version: 1.12 -- see: https://github.com/sol/hpack name: simplexmq -version: 5.6.0.4 +version: 5.6.1.0 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and From 6ded721daaca76c416408396aa068a95616f6eaf Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Sun, 31 Mar 2024 20:50:35 +0100 Subject: [PATCH 5/5] remove monad typeclasses to reduce overhead (#1077) * remove monad typeclasses to reduce overhead * remove unliftIO * StrictData * inline * optional agent port * avoid MonadUnliftIO instance (#1078) * avoid MonadUnliftIO instance * simpler liftError' * rename * narrow down instance * revert --------- Co-authored-by: Evgeny Poberezkin * logServer --------- Co-authored-by: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> --- apps/smp-agent/Main.hs | 4 +- package.yaml | 4 + simplexmq.cabal | 28 +- src/Simplex/FileTransfer/Agent.hs | 176 +++-- src/Simplex/FileTransfer/Client.hs | 10 +- src/Simplex/FileTransfer/Client/Agent.hs | 5 +- src/Simplex/FileTransfer/Client/Main.hs | 9 +- src/Simplex/FileTransfer/Crypto.hs | 2 +- src/Simplex/FileTransfer/Server.hs | 8 +- src/Simplex/FileTransfer/Server/Control.hs | 1 - src/Simplex/FileTransfer/Server/Env.hs | 2 +- src/Simplex/Messaging/Agent.hs | 748 +++++++++--------- src/Simplex/Messaging/Agent/Client.hs | 288 ++++--- src/Simplex/Messaging/Agent/Env/SQLite.hs | 30 +- src/Simplex/Messaging/Agent/Lock.hs | 12 +- .../Messaging/Agent/NtfSubSupervisor.hs | 89 ++- src/Simplex/Messaging/Agent/Protocol.hs | 12 +- src/Simplex/Messaging/Agent/Server.hs | 40 +- src/Simplex/Messaging/Client.hs | 33 +- src/Simplex/Messaging/Client/Agent.hs | 38 +- src/Simplex/Messaging/Crypto/File.hs | 2 +- src/Simplex/Messaging/Notifications/Server.hs | 3 +- .../Messaging/Notifications/Server/Env.hs | 2 +- src/Simplex/Messaging/Server.hs | 77 +- src/Simplex/Messaging/Server/Env/STM.hs | 12 +- src/Simplex/Messaging/Transport/Client.hs | 9 +- src/Simplex/Messaging/Transport/Server.hs | 14 +- src/Simplex/Messaging/Util.hs | 32 +- src/Simplex/RemoteControl/Client.hs | 27 +- src/Simplex/RemoteControl/Discovery.hs | 21 +- src/Simplex/RemoteControl/Types.hs | 11 - tests/AgentTests/FunctionalAPITests.hs | 57 +- tests/AgentTests/NotificationTests.hs | 29 +- tests/NtfClient.hs | 7 +- tests/SMPAgentClient.hs | 21 +- tests/SMPClient.hs | 13 +- tests/ServerTests.hs | 2 +- tests/XFTPAgent.hs | 31 +- tests/XFTPServerTests.hs | 3 +- 39 files changed, 1023 insertions(+), 889 deletions(-) diff --git a/apps/smp-agent/Main.hs b/apps/smp-agent/Main.hs index 84067d945..583127837 100644 --- a/apps/smp-agent/Main.hs +++ b/apps/smp-agent/Main.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} @@ -38,7 +39,8 @@ logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} -- Warning: this SMP agent server is experimental - it does not work correctly with multiple connected TCP clients in some cases. main :: IO () main = do - putStrLn $ "SMP agent listening on port " ++ tcpPort (cfg :: AgentConfig) + let AgentConfig {tcpPort} = cfg + putStrLn $ maybe (error "no agent port") (\port -> "SMP agent listening on port " ++ port) tcpPort setLogLevel LogInfo -- LogError Right st <- createAgentStore agentDbFile agentDbKey False MCConsole withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg servers st diff --git a/package.yaml b/package.yaml index 89572ce94..ecd925830 100644 --- a/package.yaml +++ b/package.yaml @@ -179,3 +179,7 @@ ghc-options: - -Wincomplete-record-updates - -Wincomplete-uni-patterns - -Wunused-type-patterns + - -O2 + +default-extensions: + - StrictData diff --git a/simplexmq.cabal b/simplexmq.cabal index 1b35839b7..80549f8f4 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -173,7 +173,9 @@ library Paths_simplexmq hs-source-dirs: src - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns + default-extensions: + StrictData + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 include-dirs: cbits c-sources: @@ -252,7 +254,9 @@ executable ntf-server Paths_simplexmq hs-source-dirs: apps/ntf-server - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts + default-extensions: + StrictData + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts build-depends: aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 @@ -325,7 +329,9 @@ executable smp-agent Paths_simplexmq hs-source-dirs: apps/smp-agent - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts + default-extensions: + StrictData + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts build-depends: aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 @@ -398,7 +404,9 @@ executable smp-server Paths_simplexmq hs-source-dirs: apps/smp-server - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts + default-extensions: + StrictData + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts build-depends: aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 @@ -471,7 +479,9 @@ executable xftp Paths_simplexmq hs-source-dirs: apps/xftp - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts + default-extensions: + StrictData + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts build-depends: aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 @@ -544,7 +554,9 @@ executable xftp-server Paths_simplexmq hs-source-dirs: apps/xftp-server - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded -rtsopts + default-extensions: + StrictData + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts build-depends: aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 @@ -649,7 +661,9 @@ test-suite simplexmq-test Paths_simplexmq hs-source-dirs: tests - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns + default-extensions: + StrictData + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 build-depends: HUnit ==1.6.* , QuickCheck ==2.14.* diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index 2abf8e3dc..bae008e58 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -74,8 +74,9 @@ import Simplex.Messaging.Util (catchAll_, liftError, tshow, unlessM, whenM) import System.FilePath (takeFileName, ()) import UnliftIO import UnliftIO.Directory +import qualified UnliftIO.Exception as E -startXFTPWorkers :: AgentMonad m => AgentClient -> Maybe FilePath -> m () +startXFTPWorkers :: AgentClient -> Maybe FilePath -> AM () startXFTPWorkers c workDir = do wd <- asks $ xftpWorkDir . xftpAgent atomically $ writeTVar wd workDir @@ -84,23 +85,26 @@ startXFTPWorkers c workDir = do startSndFiles cfg startDelFiles cfg where + startRcvFiles :: AgentConfig -> AM () startRcvFiles AgentConfig {rcvFilesTTL} = do pendingRcvServers <- withStore' c (`getPendingRcvFilesServers` rcvFilesTTL) - forM_ pendingRcvServers $ \s -> resumeXFTPRcvWork c (Just s) + lift . forM_ pendingRcvServers $ \s -> resumeXFTPRcvWork c (Just s) -- start local worker for files pending decryption, -- no need to make an extra query for the check -- as the worker will check the store anyway - resumeXFTPRcvWork c Nothing + lift $ resumeXFTPRcvWork c Nothing + startSndFiles :: AgentConfig -> AM () startSndFiles AgentConfig {sndFilesTTL} = do -- start worker for files pending encryption/creation - resumeXFTPSndWork c Nothing + lift $ resumeXFTPSndWork c Nothing pendingSndServers <- withStore' c (`getPendingSndFilesServers` sndFilesTTL) - forM_ pendingSndServers $ \s -> resumeXFTPSndWork c (Just s) + lift . forM_ pendingSndServers $ \s -> resumeXFTPSndWork c (Just s) + startDelFiles :: AgentConfig -> AM () startDelFiles AgentConfig {rcvFilesTTL} = do pendingDelServers <- withStore' c (`getPendingDelFilesServers` rcvFilesTTL) - forM_ pendingDelServers $ resumeXFTPDelWork c + lift . forM_ pendingDelServers $ resumeXFTPDelWork c -closeXFTPAgent :: MonadUnliftIO m => XFTPAgent -> m () +closeXFTPAgent :: XFTPAgent -> IO () closeXFTPAgent a = do stopWorkers $ xftpRcvWorkers a stopWorkers $ xftpSndWorkers a @@ -108,16 +112,16 @@ closeXFTPAgent a = do where stopWorkers workers = atomically (swapTVar workers M.empty) >>= mapM_ (liftIO . cancelWorker) -xftpReceiveFile' :: AgentMonad m => AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Maybe CryptoFileArgs -> m RcvFileId +xftpReceiveFile' :: AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Maybe CryptoFileArgs -> AM RcvFileId xftpReceiveFile' c userId (ValidFileDescription fd@FileDescription {chunks, redirect}) cfArgs = do g <- asks random - prefixPath <- getPrefixPath "rcv.xftp" + prefixPath <- lift $ getPrefixPath "rcv.xftp" createDirectory prefixPath let relPrefixPath = takeFileName prefixPath relTmpPath = relPrefixPath "xftp.encrypted" relSavePath = relPrefixPath "xftp.decrypted" - createDirectory =<< toFSFilePath relTmpPath - createEmptyFile =<< toFSFilePath relSavePath + lift $ createDirectory =<< toFSFilePath relTmpPath + lift $ createEmptyFile =<< toFSFilePath relSavePath let saveFile = CryptoFile relSavePath cfArgs fId <- case redirect of Nothing -> withStore c $ \db -> createRcvFile db g userId fd relPrefixPath relTmpPath saveFile @@ -125,8 +129,8 @@ xftpReceiveFile' c userId (ValidFileDescription fd@FileDescription {chunks, redi -- prepare description paths let relTmpPathRedirect = relPrefixPath "xftp.redirect-encrypted" relSavePathRedirect = relPrefixPath "xftp.redirect-decrypted" - createDirectory =<< toFSFilePath relTmpPathRedirect - createEmptyFile =<< toFSFilePath relSavePathRedirect + lift $ createDirectory =<< toFSFilePath relTmpPathRedirect + lift $ createEmptyFile =<< toFSFilePath relSavePathRedirect cfArgsRedirect <- atomically $ CF.randomArgs g let saveFileRedirect = CryptoFile relSavePathRedirect $ Just cfArgsRedirect -- create download tasks @@ -134,42 +138,42 @@ xftpReceiveFile' c userId (ValidFileDescription fd@FileDescription {chunks, redi forM_ chunks (downloadChunk c) pure fId -downloadChunk :: AgentMonad m => AgentClient -> FileChunk -> m () +downloadChunk :: AgentClient -> FileChunk -> AM () downloadChunk c FileChunk {replicas = (FileChunkReplica {server} : _)} = do - void $ getXFTPRcvWorker True c (Just server) + lift . void $ getXFTPRcvWorker True c (Just server) downloadChunk _ _ = throwError $ INTERNAL "no replicas" -getPrefixPath :: AgentMonad m => String -> m FilePath +getPrefixPath :: String -> AM' FilePath getPrefixPath suffix = do workPath <- getXFTPWorkPath ts <- liftIO getCurrentTime let isoTime = formatTime defaultTimeLocale "%Y%m%d_%H%M%S_%6q" ts uniqueCombine workPath (isoTime <> "_" <> suffix) -toFSFilePath :: AgentMonad m => FilePath -> m FilePath +toFSFilePath :: FilePath -> AM' FilePath toFSFilePath f = ( f) <$> getXFTPWorkPath -createEmptyFile :: AgentMonad m => FilePath -> m () +createEmptyFile :: FilePath -> AM' () createEmptyFile fPath = liftIO $ B.writeFile fPath "" -resumeXFTPRcvWork :: AgentMonad' m => AgentClient -> Maybe XFTPServer -> m () +resumeXFTPRcvWork :: AgentClient -> Maybe XFTPServer -> AM' () resumeXFTPRcvWork = void .: getXFTPRcvWorker False -getXFTPRcvWorker :: AgentMonad' m => Bool -> AgentClient -> Maybe XFTPServer -> m Worker +getXFTPRcvWorker :: Bool -> AgentClient -> Maybe XFTPServer -> AM' Worker getXFTPRcvWorker hasWork c server = do ws <- asks $ xftpRcvWorkers . xftpAgent getAgentWorker "xftp_rcv" hasWork c server ws $ maybe (runXFTPRcvLocalWorker c) (runXFTPRcvWorker c) server -runXFTPRcvWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> Worker -> m () +runXFTPRcvWorker :: AgentClient -> XFTPServer -> Worker -> AM () runXFTPRcvWorker c srv Worker {doWork} = do cfg <- asks config forever $ do - waitForWork doWork + lift $ waitForWork doWork atomically $ assertAgentForeground c runXFTPOperation cfg where - runXFTPOperation :: AgentConfig -> m () + runXFTPOperation :: AgentConfig -> AM () runXFTPOperation AgentConfig {rcvFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} = withWork c doWork (\db -> getNextRcvChunkToDownload db srv rcvFilesTTL) $ \case RcvFileChunk {rcvFileId, rcvFileEntityId, fileTmpPath, replicas = []} -> rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath) "chunk has no replicas" @@ -182,14 +186,14 @@ runXFTPRcvWorker c srv Worker {doWork} = do retryLoop loop e replicaDelay = do flip catchAgentError (\_ -> pure ()) $ do when notifyOnRetry $ notify c rcvFileEntityId $ RFERR e - closeXFTPServerClient c userId server digest + liftIO $ closeXFTPServerClient c userId server digest withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay atomically $ assertAgentForeground c loop retryDone e = rcvWorkerInternalError c rcvFileId rcvFileEntityId (Just fileTmpPath) (show e) - downloadFileChunk :: RcvFileChunk -> RcvFileChunkReplica -> m () + downloadFileChunk :: RcvFileChunk -> RcvFileChunkReplica -> AM () downloadFileChunk RcvFileChunk {userId, rcvFileId, rcvFileEntityId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath} replica = do - fsFileTmpPath <- toFSFilePath fileTmpPath + fsFileTmpPath <- lift $ toFSFilePath fileTmpPath chunkPath <- uniqueCombine fsFileTmpPath $ show chunkNo let chunkSpec = XFTPRcvChunkSpec chunkPath (unFileSize chunkSize) (unFileDigest digest) relChunkPath = fileTmpPath takeFileName chunkPath @@ -206,7 +210,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do liftIO . when complete $ updateRcvFileStatus db rcvFileId RFSReceived pure (entityId, complete, RFPROG rcvd total) notify c entityId progress - when complete . void $ + when complete . lift . void $ getXFTPRcvWorker True c Nothing where receivedSize :: [RcvFileChunk] -> Int64 @@ -222,37 +226,37 @@ withRetryIntervalLimit maxN ri action = withRetryIntervalCount ri $ \n delay loop -> when (n < maxN) $ action delay loop -retryOnError :: AgentMonad m => Text -> m a -> m a -> AgentErrorType -> m a +retryOnError :: Text -> AM a -> AM a -> AgentErrorType -> AM a retryOnError name loop done e = do logError $ name <> " error: " <> tshow e if temporaryAgentError e then loop else done -rcvWorkerInternalError :: AgentMonad m => AgentClient -> DBRcvFileId -> RcvFileId -> Maybe FilePath -> String -> m () +rcvWorkerInternalError :: AgentClient -> DBRcvFileId -> RcvFileId -> Maybe FilePath -> String -> AM () rcvWorkerInternalError c rcvFileId rcvFileEntityId tmpPath internalErrStr = do - forM_ tmpPath (removePath <=< toFSFilePath) + lift $ forM_ tmpPath (removePath <=< toFSFilePath) withStore' c $ \db -> updateRcvFileError db rcvFileId internalErrStr notify c rcvFileEntityId $ RFERR $ INTERNAL internalErrStr -runXFTPRcvLocalWorker :: forall m. AgentMonad m => AgentClient -> Worker -> m () +runXFTPRcvLocalWorker :: AgentClient -> Worker -> AM () runXFTPRcvLocalWorker c Worker {doWork} = do cfg <- asks config forever $ do - waitForWork doWork + lift $ waitForWork doWork atomically $ assertAgentForeground c runXFTPOperation cfg where - runXFTPOperation :: AgentConfig -> m () + runXFTPOperation :: AgentConfig -> AM () runXFTPOperation AgentConfig {rcvFilesTTL} = withWork c doWork (`getNextRcvFileToDecrypt` rcvFilesTTL) $ \f@RcvFile {rcvFileId, rcvFileEntityId, tmpPath} -> decryptFile f `catchAgentError` (rcvWorkerInternalError c rcvFileId rcvFileEntityId tmpPath . show) - decryptFile :: RcvFile -> m () + decryptFile :: RcvFile -> AM () decryptFile RcvFile {rcvFileId, rcvFileEntityId, size, digest, key, nonce, tmpPath, saveFile, status, chunks, redirect} = do let CryptoFile savePath cfArgs = saveFile - fsSavePath <- toFSFilePath savePath - when (status == RFSDecrypting) $ + fsSavePath <- lift $ toFSFilePath savePath + lift . when (status == RFSDecrypting) $ whenM (doesFileExist fsSavePath) (removeFile fsSavePath >> createEmptyFile fsSavePath) withStore' c $ \db -> updateRcvFileStatus db rcvFileId RFSDecrypting chunkPaths <- getChunkPaths chunks @@ -265,16 +269,16 @@ runXFTPRcvLocalWorker c Worker {doWork} = do case redirect of Nothing -> do notify c rcvFileEntityId $ RFDONE fsSavePath - forM_ tmpPath (removePath <=< toFSFilePath) + lift $ forM_ tmpPath (removePath <=< toFSFilePath) atomically $ waitUntilForeground c withStore' c (`updateRcvFileComplete` rcvFileId) Just RcvFileRedirect {redirectFileInfo, redirectDbId} -> do let RedirectFileInfo {size = redirectSize, digest = redirectDigest} = redirectFileInfo - forM_ tmpPath (removePath <=< toFSFilePath) + lift $ forM_ tmpPath (removePath <=< toFSFilePath) atomically $ waitUntilForeground c withStore' c (`updateRcvFileComplete` rcvFileId) -- proceed with redirect - yaml <- liftError (INTERNAL . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `finally` (toFSFilePath fsSavePath >>= removePath) + yaml <- liftError (INTERNAL . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath) next@FileDescription {chunks = nextChunks} <- case strDecode (LB.toStrict yaml) of Left _ -> throwError . XFTP $ XFTP.REDIRECT "decode error" Right (ValidFileDescription fd@FileDescription {size = dstSize, digest = dstDigest}) @@ -285,19 +289,19 @@ runXFTPRcvLocalWorker c Worker {doWork} = do withStore c $ \db -> updateRcvFileRedirect db redirectDbId next forM_ nextChunks (downloadChunk c) where - getChunkPaths :: [RcvFileChunk] -> m [FilePath] + getChunkPaths :: [RcvFileChunk] -> AM [FilePath] getChunkPaths [] = pure [] getChunkPaths (RcvFileChunk {chunkTmpPath = Just path} : cs) = do ps <- getChunkPaths cs - fsPath <- toFSFilePath path + fsPath <- lift $ toFSFilePath path pure $ fsPath : ps getChunkPaths (RcvFileChunk {chunkTmpPath = Nothing} : _cs) = throwError $ INTERNAL "no chunk path" -xftpDeleteRcvFile' :: AgentMonad m => AgentClient -> RcvFileId -> m () +xftpDeleteRcvFile' :: AgentClient -> RcvFileId -> AM' () xftpDeleteRcvFile' c rcvFileEntityId = xftpDeleteRcvFiles' c [rcvFileEntityId] -xftpDeleteRcvFiles' :: forall m. AgentMonad m => AgentClient -> [RcvFileId] -> m () +xftpDeleteRcvFiles' :: AgentClient -> [RcvFileId] -> AM' () xftpDeleteRcvFiles' c rcvFileEntityIds = do rcvFiles <- rights <$> withStoreBatch c (\db -> map (fmap (first storeError) . getRcvFileByEntityId db) rcvFileEntityIds) redirects <- rights <$> batchFiles getRcvFileRedirects rcvFiles @@ -309,29 +313,29 @@ xftpDeleteRcvFiles' c rcvFileEntityIds = do (removePath . (workPath )) prefixPath `catchAll_` pure () where fileComplete RcvFile {status} = status == RFSComplete || status == RFSError - batchFiles :: (DB.Connection -> DBRcvFileId -> IO a) -> [RcvFile] -> m [Either AgentErrorType a] + batchFiles :: (DB.Connection -> DBRcvFileId -> IO a) -> [RcvFile] -> AM' [Either AgentErrorType a] batchFiles f rcvFiles = withStoreBatch' c $ \db -> map (\RcvFile {rcvFileId} -> f db rcvFileId) rcvFiles -notify :: forall m e. (MonadUnliftIO m, AEntityI e) => AgentClient -> EntityId -> ACommand 'Agent e -> m () +notify :: forall m e. (MonadIO m, AEntityI e) => AgentClient -> EntityId -> ACommand 'Agent e -> m () notify c entId cmd = atomically $ writeTBQueue (subQ c) ("", entId, APC (sAEntity @e) cmd) -xftpSendFile' :: AgentMonad m => AgentClient -> UserId -> CryptoFile -> Int -> m SndFileId +xftpSendFile' :: AgentClient -> UserId -> CryptoFile -> Int -> AM SndFileId xftpSendFile' c userId file numRecipients = do g <- asks random - prefixPath <- getPrefixPath "snd.xftp" + prefixPath <- lift $ getPrefixPath "snd.xftp" createDirectory prefixPath let relPrefixPath = takeFileName prefixPath key <- atomically $ C.randomSbKey g nonce <- atomically $ C.randomCbNonce g -- saving absolute filePath will not allow to restore file encryption after app update, but it's a short window fId <- withStore c $ \db -> createSndFile db g userId file numRecipients relPrefixPath key nonce Nothing - void $ getXFTPSndWorker True c Nothing + lift . void $ getXFTPSndWorker True c Nothing pure fId -xftpSendDescription' :: forall m. AgentMonad m => AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Int -> m SndFileId +xftpSendDescription' :: AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Int -> AM SndFileId xftpSendDescription' c userId (ValidFileDescription fdDirect@FileDescription {size, digest}) numRecipients = do g <- asks random - prefixPath <- getPrefixPath "snd.xftp" + prefixPath <- lift $ getPrefixPath "snd.xftp" createDirectory prefixPath let relPrefixPath = takeFileName prefixPath let directYaml = prefixPath "direct.yaml" @@ -341,39 +345,39 @@ xftpSendDescription' c userId (ValidFileDescription fdDirect@FileDescription {si key <- atomically $ C.randomSbKey g nonce <- atomically $ C.randomCbNonce g fId <- withStore c $ \db -> createSndFile db g userId file numRecipients relPrefixPath key nonce $ Just RedirectFileInfo {size, digest} - void $ getXFTPSndWorker True c Nothing + lift . void $ getXFTPSndWorker True c Nothing pure fId -resumeXFTPSndWork :: AgentMonad' m => AgentClient -> Maybe XFTPServer -> m () +resumeXFTPSndWork :: AgentClient -> Maybe XFTPServer -> AM' () resumeXFTPSndWork = void .: getXFTPSndWorker False -getXFTPSndWorker :: AgentMonad' m => Bool -> AgentClient -> Maybe XFTPServer -> m Worker +getXFTPSndWorker :: Bool -> AgentClient -> Maybe XFTPServer -> AM' Worker getXFTPSndWorker hasWork c server = do ws <- asks $ xftpSndWorkers . xftpAgent getAgentWorker "xftp_snd" hasWork c server ws $ maybe (runXFTPSndPrepareWorker c) (runXFTPSndWorker c) server -runXFTPSndPrepareWorker :: forall m. AgentMonad m => AgentClient -> Worker -> m () +runXFTPSndPrepareWorker :: AgentClient -> Worker -> AM () runXFTPSndPrepareWorker c Worker {doWork} = do cfg <- asks config forever $ do - waitForWork doWork + lift $ waitForWork doWork atomically $ assertAgentForeground c runXFTPOperation cfg where - runXFTPOperation :: AgentConfig -> m () + runXFTPOperation :: AgentConfig -> AM () runXFTPOperation cfg@AgentConfig {sndFilesTTL} = withWork c doWork (`getNextSndFileToPrepare` sndFilesTTL) $ \f@SndFile {sndFileId, sndFileEntityId, prefixPath} -> prepareFile cfg f `catchAgentError` (sndWorkerInternalError c sndFileId sndFileEntityId prefixPath . show) - prepareFile :: AgentConfig -> SndFile -> m () + prepareFile :: AgentConfig -> SndFile -> AM () prepareFile _ SndFile {prefixPath = Nothing} = throwError $ INTERNAL "no prefix path" prepareFile cfg sndFile@SndFile {sndFileId, userId, prefixPath = Just ppath, status} = do SndFile {numRecipients, chunks} <- if status /= SFSEncrypted -- status is SFSNew or SFSEncrypting then do - fsEncPath <- toFSFilePath $ sndFileEncPath ppath + fsEncPath <- lift . toFSFilePath $ sndFileEncPath ppath when (status == SFSEncrypting) . whenM (doesFileExist fsEncPath) $ removeFile fsEncPath withStore' c $ \db -> updateSndFileStatus db sndFileId SFSEncrypting @@ -389,7 +393,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do withStore' c $ \db -> updateSndFileStatus db sndFileId SFSUploading where AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients, messageRetryInterval = ri} = cfg - encryptFileForUpload :: SndFile -> FilePath -> m (FileDigest, [(XFTPChunkSpec, FileDigest)]) + encryptFileForUpload :: SndFile -> FilePath -> AM (FileDigest, [(XFTPChunkSpec, FileDigest)]) encryptFileForUpload SndFile {key, nonce, srcFile, redirect} fsEncPath = do let CryptoFile {filePath} = srcFile fileName = takeFileName filePath @@ -412,12 +416,12 @@ runXFTPSndPrepareWorker c Worker {doWork} = do chunkCreated :: SndFileChunk -> Bool chunkCreated SndFileChunk {replicas} = any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSCreated) replicas - createChunk :: Int -> SndFileChunk -> m () + createChunk :: Int -> SndFileChunk -> AM () createChunk numRecipients' ch = do atomically $ assertAgentForeground c (replica, ProtoServerWithAuth srv _) <- tryCreate withStore' c $ \db -> createSndFileReplica db ch replica - void $ getXFTPSndWorker True c (Just srv) + lift . void $ getXFTPSndWorker True c (Just srv) where tryCreate = do usedSrvs <- newTVarIO ([] :: [XFTPServer]) @@ -433,21 +437,21 @@ runXFTPSndPrepareWorker c Worker {doWork} = do replica <- agentXFTPNewChunk c ch numRecipients' srvAuth pure (replica, srvAuth) -sndWorkerInternalError :: AgentMonad m => AgentClient -> DBSndFileId -> SndFileId -> Maybe FilePath -> String -> m () +sndWorkerInternalError :: AgentClient -> DBSndFileId -> SndFileId -> Maybe FilePath -> String -> AM () sndWorkerInternalError c sndFileId sndFileEntityId prefixPath internalErrStr = do - forM_ prefixPath $ removePath <=< toFSFilePath + lift . forM_ prefixPath $ removePath <=< toFSFilePath withStore' c $ \db -> updateSndFileError db sndFileId internalErrStr notify c sndFileEntityId $ SFERR $ INTERNAL internalErrStr -runXFTPSndWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> Worker -> m () +runXFTPSndWorker :: AgentClient -> XFTPServer -> Worker -> AM () runXFTPSndWorker c srv Worker {doWork} = do cfg <- asks config forever $ do - waitForWork doWork + lift $ waitForWork doWork atomically $ assertAgentForeground c runXFTPOperation cfg where - runXFTPOperation :: AgentConfig -> m () + runXFTPOperation :: AgentConfig -> AM () runXFTPOperation cfg@AgentConfig {sndFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} = do withWork c doWork (\db -> getNextSndChunkToUpload db srv sndFilesTTL) $ \case SndFileChunk {sndFileId, sndFileEntityId, filePrefixPath, replicas = []} -> sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) "chunk has no replicas" @@ -460,15 +464,15 @@ runXFTPSndWorker c srv Worker {doWork} = do retryLoop loop e replicaDelay = do flip catchAgentError (\_ -> pure ()) $ do when notifyOnRetry $ notify c sndFileEntityId $ SFERR e - closeXFTPServerClient c userId server digest + liftIO $ closeXFTPServerClient c userId server digest withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay atomically $ assertAgentForeground c loop retryDone e = sndWorkerInternalError c sndFileId sndFileEntityId (Just filePrefixPath) (show e) - uploadFileChunk :: AgentConfig -> SndFileChunk -> SndFileChunkReplica -> m () + uploadFileChunk :: AgentConfig -> SndFileChunk -> SndFileChunkReplica -> AM () uploadFileChunk AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients} sndFileChunk@SndFileChunk {sndFileId, userId, chunkSpec = chunkSpec@XFTPChunkSpec {filePath}, digest = chunkDigest} replica = do replica'@SndFileChunkReplica {sndChunkReplicaId} <- addRecipients sndFileChunk replica - fsFilePath <- toFSFilePath filePath + fsFilePath <- lift $ toFSFilePath filePath unlessM (doesFileExist fsFilePath) $ throwError $ INTERNAL "encrypted file doesn't exist on upload" let chunkSpec' = chunkSpec {filePath = fsFilePath} :: XFTPChunkSpec atomically $ assertAgentForeground c @@ -484,10 +488,10 @@ runXFTPSndWorker c srv Worker {doWork} = do when complete $ do (sndDescr, rcvDescrs) <- sndFileToDescrs sf notify c sndFileEntityId $ SFDONE sndDescr rcvDescrs - forM_ prefixPath $ removePath <=< toFSFilePath + lift . forM_ prefixPath $ removePath <=< toFSFilePath withStore' c $ \db -> updateSndFileComplete db sndFileId where - addRecipients :: SndFileChunk -> SndFileChunkReplica -> m SndFileChunkReplica + addRecipients :: SndFileChunk -> SndFileChunkReplica -> AM SndFileChunkReplica addRecipients ch@SndFileChunk {numRecipients} cr@SndFileChunkReplica {rcvIdsKeys} | length rcvIdsKeys > numRecipients = throwError $ INTERNAL "too many recipients" | length rcvIdsKeys == numRecipients = pure cr @@ -496,7 +500,7 @@ runXFTPSndWorker c srv Worker {doWork} = do rcvIdsKeys' <- agentXFTPAddRecipients c userId chunkDigest cr numRecipients' cr' <- withStore' c $ \db -> addSndChunkReplicaRecipients db cr $ L.toList rcvIdsKeys' addRecipients ch cr' - sndFileToDescrs :: SndFile -> m (ValidFileDescription 'FSender, [ValidFileDescription 'FRecipient]) + sndFileToDescrs :: SndFile -> AM (ValidFileDescription 'FSender, [ValidFileDescription 'FRecipient]) sndFileToDescrs SndFile {digest = Nothing} = throwError $ INTERNAL "snd file has no digest" sndFileToDescrs SndFile {chunks = []} = throwError $ INTERNAL "snd file has no chunks" sndFileToDescrs SndFile {digest = Just digest, key, nonce, chunks = chunks@(fstChunk : _), redirect} = do @@ -511,7 +515,7 @@ runXFTPSndWorker c srv Worker {doWork} = do fdRcvs = createRcvFileDescriptions fdRcv chunks validFdRcvs <- either (throwError . INTERNAL) pure $ mapM validateFileDescription fdRcvs pure (validFdSnd, validFdRcvs) - toSndDescrChunk :: SndFileChunk -> m FileChunk + toSndDescrChunk :: SndFileChunk -> AM FileChunk toSndDescrChunk SndFileChunk {replicas = []} = throwError $ INTERNAL "snd file chunk has no replicas" toSndDescrChunk ch@SndFileChunk {chunkNo, digest = chDigest, replicas = (SndFileChunkReplica {server, replicaId, replicaKey} : _)} = do let chunkSize = FileSize $ sndChunkSize ch @@ -562,10 +566,10 @@ runXFTPSndWorker c srv Worker {doWork} = do chunkUploaded SndFileChunk {replicas} = any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSUploaded) replicas -deleteSndFileInternal :: AgentMonad m => AgentClient -> SndFileId -> m () +deleteSndFileInternal :: AgentClient -> SndFileId -> AM' () deleteSndFileInternal c sndFileEntityId = deleteSndFilesInternal c [sndFileEntityId] -deleteSndFilesInternal :: forall m. AgentMonad m => AgentClient -> [SndFileId] -> m () +deleteSndFilesInternal :: AgentClient -> [SndFileId] -> AM' () deleteSndFilesInternal c sndFileEntityIds = do sndFiles <- rights <$> withStoreBatch c (\db -> map (fmap (first storeError) . getSndFileByEntityId db) sndFileEntityIds) let (toDelete, toMarkDeleted) = partition fileComplete sndFiles @@ -576,15 +580,15 @@ deleteSndFilesInternal c sndFileEntityIds = do batchFiles_ updateSndFileDeleted toMarkDeleted where fileComplete SndFile {status} = status == SFSComplete || status == SFSError - batchFiles_ :: (DB.Connection -> DBSndFileId -> IO a) -> [SndFile] -> m () + batchFiles_ :: (DB.Connection -> DBSndFileId -> IO a) -> [SndFile] -> AM' () batchFiles_ f sndFiles = void $ withStoreBatch' c $ \db -> map (\SndFile {sndFileId} -> f db sndFileId) sndFiles -deleteSndFileRemote :: forall m. AgentMonad m => AgentClient -> UserId -> SndFileId -> ValidFileDescription 'FSender -> m () +deleteSndFileRemote :: AgentClient -> UserId -> SndFileId -> ValidFileDescription 'FSender -> AM' () deleteSndFileRemote c userId sndFileEntityId sfd = deleteSndFilesRemote c userId [(sndFileEntityId, sfd)] -deleteSndFilesRemote :: forall m. AgentMonad m => AgentClient -> UserId -> [(SndFileId, ValidFileDescription 'FSender)] -> m () +deleteSndFilesRemote :: AgentClient -> UserId -> [(SndFileId, ValidFileDescription 'FSender)] -> AM' () deleteSndFilesRemote c userId sndFileIdsDescrs = do - deleteSndFilesInternal c (map fst sndFileIdsDescrs) `catchAgentError` (notify c "" . SFERR) + deleteSndFilesInternal c (map fst sndFileIdsDescrs) `E.catchAny` (notify c "" . SFERR . INTERNAL . show) let rs = concatMap (mapMaybe chunkReplica . fdChunks . snd) sndFileIdsDescrs void $ withStoreBatch' c (\db -> map (uncurry $ createDeletedSndChunkReplica db userId) rs) let servers = S.fromList $ map (\(FileChunkReplica {server}, _) -> server) rs @@ -596,23 +600,23 @@ deleteSndFilesRemote c userId sndFileIdsDescrs = do FileChunk {digest, replicas = replica : _} -> Just (replica, digest) _ -> Nothing -resumeXFTPDelWork :: AgentMonad' m => AgentClient -> XFTPServer -> m () +resumeXFTPDelWork :: AgentClient -> XFTPServer -> AM' () resumeXFTPDelWork = void .: getXFTPDelWorker False -getXFTPDelWorker :: AgentMonad' m => Bool -> AgentClient -> XFTPServer -> m Worker +getXFTPDelWorker :: Bool -> AgentClient -> XFTPServer -> AM' Worker getXFTPDelWorker hasWork c server = do ws <- asks $ xftpDelWorkers . xftpAgent getAgentWorker "xftp_del" hasWork c server ws $ runXFTPDelWorker c server -runXFTPDelWorker :: forall m. AgentMonad m => AgentClient -> XFTPServer -> Worker -> m () +runXFTPDelWorker :: AgentClient -> XFTPServer -> Worker -> AM () runXFTPDelWorker c srv Worker {doWork} = do cfg <- asks config forever $ do - waitForWork doWork + lift $ waitForWork doWork atomically $ assertAgentForeground c runXFTPOperation cfg where - runXFTPOperation :: AgentConfig -> m () + runXFTPOperation :: AgentConfig -> AM () runXFTPOperation AgentConfig {rcvFilesTTL, reconnectInterval = ri, xftpNotifyErrsOnRetry = notifyOnRetry, xftpConsecutiveRetries} = do -- no point in deleting files older than rcv ttl, as they will be expired on server withWork c doWork (\db -> getNextDeletedSndChunkReplica db srv rcvFilesTTL) processDeletedReplica @@ -626,7 +630,7 @@ runXFTPDelWorker c srv Worker {doWork} = do retryLoop loop e replicaDelay = do flip catchAgentError (\_ -> pure ()) $ do when notifyOnRetry $ notify c "" $ SFERR e - closeXFTPServerClient c userId server chunkDigest + liftIO $ closeXFTPServerClient c userId server chunkDigest withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay atomically $ assertAgentForeground c loop @@ -635,7 +639,7 @@ runXFTPDelWorker c srv Worker {doWork} = do agentXFTPDeleteChunk c userId replica withStore' c $ \db -> deleteDeletedSndChunkReplica db deletedSndChunkReplicaId -delWorkerInternalError :: AgentMonad m => AgentClient -> Int64 -> AgentErrorType -> m () +delWorkerInternalError :: AgentClient -> Int64 -> AgentErrorType -> AM () delWorkerInternalError c deletedSndChunkReplicaId e = do withStore' c $ \db -> deleteDeletedSndChunkReplica db deletedSndChunkReplicaId notify c "" $ SFERR e diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index d9c4c058a..ea0c351ca 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -50,7 +50,7 @@ import Simplex.Messaging.Transport.Client (TransportClientConfig, TransportHost) import Simplex.Messaging.Transport.HTTP2 import Simplex.Messaging.Transport.HTTP2.Client import Simplex.Messaging.Transport.HTTP2.File -import Simplex.Messaging.Util (bshow, liftEitherError, whenM) +import Simplex.Messaging.Util (bshow, whenM) import UnliftIO import UnliftIO.Directory @@ -98,7 +98,7 @@ getXFTPClient transportSession@(_, srv, _) config@XFTPClientConfig {xftpNetworkC clientVar <- newTVarIO Nothing let usePort = if null port then "443" else port clientDisconnected = readTVarIO clientVar >>= mapM_ disconnected - http2Client <- liftEitherError xftpClientError $ getVerifiedHTTP2Client (Just username) useHost usePort (Just keyHash) Nothing http2Config clientDisconnected + http2Client <- withExceptT xftpClientError . ExceptT $ getVerifiedHTTP2Client (Just username) useHost usePort (Just keyHash) Nothing http2Config clientDisconnected let HTTP2Client {sessionId} = http2Client thParams = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = currentXFTPVersion, thAuth = Nothing, implySessId = False, batch = True} c = XFTPClient {http2Client, thParams, transportSession, config} @@ -145,7 +145,7 @@ sendXFTPTransmission :: XFTPClient -> ByteString -> Maybe XFTPChunkSpec -> Excep sendXFTPTransmission XFTPClient {config, thParams, http2Client} t chunkSpec_ = do let req = H.requestStreaming N.methodPost "/" [] streamBody reqTimeout = (\XFTPChunkSpec {chunkSize} -> chunkTimeout config chunkSize) <$> chunkSpec_ - HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- liftEitherError xftpClientError $ sendRequest http2Client req reqTimeout + HTTP2Response {respBody = body@HTTP2Body {bodyHead}} <- withExceptT xftpClientError . ExceptT $ sendRequest http2Client req reqTimeout when (B.length bodyHead /= xftpBlockSize) $ throwError $ PCEResponseError BLOCK -- TODO validate that the file ID is the same as in the request? (_, _, (_, _fId, respOrErr)) <- liftEither . first PCEResponseError $ xftpDecodeTransmission thParams bodyHead @@ -196,9 +196,9 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec { let dhSecret = C.dh' sDhKey rpDhKey cbState <- liftEither . first PCECryptoError $ LC.cbInit dhSecret cbNonce let t = chunkTimeout config chunkSize - t `timeout` download cbState >>= maybe (throwError PCEResponseTimeout) pure + ExceptT (sequence <$> (t `timeout` download cbState)) >>= maybe (throwError PCEResponseTimeout) pure where - download cbState = + download cbState = runExceptT $ withExceptT PCEResponseError $ receiveEncFile chunkPart cbState chunkSpec `catchError` \e -> whenM (doesFileExist filePath) (removeFile filePath) >> throwError e diff --git a/src/Simplex/FileTransfer/Client/Agent.hs b/src/Simplex/FileTransfer/Client/Agent.hs index d52b17be5..1dafc8108 100644 --- a/src/Simplex/FileTransfer/Client/Agent.hs +++ b/src/Simplex/FileTransfer/Client/Agent.hs @@ -10,6 +10,7 @@ module Simplex.FileTransfer.Client.Agent where import Control.Logger.Simple (logInfo) import Control.Monad import Control.Monad.Except +import Control.Monad.Trans (lift) import Data.Bifunctor (first) import qualified Data.ByteString.Char8 as B import Data.Text (Text) @@ -86,7 +87,7 @@ getXFTPServerClient XFTPClientAgent {xftpClients, config} srv = do waitForXFTPClient :: XFTPClientVar -> ME XFTPClient waitForXFTPClient clientVar = do let XFTPClientConfig {xftpNetworkConfig = NetworkConfig {tcpConnectTimeout}} = xftpConfig config - client_ <- tcpConnectTimeout `timeout` atomically (readTMVar clientVar) + client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar) liftEither $ case client_ of Just (Right c) -> Right c Just (Left e) -> Left e @@ -110,7 +111,7 @@ getXFTPServerClient XFTPClientAgent {xftpClients, config} srv = do TM.delete srv xftpClients throwError e tryConnectAsync :: ME () - tryConnectAsync = void . async $ do + tryConnectAsync = void . lift . async . runExceptT $ do withRetryInterval (reconnectInterval config) $ \_ loop -> void $ tryConnectClient loop showServer :: XFTPServer -> Text diff --git a/src/Simplex/FileTransfer/Client/Main.hs b/src/Simplex/FileTransfer/Client/Main.hs index bca41cea8..b3fa494ed 100644 --- a/src/Simplex/FileTransfer/Client/Main.hs +++ b/src/Simplex/FileTransfer/Client/Main.hs @@ -37,6 +37,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB import Data.Char (toLower) +import Data.Either (partitionEithers) import Data.Int (Int64) import Data.List (foldl', sortOn) import Data.List.NonEmpty (NonEmpty (..), nonEmpty) @@ -321,7 +322,9 @@ cliSendFileOpts SendOptions {filePath, outputDir, numRecipients, xftpServers, re -- the reason we don't do pooled downloads here within one server is that http2 library doesn't handle cleint concurrency, even though -- upload doesn't allow other requests within the same client until complete (but download does allow). logInfo $ "uploading " <> tshow (length chunks) <> " chunks..." - map snd . sortOn fst . concat <$> pooledForConcurrentlyN 16 chunks' (mapM $ uploadFileChunk a) + (errs, rs) <- partitionEithers . concat <$> liftIO (pooledForConcurrentlyN 16 chunks' . mapM $ runExceptT . uploadFileChunk a) + mapM_ throwError errs + pure $ map snd (sortOn fst rs) where uploadFileChunk :: XFTPClientAgent -> (Int, XFTPChunkSpec, XFTPServerWithAuth) -> ExceptT CLIError IO (Int, SentFileChunk) uploadFileChunk a (chunkNo, chunkSpec@XFTPChunkSpec {chunkSize}, ProtoServerWithAuth xftpServer auth) = do @@ -433,7 +436,9 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath, FileChunkReplica {server} : _ -> server srvChunks = groupAllOn srv chunks g <- liftIO C.newRandom - chunkPaths <- map snd . sortOn fst . concat <$> pooledForConcurrentlyN 16 srvChunks (mapM $ downloadFileChunk g a encPath size downloadedChunks) + (errs, rs) <- partitionEithers . concat <$> liftIO (pooledForConcurrentlyN 16 srvChunks $ mapM $ runExceptT . downloadFileChunk g a encPath size downloadedChunks) + mapM_ throwError errs + let chunkPaths = map snd $ sortOn fst rs encDigest <- liftIO $ LC.sha512Hash <$> readChunks chunkPaths when (encDigest /= unFileDigest digest) $ throwError $ CLIError "File digest mismatch" encSize <- liftIO $ foldM (\s path -> (s +) . fromIntegral <$> getFileSize path) 0 chunkPaths diff --git a/src/Simplex/FileTransfer/Crypto.hs b/src/Simplex/FileTransfer/Crypto.hs index 03dc83a00..547a5675a 100644 --- a/src/Simplex/FileTransfer/Crypto.hs +++ b/src/Simplex/FileTransfer/Crypto.hs @@ -29,7 +29,7 @@ import UnliftIO.Directory (removeFile) encryptFile :: CryptoFile -> ByteString -> C.SbKey -> C.CbNonce -> Int64 -> Int64 -> FilePath -> ExceptT FTCryptoError IO () encryptFile srcFile fileHdr key nonce fileSize' encSize encFile = do sb <- liftEitherWith FTCECryptoError $ LC.sbInit key nonce - CF.withFile srcFile ReadMode $ \r -> withFile encFile WriteMode $ \w -> do + CF.withFile srcFile ReadMode $ \r -> ExceptT . withFile encFile WriteMode $ \w -> runExceptT $ do let lenStr = smpEncode fileSize' (hdr, !sb') = LC.sbEncryptChunk sb $ lenStr <> fileHdr padLen = encSize - authTagSize - fileSize' - 8 diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index 2d957a7b2..dd74a975d 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -330,7 +330,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case -- TODO validate body empty sId <- ExceptT $ addFileRetry st file 3 ts rcps <- mapM (ExceptT . addRecipientRetry st 3 sId) rks - withFileLog $ \sl -> do + lift $ withFileLog $ \sl -> do logAddFile sl sId file ts logAddRecipients sl sId rcps stats <- asks serverStats @@ -362,7 +362,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case st <- asks store r <- runExceptT $ do rcps <- mapM (ExceptT . addRecipientRetry st 3 sId) rks - withFileLog $ \sl -> logAddRecipients sl sId rcps + lift $ withFileLog $ \sl -> logAddRecipients sl sId rcps stats <- asks serverStats atomically $ modifyTVar' (fileRecipients stats) (+ length rks) let rIds = L.map (\(FileRecipient rId _) -> rId) rcps @@ -457,7 +457,7 @@ deleteServerFile_ FileRec {senderId, fileInfo, filePath} = do atomically $ modifyTVar' (filesCount stats) (subtract 1) atomically $ modifyTVar' (filesSize stats) (subtract $ fromIntegral $ size fileInfo) -randomId :: (MonadUnliftIO m, MonadReader XFTPEnv m) => Int -> m ByteString +randomId :: Int -> M ByteString randomId n = atomically . C.randomBytes n =<< asks random getFileId :: M XFTPFileId @@ -465,7 +465,7 @@ getFileId = do size <- asks (fileIdSize . config) atomically . C.randomBytes size =<< asks random -withFileLog :: (MonadIO m, MonadReader XFTPEnv m) => (StoreLog 'WriteMode -> IO a) -> m () +withFileLog :: (StoreLog 'WriteMode -> IO a) -> M () withFileLog action = liftIO . mapM_ action =<< asks storeLog incFileStat :: (FileServerStats -> TVar Int) -> M () diff --git a/src/Simplex/FileTransfer/Server/Control.hs b/src/Simplex/FileTransfer/Server/Control.hs index 54d349c3b..a8786170e 100644 --- a/src/Simplex/FileTransfer/Server/Control.hs +++ b/src/Simplex/FileTransfer/Server/Control.hs @@ -5,7 +5,6 @@ module Simplex.FileTransfer.Server.Control where import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString (ByteString) -import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (BasicAuth) diff --git a/src/Simplex/FileTransfer/Server/Env.hs b/src/Simplex/FileTransfer/Server/Env.hs index d71864a43..a3afe0f60 100644 --- a/src/Simplex/FileTransfer/Server/Env.hs +++ b/src/Simplex/FileTransfer/Server/Env.hs @@ -94,7 +94,7 @@ defaultFileExpiration = checkInterval = 2 * 3600 -- seconds, 2 hours } -newXFTPServerEnv :: (MonadUnliftIO m, MonadRandom m) => XFTPServerConfig -> m XFTPEnv +newXFTPServerEnv :: XFTPServerConfig -> IO XFTPEnv newXFTPServerEnv config@XFTPServerConfig {storeLogFile, fileSizeQuota, caCertificateFile, certificateFile, privateKeyFile} = do random <- liftIO C.newRandom store <- atomically newFileStore diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index e9b05c5f0..7330e823f 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -30,13 +30,11 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md module Simplex.Messaging.Agent ( -- * queue-based SMP agent - getAgentClient, runAgentClient, -- * SMP agent functional API AgentClient (..), - AgentMonad, - AgentErrorMonad, + AE, SubscriptionsInfo (..), getSMPAgentClient, getSMPAgentClient_, @@ -121,7 +119,6 @@ where import Control.Logger.Simple (logError, logInfo, showText) import Control.Monad import Control.Monad.Except -import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Reader import Crypto.Random (ChaChaDRG) import qualified Data.Aeson as J @@ -151,7 +148,7 @@ import Simplex.FileTransfer.Protocol (FileParty (..)) import Simplex.FileTransfer.Util (removePath) import Simplex.Messaging.Agent.Client import Simplex.Messaging.Agent.Env.SQLite -import Simplex.Messaging.Agent.Lock (withLock) +import Simplex.Messaging.Agent.Lock (withLock', withLock) import Simplex.Messaging.Agent.NtfSubSupervisor import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval @@ -188,18 +185,20 @@ import UnliftIO.STM -- import GHC.Conc (unsafeIOToSTM) +type AE a = ExceptT AgentErrorType IO a + -- | Creates an SMP agent client instance -getSMPAgentClient :: MonadIO m => AgentConfig -> InitialAgentServers -> SQLiteStore -> Bool -> m AgentClient +getSMPAgentClient :: AgentConfig -> InitialAgentServers -> SQLiteStore -> Bool -> IO AgentClient getSMPAgentClient = getSMPAgentClient_ 1 {-# INLINE getSMPAgentClient #-} -getSMPAgentClient_ :: MonadIO m => Int -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Bool -> m AgentClient +getSMPAgentClient_ :: Int -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Bool -> IO AgentClient getSMPAgentClient_ clientId cfg initServers store backgroundMode = liftIO $ newSMPAgentEnv cfg store >>= runReaderT runAgent where runAgent = do - c@AgentClient {acThread} <- getAgentClient clientId initServers - t <- runAgentThreads c `forkFinally` const (disconnectAgentClient c) + c@AgentClient {acThread} <- atomically . newAgentClient clientId initServers =<< ask + t <- runAgentThreads c `forkFinally` const (liftIO $ disconnectAgentClient c) atomically . writeTVar acThread . Just =<< mkWeakThreadId t pure c runAgentThreads c @@ -215,7 +214,7 @@ getSMPAgentClient_ clientId cfg initServers store backgroundMode = logError $ "Agent thread " <> name <> " crashed: " <> tshow e atomically $ writeTBQueue subQ ("", "", APC SAEConn $ ERR $ CRITICAL True $ show e) -disconnectAgentClient :: MonadUnliftIO m => AgentClient -> m () +disconnectAgentClient :: AgentClient -> IO () disconnectAgentClient c@AgentClient {agentEnv = Env {ntfSupervisor = ns, xftpAgent = xa}} = do closeAgentClient c closeNtfSupervisor ns @@ -223,311 +222,328 @@ disconnectAgentClient c@AgentClient {agentEnv = Env {ntfSupervisor = ns, xftpAge logConnection c False -- only used in the tests -disposeAgentClient :: MonadUnliftIO m => AgentClient -> m () +disposeAgentClient :: AgentClient -> IO () disposeAgentClient c@AgentClient {acThread, agentEnv = Env {store}} = do t_ <- atomically (swapTVar acThread Nothing) $>>= (liftIO . deRefWeak) disconnectAgentClient c mapM_ killThread t_ liftIO $ closeSQLiteStore store -resumeAgentClient :: MonadIO m => AgentClient -> m () +resumeAgentClient :: AgentClient -> IO () resumeAgentClient c = atomically $ writeTVar (active c) True +{-# INLINE resumeAgentClient #-} -type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m) - -createUser :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> NonEmpty XFTPServerWithAuth -> m UserId +createUser :: AgentClient -> NonEmpty SMPServerWithAuth -> NonEmpty XFTPServerWithAuth -> AE UserId createUser c = withAgentEnv c .: createUser' c +{-# INLINE createUser #-} -- | Delete user record optionally deleting all user's connections on SMP servers -deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> m () +deleteUser :: AgentClient -> UserId -> Bool -> AE () deleteUser c = withAgentEnv c .: deleteUser' c +{-# INLINE deleteUser #-} -- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id -createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> m ConnId +createConnectionAsync :: ConnectionModeI c => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> AE ConnId createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. newConnAsync c userId aCorrId enableNtfs +{-# INLINE createConnectionAsync #-} -- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id -joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +joinConnectionAsync :: AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE ConnId joinConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:: joinConnAsync c userId aCorrId enableNtfs +{-# INLINE joinConnectionAsync #-} -- | Allow connection to continue after CONF notification (LET command), no synchronous response -allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m () +allowConnectionAsync :: AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> AE () allowConnectionAsync c = withAgentEnv c .:: allowConnectionAsync' c +{-# INLINE allowConnectionAsync #-} -- | Accept contact after REQ notification (ACPT command) asynchronously, synchronous response is new connection id -acceptContactAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +acceptContactAsync :: AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE ConnId acceptContactAsync c aCorrId enableNtfs = withAgentEnv c .:: acceptContactAsync' c aCorrId enableNtfs +{-# INLINE acceptContactAsync #-} -- | Acknowledge message (ACK command) asynchronously, no synchronous response -ackMessageAsync :: forall m. AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () +ackMessageAsync :: AgentClient -> ACorrId -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> AE () ackMessageAsync c = withAgentEnv c .:: ackMessageAsync' c +{-# INLINE ackMessageAsync #-} -- | Switch connection to the new receive queue -switchConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m ConnectionStats +switchConnectionAsync :: AgentClient -> ACorrId -> ConnId -> AE ConnectionStats switchConnectionAsync c = withAgentEnv c .: switchConnectionAsync' c +{-# INLINE switchConnectionAsync #-} -- | Delete SMP agent connection (DEL command) asynchronously, no synchronous response -deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> Bool -> ConnId -> m () +deleteConnectionAsync :: AgentClient -> Bool -> ConnId -> AE () deleteConnectionAsync c waitDelivery = withAgentEnv c . deleteConnectionAsync' c waitDelivery +{-# INLINE deleteConnectionAsync #-} -- | Delete SMP agent connections using batch commands asynchronously, no synchronous response -deleteConnectionsAsync :: AgentErrorMonad m => AgentClient -> Bool -> [ConnId] -> m () +deleteConnectionsAsync :: AgentClient -> Bool -> [ConnId] -> AE () deleteConnectionsAsync c waitDelivery = withAgentEnv c . deleteConnectionsAsync' c waitDelivery +{-# INLINE deleteConnectionsAsync #-} -- | Create SMP agent connection (NEW command) -createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) +createConnection :: AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AE (ConnId, ConnectionRequestUri c) createConnection c userId enableNtfs = withAgentEnv c .:: newConn c userId "" enableNtfs +{-# INLINE createConnection #-} -- | Join SMP agent connection (JOIN command) -joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE ConnId joinConnection c userId enableNtfs = withAgentEnv c .:: joinConn c userId "" enableNtfs +{-# INLINE joinConnection #-} -- | Allow connection to continue after CONF notification (LET command) -allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m () +allowConnection :: AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> AE () allowConnection c = withAgentEnv c .:. allowConnection' c +{-# INLINE allowConnection #-} -- | Accept contact after REQ notification (ACPT command) -acceptContact :: AgentErrorMonad m => AgentClient -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +acceptContact :: AgentClient -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE ConnId acceptContact c enableNtfs = withAgentEnv c .:: acceptContact' c "" enableNtfs +{-# INLINE acceptContact #-} -- | Reject contact (RJCT command) -rejectContact :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> m () +rejectContact :: AgentClient -> ConnId -> ConfirmationId -> AE () rejectContact c = withAgentEnv c .: rejectContact' c +{-# INLINE rejectContact #-} -- | Subscribe to receive connection messages (SUB command) -subscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () +subscribeConnection :: AgentClient -> ConnId -> AE () subscribeConnection c = withAgentEnv c . subscribeConnection' c +{-# INLINE subscribeConnection #-} -- | Subscribe to receive connection messages from multiple connections, batching commands when possible -subscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +subscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType ())) subscribeConnections c = withAgentEnv c . subscribeConnections' c +{-# INLINE subscribeConnections #-} -- | Get connection message (GET command) -getConnectionMessage :: AgentErrorMonad m => AgentClient -> ConnId -> m (Maybe SMPMsgMeta) +getConnectionMessage :: AgentClient -> ConnId -> AE (Maybe SMPMsgMeta) getConnectionMessage c = withAgentEnv c . getConnectionMessage' c +{-# INLINE getConnectionMessage #-} -- | Get connection message for received notification -getNotificationMessage :: AgentErrorMonad m => AgentClient -> C.CbNonce -> ByteString -> m (NotificationInfo, [SMPMsgMeta]) +getNotificationMessage :: AgentClient -> C.CbNonce -> ByteString -> AE (NotificationInfo, [SMPMsgMeta]) getNotificationMessage c = withAgentEnv c .: getNotificationMessage' c +{-# INLINE getNotificationMessage #-} -resubscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () +resubscribeConnection :: AgentClient -> ConnId -> AE () resubscribeConnection c = withAgentEnv c . resubscribeConnection' c +{-# INLINE resubscribeConnection #-} -resubscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +resubscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType ())) resubscribeConnections c = withAgentEnv c . resubscribeConnections' c +{-# INLINE resubscribeConnections #-} -- | Send message to the connection (SEND command) -sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, PQEncryption) +sendMessage :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AE (AgentMsgId, PQEncryption) sendMessage c = withAgentEnv c .:: sendMessage' c +{-# INLINE sendMessage #-} type MsgReq = (ConnId, PQEncryption, MsgFlags, MsgBody) -- | Send multiple messages to different connections (SEND command) -sendMessages :: MonadUnliftIO m => AgentClient -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, PQEncryption)] -sendMessages c = withAgentEnv c . sendMessages' c +sendMessages :: AgentClient -> [MsgReq] -> IO [Either AgentErrorType (AgentMsgId, PQEncryption)] +sendMessages c = withAgentEnv' c . sendMessages' c +{-# INLINE sendMessages #-} -sendMessagesB :: (MonadUnliftIO m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, PQEncryption))) -sendMessagesB c = withAgentEnv c . sendMessagesB' c +sendMessagesB :: Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> IO (t (Either AgentErrorType (AgentMsgId, PQEncryption))) +sendMessagesB c = withAgentEnv' c . sendMessagesB' c +{-# INLINE sendMessagesB #-} -ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () +ackMessage :: AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> AE () ackMessage c = withAgentEnv c .:. ackMessage' c +{-# INLINE ackMessage #-} -- | Switch connection to the new receive queue -switchConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats +switchConnection :: AgentClient -> ConnId -> AE ConnectionStats switchConnection c = withAgentEnv c . switchConnection' c +{-# INLINE switchConnection #-} -- | Abort switching connection to the new receive queue -abortConnectionSwitch :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats +abortConnectionSwitch :: AgentClient -> ConnId -> AE ConnectionStats abortConnectionSwitch c = withAgentEnv c . abortConnectionSwitch' c +{-# INLINE abortConnectionSwitch #-} -- | Re-synchronize connection ratchet keys -synchronizeRatchet :: AgentErrorMonad m => AgentClient -> ConnId -> PQSupport -> Bool -> m ConnectionStats +synchronizeRatchet :: AgentClient -> ConnId -> PQSupport -> Bool -> AE ConnectionStats synchronizeRatchet c = withAgentEnv c .:. synchronizeRatchet' c +{-# INLINE synchronizeRatchet #-} -- | Suspend SMP agent connection (OFF command) -suspendConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () +suspendConnection :: AgentClient -> ConnId -> AE () suspendConnection c = withAgentEnv c . suspendConnection' c +{-# INLINE suspendConnection #-} -- | Delete SMP agent connection (DEL command) -deleteConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () +deleteConnection :: AgentClient -> ConnId -> AE () deleteConnection c = withAgentEnv c . deleteConnection' c +{-# INLINE deleteConnection #-} -- | Delete multiple connections, batching commands when possible -deleteConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +deleteConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType ())) deleteConnections c = withAgentEnv c . deleteConnections' c +{-# INLINE deleteConnections #-} -- | get servers used for connection -getConnectionServers :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats +getConnectionServers :: AgentClient -> ConnId -> AE ConnectionStats getConnectionServers c = withAgentEnv c . getConnectionServers' c +{-# INLINE getConnectionServers #-} -- | get connection ratchet associated data hash for verification (should match peer AD hash) -getConnectionRatchetAdHash :: AgentErrorMonad m => AgentClient -> ConnId -> m ByteString +getConnectionRatchetAdHash :: AgentClient -> ConnId -> AE ByteString getConnectionRatchetAdHash c = withAgentEnv c . getConnectionRatchetAdHash' c - --- | Change servers to be used for creating new queues -setProtocolServers :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentErrorMonad m) => AgentClient -> UserId -> NonEmpty (ProtoServerWithAuth p) -> m () -setProtocolServers c = withAgentEnv c .: setProtocolServers' c +{-# INLINE getConnectionRatchetAdHash #-} -- | Test protocol server -testProtocolServer :: forall p m. (ProtocolTypeI p, AgentErrorMonad m) => AgentClient -> UserId -> ProtoServerWithAuth p -> m (Maybe ProtocolTestFailure) -testProtocolServer c userId srv = withAgentEnv c $ case protocolTypeI @p of +testProtocolServer :: forall p. ProtocolTypeI p => AgentClient -> UserId -> ProtoServerWithAuth p -> IO (Maybe ProtocolTestFailure) +testProtocolServer c userId srv = withAgentEnv' c $ case protocolTypeI @p of SPSMP -> runSMPServerTest c userId srv SPXFTP -> runXFTPServerTest c userId srv SPNTF -> runNTFServerTest c userId srv -setNtfServers :: MonadUnliftIO m => AgentClient -> [NtfServer] -> m () -setNtfServers c = withAgentEnv c . setNtfServers' c - -- | set SOCKS5 proxy on/off and optionally set TCP timeout -setNetworkConfig :: MonadUnliftIO m => AgentClient -> NetworkConfig -> m () +setNetworkConfig :: AgentClient -> NetworkConfig -> IO () setNetworkConfig c cfg' = do cfg <- atomically $ do swapTVar (useNetworkConfig c) cfg' when (cfg /= cfg') $ reconnectAllServers c -getNetworkConfig :: AgentErrorMonad m => AgentClient -> m NetworkConfig +getNetworkConfig :: AgentClient -> IO NetworkConfig getNetworkConfig = readTVarIO . useNetworkConfig +{-# INLINE getNetworkConfig #-} -reconnectAllServers :: MonadUnliftIO m => AgentClient -> m () -reconnectAllServers c = liftIO $ do +reconnectAllServers :: AgentClient -> IO () +reconnectAllServers c = do reconnectServerClients c smpClients reconnectServerClients c ntfClients -- | Register device notifications token -registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus +registerNtfToken :: AgentClient -> DeviceToken -> NotificationsMode -> AE NtfTknStatus registerNtfToken c = withAgentEnv c .: registerNtfToken' c +{-# INLINE registerNtfToken #-} -- | Verify device notifications token -verifyNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> C.CbNonce -> ByteString -> m () +verifyNtfToken :: AgentClient -> DeviceToken -> C.CbNonce -> ByteString -> AE () verifyNtfToken c = withAgentEnv c .:. verifyNtfToken' c +{-# INLINE verifyNtfToken #-} -checkNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m NtfTknStatus +checkNtfToken :: AgentClient -> DeviceToken -> AE NtfTknStatus checkNtfToken c = withAgentEnv c . checkNtfToken' c +{-# INLINE checkNtfToken #-} -deleteNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m () +deleteNtfToken :: AgentClient -> DeviceToken -> AE () deleteNtfToken c = withAgentEnv c . deleteNtfToken' c +{-# INLINE deleteNtfToken #-} -getNtfToken :: AgentErrorMonad m => AgentClient -> m (DeviceToken, NtfTknStatus, NotificationsMode, NtfServer) +getNtfToken :: AgentClient -> AE (DeviceToken, NtfTknStatus, NotificationsMode, NtfServer) getNtfToken c = withAgentEnv c $ getNtfToken' c +{-# INLINE getNtfToken #-} -getNtfTokenData :: AgentErrorMonad m => AgentClient -> m NtfToken +getNtfTokenData :: AgentClient -> AE NtfToken getNtfTokenData c = withAgentEnv c $ getNtfTokenData' c +{-# INLINE getNtfTokenData #-} -- | Set connection notifications on/off -toggleConnectionNtfs :: AgentErrorMonad m => AgentClient -> ConnId -> Bool -> m () +toggleConnectionNtfs :: AgentClient -> ConnId -> Bool -> AE () toggleConnectionNtfs c = withAgentEnv c .: toggleConnectionNtfs' c +{-# INLINE toggleConnectionNtfs #-} -xftpStartWorkers :: AgentErrorMonad m => AgentClient -> Maybe FilePath -> m () +xftpStartWorkers :: AgentClient -> Maybe FilePath -> AE () xftpStartWorkers c = withAgentEnv c . startXFTPWorkers c +{-# INLINE xftpStartWorkers #-} -- | Receive XFTP file -xftpReceiveFile :: AgentErrorMonad m => AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Maybe CryptoFileArgs -> m RcvFileId +xftpReceiveFile :: AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Maybe CryptoFileArgs -> AE RcvFileId xftpReceiveFile c = withAgentEnv c .:. xftpReceiveFile' c +{-# INLINE xftpReceiveFile #-} -- | Delete XFTP rcv file (deletes work files from file system and db records) -xftpDeleteRcvFile :: AgentErrorMonad m => AgentClient -> RcvFileId -> m () -xftpDeleteRcvFile c = withAgentEnv c . xftpDeleteRcvFile' c +xftpDeleteRcvFile :: AgentClient -> RcvFileId -> IO () +xftpDeleteRcvFile c = withAgentEnv' c . xftpDeleteRcvFile' c +{-# INLINE xftpDeleteRcvFile #-} -- | Delete multiple rcv files, batching operations when possible (deletes work files from file system and db records) -xftpDeleteRcvFiles :: AgentErrorMonad m => AgentClient -> [RcvFileId] -> m () -xftpDeleteRcvFiles c = withAgentEnv c . xftpDeleteRcvFiles' c +xftpDeleteRcvFiles :: AgentClient -> [RcvFileId] -> IO () +xftpDeleteRcvFiles c = withAgentEnv' c . xftpDeleteRcvFiles' c +{-# INLINE xftpDeleteRcvFiles #-} -- | Send XFTP file -xftpSendFile :: AgentErrorMonad m => AgentClient -> UserId -> CryptoFile -> Int -> m SndFileId +xftpSendFile :: AgentClient -> UserId -> CryptoFile -> Int -> AE SndFileId xftpSendFile c = withAgentEnv c .:. xftpSendFile' c +{-# INLINE xftpSendFile #-} -- | Send XFTP file -xftpSendDescription :: AgentErrorMonad m => AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Int -> m SndFileId +xftpSendDescription :: AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Int -> AE SndFileId xftpSendDescription c = withAgentEnv c .:. xftpSendDescription' c +{-# INLINE xftpSendDescription #-} -- | Delete XFTP snd file internally (deletes work files from file system and db records) -xftpDeleteSndFileInternal :: AgentErrorMonad m => AgentClient -> SndFileId -> m () -xftpDeleteSndFileInternal c = withAgentEnv c . deleteSndFileInternal c +xftpDeleteSndFileInternal :: AgentClient -> SndFileId -> IO () +xftpDeleteSndFileInternal c = withAgentEnv' c . deleteSndFileInternal c +{-# INLINE xftpDeleteSndFileInternal #-} -- | Delete multiple snd files internally, batching operations when possible (deletes work files from file system and db records) -xftpDeleteSndFilesInternal :: AgentErrorMonad m => AgentClient -> [SndFileId] -> m () -xftpDeleteSndFilesInternal c = withAgentEnv c . deleteSndFilesInternal c +xftpDeleteSndFilesInternal :: AgentClient -> [SndFileId] -> IO () +xftpDeleteSndFilesInternal c = withAgentEnv' c . deleteSndFilesInternal c +{-# INLINE xftpDeleteSndFilesInternal #-} -- | Delete XFTP snd file chunks on servers -xftpDeleteSndFileRemote :: AgentErrorMonad m => AgentClient -> UserId -> SndFileId -> ValidFileDescription 'FSender -> m () -xftpDeleteSndFileRemote c = withAgentEnv c .:. deleteSndFileRemote c +xftpDeleteSndFileRemote :: AgentClient -> UserId -> SndFileId -> ValidFileDescription 'FSender -> IO () +xftpDeleteSndFileRemote c = withAgentEnv' c .:. deleteSndFileRemote c +{-# INLINE xftpDeleteSndFileRemote #-} -- | Delete XFTP snd file chunks on servers for multiple snd files, batching operations when possible -xftpDeleteSndFilesRemote :: AgentErrorMonad m => AgentClient -> UserId -> [(SndFileId, ValidFileDescription 'FSender)] -> m () -xftpDeleteSndFilesRemote c = withAgentEnv c .: deleteSndFilesRemote c +xftpDeleteSndFilesRemote :: AgentClient -> UserId -> [(SndFileId, ValidFileDescription 'FSender)] -> IO () +xftpDeleteSndFilesRemote c = withAgentEnv' c .: deleteSndFilesRemote c +{-# INLINE xftpDeleteSndFilesRemote #-} -- | Create new remote host pairing -rcNewHostPairing :: AgentErrorMonad m => AgentClient -> m RCHostPairing -rcNewHostPairing c = withAgentEnv c $ liftIO . newRCHostPairing =<< asks random +rcNewHostPairing :: AgentClient -> IO RCHostPairing +rcNewHostPairing AgentClient {agentEnv = Env {random}} = newRCHostPairing random +{-# INLINE rcNewHostPairing #-} -- | start TLS server for remote host with optional multicast -rcConnectHost :: AgentErrorMonad m => AgentClient -> RCHostPairing -> J.Value -> Bool -> Maybe RCCtrlAddress -> Maybe Word16 -> m RCHostConnection -rcConnectHost c = withAgentEnv c .::. rcConnectHost' - -rcConnectHost' :: AgentMonad m => RCHostPairing -> J.Value -> Bool -> Maybe RCCtrlAddress -> Maybe Word16 -> m RCHostConnection -rcConnectHost' pairing ctrlAppInfo multicast rcAddr_ port_ = do - drg <- asks random - liftError RCP $ connectRCHost drg pairing ctrlAppInfo multicast rcAddr_ port_ +rcConnectHost :: AgentClient -> RCHostPairing -> J.Value -> Bool -> Maybe RCCtrlAddress -> Maybe Word16 -> AE RCHostConnection +rcConnectHost AgentClient {agentEnv = Env {random}} = withExceptT RCP .::. connectRCHost random +{-# INLINE rcConnectHost #-} -- | connect to remote controller via URI -rcConnectCtrl :: AgentErrorMonad m => AgentClient -> RCVerifiedInvitation -> Maybe RCCtrlPairing -> J.Value -> m RCCtrlConnection -rcConnectCtrl c = withAgentEnv c .:. rcConnectCtrl' - -rcConnectCtrl' :: AgentMonad m => RCVerifiedInvitation -> Maybe RCCtrlPairing -> J.Value -> m RCCtrlConnection -rcConnectCtrl' verifiedInv pairing_ hostAppInfo = do - drg <- asks random - liftError RCP $ connectRCCtrl drg verifiedInv pairing_ hostAppInfo +rcConnectCtrl :: AgentClient -> RCVerifiedInvitation -> Maybe RCCtrlPairing -> J.Value -> AE RCCtrlConnection +rcConnectCtrl AgentClient {agentEnv = Env {random}} = withExceptT RCP .:. connectRCCtrl random +{-# INLINE rcConnectCtrl #-} -- | connect to known remote controller via multicast -rcDiscoverCtrl :: AgentErrorMonad m => AgentClient -> NonEmpty RCCtrlPairing -> m (RCCtrlPairing, RCVerifiedInvitation) -rcDiscoverCtrl c = withAgentEnv c . rcDiscoverCtrl' +rcDiscoverCtrl :: AgentClient -> NonEmpty RCCtrlPairing -> AE (RCCtrlPairing, RCVerifiedInvitation) +rcDiscoverCtrl AgentClient {agentEnv = Env {multicastSubscribers = subs}} = withExceptT RCP . discoverRCCtrl subs +{-# INLINE rcDiscoverCtrl #-} -rcDiscoverCtrl' :: AgentMonad m => NonEmpty RCCtrlPairing -> m (RCCtrlPairing, RCVerifiedInvitation) -rcDiscoverCtrl' pairings = do - subs <- asks multicastSubscribers - liftError RCP $ discoverRCCtrl subs pairings - --- | Activate operations -foregroundAgent :: MonadUnliftIO m => AgentClient -> m () -foregroundAgent c = withAgentEnv c $ foregroundAgent' c - --- | Suspend operations with max delay to deliver pending messages -suspendAgent :: MonadUnliftIO m => AgentClient -> Int -> m () -suspendAgent c = withAgentEnv c . suspendAgent' c - -execAgentStoreSQL :: AgentErrorMonad m => AgentClient -> Text -> m [Text] -execAgentStoreSQL c = withAgentEnv c . execAgentStoreSQL' c - -getAgentMigrations :: AgentErrorMonad m => AgentClient -> m [UpMigration] -getAgentMigrations c = withAgentEnv c $ getAgentMigrations' c - -debugAgentLocks :: MonadUnliftIO m => AgentClient -> m AgentLocks -debugAgentLocks c = withAgentEnv c $ debugAgentLocks' c - -getAgentStats :: MonadIO m => AgentClient -> m [(AgentStatsKey, Int)] +getAgentStats :: AgentClient -> IO [(AgentStatsKey, Int)] getAgentStats c = readTVarIO (agentStats c) >>= mapM (\(k, cnt) -> (k,) <$> readTVarIO cnt) . M.assocs -resetAgentStats :: MonadIO m => AgentClient -> m () +resetAgentStats :: AgentClient -> IO () resetAgentStats = atomically . TM.clear . agentStats +{-# INLINE resetAgentStats #-} -withAgentEnv :: AgentClient -> ReaderT Env m a -> m a -withAgentEnv c = (`runReaderT` agentEnv c) +withAgentEnv' :: AgentClient -> AM' a -> IO a +withAgentEnv' c = (`runReaderT` agentEnv c) +{-# INLINE withAgentEnv' #-} --- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. -getAgentClient :: AgentMonad' m => Int -> InitialAgentServers -> m AgentClient -getAgentClient clientId initServers = ask >>= atomically . newAgentClient clientId initServers -{-# INLINE getAgentClient #-} +withAgentEnv :: AgentClient -> AM a -> AE a +withAgentEnv c a = ExceptT $ runExceptT a `runReaderT` agentEnv c +{-# INLINE withAgentEnv #-} -logConnection :: MonadUnliftIO m => AgentClient -> Bool -> m () +logConnection :: AgentClient -> Bool -> IO () logConnection c connected = let event = if connected then "connected to" else "disconnected from" in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"] -- | Runs an SMP agent instance that receives commands and sends responses via 'TBQueue's. -runAgentClient :: AgentMonad' m => AgentClient -> m () +runAgentClient :: AgentClient -> AM' () runAgentClient c = race_ (subscriber c) (client c) +{-# INLINE runAgentClient #-} -client :: forall m. AgentMonad' m => AgentClient -> m () +client :: AgentClient -> AM' () client c@AgentClient {rcvQ, subQ} = forever $ do (corrId, entId, cmd) <- atomically $ readTBQueue rcvQ runExceptT (processCommand c (entId, cmd)) @@ -536,7 +552,7 @@ client c@AgentClient {rcvQ, subQ} = forever $ do Right (entId', resp) -> (corrId, entId', resp) -- | execute any SMP agent command -processCommand :: forall m. AgentMonad m => AgentClient -> (EntityId, APartyCmd 'Client) -> m (EntityId, APartyCmd 'Agent) +processCommand :: AgentClient -> (EntityId, APartyCmd 'Client) -> AM (EntityId, APartyCmd 'Agent) processCommand c (connId, APC e cmd) = second (APC e) <$> case cmd of NEW enableNtfs (ACM cMode) pqIK subMode -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing pqIK subMode @@ -556,14 +572,14 @@ processCommand c (connId, APC e cmd) = userId :: UserId userId = 1 -createUser' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> NonEmpty XFTPServerWithAuth -> m UserId +createUser' :: AgentClient -> NonEmpty SMPServerWithAuth -> NonEmpty XFTPServerWithAuth -> AM UserId createUser' c smp xftp = do userId <- withStore' c createUserRecord atomically $ TM.insert userId smp $ smpServers c atomically $ TM.insert userId xftp $ xftpServers c pure userId -deleteUser' :: AgentMonad m => AgentClient -> UserId -> Bool -> m () +deleteUser' :: AgentClient -> UserId -> Bool -> AM () deleteUser' c userId delSMPQueues = do if delSMPQueues then withStore c (`setUserDeleted` userId) >>= deleteConnectionsAsync_ delUser c False @@ -574,23 +590,23 @@ deleteUser' c userId delSMPQueues = do whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId) -newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> m ConnId +newConnAsync :: ConnectionModeI c => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> AM ConnId newConnAsync c userId corrId enableNtfs cMode pqInitKeys subMode = do connId <- newConnNoQueues c userId "" enableNtfs cMode (CR.connPQEncryption pqInitKeys) enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ NEW enableNtfs (ACM cMode) pqInitKeys subMode pure connId -newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> PQSupport -> m ConnId +newConnNoQueues :: AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> PQSupport -> AM ConnId newConnNoQueues c userId connId enableNtfs cMode pqSupport = do g <- asks random connAgentVersion <- asks $ maxVersion . ($ pqSupport) . smpAgentVRange . config let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} withStore c $ \db -> createNewConn db g cData cMode -joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +joinConnAsync :: AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId joinConnAsync c userId corrId enableNtfs cReqUri@CRInvitationUri {} cInfo pqSup subMode = do withInvLock c (strEncode cReqUri) "joinConnAsync" $ do - compatibleInvitationUri cReqUri pqSup >>= \case + lift (compatibleInvitationUri cReqUri pqSup) >>= \case Just (_, Compatible (CR.E2ERatchetParams v _ _ _), Compatible connAgentVersion) -> do g <- asks random let pqSupport = pqSup `CR.pqSupportAnd` versionPQSupport_ connAgentVersion (Just v) @@ -602,14 +618,14 @@ joinConnAsync c userId corrId enableNtfs cReqUri@CRInvitationUri {} cInfo pqSup joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo _pqEncryption = throwError $ CMD PROHIBITED -allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m () +allowConnectionAsync' :: AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> AM () allowConnectionAsync' c corrId connId confId ownConnInfo = withStore c (`getConn` connId) >>= \case SomeConn _ (RcvConnection _ RcvQueue {server}) -> enqueueCommand c corrId connId (Just server) $ AClientCommand $ APC SAEConn $ LET confId ownConnInfo _ -> throwError $ CMD PROHIBITED -acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +acceptContactAsync' :: AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId acceptContactAsync' c corrId enableNtfs invId ownConnInfo pqSupport subMode = do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case @@ -620,7 +636,7 @@ acceptContactAsync' c corrId enableNtfs invId ownConnInfo pqSupport subMode = do throwError err _ -> throwError $ CMD PROHIBITED -ackMessageAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () +ackMessageAsync' :: AgentClient -> ACorrId -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> AM () ackMessageAsync' c corrId connId msgId rcptInfo_ = do SomeConn cType _ <- withStore c (`getConn` connId) case cType of @@ -630,7 +646,7 @@ ackMessageAsync' c corrId connId msgId rcptInfo_ = do SCContact -> throwError $ CMD PROHIBITED SCNew -> throwError $ CMD PROHIBITED where - enqueueAck :: m () + enqueueAck :: AM () enqueueAck = do let mId = InternalId msgId RcvMsg {msgType} <- withStore c $ \db -> getRcvMsg db connId mId @@ -638,24 +654,26 @@ ackMessageAsync' c corrId connId msgId rcptInfo_ = do (RcvQueue {server}, _) <- withStore c $ \db -> setMsgUserAck db connId mId enqueueCommand c corrId connId (Just server) . AClientCommand $ APC SAEConn $ ACK msgId rcptInfo_ -deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> Bool -> ConnId -> m () +deleteConnectionAsync' :: AgentClient -> Bool -> ConnId -> AM () deleteConnectionAsync' c waitDelivery connId = deleteConnectionsAsync' c waitDelivery [connId] +{-# INLINE deleteConnectionAsync' #-} -deleteConnectionsAsync' :: AgentMonad m => AgentClient -> Bool -> [ConnId] -> m () +deleteConnectionsAsync' :: AgentClient -> Bool -> [ConnId] -> AM () deleteConnectionsAsync' = deleteConnectionsAsync_ $ pure () +{-# INLINE deleteConnectionsAsync' #-} -deleteConnectionsAsync_ :: forall m. AgentMonad m => m () -> AgentClient -> Bool -> [ConnId] -> m () +deleteConnectionsAsync_ :: AM () -> AgentClient -> Bool -> [ConnId] -> AM () deleteConnectionsAsync_ onSuccess c waitDelivery connIds = case connIds of [] -> onSuccess _ -> do (_, rqs, connIds') <- prepareDeleteConnections_ getConns c waitDelivery connIds withStore' c $ \db -> forM_ connIds' $ setConnDeleted db waitDelivery - void . forkIO $ - withLock (deleteLock c) "deleteConnectionsAsync" $ - deleteConnQueues c waitDelivery True rqs >> onSuccess + void . lift . forkIO $ + withLock' (deleteLock c) "deleteConnectionsAsync" $ + deleteConnQueues c waitDelivery True rqs >> void (runExceptT onSuccess) -- | Add connection to the new receive queue -switchConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ConnectionStats +switchConnectionAsync' :: AgentClient -> ACorrId -> ConnId -> AM ConnectionStats switchConnectionAsync' c corrId connId = withConnLock c connId "switchConnectionAsync" $ withStore c (`getConn` connId) >>= \case @@ -669,16 +687,16 @@ switchConnectionAsync' c corrId connId = pure . connectionStats $ DuplexConnection cData rqs' sqs _ -> throwError $ CMD PROHIBITED -newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) +newConn :: AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AM (ConnId, ConnectionRequestUri c) newConn c userId connId enableNtfs cMode clientData pqInitKeys subMode = getSMPServer c userId >>= newConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode -newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) +newConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (ConnId, ConnectionRequestUri c) newConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv = do connId' <- newConnNoQueues c userId connId enableNtfs cMode (CR.connPQEncryption pqInitKeys) newRcvConnSrv c userId connId' enableNtfs cMode clientData pqInitKeys subMode srv -newRcvConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) +newRcvConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (ConnId, ConnectionRequestUri c) newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv = do case (cMode, pqInitKeys) of (SCMContact, CR.IKUsePQ) -> throwError $ CMD PROHIBITED @@ -686,7 +704,7 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config (rq, qUri) <- newRcvQueue c userId connId srv smpClientVRange subMode `catchAgentError` \e -> liftIO (print e) >> throwError e rq' <- withStore c $ \db -> updateNewConnRcv db connId rq - case subMode of + liftIO $ case subMode of SMOnlyCreate -> pure () SMSubscribe -> addSubscription c rq' when enableNtfs $ do @@ -703,7 +721,7 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2 pKem pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eVRange) -joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +joinConn :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId joinConn c userId connId enableNtfs cReq cInfo pqSupport subMode = do srv <- case cReq of CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> @@ -711,9 +729,9 @@ joinConn c userId connId enableNtfs cReq cInfo pqSupport subMode = do _ -> getSMPServer c userId joinConnSrv c userId connId enableNtfs cReq cInfo pqSupport subMode srv -startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> PQSupport -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) +startJoinInvitation :: UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> PQSupport -> AM (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) startJoinInvitation userId connId enableNtfs cReqUri pqSup = - compatibleInvitationUri cReqUri pqSup >>= \case + lift (compatibleInvitationUri cReqUri pqSup) >>= \case Just (qInfo, (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), aVersion@(Compatible connAgentVersion)) -> do g <- asks random let pqSupport = pqSup `CR.pqSupportAnd` versionPQSupport_ connAgentVersion (Just v) @@ -723,13 +741,13 @@ startJoinInvitation userId connId enableNtfs cReqUri pqSup = maxSupported <- asks $ maxVersion . ($ pqSup) . e2eEncryptVRange . config let rcVs = CR.RatchetVersions {current = v, maxSupported} rc = CR.initSndRatchet rcVs rcDHRr rcDHRs rcParams - q <- newSndQueue userId "" qInfo + q <- lift $ newSndQueue userId "" qInfo let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} pure (aVersion, cData, q, rc, e2eSndParams) Nothing -> throwError $ AGENT A_VERSION -connRequestPQSupport :: MonadUnliftIO m => AgentClient -> PQSupport -> ConnectionRequestUri c -> m (Maybe (VersionSMPA, PQSupport)) -connRequestPQSupport c pqSup cReq = withAgentEnv c $ case cReq of +connRequestPQSupport :: AgentClient -> PQSupport -> ConnectionRequestUri c -> IO (Maybe (VersionSMPA, PQSupport)) +connRequestPQSupport c pqSup cReq = withAgentEnv' c $ case cReq of CRInvitationUri {} -> invPQSupported <$$> compatibleInvitationUri cReq pqSup where invPQSupported (_, Compatible (CR.E2ERatchetParams e2eV _ _ _), Compatible agentV) = (agentV, pqSup `CR.pqSupportAnd` versionPQSupport_ agentV (Just e2eV)) @@ -737,7 +755,7 @@ connRequestPQSupport c pqSup cReq = withAgentEnv c $ case cReq of where ctPQSupported (_, Compatible agentV) = (agentV, pqSup `CR.pqSupportAnd` versionPQSupport_ agentV Nothing) -compatibleInvitationUri :: AgentMonad' m => ConnectionRequestUri 'CMInvitation -> PQSupport -> m (Maybe (Compatible SMPQueueInfo, Compatible (CR.RcvE2ERatchetParams 'C.X448), Compatible VersionSMPA)) +compatibleInvitationUri :: ConnectionRequestUri 'CMInvitation -> PQSupport -> AM' (Maybe (Compatible SMPQueueInfo, Compatible (CR.RcvE2ERatchetParams 'C.X448), Compatible VersionSMPA)) compatibleInvitationUri (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqSup = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config pure $ @@ -746,7 +764,7 @@ compatibleInvitationUri (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQue <*> (e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange pqSup) <*> (crAgentVRange `compatibleVersion` smpAgentVRange pqSup) -compatibleContactUri :: AgentMonad' m => ConnectionRequestUri 'CMContact -> PQSupport -> m (Maybe (Compatible SMPQueueInfo, Compatible VersionSMPA)) +compatibleContactUri :: ConnectionRequestUri 'CMContact -> PQSupport -> AM' (Maybe (Compatible SMPQueueInfo, Compatible VersionSMPA)) compatibleContactUri (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) pqSup = do AgentConfig {smpClientVRange, smpAgentVRange} <- asks config pure $ @@ -756,8 +774,9 @@ compatibleContactUri (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = versionPQSupport_ :: VersionSMPA -> Maybe CR.VersionE2E -> PQSupport versionPQSupport_ agentV e2eV_ = PQSupport $ agentV >= pqdrSMPAgentVersion && maybe True (>= CR.pqRatchetE2EEncryptVersion) e2eV_ +{-# INLINE versionPQSupport_ #-} -joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> m ConnId +joinConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM ConnId joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSup @@ -774,14 +793,14 @@ joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMod void $ withStore' c $ \db -> deleteConn db Nothing connId' throwError e joinConnSrv c userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup subMode srv = - compatibleContactUri cReqUri pqSup >>= \case + lift (compatibleContactUri cReqUri pqSup) >>= \case Just (qInfo, vrsn) -> do (connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing (CR.IKNoPQ pqSup) subMode srv sendInvitation c userId qInfo vrsn cReq cInfo pure connId' Nothing -> throwError $ AGENT A_VERSION -joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> m () +joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM () joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = do (_aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSupport q' <- withStore c $ \db -> runExceptT $ do @@ -791,12 +810,12 @@ joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSuppo joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode _pqSupport _srv = do throwError $ CMD PROHIBITED -createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> m SMPQueueInfo +createReplyQueue :: AgentClient -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> AM SMPQueueInfo createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} subMode srv = do (rq, qUri) <- newRcvQueue c userId connId srv (versionToRange smpClientVersion) subMode let qInfo = toVersionT qUri smpClientVersion rq' <- withStore c $ \db -> upgradeSndConnToDuplex db connId rq - case subMode of + liftIO $ case subMode of SMOnlyCreate -> pure () SMSubscribe -> addSubscription c rq' when enableNtfs $ do @@ -805,7 +824,7 @@ createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVers pure qInfo -- | Approve confirmation (LET command) in Reader monad -allowConnection' :: AgentMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m () +allowConnection' :: AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> AM () allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConnection" $ do withStore c (`getConn` connId) >>= \case SomeConn _ (RcvConnection _ rq@RcvQueue {server, rcvId, e2ePrivKey, smpClientVersion = v}) -> do @@ -819,7 +838,7 @@ allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConne _ -> throwError $ CMD PROHIBITED -- | Accept contact (ACPT command) in Reader monad -acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +acceptContact' :: AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId acceptContact' c connId enableNtfs invId ownConnInfo pqSupport subMode = withConnLock c connId "acceptContact" $ do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case @@ -831,15 +850,17 @@ acceptContact' c connId enableNtfs invId ownConnInfo pqSupport subMode = withCon _ -> throwError $ CMD PROHIBITED -- | Reject contact (RJCT command) in Reader monad -rejectContact' :: AgentMonad m => AgentClient -> ConnId -> InvitationId -> m () +rejectContact' :: AgentClient -> ConnId -> InvitationId -> AM () rejectContact' c contactConnId invId = withStore c $ \db -> deleteInvitation db contactConnId invId +{-# INLINE rejectContact' #-} -- | Subscribe to receive connection messages (SUB command) in Reader monad -subscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m () +subscribeConnection' :: AgentClient -> ConnId -> AM () subscribeConnection' c connId = toConnResult connId =<< subscribeConnections' c [connId] +{-# INLINE subscribeConnection' #-} -toConnResult :: AgentMonad m => ConnId -> Map ConnId (Either AgentErrorType ()) -> m () +toConnResult :: ConnId -> Map ConnId (Either AgentErrorType ()) -> AM () toConnResult connId rs = case M.lookup connId rs of Just (Right ()) -> when (M.size rs > 1) $ logError $ T.pack $ "too many results " <> show (M.size rs) Just (Left e) -> throwError e @@ -847,19 +868,19 @@ toConnResult connId rs = case M.lookup connId rs of type QCmdResult = (QueueStatus, Either AgentErrorType ()) -subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +subscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) subscribeConnections' _ [] = pure M.empty subscribeConnections' c connIds = do conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (`getConns` connIds) let (errs, cs) = M.mapEither id conns errs' = M.map (Left . storeError) errs (subRs, rcvQs) = M.mapEither rcvQueueOrResult cs - mapM_ (mapM_ (\(cData, sqs) -> mapM_ (resumeMsgDelivery c cData) sqs) . sndQueue) cs + mapM_ (mapM_ (\(cData, sqs) -> mapM_ (lift . resumeMsgDelivery c cData) sqs) . sndQueue) cs mapM_ (resumeConnCmds c) $ M.keys cs - rcvRs <- connResults <$> subscribeQueues c (concat $ M.elems rcvQs) + rcvRs <- lift $ connResults <$> subscribeQueues c (concat $ M.elems rcvQs) ns <- asks ntfSupervisor tkn <- readTVarIO (ntfTkn ns) - when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs conns + when (instantNotifications tkn) . void . lift . forkIO . void . runExceptT $ sendNtfCreate ns rcvRs conns let rs = M.unions ([errs', subRs, rcvRs] :: [Map ConnId (Either AgentErrorType ())]) notifyResultError rs pure rs @@ -891,7 +912,7 @@ subscribeConnections' c connIds = do order (Active, _) = 2 order (_, Right _) = 3 order _ = 4 - sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> Map ConnId (Either StoreError SomeConn) -> m () + sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> Map ConnId (Either StoreError SomeConn) -> AM () sendNtfCreate ns rcvRs conns = forM_ (M.assocs rcvRs) $ \case (connId, Right _) -> forM_ (M.lookup connId conns) $ \case @@ -905,17 +926,18 @@ subscribeConnections' c connIds = do DuplexConnection cData _ sqs -> Just (cData, sqs) SndConnection cData sq -> Just (cData, [sq]) _ -> Nothing - notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m () + notifyResultError :: Map ConnId (Either AgentErrorType ()) -> AM () notifyResultError rs = do let actual = M.size rs expected = length connIds when (actual /= expected) . atomically $ writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected) -resubscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m () +resubscribeConnection' :: AgentClient -> ConnId -> AM () resubscribeConnection' c connId = toConnResult connId =<< resubscribeConnections' c [connId] +{-# INLINE resubscribeConnection' #-} -resubscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +resubscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) resubscribeConnections' _ [] = pure M.empty resubscribeConnections' c connIds = do let r = M.fromList . zip connIds . repeat $ Right () @@ -923,7 +945,7 @@ resubscribeConnections' c connIds = do -- union is left-biased, so results returned by subscribeConnections' take precedence (`M.union` r) <$> subscribeConnections' c connIds' -getConnectionMessage' :: AgentMonad m => AgentClient -> ConnId -> m (Maybe SMPMsgMeta) +getConnectionMessage' :: AgentClient -> ConnId -> AM (Maybe SMPMsgMeta) getConnectionMessage' c connId = do whenM (atomically $ hasActiveSubscription c connId) . throwError $ CMD PROHIBITED SomeConn _ conn <- withStore c (`getConn` connId) @@ -934,7 +956,7 @@ getConnectionMessage' c connId = do SndConnection _ _ -> throwError $ CONN SIMPLEX NewConnection _ -> throwError $ CMD PROHIBITED -getNotificationMessage' :: forall m. AgentMonad m => AgentClient -> C.CbNonce -> ByteString -> m (NotificationInfo, [SMPMsgMeta]) +getNotificationMessage' :: AgentClient -> C.CbNonce -> ByteString -> AM (NotificationInfo, [SMPMsgMeta]) getNotificationMessage' c nonce encNtfInfo = do withStore' c getActiveNtfToken >>= \case Just NtfToken {ntfDhSecret = Just dhSecret} -> do @@ -960,14 +982,16 @@ getNotificationMessage' c nonce encNtfInfo = do Nothing -> SMP.notification msgFlags -- | Send message to the connection (SEND command) in Reader monad -sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, PQEncryption) -sendMessage' c connId pqEnc msgFlags msg = liftEither . runIdentity =<< sendMessagesB' c (Identity (Right (connId, pqEnc, msgFlags, msg))) +sendMessage' :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AM (AgentMsgId, PQEncryption) +sendMessage' c connId pqEnc msgFlags msg = ExceptT $ runIdentity <$> sendMessagesB' c (Identity (Right (connId, pqEnc, msgFlags, msg))) +{-# INLINE sendMessage' #-} -- | Send multiple messages to different connections (SEND command) in Reader monad -sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, PQEncryption)] +sendMessages' :: AgentClient -> [MsgReq] -> AM' [Either AgentErrorType (AgentMsgId, PQEncryption)] sendMessages' c = sendMessagesB' c . map Right +{-# INLINE sendMessages' #-} -sendMessagesB' :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, PQEncryption))) +sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) let (toEnable, reqs'') = mapAccumL prepareConn [] reqs' @@ -994,37 +1018,38 @@ sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do -- / async command processing v v v -enqueueCommand :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> m () +enqueueCommand :: AgentClient -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> AM () enqueueCommand c corrId connId server aCommand = do withStore c $ \db -> createCommand db corrId connId server aCommand - void $ getAsyncCmdWorker True c server + lift . void $ getAsyncCmdWorker True c server -resumeSrvCmds :: forall m. AgentMonad' m => AgentClient -> Maybe SMPServer -> m () +resumeSrvCmds :: AgentClient -> Maybe SMPServer -> AM' () resumeSrvCmds = void .: getAsyncCmdWorker False +{-# INLINE resumeSrvCmds #-} -resumeConnCmds :: forall m. AgentMonad m => AgentClient -> ConnId -> m () +resumeConnCmds :: AgentClient -> ConnId -> AM () resumeConnCmds c connId = unlessM connQueued $ withStore' c (`getPendingCommandServers` connId) - >>= mapM_ (resumeSrvCmds c) + >>= mapM_ (lift . resumeSrvCmds c) where connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connCmdsQueued c) -getAsyncCmdWorker :: AgentMonad' m => Bool -> AgentClient -> Maybe SMPServer -> m Worker +getAsyncCmdWorker :: Bool -> AgentClient -> Maybe SMPServer -> AM' Worker getAsyncCmdWorker hasWork c server = getAgentWorker "async_cmd" hasWork c server (asyncCmdWorkers c) (runCommandProcessing c server) -runCommandProcessing :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> Worker -> m () +runCommandProcessing :: AgentClient -> Maybe SMPServer -> Worker -> AM () runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do ri <- asks $ messageRetryInterval . config -- different retry interval? forever $ do atomically $ endAgentOperation c AOSndNetwork - waitForWork doWork + lift $ waitForWork doWork atomically $ throwWhenInactive c atomically $ beginAgentOperation c AOSndNetwork withWork c doWork (`getPendingServerCommand` server_) $ processCmd (riFast ri) where - processCmd :: RetryInterval -> PendingCommand -> m () + processCmd :: RetryInterval -> PendingCommand -> AM () processCmd ri PendingCommand {cmdId, corrId, userId, connId, command} = case command of AClientCommand (APC _ cmd) -> case cmd of NEW enableNtfs (ACM cMode) pqEnc subMode -> noServer $ do @@ -1109,7 +1134,7 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do ack srv rId srvMsgId = do rq <- withStore c $ \db -> getRcvQueue db connId srv rId ackQueueMessage c rq srvMsgId - secure :: RcvQueue -> SMP.SndPublicAuthKey -> m () + secure :: RcvQueue -> SMP.SndPublicAuthKey -> AM () secure rq senderKey = do secureQueue c rq senderKey withStore' c $ \db -> setRcvQueueStatus db rq Secured @@ -1121,7 +1146,7 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do noServer a = case server_ of Nothing -> a _ -> internalErr "command requires no server" - withDuplexConn :: (Connection 'CDuplex -> m ()) -> m () + withDuplexConn :: (Connection 'CDuplex -> AM ()) -> AM () withDuplexConn a = withStore c (`getConn` connId) >>= \case SomeConn _ conn@DuplexConnection {} -> a conn @@ -1135,20 +1160,21 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do tryWithLock name = tryCommand . withConnLock c connId name internalErr s = cmdError $ INTERNAL $ s <> ": " <> show (agentCommandTag command) cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId) - notify :: forall e. AEntityI e => ACommand 'Agent e -> m () + notify :: forall e. AEntityI e => ACommand 'Agent e -> AM () notify cmd = atomically $ writeTBQueue subQ (corrId, connId, APC (sAEntity @e) cmd) -- ^ ^ ^ async command processing / -enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m (AgentMsgId, PQEncryption) +enqueueMessages :: AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> AM (AgentMsgId, PQEncryption) enqueueMessages c cData sqs msgFlags aMessage = do when (ratchetSyncSendProhibited cData) $ throwError $ INTERNAL "enqueueMessages: ratchet is not synchronized" enqueueMessages' c cData sqs msgFlags aMessage -enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) +enqueueMessages' :: AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> AM (AgentMsgId, CR.PQEncryption) enqueueMessages' c cData sqs msgFlags aMessage = - liftEither . runIdentity =<< enqueueMessagesB c (Identity (Right (cData, sqs, Nothing, msgFlags, aMessage))) + ExceptT $ runIdentity <$> enqueueMessagesB c (Identity (Right (cData, sqs, Nothing, msgFlags, aMessage))) +{-# INLINE enqueueMessages' #-} -enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, PQEncryption))) +enqueueMessagesB :: Traversable t => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption))) enqueueMessagesB c reqs = do reqs' <- enqueueMessageB c reqs enqueueSavedMessageB c $ mapMaybe snd $ rights $ toList reqs' @@ -1156,13 +1182,15 @@ enqueueMessagesB c reqs = do isActiveSndQ :: SndQueue -> Bool isActiveSndQ SndQueue {status} = status == Secured || status == Active +{-# INLINE isActiveSndQ #-} -enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m (AgentMsgId, PQEncryption) +enqueueMessage :: AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> AM (AgentMsgId, PQEncryption) enqueueMessage c cData sq msgFlags aMessage = - liftEither . fmap fst . runIdentity =<< enqueueMessageB c (Identity (Right (cData, [sq], Nothing, msgFlags, aMessage))) + ExceptT $ fmap fst . runIdentity <$> enqueueMessageB c (Identity (Right (cData, [sq], Nothing, msgFlags, aMessage))) +{-# INLINE enqueueMessage #-} -- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries -enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType ((AgentMsgId, PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) +enqueueMessageB :: forall t. (Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) -> AM' (t (Either AgentErrorType ((AgentMsgId, PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) enqueueMessageB c reqs = do cfg <- asks config reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db cfg) reqs @@ -1191,10 +1219,11 @@ enqueueMessageB c reqs = do liftIO $ createSndMsgDelivery db connId sq internalId pure (req, internalId, pqEnc) -enqueueSavedMessage :: AgentMonad' m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m () +enqueueSavedMessage :: AgentClient -> ConnData -> AgentMsgId -> SndQueue -> AM' () enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c $ Identity (cData, [sq], msgId) +{-# INLINE enqueueSavedMessage #-} -enqueueSavedMessageB :: (AgentMonad' m, Foldable t) => AgentClient -> t (ConnData, [SndQueue], AgentMsgId) -> m () +enqueueSavedMessageB :: (Foldable t) => AgentClient -> t (ConnData, [SndQueue], AgentMsgId) -> AM' () enqueueSavedMessageB c reqs = do -- saving to the database is in the start to avoid race conditions when delivery is read from queue before it is saved void $ withStoreBatch' c $ \db -> concatMap (storeDeliveries db) reqs @@ -1206,10 +1235,11 @@ enqueueSavedMessageB c reqs = do let mId = InternalId msgId in map (\sq -> createSndMsgDelivery db connId sq mId) sqs -resumeMsgDelivery :: forall m. AgentMonad' m => AgentClient -> ConnData -> SndQueue -> m () +resumeMsgDelivery :: AgentClient -> ConnData -> SndQueue -> AM' () resumeMsgDelivery = void .:. getDeliveryWorker False +{-# INLINE resumeMsgDelivery #-} -getDeliveryWorker :: AgentMonad' m => Bool -> AgentClient -> ConnData -> SndQueue -> m (Worker, TMVar ()) +getDeliveryWorker :: Bool -> AgentClient -> ConnData -> SndQueue -> AM' (Worker, TMVar ()) getDeliveryWorker hasWork c cData sq = getAgentWorker' fst mkLock "msg_delivery" hasWork c (qAddress sq) (smpDeliveryWorkers c) (runSmpQueueMsgDelivery c cData sq) where @@ -1217,17 +1247,17 @@ getDeliveryWorker hasWork c cData sq = retryLock <- newEmptyTMVar pure (w, retryLock) -submitPendingMsg :: AgentMonad' m => AgentClient -> ConnData -> SndQueue -> m () +submitPendingMsg :: AgentClient -> ConnData -> SndQueue -> AM' () submitPendingMsg c cData sq = do atomically $ modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + 1} void $ getDeliveryWorker True c cData sq -runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> (Worker, TMVar ()) -> m () +runSmpQueueMsgDelivery :: AgentClient -> ConnData -> SndQueue -> (Worker, TMVar ()) -> AM () runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq (Worker {doWork}, qLock) = do AgentConfig {messageRetryInterval = ri, messageTimeout, helloTimeout, quotaExceededTimeout} <- asks config forever $ do atomically $ endAgentOperation c AOSndNetwork - waitForWork doWork + lift $ waitForWork doWork atomically $ throwWhenInactive c atomically $ throwWhenNoDelivery c sq atomically $ beginAgentOperation c AOSndNetwork @@ -1348,26 +1378,26 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq (Worker {doWork setSndQueueStatus db sq Confirmed when (isJust rq_) $ removeConfirmations db connId where - notifyDelMsgs :: InternalId -> AgentErrorType -> UTCTime -> m () + notifyDelMsgs :: InternalId -> AgentErrorType -> UTCTime -> AM () notifyDelMsgs msgId err expireTs = do notifyDel msgId $ MERR (unId msgId) err msgIds_ <- withStore' c $ \db -> getExpiredSndMessages db connId sq expireTs forM_ (L.nonEmpty msgIds_) $ \msgIds -> do notify $ MERRS (L.map unId msgIds) err withStore' c $ \db -> forM_ msgIds $ \msgId' -> deleteSndMsgDelivery db connId sq msgId' False `catchAll_` pure () - delMsg :: InternalId -> m () + delMsg :: InternalId -> AM () delMsg = delMsgKeep False - delMsgKeep :: Bool -> InternalId -> m () + delMsgKeep :: Bool -> InternalId -> AM () delMsgKeep keepForReceipt msgId = withStore' c $ \db -> deleteSndMsgDelivery db connId sq msgId keepForReceipt - notify :: forall e. AEntityI e => ACommand 'Agent e -> m () + notify :: forall e. AEntityI e => ACommand 'Agent e -> AM () notify cmd = atomically $ writeTBQueue subQ ("", connId, APC (sAEntity @e) cmd) - notifyDel :: AEntityI e => InternalId -> ACommand 'Agent e -> m () + notifyDel :: AEntityI e => InternalId -> ACommand 'Agent e -> AM () notifyDel msgId cmd = notify cmd >> delMsg msgId connError msgId = notifyDel msgId . ERR . CONN qError msgId = notifyDel msgId . ERR . AGENT . A_QUEUE internalErr msgId = notifyDel msgId . ERR . INTERNAL -retrySndOp :: AgentMonad m => AgentClient -> m () -> m () +retrySndOp :: AgentClient -> AM () -> AM () retrySndOp c loop = do -- end... is in a separate atomically because if begin... blocks, SUSPENDED won't be sent atomically $ endAgentOperation c AOSndNetwork @@ -1375,7 +1405,7 @@ retrySndOp c loop = do atomically $ beginAgentOperation c AOSndNetwork loop -ackMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () +ackMessage' :: AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> AM () ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do SomeConn _ conn <- withStore c (`getConn` connId) case conn of @@ -1385,14 +1415,14 @@ ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do ContactConnection {} -> throwError $ CMD PROHIBITED NewConnection _ -> throwError $ CMD PROHIBITED where - ack :: m () + ack :: AM () ack = do -- the stored message was delivered via a specific queue, the rest failed to decrypt and were already acknowledged (rq, srvMsgId) <- withStore c $ \db -> setMsgUserAck db connId $ InternalId msgId ackQueueMessage c rq srvMsgId - del :: m () + del :: AM () del = withStore' c $ \db -> deleteMsg db connId $ InternalId msgId - sendRcpt :: Connection 'CDuplex -> m () + sendRcpt :: Connection 'CDuplex -> AM () sendRcpt (DuplexConnection cData@ConnData {connAgentVersion} _ sqs) = do msg@RcvMsg {msgType, msgReceipt} <- withStore c $ \db -> getRcvMsg db connId $ InternalId msgId case rcptInfo_ of @@ -1408,7 +1438,7 @@ ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do withStore' c $ \db -> deleteDeliveredSndMsg db connId $ InternalId sndMsgId _ -> pure () -switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats +switchConnection' :: AgentClient -> ConnId -> AM ConnectionStats switchConnection' c connId = withConnLock c connId "switchConnection" $ withStore c (`getConn` connId) >>= \case @@ -1420,7 +1450,7 @@ switchConnection' c connId = switchDuplexConnection c conn rq' _ -> throwError $ CMD PROHIBITED -switchDuplexConnection :: AgentMonad m => AgentClient -> Connection 'CDuplex -> RcvQueue -> m ConnectionStats +switchDuplexConnection :: AgentClient -> Connection 'CDuplex -> RcvQueue -> AM ConnectionStats switchDuplexConnection c (DuplexConnection cData@ConnData {connId, userId} rqs sqs) rq@RcvQueue {server, dbQueueId = DBQueueId dbQueueId, sndId} = do checkRQSwchStatus rq RSSwitchStarted clientVRange <- asks $ smpClientVRange . config @@ -1430,13 +1460,13 @@ switchDuplexConnection c (DuplexConnection cData@ConnData {connId, userId} rqs s (q, qUri) <- newRcvQueue c userId connId srv' clientVRange SMSubscribe let rq' = (q :: NewRcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} rq'' <- withStore c $ \db -> addConnRcvQueue db connId rq' - addSubscription c rq'' + liftIO $ addSubscription c rq'' void . enqueueMessages c cData sqs SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))] rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSendingQADD let rqs' = updatedQs rq1 rqs <> [rq''] pure . connectionStats $ DuplexConnection cData rqs' sqs -abortConnectionSwitch' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats +abortConnectionSwitch' :: AgentClient -> ConnId -> AM ConnectionStats abortConnectionSwitch' c connId = withConnLock c connId "abortConnectionSwitch" $ withStore c (`getConn` connId) >>= \case @@ -1460,7 +1490,7 @@ abortConnectionSwitch' c connId = _ -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED -synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> PQSupport -> Bool -> m ConnectionStats +synchronizeRatchet' :: AgentClient -> ConnId -> PQSupport -> Bool -> AM ConnectionStats synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchronizeRatchet" $ do withStore c (`getConn` connId) >>= \case SomeConn _ (DuplexConnection cData@ConnData {pqSupport} rqs sqs) @@ -1481,14 +1511,14 @@ synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchroni | otherwise -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED -ackQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> SMP.MsgId -> m () +ackQueueMessage :: AgentClient -> RcvQueue -> SMP.MsgId -> AM () ackQueueMessage c rq srvMsgId = sendAck c rq srvMsgId `catchAgentError` \case SMP SMP.NO_MSG -> pure () e -> throwError e -- | Suspend SMP agent connection (OFF command) in Reader monad -suspendConnection' :: AgentMonad m => AgentClient -> ConnId -> m () +suspendConnection' :: AgentClient -> ConnId -> AM () suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do SomeConn _ conn <- withStore c (`getConn` connId) case conn of @@ -1501,8 +1531,9 @@ suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do -- | Delete SMP agent connection (DEL command) in Reader monad -- unlike deleteConnectionAsync, this function does not mark connection as deleted in case of deletion failure -- currently it is used only in tests -deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m () +deleteConnection' :: AgentClient -> ConnId -> AM () deleteConnection' c connId = toConnResult connId =<< deleteConnections' c [connId] +{-# INLINE deleteConnection' #-} connRcvQueues :: Connection d -> [RcvQueue] connRcvQueues = \case @@ -1512,30 +1543,31 @@ connRcvQueues = \case SndConnection _ _ -> [] NewConnection _ -> [] -disableConn :: AgentMonad m => AgentClient -> ConnId -> m () +disableConn :: AgentClient -> ConnId -> AM' () disableConn c connId = do atomically $ removeSubscription c connId ns <- asks ntfSupervisor atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCDelete) -- Unlike deleteConnectionsAsync, this function does not mark connections as deleted in case of deletion failure. -deleteConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +deleteConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) deleteConnections' = deleteConnections_ getConns False False +{-# INLINE deleteConnections' #-} -deleteDeletedConns :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +deleteDeletedConns :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) deleteDeletedConns = deleteConnections_ getDeletedConns True False +{-# INLINE deleteDeletedConns #-} -deleteDeletedWaitingDeliveryConns :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +deleteDeletedWaitingDeliveryConns :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) deleteDeletedWaitingDeliveryConns = deleteConnections_ getConns True True +{-# INLINE deleteDeletedWaitingDeliveryConns #-} prepareDeleteConnections_ :: - forall m. - AgentMonad m => (DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]) -> AgentClient -> Bool -> [ConnId] -> - m (Map ConnId (Either AgentErrorType ()), [RcvQueue], [ConnId]) + AM (Map ConnId (Either AgentErrorType ()), [RcvQueue], [ConnId]) prepareDeleteConnections_ getConnections c waitDelivery connIds = do conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (`getConnections` connIds) let (errs, cs) = M.mapEither id conns @@ -1543,13 +1575,13 @@ prepareDeleteConnections_ getConnections c waitDelivery connIds = do (delRs, rcvQs) = M.mapEither rcvQueues cs rqs = concat $ M.elems rcvQs connIds' = M.keys rcvQs - forM_ connIds' $ disableConn c + lift . forM_ connIds' $ disableConn c -- ! delRs is not used to notify about the result in any of the calling functions, -- ! it is only used to check results count in deleteConnections_; -- ! if it was used to notify about the result, it might be necessary to differentiate -- ! between completed deletions of connections, and deletions delayed due to wait for delivery (see deleteConn) deliveryTimeout <- if waitDelivery then asks (Just . connDeleteDeliveryTimeout . config) else pure Nothing - rs' <- catMaybes . rights <$> withStoreBatch' c (\db -> map (deleteConn db deliveryTimeout) (M.keys delRs)) + rs' <- lift $ catMaybes . rights <$> withStoreBatch' c (\db -> map (deleteConn db deliveryTimeout) (M.keys delRs)) forM_ rs' $ \cId -> notify ("", cId, APC SAEConn DEL_CONN) pure (errs' <> delRs, rqs, connIds') where @@ -1559,7 +1591,7 @@ prepareDeleteConnections_ getConnections c waitDelivery connIds = do rqs -> Right rqs notify = atomically . writeTBQueue (subQ c) -deleteConnQueues :: forall m. AgentMonad m => AgentClient -> Bool -> Bool -> [RcvQueue] -> m (Map ConnId (Either AgentErrorType ())) +deleteConnQueues :: AgentClient -> Bool -> Bool -> [RcvQueue] -> AM' (Map ConnId (Either AgentErrorType ())) deleteConnQueues c waitDelivery ntf rqs = do rs <- connResults <$> (deleteQueueRecs =<< deleteQueues c rqs) let connIds = M.keys $ M.filter isRight rs @@ -1568,7 +1600,7 @@ deleteConnQueues c waitDelivery ntf rqs = do forM_ rs' $ \cId -> notify ("", cId, APC SAEConn DEL_CONN) pure rs where - deleteQueueRecs :: [(RcvQueue, Either AgentErrorType ())] -> m [(RcvQueue, Either AgentErrorType ())] + deleteQueueRecs :: [(RcvQueue, Either AgentErrorType ())] -> AM' [(RcvQueue, Either AgentErrorType ())] deleteQueueRecs rs = do maxErrs <- asks $ deleteErrorCount . config (rs', notifyActions) <- unzip . rights <$> withStoreBatch' c (\db -> map (deleteQueueRec db maxErrs) rs) @@ -1579,7 +1611,7 @@ deleteConnQueues c waitDelivery ntf rqs = do DB.Connection -> Int -> (RcvQueue, Either AgentErrorType ()) -> - IO ((RcvQueue, Either AgentErrorType ()), Maybe (m ())) + IO ((RcvQueue, Either AgentErrorType ()), Maybe (AM' ())) deleteQueueRec db maxErrs (rq, r) = case r of Right _ -> deleteConnRcvQueue db rq $> ((rq, r), Just (notifyRQ rq Nothing)) Left e @@ -1603,35 +1635,33 @@ deleteConnQueues c waitDelivery ntf rqs = do order _ = 3 deleteConnections_ :: - forall m. - AgentMonad m => (DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]) -> Bool -> Bool -> AgentClient -> [ConnId] -> - m (Map ConnId (Either AgentErrorType ())) + AM (Map ConnId (Either AgentErrorType ())) deleteConnections_ _ _ _ _ [] = pure M.empty deleteConnections_ getConnections ntf waitDelivery c connIds = do (rs, rqs, _) <- prepareDeleteConnections_ getConnections c waitDelivery connIds - rcvRs <- deleteConnQueues c waitDelivery ntf rqs + rcvRs <- lift $ deleteConnQueues c waitDelivery ntf rqs let rs' = M.union rs rcvRs notifyResultError rs' pure rs' where - notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m () + notifyResultError :: Map ConnId (Either AgentErrorType ()) -> AM () notifyResultError rs = do let actual = M.size rs expected = length connIds when (actual /= expected) . atomically $ writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ INTERNAL $ "deleteConnections result size: " <> show actual <> ", expected " <> show expected) -getConnectionServers' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats +getConnectionServers' :: AgentClient -> ConnId -> AM ConnectionStats getConnectionServers' c connId = do SomeConn _ conn <- withStore c (`getConn` connId) pure $ connectionStats conn -getConnectionRatchetAdHash' :: AgentMonad m => AgentClient -> ConnId -> m ByteString +getConnectionRatchetAdHash' :: AgentClient -> ConnId -> AM ByteString getConnectionRatchetAdHash' c connId = do CR.Ratchet {rcAD = Str rcAD} <- withStore c (`getRatchet` connId) pure $ C.sha256Hash rcAD @@ -1659,10 +1689,11 @@ connectionStats = \case } -- | Change servers to be used for creating new queues, in Reader monad -setProtocolServers' :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> NonEmpty (ProtoServerWithAuth p) -> m () -setProtocolServers' c userId srvs = atomically $ TM.insert userId srvs (userServers c) +setProtocolServers :: (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> NonEmpty (ProtoServerWithAuth p) -> IO () +setProtocolServers c userId srvs = atomically $ TM.insert userId srvs (userServers c) +{-# INLINE setProtocolServers #-} -registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus +registerNtfToken' :: AgentClient -> DeviceToken -> NotificationsMode -> AM NtfTknStatus registerNtfToken' c suppliedDeviceToken suppliedNtfMode = withStore' c getSavedNtfToken >>= \case Just tkn@NtfToken {deviceToken = savedDeviceToken, ntfTokenId, ntfTknStatus, ntfTknAction, ntfMode = savedNtfMode} -> do @@ -1701,7 +1732,7 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = withStore' c $ \db -> updateNtfMode db tkn suppliedNtfMode pure status where - replaceToken :: NtfTokenId -> m NtfTknStatus + replaceToken :: NtfTokenId -> AM NtfTknStatus replaceToken tknId = do ns <- asks ntfSupervisor tryReplace ns `catchAgentError` \e -> @@ -1720,9 +1751,9 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = _ -> createToken where t tkn = withToken c tkn Nothing - createToken :: m NtfTknStatus + createToken :: AM NtfTknStatus createToken = - getNtfServer c >>= \case + lift (getNtfServer c) >>= \case Just ntfServer -> asks (rcvAuthAlg . config) >>= \case C.AuthAlg a -> do @@ -1734,7 +1765,7 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = registerToken tkn pure NTRegistered _ -> throwError $ CMD PROHIBITED - registerToken :: NtfToken -> m () + registerToken :: NtfToken -> AM () registerToken tkn@NtfToken {ntfPubKey, ntfDhKeys = (pubDhKey, privDhKey)} = do (tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey let dhSecret = C.dh' srvPubDhKey privDhKey @@ -1742,7 +1773,7 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = ns <- asks ntfSupervisor atomically $ nsUpdateToken ns tkn {deviceToken = suppliedDeviceToken, ntfTknStatus = NTRegistered, ntfMode = suppliedNtfMode} -verifyNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> C.CbNonce -> ByteString -> m () +verifyNtfToken' :: AgentClient -> DeviceToken -> C.CbNonce -> ByteString -> AM () verifyNtfToken' c deviceToken nonce code = withStore' c getSavedNtfToken >>= \case Just tkn@NtfToken {deviceToken = savedDeviceToken, ntfTokenId = Just tknId, ntfDhSecret = Just dhSecret, ntfMode} -> do @@ -1757,7 +1788,7 @@ verifyNtfToken' c deviceToken nonce code = when (ntfMode == NMInstant) $ initializeNtfSubs c _ -> throwError $ CMD PROHIBITED -checkNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m NtfTknStatus +checkNtfToken' :: AgentClient -> DeviceToken -> AM NtfTknStatus checkNtfToken' c deviceToken = withStore' c getSavedNtfToken >>= \case Just tkn@NtfToken {deviceToken = savedDeviceToken, ntfTokenId = Just tknId} -> do @@ -1765,7 +1796,7 @@ checkNtfToken' c deviceToken = agentNtfCheckToken c tknId tkn _ -> throwError $ CMD PROHIBITED -deleteNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m () +deleteNtfToken' :: AgentClient -> DeviceToken -> AM () deleteNtfToken' c deviceToken = withStore' c getSavedNtfToken >>= \case Just tkn@NtfToken {deviceToken = savedDeviceToken} -> do @@ -1774,20 +1805,20 @@ deleteNtfToken' c deviceToken = deleteNtfSubs c NSCSmpDelete _ -> throwError $ CMD PROHIBITED -getNtfToken' :: AgentMonad m => AgentClient -> m (DeviceToken, NtfTknStatus, NotificationsMode, NtfServer) +getNtfToken' :: AgentClient -> AM (DeviceToken, NtfTknStatus, NotificationsMode, NtfServer) getNtfToken' c = withStore' c getSavedNtfToken >>= \case Just NtfToken {deviceToken, ntfTknStatus, ntfMode, ntfServer} -> pure (deviceToken, ntfTknStatus, ntfMode, ntfServer) _ -> throwError $ CMD PROHIBITED -getNtfTokenData' :: AgentMonad m => AgentClient -> m NtfToken +getNtfTokenData' :: AgentClient -> AM NtfToken getNtfTokenData' c = withStore' c getSavedNtfToken >>= \case Just tkn -> pure tkn _ -> throwError $ CMD PROHIBITED -- | Set connection notifications, in Reader monad -toggleConnectionNtfs' :: forall m. AgentMonad m => AgentClient -> ConnId -> Bool -> m () +toggleConnectionNtfs' :: AgentClient -> ConnId -> Bool -> AM () toggleConnectionNtfs' c connId enable = do SomeConn _ conn <- withStore c (`getConn` connId) case conn of @@ -1796,7 +1827,7 @@ toggleConnectionNtfs' c connId enable = do ContactConnection cData _ -> toggle cData _ -> throwError $ CONN SIMPLEX where - toggle :: ConnData -> m () + toggle :: ConnData -> AM () toggle cData | enableNtfs cData == enable = pure () | otherwise = do @@ -1805,7 +1836,7 @@ toggleConnectionNtfs' c connId enable = do let cmd = if enable then NSCCreate else NSCDelete atomically $ sendNtfSubCommand ns (connId, cmd) -deleteToken_ :: AgentMonad m => AgentClient -> NtfToken -> m () +deleteToken_ :: AgentClient -> NtfToken -> AM () deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do ns <- asks ntfSupervisor forM_ ntfTokenId $ \tknId -> do @@ -1818,7 +1849,7 @@ deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do withStore' c $ \db -> removeNtfToken db tkn atomically $ nsRemoveNtfToken ns -withToken :: AgentMonad m => AgentClient -> NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m NtfTknStatus +withToken :: AgentClient -> NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> AM a -> AM NtfTknStatus withToken c tkn@NtfToken {deviceToken, ntfMode} from_ (toStatus, toAction_) f = do ns <- asks ntfSupervisor forM_ from_ $ \(status, action) -> do @@ -1837,16 +1868,17 @@ withToken c tkn@NtfToken {deviceToken, ntfMode} from_ (toStatus, toAction_) f = throwError e Left e -> throwError e -initializeNtfSubs :: AgentMonad m => AgentClient -> m () +initializeNtfSubs :: AgentClient -> AM () initializeNtfSubs c = sendNtfConnCommands c NSCCreate +{-# INLINE initializeNtfSubs #-} -deleteNtfSubs :: AgentMonad m => AgentClient -> NtfSupervisorCommand -> m () +deleteNtfSubs :: AgentClient -> NtfSupervisorCommand -> AM () deleteNtfSubs c deleteCmd = do ns <- asks ntfSupervisor void . atomically . flushTBQueue $ ntfSubQ ns sendNtfConnCommands c deleteCmd -sendNtfConnCommands :: AgentMonad m => AgentClient -> NtfSupervisorCommand -> m () +sendNtfConnCommands :: AgentClient -> NtfSupervisorCommand -> AM () sendNtfConnCommands c cmd = do ns <- asks ntfSupervisor connIds <- atomically $ getSubscriptions c @@ -1857,23 +1889,26 @@ sendNtfConnCommands c cmd = do _ -> atomically $ writeTBQueue (subQ c) ("", connId, APC SAEConn $ ERR $ INTERNAL "no connection data") -setNtfServers' :: AgentMonad' m => AgentClient -> [NtfServer] -> m () -setNtfServers' c = atomically . writeTVar (ntfServers c) +setNtfServers :: AgentClient -> [NtfServer] -> IO () +setNtfServers c = atomically . writeTVar (ntfServers c) +{-# INLINE setNtfServers #-} -foregroundAgent' :: AgentMonad' m => AgentClient -> m () -foregroundAgent' c = do +-- | Activate operations +foregroundAgent :: AgentClient -> IO () +foregroundAgent c = do atomically $ writeTVar (agentState c) ASForeground mapM_ activate $ reverse agentOperations where activate opSel = atomically $ modifyTVar' (opSel c) $ \s -> s {opSuspended = False} -suspendAgent' :: AgentMonad' m => AgentClient -> Int -> m () -suspendAgent' c 0 = do +-- | Suspend operations with max delay to deliver pending messages +suspendAgent :: AgentClient -> Int -> IO () +suspendAgent c 0 = do atomically $ writeTVar (agentState c) ASSuspended mapM_ suspend agentOperations where suspend opSel = atomically $ modifyTVar' (opSel c) $ \s -> s {opSuspended = True} -suspendAgent' c@AgentClient {agentState = as} maxDelay = do +suspendAgent c@AgentClient {agentState = as} maxDelay = do state <- atomically $ do writeTVar as ASSuspending @@ -1889,14 +1924,14 @@ suspendAgent' c@AgentClient {agentState = as} maxDelay = do -- unsafeIOToSTM $ putStrLn $ "in timeout: suspendSendingAndDatabase" suspendSendingAndDatabase c -execAgentStoreSQL' :: AgentMonad m => AgentClient -> Text -> m [Text] -execAgentStoreSQL' c sql = withStore' c (`execSQL` sql) +execAgentStoreSQL :: AgentClient -> Text -> AE [Text] +execAgentStoreSQL c sql = withAgentEnv c $ withStore' c (`execSQL` sql) -getAgentMigrations' :: AgentMonad m => AgentClient -> m [UpMigration] -getAgentMigrations' c = map upMigration <$> withStore' c (Migrations.getCurrent . DB.conn) +getAgentMigrations :: AgentClient -> AE [UpMigration] +getAgentMigrations c = withAgentEnv c $ map upMigration <$> withStore' c (Migrations.getCurrent . DB.conn) -debugAgentLocks' :: AgentMonad' m => AgentClient -> m AgentLocks -debugAgentLocks' AgentClient {connLocks = cs, invLocks = is, deleteLock = d} = do +debugAgentLocks :: AgentClient -> IO AgentLocks +debugAgentLocks AgentClient {connLocks = cs, invLocks = is, deleteLock = d} = do connLocks <- getLocks cs invLocks <- getLocks is delLock <- atomically $ tryReadTMVar d @@ -1904,10 +1939,11 @@ debugAgentLocks' AgentClient {connLocks = cs, invLocks = is, deleteLock = d} = d where getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls) -getSMPServer :: AgentMonad m => AgentClient -> UserId -> m SMPServerWithAuth +getSMPServer :: AgentClient -> UserId -> AM SMPServerWithAuth getSMPServer c userId = withUserServers c userId pickServer +{-# INLINE getSMPServer #-} -subscriber :: AgentMonad' m => AgentClient -> m () +subscriber :: AgentClient -> AM' () subscriber c@AgentClient {msgQ} = forever $ do t <- atomically $ readTBQueue msgQ agentOperationBracket c AORcvNetwork waitUntilActive $ @@ -1915,7 +1951,7 @@ subscriber c@AgentClient {msgQ} = forever $ do Left e -> liftIO $ print e Right _ -> return () -cleanupManager :: forall m. AgentMonad' m => AgentClient -> m () +cleanupManager :: AgentClient -> AM' () cleanupManager c@AgentClient {subQ} = do delay <- asks (initialCleanupDelay . config) liftIO $ threadDelay' delay @@ -1935,7 +1971,7 @@ cleanupManager c@AgentClient {subQ} = do run SFERR deleteExpiredReplicasForDeletion liftIO $ threadDelay' int where - run :: forall e. AEntityI e => (AgentErrorType -> ACommand 'Agent e) -> ExceptT AgentErrorType m () -> m () + run :: forall e. AEntityI e => (AgentErrorType -> ACommand 'Agent e) -> AM () -> AM' () run err a = do waitActive . runExceptT $ a `catchAgentError` (notify "" . err) step <- asks $ cleanupStepInterval . config @@ -1951,50 +1987,50 @@ cleanupManager c@AgentClient {subQ} = do rcvFilesTTL <- asks $ rcvFilesTTL . config rcvExpired <- withStore' c (`getRcvFilesExpired` rcvFilesTTL) forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do - removePath =<< toFSFilePath p + lift $ removePath =<< toFSFilePath p withStore' c (`deleteRcvFile'` dbId) deleteRcvFilesDeleted = do rcvDeleted <- withStore' c getCleanupRcvFilesDeleted forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do - removePath =<< toFSFilePath p + lift $ removePath =<< toFSFilePath p withStore' c (`deleteRcvFile'` dbId) deleteRcvFilesTmpPaths = do rcvTmpPaths <- withStore' c getCleanupRcvFilesTmpPaths forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do - removePath =<< toFSFilePath p + lift $ removePath =<< toFSFilePath p withStore' c (`updateRcvFileNoTmpPath` dbId) deleteSndFilesExpired = do sndFilesTTL <- asks $ sndFilesTTL . config sndExpired <- withStore' c (`getSndFilesExpired` sndFilesTTL) forM_ sndExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do - forM_ p $ removePath <=< toFSFilePath + lift . forM_ p $ removePath <=< toFSFilePath withStore' c (`deleteSndFile'` dbId) deleteSndFilesDeleted = do sndDeleted <- withStore' c getCleanupSndFilesDeleted forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do - forM_ p $ removePath <=< toFSFilePath + lift . forM_ p $ removePath <=< toFSFilePath withStore' c (`deleteSndFile'` dbId) deleteSndFilesPrefixPaths = do sndPrefixPaths <- withStore' c getCleanupSndFilesPrefixPaths forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do - removePath =<< toFSFilePath p + lift $ removePath =<< toFSFilePath p withStore' c (`updateSndFileNoPrefixPath` dbId) deleteExpiredReplicasForDeletion = do rcvFilesTTL <- asks $ rcvFilesTTL . config withStore' c (`deleteDeletedSndChunkReplicasExpired` rcvFilesTTL) - notify :: forall e. AEntityI e => EntityId -> ACommand 'Agent e -> ExceptT AgentErrorType m () + notify :: forall e. AEntityI e => EntityId -> ACommand 'Agent e -> AM () notify entId cmd = atomically $ writeTBQueue subQ ("", entId, APC (sAEntity @e) cmd) data ACKd = ACKd | ACKPending -- | make sure to ACK or throw in each message processing branch -- it cannot be finally, unfortunately, as sometimes it needs to be ACK+DEL -processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission SMPVersion BrokerMsg -> m () +processSMPTransmission :: AgentClient -> ServerTransmission SMPVersion BrokerMsg -> AM () processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, sessId, rId, cmd) = do (rq, SomeConn _ conn) <- withStore c (\db -> getRcvConn db srv rId) processSMP rq conn $ toConnData conn where - processSMP :: forall c. RcvQueue -> Connection c -> ConnData -> m () + processSMP :: forall c. RcvQueue -> Connection c -> ConnData -> AM () processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn @@ -2069,9 +2105,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, QTEST _ -> logServer "<--" c srv rId ("MSG :" <> logSecret srvMsgId) >> ackDel msgId EREADY _ -> qDuplexAckDel conn'' "EREADY" $ ereadyMsg rcPrev where - qDuplexAckDel :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m ACKd + qDuplexAckDel :: Connection c -> String -> (Connection 'CDuplex -> AM ()) -> AM ACKd qDuplexAckDel conn'' name a = qDuplex conn'' name a >> ackDel msgId - resetRatchetSync :: m (Connection c) + resetRatchetSync :: AM (Connection c) resetRatchetSync | rss `notElem` ([RSOk, RSStarted] :: [RatchetSyncState]) = do let cData'' = (toConnData conn') {ratchetSyncState = RSOk} :: ConnData @@ -2098,7 +2134,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, unless exists notifySync ack where - notifySync :: m () + notifySync :: AM () notifySync = qDuplex conn' "AGENT A_CRYPTO error" $ \connDuplex -> do let rss' = cryptoErrToSyncState e when (rss `elem` ([RSOk, RSAllowed, RSRequired] :: [RatchetSyncState])) $ do @@ -2108,11 +2144,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, withStore' c $ \db -> setConnRatchetSync db connId rss' Left e -> checkDuplicateHash e encryptedMsgHash >> ack where - checkDuplicateHash :: AgentErrorType -> ByteString -> m () + checkDuplicateHash :: AgentErrorType -> ByteString -> AM () checkDuplicateHash e encryptedMsgHash = unlessM (withStore' c $ \db -> checkRcvMsgHashExists db connId encryptedMsgHash) $ throwError e - agentClientMsg :: TVar ChaChaDRG -> ByteString -> m (Maybe (InternalId, MsgMeta, AMessage, CR.RatchetX448)) + agentClientMsg :: TVar ChaChaDRG -> ByteString -> AM (Maybe (InternalId, MsgMeta, AMessage, CR.RatchetX448)) agentClientMsg g encryptedMsgHash = withStore c $ \db -> runExceptT $ do rc <- ExceptT $ getRatchet db connId -- ratchet state pre-decryption - required for processing EREADY (agentMsgBody, pqEncryption) <- agentRatchetDecrypt' g db connId rc encAgentMessage @@ -2132,7 +2168,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> pure Nothing _ -> prohibited >> ack _ -> prohibited >> ack - updateConnVersion :: Connection c -> ConnData -> VersionSMPA -> m (Connection c) + updateConnVersion :: Connection c -> ConnData -> VersionSMPA -> AM (Connection c) updateConnVersion conn' cData'@ConnData {pqSupport} msgAgentVersion = do aVRange <- asks $ ($ pqSupport) . smpAgentVRange . config let msgAVRange = fromMaybe (versionToRange msgAgentVersion) $ safeVersionRange (minVersion aVRange) msgAgentVersion @@ -2144,11 +2180,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, pure $ updateConnection cData'' conn' | otherwise -> pure conn' Nothing -> pure conn' - ack :: m ACKd + ack :: AM ACKd ack = enqueueCmd (ICAck rId srvMsgId) $> ACKd - ackDel :: InternalId -> m ACKd + ackDel :: InternalId -> AM ACKd ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd - handleNotifyAck :: m ACKd -> m ACKd + handleNotifyAck :: AM ACKd -> AM ACKd handleNotifyAck m = m `catchAgentError` \e -> notify (ERR e) >> ack SMP.END -> atomically (TM.lookup tSess smpClients $>>= (tryReadTMVar . sessionVar) >>= processEND) @@ -2167,19 +2203,19 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, logServer "<--" c srv rId $ "unexpected: " <> bshow cmd notify . ERR $ BROKER (B.unpack $ strEncode srv) UNEXPECTED where - notify :: forall e. AEntityI e => ACommand 'Agent e -> m () + notify :: forall e m. MonadIO m => AEntityI e => ACommand 'Agent e -> m () notify = atomically . notify' notify' :: forall e. AEntityI e => ACommand 'Agent e -> STM () notify' msg = writeTBQueue subQ ("", connId, APC (sAEntity @e) msg) - prohibited :: m () + prohibited :: AM () prohibited = notify . ERR $ AGENT A_PROHIBITED - enqueueCmd :: InternalCommand -> m () + enqueueCmd :: InternalCommand -> AM () enqueueCmd = enqueueCommand c "" connId (Just srv) . AInternalCommand - decryptClientMessage :: C.DhSecretX25519 -> SMP.ClientMsgEnvelope -> m (SMP.PrivHeader, AgentMsgEnvelope) + decryptClientMessage :: C.DhSecretX25519 -> SMP.ClientMsgEnvelope -> AM (SMP.PrivHeader, AgentMsgEnvelope) decryptClientMessage e2eDh SMP.ClientMsgEnvelope {cmNonce, cmEncBody} = do clientMsg <- agentCbDecrypt e2eDh cmNonce cmEncBody SMP.ClientMessage privHeader clientBody <- parseMessage clientMsg @@ -2192,10 +2228,10 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, -- else throwError $ AGENT A_VERSION pure (privHeader, agentEnvelope) - parseMessage :: Encoding a => ByteString -> m a + parseMessage :: Encoding a => ByteString -> AM a parseMessage = liftEither . parse smpP (AGENT A_MESSAGE) - smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> ByteString -> VersionSMPC -> VersionSMPA -> m () + smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> ByteString -> VersionSMPC -> VersionSMPA -> AM () smpConfirmation srvMsgId conn' senderKey e2ePubKey e2eEncryption encConnInfo smpClientVersion agentVersion = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config @@ -2246,7 +2282,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> prohibited _ -> prohibited - helloMsg :: SMP.MsgId -> MsgMeta -> Connection c -> m () + helloMsg :: SMP.MsgId -> MsgMeta -> Connection c -> AM () helloMsg srvMsgId MsgMeta {pqEncryption} conn' = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId case status of @@ -2261,12 +2297,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, | otherwise -> enqueueDuplexHello sq _ -> pure () where - enqueueDuplexHello :: SndQueue -> m () + enqueueDuplexHello :: SndQueue -> AM () enqueueDuplexHello sq = do let cData' = toConnData conn' void $ enqueueMessage c cData' sq SMP.MsgFlags {notification = True} HELLO - continueSending :: SMP.MsgId -> (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m () + continueSending :: SMP.MsgId -> (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> AM () continueSending srvMsgId addr (DuplexConnection _ _ sqs) = case findQ addr sqs of Just sq -> do @@ -2276,7 +2312,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, >>= mapM_ (\(_, retryLock) -> tryPutTMVar retryLock ()) Nothing -> qError "QCONT: queue address not found" - messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> m ACKd + messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> AM ACKd messagesRcvd rcpts msgMeta@MsgMeta {broker = (srvMsgId, _)} _ = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAgentError` \e -> notify (ERR e) $> Nothing @@ -2284,9 +2320,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, Just rs' -> notify (RCVD msgMeta rs') $> ACKPending Nothing -> ack where - ack :: m ACKd + ack :: AM ACKd ack = enqueueCmd (ICAck rId srvMsgId) $> ACKd - clientReceipt :: AMessageReceipt -> m (Maybe MsgReceipt) + clientReceipt :: AMessageReceipt -> AM (Maybe MsgReceipt) clientReceipt AMessageReceipt {agentMsgId, msgHash} = do let sndMsgId = InternalSndId agentMsgId SndMsg {internalId = InternalId msgId, msgType, internalHash, msgReceipt} <- withStore c $ \db -> getSndMsgViaRcpt db connId sndMsgId @@ -2301,7 +2337,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, pure $ Just rcpt -- processed by queue sender - qAddMsg :: SMP.MsgId -> NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> m () + qAddMsg :: SMP.MsgId -> NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> AM () qAddMsg _ ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported" qAddMsg srvMsgId ((qUri, Just addr) :| _) (DuplexConnection cData' rqs sqs) = do when (ratchetSyncSendProhibited cData') $ throwError $ AGENT (A_QUEUE "ratchet is not synchronized") @@ -2316,7 +2352,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, Just sqs' -> do -- move inside case? withStore' c $ \db -> mapM_ (deleteConnSndQueue db connId) delSqs - sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue userId connId qInfo + sq_@SndQueue {sndPublicKey, e2ePubKey} <- lift $ newSndQueue userId connId qInfo let sq'' = (sq_ :: NewSndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} sq2 <- withStore c $ \db -> addConnSndQueue db connId sq'' case (sndPublicKey, e2ePubKey) of @@ -2334,7 +2370,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> throwError $ AGENT A_VERSION -- processed by queue recipient - qKeyMsg :: SMP.MsgId -> NonEmpty (SMPQueueInfo, SndPublicAuthKey) -> Connection 'CDuplex -> m () + qKeyMsg :: SMP.MsgId -> NonEmpty (SMPQueueInfo, SndPublicAuthKey) -> Connection 'CDuplex -> AM () qKeyMsg srvMsgId ((qInfo, senderKey) :| _) conn'@(DuplexConnection cData' rqs _) = do when (ratchetSyncSendProhibited cData') $ throwError $ AGENT (A_QUEUE "ratchet is not synchronized") clientVRange <- asks $ smpClientVRange . config @@ -2355,7 +2391,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, -- processed by queue sender -- mark queue as Secured and to start sending messages to it - qUseMsg :: SMP.MsgId -> NonEmpty ((SMPServer, SMP.SenderId), Bool) -> Connection 'CDuplex -> m () + qUseMsg :: SMP.MsgId -> NonEmpty ((SMPServer, SMP.SenderId), Bool) -> Connection 'CDuplex -> AM () -- NOTE: does not yet support the change of the primary status during the rotation qUseMsg srvMsgId ((addr, _primary) :| _) (DuplexConnection cData' rqs sqs) = do when (ratchetSyncSendProhibited cData') $ throwError $ AGENT (A_QUEUE "ratchet is not synchronized") @@ -2376,24 +2412,24 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> qError "QUSE: switching SndQueue not found in connection" _ -> qError "QUSE: switched queue address not found in connection" - qError :: String -> m a + qError :: String -> AM a qError = throwError . AGENT . A_QUEUE - ereadyMsg :: CR.RatchetX448 -> Connection 'CDuplex -> m () + ereadyMsg :: CR.RatchetX448 -> Connection 'CDuplex -> AM () ereadyMsg rcPrev (DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) = do let CR.Ratchet {rcSnd} = rcPrev -- if ratchet was initialized as receiving, it means EREADY wasn't sent on key negotiation when (isNothing rcSnd) . void $ enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} (EREADY lastExternalSndId) - smpInvitation :: SMP.MsgId -> Connection c -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () + smpInvitation :: SMP.MsgId -> Connection c -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> AM () smpInvitation srvMsgId conn' connReq@(CRInvitationUri crData _) cInfo = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId case conn' of ContactConnection {} -> do -- show connection request even if invitaion via contact address is not compatible. -- in case invitation not compatible, assume there is no PQ encryption support. - pqSupport <- maybe PQSupportOff pqSupported <$> compatibleInvitationUri connReq PQSupportOn + pqSupport <- lift $ maybe PQSupportOff pqSupported <$> compatibleInvitationUri connReq PQSupportOn g <- asks random let newInv = NewInvitation {contactConnId = connId, connReq, recipientConnInfo = cInfo} invId <- withStore c $ \db -> createInvitation db g newInv @@ -2404,12 +2440,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, pqSupported (_, Compatible (CR.E2ERatchetParams v _ _ _), Compatible agentVersion) = PQSupportOn `CR.pqSupportAnd` versionPQSupport_ agentVersion (Just v) - qDuplex :: Connection c -> String -> (Connection 'CDuplex -> m a) -> m a + qDuplex :: Connection c -> String -> (Connection 'CDuplex -> AM a) -> AM a qDuplex conn' name action = case conn' of DuplexConnection {} -> action conn' _ -> qError $ name <> ": message must be sent to duplex connection" - newRatchetKey :: CR.RcvE2ERatchetParams 'C.X448 -> Connection 'CDuplex -> m () + newRatchetKey :: CR.RcvE2ERatchetParams 'C.X448 -> Connection 'CDuplex -> AM () newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv _) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId, pqSupport} _ sqs) = unlessM ratchetExists $ do AgentConfig {e2eEncryptVRange} <- asks config @@ -2422,12 +2458,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, where rkHashRcv = rkHash k1Rcv k2Rcv rkHash k1 k2 = C.sha256Hash $ C.pubKeyBytes k1 <> C.pubKeyBytes k2 - ratchetExists :: m Bool + ratchetExists :: AM Bool ratchetExists = withStore' c $ \db -> do exists <- checkRatchetKeyHashExists db connId rkHashRcv unless exists $ addProcessedRatchetKeyHash db connId rkHashRcv pure exists - getSendRatchetKeys :: m (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) + getSendRatchetKeys :: AM (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) getSendRatchetKeys = case rss of RSOk -> sendReplyKey -- receiving client RSAllowed -> sendReplyKey @@ -2450,19 +2486,19 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, let cData'' = cData' {ratchetSyncState = RSRequired} :: ConnData conn'' = updateConnection cData'' conn' notify $ RSYNC RSRequired (Just RATCHET_SYNC) (connectionStats conn'') - notifyAgreed :: m () + notifyAgreed :: AM () notifyAgreed = do let cData'' = cData' {ratchetSyncState = RSAgreed} :: ConnData conn'' = updateConnection cData'' conn' notify . RSYNC RSAgreed Nothing $ connectionStats conn'' - recreateRatchet :: CR.Ratchet 'C.X448 -> m () + recreateRatchet :: CR.Ratchet 'C.X448 -> AM () recreateRatchet rc = withStore' c $ \db -> do setConnRatchetSync db connId RSAgreed deleteRatchet db connId createRatchet db connId rc -- compare public keys `k1` in AgentRatchetKey messages sent by self and other party -- to determine ratchet initilization ordering - initRatchet :: CR.RatchetVersions -> (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) -> m () + initRatchet :: CR.RatchetVersions -> (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) -> AM () initRatchet rcVs (pk1, pk2, pKem) | rkHash (C.publicKey pk1) (C.publicKey pk2) <= rkHashRcv = do rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 pk2 pKem e2eOtherPartyParams @@ -2482,43 +2518,45 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, | internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash | otherwise = MsgError MsgDuplicate -- this case is not possible -checkRQSwchStatus :: AgentMonad m => RcvQueue -> RcvSwitchStatus -> m () +checkRQSwchStatus :: RcvQueue -> RcvSwitchStatus -> AM () checkRQSwchStatus rq@RcvQueue {rcvSwchStatus} expected = unless (rcvSwchStatus == Just expected) $ switchStatusError rq expected rcvSwchStatus +{-# INLINE checkRQSwchStatus #-} -checkSQSwchStatus :: AgentMonad m => SndQueue -> SndSwitchStatus -> m () +checkSQSwchStatus :: SndQueue -> SndSwitchStatus -> AM () checkSQSwchStatus sq@SndQueue {sndSwchStatus} expected = unless (sndSwchStatus == Just expected) $ switchStatusError sq expected sndSwchStatus +{-# INLINE checkSQSwchStatus #-} -switchStatusError :: (SMPQueueRec q, AgentMonad m, Show a) => q -> a -> Maybe a -> m () +switchStatusError :: (SMPQueueRec q, Show a) => q -> a -> Maybe a -> AM () switchStatusError q expected actual = throwError . INTERNAL $ ("unexpected switch status, queueId=" <> show (queueId q)) <> (", expected=" <> show expected) <> (", actual=" <> show actual) -connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m () +connectReplyQueues :: AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> AM () connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = do clientVRange <- asks $ smpClientVRange . config case qInfo `proveCompatible` clientVRange of Nothing -> throwError $ AGENT A_VERSION Just qInfo' -> do - sq <- newSndQueue userId connId qInfo' + sq <- lift $ newSndQueue userId connId qInfo' sq' <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq enqueueConfirmation c cData sq' ownConnInfo Nothing -confirmQueueAsync :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> m () +confirmQueueAsync :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM () confirmQueueAsync c cData sq srv connInfo e2eEncryption_ subMode = do storeConfirmation c cData sq e2eEncryption_ =<< mkAgentConfirmation c cData sq srv connInfo subMode - submitPendingMsg c cData sq + lift $ submitPendingMsg c cData sq -confirmQueue :: forall m. AgentMonad m => Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> m () +confirmQueue :: Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM () confirmQueue (Compatible agentVersion) c cData@ConnData {connId, pqSupport} sq srv connInfo e2eEncryption_ subMode = do msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode sendConfirmation c sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed where - mkConfirmation :: AgentMessage -> m MsgBody + mkConfirmation :: AgentMessage -> AM MsgBody mkConfirmation aMessage = do -- the version to be used when PQSupport is disabled currentE2EVersion <- asks $ maxVersion . ($ PQSupportOff) . e2eEncryptVRange . config @@ -2528,17 +2566,17 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId, pqSupport} sq s (encConnInfo, _) <- agentRatchetEncrypt db cData (smpEncode aMessage) e2eEncConnInfoLength (Just pqEnc) currentE2EVersion pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo} -mkAgentConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> m AgentMessage +mkAgentConfirmation :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> AM AgentMessage mkAgentConfirmation c cData sq srv connInfo subMode = do qInfo <- createReplyQueue c cData sq subMode srv pure $ AgentConnInfoReply (qInfo :| []) connInfo -enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> m () +enqueueConfirmation :: AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> AM () enqueueConfirmation c cData sq connInfo e2eEncryption_ = do storeConfirmation c cData sq e2eEncryption_ $ AgentConnInfo connInfo - submitPendingMsg c cData sq + lift $ submitPendingMsg c cData sq -storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> AgentMessage -> m () +storeConfirmation :: AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> AgentMessage -> AM () storeConfirmation c cData@ConnData {connId, pqSupport, connAgentVersion = v} sq e2eEncryption_ agentMsg = do -- the version to be used when PQSupport is disabled currentE2EVersion <- asks $ maxVersion . ($ PQSupportOff) . e2eEncryptVRange . config @@ -2555,19 +2593,19 @@ storeConfirmation c cData@ConnData {connId, pqSupport, connAgentVersion = v} sq liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId -enqueueRatchetKeyMsgs :: forall m. AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> m () +enqueueRatchetKeyMsgs :: AgentClient -> ConnData -> NonEmpty SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> AM () enqueueRatchetKeyMsgs c cData (sq :| sqs) e2eEncryption = do msgId <- enqueueRatchetKey c cData sq e2eEncryption - mapM_ (enqueueSavedMessage c cData msgId) $ filter isActiveSndQ sqs + mapM_ (lift . enqueueSavedMessage c cData msgId) $ filter isActiveSndQ sqs -enqueueRatchetKey :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> m AgentMsgId +enqueueRatchetKey :: AgentClient -> ConnData -> SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> AM AgentMsgId enqueueRatchetKey c cData@ConnData {connId, pqSupport} sq e2eEncryption = do aVRange <- asks $ ($ pqSupport) . smpAgentVRange . config msgId <- storeRatchetKey $ maxVersion aVRange - submitPendingMsg c cData sq + lift $ submitPendingMsg c cData sq pure $ unId msgId where - storeRatchetKey :: VersionSMPA -> m InternalId + storeRatchetKey :: VersionSMPA -> AM InternalId storeRatchetKey agentVersion = withStore c $ \db -> runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId @@ -2587,7 +2625,7 @@ agentRatchetEncrypt :: DB.Connection -> ConnData -> ByteString -> (VersionSMPA - agentRatchetEncrypt db ConnData {connId, connAgentVersion = v, pqSupport} msg getPaddedLen pqEnc_ currentE2EVersion = do rc <- ExceptT $ getRatchet db connId let paddedLen = getPaddedLen v pqSupport - (encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg pqEnc_ currentE2EVersion + (encMsg, rc') <- withExceptT (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg pqEnc_ currentE2EVersion liftIO $ updateRatchet db connId rc' CR.SMDNoChange pure (encMsg, CR.rcSndKEM rc') @@ -2600,11 +2638,11 @@ agentRatchetDecrypt g db connId encAgentMsg = do agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption) agentRatchetDecrypt' g db connId rc encAgentMsg = do skipped <- liftIO $ getSkippedMsgKeys db connId - (agentMsgBody_, rc', skippedDiff) <- liftE (SEAgentError . cryptoError) $ CR.rcDecrypt g rc skipped encAgentMsg + (agentMsgBody_, rc', skippedDiff) <- withExceptT (SEAgentError . cryptoError) $ CR.rcDecrypt g rc skipped encAgentMsg liftIO $ updateRatchet db connId rc' skippedDiff liftEither $ bimap (SEAgentError . cryptoError) (,CR.rcRcvKEM rc') agentMsgBody_ -newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => UserId -> ConnId -> Compatible SMPQueueInfo -> m NewSndQueue +newSndQueue :: UserId -> ConnId -> Compatible SMPQueueInfo -> AM' NewSndQueue newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do C.AuthAlg a <- asks $ sndAuthAlg . config g <- asks random diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 3f979125f..7798fcadd 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -293,10 +293,11 @@ data AgentClient = AgentClient agentEnv :: Env } -getAgentWorker :: (AgentMonad' m, Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> ExceptT AgentErrorType m ()) -> m Worker +getAgentWorker :: (Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> AM ()) -> AM' Worker getAgentWorker = getAgentWorker' id pure +{-# INLINE getAgentWorker #-} -getAgentWorker' :: forall a k m. (AgentMonad' m, Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> ExceptT AgentErrorType m ()) -> m a +getAgentWorker' :: forall a k. (Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> AM ()) -> AM' a getAgentWorker' toW fromW name hasWork c key ws work = do atomically (getWorker >>= maybe createWorker whenExists) >>= \w -> runWorker w $> w where @@ -310,9 +311,9 @@ getAgentWorker' toW fromW name hasWork c key ws work = do | otherwise = pure w runWorker w = runWorkerAsync (toW w) runWork where - runWork :: m () + runWork :: AM' () runWork = tryAgentError' (work w) >>= restartOrDelete - restartOrDelete :: Either AgentErrorType () -> m () + restartOrDelete :: Either AgentErrorType () -> AM' () restartOrDelete e_ = do t <- liftIO getSystemTime maxRestarts <- asks $ maxWorkerRestartsPerMin . config @@ -350,7 +351,7 @@ newWorker c = do restarts <- newTVar $ RestartCount 0 0 pure Worker {workerId, doWork, action, restarts} -runWorkerAsync :: AgentMonad' m => Worker -> m () -> m () +runWorkerAsync :: Worker -> AM' () -> AM' () runWorkerAsync Worker {action} work = E.bracket (atomically $ takeTMVar action) -- get current action, locking to avoid race conditions @@ -394,6 +395,7 @@ data AgentStatsKey = AgentStatsKey } deriving (Eq, Ord, Show) +-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. newAgentClient :: Int -> InitialAgentServers -> Env -> STM AgentClient newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do let qSize = tbqSize $ config agentEnv @@ -469,13 +471,15 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = agentClientStore :: AgentClient -> SQLiteStore agentClientStore AgentClient {agentEnv = Env {store}} = store +{-# INLINE agentClientStore #-} agentDRG :: AgentClient -> TVar ChaChaDRG agentDRG AgentClient {agentEnv = Env {random}} = random +{-# INLINE agentDRG #-} class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where type Client msg = c | c -> msg - getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (Client msg) + getProtocolServerClient :: AgentClient -> TransportSession msg -> AM (Client msg) clientProtocolError :: err -> AgentErrorType closeProtocolServerClient :: Client msg -> IO () clientServer :: Client msg -> String @@ -509,7 +513,7 @@ instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where clientTransportHost = X.xftpTransportHost clientSessionTs = X.xftpSessionTs -getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m SMPClient +getSMPServerClient :: AgentClient -> SMPTransportSession -> AM SMPClient getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, _) = do unlessM (readTVarIO active) . throwError $ INACTIVE atomically (getTSessVar c tSess smpClients) @@ -520,13 +524,14 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, -- make it expensive to check for pending subscriptions. newClient v = newProtocolClient c tSess smpClients connectClient v - `catchAgentError` \e -> resubscribeSMPSession c tSess >> throwError e - connectClient :: SMPClientVar -> m SMPClient + `catchAgentError` \e -> lift (resubscribeSMPSession c tSess) >> throwError e + connectClient :: SMPClientVar -> AM SMPClient connectClient v = do - cfg <- getClientConfig c smpCfg + cfg <- lift $ getClientConfig c smpCfg g <- asks random env <- ask - liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient g tSess cfg (Just msgQ) $ clientDisconnected env v) + liftError' (protocolClientError SMP $ B.unpack $ strEncode srv) $ + getProtocolClient g tSess cfg (Just msgQ) $ clientDisconnected env v clientDisconnected :: Env -> SMPClientVar -> SMPClient -> IO () clientDisconnected env v client = do @@ -557,7 +562,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO () notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd) -resubscribeSMPSession :: AgentMonad' m => AgentClient -> SMPTransportSession -> m () +resubscribeSMPSession :: AgentClient -> SMPTransportSession -> AM' () resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess = atomically getWorkerVar >>= mapM_ (either newSubWorker (\_ -> pure ())) where @@ -585,12 +590,12 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess = whenM (isEmptyTMVar $ sessionVar v) retry removeTSessVar v tSess smpSubWorkers -reconnectSMPClient :: forall m. AgentMonad m => TVar Int -> AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> m () +reconnectSMPClient :: TVar Int -> AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> AM () reconnectSMPClient tc c tSess@(_, srv, _) qs = do NetworkConfig {tcpTimeout} <- readTVarIO $ useNetworkConfig c -- this allows 3x of timeout per batch of subscription (90 queues per batch empirically) let t = (length qs `div` 90 + 1) * tcpTimeout * 3 - t `timeout` resubscribe >>= \case + ExceptT (sequence <$> (t `timeout` runExceptT resubscribe)) >>= \case Just _ -> atomically $ writeTVar tc 0 Nothing -> do tc' <- atomically $ stateTVar tc $ \i -> (i + 1, i + 1) @@ -599,10 +604,10 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do msg = show tc' <> " consecutive subscription timeouts: " <> show (length qs) <> " queues, transport session: " <> show tSess atomically $ writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ err msg) where - resubscribe :: m () + resubscribe :: AM () resubscribe = do cs <- readTVarIO $ RQ.getConnections $ activeSubs c - rs <- subscribeQueues c $ L.toList qs + rs <- lift . subscribeQueues c $ L.toList qs let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs liftIO $ do let conns = filter (`M.notMember` cs) okConns @@ -616,7 +621,7 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO () notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd) -getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfTransportSession -> m NtfClient +getNtfServerClient :: AgentClient -> NtfTransportSession -> AM NtfClient getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do unlessM (readTVarIO active) . throwError $ INACTIVE atomically (getTSessVar c tSess ntfClients) @@ -624,11 +629,12 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d (newProtocolClient c tSess ntfClients connectClient) (waitForProtocolClient c tSess) where - connectClient :: NtfClientVar -> m NtfClient + connectClient :: NtfClientVar -> AM NtfClient connectClient v = do - cfg <- getClientConfig c ntfCfg + cfg <- lift $ getClientConfig c ntfCfg g <- asks random - liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient g tSess cfg Nothing $ clientDisconnected v) + liftError' (protocolClientError NTF $ B.unpack $ strEncode srv) $ + getProtocolClient g tSess cfg Nothing $ clientDisconnected v clientDisconnected :: NtfClientVar -> NtfClient -> IO () clientDisconnected v client = do @@ -637,7 +643,7 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv -getXFTPServerClient :: forall m. AgentMonad m => AgentClient -> XFTPTransportSession -> m XFTPClient +getXFTPServerClient :: AgentClient -> XFTPTransportSession -> AM XFTPClient getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@(userId, srv, _) = do unlessM (readTVarIO active) . throwError $ INACTIVE atomically (getTSessVar c tSess xftpClients) @@ -645,11 +651,12 @@ getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@ (newProtocolClient c tSess xftpClients connectClient) (waitForProtocolClient c tSess) where - connectClient :: XFTPClientVar -> m XFTPClient + connectClient :: XFTPClientVar -> AM XFTPClient connectClient v = do cfg <- asks $ xftpCfg . config xftpNetworkConfig <- readTVarIO useNetworkConfig - liftEitherError (protocolClientError XFTP $ B.unpack $ strEncode srv) (X.getXFTPClient tSess cfg {xftpNetworkConfig} $ clientDisconnected v) + liftError' (protocolClientError XFTP $ B.unpack $ strEncode srv) $ + X.getXFTPClient tSess cfg {xftpNetworkConfig} $ clientDisconnected v clientDisconnected :: XFTPClientVar -> XFTPClient -> IO () clientDisconnected v client = do @@ -671,6 +678,7 @@ getTSessVar c tSess vs = maybe (Left <$> newSessionVar) (pure . Right) =<< TM.lo removeTSessVar :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM () removeTSessVar = void .:. removeTSessVar' +{-# INLINE removeTSessVar #-} removeTSessVar' :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM Bool removeTSessVar' v tSess vs = @@ -678,7 +686,7 @@ removeTSessVar' v tSess vs = Just v' | sessionVarId v == sessionVarId v' -> TM.delete tSess vs $> True _ -> pure False -waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (Client msg) +waitForProtocolClient :: ProtocolTypeI (ProtoType msg) => AgentClient -> TransportSession msg -> ClientVar msg -> AM (Client msg) waitForProtocolClient c (_, srv, _) v = do NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) @@ -689,14 +697,14 @@ waitForProtocolClient c (_, srv, _) v = do -- clientConnected arg is only passed for SMP server newProtocolClient :: - forall v err msg m. - (AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => + forall v err msg. + (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> TMap (TransportSession msg) (ClientVar msg) -> - (ClientVar msg -> m (Client msg)) -> + (ClientVar msg -> AM (Client msg)) -> ClientVar msg -> - m (Client msg) + AM (Client msg) newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v = tryAgentError (connectClient v) >>= \case Right client -> do @@ -715,14 +723,14 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v = hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost -getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> m (ProtocolClientConfig v) +getClientConfig :: AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> AM' (ProtocolClientConfig v) getClientConfig AgentClient {useNetworkConfig} cfgSel = do cfg <- asks $ cfgSel . config networkConfig <- readTVarIO useNetworkConfig pure cfg {networkConfig} -closeAgentClient :: MonadIO m => AgentClient -> m () -closeAgentClient c = liftIO $ do +closeAgentClient :: AgentClient -> IO () +closeAgentClient c = do atomically $ writeTVar (active c) False closeProtocolServerClients c smpClients closeProtocolServerClients c ntfClients @@ -750,9 +758,11 @@ cancelWorker Worker {doWork, action} = do waitUntilActive :: AgentClient -> STM () waitUntilActive c = unlessM (readTVar $ active c) retry +{-# INLINE waitUntilActive #-} throwWhenInactive :: AgentClient -> STM () throwWhenInactive c = unlessM (readTVar $ active c) $ throwSTM ThreadKilled +{-# INLINE throwWhenInactive #-} -- this function is used to remove workers once delivery is complete, not when it is removed from the map throwWhenNoDelivery :: AgentClient -> SndQueue -> STM () @@ -779,82 +789,98 @@ closeClient_ c v = do Just (Right client) -> closeProtocolServerClient client `catchAll_` pure () _ -> pure () -closeXFTPServerClient :: AgentMonad' m => AgentClient -> UserId -> XFTPServer -> FileDigest -> m () +closeXFTPServerClient :: AgentClient -> UserId -> XFTPServer -> FileDigest -> IO () closeXFTPServerClient c userId server (FileDigest chunkDigest) = - mkTransportSession c userId server chunkDigest >>= liftIO . closeClient c xftpClients + mkTransportSession c userId server chunkDigest >>= closeClient c xftpClients -withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a -withConnLock _ "" _ = id -withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId name +withConnLock :: AgentClient -> ConnId -> String -> AM a -> AM a +withConnLock c connId name = ExceptT . withConnLock' c connId name . runExceptT +{-# INLINE withConnLock #-} -withInvLock :: MonadUnliftIO m => AgentClient -> ByteString -> String -> m a -> m a -withInvLock AgentClient {invLocks} = withLockMap_ invLocks +withConnLock' :: AgentClient -> ConnId -> String -> AM' a -> AM' a +withConnLock' _ "" _ = id +withConnLock' AgentClient {connLocks} connId name = withLockMap_ connLocks connId name +{-# INLINE withConnLock' #-} -withConnLocks :: MonadUnliftIO m => AgentClient -> [ConnId] -> String -> m a -> m a +withInvLock :: AgentClient -> ByteString -> String -> AM a -> AM a +withInvLock c key name = ExceptT . withInvLock' c key name . runExceptT +{-# INLINE withInvLock #-} + +withInvLock' :: AgentClient -> ByteString -> String -> AM' a -> AM' a +withInvLock' AgentClient {invLocks} = withLockMap_ invLocks +{-# INLINE withInvLock' #-} + +withConnLocks :: AgentClient -> [ConnId] -> String -> AM' a -> AM' a withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null) +{-# INLINE withConnLocks #-} withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a withLockMap_ = withGetLock . getMapLock +{-# INLINE withLockMap_ #-} withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a withLocksMap_ = withGetLocks . getMapLock +{-# INLINE withLocksMap_ #-} getMapLock :: Ord k => TMap k Lock -> k -> STM Lock getMapLock locks key = TM.lookup key locks >>= maybe newLock pure where newLock = createLock >>= \l -> TM.insert key l locks $> l -withClient_ :: forall a m v err msg. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a +withClient_ :: forall a v err msg. ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> AM a) -> AM a withClient_ c tSess@(userId, srv, _) statCmd action = do cl <- getProtocolServerClient c tSess (action cl <* stat cl "OK") `catchAgentError` logServerError cl where stat cl = liftIO . incClientStat c userId cl statCmd - logServerError :: Client msg -> AgentErrorType -> m a + logServerError :: Client msg -> AgentErrorType -> AM a logServerError cl e = do logServer "<--" c srv "" $ strEncode e stat cl $ strEncode e throwError e -withLogClient_ :: (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a +withLogClient_ :: ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> AM a) -> AM a withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do logServer "-->" c srv entId cmdStr res <- withClient_ c tSess cmdStr action logServer "<--" c srv entId "OK" return res -withClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a +withClient :: forall v err msg a. ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> AM a withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client +{-# INLINE withClient #-} -withLogClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a +withLogClient :: forall v err msg a. ProtocolServerClient v err msg => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> AM a withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client +{-# INLINE withLogClient #-} -withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a +withSMPClient :: SMPQueueRec q => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> AM a withSMPClient c q cmdStr action = do - tSess <- mkSMPTransportSession c q + tSess <- liftIO $ mkSMPTransportSession c q withLogClient c tSess (queueId q) cmdStr action -withSMPClient_ :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> m a) -> m a +withSMPClient_ :: SMPQueueRec q => AgentClient -> q -> ByteString -> (SMPClient -> AM a) -> AM a withSMPClient_ c q cmdStr action = do - tSess <- mkSMPTransportSession c q + tSess <- liftIO $ mkSMPTransportSession c q withLogClient_ c tSess (queueId q) cmdStr action -withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT NtfClientError IO a) -> m a +withNtfClient :: AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT NtfClientError IO a) -> AM a withNtfClient c srv = withLogClient c (0, srv, Nothing) withXFTPClient :: - (AgentMonad m, ProtocolServerClient v err msg) => + ProtocolServerClient v err msg => AgentClient -> (UserId, ProtoServer msg, EntityId) -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO b) -> - m b + AM b withXFTPClient c (userId, srv, entityId) cmdStr action = do - tSess <- mkTransportSession c userId srv entityId + tSess <- liftIO $ mkTransportSession c userId srv entityId withLogClient c tSess entityId cmdStr action -liftClient :: (AgentMonad m, Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> m a +liftClient :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> AM a liftClient protocolError_ = liftError . protocolClientError protocolError_ +{-# INLINE liftClient #-} protocolClientError :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ProtocolClientError err -> AgentErrorType protocolClientError protocolError_ host = \case @@ -889,7 +915,7 @@ data ProtocolTestFailure = ProtocolTestFailure } deriving (Eq, Show) -runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe ProtocolTestFailure) +runSMPServerTest :: AgentClient -> UserId -> SMPServerWithAuth -> AM' (Maybe ProtocolTestFailure) runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do cfg <- getClientConfig c smpCfg C.AuthAlg ra <- asks $ rcvAuthAlg . config @@ -915,7 +941,7 @@ runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do testErr :: ProtocolTestStep -> SMPClientError -> ProtocolTestFailure testErr step = ProtocolTestFailure step . protocolClientError SMP addr -runXFTPServerTest :: forall m. AgentMonad m => AgentClient -> UserId -> XFTPServerWithAuth -> m (Maybe ProtocolTestFailure) +runXFTPServerTest :: AgentClient -> UserId -> XFTPServerWithAuth -> AM' (Maybe ProtocolTestFailure) runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do cfg <- asks $ xftpCfg . config g <- asks random @@ -949,7 +975,7 @@ runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do testErr step = ProtocolTestFailure step . protocolClientError XFTP addr chSize :: Integral a => a chSize = kb 64 - getTempFilePath :: FilePath -> m FilePath + getTempFilePath :: FilePath -> AM' FilePath getTempFilePath workPath = do ts <- liftIO getCurrentTime let isoTime = formatTime defaultTimeLocale "%Y-%m-%dT%H%M%S.%6q" ts @@ -963,7 +989,7 @@ runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do createTestChunk :: FilePath -> IO () createTestChunk fp = B.writeFile fp =<< atomically . C.randomBytes chSize =<< C.newRandom -runNTFServerTest :: AgentMonad m => AgentClient -> UserId -> NtfServerWithAuth -> m (Maybe ProtocolTestFailure) +runNTFServerTest :: AgentClient -> UserId -> NtfServerWithAuth -> AM' (Maybe ProtocolTestFailure) runNTFServerTest c userId (ProtoServerWithAuth srv _) = do cfg <- getClientConfig c ntfCfg C.AuthAlg a <- asks $ rcvAuthAlg . config @@ -987,27 +1013,32 @@ runNTFServerTest c userId (ProtoServerWithAuth srv _) = do testErr :: ProtocolTestStep -> SMPClientError -> ProtocolTestFailure testErr step = ProtocolTestFailure step . protocolClientError NTF addr -getXFTPWorkPath :: AgentMonad m => m FilePath +getXFTPWorkPath :: AM' FilePath getXFTPWorkPath = do workDir <- readTVarIO =<< asks (xftpWorkDir . xftpAgent) maybe getTemporaryDirectory pure workDir -mkTransportSession :: AgentMonad' m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg) +mkTransportSession :: AgentClient -> UserId -> ProtoServer msg -> EntityId -> IO (TransportSession msg) mkTransportSession c userId srv entityId = mkTSession userId srv entityId <$> getSessionMode c +{-# INLINE mkTransportSession #-} mkTSession :: UserId -> ProtoServer msg -> EntityId -> TransportSessionMode -> TransportSession msg mkTSession userId srv entityId mode = (userId, srv, if mode == TSMEntity then Just entityId else Nothing) +{-# INLINE mkTSession #-} -mkSMPTransportSession :: (AgentMonad' m, SMPQueueRec q) => AgentClient -> q -> m SMPTransportSession +mkSMPTransportSession :: SMPQueueRec q => AgentClient -> q -> IO SMPTransportSession mkSMPTransportSession c q = mkSMPTSession q <$> getSessionMode c +{-# INLINE mkSMPTransportSession #-} mkSMPTSession :: SMPQueueRec q => q -> TransportSessionMode -> SMPTransportSession mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q) +{-# INLINE mkSMPTSession #-} -getSessionMode :: AgentMonad' m => AgentClient -> m TransportSessionMode +getSessionMode :: AgentClient -> IO TransportSessionMode getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig +{-# INLINE getSessionMode #-} -newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri) +newRcvQueue :: AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> AM (NewRcvQueue, SMPQueueUri) newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do C.AuthAlg a <- asks (rcvAuthAlg . config) g <- asks random @@ -1015,10 +1046,10 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do (dhKey, privDhKey) <- atomically $ C.generateKeyPair g (e2eDhKey, e2ePrivKey) <- atomically $ C.generateKeyPair g logServer "-->" c srv "" "NEW" - tSess <- mkTransportSession c userId srv connId + tSess <- liftIO $ mkTransportSession c userId srv connId QIK {rcvId, sndId, rcvPublicDhKey} <- withClient c tSess "NEW" $ \smp -> createSMPQueue smp rKeys dhKey auth subMode - logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId] + liftIO . logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId] let rq = RcvQueue { userId, @@ -1057,14 +1088,16 @@ temporaryAgentError = \case BROKER _ TIMEOUT -> True INACTIVE -> True _ -> False +{-# INLINE temporaryAgentError #-} temporaryOrHostError :: AgentErrorType -> Bool temporaryOrHostError = \case BROKER _ HOST -> True e -> temporaryAgentError e +{-# INLINE temporaryOrHostError #-} -- | Subscribe to queues. The list of results can have a different order. -subscribeQueues :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())] +subscribeQueues :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())] subscribeQueues c qs = do (errs, qs') <- partitionEithers <$> mapM checkQueue qs atomically $ do @@ -1088,11 +1121,11 @@ subscribeQueues c qs = do type BatchResponses e r = (NonEmpty (RcvQueue, Either e r)) -- statBatchSize is not used to batch the commands, only for traffic statistics -sendTSessionBatches :: forall m q r. AgentMonad' m => ByteString -> Int -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses SMPClientError r)) -> AgentClient -> [q] -> m [(RcvQueue, Either AgentErrorType r)] +sendTSessionBatches :: forall q r. ByteString -> Int -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses SMPClientError r)) -> AgentClient -> [q] -> AM' [(RcvQueue, Either AgentErrorType r)] sendTSessionBatches statCmd statBatchSize toRQ action c qs = concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues) where - batchQueues :: m [(SMPTransportSession, NonEmpty q)] + batchQueues :: AM' [(SMPTransportSession, NonEmpty q)] batchQueues = do mode <- sessionMode <$> readTVarIO (useNetworkConfig c) pure . M.assocs $ foldl' (batch mode) M.empty qs @@ -1100,7 +1133,7 @@ sendTSessionBatches statCmd statBatchSize toRQ action c qs = batch mode m q = let tSess = mkSMPTSession (toRQ q) mode in M.alter (Just . maybe [q] (q <|)) tSess m - sendClientBatch :: (SMPTransportSession, NonEmpty q) -> m (BatchResponses AgentErrorType r) + sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses AgentErrorType r) sendClientBatch (tSess@(userId, srv, _), qs') = tryAgentError' (getSMPServerClient c tSess) >>= \case Left e -> pure $ L.map ((,Left e) . toRQ) qs' @@ -1120,7 +1153,7 @@ sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs) where queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) -addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m () +addSubscription :: AgentClient -> RcvQueue -> IO () addSubscription c rq@RcvQueue {connId} = atomically $ do modifyTVar' (subscrConns c) $ S.insert connId RQ.addQueue rq $ activeSubs c @@ -1128,6 +1161,7 @@ addSubscription c rq@RcvQueue {connId} = atomically $ do hasActiveSubscription :: AgentClient -> ConnId -> STM Bool hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c +{-# INLINE hasActiveSubscription #-} removeSubscription :: AgentClient -> ConnId -> STM () removeSubscription c connId = do @@ -1137,19 +1171,23 @@ removeSubscription c connId = do getSubscriptions :: AgentClient -> STM (Set ConnId) getSubscriptions = readTVar . subscrConns +{-# INLINE getSubscriptions #-} logServer :: MonadIO m => ByteString -> AgentClient -> ProtocolServer s -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr] +{-# INLINE logServer #-} showServer :: ProtocolServer s -> ByteString showServer ProtocolServer {host, port} = strEncode host <> B.pack (if null port then "" else ':' : port) +{-# INLINE showServer #-} logSecret :: ByteString -> ByteString logSecret bs = encode $ B.take 3 bs +{-# INLINE logSecret #-} -sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m () +sendConfirmation :: AgentClient -> SndQueue -> ByteString -> AM () sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation = withSMPClient_ c sq "SEND " $ \smp -> do let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation @@ -1157,21 +1195,21 @@ sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubK liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database" -sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () +sendInvitation :: AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> AM () sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do - tSess <- mkTransportSession c userId smpServer senderId + tSess <- liftIO $ mkTransportSession c userId smpServer senderId withLogClient_ c tSess senderId "SEND " $ \smp -> do msg <- mkInvitation liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing senderId MsgFlags {notification = True} msg where - mkInvitation :: m ByteString + mkInvitation :: AM ByteString -- this is only encrypted with per-queue E2E, not with double ratchet mkInvitation = do let agentEnvelope = AgentInvitation {agentVersion, connReq, connInfo} agentCbEncryptOnce v dhPublicKey . smpEncode $ SMP.ClientMessage SMP.PHEmpty (smpEncode agentEnvelope) -getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta) +getQueueMessage :: AgentClient -> RcvQueue -> AM (Maybe SMPMsgMeta) getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do atomically createTakeGetLock msg_ <- withSMPClient c rq "GET" $ \smp -> @@ -1186,23 +1224,23 @@ getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do takeTMVar l pure $ Just l -decryptSMPMessage :: AgentMonad m => RcvQueue -> SMP.RcvMessage -> m SMP.ClientRcvMsgBody +decryptSMPMessage :: RcvQueue -> SMP.RcvMessage -> AM SMP.ClientRcvMsgBody decryptSMPMessage rq SMP.RcvMessage {msgId, msgBody = SMP.EncRcvMsgBody body} = liftEither . parse SMP.clientRcvMsgBodyP (AGENT A_MESSAGE) =<< decrypt body where decrypt = agentCbDecrypt (rcvDhSecret rq) (C.cbNonce msgId) -secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicAuthKey -> m () +secureQueue :: AgentClient -> RcvQueue -> SndPublicAuthKey -> AM () secureQueue c rq@RcvQueue {rcvId, rcvPrivateKey} senderKey = withSMPClient c rq "KEY " $ \smp -> secureSMPQueue smp rcvPrivateKey rcvId senderKey -enableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> SMP.NtfPublicAuthKey -> SMP.RcvNtfPublicDhKey -> m (SMP.NotifierId, SMP.RcvNtfPublicDhKey) +enableQueueNotifications :: AgentClient -> RcvQueue -> SMP.NtfPublicAuthKey -> SMP.RcvNtfPublicDhKey -> AM (SMP.NotifierId, SMP.RcvNtfPublicDhKey) enableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey = withSMPClient c rq "NKEY " $ \smp -> enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey -enableQueuesNtfs :: forall m. AgentMonad' m => AgentClient -> [(RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)] -> m [(RcvQueue, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))] +enableQueuesNtfs :: AgentClient -> [(RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey)] -> AM' [(RcvQueue, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))] enableQueuesNtfs = sendTSessionBatches "NKEY" 90 fst3 enableQueues_ where fst3 (x, _, _) = x @@ -1211,15 +1249,15 @@ enableQueuesNtfs = sendTSessionBatches "NKEY" 90 fst3 enableQueues_ queueCreds :: (RcvQueue, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) -> (SMP.RcvPrivateAuthKey, SMP.RecipientId, SMP.NtfPublicAuthKey, SMP.RcvNtfPublicDhKey) queueCreds (RcvQueue {rcvPrivateKey, rcvId}, notifierKey, rcvNtfPublicDhKey) = (rcvPrivateKey, rcvId, notifierKey, rcvNtfPublicDhKey) -disableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> m () +disableQueueNotifications :: AgentClient -> RcvQueue -> AM () disableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} = withSMPClient c rq "NDEL" $ \smp -> disableSMPQueueNotifications smp rcvPrivateKey rcvId -disableQueuesNtfs :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())] +disableQueuesNtfs :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())] disableQueuesNtfs = sendTSessionBatches "NDEL" 90 id $ sendBatch disableSMPQueuesNtfs -sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m () +sendAck :: AgentClient -> RcvQueue -> MsgId -> AM () sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do withSMPClient c rq ("ACK:" <> logSecret msgId) $ \smp -> ackSMPMessage smp rcvPrivateKey rcvId msgId @@ -1233,93 +1271,93 @@ releaseGetLock :: AgentClient -> RcvQueue -> STM () releaseGetLock c RcvQueue {server, rcvId} = TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ()) -suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m () +suspendQueue :: AgentClient -> RcvQueue -> AM () suspendQueue c rq@RcvQueue {rcvId, rcvPrivateKey} = withSMPClient c rq "OFF" $ \smp -> suspendSMPQueue smp rcvPrivateKey rcvId -deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m () +deleteQueue :: AgentClient -> RcvQueue -> AM () deleteQueue c rq@RcvQueue {rcvId, rcvPrivateKey} = do withSMPClient c rq "DEL" $ \smp -> deleteSMPQueue smp rcvPrivateKey rcvId -deleteQueues :: forall m. AgentMonad' m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())] +deleteQueues :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())] deleteQueues = sendTSessionBatches "DEL" 90 id $ sendBatch deleteSMPQueues -sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m () +sendAgentMessage :: AgentClient -> SndQueue -> MsgFlags -> ByteString -> AM () sendAgentMessage c sq@SndQueue {sndId, sndPrivateKey} msgFlags agentMsg = withSMPClient_ c sq "SEND " $ \smp -> do let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg liftClient SMP (clientServer smp) $ sendSMPMessage smp (Just sndPrivateKey) sndId msgFlags msg -agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> NtfPublicAuthKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519) +agentNtfRegisterToken :: AgentClient -> NtfToken -> NtfPublicAuthKey -> C.PublicKeyX25519 -> AM (NtfTokenId, C.PublicKeyX25519) agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey = withClient c (0, ntfServer, Nothing) "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey) -agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m () +agentNtfVerifyToken :: AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> AM () agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code = withNtfClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code -agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus +agentNtfCheckToken :: AgentClient -> NtfTokenId -> NtfToken -> AM NtfTknStatus agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} = withNtfClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId -agentNtfReplaceToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> m () +agentNtfReplaceToken :: AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> AM () agentNtfReplaceToken c tknId NtfToken {ntfServer, ntfPrivKey} token = withNtfClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token -agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m () +agentNtfDeleteToken :: AgentClient -> NtfTokenId -> NtfToken -> AM () agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} = withNtfClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId -agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m () +agentNtfEnableCron :: AgentClient -> NtfTokenId -> NtfToken -> Word16 -> AM () agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval = withNtfClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval -agentNtfCreateSubscription :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> SMP.NtfPrivateAuthKey -> m NtfSubscriptionId +agentNtfCreateSubscription :: AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> SMP.NtfPrivateAuthKey -> AM NtfSubscriptionId agentNtfCreateSubscription c tknId NtfToken {ntfServer, ntfPrivKey} smpQueue nKey = withNtfClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey) -agentNtfCheckSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m NtfSubStatus +agentNtfCheckSubscription :: AgentClient -> NtfSubscriptionId -> NtfToken -> AM NtfSubStatus agentNtfCheckSubscription c subId NtfToken {ntfServer, ntfPrivKey} = withNtfClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId -agentNtfDeleteSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m () +agentNtfDeleteSubscription :: AgentClient -> NtfSubscriptionId -> NtfToken -> AM () agentNtfDeleteSubscription c subId NtfToken {ntfServer, ntfPrivKey} = withNtfClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId -agentXFTPDownloadChunk :: AgentMonad m => AgentClient -> UserId -> FileDigest -> RcvFileChunkReplica -> XFTPRcvChunkSpec -> m () +agentXFTPDownloadChunk :: AgentClient -> UserId -> FileDigest -> RcvFileChunkReplica -> XFTPRcvChunkSpec -> AM () agentXFTPDownloadChunk c userId (FileDigest chunkDigest) RcvFileChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey} chunkSpec = do g <- asks random withXFTPClient c (userId, server, chunkDigest) "FGET" $ \xftp -> X.downloadXFTPChunk g xftp replicaKey fId chunkSpec -agentXFTPNewChunk :: AgentMonad m => AgentClient -> SndFileChunk -> Int -> XFTPServerWithAuth -> m NewSndChunkReplica +agentXFTPNewChunk :: AgentClient -> SndFileChunk -> Int -> XFTPServerWithAuth -> AM NewSndChunkReplica agentXFTPNewChunk c SndFileChunk {userId, chunkSpec = XFTPChunkSpec {chunkSize}, digest = FileDigest chunkDigest} n (ProtoServerWithAuth srv auth) = do rKeys <- xftpRcvKeys n (sndKey, replicaKey) <- atomically . C.generateAuthKeyPair C.SEd25519 =<< asks random let fileInfo = FileInfo {sndKey, size = fromIntegral chunkSize, digest = chunkDigest} logServer "-->" c srv "" "FNEW" - tSess <- mkTransportSession c userId srv chunkDigest + tSess <- liftIO $ mkTransportSession c userId srv chunkDigest (sndId, rIds) <- withClient c tSess "FNEW" $ \xftp -> X.createXFTPChunk xftp replicaKey fileInfo (L.map fst rKeys) auth logServer "<--" c srv "" $ B.unwords ["SIDS", logSecret sndId] pure NewSndChunkReplica {server = srv, replicaId = ChunkReplicaId sndId, replicaKey, rcvIdsKeys = L.toList $ xftpRcvIdsKeys rIds rKeys} -agentXFTPUploadChunk :: AgentMonad m => AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> XFTPChunkSpec -> m () +agentXFTPUploadChunk :: AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> XFTPChunkSpec -> AM () agentXFTPUploadChunk c userId (FileDigest chunkDigest) SndFileChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey} chunkSpec = withXFTPClient c (userId, server, chunkDigest) "FPUT" $ \xftp -> X.uploadXFTPChunk xftp replicaKey fId chunkSpec -agentXFTPAddRecipients :: AgentMonad m => AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> Int -> m (NonEmpty (ChunkReplicaId, C.APrivateAuthKey)) +agentXFTPAddRecipients :: AgentClient -> UserId -> FileDigest -> SndFileChunkReplica -> Int -> AM (NonEmpty (ChunkReplicaId, C.APrivateAuthKey)) agentXFTPAddRecipients c userId (FileDigest chunkDigest) SndFileChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey} n = do rKeys <- xftpRcvKeys n rIds <- withXFTPClient c (userId, server, chunkDigest) "FADD" $ \xftp -> X.addXFTPRecipients xftp replicaKey fId (L.map fst rKeys) pure $ xftpRcvIdsKeys rIds rKeys -agentXFTPDeleteChunk :: AgentMonad m => AgentClient -> UserId -> DeletedSndChunkReplica -> m () +agentXFTPDeleteChunk :: AgentClient -> UserId -> DeletedSndChunkReplica -> AM () agentXFTPDeleteChunk c userId DeletedSndChunkReplica {server, replicaId = ChunkReplicaId fId, replicaKey, chunkDigest = FileDigest chunkDigest} = withXFTPClient c (userId, server, chunkDigest) "FDEL" $ \xftp -> X.deleteXFTPChunk xftp replicaKey fId -xftpRcvKeys :: AgentMonad m => Int -> m (NonEmpty C.AAuthKeyPair) +xftpRcvKeys :: Int -> AM (NonEmpty C.AAuthKeyPair) xftpRcvKeys n = do rKeys <- atomically . replicateM n . C.generateAuthKeyPair C.SEd25519 =<< asks random case L.nonEmpty rKeys of @@ -1329,7 +1367,7 @@ xftpRcvKeys n = do xftpRcvIdsKeys :: NonEmpty ByteString -> NonEmpty C.AAuthKeyPair -> NonEmpty (ChunkReplicaId, C.APrivateAuthKey) xftpRcvIdsKeys rIds rKeys = L.map ChunkReplicaId rIds `L.zip` L.map snd rKeys -agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString +agentCbEncrypt :: SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> AM ByteString agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do cmNonce <- atomically . C.randomCbNonce =<< asks random let paddedLen = maybe SMP.e2eEncMessageLength (const SMP.e2eEncConfirmationLength) e2ePubKey @@ -1340,7 +1378,7 @@ agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody} -- add encoding as AgentInvitation'? -agentCbEncryptOnce :: AgentMonad m => VersionSMPC -> C.PublicKeyX25519 -> ByteString -> m ByteString +agentCbEncryptOnce :: VersionSMPC -> C.PublicKeyX25519 -> ByteString -> AM ByteString agentCbEncryptOnce clientVersion dhRcvPubKey msg = do g <- asks random (dhSndPubKey, dhSndPrivKey) <- atomically $ C.generateKeyPair g @@ -1354,7 +1392,7 @@ agentCbEncryptOnce clientVersion dhRcvPubKey msg = do -- | NaCl crypto-box decrypt - both for messages received from the server -- and per-queue E2E encrypted messages from the sender that were inside. -agentCbDecrypt :: AgentMonad m => C.DhSecretX25519 -> C.CbNonce -> ByteString -> m ByteString +agentCbDecrypt :: C.DhSecretX25519 -> C.CbNonce -> ByteString -> AM ByteString agentCbDecrypt dhSecret nonce msg = liftEither . first cryptoError $ C.cbDecrypt dhSecret nonce msg @@ -1373,10 +1411,11 @@ cryptoError = \case where c = AGENT . A_CRYPTO -waitForWork :: AgentMonad' m => TMVar () -> m () +waitForWork :: MonadIO m => TMVar () -> m () waitForWork = void . atomically . readTMVar +{-# INLINE waitForWork #-} -withWork :: AgentMonad m => AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> m ()) -> m () +withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> AM ()) -> AM () withWork c doWork getWork action = withStore' c getWork >>= \case Right (Just r) -> action r @@ -1389,12 +1428,15 @@ withWork c doWork getWork action = noWorkToDo :: TMVar () -> IO () noWorkToDo = void . atomically . tryTakeTMVar +{-# INLINE noWorkToDo #-} hasWorkToDo :: Worker -> STM () hasWorkToDo = hasWorkToDo' . doWork +{-# INLINE hasWorkToDo #-} hasWorkToDo' :: TMVar () -> STM () hasWorkToDo' = void . (`tryPutTMVar` ()) +{-# INLINE hasWorkToDo' #-} endAgentOperation :: AgentClient -> AgentOperation -> STM () endAgentOperation c op = endOperation c op $ case op of @@ -1438,6 +1480,7 @@ endOperation c op endedAction = do whenSuspending :: AgentClient -> STM () -> STM () whenSuspending c = whenM ((== ASSuspending) <$> readTVar (agentState c)) +{-# INLINE whenSuspending #-} beginAgentOperation :: AgentClient -> AgentOperation -> STM () beginAgentOperation c op = do @@ -1457,20 +1500,22 @@ agentOperationBracket c op check action = waitUntilForeground :: AgentClient -> STM () waitUntilForeground c = unlessM ((ASForeground ==) <$> readTVar (agentState c)) retry +{-# INLINE waitUntilForeground #-} -withStore' :: AgentMonad m => AgentClient -> (DB.Connection -> IO a) -> m a +withStore' :: AgentClient -> (DB.Connection -> IO a) -> AM a withStore' c action = withStore c $ fmap Right . action +{-# INLINE withStore' #-} -withStore :: AgentMonad m => AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a +withStore :: AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> AM a withStore c action = do st <- asks store - liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $ + withExceptT storeError . ExceptT . liftIO . agentOperationBracket c AODatabase (\_ -> pure ()) $ withTransaction st action `E.catch` handleInternal "" where handleInternal :: String -> E.SomeException -> IO (Either StoreError a) handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr -withStoreBatch :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> m (t (Either AgentErrorType a)) +withStoreBatch :: Traversable t => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> AM' (t (Either AgentErrorType a)) withStoreBatch c actions = do st <- asks store liftIO . agentOperationBracket c AODatabase (\_ -> pure ()) $ @@ -1480,8 +1525,9 @@ withStoreBatch c actions = do handleInternal :: E.SomeException -> IO (Either AgentErrorType a) handleInternal = pure . Left . INTERNAL . show -withStoreBatch' :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO a)) -> m (t (Either AgentErrorType a)) +withStoreBatch' :: Traversable t => AgentClient -> (DB.Connection -> t (IO a)) -> AM' (t (Either AgentErrorType a)) withStoreBatch' c actions = withStoreBatch c (fmap (fmap Right) . actions) +{-# INLINE withStoreBatch' #-} storeError :: StoreError -> AgentErrorType storeError = \case @@ -1505,6 +1551,7 @@ incStat AgentClient {agentStats} n k = do incClientStat :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO () incClientStat c userId pc = incClientStatN c userId pc 1 +{-# INLINE incClientStat #-} incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO () incServerStat c userId ProtocolServer {host} cmd res = do @@ -1523,27 +1570,28 @@ userServers :: forall p. (ProtocolTypeI p, UserProtocol p) => AgentClient -> TMa userServers c = case protocolTypeI @p of SPSMP -> smpServers c SPXFTP -> xftpServers c +{-# INLINE userServers #-} -pickServer :: forall p m. AgentMonad' m => NonEmpty (ProtoServerWithAuth p) -> m (ProtoServerWithAuth p) +pickServer :: forall p. NonEmpty (ProtoServerWithAuth p) -> AM (ProtoServerWithAuth p) pickServer = \case srv :| [] -> pure srv servers -> do gen <- asks randomServer atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1)) -getNextServer :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> [ProtocolServer p] -> m (ProtoServerWithAuth p) +getNextServer :: forall p. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> [ProtocolServer p] -> AM (ProtoServerWithAuth p) getNextServer c userId usedSrvs = withUserServers c userId $ \srvs -> case L.nonEmpty $ deleteFirstsBy sameSrvAddr' (L.toList srvs) (map noAuthSrv usedSrvs) of Just srvs' -> pickServer srvs' _ -> pickServer srvs -withUserServers :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> (NonEmpty (ProtoServerWithAuth p) -> m a) -> m a +withUserServers :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> (NonEmpty (ProtoServerWithAuth p) -> AM a) -> AM a withUserServers c userId action = atomically (TM.lookup userId $ userServers c) >>= \case Just srvs -> action srvs _ -> throwError $ INTERNAL "unknown userId - no user servers" -withNextSrv :: forall p m a. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> m a) -> m a +withNextSrv :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> TVar [ProtocolServer p] -> [ProtocolServer p] -> (ProtoServerWithAuth p -> AM a) -> AM a withNextSrv c userId usedSrvs initUsed action = do used <- readTVarIO usedSrvs srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId used @@ -1564,7 +1612,7 @@ data SubscriptionsInfo = SubscriptionsInfo } deriving (Show) -getAgentSubscriptions :: MonadIO m => AgentClient -> m SubscriptionsInfo +getAgentSubscriptions :: AgentClient -> IO SubscriptionsInfo getAgentSubscriptions c = do activeSubscriptions <- getSubs activeSubs pendingSubscriptions <- getSubs pendingSubs @@ -1600,7 +1648,7 @@ data WorkersDetails = WorkersDetails } deriving (Show) -getAgentWorkersDetails :: MonadIO m => AgentClient -> m AgentWorkersDetails +getAgentWorkersDetails :: AgentClient -> IO AgentWorkersDetails getAgentWorkersDetails AgentClient {smpClients, ntfClients, xftpClients, smpDeliveryWorkers, asyncCmdWorkers, smpSubWorkers, agentEnv} = do smpClients_ <- textKeys <$> readTVarIO smpClients ntfClients_ <- textKeys <$> readTVarIO ntfClients @@ -1632,7 +1680,7 @@ getAgentWorkersDetails AgentClient {smpClients, ntfClients, xftpClients, smpDeli textKeys = map textKey . M.keys textKey :: StrEncoding k => k -> Text textKey = decodeASCII . strEncode - workerStats :: (StrEncoding k, MonadIO m) => Map k Worker -> m (Map Text WorkersDetails) + workerStats :: StrEncoding k => Map k Worker -> IO (Map Text WorkersDetails) workerStats ws = fmap M.fromList . forM (M.toList ws) $ \(qa, Worker {restarts, doWork, action}) -> do RestartCount {restartCount} <- readTVarIO restarts hasWork <- atomically $ not <$> isEmptyTMVar doWork @@ -1664,7 +1712,7 @@ data WorkersSummary = WorkersSummary } deriving (Show) -getAgentWorkersSummary :: MonadIO m => AgentClient -> m AgentWorkersSummary +getAgentWorkersSummary :: AgentClient -> IO AgentWorkersSummary getAgentWorkersSummary AgentClient {smpClients, ntfClients, xftpClients, smpDeliveryWorkers, asyncCmdWorkers, smpSubWorkers, agentEnv} = do smpClientsCount <- M.size <$> readTVarIO smpClients ntfClientsCount <- M.size <$> readTVarIO ntfClients @@ -1695,7 +1743,7 @@ getAgentWorkersSummary AgentClient {smpClients, ntfClients, xftpClients, smpDeli Env {ntfSupervisor, xftpAgent} = agentEnv NtfSupervisor {ntfWorkers, ntfSMPWorkers} = ntfSupervisor XFTPAgent {xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers} = xftpAgent - workerSummary :: MonadIO m => M.Map k Worker -> m WorkersSummary + workerSummary :: M.Map k Worker -> IO WorkersSummary workerSummary = liftIO . foldM byWork WorkersSummary {numActive = 0, numIdle = 0, totalRestarts = 0} where byWork WorkersSummary {numActive, numIdle, totalRestarts} Worker {action, restarts} = do diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 20a378a45..a1d060586 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -11,8 +11,8 @@ {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} module Simplex.Messaging.Agent.Env.SQLite - ( AgentMonad, - AgentMonad', + ( AM', + AM, AgentConfig (..), InitialAgentServers (..), NetworkConfig (..), @@ -21,6 +21,7 @@ module Simplex.Messaging.Agent.Env.SQLite tryAgentError, tryAgentError', catchAgentError, + catchAgentError', agentFinally, Env (..), newSMPAgentEnv, @@ -34,7 +35,6 @@ module Simplex.Messaging.Agent.Env.SQLite ) where -import Control.Monad import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader @@ -65,14 +65,14 @@ import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPVersion, TLS, Transport (..)) import Simplex.Messaging.Transport.Client (defaultSMPPort) -import Simplex.Messaging.Util (allFinally, catchAllErrors, tryAllErrors) +import Simplex.Messaging.Util (allFinally, catchAllErrors, catchAllErrors', tryAllErrors, tryAllErrors') import System.Random (StdGen, newStdGen) import UnliftIO (Async, SomeException) import UnliftIO.STM -type AgentMonad' m = (MonadUnliftIO m, MonadReader Env m) +type AM' a = ReaderT Env IO a -type AgentMonad m = (AgentMonad' m, MonadError AgentErrorType m) +type AM a = ExceptT AgentErrorType (ReaderT Env IO) a data InitialAgentServers = InitialAgentServers { smp :: Map UserId (NonEmpty SMPServerWithAuth), @@ -82,7 +82,7 @@ data InitialAgentServers = InitialAgentServers } data AgentConfig = AgentConfig - { tcpPort :: ServiceName, + { tcpPort :: Maybe ServiceName, rcvAuthAlg :: C.AuthAlg, sndAuthAlg :: C.AuthAlg, connIdBytes :: Int, @@ -149,7 +149,7 @@ defaultMessageRetryInterval = defaultAgentConfig :: AgentConfig defaultAgentConfig = AgentConfig - { tcpPort = "5224", + { tcpPort = Just "5224", -- while the current client version supports X25519, it can only be enabled once support for SMP v6 is dropped, -- and all servers are required to support v7 to be compatible. rcvAuthAlg = C.AuthAlg C.SEd25519, -- this will stay as Ed25519 @@ -250,20 +250,24 @@ newXFTPAgent = do xftpDelWorkers <- TM.empty pure XFTPAgent {xftpWorkDir, xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers} -tryAgentError :: AgentMonad m => m a -> m (Either AgentErrorType a) +tryAgentError :: AM a -> AM (Either AgentErrorType a) tryAgentError = tryAllErrors mkInternal {-# INLINE tryAgentError #-} -- unlike runExceptT, this ensures we catch IO exceptions as well -tryAgentError' :: AgentMonad' m => ExceptT AgentErrorType m a -> m (Either AgentErrorType a) -tryAgentError' = fmap join . runExceptT . tryAgentError +tryAgentError' :: AM a -> AM' (Either AgentErrorType a) +tryAgentError' = tryAllErrors' mkInternal {-# INLINE tryAgentError' #-} -catchAgentError :: AgentMonad m => m a -> (AgentErrorType -> m a) -> m a +catchAgentError :: AM a -> (AgentErrorType -> AM a) -> AM a catchAgentError = catchAllErrors mkInternal {-# INLINE catchAgentError #-} -agentFinally :: AgentMonad m => m a -> m b -> m a +catchAgentError' :: AM a -> (AgentErrorType -> AM' a) -> AM' a +catchAgentError' = catchAllErrors' mkInternal +{-# INLINE catchAgentError' #-} + +agentFinally :: AM a -> AM b -> AM a agentFinally = allFinally mkInternal {-# INLINE agentFinally #-} diff --git a/src/Simplex/Messaging/Agent/Lock.hs b/src/Simplex/Messaging/Agent/Lock.hs index 37b63eb0e..c0647b844 100644 --- a/src/Simplex/Messaging/Agent/Lock.hs +++ b/src/Simplex/Messaging/Agent/Lock.hs @@ -1,15 +1,15 @@ -{-# LANGUAGE NamedFieldPuns #-} - module Simplex.Messaging.Agent.Lock ( Lock, createLock, withLock, + withLock', withGetLock, withGetLocks, ) where import Control.Monad (void) +import Control.Monad.Except (ExceptT (..), runExceptT) import Control.Monad.IO.Unlift import Data.Functor (($>)) import UnliftIO.Async (forConcurrently) @@ -22,8 +22,12 @@ createLock :: STM Lock createLock = newEmptyTMVar {-# INLINE createLock #-} -withLock :: MonadUnliftIO m => Lock -> String -> m a -> m a -withLock lock name = +withLock :: MonadUnliftIO m => Lock -> String -> ExceptT e m a -> ExceptT e m a +withLock lock name = ExceptT . withLock' lock name . runExceptT +{-# INLINE withLock #-} + +withLock' :: MonadUnliftIO m => Lock -> String -> m a -> m a +withLock' lock name = E.bracket_ (atomically $ putTMVar lock name) (void . atomically $ takeTMVar lock) diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index aa2c0e9c6..7e47b5ba6 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -44,7 +44,7 @@ import UnliftIO import UnliftIO.Concurrent (forkIO, threadDelay) import qualified UnliftIO.Exception as E -runNtfSupervisor :: forall m. AgentMonad' m => AgentClient -> m () +runNtfSupervisor :: AgentClient -> AM' () runNtfSupervisor c = do ns <- asks ntfSupervisor forever $ do @@ -54,13 +54,13 @@ runNtfSupervisor c = do Left e -> notifyErr connId e Right _ -> return () where - handleErr :: ConnId -> m () -> m () + handleErr :: ConnId -> AM' () -> AM' () handleErr connId = E.handle $ \(e :: E.SomeException) -> do logError $ "runNtfSupervisor error " <> tshow e notifyErr connId e notifyErr connId e = notifyInternalError c connId $ "runNtfSupervisor error " <> show e -processNtfSub :: forall m. AgentMonad m => AgentClient -> (ConnId, NtfSupervisorCommand) -> m () +processNtfSub :: AgentClient -> (ConnId, NtfSupervisorCommand) -> AM () processNtfSub c (connId, cmd) = do logInfo $ "processNtfSub - connId = " <> tshow connId <> " - cmd = " <> tshow cmd case cmd of @@ -77,11 +77,11 @@ processNtfSub c (connId, cmd) = do Just ClientNtfCreds {notifierId} -> do let newSub = newNtfSubscription connId smpServer (Just notifierId) ntfServer NASKey withStore c $ \db -> createNtfSubscription db newSub $ NtfSubNTFAction NSACreate - void $ getNtfNTFWorker True c ntfServer + lift . void $ getNtfNTFWorker True c ntfServer Nothing -> do let newSub = newNtfSubscription connId smpServer Nothing ntfServer NASNew withStore c $ \db -> createNtfSubscription db newSub $ NtfSubSMPAction NSASmpKey - void $ getNtfSMPWorker True c smpServer + lift . void $ getNtfSMPWorker True c smpServer (Just (sub@NtfSubscription {ntfSubStatus, ntfServer = subNtfServer, smpServer = smpServer', ntfQueueId}, action_)) -> do case (clientNtfCreds, ntfQueueId) of (Just ClientNtfCreds {notifierId}, Just ntfQueueId') @@ -90,7 +90,7 @@ processNtfSub c (connId, cmd) = do (Nothing, Nothing) -> create _ -> rotate where - create :: m () + create :: AM () create = case action_ of -- action was set to NULL after worker internal error Nothing -> resetSubscription @@ -101,60 +101,60 @@ processNtfSub c (connId, cmd) = do then resetSubscription else withTokenServer $ \ntfServer -> do withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate) - void $ getNtfNTFWorker True c ntfServer + lift . void $ getNtfNTFWorker True c ntfServer | otherwise -> case action of - NtfSubNTFAction _ -> void $ getNtfNTFWorker True c subNtfServer - NtfSubSMPAction _ -> void $ getNtfSMPWorker True c smpServer - rotate :: m () + NtfSubNTFAction _ -> lift . void $ getNtfNTFWorker True c subNtfServer + NtfSubSMPAction _ -> lift . void $ getNtfSMPWorker True c smpServer + rotate :: AM () rotate = do withStore' c $ \db -> supervisorUpdateNtfSub db sub (NtfSubNTFAction NSARotate) - void $ getNtfNTFWorker True c subNtfServer - resetSubscription :: m () + lift . void $ getNtfNTFWorker True c subNtfServer + resetSubscription :: AM () resetSubscription = withTokenServer $ \ntfServer -> do let sub' = sub {ntfQueueId = Nothing, ntfServer, ntfSubId = Nothing, ntfSubStatus = NASNew} withStore' c $ \db -> supervisorUpdateNtfSub db sub' (NtfSubSMPAction NSASmpKey) - void $ getNtfSMPWorker True c smpServer + lift . void $ getNtfSMPWorker True c smpServer NSCDelete -> do sub_ <- withStore' c $ \db -> do supervisorUpdateNtfAction db connId (NtfSubNTFAction NSADelete) getNtfSubscription db connId logInfo $ "processNtfSub, NSCDelete - sub_ = " <> tshow sub_ case sub_ of - (Just (NtfSubscription {ntfServer}, _)) -> void $ getNtfNTFWorker True c ntfServer + (Just (NtfSubscription {ntfServer}, _)) -> lift . void $ getNtfNTFWorker True c ntfServer _ -> pure () -- err "NSCDelete - no subscription" NSCSmpDelete -> do withStore' c (`getPrimaryRcvQueue` connId) >>= \case Right rq@RcvQueue {server = smpServer} -> do logInfo $ "processNtfSub, NSCSmpDelete - rq = " <> tshow rq withStore' c $ \db -> supervisorUpdateNtfAction db connId (NtfSubSMPAction NSASmpDelete) - void $ getNtfSMPWorker True c smpServer + lift . void $ getNtfSMPWorker True c smpServer _ -> notifyInternalError c connId "NSCSmpDelete - no rcv queue" - NSCNtfWorker ntfServer -> void $ getNtfNTFWorker True c ntfServer - NSCNtfSMPWorker smpServer -> void $ getNtfSMPWorker True c smpServer + NSCNtfWorker ntfServer -> lift . void $ getNtfNTFWorker True c ntfServer + NSCNtfSMPWorker smpServer -> lift . void $ getNtfSMPWorker True c smpServer -getNtfNTFWorker :: AgentMonad' m => Bool -> AgentClient -> NtfServer -> m Worker +getNtfNTFWorker :: Bool -> AgentClient -> NtfServer -> AM' Worker getNtfNTFWorker hasWork c server = do ws <- asks $ ntfWorkers . ntfSupervisor getAgentWorker "ntf_ntf" hasWork c server ws $ runNtfWorker c server -getNtfSMPWorker :: AgentMonad' m => Bool -> AgentClient -> SMPServer -> m Worker +getNtfSMPWorker :: Bool -> AgentClient -> SMPServer -> AM' Worker getNtfSMPWorker hasWork c server = do ws <- asks $ ntfSMPWorkers . ntfSupervisor getAgentWorker "ntf_smp" hasWork c server ws $ runNtfSMPWorker c server -withTokenServer :: AgentMonad' m => (NtfServer -> m ()) -> m () -withTokenServer action = getNtfToken >>= mapM_ (\NtfToken {ntfServer} -> action ntfServer) +withTokenServer :: (NtfServer -> AM ()) -> AM () +withTokenServer action = lift getNtfToken >>= mapM_ (\NtfToken {ntfServer} -> action ntfServer) -runNtfWorker :: forall m. AgentMonad m => AgentClient -> NtfServer -> Worker -> m () +runNtfWorker :: AgentClient -> NtfServer -> Worker -> AM () runNtfWorker c srv Worker {doWork} = do delay <- asks $ ntfWorkerDelay . config forever $ do waitForWork doWork - agentOperationBracket c AONtfNetwork throwWhenInactive runNtfOperation + ExceptT $ agentOperationBracket c AONtfNetwork throwWhenInactive $ runExceptT runNtfOperation threadDelay delay where - runNtfOperation :: m () + runNtfOperation :: AM () runNtfOperation = withWork c doWork (`getNextNtfSubNTFAction` srv) $ \nextSub@(NtfSubscription {connId}, _, _) -> do @@ -163,13 +163,13 @@ runNtfWorker c srv Worker {doWork} = do withRetryInterval ri $ \_ loop -> processSub nextSub `catchAgentError` retryOnError c "NtfWorker" loop (workerInternalError c connId . show) - processSub :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> m () + processSub :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> AM () processSub (sub@NtfSubscription {connId, smpServer, ntfSubId}, action, actionTs) = do ts <- liftIO getCurrentTime - unlessM (rescheduleAction doWork ts actionTs) $ + unlessM (lift $ rescheduleAction doWork ts actionTs) $ case action of NSACreate -> - getNtfToken >>= \case + lift getNtfToken >>= \case Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do RcvQueue {clientNtfCreds} <- withStore c (`getPrimaryRcvQueue` connId) case clientNtfCreds of @@ -182,13 +182,13 @@ runNtfWorker c srv Worker {doWork} = do _ -> workerInternalError c connId "NSACreate - no notifier queue credentials" _ -> workerInternalError c connId "NSACreate - no active token" NSACheck -> - getNtfToken >>= \case + lift getNtfToken >>= \case Just tkn -> case ntfSubId of Just nSubId -> agentNtfCheckSubscription c nSubId tkn >>= \case NSAuth -> do - getNtfServer c >>= \case + lift (getNtfServer c) >>= \case Just ntfServer -> do withStore' c $ \db -> updateNtfSubscription db sub {ntfServer, ntfQueueId = Nothing, ntfSubId = Nothing, ntfSubStatus = NASNew} (NtfSubSMPAction NSASmpKey) ts @@ -200,7 +200,7 @@ runNtfWorker c srv Worker {doWork} = do _ -> workerInternalError c connId "NSACheck - no active token" NSADelete -> case ntfSubId of Just nSubId -> - (getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId)) + (lift getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId)) `agentFinally` continueDeletion _ -> continueDeletion where @@ -211,7 +211,7 @@ runNtfWorker c srv Worker {doWork} = do atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer) NSARotate -> case ntfSubId of Just nSubId -> - (getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId)) + (lift getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId)) `agentFinally` deleteCreate _ -> deleteCreate where @@ -228,12 +228,14 @@ runNtfWorker c srv Worker {doWork} = do withStore' c $ \db -> updateNtfSubscription db sub {ntfSubStatus = toStatus} toAction actionTs' -runNtfSMPWorker :: forall m. AgentMonad m => AgentClient -> SMPServer -> Worker -> m () +runNtfSMPWorker :: AgentClient -> SMPServer -> Worker -> AM () runNtfSMPWorker c srv Worker {doWork} = do + env <- ask delay <- asks $ ntfSMPWorkerDelay . config forever $ do waitForWork doWork - agentOperationBracket c AONtfNetwork throwWhenInactive runNtfSMPOperation + ExceptT . liftIO . agentOperationBracket c AONtfNetwork throwWhenInactive $ + runReaderT (runExceptT runNtfSMPOperation) env threadDelay delay where runNtfSMPOperation = @@ -244,13 +246,13 @@ runNtfSMPWorker c srv Worker {doWork} = do withRetryInterval ri $ \_ loop -> processSub nextSub `catchAgentError` retryOnError c "NtfSMPWorker" loop (workerInternalError c connId . show) - processSub :: (NtfSubscription, NtfSubSMPAction, NtfActionTs) -> m () + processSub :: (NtfSubscription, NtfSubSMPAction, NtfActionTs) -> AM () processSub (sub@NtfSubscription {connId, ntfServer}, smpAction, actionTs) = do ts <- liftIO getCurrentTime - unlessM (rescheduleAction doWork ts actionTs) $ + unlessM (lift $ rescheduleAction doWork ts actionTs) $ case smpAction of NSASmpKey -> - getNtfToken >>= \case + lift getNtfToken >>= \case Just NtfToken {ntfTknStatus = NTActive, ntfMode = NMInstant} -> do rq <- withStore c (`getPrimaryRcvQueue` connId) C.AuthAlg a <- asks (rcvAuthAlg . config) @@ -272,7 +274,7 @@ runNtfSMPWorker c srv Worker {doWork} = do mapM_ (disableQueueNotifications c) rq_ withStore' c $ \db -> deleteNtfSubscription db connId -rescheduleAction :: AgentMonad' m => TMVar () -> UTCTime -> UTCTime -> m Bool +rescheduleAction :: TMVar () -> UTCTime -> UTCTime -> AM' Bool rescheduleAction doWork ts actionTs | actionTs <= ts = pure False | otherwise = do @@ -282,7 +284,7 @@ rescheduleAction doWork ts actionTs atomically $ hasWorkToDo' doWork pure True -retryOnError :: AgentMonad' m => AgentClient -> Text -> m () -> (AgentErrorType -> m ()) -> AgentErrorType -> m () +retryOnError :: AgentClient -> Text -> AM () -> (AgentErrorType -> AM ()) -> AgentErrorType -> AM () retryOnError c name loop done e = do logError $ name <> " error: " <> tshow e case e of @@ -296,16 +298,17 @@ retryOnError c name loop done e = do atomically $ beginAgentOperation c AONtfNetwork loop -workerInternalError :: AgentMonad m => AgentClient -> ConnId -> String -> m () +workerInternalError :: AgentClient -> ConnId -> String -> AM () workerInternalError c connId internalErrStr = do withStore' c $ \db -> setNullNtfSubscriptionAction db connId notifyInternalError c connId internalErrStr -- TODO change error -notifyInternalError :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m () +notifyInternalError :: MonadIO m => AgentClient -> ConnId -> String -> m () notifyInternalError AgentClient {subQ} connId internalErrStr = atomically $ writeTBQueue subQ ("", connId, APC SAEConn $ ERR $ INTERNAL internalErrStr) +{-# INLINE notifyInternalError #-} -getNtfToken :: AgentMonad' m => m (Maybe NtfToken) +getNtfToken :: AM' (Maybe NtfToken) getNtfToken = do tkn <- asks $ ntfTkn . ntfSupervisor readTVarIO tkn @@ -326,14 +329,14 @@ instantNotifications = \case Just NtfToken {ntfTknStatus = NTActive, ntfMode = NMInstant} -> True _ -> False -closeNtfSupervisor :: MonadUnliftIO m => NtfSupervisor -> m () +closeNtfSupervisor :: NtfSupervisor -> IO () closeNtfSupervisor ns = do stopWorkers $ ntfWorkers ns stopWorkers $ ntfSMPWorkers ns where stopWorkers workers = atomically (swapTVar workers M.empty) >>= mapM_ (liftIO . cancelWorker) -getNtfServer :: AgentMonad' m => AgentClient -> m (Maybe NtfServer) +getNtfServer :: AgentClient -> AM' (Maybe NtfServer) getNtfServer c = do ntfServers <- readTVarIO $ ntfServers c case ntfServers of diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index df9907fe0..9c24646e3 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -1897,15 +1897,15 @@ tGetRaw :: Transport c => c -> IO ARawTransmission tGetRaw h = (,,) <$> getLn h <*> getLn h <*> getLn h -- | Send SMP agent protocol command (or response) to TCP connection. -tPut :: (Transport c, MonadIO m) => c -> ATransmission p -> m () +tPut :: Transport c => c -> ATransmission p -> IO () tPut h (corrId, connId, APC _ cmd) = - liftIO $ tPutRaw h (corrId, connId, serializeCommand cmd) + tPutRaw h (corrId, connId, serializeCommand cmd) -- | Receive client and agent transmissions from TCP connection. -tGet :: forall c m p. (Transport c, MonadIO m) => SAParty p -> c -> m (ATransmissionOrError p) +tGet :: forall c p. Transport c => SAParty p -> c -> IO (ATransmissionOrError p) tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody where - tParseLoadBody :: ARawTransmission -> m (ATransmissionOrError p) + tParseLoadBody :: ARawTransmission -> IO (ATransmissionOrError p) tParseLoadBody t@(corrId, entId, command) = do let cmd = parseCommand command >>= fromParty >>= tConnId t fullCmd <- either (return . Left) cmdWithMsgBody cmd @@ -1935,7 +1935,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody | B.null entId -> Left $ CMD NO_CONN | otherwise -> Right cmd - cmdWithMsgBody :: APartyCmd p -> m (Either AgentErrorType (APartyCmd p)) + cmdWithMsgBody :: APartyCmd p -> IO (Either AgentErrorType (APartyCmd p)) cmdWithMsgBody (APC e cmd) = APC e <$$> case cmd of SEND pqEnc msgFlags body -> SEND pqEnc msgFlags <$$> getBody body @@ -1948,7 +1948,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody INFO pqSup cInfo -> INFO pqSup <$$> getBody cInfo _ -> pure $ Right cmd - getBody :: ByteString -> m (Either AgentErrorType ByteString) + getBody :: ByteString -> IO (Either AgentErrorType ByteString) getBody binary = case B.unpack binary of ':' : body -> return . Right $ B.pack body diff --git a/src/Simplex/Messaging/Agent/Server.hs b/src/Simplex/Messaging/Agent/Server.hs index a6e15dcc4..368c0a23d 100644 --- a/src/Simplex/Messaging/Agent/Server.hs +++ b/src/Simplex/Messaging/Agent/Server.hs @@ -12,13 +12,13 @@ where import Control.Logger.Simple (logInfo) import Control.Monad -import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Reader -import Crypto.Random (MonadRandom) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Text.Encoding (decodeUtf8) +import Network.Socket (ServiceName) import Simplex.Messaging.Agent +import Simplex.Messaging.Agent.Client (newAgentClient) import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore) @@ -32,7 +32,7 @@ import UnliftIO.STM -- | Runs an SMP agent as a TCP service using passed configuration. -- -- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs -runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> m () +runSMPAgent :: ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> IO () runSMPAgent t cfg initServers store = runSMPAgentBlocking t cfg initServers store 0 =<< newEmptyTMVarIO @@ -40,44 +40,46 @@ runSMPAgent t cfg initServers store = -- -- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True) -- and when it is disconnected from the TCP socket once the server thread is killed (False). -runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Int -> TMVar Bool -> m () -runSMPAgentBlocking (ATransport t) cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers store initClientId started = do - liftIO (newSMPAgentEnv cfg store) >>= runReaderT (smpAgent t) +runSMPAgentBlocking :: ATransport -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Int -> TMVar Bool -> IO () +runSMPAgentBlocking (ATransport t) cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers store initClientId started = + case tcpPort of + Just port -> newSMPAgentEnv cfg store >>= smpAgent t port + Nothing -> E.throwIO $ userError "no agent port" where - smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' () - smpAgent _ = do + smpAgent :: forall c. Transport c => TProxy c -> ServiceName -> Env -> IO () + smpAgent _ port env = do -- tlsServerParams is not in Env to avoid breaking functional API w/t key and certificate generation - tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile + tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile clientId <- newTVarIO initClientId - runTransportServer started tcpPort tlsServerParams defaultTransportServerConfig $ \(h :: c) -> do - liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion + runTransportServer started port tlsServerParams defaultTransportServerConfig $ \(h :: c) -> do + putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion cId <- atomically $ stateTVar clientId $ \i -> (i + 1, i + 1) - c <- getAgentClient cId initServers + c <- atomically $ newAgentClient cId initServers env logConnection c True - race_ (connectClient h c) (runAgentClient c) - `E.finally` disconnectAgentClient c + race_ (connectClient h c) (runAgentClient c `runReaderT` env) + `E.finally` (disconnectAgentClient c) -connectClient :: Transport c => MonadUnliftIO m => c -> AgentClient -> m () +connectClient :: Transport c => c -> AgentClient -> IO () connectClient h c = race_ (send h c) (receive h c) -receive :: forall c m. (Transport c, MonadUnliftIO m) => c -> AgentClient -> m () +receive :: forall c. Transport c => c -> AgentClient -> IO () receive h c@AgentClient {rcvQ, subQ} = forever $ do (corrId, entId, cmdOrErr) <- tGet SClient h case cmdOrErr of Right cmd -> write rcvQ (corrId, entId, cmd) Left e -> write subQ (corrId, entId, APC SAEConn $ ERR e) where - write :: TBQueue (ATransmission p) -> ATransmission p -> m () + write :: TBQueue (ATransmission p) -> ATransmission p -> IO () write q t = do logClient c "-->" t atomically $ writeTBQueue q t -send :: (Transport c, MonadUnliftIO m) => c -> AgentClient -> m () +send :: Transport c => c -> AgentClient -> IO () send h c@AgentClient {subQ} = forever $ do t <- atomically $ readTBQueue subQ tPut h t logClient c "<--" t -logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> m () +logClient :: AgentClient -> ByteString -> ATransmission a -> IO () logClient AgentClient {clientId} dir (corrId, connId, APC _ cmd) = do logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, connId, B.takeWhile (/= ' ') $ serializeCommand cmd] diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index d8b202761..b7613f4dc 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -97,7 +97,7 @@ import Data.List (find) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Maybe (fromMaybe) -import Data.Time.Clock (UTCTime, getCurrentTime) +import Data.Time.Clock (UTCTime (..), getCurrentTime) import Network.Socket (ServiceName) import Numeric.Natural import qualified Simplex.Messaging.Crypto as C @@ -138,11 +138,12 @@ data PClient v err msg = PClient msgQ :: Maybe (TBQueue (ServerTransmission v msg)) } -smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM (ProtocolClient SMPVersion err msg) +smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM SMPClient smpClientStub g sessionId thVersion thAuth = do connected <- newTVar False clientCorrId <- C.newRandomDRG g sentCommands <- TM.empty + pingErrorCount <- newTVar 0 sndQ <- newTBQueue 100 rcvQ <- newTBQueue 100 return @@ -157,15 +158,15 @@ smpClientStub g sessionId thVersion thAuth = do implySessId = thVersion >= authCmdsSMPVersion, batch = True }, - sessionTs = undefined, + sessionTs = UTCTime (read "2024-03-31") 0, client_ = PClient { connected, - transportSession = undefined, - transportHost = undefined, - tcpTimeout = undefined, + transportSession = (1, "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001", Nothing), + transportHost = "localhost", + tcpTimeout = 15_000_000, batchDelay = Nothing, - pingErrorCount = undefined, + pingErrorCount, clientCorrId, sentCommands, sndQ, @@ -239,6 +240,7 @@ defaultNetworkConfig = transportClientConfig :: NetworkConfig -> TransportClientConfig transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} = TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing} +{-# INLINE transportClientConfig #-} -- | protocol client configuration. data ProtocolClientConfig v = ProtocolClientConfig @@ -264,9 +266,11 @@ defaultClientConfig serverVRange = serverVRange, batchDelay = Nothing } +{-# INLINE defaultClientConfig #-} defaultSMPClientConfig :: ProtocolClientConfig SMPVersion defaultSMPClientConfig = defaultClientConfig supportedClientSMPRelayVRange +{-# INLINE defaultSMPClientConfig #-} data Request err msg = Request { entityId :: EntityId, @@ -296,12 +300,15 @@ protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient v err ms protocolClientServer = B.unpack . strEncode . snd3 . transportSession . client_ where snd3 (_, s, _) = s +{-# INLINE protocolClientServer #-} transportHost' :: ProtocolClient v err msg -> TransportHost transportHost' = transportHost . client_ +{-# INLINE transportHost' #-} transportSession' :: ProtocolClient v err msg -> TransportSession msg transportSession' = transportSession . client_ +{-# INLINE transportSession' #-} type UserId = Int64 @@ -426,10 +433,12 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize proxyUsername :: TransportSession msg -> ByteString proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_ +{-# INLINE proxyUsername #-} -- | Disconnects client from the server and terminates client threads. closeProtocolClient :: ProtocolClient v err msg -> IO () closeProtocolClient = mapM_ uninterruptibleCancel . action +{-# INLINE closeProtocolClient #-} -- | SMP client error type. data ProtocolClientError err @@ -469,6 +478,7 @@ temporaryClientError = \case PCEResponseTimeout -> True PCEIOError _ -> True _ -> False +{-# INLINE temporaryClientError #-} -- | Create a new SMP queue. -- @@ -536,16 +546,19 @@ getSMPMessage c rpKey rId = -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateAuthKey -> NotifierId -> ExceptT SMPClientError IO () subscribeSMPQueueNotifications = okSMPCommand NSUB +{-# INLINE subscribeSMPQueueNotifications #-} -- | Subscribe to multiple SMP queues notifications batching commands if supported. subscribeSMPQueuesNtfs :: SMPClient -> NonEmpty (NtfPrivateAuthKey, NotifierId) -> IO (NonEmpty (Either SMPClientError ())) subscribeSMPQueuesNtfs = okSMPCommands NSUB +{-# INLINE subscribeSMPQueuesNtfs #-} -- | Secure the SMP queue by adding a sender public key. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command secureSMPQueue :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> SndPublicAuthKey -> ExceptT SMPClientError IO () secureSMPQueue c rpKey rId senderKey = okSMPCommand (KEY senderKey) c rpKey rId +{-# INLINE secureSMPQueue #-} -- | Enable notifications for the queue for push notifications server. -- @@ -571,10 +584,12 @@ enableSMPQueuesNtfs c qs = L.map process <$> sendProtocolCommands c cs -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#disable-notifications-command disableSMPQueueNotifications :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> ExceptT SMPClientError IO () disableSMPQueueNotifications = okSMPCommand NDEL +{-# INLINE disableSMPQueueNotifications #-} -- | Disable notifications for multiple queues for push notifications server. disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateAuthKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ())) disableSMPQueuesNtfs = okSMPCommands NDEL +{-# INLINE disableSMPQueuesNtfs #-} -- | Send SMP message. -- @@ -601,16 +616,19 @@ ackSMPMessage c rpKey rId msgId = -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue suspendSMPQueue :: SMPClient -> RcvPrivateAuthKey -> QueueId -> ExceptT SMPClientError IO () suspendSMPQueue = okSMPCommand OFF +{-# INLINE suspendSMPQueue #-} -- | Irreversibly delete SMP queue and all messages in it. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue deleteSMPQueue :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> ExceptT SMPClientError IO () deleteSMPQueue = okSMPCommand DEL +{-# INLINE deleteSMPQueue #-} -- | Delete multiple SMP queues batching commands if supported. deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateAuthKey, RecipientId) -> IO (NonEmpty (Either SMPClientError ())) deleteSMPQueues = okSMPCommands DEL +{-# INLINE deleteSMPQueues #-} okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateAuthKey -> QueueId -> ExceptT SMPClientError IO () okSMPCommand cmd c pKey qId = @@ -631,6 +649,7 @@ okSMPCommands cmd c qs = L.map process <$> sendProtocolCommands c cs -- | Send SMP command sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateAuthKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd) +{-# INLINE sendSMPCommand #-} type PCTransmission err msg = (Either TransportError SentRawTransmission, Request err msg) diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 73f47648b..8e21aada1 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -18,6 +19,7 @@ import Control.Monad import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Trans.Except +import Control.Monad.Trans.Reader import Crypto.Random (ChaChaDRG) import Data.Bifunctor (bimap, first) import Data.ByteString.Char8 (ByteString) @@ -106,12 +108,24 @@ newtype InternalException e = InternalException {unInternalException :: e} instance Exception e => Exception (InternalException e) -instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where - withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b - withRunInIO exceptToIO = +instance Exception e => MonadUnliftIO (ExceptT e IO) where + {-# INLINE withRunInIO #-} + withRunInIO :: ((forall a. ExceptT e IO a -> IO a) -> IO b) -> ExceptT e IO b + withRunInIO inner = + ExceptT . fmap (first unInternalException) . E.try $ + withRunInIO $ \run -> + inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT) + -- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`, + -- the last two lines could be replaced with: + -- inner $ either (E.throwIO . InternalException) pure <=< runExceptT + +instance Exception e => MonadUnliftIO (ExceptT e (ReaderT r IO)) where + {-# INLINE withRunInIO #-} + withRunInIO :: ((forall a. ExceptT e (ReaderT r IO) a -> IO a) -> IO b) -> ExceptT e (ReaderT r IO) b + withRunInIO inner = withExceptT unInternalException . ExceptT . E.try $ withRunInIO $ \run -> - exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT) + inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT) newSMPClientAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> STM SMPClientAgent newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg = do @@ -147,7 +161,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr Nothing -> Left PCEResponseTimeout newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient - newSMPClient smpVar = tryConnectClient pure tryConnectAsync + newSMPClient smpVar = tryConnectClient pure (liftIO tryConnectAsync) where tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a tryConnectClient successAction retryAction = @@ -163,9 +177,9 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr putTMVar smpVar (Left e) TM.delete srv smpClients throwE e - tryConnectAsync :: ExceptT SMPClientError IO () + tryConnectAsync :: IO () tryConnectAsync = do - a <- async connectAsync + a <- async $ void $ runExceptT connectAsync atomically $ modifyTVar' (asyncClients ca) (a :) connectAsync :: ExceptT SMPClientError IO () connectAsync = @@ -199,11 +213,11 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr serverDown :: Map SMPSub C.APrivateAuthKey -> IO () serverDown ss = unless (M.null ss) $ do notify . CADisconnected srv $ M.keysSet ss - void $ runExceptT reconnectServer + reconnectServer - reconnectServer :: ExceptT SMPClientError IO () + reconnectServer :: IO () reconnectServer = do - a <- async tryReconnectClient + a <- async $ void $ runExceptT tryReconnectClient atomically $ modifyTVar' (reconnections ca) (a :) tryReconnectClient :: ExceptT SMPClientError IO () @@ -247,8 +261,8 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr notify :: SMPClientAgentEvent -> IO () notify evt = atomically $ writeTBQueue (agentQ ca) evt -closeSMPClientAgent :: MonadUnliftIO m => SMPClientAgent -> m () -closeSMPClientAgent c = liftIO $ do +closeSMPClientAgent :: SMPClientAgent -> IO () +closeSMPClientAgent c = do closeSMPServerClients c cancelActions $ reconnections c cancelActions $ asyncClients c diff --git a/src/Simplex/Messaging/Crypto/File.hs b/src/Simplex/Messaging/Crypto/File.hs index 84a1d18e9..2787df58e 100644 --- a/src/Simplex/Messaging/Crypto/File.hs +++ b/src/Simplex/Messaging/Crypto/File.hs @@ -76,7 +76,7 @@ withFile :: CryptoFile -> IOMode -> (CryptoFileHandle -> ExceptT FTCryptoError I withFile (CryptoFile path cfArgs) mode action = do sb <- forM cfArgs $ \(CFArgs key nonce) -> liftEitherWith FTCECryptoError (LC.sbInit key nonce) >>= newTVarIO - IO.withFile path mode $ \h -> action $ CFHandle h sb + ExceptT . IO.withFile path mode $ \h -> runExceptT $ action $ CFHandle h sb hPut :: CryptoFileHandle -> LazyByteString -> IO () hPut (CFHandle h sb_) s = LB.hPut h =<< maybe (pure s) encrypt sb_ diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 1d62665f0..2580e58fd 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -81,7 +81,8 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do runServer (tcpPort, ATransport t) = do serverParams <- asks tlsServerParams serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams - runTransportServer started tcpPort serverParams tCfg (runClient serverSignKey t) + env <- ask + liftIO $ runTransportServer started tcpPort serverParams tCfg $ \h -> runClient serverSignKey t h `runReaderT` env fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey runClient :: Transport c => C.APrivateSignKey -> TProxy c -> c -> M () diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 0d722dcc3..4a93a8a34 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -83,7 +83,7 @@ data NtfEnv = NtfEnv serverStats :: NtfServerStats } -newNtfServerEnv :: (MonadUnliftIO m, MonadRandom m) => NtfServerConfig -> m NtfEnv +newNtfServerEnv :: NtfServerConfig -> IO NtfEnv newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsConfig, storeLogFile, caCertificateFile, certificateFile, privateKeyFile} = do random <- liftIO C.newRandom store <- atomically newNtfStore diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index a7e25a82b..0299bebfc 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -128,14 +128,15 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do : serverThread s "server ntfSubscribedQ" ntfSubscribedQ Env.notifiers ntfSubscriptions (\_ -> pure ()) : map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg <> controlPortThread_ cfg ) - `finally` withLock (savingLock s) "final" (saveServer False) + `finally` withLock' (savingLock s) "final" (saveServer False) where runServer :: (ServiceName, ATransport) -> M () runServer (tcpPort, ATransport t) = do serverParams <- asks tlsServerParams ss <- asks sockets serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams - runTransportServerState ss started tcpPort serverParams tCfg (runClient serverSignKey t) + env <- ask + liftIO $ runTransportServerState ss started tcpPort serverParams tCfg $ \h -> runClient serverSignKey t h `runReaderT` env fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey saveServer :: Bool -> M () @@ -387,7 +388,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do withLog (`logDeleteQueue` queueId) updateDeletedStats q liftIO . hPutStrLn h $ "ok, " <> show numDeleted <> " messages deleted" - CPSave -> withAdminRole $ withLock (savingLock srv) "control" $ do + CPSave -> withAdminRole $ withLock' (savingLock srv) "control" $ do hPutStrLn h "saving server state..." unliftIO u $ saveServer True hPutStrLn h "server state saved!" @@ -579,7 +580,7 @@ dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/Xo dummyKeyX25519 :: C.PublicKey 'C.X25519 dummyKeyX25519 = "MCowBQYDK2VuAyEA4JGSMYht18H4mas/jHeBwfcM7jLwNYJNOAhi2/g4RXg=" -client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m () +client :: Client -> Server -> M () client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Server {subscribedQ, ntfSubscribedQ, notifiers} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands" forever $ @@ -587,7 +588,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv >>= mapM processCommand >>= atomically . writeTBQueue sndQ where - processCommand :: (Maybe QueueRec, Transmission Cmd) -> m (Transmission BrokerMsg) + processCommand :: (Maybe QueueRec, Transmission Cmd) -> M (Transmission BrokerMsg) processCommand (qr_, (corrId, queueId, cmd)) = do st <- asks queueStore case cmd of @@ -616,7 +617,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv OFF -> suspendQueue_ st DEL -> delQueueAndMsgs st where - createQueue :: QueueStore -> RcvPublicAuthKey -> RcvPublicDhKey -> SubscriptionMode -> m (Transmission BrokerMsg) + createQueue :: QueueStore -> RcvPublicAuthKey -> RcvPublicDhKey -> SubscriptionMode -> M (Transmission BrokerMsg) createQueue st recipientKey dhKey subMode = time "NEW" $ do (rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random let rcvDhSecret = C.dh' dhKey privDhKey @@ -634,7 +635,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv (corrId,queueId,) <$> addQueueRetry 3 qik qRec where addQueueRetry :: - Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> m BrokerMsg + Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> M BrokerMsg addQueueRetry 0 _ _ = pure $ ERR INTERNAL addQueueRetry n qik qRec = do ids@(rId, _) <- getIds @@ -659,25 +660,25 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv Right q -> logCreateQueue s q _ -> pure () - getIds :: m (RecipientId, SenderId) + getIds :: M (RecipientId, SenderId) getIds = do n <- asks $ queueIdBytes . config liftM2 (,) (randomId n) (randomId n) - secureQueue_ :: QueueStore -> SndPublicAuthKey -> m (Transmission BrokerMsg) + secureQueue_ :: QueueStore -> SndPublicAuthKey -> M (Transmission BrokerMsg) secureQueue_ st sKey = time "KEY" $ do withLog $ \s -> logSecureQueue s queueId sKey stats <- asks serverStats atomically $ modifyTVar' (qSecured stats) (+ 1) atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey - addQueueNotifier_ :: QueueStore -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> m (Transmission BrokerMsg) + addQueueNotifier_ :: QueueStore -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> M (Transmission BrokerMsg) addQueueNotifier_ st notifierKey dhKey = time "NKEY" $ do (rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random let rcvNtfDhSecret = C.dh' dhKey privDhKey (corrId,queueId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret where - addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> m BrokerMsg + addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> M BrokerMsg addNotifierRetry 0 _ _ = pure $ ERR INTERNAL addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do notifierId <- randomId =<< asks (queueIdBytes . config) @@ -689,17 +690,17 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv withLog $ \s -> logAddNotifier s queueId ntfCreds pure $ NID notifierId rcvPublicDhKey - deleteQueueNotifier_ :: QueueStore -> m (Transmission BrokerMsg) + deleteQueueNotifier_ :: QueueStore -> M (Transmission BrokerMsg) deleteQueueNotifier_ st = do withLog (`logDeleteNotifier` queueId) okResp <$> atomically (deleteQueueNotifier st queueId) - suspendQueue_ :: QueueStore -> m (Transmission BrokerMsg) + suspendQueue_ :: QueueStore -> M (Transmission BrokerMsg) suspendQueue_ st = do withLog (`logSuspendQueue` queueId) okResp <$> atomically (suspendQueue st queueId) - subscribeQueue :: QueueRec -> RecipientId -> m (Transmission BrokerMsg) + subscribeQueue :: QueueRec -> RecipientId -> M (Transmission BrokerMsg) subscribeQueue qr rId = do atomically (TM.lookup rId subscriptions) >>= \case Nothing -> @@ -712,19 +713,19 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv s -> atomically (tryTakeTMVar $ delivered s) >> deliver sub where - newSub :: m (TVar Sub) + newSub :: M (TVar Sub) newSub = time "SUB newSub" . atomically $ do writeTQueue subscribedQ (rId, clnt) sub <- newTVar =<< newSubscription NoSub TM.insert rId sub subscriptions pure sub - deliver :: TVar Sub -> m (Transmission BrokerMsg) + deliver :: TVar Sub -> M (Transmission BrokerMsg) deliver sub = do q <- getStoreMsgQueue "SUB" rId msg_ <- atomically $ tryPeekMsg q deliverMessage "SUB" qr rId sub q msg_ - getMessage :: QueueRec -> m (Transmission BrokerMsg) + getMessage :: QueueRec -> M (Transmission BrokerMsg) getMessage qr = time "GET" $ do atomically (TM.lookup queueId subscriptions) >>= \case Nothing -> @@ -743,7 +744,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv sub <- newTVar s TM.insert queueId sub subscriptions pure s - getMessage_ :: Sub -> m (Transmission BrokerMsg) + getMessage_ :: Sub -> M (Transmission BrokerMsg) getMessage_ s = do q <- getStoreMsgQueue "GET" queueId atomically $ @@ -753,17 +754,17 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv in setDelivered s msg $> (corrId, queueId, MSG encMsg) _ -> pure (corrId, queueId, OK) - withQueue :: (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg) + withQueue :: (QueueRec -> M (Transmission BrokerMsg)) -> M (Transmission BrokerMsg) withQueue action = maybe (pure $ err AUTH) action qr_ - subscribeNotifications :: m (Transmission BrokerMsg) + subscribeNotifications :: M (Transmission BrokerMsg) subscribeNotifications = time "NSUB" . atomically $ do unlessM (TM.member queueId ntfSubscriptions) $ do writeTQueue ntfSubscribedQ (queueId, clnt) TM.insert queueId () ntfSubscriptions pure ok - acknowledgeMsg :: QueueRec -> MsgId -> m (Transmission BrokerMsg) + acknowledgeMsg :: QueueRec -> MsgId -> M (Transmission BrokerMsg) acknowledgeMsg qr msgId = time "ACK" $ do atomically (TM.lookup queueId subscriptions) >>= \case Nothing -> pure $ err NO_MSG @@ -789,7 +790,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv if msgId == msgId' || B.null msgId then pure $ Just s else putTMVar delivered msgId' $> Nothing - updateStats :: Message -> m () + updateStats :: Message -> M () updateStats = \case MessageQuota {} -> pure () Message {msgFlags} -> do @@ -801,7 +802,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv atomically $ modifyTVar' (msgRecvNtf stats) (+ 1) atomically $ updatePeriodStats (activeQueuesNtf stats) queueId - sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg) + sendMessage :: QueueRec -> MsgFlags -> MsgBody -> M (Transmission BrokerMsg) sendMessage qr msgFlags msgBody | B.length msgBody > maxMessageLength = pure $ err LARGE_MSG | otherwise = case status qr of @@ -827,13 +828,13 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) pure ok where - mkMessage :: C.MaxLenBS MaxMessageLen -> m Message + mkMessage :: C.MaxLenBS MaxMessageLen -> M Message mkMessage body = do msgId <- randomId =<< asks (msgIdBytes . config) msgTs <- liftIO getSystemTime pure $ Message msgId msgTs msgFlags body - expireMessages :: MsgQueue -> m () + expireMessages :: MsgQueue -> M () expireMessages q = do msgExp <- asks $ messageExpiration . config old <- liftIO $ mapM expireBeforeEpoch msgExp @@ -861,7 +862,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128 pure . (cbNonce,) $ fromRight "" encNMsgMeta - deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg) + deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> M (Transmission BrokerMsg) deliverMessage name qr rId sub q msg_ = time (name <> " deliver") $ do readTVarIO sub >>= \case s@Sub {subThread = NoSub} -> @@ -872,7 +873,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv _ -> forkSub $> ok _ -> pure ok where - forkSub :: m () + forkSub :: M () forkSub = do atomically . modifyTVar' sub $ \s -> s {subThread = SubPending} t <- mkWeakThreadId =<< forkIO subscriber @@ -890,7 +891,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv void $ setDelivered s msg writeTVar sub $! s {subThread = NoSub} - time :: T.Text -> m a -> m a + time :: T.Text -> M a -> M a time name = timed name queueId encryptMsg :: QueueRec -> Message -> RcvMessage @@ -906,13 +907,13 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv setDelivered :: Sub -> Message -> STM Bool setDelivered s msg = tryPutTMVar (delivered s) (messageId msg) - getStoreMsgQueue :: T.Text -> RecipientId -> m MsgQueue + getStoreMsgQueue :: T.Text -> RecipientId -> M MsgQueue getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do ms <- asks msgStore quota <- asks $ msgQueueQuota . config atomically $ getMsgQueue ms rId quota - delQueueAndMsgs :: QueueStore -> m (Transmission BrokerMsg) + delQueueAndMsgs :: QueueStore -> M (Transmission BrokerMsg) delQueueAndMsgs st = do withLog (`logDeleteQueue` queueId) ms <- asks msgStore @@ -929,7 +930,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv okResp :: Either ErrorType () -> Transmission BrokerMsg okResp = either err $ const ok -updateDeletedStats :: (MonadUnliftIO m, MonadReader Env m) => QueueRec -> m () +updateDeletedStats :: QueueRec -> M () updateDeletedStats q = do stats <- asks serverStats let delSel = if isNothing (senderKey q) then qDeletedNew else qDeletedSecured @@ -937,12 +938,12 @@ updateDeletedStats q = do atomically $ modifyTVar' (qDeletedAll stats) (+ 1) atomically $ modifyTVar' (qCount stats) (subtract 1) -withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m () +withLog :: (StoreLog 'WriteMode -> IO a) -> M () withLog action = do env <- ask liftIO . mapM_ action $ storeLog (env :: Env) -timed :: MonadUnliftIO m => T.Text -> RecipientId -> m a -> m a +timed :: T.Text -> RecipientId -> M a -> M a timed name qId a = do t <- liftIO getSystemTime r <- a @@ -954,10 +955,10 @@ timed name qId a = do diff t t' = (systemSeconds t' - systemSeconds t) * sec + fromIntegral (systemNanoseconds t' - systemNanoseconds t) sec = 1000_000000 -randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m ByteString +randomId :: Int -> M ByteString randomId n = atomically . C.randomBytes n =<< asks random -saveServerMessages :: (MonadUnliftIO m, MonadReader Env m) => Bool -> m () +saveServerMessages :: Bool -> M () saveServerMessages keepMsgs = asks (storeMsgsFile . config) >>= mapM_ saveMessages where saveMessages f = do @@ -972,7 +973,7 @@ saveServerMessages keepMsgs = asks (storeMsgsFile . config) >>= mapM_ saveMessag atomically (getMessages ms rId) >>= mapM_ (B.hPutStrLn h . strEncode . MLRv3 rId) -restoreServerMessages :: forall m. (MonadUnliftIO m, MonadReader Env m) => m Int +restoreServerMessages :: M Int restoreServerMessages = asks (storeMsgsFile . config) >>= \case Just f -> ifM (doesFileExist f) (restoreMessages f) (pure 0) Nothing -> pure 0 @@ -1008,7 +1009,7 @@ restoreServerMessages = asks (storeMsgsFile . config) >>= \case msgErr :: Show e => String -> e -> String msgErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s) -saveServerStats :: (MonadUnliftIO m, MonadReader Env m) => m () +saveServerStats :: M () saveServerStats = asks (serverStatsBackupFile . config) >>= mapM_ (\f -> asks serverStats >>= atomically . getServerStatsData >>= liftIO . saveStats f) @@ -1018,7 +1019,7 @@ saveServerStats = B.writeFile f $ strEncode stats logInfo "server stats saved" -restoreServerStats :: (MonadUnliftIO m, MonadReader Env m) => Int -> m () +restoreServerStats :: Int -> M () restoreServerStats expiredWhileRestoring = asks (serverStatsBackupFile . config) >>= mapM_ restoreStats where restoreStats f = whenM (doesFileExist f) $ do diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 9d783fba9..baadfc79b 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -173,25 +173,25 @@ newSubscription subThread = do delivered <- newEmptyTMVar return Sub {subThread, delivered} -newEnv :: forall m. (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env +newEnv :: ServerConfig -> IO Env newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile, storeLogFile} = do server <- atomically newServer queueStore <- atomically newQueueStore msgStore <- atomically newMsgStore random <- liftIO C.newRandom storeLog <- restoreQueues queueStore `mapM` storeLogFile - tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile - Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile + tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile + Fingerprint fp <- loadFingerprint caCertificateFile let serverIdentity = KeyHash fp - serverStats <- atomically . newServerStats =<< liftIO getCurrentTime + serverStats <- atomically . newServerStats =<< getCurrentTime sockets <- atomically newSocketState clientSeq <- newTVarIO 0 clients <- newTVarIO mempty return Env {config, server, serverIdentity, queueStore, msgStore, random, storeLog, tlsServerParams, serverStats, sockets, clientSeq, clients} where - restoreQueues :: QueueStore -> FilePath -> m (StoreLog 'WriteMode) + restoreQueues :: QueueStore -> FilePath -> IO (StoreLog 'WriteMode) restoreQueues QueueStore {queues, senders, notifiers} f = do - (qs, s) <- liftIO $ readWriteStoreLog f + (qs, s) <- readWriteStoreLog f atomically $ do writeTVar queues =<< mapM newTVar qs writeTVar senders $! M.foldr' addSender M.empty qs diff --git a/src/Simplex/Messaging/Transport/Client.hs b/src/Simplex/Messaging/Transport/Client.hs index ddc08ae98..8cca76043 100644 --- a/src/Simplex/Messaging/Transport/Client.hs +++ b/src/Simplex/Messaging/Transport/Client.hs @@ -23,7 +23,6 @@ where import Control.Applicative (optional) import Control.Logger.Simple (logError) import Control.Monad (when) -import Control.Monad.IO.Unlift import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) @@ -126,10 +125,10 @@ clientTransportConfig TransportClientConfig {logTLSErrors} = TransportConfig {logTLSErrors, transportTimeout = Nothing} -- | Connect to passed TCP host:port and pass handle to the client. -runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a +runTransportClient :: Transport c => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> IO a) -> IO a runTransportClient = runTLSTransportClient supportedParameters Nothing -runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a +runTLSTransportClient :: Transport c => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> IO a) -> IO a runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials} proxyUsername host port keyHash client = do serverCert <- newEmptyTMVarIO let hostName = B.unpack $ strEncode host @@ -137,7 +136,7 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, connectTCP = case socksProxy of Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host _ -> connectTCPClient hostName - c <- liftIO $ do + c <- do sock <- connectTCP port mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive" <> tshow e) let tCfg = clientTransportConfig cfg @@ -148,7 +147,7 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, closeTLS tls >> error "onServerCertificate failed" Just c -> pure c getClientConnection tCfg chain tls - client c `E.finally` liftIO (closeConnection c) + client c `E.finally` closeConnection c where hostAddr = \case THIPv4 addr -> SocksAddrIPV4 $ tupleToHostAddress addr diff --git a/src/Simplex/Messaging/Transport/Server.hs b/src/Simplex/Messaging/Transport/Server.hs index 66535bf21..542ebbb75 100644 --- a/src/Simplex/Messaging/Transport/Server.hs +++ b/src/Simplex/Messaging/Transport/Server.hs @@ -26,7 +26,6 @@ where import Control.Applicative ((<|>)) import Control.Logger.Simple import Control.Monad -import Control.Monad.IO.Unlift import qualified Crypto.Store.X509 as SX import Data.Default (def) import Data.List (find) @@ -70,27 +69,26 @@ serverTransportConfig TransportServerConfig {logTLSErrors} = -- | Run transport server (plain TCP or WebSockets) on passed TCP port and signal when server started and stopped via passed TMVar. -- -- All accepted connections are passed to the passed function. -runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m () +runTransportServer :: forall c. Transport c => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> IO ()) -> IO () runTransportServer started port params cfg server = do ss <- atomically newSocketState runTransportServerState ss started port params cfg server -runTransportServerState :: forall c m. (Transport c, MonadUnliftIO m) => SocketState -> TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m () +runTransportServerState :: forall c . Transport c => SocketState -> TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> IO ()) -> IO () runTransportServerState ss started port = runTransportServerSocketState ss started (startTCPServer started port) (transportName (TProxy :: TProxy c)) -- | Run a transport server with provided connection setup and handler. -runTransportServerSocket :: (MonadUnliftIO m, Transport a) => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m () +runTransportServerSocket :: Transport a => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> IO ()) -> IO () runTransportServerSocket started getSocket threadLabel serverParams cfg server = do ss <- atomically newSocketState runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server -- | Run a transport server with provided connection setup and handler. -runTransportServerSocketState :: (MonadUnliftIO m, Transport a) => SocketState -> TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m () +runTransportServerSocketState :: Transport a => SocketState -> TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> IO ()) -> IO () runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server = do - u <- askUnliftIO labelMyThread $ "transport server for " <> threadLabel - liftIO . runTCPServerSocket ss started getSocket $ \conn -> - E.bracket (setup conn >>= maybe (fail "tls setup timeout") pure) closeConnection (unliftIO u . server) + runTCPServerSocket ss started getSocket $ \conn -> + E.bracket (setup conn >>= maybe (fail "tls setup timeout") pure) closeConnection server where tCfg = serverTransportConfig cfg setup conn = timeout (tlsSetupTimeout cfg) $ do diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index e9d94f0c2..a880cfaad 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -50,26 +50,18 @@ maybeWord :: (a -> ByteString) -> Maybe a -> ByteString maybeWord f = maybe "" $ B.cons ' ' . f {-# INLINE maybeWord #-} -liftIOEither :: (MonadIO m, MonadError e m) => IO (Either e a) -> m a -liftIOEither a = liftIO a >>= liftEither -{-# INLINE liftIOEither #-} - -liftError :: (MonadIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a -liftError f = liftEitherError f . runExceptT +liftError :: MonadIO m => (e -> e') -> ExceptT e IO a -> ExceptT e' m a +liftError f = liftError' f . runExceptT {-# INLINE liftError #-} -liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a -liftEitherError f a = liftIOEither (first f <$> a) -{-# INLINE liftEitherError #-} +liftError' :: MonadIO m => (e -> e') -> IO (Either e a) -> ExceptT e' m a +liftError' f = ExceptT . fmap (first f) . liftIO +{-# INLINE liftError' #-} -liftEitherWith :: MonadError e' m => (e -> e') -> Either e a -> m a +liftEitherWith :: MonadIO m => (e -> e') -> Either e a -> ExceptT e' m a liftEitherWith f = liftEither . first f {-# INLINE liftEitherWith #-} -liftE :: (e -> e') -> ExceptT e IO a -> ExceptT e' IO a -liftE f a = ExceptT $ first f <$> runExceptT a -{-# INLINE liftE #-} - ifM :: Monad m => m Bool -> m a -> m a -> m a ifM ba t f = ba >>= \b -> if b then t else f {-# INLINE ifM #-} @@ -109,10 +101,18 @@ tryAllErrors :: (MonadUnliftIO m, MonadError e m) => (E.SomeException -> e) -> m tryAllErrors err action = tryError action `UE.catch` (pure . Left . err) {-# INLINE tryAllErrors #-} +tryAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> m (Either e a) +tryAllErrors' err action = runExceptT action `UE.catch` (pure . Left . err) +{-# INLINE tryAllErrors' #-} + catchAllErrors :: (MonadUnliftIO m, MonadError e m) => (E.SomeException -> e) -> m a -> (e -> m a) -> m a catchAllErrors err action handler = tryAllErrors err action >>= either handler pure {-# INLINE catchAllErrors #-} +catchAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> m a) -> m a +catchAllErrors' err action handler = tryAllErrors' err action >>= either handler pure +{-# INLINE catchAllErrors' #-} + catchThrow :: (MonadUnliftIO m, MonadError e m) => m a -> (E.SomeException -> e) -> m a catchThrow action err = catchAllErrors err action throwError {-# INLINE catchThrow #-} @@ -148,8 +148,8 @@ safeDecodeUtf8 = decodeUtf8With onError where onError _ _ = Just '?' -timeoutThrow :: (MonadUnliftIO m, MonadError e m) => e -> Int -> m a -> m a -timeoutThrow e ms action = timeout ms action >>= maybe (throwError e) pure +timeoutThrow :: MonadUnliftIO m => e -> Int -> ExceptT e m a -> ExceptT e m a +timeoutThrow e ms action = ExceptT (sequence <$> (ms `timeout` runExceptT action)) >>= maybe (throwError e) pure threadDelay' :: Int64 -> IO () threadDelay' time diff --git a/src/Simplex/RemoteControl/Client.hs b/src/Simplex/RemoteControl/Client.hs index 3cf1050fa..9ef1f820a 100644 --- a/src/Simplex/RemoteControl/Client.hs +++ b/src/Simplex/RemoteControl/Client.hs @@ -103,14 +103,14 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct found@(RCCtrlAddress {address} :| _) <- findCtrlAddress c@RCHClient_ {startedPort, announcer} <- liftIO mkClient hostKeys <- atomically genHostKeys - action <- runClient c r hostKeys `putRCError` r + action <- liftIO $ runClient c r hostKeys -- wait for the port to make invitation portNum <- atomically $ readTMVar startedPort signedInv@RCSignedInvitation {invitation} <- maybe (throwError RCETLSStartFailed) (liftIO . mkInvitation hostKeys address) portNum when multicast $ case knownHost of Nothing -> throwError RCENewController Just KnownHostPairing {hostDhPubKey} -> do - ann <- async . liftIO . runExceptT $ announceRC drg 60 idPrivKey hostDhPubKey hostKeys invitation + ann <- liftIO . async . runExceptT $ announceRC drg 60 idPrivKey hostDhPubKey hostKeys invitation atomically $ putTMVar announcer ann pure (found, signedInv, RCHostClient {action, client_ = c}, r) where @@ -125,9 +125,9 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct endSession <- newEmptyTMVarIO hostCAHash <- newEmptyTMVarIO pure RCHClient_ {startedPort, announcer, hostCAHash, endSession} - runClient :: RCHClient_ -> RCStepTMVar (SessionCode, TLS, RCStepTMVar (RCHostSession, RCHostHello, RCHostPairing)) -> RCHostKeys -> ExceptT RCErrorType IO (Async ()) + runClient :: RCHClient_ -> RCStepTMVar (SessionCode, TLS, RCStepTMVar (RCHostSession, RCHostHello, RCHostPairing)) -> RCHostKeys -> IO (Async ()) runClient RCHClient_ {startedPort, announcer, hostCAHash, endSession} r hostKeys = do - tlsCreds <- liftIO $ genTLSCredentials drg caKey caCert + tlsCreds <- genTLSCredentials drg caKey caCert startTLSServer port_ startedPort tlsCreds (tlsHooks r knownHost hostCAHash) $ \tls -> void . runExceptT $ do r' <- newEmptyTMVarIO @@ -265,7 +265,7 @@ connectRCCtrl_ :: TVar ChaChaDRG -> RCCtrlPairing -> RCInvitation -> J.Value -> connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca, host, port} hostAppInfo = do r <- newEmptyTMVarIO c <- liftIO mkClient - action <- async $ runClient c r `putRCError` r + action <- liftIO . async . void . runExceptT $ runClient c r `putRCError` r pure (RCCtrlClient {action, client_ = c}, r) where mkClient :: IO RCCClient_ @@ -280,7 +280,7 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca, TLS.Credentials (creds : _) -> pure $ Just creds _ -> throwError $ RCEInternal "genTLSCredentials must generate credentials" let clientConfig = defaultTransportClientConfig {clientCredentials} - runTransportClient clientConfig Nothing host (show port) (Just ca) $ \tls@TLS {tlsBuffer, tlsContext} -> do + ExceptT . runTransportClient clientConfig Nothing host (show port) (Just ca) $ \tls@TLS {tlsBuffer, tlsContext} -> runExceptT $ do -- pump socket to detect connection problems liftIO $ peekBuffered tlsBuffer 100000 (TLS.recvData tlsContext) >>= logDebug . tshow -- should normally be ("", Nothing) here logDebug "Got TLS connection" @@ -360,11 +360,11 @@ prepareCtrlSession -- * Multicast discovery announceRC :: TVar ChaChaDRG -> Int -> C.PrivateKeyEd25519 -> C.PublicKeyX25519 -> RCHostKeys -> RCInvitation -> ExceptT RCErrorType IO () -announceRC drg maxCount idPrivKey knownDhPub RCHostKeys {sessKeys, dhKeys} inv = withSender $ \sender -> do +announceRC drg maxCount idPrivKey knownDhPub RCHostKeys {sessKeys, dhKeys} inv = ExceptT $ withSender $ \sender -> runExceptT $ do replicateM_ maxCount $ do logDebug "Announcing..." nonce <- atomically $ C.randomCbNonce drg - encInvitation <- liftEitherWith undefined $ C.cbEncrypt sharedKey nonce sigInvitation encInvitationSize + encInvitation <- liftEitherWith (const RCEEncrypt) $ C.cbEncrypt sharedKey nonce sigInvitation encInvitationSize liftIO . UDP.send sender $ smpEncode RCEncInvitation {dhPubKey, nonce, encInvitation} threadDelay 1000000 where @@ -375,9 +375,9 @@ announceRC drg maxCount idPrivKey knownDhPub RCHostKeys {sessKeys, dhKeys} inv = discoverRCCtrl :: TMVar Int -> NonEmpty RCCtrlPairing -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation) discoverRCCtrl subscribers pairings = - timeoutThrow RCENotDiscovered 30000000 $ withListener subscribers $ \listener -> - loop $ do - (source, bytes) <- recvAnnounce listener + timeoutThrow RCENotDiscovered 30000000 $ ExceptT $ withListener subscribers $ \listener -> + runExceptT . loop $ do + (source, bytes) <- liftIO $ recvAnnounce listener encInvitation <- liftEitherWith (const RCEInvitation) $ smpDecode bytes r@(_, RCVerifiedInvitation RCInvitation {host}) <- findRCCtrlPairing pairings encInvitation case source of @@ -386,10 +386,7 @@ discoverRCCtrl subscribers pairings = pure r where loop :: ExceptT RCErrorType IO a -> ExceptT RCErrorType IO a - loop action = - liftIO (runExceptT action) >>= \case - Left err -> logError (tshow err) >> loop action - Right res -> pure res + loop action = action `catchRCError` \e -> logError (tshow e) >> loop action findRCCtrlPairing :: NonEmpty RCCtrlPairing -> RCEncInvitation -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation) findRCCtrlPairing pairings RCEncInvitation {dhPubKey, nonce, encInvitation} = do diff --git a/src/Simplex/RemoteControl/Discovery.hs b/src/Simplex/RemoteControl/Discovery.hs index 2155a1fba..e70eb1c25 100644 --- a/src/Simplex/RemoteControl/Discovery.hs +++ b/src/Simplex/RemoteControl/Discovery.hs @@ -68,7 +68,7 @@ preferAddress RCCtrlAddress {address, interface} addrs = matchAddr RCCtrlAddress {address = a} = a == address matchIface RCCtrlAddress {interface = i} = i == interface -startTLSServer :: MonadUnliftIO m => Maybe Word16 -> TMVar (Maybe N.PortNumber) -> TLS.Credentials -> TLS.ServerHooks -> (Transport.TLS -> IO ()) -> m (Async ()) +startTLSServer :: Maybe Word16 -> TMVar (Maybe N.PortNumber) -> TLS.Credentials -> TLS.ServerHooks -> (Transport.TLS -> IO ()) -> IO (Async ()) startTLSServer port_ startedOnPort credentials hooks server = async . liftIO $ do started <- newEmptyTMVarIO bracketOnError (startTCPServer started $ maybe "0" show port_) (\_e -> setPort Nothing) $ \socket -> @@ -91,14 +91,14 @@ startTLSServer port_ startedOnPort credentials hooks server = async . liftIO $ d TLS.serverSupported = supportedParameters } -withSender :: MonadUnliftIO m => (UDP.UDPSocket -> m a) -> m a -withSender = bracket (liftIO $ UDP.clientSocket MULTICAST_ADDR_V4 DISCOVERY_PORT False) (liftIO . UDP.close) +withSender :: (UDP.UDPSocket -> IO a) -> IO a +withSender = bracket (UDP.clientSocket MULTICAST_ADDR_V4 DISCOVERY_PORT False) (UDP.close) -withListener :: MonadUnliftIO m => TMVar Int -> (UDP.ListenSocket -> m a) -> m a +withListener :: TMVar Int -> (UDP.ListenSocket -> IO a) -> IO a withListener subscribers = bracket (openListener subscribers) (closeListener subscribers) -openListener :: MonadIO m => TMVar Int -> m UDP.ListenSocket -openListener subscribers = liftIO $ do +openListener :: TMVar Int -> IO UDP.ListenSocket +openListener subscribers = do sock <- UDP.serverSocket (MULTICAST_ADDR_V4, read DISCOVERY_PORT) logDebug $ "Discovery listener socket: " <> tshow sock let raw = UDP.listenSocket sock @@ -106,10 +106,9 @@ openListener subscribers = liftIO $ do joinMulticast subscribers raw (listenerHostAddr4 sock) pure sock -closeListener :: MonadIO m => TMVar Int -> UDP.ListenSocket -> m () +closeListener :: TMVar Int -> UDP.ListenSocket -> IO () closeListener subscribers sock = - liftIO $ - partMulticast subscribers (UDP.listenSocket sock) (listenerHostAddr4 sock) `finally` UDP.stop sock + partMulticast subscribers (UDP.listenSocket sock) (listenerHostAddr4 sock) `finally` UDP.stop sock joinMulticast :: TMVar Int -> N.Socket -> N.HostAddress -> IO () joinMulticast subscribers sock group = do @@ -132,7 +131,7 @@ listenerHostAddr4 sock = case UDP.mySockAddr sock of N.SockAddrInet _port host -> host _ -> error "MULTICAST_ADDR_V4 is V4" -recvAnnounce :: MonadIO m => UDP.ListenSocket -> m (N.SockAddr, ByteString) -recvAnnounce sock = liftIO $ do +recvAnnounce :: UDP.ListenSocket -> IO (N.SockAddr, ByteString) +recvAnnounce sock = do (invite, UDP.ClientSockAddr source _cmsg) <- UDP.recvFrom sock pure (source, invite) diff --git a/src/Simplex/RemoteControl/Types.hs b/src/Simplex/RemoteControl/Types.hs index b8a7c1141..e5d885b1d 100644 --- a/src/Simplex/RemoteControl/Types.hs +++ b/src/Simplex/RemoteControl/Types.hs @@ -238,17 +238,6 @@ type SessionCode = ByteString type RCStepTMVar a = TMVar (Either RCErrorType a) -type Tasks = TVar [Async ()] - -asyncRegistered :: MonadUnliftIO m => Tasks -> m () -> m () -asyncRegistered tasks action = async action >>= registerAsync tasks - -registerAsync :: MonadIO m => Tasks -> Async () -> m () -registerAsync tasks = atomically . modifyTVar tasks . (:) - -cancelTasks :: MonadIO m => Tasks -> m () -cancelTasks tasks = readTVarIO tasks >>= mapM_ cancel - $(JQ.deriveJSON (sumTypeJSON $ dropPrefix "RCE") ''RCErrorType) $(JQ.deriveJSON defaultJSON ''RCCtrlAddress) diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 1fe002cb0..3fa8becdf 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -233,13 +233,13 @@ inAnyOrder g rs = do expected :: a -> (a -> Bool) -> Bool expected r rp = rp r -createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) +createConnection :: AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> AE (ConnId, ConnectionRequestUri c) createConnection c userId enableNtfs cMode clientData = A.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQSupportOn) -joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId +joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE ConnId joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId enableNtfs cReq connInfo PQSupportOn -sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> m AgentMsgId +sendMessage :: AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> AE AgentMsgId sendMessage c connId msgFlags msgBody = do (msgId, pqEnc) <- A.sendMessage c connId PQEncOn msgFlags msgBody liftIO $ pqEnc `shouldBe` PQEncOn @@ -664,7 +664,7 @@ testAsyncInitiatingOffline :: HasCallStack => IO () testAsyncInitiatingOffline = withAgentClients2 $ \alice bob -> runRight_ $ do (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - disposeAgentClient alice + liftIO $ disposeAgentClient alice aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe alice' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB subscribeConnection alice' bobId @@ -680,7 +680,7 @@ testAsyncJoiningOfflineBeforeActivation = withAgentClients2 $ \alice bob -> runRight_ $ do (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe - disposeAgentClient bob + liftIO $ disposeAgentClient bob ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" bob' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB2 @@ -694,9 +694,9 @@ testAsyncBothOffline :: HasCallStack => IO () testAsyncBothOffline = withAgentClients2 $ \alice bob -> runRight_ $ do (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - disposeAgentClient alice + liftIO $ disposeAgentClient alice aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe - disposeAgentClient bob + liftIO $ disposeAgentClient bob alice' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB subscribeConnection alice' bobId ("", _, CONF confId _ "bob's connInfo") <- get alice' @@ -1067,8 +1067,7 @@ testExpireMessageQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1} testP b <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 (aId, bId) <- runRight $ do (aId, bId) <- makeConnection a b - liftIO $ threadDelay 500000 - disposeAgentClient b + liftIO $ threadDelay 500000 >> disposeAgentClient b 4 <- sendMessage a bId SMP.noMsgFlags "1" get a ##> ("", bId, SENT 4) 5 <- sendMessage a bId SMP.noMsgFlags "2" @@ -1091,8 +1090,7 @@ testExpireManyMessagesQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1} b <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 (aId, bId) <- runRight $ do (aId, bId) <- makeConnection a b - liftIO $ threadDelay 500000 - disposeAgentClient b + liftIO $ threadDelay 500000 >> disposeAgentClient b 4 <- sendMessage a bId SMP.noMsgFlags "1" get a ##> ("", bId, SENT 4) 5 <- sendMessage a bId SMP.noMsgFlags "2" @@ -1161,13 +1159,13 @@ setupDesynchronizedRatchet alice bob = do runRight_ $ do subscribeConnection bob2 aliceId - Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId PQSupportOn False + Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ synchronizeRatchet bob2 aliceId PQSupportOn False 8 <- sendMessage alice bobId SMP.noMsgFlags "hello 5" get alice ##> ("", bobId, SENT 8) get bob2 =##> ratchetSyncP aliceId RSRequired - Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ sendMessage bob2 aliceId SMP.noMsgFlags "hello 6" + Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ sendMessage bob2 aliceId SMP.noMsgFlags "hello 6" pure () pure (aliceId, bobId, bob2) @@ -1224,7 +1222,7 @@ testRatchetSyncClientRestart t = do ("", "", DOWN _ _) <- nGet bob2 ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ ratchetSyncState `shouldBe` RSStarted - disposeAgentClient bob2 + liftIO $ disposeAgentClient bob2 bob3 <- getSMPAgentClient' 3 agentCfg initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do runRight_ $ do @@ -1420,12 +1418,12 @@ testSuspendingAgent = get a ##> ("", bId, SENT 4) get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False ackMessage b aId 4 Nothing - suspendAgent b 1000000 + liftIO $ suspendAgent b 1000000 get' b ##> ("", "", SUSPENDED) 5 <- sendMessage a bId SMP.noMsgFlags "hello 2" get a ##> ("", bId, SENT 5) Nothing <- 100000 `timeout` get b - foregroundAgent b + liftIO $ foregroundAgent b get b =##> \case ("", c, Msg "hello 2") -> c == aId; _ -> False testSuspendingAgentCompleteSending :: ATransport -> IO () @@ -1444,7 +1442,7 @@ testSuspendingAgentCompleteSending t = withAgentClients2 $ \a b -> do 5 <- sendMessage b aId SMP.noMsgFlags "hello too" 6 <- sendMessage b aId SMP.noMsgFlags "how are you?" liftIO $ threadDelay 100000 - suspendAgent b 5000000 + liftIO $ suspendAgent b 5000000 withSmpServerStoreLogOn t testPort $ \_ -> runRight_ @AgentErrorType $ do pGet b =##> \case ("", c, APC _ (SENT 5)) -> c == aId; ("", "", APC _ UP {}) -> True; _ -> False @@ -1473,7 +1471,7 @@ testSuspendingAgentTimeout t = withAgentClients2 $ \a b -> do ("", "", DOWN {}) <- nGet b 5 <- sendMessage b aId SMP.noMsgFlags "hello too" 6 <- sendMessage b aId SMP.noMsgFlags "how are you?" - suspendAgent b 100000 + liftIO $ suspendAgent b 100000 ("", "", SUSPENDED) <- nGet b pure () @@ -2095,7 +2093,7 @@ testSwitchDelete servers = do runRight_ $ do (aId, bId) <- makeConnection a b exchangeGreetingsMsgId 4 a bId b aId - disposeAgentClient b + liftIO $ disposeAgentClient b stats <- switchConnectionAsync a "" bId liftIO $ rcvSwchStatuses' stats `shouldMatchList` [Just RSSwitchStarted] phaseRcv a bId SPStarted [Just RSSendingQADD, Nothing] @@ -2120,7 +2118,7 @@ testAbortSwitchStarted servers = do liftIO $ rcvSwchStatuses' stats `shouldMatchList` [Just RSSwitchStarted] phaseRcv a bId SPStarted [Just RSSendingQADD, Nothing] -- repeat switch is prohibited - Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ switchConnectionAsync a "" bId + Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ switchConnectionAsync a "" bId -- abort current switch stats' <- abortConnectionSwitch a bId liftIO $ rcvSwchStatuses' stats' `shouldMatchList` [Nothing] @@ -2242,7 +2240,7 @@ testCannotAbortSwitchSecured servers = do withA' $ \a -> do phaseRcv a bId SPConfirmed [Just RSSendingQADD, Nothing] phaseRcv a bId SPSecured [Just RSSendingQUSE, Nothing] - Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ abortConnectionSwitch a bId + Left A.CMD {cmdErr = PROHIBITED} <- liftIO . runExceptT $ abortConnectionSwitch a bId pure () withA $ \a -> withB $ \b -> runRight_ $ do subscribeConnection a bId @@ -2407,7 +2405,7 @@ testSMPServerConnectionTest :: ATransport -> Maybe BasicAuth -> SMPServerWithAut testSMPServerConnectionTest t newQueueBasicAuth srv = withSmpServerConfigOn t cfg {newQueueBasicAuth} testPort2 $ \_ -> do a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB -- initially passed server is not running - runRight $ testProtocolServer a 1 srv + testProtocolServer a 1 srv testRatchetAdHash :: HasCallStack => IO () testRatchetAdHash = @@ -2551,7 +2549,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do exchangeGreetings a bId1' b aId1' a `hasClients` 1 b `hasClients` 1 - setNetworkConfig a nc {sessionMode = TSMEntity} + liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity} liftIO $ threadDelay 250000 ("", "", DOWN _ _) <- nGet a ("", "", UP _ _) <- nGet a @@ -2560,7 +2558,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do exchangeGreetingsMsgId 6 a bId1 b aId1 exchangeGreetingsMsgId 6 a bId1' b aId1' liftIO $ threadDelay 250000 - setNetworkConfig a nc {sessionMode = TSMUser} + liftIO $ setNetworkConfig a nc {sessionMode = TSMUser} liftIO $ threadDelay 250000 ("", "", DOWN _ _) <- nGet a ("", "", DOWN _ _) <- nGet a @@ -2575,7 +2573,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do exchangeGreetings a bId2' b aId2' a `hasClients` 2 b `hasClients` 1 - setNetworkConfig a nc {sessionMode = TSMEntity} + liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity} liftIO $ threadDelay 250000 ("", "", DOWN _ _) <- nGet a ("", "", DOWN _ _) <- nGet a @@ -2587,7 +2585,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do exchangeGreetingsMsgId 6 a bId2 b aId2 exchangeGreetingsMsgId 6 a bId2' b aId2' liftIO $ threadDelay 250000 - setNetworkConfig a nc {sessionMode = TSMUser} + liftIO $ setNetworkConfig a nc {sessionMode = TSMUser} liftIO $ threadDelay 250000 ("", "", DOWN _ _) <- nGet a ("", "", DOWN _ _) <- nGet a @@ -2625,9 +2623,10 @@ testServerMultipleIdentities = get bob ##> ("", aliceId, CON) exchangeGreetings alice bobId bob aliceId -- this saves queue with second server identity - Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe - disposeAgentClient bob - bob' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB2 + bob' <- liftIO $ do + Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe + disposeAgentClient bob + getSMPAgentClient' 3 agentCfg initAgentServers testDB2 subscribeConnection bob' aliceId exchangeGreetingsMsgId 6 alice bobId bob' aliceId where diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index c884cbd93..722e48d02 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -179,7 +179,7 @@ testNotificationToken APNSMockServer {apnsQ} = do deleteNtfToken a tkn -- agent deleted this token Left (CMD PROHIBITED) <- tryE $ checkNtfToken a tkn - disposeAgentClient a + liftIO $ disposeAgentClient a (.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString v .-> key = do @@ -211,7 +211,7 @@ testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do -- can still use the first verification code, it is the same after decryption verifyNtfToken a tkn nonce verification NTActive <- checkNtfToken a tkn - disposeAgentClient a + liftIO $ disposeAgentClient a testNtfTokenSecondRegistration :: APNSMockServer -> IO () testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do @@ -247,8 +247,9 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do Left (NTF AUTH) <- tryE $ checkNtfToken a tkn -- and the second is active NTActive <- checkNtfToken a' tkn - disposeAgentClient a - disposeAgentClient a' + pure () + disposeAgentClient a + disposeAgentClient a' testNtfTokenServerRestart :: ATransport -> APNSMockServer -> IO () testNtfTokenServerRestart t APNSMockServer {apnsQ} = do @@ -277,11 +278,11 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do liftIO $ sendApnsResponse' APNSRespOk verifyNtfToken a' tkn nonce' verification' NTActive <- checkNtfToken a' tkn - disposeAgentClient a' + liftIO $ disposeAgentClient a' -getTestNtfTokenPort :: (MonadUnliftIO m, MonadError AgentErrorType m) => AgentClient -> m String +getTestNtfTokenPort :: AgentClient -> AE String getTestNtfTokenPort a = - runReaderT (withStore' a getSavedNtfToken) (agentEnv a) >>= \case + ExceptT (runExceptT (withStore' a getSavedNtfToken) `runReaderT` agentEnv a) >>= \case Just NtfToken {ntfServer = ProtocolServer {port}} -> pure port Nothing -> error "no active NtfToken" @@ -317,18 +318,18 @@ testNtfTokenChangeServers t APNSMockServer {apnsQ} = a <- liftIO $ getSMPAgentClient' 1 agentCfg initAgentServers testDB tkn <- registerTestToken a "abcd" NMInstant apnsQ NTActive <- checkNtfToken a tkn - setNtfServers a [testNtfServer2] + liftIO $ setNtfServers a [testNtfServer2] NTActive <- checkNtfToken a tkn -- still works on old server - disposeAgentClient a + liftIO $ disposeAgentClient a pure tkn threadDelay 1000000 - a <- liftIO $ getSMPAgentClient' 2 agentCfg initAgentServers testDB + a <- getSMPAgentClient' 2 agentCfg initAgentServers testDB runRight_ $ do getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort NTActive <- checkNtfToken a tkn1 - setNtfServers a [testNtfServer2] -- just change configured server list + liftIO $ setNtfServers a [testNtfServer2] -- just change configured server list getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort -- not yet changed -- trigger token replace tkn2 <- registerTestToken a "xyzw" NMInstant apnsQ @@ -345,7 +346,7 @@ testRunNTFServerTests :: ATransport -> NtfServer -> IO (Maybe ProtocolTestFailur testRunNTFServerTests t srv = withNtfServerThreadOn t ntfTestPort $ \ntf -> do a <- liftIO $ getSMPAgentClient' 1 agentCfg initAgentServers testDB - r <- runRight $ testProtocolServer a 1 $ ProtoServerWithAuth srv Nothing + r <- testProtocolServer a 1 $ ProtoServerWithAuth srv Nothing killThread ntf pure r @@ -712,7 +713,7 @@ testNotificationsOldToken APNSMockServer {apnsQ} = do liftIO $ threadDelay 250000 testMessageAB "hello" -- change server - setNtfServers a [testNtfServer2] -- server 2 isn't running now, don't use + liftIO $ setNtfServers a [testNtfServer2] -- server 2 isn't running now, don't use -- replacing token keeps server _ <- registerTestToken a "xyzw" NMInstant apnsQ getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort @@ -738,7 +739,7 @@ testNotificationsNewToken APNSMockServer {apnsQ} oldNtf = do liftIO $ threadDelay 250000 testMessageAB "hello" -- switch - setNtfServers a [testNtfServer2] + liftIO $ setNtfServers a [testNtfServer2] deleteNtfToken a tkn _ <- registerTestToken a "abcd" NMInstant apnsQ getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort2 diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index 5a2dbb8de..26981628f 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -16,7 +16,6 @@ module NtfClient where import Control.Monad import Control.Monad.Except (runExceptT) -import Control.Monad.IO.Unlift import Data.Aeson (FromJSON (..), ToJSON (..), (.:)) import qualified Data.Aeson as J import qualified Data.Aeson.Types as JT @@ -71,13 +70,13 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=" ntfTestStoreLogFile :: FilePath ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log" -testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleNTF c -> m a) -> m a +testNtfClient :: Transport c => (THandleNTF c -> IO a) -> IO a testNtfClient client = do Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h -> do - g <- liftIO C.newRandom + g <- C.newRandom ks <- atomically $ C.generateKeyPair g - liftIO (runExceptT $ ntfClientHandshake h ks testKeyHash supportedClientNTFVRange) >>= \case + runExceptT (ntfClientHandshake h ks testKeyHash supportedClientNTFVRange) >>= \case Right th -> client th Left e -> error $ show e diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index c85499c12..59370e654 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -12,7 +12,6 @@ module SMPAgentClient where import Control.Monad import Control.Monad.IO.Unlift -import Crypto.Random import qualified Data.ByteString.Char8 as B import Data.List.NonEmpty (NonEmpty) import Data.Map.Strict (Map) @@ -202,7 +201,7 @@ initAgentServers2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer agentCfg :: AgentConfig agentCfg = defaultAgentConfig - { tcpPort = agentTestPort, + { tcpPort = Just agentTestPort, tbqSize = 4, -- database = testDB, smpCfg = defaultSMPClientConfig {qSize = 1, defaultTransport = (testPort, transport @TLS), networkConfig}, @@ -224,11 +223,9 @@ fastRetryInterval = defaultReconnectInterval {initialInterval = 50_000} fastMessageRetryInterval :: RetryInterval2 fastMessageRetryInterval = RetryInterval2 {riFast = fastRetryInterval, riSlow = fastRetryInterval} -type AgentTestMonad m = (MonadUnliftIO m, MonadRandom m, MonadFail m) - -withSmpAgentThreadOn_ :: AgentTestMonad m => ATransport -> (ServiceName, ServiceName, FilePath) -> Int -> m () -> (ThreadId -> m a) -> m a +withSmpAgentThreadOn_ :: ATransport -> (ServiceName, ServiceName, FilePath) -> Int -> IO () -> (ThreadId -> IO a) -> IO a withSmpAgentThreadOn_ t (port', smpPort', db') initClientId afterProcess = - let cfg' = agentCfg {tcpPort = port'} + let cfg' = agentCfg {tcpPort = Just port'} initServers' = initAgentServers {smp = userServers [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]} in serverBracket ( \started -> do @@ -241,24 +238,24 @@ withSmpAgentThreadOn_ t (port', smpPort', db') initClientId afterProcess = userServers :: NonEmpty (ProtoServerWithAuth p) -> Map UserId (NonEmpty (ProtoServerWithAuth p)) userServers srvs = M.fromList [(1, srvs)] -withSmpAgentThreadOn :: AgentTestMonad m => ATransport -> (ServiceName, ServiceName, FilePath) -> (ThreadId -> m a) -> m a +withSmpAgentThreadOn :: ATransport -> (ServiceName, ServiceName, FilePath) -> (ThreadId -> IO a) -> IO a withSmpAgentThreadOn t a@(_, _, db') = withSmpAgentThreadOn_ t a 0 $ removeFile db' -withSmpAgentOn :: AgentTestMonad m => ATransport -> (ServiceName, ServiceName, FilePath) -> m a -> m a +withSmpAgentOn :: ATransport -> (ServiceName, ServiceName, FilePath) -> IO a -> IO a withSmpAgentOn t (port', smpPort', db') = withSmpAgentThreadOn t (port', smpPort', db') . const -withSmpAgent :: AgentTestMonad m => ATransport -> m a -> m a +withSmpAgent :: ATransport -> IO a -> IO a withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB) -testSMPAgentClientOn :: (Transport c, MonadUnliftIO m, MonadFail m) => ServiceName -> (c -> m a) -> m a +testSMPAgentClientOn :: Transport c => ServiceName -> (c -> IO a) -> IO a testSMPAgentClientOn port' client = do Right useHost <- pure $ chooseTransportHost defaultNetworkConfig agentTestHost runTransportClient defaultTransportClientConfig Nothing useHost port' (Just testKeyHash) $ \h -> do - line <- liftIO $ getLn h + line <- getLn h if line == "Welcome to SMP agent v" <> B.pack simplexMQVersion then client h else do error $ "wrong welcome message: " <> B.unpack line -testSMPAgentClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (c -> m a) -> m a +testSMPAgentClient :: Transport c => (c -> IO a) -> IO a testSMPAgentClient = testSMPAgentClientOn agentTestPort diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 87163483f..330a3f14c 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -13,7 +13,6 @@ module SMPClient where import Control.Monad.Except (runExceptT) -import Control.Monad.IO.Unlift import Data.ByteString.Char8 (ByteString) import Data.List.NonEmpty (NonEmpty) import Network.Socket @@ -68,23 +67,23 @@ xit'' d t = do ci <- runIO $ lookupEnv "CI" (if ci == Just "true" then skip "skipped on CI" . it d else it d) t -testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleSMP c -> m a) -> m a +testSMPClient :: Transport c => (THandleSMP c -> IO a) -> IO a testSMPClient = testSMPClientVR supportedClientSMPRelayVRange -testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRangeSMP -> (THandleSMP c -> m a) -> m a +testSMPClientVR :: Transport c => VersionRangeSMP -> (THandleSMP c -> IO a) -> IO a testSMPClientVR vr client = do Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h -> do - g <- liftIO C.newRandom + g <- C.newRandom ks <- atomically $ C.generateKeyPair g - liftIO (runExceptT $ smpClientHandshake h ks testKeyHash vr) >>= \case + runExceptT (smpClientHandshake h ks testKeyHash vr) >>= \case Right th -> client th Left e -> error $ show e cfg :: ServerConfig cfg = ServerConfig - { transports = undefined, + { transports = [], smpHandshakeTimeout = 60000000, tbqSize = 1, -- serverTbqSize = 1, @@ -129,7 +128,7 @@ withSmpServerConfigOn t cfg' port' = withSmpServerThreadOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a withSmpServerThreadOn t = withSmpServerConfigOn t cfg -serverBracket :: (HasCallStack, MonadUnliftIO m) => (TMVar Bool -> m ()) -> m () -> (HasCallStack => ThreadId -> m a) -> m a +serverBracket :: HasCallStack => (TMVar Bool -> IO ()) -> IO () -> (HasCallStack => ThreadId -> IO a) -> IO a serverBracket process afterProcess f = do started <- newEmptyTMVarIO E.bracket diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 008b4da88..951a69771 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -103,7 +103,7 @@ tPut1 h t = do [r] <- tPut h [Right t] pure r -tGet1 :: (ProtocolEncoding v err cmd, Transport c, MonadIO m, MonadFail m) => THandle v c -> m (SignedTransmission err cmd) +tGet1 :: (ProtocolEncoding v err cmd, Transport c) => THandle v c -> IO (SignedTransmission err cmd) tGet1 h = do [r] <- liftIO $ tGet h pure r diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index 59c668d66..2befdcc76 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -103,7 +103,7 @@ testXFTPAgentSendReceive = withXFTPServer $ do sndr <- getSMPAgentClient' 1 agentCfg initAgentServers testDB (rfd1, rfd2) <- runRight $ do (sfId, _, rfd1, rfd2) <- testSend sndr filePath - xftpDeleteSndFileInternal sndr sfId + liftIO $ xftpDeleteSndFileInternal sndr sfId pure (rfd1, rfd2) -- receive file, delete rcv file @@ -112,9 +112,8 @@ testXFTPAgentSendReceive = withXFTPServer $ do where testReceiveDelete clientId rfd originalFilePath = do rcp <- getSMPAgentClient' clientId agentCfg initAgentServers testDB2 - runRight_ $ do - rfId <- testReceive rcp rfd originalFilePath - xftpDeleteRcvFile rcp rfId + rfId <- runRight $ testReceive rcp rfd originalFilePath + xftpDeleteRcvFile rcp rfId disposeAgentClient rcp testXFTPAgentSendReceiveEncrypted :: HasCallStack => IO () @@ -127,7 +126,7 @@ testXFTPAgentSendReceiveEncrypted = withXFTPServer $ do sndr <- getSMPAgentClient' 1 agentCfg initAgentServers testDB (rfd1, rfd2) <- runRight $ do (sfId, _, rfd1, rfd2) <- testSendCF sndr file - xftpDeleteSndFileInternal sndr sfId + liftIO $ xftpDeleteSndFileInternal sndr sfId pure (rfd1, rfd2) -- receive file, delete rcv file testReceiveDelete 2 rfd1 filePath g @@ -136,9 +135,8 @@ testXFTPAgentSendReceiveEncrypted = withXFTPServer $ do testReceiveDelete clientId rfd originalFilePath g = do rcp <- getSMPAgentClient' clientId agentCfg initAgentServers testDB2 cfArgs <- atomically $ Just <$> CF.randomArgs g - runRight_ $ do - rfId <- testReceiveCF rcp rfd cfArgs originalFilePath - xftpDeleteRcvFile rcp rfId + rfId <- runRight $ testReceiveCF rcp rfd cfArgs originalFilePath + xftpDeleteRcvFile rcp rfId disposeAgentClient rcp testXFTPAgentSendReceiveRedirect :: HasCallStack => IO () @@ -468,11 +466,9 @@ testXFTPAgentDelete = withGlobalLogging logCfgNoLogs $ length <$> listDirectory xftpServerFiles `shouldReturn` 6 -- delete file - runRight $ do - xftpStartWorkers sndr (Just senderFiles) - xftpDeleteSndFileRemote sndr 1 sfId sndDescr - Nothing <- liftIO $ 100000 `timeout` sfGet sndr - pure () + runRight_ $ xftpStartWorkers sndr (Just senderFiles) + xftpDeleteSndFileRemote sndr 1 sfId sndDescr + Nothing <- 100000 `timeout` sfGet sndr disposeAgentClient rcp1 threadDelay 1000000 @@ -505,10 +501,9 @@ testXFTPAgentDeleteRestore = withGlobalLogging logCfgNoLogs $ do -- delete file - should not succeed with server down sndr <- getSMPAgentClient' 3 agentCfg initAgentServers testDB - runRight $ do - xftpStartWorkers sndr (Just senderFiles) - xftpDeleteSndFileRemote sndr 1 sfId sndDescr - liftIO $ timeout 300000 (get sndr) `shouldReturn` Nothing -- wait for worker attempt + runRight_ $ xftpStartWorkers sndr (Just senderFiles) + xftpDeleteSndFileRemote sndr 1 sfId sndDescr + timeout 300000 (get sndr) `shouldReturn` Nothing -- wait for worker attempt disposeAgentClient sndr threadDelay 300000 @@ -636,4 +631,4 @@ testXFTPServerTest :: HasCallStack => Maybe BasicAuth -> XFTPServerWithAuth -> I testXFTPServerTest newFileBasicAuth srv = withXFTPServerCfg testXFTPServerConfig {newFileBasicAuth, xftpPort = xftpTestPort2} $ \_ -> do a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB -- initially passed server is not running - runRight $ testProtocolServer a 1 srv + testProtocolServer a 1 srv diff --git a/tests/XFTPServerTests.hs b/tests/XFTPServerTests.hs index 71700280a..e2d447274 100644 --- a/tests/XFTPServerTests.hs +++ b/tests/XFTPServerTests.hs @@ -28,7 +28,6 @@ import qualified Simplex.Messaging.Crypto.Lazy as LC import qualified Simplex.Messaging.Encoding.Base64.URL as U import Simplex.Messaging.Protocol (BasicAuth, SenderId) import Simplex.Messaging.Server.Expiration (ExpirationConfig (..)) -import Simplex.Messaging.Util (liftIOEither) import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive, removeFile) import System.FilePath (()) import Test.Hspec @@ -220,7 +219,7 @@ testFileChunkExpiration = withXFTPServerCfg testXFTPServerConfig {fileExpiration testInactiveClientExpiration :: Expectation testInactiveClientExpiration = withXFTPServerCfg testXFTPServerConfig {inactiveClientExpiration} $ \_ -> runRight_ $ do disconnected <- newEmptyTMVarIO - c <- liftIOEither $ getXFTPClient (1, testXFTPServer, Nothing) testXFTPClientConfig (\_ -> atomically $ putTMVar disconnected ()) + c <- ExceptT $ getXFTPClient (1, testXFTPServer, Nothing) testXFTPClientConfig (\_ -> atomically $ putTMVar disconnected ()) pingXFTP c liftIO $ do threadDelay 100000