diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index 9a9e094ea..36abd1091 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -16,7 +16,6 @@ logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} main :: IO () main = do - setLogLevel LogDebug cfgPath <- getEnvPath "SMP_SERVER_CFG_PATH" defaultCfgPath logPath <- getEnvPath "SMP_SERVER_LOG_PATH" defaultLogPath withGlobalLogging logCfg $ smpServerCLI_ Static.generateSite Static.serveStaticFiles Static.attachStaticFiles cfgPath logPath diff --git a/rfcs/2025-05-05-client-certificates.md b/rfcs/2025-05-05-client-certificates.md new file mode 100644 index 000000000..00de2f9b3 --- /dev/null +++ b/rfcs/2025-05-05-client-certificates.md @@ -0,0 +1,142 @@ +# Service certificates for high volume servers and services connecting to SMP servers + +## Problem + +The absense of user and client identification benefits privacy, but it requires separately authorizing subscription for each messaging queue, that doesn't scale when a high volume server or service acts as a client for SMP server even for the current traffic and network size. + +These servers/services include: +- operators' chat relays (aka super-peers), +- notification servers, +- high-traffic service chat bots, +- high-traffic business support clients. + +The future chat relays would reduce the number of subscriptions required for the usual clients, by replacing connections with each group member to 1-3 connections with chat relays per group/community, it would shift the burden to the chat relays, that are also clients. + +Self-hosted chat relays may want to retain privacy, so they will not use client certificates, but this privacy is not needed (and counter-productive) for the chat relays provided by network operators. + +Even today, directory service subscribing to all queues may take 15-20 minutes, which is experienced as downtime by the end users. + +Notification servers also acting as clients to messaging servers also take 15-20 minutes to subscribe to all notifications, during which time notifications are not delivered. + +Not only these subscription take a lot of time, they also consume a large amount of memory both in the clients and in the servers, as association between clients and queues is currently session-scoped and not persisted anywhere (and it should not be, because end-users' clients do need privacy). + +## Solution + +High volume "clients" (operators' chat relays, directory service, SimpleX Chat team support client, SimpleX Status bot, etc.) that don't need privacy will identify themselves to the messaging servers at a point of connection by providing client sertificate, both in TLS handshake and in SMP handshake (the same certificate must be provided). + +All the new queues and subscriptions made in this session will be creating a permanent association of the messaging queue with the client, and on subsequent reconnections the client can "subscribe" to all their queues with a single client subscription command. + +This will save a lot of time subscribing and resubscribing on server and client restarts, servers' bandwidth, servers' traffic spikes, and memory of both clients and servers. + +## Protocol + +An ephemeral per-session signature key signed by long-term client certificate is used for client authorization – this session signature key will be passed in SMP handshake. + +To transition existing queues, the subscription command will have to be double-signed - by the queue key, and then by client key. + +When server receives such "hand-over" subscription it would create a permanent association between the client certificate and the queue, and on subsequent re-connections the client can subscribe to all the existing queues still associated with the client with one command. + +The server will respond to the client with the number of queues it was subscribed to - it would both inform the client that it has to re-connect in case of interruption, and can be used for client and server statistics. + +When client creates a new queue, it would also sign the request with both keys, per-queue and client's. Other queue operations (e.g., deletion, or changing associated queue data for short links) would still require two signatures, both the queue key and the client key. + +The open question is whether there is any value in allowing to remove the association between the client and the queue. Probably not, as threat model should assume that the server would retain this information, and the use-case for users controlling their servers is narrow. + +## Protocol connection handshake + +Currently, the types for handshakes are: + +```haskell +data ServerHandshake = ServerHandshake + { smpVersionRange :: VersionRangeSMP, + sessionId :: SessionId, + -- pub key to agree shared secrets for command authorization and entity ID encryption. + -- todo C.PublicKeyX25519 + authPubKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey) + } + +data ClientHandshake = ClientHandshake + { -- | agreed SMP server protocol version + smpVersion :: VersionSMP, + -- | server identity - CA certificate fingerprint + keyHash :: C.KeyHash, + -- | pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys. + authPubKey :: Maybe C.PublicKeyX25519, + -- | Whether connecting client is a proxy server (send from SMP v12). + -- This property, if True, disables additional transport encrytion inside TLS. + -- (Proxy server connection already has additional encryption, so this layer is not needed there). + proxyServer :: Bool + } +``` + +`ServerHandshake` already contains `authPubKey` with the server certificate chain and the signed key for connection encryption and creating a shared secret for denable authorization (with client entity key) and session encryption layer. + +`ClientHandshake` contains only ephemeral `authPubKey` to compute a shared secret for session encryption layer, so we need an additional field for an optional client certificate: + +```haskell +serviceCertKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey) +``` + +Certificate here defines client identity. The actual key to be used to sign commands is session-scoped, and is signed by the certificate key. In case of notification server it MUST be the same certificate that is used for server TLS connections. + +For operators' clients we may optionally include operators' certificate in the chain, and that would allow servers to identify operators if either wants to. This would improve end-user security, as not only the server would validate that its certificate matches the address, but it would also validate that it is operated by SimpleX Chat or by Flux, preventing any server impersonation (e.g., via DNS manipulations) - the client could then report that the files are hosted on SimpleX Chat servers, but then can stop and show additional warning in case certificate does not match the domain - same as the browsers do with CA stores in the client. + +## Protocol transmissions + +Each transport block can contain one or several protocol transmissions. + +Each transmission has this structure: + +```abnf +transmission = authenticator authorized +; authenticator - Ed25519 signature for recipients or X25519 authenticator for senders, to provide repudiation. +; authenticator authorizes the rest of the transmission. +authorized = sessId corrId entityId command. +; sessId is tls-unique channel binding, its presense in the transmission prevents replay attacks. +``` + +The proposed change would replace authenticator with exactly one or two authenticators, where the first one will remain resource-level authorization (queue key), and the optional second one will be client authorization with the client key. + +```abnf +authenticator = queue_authenticator ("0" / "1" service_authenticator) +; "0" and "1" characters (digit characters, not x00 or x01) are conventionally used for Maybe types in the protocol. +``` + +In case service_authenticator is present, queue_authenticator should authorize over `fingerprint authorized` (concatenation of service identity certificate fingerprint and the rest of the transmission). + +All queues created with client key will have to be double-authorized with both the queue key and the client key - both the client and the server would have to maintain this knowledge, whether the queue is associated with the client or not. + +Asymmetric retries have to be supported - the first request creating this association may succeed on the server and timeout on the client. + +## Subscription + +To subscribe to all associated queues the client has to send a single command authorized with the client key passed in handshake. + +The command and response: + +```haskell +SUBS :: Command Recipient -- to enable all client subscriptions, empty entity ID in the transmission, signed by client key - it must be the same as was used in handover subscription signature. +NSUBS :: Command Recipient -- notification subscription +SOK :: Maybe ServiceId -- new subscription response +SOKS :: Int64 -> BrokerMsg -- response from the server, includes the number of subscribed queues +ENDS :: Int64 -> BrokerMsg -- when another session subscribes with the same certificate +``` + +Open questions: +- What should used as an entity ID for `SUBS` transmission - certificate fingerprint or an empty string? +- Should there be a command to get the list of all associated queues? It is likely to be useful for debugging? +- What should happen when `SUB` is sent for a single already associated queue? What if it is signed with the correct session key, but that is different from existing association? The current approach is that once associated, this associaiton would require authorization for single subscriptions, with the same certificate as already associated. + +## Ephemeral client-session association + +This was considered to reduce costs for the usual clients to re-subscribe. Currently it's a big problem, because of groups, and with transition to chat relays it won't be. + +For some very busy end-user clients it may help. + +Given that server has access to an ephemeral association between recipient client session and queues anyway (even with clients connecting via Tor, unless per-connection transport isolation is used), introducing `sessionPubKey` to allow resubscription to the previously subscribed queues may reduce the traffic. This won't change threat model as the server would only keep this association in memory, and not persist it. Clients on another hand may safely persist this association for fast resubscription on client restarts. + +This is not planned for the forseable future, as migrating to chat relays would solve most of the problem. + +Assuming an average active user has 20 contacts and 20 groups, and they would need ~3 subscriptions for each (for redundancy), so about 120 subscription to reconnect. The single 16kb transport block allows to send ~136 subscriptions. Which means that ephemeral sessions would create no value for clients at all, unless they are super active. + +Further, improving transport efficiency for super-active non-identified clients may help network abuse, so ephemeral sessions may have negative value. diff --git a/simplexmq.cabal b/simplexmq.cabal index cb35b4b4d..c8e7c970a 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -147,6 +147,7 @@ library Simplex.Messaging.Transport.HTTP2.Server Simplex.Messaging.Transport.KeepAlive Simplex.Messaging.Transport.Server + Simplex.Messaging.Transport.Shared Simplex.Messaging.Util Simplex.Messaging.Version Simplex.Messaging.Version.Internal diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index 2aa7d4757..4df783134 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -116,7 +116,7 @@ getXFTPClient transportSession@(_, srv, _) config@XFTPClientConfig {clientALPN, let HTTP2Client {sessionId, sessionALPN} = http2Client v = VersionXFTP 1 thServerVRange = versionToRange v - thParams0 = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = v, thServerVRange, thAuth = Nothing, implySessId = False, encryptBlock = Nothing, batch = True} + thParams0 = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = v, thServerVRange, thAuth = Nothing, implySessId = False, encryptBlock = Nothing, batch = True, serviceAuth = False} logDebug $ "Client negotiated handshake protocol: " <> tshow sessionALPN thParams@THandleParams {thVersion} <- case sessionALPN of Just "xftp/1" -> xftpClientHandshakeV1 serverVRange keyHash http2Client thParams0 @@ -132,7 +132,8 @@ xftpClientHandshakeV1 serverVRange keyHash@(C.KeyHash kh) c@HTTP2Client {session (vr, sk) <- processServerHandshake shs let v = maxVersion vr sendClientHandshake XFTPClientHandshake {xftpVersion = v, keyHash} - pure thParams0 {thAuth = Just THAuthClient {serverPeerPubKey = sk, serverCertKey = ck, sessSecret = Nothing}, thVersion = v, thServerVRange = vr} + let thAuth = Just THAuthClient {peerServerPubKey = sk, peerServerCertKey = ck, clientService = Nothing, sessSecret = Nothing} + pure thParams0 {thAuth, thVersion = v, thServerVRange = vr} where getServerHandshake :: ExceptT XFTPClientError IO XFTPServerHandshake getServerHandshake = do diff --git a/src/Simplex/FileTransfer/Protocol.hs b/src/Simplex/FileTransfer/Protocol.hs index 9bf552732..4a9b0086a 100644 --- a/src/Simplex/FileTransfer/Protocol.hs +++ b/src/Simplex/FileTransfer/Protocol.hs @@ -144,10 +144,15 @@ instance Protocol XFTPVersion XFTPErrorType FileResponse where type ProtoCommand FileResponse = FileCmd type ProtoType FileResponse = 'PXFTP protocolClientHandshake = xftpClientHandshakeStub + {-# INLINE protocolClientHandshake #-} + useServiceAuth _ = False + {-# INLINE useServiceAuth #-} protocolPing = FileCmd SFRecipient PING + {-# INLINE protocolPing #-} protocolError = \case FRErr e -> Just e _ -> Nothing + {-# INLINE protocolError #-} data FileCommand (p :: FileParty) where FNEW :: FileInfo -> NonEmpty RcvPublicAuthKey -> Maybe BasicAuth -> FileCommand FSender @@ -227,6 +232,7 @@ instance ProtocolEncoding XFTPVersion XFTPErrorType FileCmd where {-# INLINE fromProtocolError #-} checkCredentials t (FileCmd p c) = FileCmd p <$> checkCredentials t c + {-# INLINE checkCredentials #-} instance Encoding FileInfo where smpEncode FileInfo {sndKey, size, digest} = smpEncode (sndKey, size, digest) @@ -332,7 +338,7 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion 'TClient -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString xftpEncodeAuthTransmission thParams@THandleParams {thAuth} pKey (corrId, fId, msg) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, fId, msg) - xftpEncodeBatch1 . (,tToSend) =<< authTransmission thAuth (Just pKey) (C.cbNonce $ bs corrId) tForAuth + xftpEncodeBatch1 . (,tToSend) =<< authTransmission thAuth False (Just pKey) (C.cbNonce $ bs corrId) tForAuth xftpEncodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion p -> Transmission c -> Either TransportError ByteString xftpEncodeTransmission thParams (corrId, fId, msg) = do @@ -341,7 +347,7 @@ xftpEncodeTransmission thParams (corrId, fId, msg) = do -- this function uses batch syntax but puts only one transmission in the batch xftpEncodeBatch1 :: SentRawTransmission -> Either TransportError ByteString -xftpEncodeBatch1 t = first (const TELargeMsg) $ C.pad (tEncodeBatch1 t) xftpBlockSize +xftpEncodeBatch1 t = first (const TELargeMsg) $ C.pad (tEncodeBatch1 False t) xftpBlockSize xftpDecodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion p -> ByteString -> Either XFTPErrorType (SignedTransmission e c) xftpDecodeTransmission thParams t = do diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index b4e71ec19..9a385739b 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -26,12 +26,12 @@ import Data.ByteString.Builder (Builder, byteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Int (Int64) -import Data.List (intercalate) import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe, isJust) import qualified Data.Text as T +import qualified Data.Text.IO as T import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime) import Data.Time.Format.ISO8601 (iso8601Show) import Data.Word (Word32) @@ -53,7 +53,7 @@ import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (CorrId (..), BlockingInfo, EntityId (..), RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TransmissionAuth, pattern NoEntity) +import Simplex.Messaging.Protocol (CorrId (..), BlockingInfo, EntityId (..), RcvPublicAuthKey, RcvPublicDhKey, RecipientId, TAuthorizations, pattern NoEntity) import Simplex.Messaging.Server (dummyVerifyCmd, verifyCmdAuthorization) import Simplex.Messaging.Server.Control (CPClientRole (..)) import Simplex.Messaging.Server.Expiration @@ -112,7 +112,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira srvCreds@(chain, pk) <- asks tlsServerCreds signKey <- liftIO $ case C.x509ToPrivate' pk of Right pk' -> pure pk' - Left e -> putStrLn ("servers has no valid key: " <> show e) >> exitFailure + Left e -> putStrLn ("Server has no valid key: " <> show e) >> exitFailure env <- ask sessions <- liftIO TM.emptyIO let cleanup sessionId = atomically $ TM.delete sessionId sessions @@ -120,7 +120,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira reqBody <- getHTTP2Body r xftpBlockSize let v = VersionXFTP 1 thServerVRange = versionToRange v - thParams0 = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = v, thServerVRange, thAuth = Nothing, implySessId = False, encryptBlock = Nothing, batch = True} + thParams0 = THandleParams {sessionId, blockSize = xftpBlockSize, thVersion = v, thServerVRange, thAuth = Nothing, implySessId = False, encryptBlock = Nothing, batch = True, serviceAuth = False} req0 = XFTPTransportRequest {thParams = thParams0, request = r, reqBody, sendResponse} flip runReaderT env $ case sessionALPN of Nothing -> processRequest req0 @@ -158,7 +158,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira unless (keyHash == kh) $ throwE HANDSHAKE case compatibleVRange' xftpServerVRange v of Just (Compatible vr) -> do - let auth = THAuthServer {serverPrivKey = pk, sessSecret' = Nothing} + let auth = THAuthServer {serverPrivKey = pk, peerClientService = Nothing, sessSecret' = Nothing} thParams = thParams0 {thAuth = Just auth, thVersion = v, thServerVRange = vr} atomically $ TM.insert sessionId (HandshakeAccepted thParams) sessions #ifdef slow_servers @@ -221,22 +221,22 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira fileDownloadAcks' <- atomicSwapIORef fileDownloadAcks 0 filesCount' <- readIORef filesCount filesSize' <- readIORef filesSize - hPutStrLn h $ - intercalate + T.hPutStrLn h $ + T.intercalate "," - [ iso8601Show $ utctDay fromTime', - show filesCreated', - show fileRecipients', - show filesUploaded', - show filesDeleted', + [ T.pack $ iso8601Show $ utctDay fromTime', + tshow filesCreated', + tshow fileRecipients', + tshow filesUploaded', + tshow filesDeleted', dayCount files, weekCount files, monthCount files, - show fileDownloads', - show fileDownloadAcks', - show filesCount', - show filesSize', - show filesExpired' + tshow fileDownloads', + tshow fileDownloadAcks', + tshow filesCount', + tshow filesSize', + tshow filesExpired' ] liftIO $ threadDelay' interval @@ -361,7 +361,7 @@ randomDelay = do data VerificationResult = VRVerified XFTPRequest | VRFailed XFTPErrorType -verifyXFTPTransmission :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TransmissionAuth -> ByteString -> XFTPFileId -> FileCmd -> M VerificationResult +verifyXFTPTransmission :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TAuthorizations -> ByteString -> XFTPFileId -> FileCmd -> M VerificationResult verifyXFTPTransmission auth_ tAuth authorized fId cmd = case cmd of FileCmd SFSender (FNEW file rcps auth') -> pure $ XFTPReqNew file rcps auth' `verifyWith` sndKey file diff --git a/src/Simplex/FileTransfer/Server/Main.hs b/src/Simplex/FileTransfer/Server/Main.hs index e2abc55ac..e8b818f5b 100644 --- a/src/Simplex/FileTransfer/Server/Main.hs +++ b/src/Simplex/FileTransfer/Server/Main.hs @@ -191,7 +191,8 @@ xftpServerCLI cfgPath logPath = do transportConfig = mkTransportServerConfig (fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini) - (Just alpnSupportedXFTPhandshakes), + (Just alpnSupportedXFTPhandshakes) + False, responseDelay = 0 } diff --git a/src/Simplex/FileTransfer/Transport.hs b/src/Simplex/FileTransfer/Transport.hs index ce0190f1f..80a2cc020 100644 --- a/src/Simplex/FileTransfer/Transport.hs +++ b/src/Simplex/FileTransfer/Transport.hs @@ -57,7 +57,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol (BlockingInfo, CommandError) -import Simplex.Messaging.Transport (ALPN, CertChainPubKey, SessionId, THandle (..), THandleParams (..), TransportError (..), TransportPeer (..)) +import Simplex.Messaging.Transport (ALPN, CertChainPubKey, ServiceCredentials, SessionId, THandle (..), THandleParams (..), TransportError (..), TransportPeer (..)) import Simplex.Messaging.Transport.HTTP2.File import Simplex.Messaging.Util (bshow, tshow) import Simplex.Messaging.Version @@ -101,8 +101,8 @@ supportedFileServerVRange :: VersionRangeXFTP supportedFileServerVRange = mkVersionRange initialXFTPVersion currentXFTPVersion -- XFTP protocol does not use this handshake method -xftpClientHandshakeStub :: c 'TClient -> Maybe C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> Bool -> ExceptT TransportError IO (THandle XFTPVersion c 'TClient) -xftpClientHandshakeStub _c _ks _keyHash _xftpVRange _proxyServer = throwE TEVersion +xftpClientHandshakeStub :: c 'TClient -> Maybe C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> Bool -> Maybe (ServiceCredentials, C.KeyPairEd25519) -> ExceptT TransportError IO (THandle XFTPVersion c 'TClient) +xftpClientHandshakeStub _c _ks _keyHash _xftpVRange _proxyServer _serviceKeys = throwE TEVersion alpnSupportedXFTPhandshakes :: [ALPN] alpnSupportedXFTPhandshakes = ["xftp/1"] diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index a3f5cf371..69ca55b87 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -73,6 +73,7 @@ module Simplex.Messaging.Agent getNotificationConns, resubscribeConnection, resubscribeConnections, + subscribeClientService, sendMessage, sendMessages, sendMessagesB, @@ -367,7 +368,7 @@ deleteConnectionsAsync c waitDelivery = withAgentEnv c . deleteConnectionsAsync' {-# INLINE deleteConnectionsAsync #-} -- | Create SMP agent connection (NEW command) -createConnection :: ConnectionModeI c => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe ConnInfo -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AE (ConnId, CreatedConnLink c) +createConnection :: ConnectionModeI c => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe ConnInfo -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AE (ConnId, (CreatedConnLink c, Maybe ClientServiceId)) createConnection c userId enableNtfs = withAgentEnv c .::. newConn c userId enableNtfs {-# INLINE createConnection #-} @@ -410,7 +411,7 @@ prepareConnectionToAccept c enableNtfs = withAgentEnv c .: newConnToAccept c "" {-# INLINE prepareConnectionToAccept #-} -- | Join SMP agent connection (JOIN command). -joinConnection :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE SndQueueSecured +joinConnection :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE (SndQueueSecured, Maybe ClientServiceId) joinConnection c userId connId enableNtfs = withAgentEnv c .:: joinConn c userId connId enableNtfs {-# INLINE joinConnection #-} @@ -420,7 +421,7 @@ allowConnection c = withAgentEnv c .:. allowConnection' c {-# INLINE allowConnection #-} -- | Accept contact after REQ notification (ACPT command) -acceptContact :: AgentClient -> ConnId -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE SndQueueSecured +acceptContact :: AgentClient -> ConnId -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE (SndQueueSecured, Maybe ClientServiceId) acceptContact c connId enableNtfs = withAgentEnv c .:: acceptContact' c connId enableNtfs {-# INLINE acceptContact #-} @@ -430,12 +431,12 @@ rejectContact c = withAgentEnv c .: rejectContact' c {-# INLINE rejectContact #-} -- | Subscribe to receive connection messages (SUB command) -subscribeConnection :: AgentClient -> ConnId -> AE () +subscribeConnection :: AgentClient -> ConnId -> AE (Maybe ClientServiceId) subscribeConnection c = withAgentEnv c . subscribeConnection' c {-# INLINE subscribeConnection #-} -- | Subscribe to receive connection messages from multiple connections, batching commands when possible -subscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType ())) +subscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) subscribeConnections c = withAgentEnv c . subscribeConnections' c {-# INLINE subscribeConnections #-} @@ -449,14 +450,19 @@ getNotificationConns :: AgentClient -> C.CbNonce -> ByteString -> AE (NonEmpty N getNotificationConns c = withAgentEnv c .: getNotificationConns' c {-# INLINE getNotificationConns #-} -resubscribeConnection :: AgentClient -> ConnId -> AE () +resubscribeConnection :: AgentClient -> ConnId -> AE (Maybe ClientServiceId) resubscribeConnection c = withAgentEnv c . resubscribeConnection' c {-# INLINE resubscribeConnection #-} -resubscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType ())) +resubscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) resubscribeConnections c = withAgentEnv c . resubscribeConnections' c {-# INLINE resubscribeConnections #-} +-- TODO [certs rcv] how to communicate that service ID changed - as error or as result? +subscribeClientService :: AgentClient -> ClientServiceId -> AE Int +subscribeClientService c = withAgentEnv c . subscribeClientService' c +{-# INLINE subscribeClientService #-} + -- | Send message to the connection (SEND command) sendMessage :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AE (AgentMsgId, PQEncryption) sendMessage c = withAgentEnv c .:: sendMessage' c @@ -826,7 +832,7 @@ switchConnectionAsync' c corrId connId = pure . connectionStats $ DuplexConnection cData rqs' sqs _ -> throwE $ CMD PROHIBITED "switchConnectionAsync: not duplex" -newConn :: ConnectionModeI c => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe ConnInfo -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AM (ConnId, CreatedConnLink c) +newConn :: ConnectionModeI c => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe ConnInfo -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AM (ConnId, (CreatedConnLink c, Maybe ClientServiceId)) newConn c userId enableNtfs cMode userData_ clientData pqInitKeys subMode = do srv <- getSMPServer c userId connId <- newConnNoQueues c userId enableNtfs cMode (CR.connPQEncryption pqInitKeys) @@ -929,7 +935,7 @@ changeConnectionUser' c oldUserId connId newUserId = do where updateConn = withStore' c $ \db -> setConnUserId db oldUserId connId newUserId -newRcvConnSrv :: forall c. ConnectionModeI c => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe ConnInfo -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (CreatedConnLink c) +newRcvConnSrv :: forall c. ConnectionModeI c => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe ConnInfo -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (CreatedConnLink c, Maybe ClientServiceId) newRcvConnSrv c userId connId enableNtfs cMode userData_ clientData pqInitKeys subMode srvWithAuth@(ProtoServerWithAuth srv _) = do case (cMode, pqInitKeys) of (SCMContact, CR.IKUsePQ) -> throwE $ CMD PROHIBITED "newRcvConnSrv" @@ -939,11 +945,13 @@ newRcvConnSrv c userId connId enableNtfs cMode userData_ clientData pqInitKeys s Just d -> do (nonce, qUri, cReq, qd) <- prepareLinkData d $ fst e2eKeys (rq, qUri') <- createRcvQueue (Just nonce) qd e2eKeys - connReqWithShortLink qUri cReq qUri' (shortLink rq) + ccLink <- connReqWithShortLink qUri cReq qUri' (shortLink rq) + pure (ccLink, clientServiceId rq) Nothing -> do let qd = case cMode of SCMContact -> CQRContact Nothing; SCMInvitation -> CQRMessaging Nothing - (_, qUri) <- createRcvQueue Nothing qd e2eKeys - (`CCLink` Nothing) <$> createConnReq qUri + (rq, qUri) <- createRcvQueue Nothing qd e2eKeys + cReq <- createConnReq qUri + pure (CCLink cReq Nothing, clientServiceId rq) where createRcvQueue :: Maybe C.CbNonce -> ClntQueueReqData -> C.KeyPairX25519 -> AM (RcvQueue, SMPQueueUri) createRcvQueue nonce_ qd e2eKeys = do @@ -1033,7 +1041,7 @@ newConnToAccept c connId enableNtfs invId pqSup = do newConnToJoin c userId connId enableNtfs connReq pqSup _ -> throwE $ CMD PROHIBITED "newConnToAccept" -joinConn :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM SndQueueSecured +joinConn :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM (SndQueueSecured, Maybe ClientServiceId) joinConn c userId connId enableNtfs cReq cInfo pqSupport subMode = do srv <- getNextSMPServer c userId [qServer $ connReqQueue cReq] joinConnSrv c userId connId enableNtfs cReq cInfo pqSupport subMode srv @@ -1113,7 +1121,7 @@ versionPQSupport_ :: VersionSMPA -> Maybe CR.VersionE2E -> PQSupport versionPQSupport_ agentV e2eV_ = PQSupport $ agentV >= pqdrSMPAgentVersion && maybe True (>= CR.pqRatchetE2EEncryptVersion) e2eV_ {-# INLINE versionPQSupport_ #-} -joinConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM SndQueueSecured +joinConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM (SndQueueSecured, Maybe ClientServiceId) joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do SomeConn cType conn <- withStore c (`getConn` connId) @@ -1123,7 +1131,7 @@ joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMod DuplexConnection _ (RcvQueue {status = New} :| _) (sq@SndQueue {status = New} :| _) -> doJoin $ Just sq _ -> throwE $ CMD PROHIBITED $ "joinConnSrv: bad connection " <> show cType where - doJoin :: Maybe SndQueue -> AM SndQueueSecured + doJoin :: Maybe SndQueue -> AM (SndQueueSecured, Maybe ClientServiceId) doJoin sq_ = do (cData, sq, e2eSndParams, lnkId_) <- startJoinInvitation c userId connId sq_ enableNtfs inv pqSup secureConfirmQueue c cData sq srv cInfo (Just e2eSndParams) subMode @@ -1131,9 +1139,9 @@ joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMod joinConnSrv c userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup subMode srv = lift (compatibleContactUri cReqUri) >>= \case Just (qInfo, vrsn) -> do - CCLink cReq _ <- newRcvConnSrv c userId connId enableNtfs SCMInvitation Nothing Nothing (CR.IKNoPQ pqSup) subMode srv + (CCLink cReq _, service) <- newRcvConnSrv c userId connId enableNtfs SCMInvitation Nothing Nothing (CR.IKNoPQ pqSup) subMode srv void $ sendInvitation c userId connId qInfo vrsn cReq cInfo - pure False + pure (False, service) Nothing -> throwE $ AGENT A_VERSION delInvSL :: AgentClient -> ConnId -> SMPServerWithAuth -> SMP.LinkId -> AM () @@ -1141,7 +1149,7 @@ delInvSL c connId srv lnkId = withStore' c (\db -> deleteInvShortLink db (protoServer srv) lnkId) `catchE` \e -> liftIO $ nonBlockingWriteTBQueue (subQ c) ("", connId, AEvt SAEConn (ERR $ INTERNAL $ "error deleting short link " <> show e)) -joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM SndQueueSecured +joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM (SndQueueSecured, Maybe ClientServiceId) joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = do SomeConn cType conn <- withStore c (`getConn` connId) case conn of @@ -1149,7 +1157,7 @@ joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSuppo SndConnection _ sq -> doJoin $ Just sq _ -> throwE $ CMD PROHIBITED $ "joinConnSrvAsync: bad connection " <> show cType where - doJoin :: Maybe SndQueue -> AM SndQueueSecured + doJoin :: Maybe SndQueue -> AM (SndQueueSecured, Maybe ClientServiceId) doJoin sq_ = do (cData, sq, e2eSndParams, lnkId_) <- startJoinInvitation c userId connId sq_ enableNtfs inv pqSupport secureConfirmQueueAsync c cData sq srv cInfo (Just e2eSndParams) subMode @@ -1157,7 +1165,7 @@ joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSuppo joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode _pqSupport _srv = do throwE $ CMD PROHIBITED "joinConnSrvAsync" -createReplyQueue :: AgentClient -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> AM SMPQueueInfo +createReplyQueue :: AgentClient -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> AM (SMPQueueInfo, Maybe ClientServiceId) createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} subMode srv = do -- TODO [notifications] send correct NTF credentials here (rq, qUri, tSess, sessId) <- newRcvQueue c userId connId srv (versionToRange smpClientVersion) SCMInvitation subMode -- Nothing @@ -1168,7 +1176,7 @@ createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVers when enableNtfs $ do ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (NSCCreate, [connId]) - pure qInfo + pure (qInfo, clientServiceId rq') -- | Approve confirmation (LET command) in Reader monad allowConnection' :: AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> AM () @@ -1181,14 +1189,14 @@ allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConne _ -> throwE $ CMD PROHIBITED "allowConnection" -- | Accept contact (ACPT command) in Reader monad -acceptContact' :: AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM SndQueueSecured +acceptContact' :: AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM (SndQueueSecured, Maybe ClientServiceId) acceptContact' c connId enableNtfs invId ownConnInfo pqSupport subMode = withConnLock c connId "acceptContact" $ do Invitation {contactConnId, connReq} <- withStore c $ \db -> getInvitation db "acceptContact'" invId withStore c (`getConn` contactConnId) >>= \case SomeConn _ (ContactConnection ConnData {userId} _) -> do - sqSecured <- joinConn c userId connId enableNtfs connReq ownConnInfo pqSupport subMode + r <- joinConn c userId connId enableNtfs connReq ownConnInfo pqSupport subMode withStore' c $ \db -> acceptInvitation db invId ownConnInfo - pure sqSecured + pure r _ -> throwE $ CMD PROHIBITED "acceptContact" -- | Reject contact (RJCT command) in Reader monad @@ -1198,19 +1206,23 @@ rejectContact' c contactConnId invId = {-# INLINE rejectContact' #-} -- | Subscribe to receive connection messages (SUB command) in Reader monad -subscribeConnection' :: AgentClient -> ConnId -> AM () +subscribeConnection' :: AgentClient -> ConnId -> AM (Maybe ClientServiceId) subscribeConnection' c connId = toConnResult connId =<< subscribeConnections' c [connId] {-# INLINE subscribeConnection' #-} -toConnResult :: ConnId -> Map ConnId (Either AgentErrorType ()) -> AM () +toConnResult :: ConnId -> Map ConnId (Either AgentErrorType a) -> AM a toConnResult connId rs = case M.lookup connId rs of - Just (Right ()) -> when (M.size rs > 1) $ logError $ T.pack $ "too many results " <> show (M.size rs) + Just (Right r) -> r <$ when (M.size rs > 1) (logError $ T.pack $ "too many results " <> show (M.size rs)) Just (Left e) -> throwE e _ -> throwE $ INTERNAL $ "no result for connection " <> B.unpack connId -type QCmdResult = (QueueStatus, Either AgentErrorType ()) +type QCmdResult a = (QueueStatus, Either AgentErrorType a) -subscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) +type QDelResult = QCmdResult () + +type QSubResult = QCmdResult (Maybe SMP.ServiceId) + +subscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) subscribeConnections' _ [] = pure M.empty subscribeConnections' c connIds = do conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (`getConns` connIds) @@ -1220,41 +1232,45 @@ subscribeConnections' c connIds = do resumeDelivery cs lift $ resumeConnCmds c $ M.keys cs rcvRs <- lift $ connResults . fst <$> subscribeQueues c (concat $ M.elems rcvQs) + rcvRs' <- storeClientServiceAssocs rcvRs ns <- asks ntfSupervisor tkn <- readTVarIO (ntfTkn ns) - lift $ when (instantNotifications tkn) . void . forkIO . void $ sendNtfCreate ns rcvRs cs - let rs = M.unions ([errs', subRs, rcvRs] :: [Map ConnId (Either AgentErrorType ())]) + lift $ when (instantNotifications tkn) . void . forkIO . void $ sendNtfCreate ns rcvRs' cs + let rs = M.unions ([errs', subRs, rcvRs'] :: [Map ConnId (Either AgentErrorType (Maybe ClientServiceId))]) notifyResultError rs pure rs where - rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue] + rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType (Maybe ClientServiceId)) [RcvQueue] rcvQueueOrResult (SomeConn _ conn) = case conn of DuplexConnection _ rqs _ -> Right $ L.toList rqs SndConnection _ sq -> Left $ sndSubResult sq RcvConnection _ rq -> Right [rq] ContactConnection _ rq -> Right [rq] - NewConnection _ -> Left (Right ()) - sndSubResult :: SndQueue -> Either AgentErrorType () + NewConnection _ -> Left (Right Nothing) + sndSubResult :: SndQueue -> Either AgentErrorType (Maybe ClientServiceId) sndSubResult SndQueue {status} = case status of - Confirmed -> Right () + Confirmed -> Right Nothing Active -> Left $ CONN SIMPLEX _ -> Left $ INTERNAL "unexpected queue status" - connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ()) + connResults :: [(RcvQueue, Either AgentErrorType (Maybe SMP.ServiceId))] -> Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) connResults = M.map snd . foldl' addResult M.empty where -- collects results by connection ID - addResult :: Map ConnId QCmdResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QCmdResult + addResult :: Map ConnId QSubResult -> (RcvQueue, Either AgentErrorType (Maybe SMP.ServiceId)) -> Map ConnId QSubResult addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs -- combines two results for one connection, by using only Active queues (if there is at least one Active queue) - combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult + combineRes :: QSubResult -> Maybe QSubResult -> Maybe QSubResult combineRes r' (Just r) = Just $ if order r <= order r' then r else r' combineRes r' _ = Just r' - order :: QCmdResult -> Int + order :: QSubResult -> Int order (Active, Right _) = 1 order (Active, _) = 2 order (_, Right _) = 3 order _ = 4 - sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> Map ConnId SomeConn -> AM' () + -- TODO [certs rcv] store associations of queues with client service ID + storeClientServiceAssocs :: Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) + storeClientServiceAssocs = pure . M.map (Nothing <$) + sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType (Maybe ClientServiceId)) -> Map ConnId SomeConn -> AM' () sendNtfCreate ns rcvRs cs = do let oks = M.keysSet $ M.filter (either temporaryAgentError $ const True) rcvRs cs' = M.restrictKeys cs oks @@ -1272,25 +1288,29 @@ subscribeConnections' c connIds = do DuplexConnection cData _ sqs -> Just (cData, sqs) SndConnection cData sq -> Just (cData, [sq]) _ -> Nothing - notifyResultError :: Map ConnId (Either AgentErrorType ()) -> AM () + notifyResultError :: Map ConnId (Either AgentErrorType (Maybe ClientServiceId)) -> AM () notifyResultError rs = do let actual = M.size rs expected = length connIds when (actual /= expected) . atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected) -resubscribeConnection' :: AgentClient -> ConnId -> AM () +resubscribeConnection' :: AgentClient -> ConnId -> AM (Maybe ClientServiceId) resubscribeConnection' c connId = toConnResult connId =<< resubscribeConnections' c [connId] {-# INLINE resubscribeConnection' #-} -resubscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) +resubscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) resubscribeConnections' _ [] = pure M.empty resubscribeConnections' c connIds = do - let r = M.fromList . zip connIds . repeat $ Right () + let r = M.fromList . zip connIds . repeat $ Right Nothing connIds' <- filterM (fmap not . atomically . hasActiveSubscription c) connIds -- union is left-biased, so results returned by subscribeConnections' take precedence (`M.union` r) <$> subscribeConnections' c connIds' +-- TODO [certs rcv] +subscribeClientService' :: AgentClient -> ClientServiceId -> AM Int +subscribeClientService' = undefined + -- requesting messages sequentially, to reduce memory usage getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta))) getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage @@ -1444,13 +1464,13 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do NEW enableNtfs (ACM cMode) pqEnc subMode -> noServer $ do triedHosts <- newTVarIO S.empty tryCommand . withNextSrv c userId storageSrvs triedHosts [] $ \srv -> do - CCLink cReq _ <- newRcvConnSrv c userId connId enableNtfs cMode Nothing Nothing pqEnc subMode srv - notify $ INV (ACR cMode cReq) + (CCLink cReq _, service) <- newRcvConnSrv c userId connId enableNtfs cMode Nothing Nothing pqEnc subMode srv + notify $ INV (ACR cMode cReq) service JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) pqEnc subMode connInfo -> noServer $ do triedHosts <- newTVarIO S.empty tryCommand . withNextSrv c userId storageSrvs triedHosts [qServer q] $ \srv -> do - sqSecured <- joinConnSrvAsync c userId connId enableNtfs cReq connInfo pqEnc subMode srv - notify $ JOINED sqSecured + (sqSecured, service) <- joinConnSrvAsync c userId connId enableNtfs cReq connInfo pqEnc subMode srv + notify $ JOINED sqSecured service LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK ACK msgId rcptInfo_ -> withServer' . tryCommand $ ackMessage' c connId msgId rcptInfo_ >> notify OK SWCH -> @@ -2114,13 +2134,13 @@ deleteConnQueues c waitDelivery ntf rqs = do connResults = M.map snd . foldl' addResult M.empty where -- collects results by connection ID - addResult :: Map ConnId QCmdResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QCmdResult + addResult :: Map ConnId QDelResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QDelResult addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs -- combines two results for one connection, by prioritizing errors in Active queues - combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult + combineRes :: QDelResult -> Maybe QDelResult -> Maybe QDelResult combineRes r' (Just r) = Just $ if order r <= order r' then r else r' combineRes r' _ = Just r' - order :: QCmdResult -> Int + order :: QDelResult -> Int order (Active, Left _) = 1 order (_, Left _) = 2 order _ = 3 @@ -2448,7 +2468,7 @@ debugAgentLocks AgentClient {connLocks = cs, invLocks = is, deleteLock = d} = do delLock <- atomically $ tryReadTMVar d pure AgentLocks {connLocks, invLocks, delLock} where - getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls) + getLocks ls = atomically $ M.mapKeys (safeDecodeUtf8 . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls) getSMPServer :: AgentClient -> UserId -> AM SMPServerWithAuth getSMPServer c userId = getNextSMPServer c userId [] @@ -2553,6 +2573,8 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId withRcvConn entId $ \rq conn -> case cmd of SMP.SUB -> case respOrErr of Right SMP.OK -> processSubOk rq upConnIds + -- TODO [certs rcv] associate queue with the service + Right (SMP.SOK serviceId_) -> processSubOk rq upConnIds Right msg@SMP.MSG {} -> do processSubOk rq upConnIds -- the connection is UP even when processing this particular message fails runProcessSMP rq conn (toConnData conn) msg @@ -3154,20 +3176,22 @@ connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo sq_ (qInfo :| _ (sq, _) <- lift $ newSndQueue userId connId qInfo' Nothing withStore c $ \db -> upgradeRcvConnToDuplex db connId sq -secureConfirmQueueAsync :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM SndQueueSecured +secureConfirmQueueAsync :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM (SndQueueSecured, Maybe ClientServiceId) secureConfirmQueueAsync c cData sq srv connInfo e2eEncryption_ subMode = do sqSecured <- agentSecureSndQueue c cData sq - storeConfirmation c cData sq e2eEncryption_ =<< mkAgentConfirmation c cData sq srv connInfo subMode + (qInfo, service) <- mkAgentConfirmation c cData sq srv connInfo subMode + storeConfirmation c cData sq e2eEncryption_ qInfo lift $ submitPendingMsg c cData sq - pure sqSecured + pure (sqSecured, service) -secureConfirmQueue :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM SndQueueSecured +secureConfirmQueue :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM (SndQueueSecured, Maybe ClientServiceId) secureConfirmQueue c cData@ConnData {connId, connAgentVersion, pqSupport} sq srv connInfo e2eEncryption_ subMode = do sqSecured <- agentSecureSndQueue c cData sq - msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode + (qInfo, service) <- mkAgentConfirmation c cData sq srv connInfo subMode + msg <- mkConfirmation qInfo void $ sendConfirmation c sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed - pure sqSecured + pure (sqSecured, service) where mkConfirmation :: AgentMessage -> AM MsgBody mkConfirmation aMessage = do @@ -3193,10 +3217,10 @@ agentSecureSndQueue c ConnData {connAgentVersion} sq@SndQueue {queueMode, status sndSecure = senderCanSecure queueMode initiatorRatchetOnConf = connAgentVersion >= ratchetOnConfSMPAgentVersion -mkAgentConfirmation :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> AM AgentMessage +mkAgentConfirmation :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> AM (AgentMessage, Maybe ClientServiceId) mkAgentConfirmation c cData sq srv connInfo subMode = do - qInfo <- createReplyQueue c cData sq subMode srv - pure $ AgentConnInfoReply (qInfo :| []) connInfo + (qInfo, service) <- createReplyQueue c cData sq subMode srv + pure (AgentConnInfoReply (qInfo :| []) connInfo, service) enqueueConfirmation :: AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> AM () enqueueConfirmation c cData sq connInfo e2eEncryption_ = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 5e93391e2..3f770f468 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -246,6 +246,7 @@ import Simplex.Messaging.Protocol ( AProtocolType (..), BrokerMsg, EntityId (..), + ServiceId, ErrorType, MsgFlags (..), MsgId, @@ -457,9 +458,9 @@ data AgentState = ASForeground | ASSuspending | ASSuspended deriving (Eq, Show) data AgentLocks = AgentLocks - { connLocks :: Map String String, - invLocks :: Map String String, - delLock :: Maybe String + { connLocks :: Map Text Text, + invLocks :: Map Text Text, + delLock :: Maybe Text } deriving (Show) @@ -985,32 +986,32 @@ closeXFTPServerClient :: AgentClient -> UserId -> XFTPServer -> FileDigest -> IO closeXFTPServerClient c userId server (FileDigest chunkDigest) = mkTransportSession c userId server chunkDigest >>= closeClient c xftpClients -withConnLock :: AgentClient -> ConnId -> String -> AM a -> AM a +withConnLock :: AgentClient -> ConnId -> Text -> AM a -> AM a withConnLock c connId name = ExceptT . withConnLock' c connId name . runExceptT {-# INLINE withConnLock #-} -withConnLock' :: AgentClient -> ConnId -> String -> AM' a -> AM' a +withConnLock' :: AgentClient -> ConnId -> Text -> AM' a -> AM' a withConnLock' _ "" _ = id withConnLock' AgentClient {connLocks} connId name = withLockMap connLocks connId name {-# INLINE withConnLock' #-} -withInvLock :: AgentClient -> ByteString -> String -> AM a -> AM a +withInvLock :: AgentClient -> ByteString -> Text -> AM a -> AM a withInvLock c key name = ExceptT . withInvLock' c key name . runExceptT {-# INLINE withInvLock #-} -withInvLock' :: AgentClient -> ByteString -> String -> AM' a -> AM' a +withInvLock' :: AgentClient -> ByteString -> Text -> AM' a -> AM' a withInvLock' AgentClient {invLocks} = withLockMap invLocks {-# INLINE withInvLock' #-} -withConnLocks :: AgentClient -> Set ConnId -> String -> AM' a -> AM' a +withConnLocks :: AgentClient -> Set ConnId -> Text -> AM' a -> AM' a withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks {-# INLINE withConnLocks #-} -withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a +withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> Text -> m a -> m a withLockMap = withGetLock . getMapLock {-# INLINE withLockMap #-} -withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> String -> m a -> m a +withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> Text -> m a -> m a withLocksMap_ = withGetLocks . getMapLock {-# INLINE withLocksMap_ #-} @@ -1196,6 +1197,7 @@ protocolClientError protocolError_ host = \case PCEIncompatibleHost -> BROKER host HOST PCETransportError e -> BROKER host $ TRANSPORT e e@PCECryptoError {} -> INTERNAL $ show e + PCEServiceUnavailable {} -> BROKER host NO_SERVICE PCEIOError {} -> BROKER host NETWORK data ProtocolTestStep @@ -1377,9 +1379,10 @@ newRcvQueue_ c userId connId (ProtoServerWithAuth srv auth) vRange cqrd subMode logServer "-->" c srv NoEntity "NEW" tSess <- mkTransportSession c userId srv connId -- TODO [notifications] - r@(thParams', QIK {rcvId, sndId, rcvPublicDhKey, queueMode}) <- + r@(thParams', QIK {rcvId, sndId, rcvPublicDhKey, queueMode, serviceId}) <- withClient c tSess $ \(SMPConnectedClient smp _) -> (thParams smp,) <$> createSMPQueue smp nonce_ rKeys dhKey auth subMode (queueReqData cqrd) + -- TODO [certs rcv] validate that serviceId is the same as in the client session liftIO . logServer "<--" c srv NoEntity $ B.unwords ["IDS", logSecret rcvId, logSecret sndId] shortLink <- mkShortLinkCreds r let rq = @@ -1395,6 +1398,7 @@ newRcvQueue_ c userId connId (ProtoServerWithAuth srv auth) vRange cqrd subMode sndId, queueMode, shortLink, + clientService = ClientService DBNewEntity <$> serviceId, status = New, dbQueueId = DBNewEntity, primary = True, @@ -1434,13 +1438,13 @@ newRcvQueue_ c userId connId (ProtoServerWithAuth srv auth) vRange cqrd subMode newErr :: String -> AM (Maybe ShortLinkCreds) newErr = throwE . BROKER (B.unpack $ strEncode srv) . UNEXPECTED . ("Create queue: " <>) -processSubResult :: AgentClient -> SessionId -> RcvQueue -> Either SMPClientError () -> STM () +processSubResult :: AgentClient -> SessionId -> RcvQueue -> Either SMPClientError (Maybe ServiceId) -> STM () processSubResult c sessId rq@RcvQueue {userId, server, connId} = \case Left e -> unless (temporaryClientError e) $ do incSMPServerStat c userId server connSubErrs failSubscription c rq e - Right () -> + Right _serviceId -> -- TODO [certs rcv] store association with the service ifM (hasPendingSubscription c connId) (incSMPServerStat c userId server connSubscribed >> addSubscription c sessId rq) @@ -1479,7 +1483,7 @@ serverHostError = \case _ -> False -- | Subscribe to queues. The list of results can have a different order. -subscribeQueues :: AgentClient -> [RcvQueue] -> AM' ([(RcvQueue, Either AgentErrorType ())], Maybe SessionId) +subscribeQueues :: AgentClient -> [RcvQueue] -> AM' ([(RcvQueue, Either AgentErrorType (Maybe ServiceId))], Maybe SessionId) subscribeQueues c qs = do (errs, qs') <- partitionEithers <$> mapM checkQueue qs atomically $ do @@ -1494,7 +1498,7 @@ subscribeQueues c qs = do checkQueue rq = do prohibited <- liftIO $ hasGetLock c rq pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED "subscribeQueues") else Right rq - subscribeQueues_ :: Env -> TVar (Maybe SessionId) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses RcvQueue SMPClientError ()) + subscribeQueues_ :: Env -> TVar (Maybe SessionId) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses RcvQueue SMPClientError (Maybe ServiceId)) subscribeQueues_ env session smp qs' = do let (userId, srv, _) = transportSession' smp atomically $ incSMPServerStat' c userId srv connSubAttempts $ length qs' @@ -1514,7 +1518,7 @@ subscribeQueues c qs = do tSess = transportSession' smp sessId = sessionId $ thParams smp hasTempErrors = any (either temporaryClientError (const False) . snd) - processSubResults :: NonEmpty (RcvQueue, Either SMPClientError ()) -> STM () + processSubResults :: NonEmpty (RcvQueue, Either SMPClientError (Maybe ServiceId)) -> STM () processSubResults = mapM_ $ uncurry $ processSubResult c sessId resubscribe = resubscribeSMPSession c tSess `runReaderT` env @@ -1551,7 +1555,7 @@ sendTSessionBatches statCmd toRQ action c qs = where agentError = second . first $ protocolClientError SMP $ clientServer smp -sendBatch :: (SMPClient -> NonEmpty (SMP.RecipientId, SMP.RcvPrivateAuthKey) -> IO (NonEmpty (Either SMPClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses RcvQueue SMPClientError ()) +sendBatch :: (SMPClient -> NonEmpty (SMP.RecipientId, SMP.RcvPrivateAuthKey) -> IO (NonEmpty (Either SMPClientError a))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses RcvQueue SMPClientError a) sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs) where queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvId, rcvPrivateKey) diff --git a/src/Simplex/Messaging/Agent/Lock.hs b/src/Simplex/Messaging/Agent/Lock.hs index 43b2358fd..18a6f0985 100644 --- a/src/Simplex/Messaging/Agent/Lock.hs +++ b/src/Simplex/Messaging/Agent/Lock.hs @@ -16,11 +16,12 @@ import Control.Monad.IO.Unlift import Data.Functor (($>)) import Data.Set (Set) import qualified Data.Set as S +import Data.Text (Text) import UnliftIO.Async (forConcurrently) import qualified UnliftIO.Exception as E import UnliftIO.STM -type Lock = TMVar String +type Lock = TMVar Text createLock :: STM Lock createLock = newEmptyTMVar @@ -30,24 +31,24 @@ createLockIO :: IO Lock createLockIO = newEmptyTMVarIO {-# INLINE createLockIO #-} -withLock :: MonadUnliftIO m => Lock -> String -> ExceptT e m a -> ExceptT e m a +withLock :: MonadUnliftIO m => Lock -> Text -> ExceptT e m a -> ExceptT e m a withLock lock name = ExceptT . withLock' lock name . runExceptT {-# INLINE withLock #-} -withLock' :: MonadUnliftIO m => Lock -> String -> m a -> m a +withLock' :: MonadUnliftIO m => Lock -> Text -> m a -> m a withLock' lock name = E.bracket_ (atomically $ putTMVar lock name) (void . atomically $ takeTMVar lock) -withGetLock :: MonadUnliftIO m => (k -> STM Lock) -> k -> String -> m a -> m a +withGetLock :: MonadUnliftIO m => (k -> STM Lock) -> k -> Text -> m a -> m a withGetLock getLock key name a = E.bracket (atomically $ getPutLock getLock key name) (atomically . takeTMVar) (const a) -withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> Set k -> String -> m a -> m a +withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> Set k -> Text -> m a -> m a withGetLocks getLock keys name = E.bracket holdLocks releaseLocks . const where holdLocks = forConcurrently (S.toList keys) $ \key -> atomically $ getPutLock getLock key name @@ -55,5 +56,5 @@ withGetLocks getLock keys name = E.bracket holdLocks releaseLocks . const -- getLock and putTMVar can be in one transaction on the assumption that getLock doesn't write in case the lock already exists, -- and in case it is created and added to some shared resource (we use TMap) it also helps avoid contention for the newly created lock. -getPutLock :: (k -> STM Lock) -> k -> String -> STM Lock +getPutLock :: (k -> STM Lock) -> k -> Text -> STM Lock getPutLock getLock key name = getLock key >>= \l -> putTMVar l name $> l diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index cfdb4eeaa..8a29a9299 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -122,6 +122,9 @@ module Simplex.Messaging.Agent.Protocol ContactConnType (..), ShortLinkScheme (..), LinkKey (..), + StoredClientService (..), + ClientService, + ClientServiceId, sameConnReqContact, sameShortLinkContact, simplexChat, @@ -193,12 +196,13 @@ import Data.Time.Clock.System (SystemTime) import Data.Type.Equality import Data.Typeable (Typeable) import Data.Word (Word16, Word32) -import Simplex.Messaging.Agent.Store.DB (Binary (..), FromField (..), ToField (..), blobFieldDecoder, fromTextField_) import Simplex.FileTransfer.Description import Simplex.FileTransfer.Protocol (FileParty (..)) import Simplex.FileTransfer.Transport (XFTPErrorType) import Simplex.FileTransfer.Types (FileErrorType) import Simplex.Messaging.Agent.QueryString +import Simplex.Messaging.Agent.Store.DB (Binary (..), FromField (..), ToField (..), blobFieldDecoder, fromTextField_) +import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Client (ProxyClientError) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet @@ -367,7 +371,7 @@ type SndQueueSecured = Bool -- | Parameterized type for SMP agent events data AEvent (e :: AEntity) where - INV :: AConnectionRequestUri -> AEvent AEConn + INV :: AConnectionRequestUri -> Maybe ClientServiceId -> AEvent AEConn CONF :: ConfirmationId -> PQSupport -> [SMPServer] -> ConnInfo -> AEvent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake REQ :: InvitationId -> PQSupport -> NonEmpty SMPServer -> ConnInfo -> AEvent AEConn -- ConnInfo is from sender INFO :: PQSupport -> ConnInfo -> AEvent AEConn @@ -393,7 +397,7 @@ data AEvent (e :: AEntity) where DEL_USER :: Int64 -> AEvent AENone STAT :: ConnectionStats -> AEvent AEConn OK :: AEvent AEConn - JOINED :: SndQueueSecured -> AEvent AEConn + JOINED :: SndQueueSecured -> Maybe ClientServiceId -> AEvent AEConn ERR :: AgentErrorType -> AEvent AEConn ERRS :: [(ConnId, AgentErrorType)] -> AEvent AENone SUSPENDED :: AEvent AENone @@ -493,7 +497,7 @@ aCommandTag = \case aEventTag :: AEvent e -> AEventTag e aEventTag = \case - INV _ -> INV_ + INV {} -> INV_ CONF {} -> CONF_ REQ {} -> REQ_ INFO {} -> INFO_ @@ -519,7 +523,7 @@ aEventTag = \case DEL_USER _ -> DEL_USER_ STAT _ -> STAT_ OK -> OK_ - JOINED _ -> JOINED_ + JOINED {} -> JOINED_ ERR _ -> ERR_ ERRS _ -> ERRS_ SUSPENDED -> SUSPENDED_ @@ -1512,7 +1516,7 @@ instance StrEncoding AConnShortLink where <|> "https://" *> ((SLSServer,) . Just <$> strP) <|> fail "bad short link scheme" contactTypeP = do - Just <$> (A.anyChar >>= ctTypeP . toUpper) + Just <$> (A.anyChar >>= ctTypeP . toUpper) <|> A.char 'i' $> Nothing <|> fail "unknown short link type" serverQueryP h_ = @@ -1549,7 +1553,7 @@ ctTypeP :: Char -> Parser ContactConnType ctTypeP = \case 'A' -> pure CCTContact 'C' -> pure CCTChannel - 'G' -> pure CCTGroup + 'G' -> pure CCTGroup _ -> fail "unknown contact address type" {-# INLINE ctTypeP #-} @@ -1702,6 +1706,16 @@ instance Encoding AConnLinkData where userData <- smpP pure $ ACLD SCMContact ContactLinkData {agentVRange, direct, owners, relays, userData} +data StoredClientService (s :: DBStored) = ClientService + { dbServiceId :: DBEntityId' s, + serviceId :: SMP.ServiceId + } + deriving (Eq, Show) + +type ClientService = StoredClientService 'DBStored + +type ClientServiceId = DBEntityId + -- | SMP queue status. data QueueStatus = -- | queue is created diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 1c46cb60c..1613cc190 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -84,6 +84,8 @@ data StoredRcvQueue (q :: DBStored) = RcvQueue queueMode :: Maybe QueueMode, -- | short link ID and credentials shortLink :: Maybe ShortLinkCreds, + -- | associated client service + clientService :: Maybe (StoredClientService q), -- | queue status status :: QueueStatus, -- | database queue ID (within connection) @@ -109,6 +111,10 @@ data ShortLinkCreds = ShortLinkCreds } deriving (Show) +clientServiceId :: RcvQueue -> Maybe ClientServiceId +clientServiceId = fmap dbServiceId . clientService +{-# INLINE clientServiceId #-} + rcvQueueInfo :: RcvQueue -> RcvQueueInfo rcvQueueInfo rq@RcvQueue {server, rcvSwchStatus} = RcvQueueInfo {rcvServer = server, rcvSwitchStatus = rcvSwchStatus, canAbortSwitch = canAbortRcvSwitch rq} diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index 2095a4462..137c320c8 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -381,6 +381,7 @@ createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode createNewConn db gVar cData cMode = do fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode) +-- TODO [certs rcv] store clientServiceId from NewRcvQueue updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) updateNewConnRcv db connId rq = getConn db connId $>>= \case @@ -473,6 +474,7 @@ upgradeRcvConnToDuplex db connId sq = (SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq (SomeConn c _) -> pure . Left . SEBadConnType $ connType c +-- TODO [certs rcv] store clientServiceId from NewRcvQueue upgradeSndConnToDuplex :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) upgradeSndConnToDuplex db connId rq = getConn db connId >>= \case @@ -480,6 +482,7 @@ upgradeSndConnToDuplex db connId rq = Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c _ -> pure $ Left SEConnNotFound +-- TODO [certs rcv] store clientServiceId from NewRcvQueue addConnRcvQueue :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) addConnRcvQueue db connId rq = getConn db connId >>= \case @@ -1976,7 +1979,8 @@ insertRcvQueue_ db connId' rq@RcvQueue {..} serverKeyHash_ = do :. (sndId, queueMode, status, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_) :. (shortLinkId <$> shortLink, shortLinkKey <$> shortLink, linkPrivSigKey <$> shortLink, linkEncFixedData <$> shortLink) ) - pure (rq :: NewRcvQueue) {connId = connId', dbQueueId = qId} + -- TODO [certs rcv] save client service + pure (rq :: NewRcvQueue) {connId = connId', dbQueueId = qId, clientService = Nothing} -- * createSndConn helpers @@ -2170,7 +2174,8 @@ toRcvQueue shortLink = case (shortLinkId_, shortLinkKey_, linkPrivSigKey_, linkEncFixedData_) of (Just shortLinkId, Just shortLinkKey, Just linkPrivSigKey, Just linkEncFixedData) -> Just ShortLinkCreds {shortLinkId, shortLinkKey, linkPrivSigKey, linkEncFixedData} _ -> Nothing - in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode, shortLink, status, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion, clientNtfCreds, deleteErrors} + -- TODO [certs rcv] read client service + in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode, shortLink, clientService = Nothing, status, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion, clientNtfCreds, deleteErrors} getRcvQueueById :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue) getRcvQueueById db connId dbRcvId = diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20250517_service_certs.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20250517_service_certs.hs new file mode 100644 index 000000000..48f847091 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20250517_service_certs.hs @@ -0,0 +1,39 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250517_service_certs where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +m20250517_service_certs :: Query +m20250517_service_certs = + [sql| +CREATE TABLE server_certs( + server_cert_id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL REFERENCES users ON UPDATE RESTRICT ON DELETE CASCADE, + host TEXT NOT NULL, + port TEXT NOT NULL, + certificate BLOB NOT NULL, + priv_key BLOB NOT NULL, + service_id BLOB, + FOREIGN KEY(host, port) REFERENCES servers ON UPDATE CASCADE ON DELETE RESTRICT, +); + +CREATE UNIQUE INDEX idx_server_certs_user_id_host_port ON server_certs(user_id, host, port); + +CREATE INDEX idx_server_certs_host_port ON server_certs(host, port); + +ALTER TABLE rcv_queues ADD COLUMN rcv_service_id BLOB; + |] + +down_m20250517_service_certs :: Query +down_m20250517_service_certs = + [sql| +ALTER TABLE rcv_queues DROP COLUMN rcv_service_id; + +DROP INDEX idx_server_certs_host_port; + +DROP INDEX idx_server_certs_user_id_host_port; + +DROP TABLE server_certs; + |] diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index fc818577a..8c048589c 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -46,6 +46,8 @@ module Simplex.Messaging.Client getSMPMessage, subscribeSMPQueueNotifications, subscribeSMPQueuesNtfs, + subscribeService, + smpClientService, secureSMPQueue, secureSndSMPQueue, proxySecureSndSMPQueue, @@ -92,6 +94,7 @@ module Simplex.Messaging.Client clientSocksCredentials, chooseTransportHost, temporaryClientError, + smpClientServiceError, smpProxyError, textToHostMode, ServerTransmissionBatch, @@ -153,7 +156,7 @@ import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client (SocksAuth (..), SocksProxyWithAuth (..), TransportClientConfig (..), TransportHost (..), defaultSMPPort, defaultTcpConnectTimeout, runTransportClient) import Simplex.Messaging.Transport.KeepAlive -import Simplex.Messaging.Util (bshow, diffToMicroseconds, ifM, liftEitherWith, raceAny_, threadDelay', tryWriteTBQueue, tshow, whenM) +import Simplex.Messaging.Util import Simplex.Messaging.Version import System.Mem.Weak (Weak, deRefWeak) import System.Timeout (timeout) @@ -207,7 +210,8 @@ smpClientStub g sessionId thVersion thAuth = do blockSize = smpBlockSize, implySessId = thVersion >= authCmdsSMPVersion, encryptBlock = Nothing, - batch = True + batch = True, + serviceAuth = thVersion >= serviceCertsSMPVersion }, sessionTs = ts, client_ = @@ -428,6 +432,7 @@ data ProtocolClientConfig v = ProtocolClientConfig -- | network configuration networkConfig :: NetworkConfig, clientALPN :: Maybe [ALPN], + serviceCredentials :: Maybe ServiceCredentials, -- | client-server protocol version range serverVRange :: VersionRange v, -- | agree shared session secret (used in SMP proxy for additional encryption layer) @@ -446,6 +451,7 @@ defaultClientConfig clientALPN useSNI serverVRange = defaultTransport = ("443", transport @TLS), networkConfig = defaultNetworkConfig, clientALPN, + serviceCredentials = Nothing, serverVRange, agreeSecret = False, proxyServer = False, @@ -518,7 +524,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe ByteString) -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig v -> [HostName] -> Maybe (TBQueue (ServerTransmissionBatch v err msg)) -> UTCTime -> (ProtocolClient v err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg)) -getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, clientALPN, serverVRange, agreeSecret, proxyServer, useSNI} presetDomains msgQ proxySessTs disconnected = do +getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, clientALPN, serviceCredentials, serverVRange, agreeSecret, proxyServer, useSNI} presetDomains msgQ proxySessTs disconnected = do case chooseTransportHost networkConfig (host srv) of Right useHost -> (getCurrentTime >>= mkProtocolClient useHost >>= runClient useTransport useHost) @@ -556,7 +562,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize runClient :: (ServiceName, ATransport 'TClient) -> TransportHost -> PClient v err msg -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg)) runClient (port', ATransport t) useHost c = do cVar <- newEmptyTMVarIO - let tcConfig = transportClientConfig networkConfig useHost useSNI clientALPN + let tcConfig = (transportClientConfig networkConfig useHost useSNI clientALPN) {clientCredentials = serviceCreds <$> serviceCredentials} socksCreds = clientSocksCredentials networkConfig proxySessTs transportSession tId <- runTransportClient tcConfig socksCreds useHost port' (Just $ keyHash srv) (client t c cVar) @@ -584,7 +590,8 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize client :: forall c. Transport c => TProxy c 'TClient -> PClient v err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient v err msg)) -> c 'TClient -> IO () client _ c cVar h = do ks <- if agreeSecret then Just <$> atomically (C.generateKeyPair g) else pure Nothing - runExceptT (protocolClientHandshake @v @err @msg h ks (keyHash srv) serverVRange proxyServer) >>= \case + serviceKeys_ <- mapM (\creds -> (creds,) <$> atomically (C.generateKeyPair g)) serviceCredentials + runExceptT (protocolClientHandshake @v @err @msg h ks (keyHash srv) serverVRange proxyServer serviceKeys_) >>= \case Left e -> atomically . putTMVar cVar . Left $ PCETransportError e Right th@THandle {params} -> do sessionTs <- getCurrentTime @@ -702,6 +709,8 @@ data ProtocolClientError err PCENetworkError | -- | No host compatible with network configuration PCEIncompatibleHost + | -- | Service is unavailable for command that requires service connection + PCEServiceUnavailable | -- | TCP transport handshake or some other transport error. -- Forwarded to the agent client as `ERR BROKER TRANSPORT e`. PCETransportError TransportError @@ -721,6 +730,14 @@ temporaryClientError = \case _ -> False {-# INLINE temporaryClientError #-} +smpClientServiceError :: SMPClientError -> Bool +smpClientServiceError = \case + PCEServiceUnavailable -> True + PCETransportError (TEHandshake BAD_SERVICE) -> True -- TODO [certs] this error may be temporary, so we should possibly resubscribe. + PCEProtocolError SERVICE -> True + PCEProtocolError (PROXY (BROKER NO_SERVICE)) -> True -- for completeness, it cannot happen. + _ -> False + -- converts error of client running on proxy to the error sent to client connected to proxy smpProxyError :: SMPClientError -> ErrorType smpProxyError = \case @@ -730,6 +747,7 @@ smpProxyError = \case PCEResponseTimeout -> PROXY $ BROKER TIMEOUT PCENetworkError -> PROXY $ BROKER NETWORK PCEIncompatibleHost -> PROXY $ BROKER HOST + PCEServiceUnavailable -> PROXY $ BROKER $ NO_SERVICE -- for completeness, it cannot happen. PCETransportError t -> PROXY $ BROKER $ TRANSPORT t PCECryptoError _ -> CRYPTO PCEIOError _ -> INTERNAL @@ -756,34 +774,34 @@ createSMPQueue c nonce_ (rKey, rpKey) dhKey auth subMode qrd = -- | Subscribe to the SMP queue. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue -subscribeSMPQueue :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> ExceptT SMPClientError IO () +subscribeSMPQueue :: SMPClient -> RcvPrivateAuthKey -> RecipientId -> ExceptT SMPClientError IO (Maybe ServiceId) subscribeSMPQueue c rpKey rId = do liftIO $ enablePings c - sendSMPCommand c (Just rpKey) rId SUB >>= \case - OK -> pure () - cmd@MSG {} -> liftIO $ writeSMPMessage c rId cmd - r -> throwE $ unexpectedResponse r + sendSMPCommand c (Just rpKey) rId SUB >>= liftIO . processSUBResponse_ c rId >>= except -- | Subscribe to multiple SMP queues batching commands if supported. -subscribeSMPQueues :: SMPClient -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> IO (NonEmpty (Either SMPClientError ())) +subscribeSMPQueues :: SMPClient -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> IO (NonEmpty (Either SMPClientError (Maybe ServiceId))) subscribeSMPQueues c qs = do liftIO $ enablePings c sendProtocolCommands c cs >>= mapM (processSUBResponse c) where cs = L.map (\(rId, rpKey) -> (rId, Just rpKey, Cmd SRecipient SUB)) qs -streamSubscribeSMPQueues :: SMPClient -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> ([(RecipientId, Either SMPClientError ())] -> IO ()) -> IO () +streamSubscribeSMPQueues :: SMPClient -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> ([(RecipientId, Either SMPClientError (Maybe ServiceId))] -> IO ()) -> IO () streamSubscribeSMPQueues c qs cb = streamProtocolCommands c cs $ mapM process >=> cb where cs = L.map (\(rId, rpKey) -> (rId, Just rpKey, Cmd SRecipient SUB)) qs process r@(Response rId _) = (rId,) <$> processSUBResponse c r -processSUBResponse :: SMPClient -> Response ErrorType BrokerMsg -> IO (Either SMPClientError ()) -processSUBResponse c (Response rId r) = case r of - Right OK -> pure $ Right () - Right cmd@MSG {} -> writeSMPMessage c rId cmd $> Right () - Right r' -> pure . Left $ unexpectedResponse r' - Left e -> pure $ Left e +processSUBResponse :: SMPClient -> Response ErrorType BrokerMsg -> IO (Either SMPClientError (Maybe ServiceId)) +processSUBResponse c (Response rId r) = pure r $>>= processSUBResponse_ c rId + +processSUBResponse_ :: SMPClient -> RecipientId -> BrokerMsg -> IO (Either SMPClientError (Maybe ServiceId)) +processSUBResponse_ c rId = \case + OK -> pure $ Right Nothing + SOK serviceId_ -> pure $ Right serviceId_ + cmd@MSG {} -> writeSMPMessage c rId cmd $> Right Nothing + r' -> pure . Left $ unexpectedResponse r' writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO () writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c [(rId, STEvent (Right msg))]) (msgQ $ client_ c) @@ -806,18 +824,47 @@ getSMPMessage c rpKey rId = -- | Subscribe to the SMP queue notifications. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications -subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateAuthKey -> NotifierId -> ExceptT SMPClientError IO () +subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateAuthKey -> NotifierId -> ExceptT SMPClientError IO (Maybe ServiceId) subscribeSMPQueueNotifications c npKey nId = do liftIO $ enablePings c - okSMPCommand NSUB c npKey nId -{-# INLINE subscribeSMPQueueNotifications #-} + sendSMPCommand c (Just npKey) nId NSUB >>= except . nsubResponse_ -- | Subscribe to multiple SMP queues notifications batching commands if supported. -subscribeSMPQueuesNtfs :: SMPClient -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO (NonEmpty (Either SMPClientError ())) +subscribeSMPQueuesNtfs :: SMPClient -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO (NonEmpty (Either SMPClientError (Maybe ServiceId))) subscribeSMPQueuesNtfs c qs = do liftIO $ enablePings c - okSMPCommands NSUB c qs -{-# INLINE subscribeSMPQueuesNtfs #-} + L.map nsubResponse <$> sendProtocolCommands c cs + where + cs = L.map (\(nId, npKey) -> (nId, Just npKey, Cmd SNotifier NSUB)) qs + +nsubResponse :: Response ErrorType BrokerMsg -> Either SMPClientError (Maybe ServiceId) +nsubResponse (Response _ r) = r >>= nsubResponse_ +{-# INLINE nsubResponse #-} + +nsubResponse_ :: BrokerMsg -> Either SMPClientError (Maybe ServiceId) +nsubResponse_ = \case + OK -> Right Nothing + SOK serviceId_ -> Right serviceId_ + r' -> Left $ unexpectedResponse r' +{-# INLINE nsubResponse_ #-} + +subscribeService :: forall p. (PartyI p, SubscriberParty p) => SMPClient -> SParty p -> ExceptT SMPClientError IO Int64 +subscribeService c party = case smpClientService c of + Just THClientService {serviceId, serviceKey} -> do + liftIO $ enablePings c + sendSMPCommand c (Just (C.APrivateAuthKey C.SEd25519 serviceKey)) serviceId subCmd >>= \case + SOKS n -> pure n + r -> throwE $ unexpectedResponse r + where + subCmd :: Command p + subCmd = case party of + SRecipient -> SUBS + SNotifier -> NSUBS + Nothing -> throwE PCEServiceUnavailable + +smpClientService :: SMPClient -> Maybe THClientService +smpClientService = thAuth . thParams >=> clientService +{-# INLINE smpClientService #-} enablePings :: SMPClient -> IO () enablePings ProtocolClient {client_ = PClient {sendPings}} = atomically $ writeTVar sendPings True @@ -1049,15 +1096,16 @@ proxySMPCommand :: ExceptT SMPClientError IO (Either ProxyClientError BrokerMsg) proxySMPCommand c@ProtocolClient {thParams = proxyThParams, client_ = PClient {clientCorrId = g, tcpTimeout}} (ProxiedRelay sessionId v _ serverKey) spKey sId command = do -- prepare params - let serverThAuth = (\ta -> ta {serverPeerPubKey = serverKey}) <$> thAuth proxyThParams + let serverThAuth = (\ta -> ta {peerServerPubKey = serverKey}) <$> thAuth proxyThParams serverThParams = smpTHParamsSetVersion v proxyThParams {sessionId, thAuth = serverThAuth} (cmdPubKey, cmdPrivKey) <- liftIO . atomically $ C.generateKeyPair @'C.X25519 g let cmdSecret = C.dh' serverKey cmdPrivKey nonce@(C.CbNonce corrId) <- liftIO . atomically $ C.randomCbNonce g -- encode let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth serverThParams (CorrId corrId, sId, Cmd (sParty @p) command) - auth <- liftEitherWith PCETransportError $ authTransmission serverThAuth spKey nonce tForAuth - b <- case batchTransmissions (batch serverThParams) (blockSize serverThParams) [Right (auth, tToSend)] of + -- serviceAuth is False here – proxied commands are not used with service certificates + auth <- liftEitherWith PCETransportError $ authTransmission serverThAuth False spKey nonce tForAuth + b <- case batchTransmissions serverThParams [Right (auth, tToSend)] of [] -> throwE $ PCETransportError TELargeMsg TBError e _ : _ -> throwE $ PCETransportError e TBTransmission s _ : _ -> pure s @@ -1100,7 +1148,7 @@ forwardSMPTransmission c@ProtocolClient {thParams, client_ = PClient {clientCorr let fwdT = FwdTransmission {fwdCorrId, fwdVersion, fwdKey, fwdTransmission} eft = EncFwdTransmission $ C.cbEncryptNoPad sessSecret nonce (smpEncode fwdT) -- send - sendProtocolCommand_ c (Just nonce) Nothing Nothing NoEntity (Cmd SSender (RFWD eft)) >>= \case + sendProtocolCommand_ c (Just nonce) Nothing Nothing NoEntity (Cmd SProxyService (RFWD eft)) >>= \case RRES (EncFwdResponse efr) -> do -- unwrap r' <- liftEitherWith PCECryptoError $ C.cbDecryptNoPad sessSecret (C.reverseNonce nonce) efr @@ -1139,8 +1187,8 @@ type PCTransmission err msg = (Either TransportError SentRawTransmission, Reques -- | Send multiple commands with batching and collect responses sendProtocolCommands :: forall v err msg. Protocol v err msg => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) -sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs = do - bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs +sendProtocolCommands c@ProtocolClient {thParams} cs = do + bs <- batchTransmissions' thParams <$> mapM (mkTransmission c) cs validate . concat =<< mapM (sendBatch c) bs where validate :: [Response err msg] -> IO (NonEmpty (Response err msg)) @@ -1156,8 +1204,8 @@ sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSiz diff = L.length cs - length rs streamProtocolCommands :: forall v err msg. Protocol v err msg => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () -streamProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs cb = do - bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs +streamProtocolCommands c@ProtocolClient {thParams} cs cb = do + bs <- batchTransmissions' thParams <$> mapM (mkTransmission c) cs mapM_ (cb <=< sendBatch c) bs sendBatch :: ProtocolClient v err msg -> TransportBatch (Request err msg) -> IO [Response err msg] @@ -1186,7 +1234,7 @@ sendProtocolCommand c = sendProtocolCommand_ c Nothing Nothing -- -- Please note: if nonce is passed it is also used as a correlation ID sendProtocolCommand_ :: forall v err msg. Protocol v err msg => ProtocolClient v err msg -> Maybe C.CbNonce -> Maybe Int -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg -sendProtocolCommand_ c@ProtocolClient {client_ = PClient {sndQ}, thParams = THandleParams {batch, blockSize}} nonce_ tOut pKey entId cmd = +sendProtocolCommand_ c@ProtocolClient {client_ = PClient {sndQ}, thParams = THandleParams {batch, blockSize, serviceAuth}} nonce_ tOut pKey entId cmd = ExceptT $ uncurry sendRecv =<< mkTransmission_ c nonce_ (entId, pKey, cmd) where -- two separate "atomically" needed to avoid blocking @@ -1200,8 +1248,8 @@ sendProtocolCommand_ c@ProtocolClient {client_ = PClient {sndQ}, thParams = THan response <$> getResponse c tOut r where s - | batch = tEncodeBatch1 t - | otherwise = tEncode t + | batch = tEncodeBatch1 serviceAuth t + | otherwise = tEncode serviceAuth t nonBlockingWriteTBQueue :: TBQueue a -> a -> IO () nonBlockingWriteTBQueue q x = do @@ -1221,14 +1269,14 @@ getResponse ProtocolClient {client_ = PClient {tcpTimeout, timeoutErrorCount}} t Nothing -> modifyTVar' timeoutErrorCount (+ 1) $> Left PCEResponseTimeout pure Response {entityId, response} -mkTransmission :: Protocol v err msg => ProtocolClient v err msg -> ClientCommand msg -> IO (PCTransmission err msg) +mkTransmission :: Protocol v err msg => ProtocolClient v err msg -> ClientCommand msg -> IO (PCTransmission err msg) mkTransmission c = mkTransmission_ c Nothing mkTransmission_ :: forall v err msg. Protocol v err msg => ProtocolClient v err msg -> Maybe C.CbNonce -> ClientCommand msg -> IO (PCTransmission err msg) mkTransmission_ ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCommands}} nonce_ (entityId, pKey_, command) = do nonce@(C.CbNonce corrId) <- maybe (atomically $ C.randomCbNonce clientCorrId) pure nonce_ let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (CorrId corrId, entityId, command) - auth = authTransmission (thAuth thParams) pKey_ nonce tForAuth + auth = authTransmission (thAuth thParams) (useServiceAuth command) pKey_ nonce tForAuth r <- mkRequest (CorrId corrId) pure ((,tToSend) <$> auth, r) where @@ -1247,18 +1295,25 @@ mkTransmission_ ProtocolClient {thParams, client_ = PClient {clientCorrId, sentC atomically $ TM.insert corrId r sentCommands pure r -authTransmission :: Maybe (THandleAuth 'TClient) -> Maybe C.APrivateAuthKey -> C.CbNonce -> ByteString -> Either TransportError (Maybe TransmissionAuth) -authTransmission thAuth pKey_ nonce t = traverse authenticate pKey_ +authTransmission :: Maybe (THandleAuth 'TClient) -> Bool -> Maybe C.APrivateAuthKey -> C.CbNonce -> ByteString -> Either TransportError (Maybe TAuthorizations) +authTransmission thAuth serviceAuth pKey_ nonce t = traverse authenticate pKey_ where - authenticate :: C.APrivateAuthKey -> Either TransportError TransmissionAuth - authenticate (C.APrivateAuthKey a pk) = case a of + authenticate :: C.APrivateAuthKey -> Either TransportError TAuthorizations + authenticate (C.APrivateAuthKey a pk) = (,serviceSig) <$> case a of C.SX25519 -> case thAuth of - Just THAuthClient {serverPeerPubKey = k} -> Right $ TAAuthenticator $ C.cbAuthenticate k pk nonce t + Just THAuthClient {peerServerPubKey = k} -> Right $ TAAuthenticator $ C.cbAuthenticate k pk nonce t' Nothing -> Left TENoServerAuth C.SEd25519 -> sign pk C.SEd448 -> sign pk + -- When command is signed by both entity key and service key, + -- entity key must sign over both transmission and service certificate hash, + -- to prevent any service substitution via MITM inside TLS. + (t', serviceSig) = case clientService =<< thAuth of + Just THClientService {serviceCertHash = XV.Fingerprint fp, serviceKey} | serviceAuth -> + (fp <> t, Just $ C.sign' serviceKey t) -- service key only needs to sign transmission itself + _ -> (t, Nothing) sign :: forall a. (C.AlgorithmI a, C.SignatureAlgorithm a) => C.PrivateKey a -> Either TransportError TransmissionAuth - sign pk = Right $ TASignature $ C.ASignature (C.sAlgorithm @a) (C.sign' pk t) + sign pk = Right $ TASignature $ C.ASignature (C.sAlgorithm @a) (C.sign' pk t') data TBQueueInfo = TBQueueInfo { qLength :: Int, diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 1340eeeb0..9c2238551 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -11,7 +11,27 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} -module Simplex.Messaging.Client.Agent where +module Simplex.Messaging.Client.Agent + ( SMPClientAgent (..), + SMPClientAgentConfig (..), + SMPClientAgentEvent (..), + OwnServer, + defaultSMPClientAgentConfig, + newSMPClientAgent, + getSMPServerClient'', + getConnectedSMPServerClient, + closeSMPClientAgent, + lookupSMPServerClient, + isOwnServer, + subscribeServiceNtfs, + subscribeQueuesNtfs, + activeClientSession', + removeActiveSub, + removeActiveSubs, + removePendingSub, + removePendingSubs, + ) +where import Control.Concurrent (forkIO) import Control.Concurrent.Async (Async, uninterruptibleCancel) @@ -24,10 +44,13 @@ import Control.Monad.Trans.Except import Crypto.Random (ChaChaDRG) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Constraint (Dict (..)) +import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M +import Data.Maybe (isJust, isNothing) import qualified Data.Set as S import Data.Text.Encoding import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) @@ -36,7 +59,22 @@ import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, NotifierId, NtfPrivateAuthKey, Party (..), ProtocolServer (..), QueueId, RcvPrivateAuthKey, RecipientId, SMPServer, SParty (..), SubscriberParty) +import Simplex.Messaging.Protocol + ( BrokerMsg, + ErrorType, + NotifierId, + NtfPrivateAuthKey, + Party (..), + PartyI, + ProtocolServer (..), + QueueId, + SMPServer, + SParty (..), + SubscriberParty, + subscriberParty, + subscriberServiceRole + ) +import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -50,10 +88,18 @@ import UnliftIO.STM type SMPClientVar = SessionVar (Either (SMPClientError, Maybe UTCTime) (OwnServer, SMPClient)) data SMPClientAgentEvent - = CAConnected SMPServer + = CAConnected SMPServer (Maybe ServiceId) | CADisconnected SMPServer (NonEmpty QueueId) - | CASubscribed SMPServer (NonEmpty QueueId) + | CASubscribed SMPServer (Maybe ServiceId) (NonEmpty QueueId) | CASubError SMPServer (NonEmpty (QueueId, SMPClientError)) + | CAServiceDisconnected SMPServer (ServiceId, Int64) + | CAServiceSubscribed SMPServer (ServiceId, Int64) Int64 + | CAServiceSubError SMPServer (ServiceId, Int64) SMPClientError + -- CAServiceUnavailable is used when service ID in pending subscription is different from the current service in connection. + -- This will require resubscribing to all queues associated with this service ID individually, creating new associations. + -- It may happen if, for example, SMP server deletes service information (e.g. via downgrade and upgrade) + -- and assigns different service ID to the service certificate. + | CAServiceUnavailable SMPServer (ServiceId, Int64) data SMPClientAgentConfig = SMPClientAgentConfig { smpCfg :: ProtocolClientConfig SMPVersion, @@ -94,7 +140,14 @@ data SMPClientAgent p = SMPClientAgent randomDrg :: TVar ChaChaDRG, smpClients :: TMap SMPServer SMPClientVar, smpSessions :: TMap SessionId (OwnServer, SMPClient), + -- Only one service subscription can exist per server with this agent. + -- With correctly functioning SMP server, queue and service subscriptions can't be + -- active at the same time. + activeServiceSubs :: TMap SMPServer (TVar (Maybe ((ServiceId, Int64), SessionId))), activeQueueSubs :: TMap SMPServer (TMap QueueId (SessionId, C.APrivateAuthKey)), + -- Pending service subscriptions can co-exist with pending queue subscriptions + -- on the same SMP server during subscriptions being transitioned from per-queue to service. + pendingServiceSubs :: TMap SMPServer (TVar (Maybe (ServiceId, Int64))), pendingQueueSubs :: TMap SMPServer (TMap QueueId C.APrivateAuthKey), smpSubWorkers :: TMap SMPServer (SessionVar (Async ())), workerSeq :: TVar Int @@ -110,7 +163,9 @@ newSMPClientAgent agentParty agentCfg@SMPClientAgentConfig {msgQSize, agentQSize agentQ <- newTBQueueIO agentQSize smpClients <- TM.emptyIO smpSessions <- TM.emptyIO + activeServiceSubs <- TM.emptyIO activeQueueSubs <- TM.emptyIO + pendingServiceSubs <- TM.emptyIO pendingQueueSubs <- TM.emptyIO smpSubWorkers <- TM.emptyIO workerSeq <- newTVarIO 0 @@ -125,7 +180,9 @@ newSMPClientAgent agentParty agentCfg@SMPClientAgentConfig {msgQSize, agentQSize randomDrg, smpClients, smpSessions, + activeServiceSubs, activeQueueSubs, + pendingServiceSubs, pendingQueueSubs, smpSubWorkers, workerSeq @@ -170,7 +227,8 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke atomically $ do putTMVar (sessionVar v) (Right c) TM.insert (sessionId $ thParams smp) c smpSessions - notify ca $ CAConnected srv + let serviceId_ = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp + notify ca $ CAConnected srv serviceId_ pure $ Right c Left e -> do let ei = persistErrorInterval agentCfg @@ -196,27 +254,46 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random where clientDisconnected :: SMPClient -> IO () clientDisconnected smp = do - removeClientAndSubs smp >>= (`forM_` serverDown) + removeClientAndSubs smp >>= serverDown logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv - removeClientAndSubs :: SMPClient -> IO (Maybe (Map QueueId C.APrivateAuthKey)) - removeClientAndSubs smp = atomically $ do - TM.delete sessId smpSessions - removeSessVar v srv smpClients - TM.lookup srv (activeQueueSubs ca) >>= mapM updateSubs + removeClientAndSubs :: SMPClient -> IO (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) + removeClientAndSubs smp = do + -- Looking up subscription vars outside of STM transaction to reduce re-evaluation. + -- It is possible because these vars are never removed, they are only added. + sVar_ <- TM.lookupIO srv $ activeServiceSubs ca + qVar_ <- TM.lookupIO srv $ activeQueueSubs ca + atomically $ do + TM.delete sessId smpSessions + removeSessVar v srv smpClients + sSub <- pure sVar_ $>>= updateServiceSub + qSubs <- pure qVar_ $>>= updateQueueSubs + pure (sSub, qSubs) where sessId = sessionId $ thParams smp - updateSubs sVar = do + updateServiceSub sVar = do -- (sub, sessId') + -- We don't change active subscription in case session ID is different from disconnected client + serviceSub_ <- stateTVar sVar $ \case + Just (serviceSub, sessId') | sessId == sessId' -> (Just serviceSub, Nothing) + s -> (Nothing, s) + -- We don't reset pending subscription to Nothing here to avoid any race conditions + -- with subsequent client sessions that might have set pending already. + when (isJust serviceSub_) $ setPendingServiceSub ca srv serviceSub_ + pure serviceSub_ + updateQueueSubs qVar = do -- removing subscriptions that have matching sessionId to disconnected client -- and keep the other ones (they can be made by the new client) - pending <- M.map snd <$> stateTVar sVar (M.partition ((sessId ==) . fst)) - addSubs_ (pendingQueueSubs ca) srv pending - pure pending + subs <- M.map snd <$> stateTVar qVar (M.partition ((sessId ==) . fst)) + if M.null subs + then pure Nothing + else Just subs <$ addSubs_ (pendingQueueSubs ca) srv subs - serverDown :: Map QueueId C.APrivateAuthKey -> IO () - serverDown ss = forM_ (L.nonEmpty $ M.keys ss) $ \qIds -> do - notify ca $ CADisconnected srv qIds - reconnectClient ca srv + serverDown :: (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> IO () + serverDown (sSub, qSubs) = do + mapM_ (notify ca . CAServiceDisconnected srv) sSub + let qIds = L.nonEmpty . M.keys =<< qSubs + mapM_ (notify ca . CADisconnected srv) qIds + when (isJust sSub || isJust qIds) $ reconnectClient ca srv -- | Spawn reconnect worker if needed reconnectClient :: SMPClientAgent p -> SMPServer -> IO () @@ -226,7 +303,7 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s where getWorkerVar ts = ifM - (noPending) + (noPending <$> getPending TM.lookup readTVar) (pure Nothing) -- prevent race with cleanup and adding pending queues in another call (Just <$> getSessVar workerSeq srv smpSubWorkers ts) newSubWorker :: SessionVar (Async ()) -> IO () @@ -235,13 +312,17 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s atomically $ putTMVar (sessionVar v) a runSubWorker = withRetryInterval (reconnectInterval agentCfg) $ \_ loop -> do - pending <- liftIO getPending - unless (null pending) $ whenM (readTVarIO active) $ do - void $ tcpConnectTimeout `timeout` runExceptT (reconnectSMPClient ca srv pending) + subs <- getPending TM.lookupIO readTVarIO + unless (noPending subs) $ whenM (readTVarIO active) $ do + void $ tcpConnectTimeout `timeout` runExceptT (reconnectSMPClient ca srv subs) loop ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg - noPending = maybe (pure True) (fmap M.null . readTVar) =<< TM.lookup srv (pendingQueueSubs ca) - getPending = maybe (pure M.empty) readTVarIO =<< TM.lookupIO srv (pendingQueueSubs ca) + noPending (sSub, qSubs) = isNothing sSub && maybe True M.null qSubs + getPending :: Monad m => (forall a. SMPServer -> TMap SMPServer a -> m (Maybe a)) -> (forall a. TVar a -> m a) -> m (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) + getPending lkup rd = do + sSub <- lkup srv (pendingServiceSubs ca) $>>= rd + qSubs <- lkup srv (pendingQueueSubs ca) >>= mapM rd + pure (sSub, qSubs) cleanup :: SessionVar (Async ()) -> STM () cleanup v = do -- Here we wait until TMVar is not empty to prevent worker cleanup happening before worker is added to TMVar. @@ -249,19 +330,20 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s whenM (isEmptyTMVar $ sessionVar v) retry removeSessVar v srv smpSubWorkers -reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> Map QueueId C.APrivateAuthKey -> ExceptT SMPClientError IO () -reconnectSMPClient ca@SMPClientAgent {agentCfg, agentParty} srv subs = - withSMP ca srv $ \smp -> liftIO $ case agentParty of - SRecipient -> resubscribe SRecipient smp - SNotifier -> resubscribe SNotifier smp - _ -> pure () +reconnectSMPClient :: forall p. SMPClientAgent p -> SMPServer -> (Maybe (ServiceId, Int64), Maybe (Map QueueId C.APrivateAuthKey)) -> ExceptT SMPClientError IO () +reconnectSMPClient ca@SMPClientAgent {agentCfg, agentParty} srv (sSub_, qSubs_) = + withSMP ca srv $ \smp -> liftIO $ case subscriberParty agentParty of + Just Dict -> resubscribe smp + Nothing -> pure () where - resubscribe :: SubscriberParty p => SParty p -> SMPClient -> IO () - resubscribe _ smp = do - currSubs_ <- mapM readTVarIO =<< TM.lookupIO srv (activeQueueSubs ca) - let subs' :: [(QueueId, C.APrivateAuthKey)] = - maybe id (\currSubs -> filter ((`M.notMember` currSubs) . fst)) currSubs_ $ M.assocs subs - mapM_ (smpSubscribeQueues ca smp srv) $ toChunks (agentSubsBatchSize agentCfg) subs' + resubscribe :: (PartyI p, SubscriberParty p) => SMPClient -> IO () + resubscribe smp = do + mapM_ (smpSubscribeService ca smp srv) sSub_ + forM_ qSubs_ $ \qSubs -> do + currSubs_ <- mapM readTVarIO =<< TM.lookupIO srv (activeQueueSubs ca) + let qSubs' :: [(QueueId, C.APrivateAuthKey)] = + maybe id (\currSubs -> filter ((`M.notMember` currSubs) . fst)) currSubs_ $ M.assocs qSubs + mapM_ (smpSubscribeQueues @p ca smp srv) $ toChunks (agentSubsBatchSize agentCfg) qSubs' notify :: MonadIO m => SMPClientAgent p -> SMPClientAgentEvent -> m () notify ca evt = atomically $ writeTBQueue (agentQ ca) evt @@ -313,10 +395,6 @@ withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPE logInfo $ "SMP error (" <> safeDecodeUtf8 (strEncode $ host srv) <> "): " <> tshow e throwE e -subscribeQueuesSMP :: SMPClientAgent 'Recipient -> SMPServer -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> IO () -subscribeQueuesSMP = subscribeQueues_ -{-# INLINE subscribeQueuesSMP #-} - subscribeQueuesNtfs :: SMPClientAgent 'Notifier -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO () subscribeQueuesNtfs = subscribeQueues_ {-# INLINE subscribeQueuesNtfs #-} @@ -340,28 +418,36 @@ smpSubscribeQueues ca smp srv subs = do (Just <$> processSubscriptions rs) (pure Nothing) case rs' of - Just (tempErrs, finalErrs, oks, _) -> do - notify_ CASubscribed $ map fst oks + Just (tempErrs, finalErrs, (qOks, sQs), _) -> do + notify_ (`CASubscribed` Nothing) $ map fst qOks + when (isJust smpServiceId) $ notify_ (`CASubscribed` smpServiceId) sQs notify_ CASubError finalErrs when tempErrs $ reconnectClient ca srv Nothing -> reconnectClient ca srv where - processSubscriptions :: NonEmpty (Either SMPClientError ()) -> STM (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) + processSubscriptions :: NonEmpty (Either SMPClientError (Maybe ServiceId)) -> STM (Bool, [(QueueId, SMPClientError)], ([(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]), [QueueId]) processSubscriptions rs = do pending <- maybe (pure M.empty) readTVar =<< TM.lookup srv (pendingQueueSubs ca) - let acc@(_, _, oks, notPending) = foldr (groupSub pending) (False, [], [], []) (L.zip subs rs) - unless (null oks) $ addActiveSubs ca srv oks + let acc@(_, _, (qOks, sQs), notPending) = foldr (groupSub pending) (False, [], ([], []), []) (L.zip subs rs) + unless (null qOks) $ addActiveSubs ca srv qOks + unless (null sQs) $ forM_ smpServiceId $ \serviceId -> + updateActiveServiceSub ca srv ((serviceId, fromIntegral $ length sQs), sessId) unless (null notPending) $ removePendingSubs ca srv notPending pure acc sessId = sessionId $ thParams smp + smpServiceId = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp groupSub :: Map QueueId C.APrivateAuthKey -> - ((QueueId, C.APrivateAuthKey), Either SMPClientError ()) -> - (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) -> - (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) - groupSub pending ((qId, pk), r) acc@(!tempErrs, finalErrs, oks, notPending) = case r of - Right () - | M.member qId pending -> (tempErrs, finalErrs, (qId, (sessId, pk)) : oks, qId : notPending) + ((QueueId, C.APrivateAuthKey), Either SMPClientError (Maybe ServiceId)) -> + (Bool, [(QueueId, SMPClientError)], ([(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]), [QueueId]) -> + (Bool, [(QueueId, SMPClientError)], ([(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]), [QueueId]) + groupSub pending ((qId, pk), r) acc@(!tempErrs, finalErrs, oks@(qOks, sQs), notPending) = case r of + Right serviceId_ + | M.member qId pending -> + let oks' = case (smpServiceId, serviceId_) of + (Just sId, Just sId') | sId == sId' -> (qOks, qId : sQs) + _ -> ((qId, (sessId, pk)) : qOks, sQs) + in (tempErrs, finalErrs, oks', qId : notPending) | otherwise -> acc Left e | temporaryClientError e -> (True, finalErrs, oks, notPending) @@ -369,6 +455,48 @@ smpSubscribeQueues ca smp srv subs = do notify_ :: (SMPServer -> NonEmpty a -> SMPClientAgentEvent) -> [a] -> IO () notify_ evt qs = mapM_ (notify ca . evt srv) $ L.nonEmpty qs +subscribeServiceNtfs :: SMPClientAgent 'Notifier -> SMPServer -> (ServiceId, Int64) -> IO () +subscribeServiceNtfs = subscribeService_ +{-# INLINE subscribeServiceNtfs #-} + +subscribeService_ :: (PartyI p, SubscriberParty p) => SMPClientAgent p -> SMPServer -> (ServiceId, Int64) -> IO () +subscribeService_ ca srv serviceSub = do + atomically $ setPendingServiceSub ca srv $ Just serviceSub + runExceptT (getSMPServerClient' ca srv) >>= \case + Right smp -> smpSubscribeService ca smp srv serviceSub + Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that + +smpSubscribeService :: (PartyI p, SubscriberParty p) => SMPClientAgent p -> SMPClient -> SMPServer -> (ServiceId, Int64) -> IO () +smpSubscribeService ca smp srv serviceSub@(serviceId, _) = case smpClientService smp of + Just service | serviceAvailable service -> subscribe + _ -> notifyUnavailable + where + subscribe = do + r <- runExceptT $ subscribeService smp $ agentParty ca + ok <- + atomically $ + ifM + (activeClientSession ca smp srv) + (True <$ processSubscription r) + (pure False) + if ok + then case r of + Right n -> notify ca $ CAServiceSubscribed srv serviceSub n + Left e + | smpClientServiceError e -> notifyUnavailable + | temporaryClientError e -> reconnectClient ca srv + | otherwise -> notify ca $ CAServiceSubError srv serviceSub e + else reconnectClient ca srv + processSubscription = mapM_ $ \n -> do + setActiveServiceSub ca srv $ Just ((serviceId, n), sessId) + setPendingServiceSub ca srv Nothing + serviceAvailable THClientService {serviceRole, serviceId = serviceId'} = + serviceId == serviceId' && subscriberServiceRole (agentParty ca) == serviceRole + notifyUnavailable = do + atomically $ setPendingServiceSub ca srv Nothing + notify ca $ CAServiceUnavailable srv serviceSub -- this will resubscribe all queues directly + sessId = sessionId $ thParams smp + activeClientSession' :: SMPClientAgent p -> SessionId -> SMPServer -> STM Bool activeClientSession' ca sessId srv = sameSess <$> tryReadSessVar srv (smpClients ca) where @@ -400,7 +528,35 @@ addSubs_ :: TMap SMPServer (TMap QueueId s) -> SMPServer -> Map QueueId s -> STM addSubs_ subs srv ss = TM.lookup srv subs >>= \case Just m -> TM.union ss m - _ -> newTVar ss >>= \v -> TM.insert srv v subs + _ -> TM.insertM srv (newTVar ss) subs + +setActiveServiceSub :: SMPClientAgent p -> SMPServer -> Maybe ((ServiceId, Int64), SessionId) -> STM () +setActiveServiceSub = setServiceSub_ activeServiceSubs +{-# INLINE setActiveServiceSub #-} + +setPendingServiceSub :: SMPClientAgent p -> SMPServer -> Maybe (ServiceId, Int64) -> STM () +setPendingServiceSub = setServiceSub_ pendingServiceSubs +{-# INLINE setPendingServiceSub #-} + +setServiceSub_ :: + (SMPClientAgent p -> TMap SMPServer (TVar (Maybe sub))) -> + SMPClientAgent p -> + SMPServer -> + Maybe sub -> + STM () +setServiceSub_ subsSel ca srv sub = + TM.lookup srv (subsSel ca) >>= \case + Just v -> writeTVar v sub + Nothing -> TM.insertM srv (newTVar sub) (subsSel ca) + +updateActiveServiceSub :: SMPClientAgent p -> SMPServer -> ((ServiceId, Int64), SessionId) -> STM () +updateActiveServiceSub ca srv sub@((serviceId', n'), sessId') = + TM.lookup srv (activeServiceSubs ca) >>= \case + Just v -> modifyTVar' v $ \case + Just ((serviceId, n), sessId) | serviceId == serviceId' && sessId == sessId' -> + Just ((serviceId, n + n'), sessId) + _ -> Just sub + Nothing -> TM.insertM srv (newTVar $ Just sub) (activeServiceSubs ca) removeActiveSub :: SMPClientAgent p -> SMPServer -> QueueId -> STM () removeActiveSub = removeSub_ . activeQueueSubs diff --git a/src/Simplex/Messaging/Encoding/String.hs b/src/Simplex/Messaging/Encoding/String.hs index c114cee8e..922ff2266 100644 --- a/src/Simplex/Messaging/Encoding/String.hs +++ b/src/Simplex/Messaging/Encoding/String.hs @@ -39,6 +39,8 @@ import Data.Time.Clock (UTCTime) import Data.Time.Clock.System (SystemTime (..)) import Data.Time.Format.ISO8601 import Data.Word (Word16, Word32) +import qualified Data.X509 as X +import qualified Data.X509.Validation as XV import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util (bshow, (<$?>)) @@ -146,6 +148,18 @@ instance StrEncoding UTCTime where strEncode = B.pack . iso8601Show strP = maybe (Left "bad UTCTime") Right . iso8601ParseM . B.unpack <$?> A.takeTill (\c -> c == ' ' || c == '\n' || c == ',' || c == ';') +instance StrEncoding X.CertificateChain where + strEncode = (\(X.CertificateChainRaw blobs) -> strEncodeList blobs) . X.encodeCertificateChain + {-# INLINE strEncode #-} + strP = either (fail . show) pure . X.decodeCertificateChain . X.CertificateChainRaw =<< strListP + {-# INLINE strP #-} + +instance StrEncoding XV.Fingerprint where + strEncode (XV.Fingerprint s) = strEncode s + {-# INLINE strEncode #-} + strP = XV.Fingerprint <$> strP + {-# INLINE strP #-} + -- lists encode/parse as comma-separated strings strEncodeList :: StrEncoding a => [a] -> ByteString strEncodeList = B.intercalate "," . map strEncode diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 769c35510..f63f8e54c 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -154,10 +154,15 @@ instance Protocol NTFVersion ErrorType NtfResponse where type ProtoCommand NtfResponse = NtfCmd type ProtoType NtfResponse = 'PNTF protocolClientHandshake c _ks = ntfClientHandshake c + {-# INLINE protocolClientHandshake #-} + useServiceAuth _ = False + {-# INLINE useServiceAuth #-} protocolPing = NtfCmd SSubscription PING + {-# INLINE protocolPing #-} protocolError = \case NRErr e -> Just e _ -> Nothing + {-# INLINE protocolError #-} data NtfCommand (e :: NtfEntity) where -- | register new device token for notifications @@ -478,6 +483,8 @@ data NtfSubStatus NSDeleted | -- | SMP AUTH error NSAuth + | -- | SMP SERVICE error - rejected service signature on individual subscriptions + NSService | -- | SMP error other than AUTH NSErr ByteString deriving (Eq, Ord, Show) @@ -491,6 +498,7 @@ ntfShouldSubscribe = \case NSEnd -> False NSDeleted -> False NSAuth -> False + NSService -> True NSErr _ -> False instance Encoding NtfSubStatus where @@ -502,6 +510,7 @@ instance Encoding NtfSubStatus where NSEnd -> "END" NSDeleted -> "DELETED" NSAuth -> "AUTH" + NSService -> "SERVICE" NSErr err -> "ERR " <> err smpP = A.takeTill (== ' ') >>= \case @@ -512,6 +521,7 @@ instance Encoding NtfSubStatus where "END" -> pure NSEnd "DELETED" -> pure NSDeleted "AUTH" -> pure NSAuth + "SERVICE" -> pure NSService "ERR" -> NSErr <$> (A.space *> A.takeByteString) _ -> fail "bad NtfSubStatus" diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index cb5203341..7b325e10d 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -32,11 +32,11 @@ import Data.Functor (($>)) import Data.IORef import Data.Int (Int64) import qualified Data.IntSet as IS -import Data.List (foldl', intercalate) +import Data.List (foldl') import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (mapMaybe) +import Data.Maybe (isJust, mapMaybe) import qualified Data.Set as S import Data.Text (Text) import qualified Data.Text as T @@ -62,7 +62,7 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore, TokenNtfMessag import Simplex.Messaging.Notifications.Server.Store.Postgres import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Notifications.Transport -import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGet, tPut) +import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), NotifierId, Party (..), ProtocolServer (host), SMPServer, ServiceId, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGet, tPut) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server import Simplex.Messaging.Server.Control (CPClientRole (..)) @@ -187,31 +187,31 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions} ntfVrfInvalidTkn' <- atomicSwapIORef ntfVrfInvalidTkn 0 tkn <- liftIO $ periodStatCounts activeTokens ts sub <- liftIO $ periodStatCounts activeSubs ts - hPutStrLn h $ - intercalate + T.hPutStrLn h $ + T.intercalate "," - [ iso8601Show $ utctDay fromTime', - show tknCreated', - show tknVerified', - show tknDeleted', - show subCreated', - show subDeleted', - show ntfReceived', - show ntfDelivered', + [ T.pack $ iso8601Show $ utctDay fromTime', + tshow tknCreated', + tshow tknVerified', + tshow tknDeleted', + tshow subCreated', + tshow subDeleted', + tshow ntfReceived', + tshow ntfDelivered', dayCount tkn, weekCount tkn, monthCount tkn, dayCount sub, weekCount sub, monthCount sub, - show tknReplaced', - show ntfFailed', - show ntfCronDelivered', - show ntfCronFailed', - show ntfVrfQueued', - show ntfVrfDelivered', - show ntfVrfFailed', - show ntfVrfInvalidTkn' + tshow tknReplaced', + tshow ntfFailed', + tshow ntfCronDelivered', + tshow ntfCronFailed', + tshow ntfVrfQueued', + tshow ntfVrfDelivered', + tshow ntfVrfFailed', + tshow ntfVrfInvalidTkn' ] liftIO $ threadDelay' interval @@ -253,38 +253,66 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions} #endif let NtfSubscriber {smpSubscribers, smpAgent = a} = subscriber NtfPushServer {pushQ} = pushServer - SMPClientAgent {smpClients, smpSessions, activeQueueSubs, pendingQueueSubs, smpSubWorkers} = a + SMPClientAgent {smpClients, smpSessions, smpSubWorkers} = a srvSubscribers <- getSMPWorkerMetrics a smpSubscribers srvClients <- getSMPWorkerMetrics a smpClients srvSubWorkers <- getSMPWorkerMetrics a smpSubWorkers - ntfActiveSubs <- getSMPSubMetrics a activeQueueSubs - ntfPendingSubs <- getSMPSubMetrics a pendingQueueSubs + ntfActiveServiceSubs <- getSMPServiceSubMetrics a activeServiceSubs $ snd . fst + ntfActiveQueueSubs <- getSMPSubMetrics a activeQueueSubs + ntfPendingServiceSubs <- getSMPServiceSubMetrics a pendingServiceSubs snd + ntfPendingQueueSubs <- getSMPSubMetrics a pendingQueueSubs smpSessionCount <- M.size <$> readTVarIO smpSessions apnsPushQLength <- atomically $ lengthTBQueue pushQ - pure NtfRealTimeMetrics {threadsCount, srvSubscribers, srvClients, srvSubWorkers, ntfActiveSubs, ntfPendingSubs, smpSessionCount, apnsPushQLength} + pure + NtfRealTimeMetrics + { threadsCount, + srvSubscribers, + srvClients, + srvSubWorkers, + ntfActiveServiceSubs, + ntfActiveQueueSubs, + ntfPendingServiceSubs, + ntfPendingQueueSubs, + smpSessionCount, + apnsPushQLength + } where - getSMPSubMetrics :: SMPClientAgent 'Notifier -> TMap SMPServer (TMap NotifierId a) -> IO NtfSMPSubMetrics - getSMPSubMetrics a v = do - subs <- readTVarIO v + getSMPServiceSubMetrics :: forall sub. SMPClientAgent 'Notifier -> (SMPClientAgent 'Notifier -> TMap SMPServer (TVar (Maybe sub))) -> (sub -> Int64) -> IO NtfSMPSubMetrics + getSMPServiceSubMetrics a sel subQueueCount = getSubMetrics_ a sel countSubs + where + countSubs :: (NtfSMPSubMetrics, S.Set Text) -> (SMPServer, TVar (Maybe sub)) -> IO (NtfSMPSubMetrics, S.Set Text) + countSubs acc (srv, serviceSubs) = maybe acc (subMetricsResult a acc srv . fromIntegral . subQueueCount) <$> readTVarIO serviceSubs + + getSMPSubMetrics :: SMPClientAgent 'Notifier -> (SMPClientAgent 'Notifier -> TMap SMPServer (TMap NotifierId a)) -> IO NtfSMPSubMetrics + getSMPSubMetrics a sel = getSubMetrics_ a sel countSubs + where + countSubs :: (NtfSMPSubMetrics, S.Set Text) -> (SMPServer, TMap NotifierId a) -> IO (NtfSMPSubMetrics, S.Set Text) + countSubs acc (srv, queueSubs) = subMetricsResult a acc srv . M.size <$> readTVarIO queueSubs + + getSubMetrics_ :: + SMPClientAgent 'Notifier -> + (SMPClientAgent 'Notifier -> TVar (M.Map SMPServer sub')) -> + ((NtfSMPSubMetrics, S.Set Text) -> (SMPServer, sub') -> IO (NtfSMPSubMetrics, S.Set Text)) -> + IO NtfSMPSubMetrics + getSubMetrics_ a sel countSubs = do + subs <- readTVarIO $ sel a let metrics = NtfSMPSubMetrics {ownSrvSubs = M.empty, otherServers = 0, otherSrvSubCount = 0} (metrics', otherSrvs) <- foldM countSubs (metrics, S.empty) $ M.assocs subs pure (metrics' :: NtfSMPSubMetrics) {otherServers = S.size otherSrvs} + + subMetricsResult :: SMPClientAgent 'Notifier -> (NtfSMPSubMetrics, S.Set Text) -> SMPServer -> Int -> (NtfSMPSubMetrics, S.Set Text) + subMetricsResult a acc@(metrics, !otherSrvs) srv@(SMPServer (h :| _) _ _) cnt + | isOwnServer a srv = + let !ownSrvSubs' = M.alter (Just . maybe cnt (+ cnt)) host ownSrvSubs + metrics' = metrics {ownSrvSubs = ownSrvSubs'} :: NtfSMPSubMetrics + in (metrics', otherSrvs) + | cnt == 0 = acc + | otherwise = + let metrics' = metrics {otherSrvSubCount = otherSrvSubCount + cnt} :: NtfSMPSubMetrics + in (metrics', S.insert host otherSrvs) where - countSubs :: (NtfSMPSubMetrics, S.Set Text) -> (SMPServer, TMap NotifierId a) -> IO (NtfSMPSubMetrics, S.Set Text) - countSubs acc@(metrics, !otherSrvs) (srv@(SMPServer (h :| _) _ _), activeQueueSubs) = - result . M.size <$> readTVarIO activeQueueSubs - where - result cnt - | isOwnServer a srv = - let !ownSrvSubs' = M.alter (Just . maybe cnt (+ cnt)) host ownSrvSubs - metrics' = metrics {ownSrvSubs = ownSrvSubs'} :: NtfSMPSubMetrics - in (metrics', otherSrvs) - | cnt == 0 = acc - | otherwise = - let metrics' = metrics {otherSrvSubCount = otherSrvSubCount + cnt} :: NtfSMPSubMetrics - in (metrics', S.insert host otherSrvs) - NtfSMPSubMetrics {ownSrvSubs, otherSrvSubCount} = metrics - host = safeDecodeUtf8 $ strEncode h + NtfSMPSubMetrics {ownSrvSubs, otherSrvSubCount} = metrics + host = safeDecodeUtf8 $ strEncode h getSMPWorkerMetrics :: SMPClientAgent 'Notifier -> TMap SMPServer a -> IO NtfSMPWorkerMetrics getSMPWorkerMetrics a v = workerMetrics a . M.keys <$> readTVarIO v @@ -372,20 +400,21 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions} logError "Unauthorized control port command" hPutStrLn h "AUTH" r -> do - NtfRealTimeMetrics {threadsCount, srvSubscribers, srvClients, srvSubWorkers, ntfActiveSubs, ntfPendingSubs, smpSessionCount, apnsPushQLength} <- - getNtfRealTimeMetrics =<< unliftIO u ask + rtm <- getNtfRealTimeMetrics =<< unliftIO u ask #if MIN_VERSION_base(4,18,0) - hPutStrLn h $ "Threads: " <> show threadsCount + hPutStrLn h $ "Threads: " <> show (threadsCount rtm) #else hPutStrLn h "Threads: not available on GHC 8.10" #endif - putSMPWorkers "SMP subcscribers" srvSubscribers - putSMPWorkers "SMP clients" srvClients - putSMPWorkers "SMP subscription workers" srvSubWorkers - hPutStrLn h $ "SMP sessions count: " <> show smpSessionCount - putSMPSubs "SMP subscriptions" ntfActiveSubs - putSMPSubs "Pending SMP subscriptions" ntfPendingSubs - hPutStrLn h $ "Push notifications queue length: " <> show apnsPushQLength + putSMPWorkers "SMP subcscribers" $ srvSubscribers rtm + putSMPWorkers "SMP clients" $ srvClients rtm + putSMPWorkers "SMP subscription workers" $ srvSubWorkers rtm + hPutStrLn h $ "SMP sessions count: " <> show (smpSessionCount rtm) + putSMPSubs "SMP service subscriptions" $ ntfActiveServiceSubs rtm + putSMPSubs "SMP queue subscriptions" $ ntfActiveQueueSubs rtm + putSMPSubs "Pending SMP service subscriptions" $ ntfPendingServiceSubs rtm + putSMPSubs "Pending SMP queue subscriptions" $ ntfPendingQueueSubs rtm + hPutStrLn h $ "Push notifications queue length: " <> show (apnsPushQLength rtm) where putSMPSubs :: Text -> NtfSMPSubMetrics -> IO () putSMPSubs name NtfSMPSubMetrics {ownSrvSubs, otherServers, otherSrvSubCount} = do @@ -423,35 +452,39 @@ resubscribe NtfSubscriber {smpAgent = ca} = do liftIO $ do srvs <- getUsedSMPServers st logNote $ "Starting SMP resubscriptions for " <> tshow (length srvs) <> " servers..." - counts <- mapConcurrently (subscribeSrvSubs st batchSize) srvs + counts <- mapConcurrently (subscribeSrvSubs ca st batchSize) srvs logNote $ "Completed all SMP resubscriptions for " <> tshow (length srvs) <> " servers (" <> tshow (sum counts) <> " subscriptions)" + +subscribeSrvSubs :: SMPClientAgent 'Notifier -> NtfPostgresStore -> Int -> (SMPServer, Int64, Maybe (ServiceId, Int64)) -> IO Int +subscribeSrvSubs ca st batchSize (srv, srvId, service_) = do + let srvStr = safeDecodeUtf8 (strEncode $ L.head $ host srv) + logNote $ "Starting SMP resubscriptions for " <> srvStr + forM_ service_ $ \(serviceId, n) -> do + logNote $ "Subscribing service to " <> srvStr <> " with " <> tshow n <> " associated queues" + subscribeServiceNtfs ca srv (serviceId, n) + n <- subscribeLoop 0 Nothing + logNote $ "Completed SMP resubscriptions for " <> srvStr <> " (" <> tshow n <> " subscriptions)" + pure n where - subscribeSrvSubs st batchSize srv = do - let srvStr = safeDecodeUtf8 (strEncode $ L.head $ host srv) - logNote $ "Starting SMP resubscriptions for " <> srvStr - n <- loop 0 Nothing - logNote $ "Completed SMP resubscriptions for " <> srvStr <> " (" <> tshow n <> " subscriptions)" - pure n - where - dbBatchSize = batchSize * 100 - loop n afterSubId_ = - getServerNtfSubscriptions st srv afterSubId_ dbBatchSize >>= \case - Left _ -> exitFailure - Right [] -> pure n - Right subs -> do - mapM_ (subscribeQueuesNtfs ca srv . L.map snd) $ toChunks batchSize subs - let len = length subs - n' = n + len - afterSubId_' = Just $ fst $ last subs - if len < dbBatchSize then pure n' else loop n' afterSubId_' + dbBatchSize = batchSize * 100 + subscribeLoop n afterSubId_ = + getServerNtfSubscriptions st srvId afterSubId_ dbBatchSize >>= \case + Left _ -> exitFailure + Right [] -> pure n + Right subs -> do + mapM_ (subscribeQueuesNtfs ca srv . L.map snd) $ toChunks batchSize subs + let len = length subs + n' = n + len + afterSubId_' = Just $ fst $ last subs + if len < dbBatchSize then pure n' else subscribeLoop n' afterSubId_' -- this function is concurrency-safe - only onle subscriber per server can be created at a time, -- other threads would wait for the first thread to create it. -subscribeNtfs :: NtfSubscriber -> NtfPostgresStore -> SMPServer -> NonEmpty ServerNtfSub -> IO () -subscribeNtfs NtfSubscriber {smpSubscribers, subscriberSeq, smpAgent = ca} st smpServer ntfSubs = +subscribeNtfs :: NtfSubscriber -> NtfPostgresStore -> SMPServer -> ServerNtfSub -> IO () +subscribeNtfs NtfSubscriber {smpSubscribers, subscriberSeq, smpAgent = ca} st smpServer ntfSub = getSubscriberVar >>= either createSMPSubscriber waitForSMPSubscriber - >>= mapM_ (\sub -> atomically $ writeTQueue (subscriberSubQ sub) ntfSubs) + >>= mapM_ (\sub -> atomically $ writeTQueue (subscriberSubQ sub) ntfSub) where getSubscriberVar :: IO (Either SMPSubscriberVar SMPSubscriberVar) getSubscriberVar = atomically . getSessVar subscriberSeq smpServer smpSubscribers =<< getCurrentTime @@ -477,14 +510,13 @@ subscribeNtfs NtfSubscriber {smpSubscribers, subscriberSeq, smpAgent = ca} st sm atomically $ removeSessVar v smpServer smpSubscribers pure Nothing - runSMPSubscriber :: TQueue (NonEmpty ServerNtfSub) -> IO () + runSMPSubscriber :: TQueue ServerNtfSub -> IO () runSMPSubscriber q = forever $ do -- TODO [ntfdb] possibly, the subscriptions can be batched here and sent every say 5 seconds -- this should be analysed once we have prometheus stats - subs <- atomically $ readTQueue q - updated <- batchUpdateSubStatus st subs NSPending - logSubStatus smpServer "subscribing" (L.length subs) updated - subscribeQueuesNtfs ca smpServer $ L.map snd subs + (nId, sub) <- atomically $ readTQueue q + void $ updateSubStatus st nId NSPending + subscribeQueuesNtfs ca smpServer [sub] ntfSubscriber :: NtfSubscriber -> M () ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = @@ -520,34 +552,62 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = receiveAgent = do st <- asks store + batchSize <- asks $ subsBatchSize . config liftIO $ forever $ atomically (readTBQueue agentQ) >>= \case - CAConnected srv -> - logInfo $ "SMP server reconnected " <> showServer' srv + CAConnected srv serviceId -> do + let asService = if isJust serviceId then "as service " else "" + logInfo $ "SMP server reconnected " <> asService <> showServer' srv CADisconnected srv nIds -> do - updated <- batchUpdateSrvSubStatus st srv nIds NSInactive + updated <- batchUpdateSrvSubStatus st srv Nothing nIds NSInactive logSubStatus srv "disconnected" (L.length nIds) updated - CASubscribed srv nIds -> do - updated <- batchUpdateSrvSubStatus st srv nIds NSActive - logSubStatus srv "subscribed" (L.length nIds) updated + CASubscribed srv serviceId nIds -> do + updated <- batchUpdateSrvSubStatus st srv serviceId nIds NSActive + let asService = if isJust serviceId then " as service" else "" + logSubStatus srv ("subscribed" <> asService) (L.length nIds) updated CASubError srv errs -> do - forM_ (L.nonEmpty $ mapMaybe (\(nId, err) -> (nId,) <$> subErrorStatus err) $ L.toList errs) $ \subStatuses -> do - updated <- batchUpdateSrvSubStatuses st srv subStatuses + forM_ (L.nonEmpty $ mapMaybe (\(nId, err) -> (nId,) <$> queueSubErrorStatus err) $ L.toList errs) $ \subStatuses -> do + updated <- batchUpdateSrvSubErrors st srv subStatuses logSubErrors srv subStatuses updated + -- TODO [certs] resubscribe queues with statuses NSErr and NSService + CAServiceDisconnected srv serviceSub -> + logNote $ "SMP server service disconnected " <> showService srv serviceSub + CAServiceSubscribed srv serviceSub@(_, expected) n + | expected == n -> logNote msg + | otherwise -> logWarn $ msg <> ", confirmed subs: " <> tshow n + where + msg = "SMP server service subscribed " <> showService srv serviceSub + CAServiceSubError srv serviceSub e -> + -- Errors that require re-subscribing queues directly are reported as CAServiceUnavailable. + -- See smpSubscribeService in Simplex.Messaging.Client.Agent + logError $ "SMP server service subscription error " <> showService srv serviceSub <> ": " <> tshow e + CAServiceUnavailable srv serviceSub -> do + logError $ "SMP server service unavailable: " <> showService srv serviceSub + removeServiceAssociation st srv >>= \case + Right (srvId, updated) -> do + logSubStatus srv "removed service association" updated updated + void $ subscribeSrvSubs ca st batchSize (srv, srvId, Nothing) + Left e -> logError $ "SMP server update and resubscription error " <> tshow e + where + showService srv (serviceId, n) = showServer' srv <> ", service ID " <> decodeLatin1 (strEncode serviceId) <> ", " <> tshow n <> " subs" - logSubErrors :: SMPServer -> NonEmpty (SMP.NotifierId, NtfSubStatus) -> Int64 -> IO () + logSubErrors :: SMPServer -> NonEmpty (SMP.NotifierId, NtfSubStatus) -> Int -> IO () logSubErrors srv subs updated = forM_ (L.group $ L.sort $ L.map snd subs) $ \ss -> do logError $ "SMP server subscription errors " <> showServer' srv <> ": " <> tshow (L.head ss) <> " (" <> tshow (length ss) <> " errors, " <> tshow updated <> " subs updated)" - subErrorStatus :: SMPClientError -> Maybe NtfSubStatus - subErrorStatus = \case + queueSubErrorStatus :: SMPClientError -> Maybe NtfSubStatus + queueSubErrorStatus = \case PCEProtocolError AUTH -> Just NSAuth + -- TODO [certs] we could allow making individual subscriptions within service session to handle SERVICE error. + -- This would require full stack changes in SMP server, SMP client and SMP service agent. + PCEProtocolError SERVICE -> Just NSService PCEProtocolError e -> updateErr "SMP error " e PCEResponseError e -> updateErr "ResponseError " e PCEUnexpectedResponse r -> updateErr "UnexpectedResponse " r PCETransportError e -> updateErr "TransportError " e PCECryptoError e -> updateErr "CryptoError " e PCEIncompatibleHost -> Just $ NSErr "IncompatibleHost" + PCEServiceUnavailable -> Just NSService -- this error should not happen on individual subscriptions PCEResponseTimeout -> Nothing PCENetworkError -> Nothing PCEIOError _ -> Nothing @@ -556,7 +616,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = updateErr :: Show e => ByteString -> e -> Maybe NtfSubStatus updateErr errType e = Just $ NSErr $ errType <> bshow e -logSubStatus :: SMPServer -> T.Text -> Int -> Int64 -> IO () +logSubStatus :: SMPServer -> T.Text -> Int -> Int -> IO () logSubStatus srv event n updated = logInfo $ "SMP server " <> event <> " " <> showServer' srv <> " (" <> tshow n <> " subs, " <> tshow updated <> " subs updated)" @@ -796,7 +856,7 @@ client NtfServerClient {rcvQ, sndQ} ns@NtfSubscriber {smpAgent = ca} NtfPushServ withNtfStore (`addNtfSubscription` sub) $ \case True -> do st <- asks store - liftIO $ subscribeNtfs ns st srv [(subId, (nId, nKey))] + liftIO $ subscribeNtfs ns st srv (subId, (nId, nKey)) incNtfStat subCreated pure $ NRSubId subId False -> pure $ NRErr AUTH diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 07d02a502..e99350c7d 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -17,10 +17,11 @@ import Data.List.NonEmpty (NonEmpty) import qualified Data.Text as T import Data.Time.Clock (getCurrentTime) import Data.Time.Clock.System (SystemTime) -import Data.X509.Validation (Fingerprint (..)) +import qualified Data.X509.Validation as XV import Network.Socket import qualified Network.TLS as TLS import Numeric.Natural +import Simplex.Messaging.Client (ProtocolClientConfig (..)) import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol @@ -39,7 +40,7 @@ import Simplex.Messaging.Server.StoreLog (closeStoreLog) import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (ASrvTransport, THandleParams, TransportPeer (..)) +import Simplex.Messaging.Transport (ASrvTransport, SMPServiceRole (..), ServiceCredentials (..), THandleParams, TransportPeer (..)) import Simplex.Messaging.Transport.Server (AddHTTP, ServerCredentials, TransportServerConfig, loadFingerprint, loadServerCredential) import System.Exit (exitFailure) import System.Mem.Weak (Weak) @@ -60,6 +61,8 @@ data NtfServerConfig = NtfServerConfig inactiveClientExpiration :: Maybe ExpirationConfig, dbStoreConfig :: PostgresStoreCfg, ntfCredentials :: ServerCredentials, + -- send service credentials and use service subscriptions when SMP server supports them + useServiceCreds :: Bool, periodicNtfsInterval :: Int, -- seconds -- stats config - see SMP server config logStatsInterval :: Maybe Int64, @@ -93,14 +96,23 @@ data NtfEnv = NtfEnv } newNtfServerEnv :: NtfServerConfig -> IO NtfEnv -newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbStoreConfig, ntfCredentials, startOptions} = do +newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbStoreConfig, ntfCredentials, useServiceCreds, startOptions} = do when (compactLog startOptions) $ compactDbStoreLog $ dbStoreLogPath dbStoreConfig random <- C.newRandom store <- newNtfDbStore dbStoreConfig - subscriber <- newNtfSubscriber smpAgentCfg random - pushServer <- newNtfPushServer pushQSize apnsConfig tlsServerCreds <- loadServerCredential ntfCredentials - Fingerprint fp <- loadFingerprint ntfCredentials + serviceCertHash@(XV.Fingerprint fp) <- loadFingerprint ntfCredentials + smpAgentCfg' <- + if useServiceCreds + then do + serviceSignKey <- case C.x509ToPrivate' $ snd tlsServerCreds of + Right pk -> pure pk + Left e -> putStrLn ("Server has no valid key: " <> show e) >> exitFailure + let service = ServiceCredentials {serviceRole = SRNotifier, serviceCreds = tlsServerCreds, serviceCertHash, serviceSignKey} + pure smpAgentCfg {smpCfg = (smpCfg smpAgentCfg) {serviceCredentials = Just service}} + else pure smpAgentCfg + subscriber <- newNtfSubscriber smpAgentCfg' random + pushServer <- newNtfPushServer pushQSize apnsConfig serverStats <- newNtfServerStats =<< getCurrentTime pure NtfEnv {config, subscriber, pushServer, store, random, tlsServerCreds, serverIdentity = C.KeyHash fp, serverStats} where @@ -129,7 +141,7 @@ newNtfSubscriber smpAgentCfg random = do data SMPSubscriber = SMPSubscriber { smpServer :: SMPServer, - subscriberSubQ :: TQueue (NonEmpty ServerNtfSub), + subscriberSubQ :: TQueue ServerNtfSub, subThreadId :: Weak ThreadId } diff --git a/src/Simplex/Messaging/Notifications/Server/Main.hs b/src/Simplex/Messaging/Notifications/Server/Main.hs index 45f76d002..344456251 100644 --- a/src/Simplex/Messaging/Notifications/Server/Main.hs +++ b/src/Simplex/Messaging/Notifications/Server/Main.hs @@ -99,9 +99,9 @@ ntfServerCLI cfgPath logPath = restoreServerLastNtfs stmStore defaultLastNtfsFile let storeCfg = PostgresStoreCfg {dbOpts = dbOpts {createSchema = True}, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = iniDeletedTTL ini} ps <- newNtfDbStore storeCfg - (tCnt, sCnt, nCnt) <- importNtfSTMStore ps stmStore skipTokens + (tCnt, sCnt, nCnt, serviceCnt) <- importNtfSTMStore ps stmStore skipTokens renameFile storeLogFile $ storeLogFile <> ".bak" - putStrLn $ "Import completed: " <> show tCnt <> " tokens, " <> show sCnt <> " subscriptions, " <> show nCnt <> " last token notifications." + putStrLn $ "Import completed: " <> show tCnt <> " tokens, " <> show sCnt <> " subscriptions, " <> show serviceCnt <> " service associations, " <> show nCnt <> " last token notifications." putStrLn "Configure database options in INI file." SCExport | schemaExists && storeLogExists -> exitConfigureNtfStore connstr schema @@ -195,6 +195,8 @@ ntfServerCLI cfgPath logPath = \# socks_mode: onion\n\n\ \# The domain suffixes of the relays you operate (space-separated) to count as separate proxy statistics.\n\ \# own_server_domains: \n\n\ + \# User service subscriptions with server certificate\n\n\ + \# use_service_credentials: off\n\n\ \[INACTIVE_CLIENTS]\n\ \# TTL and interval to check inactive clients\n\ \disconnect: off\n" @@ -265,6 +267,7 @@ ntfServerCLI cfgPath logPath = privateKeyFile = c serverKeyFile, certificateFile = c serverCrtFile }, + useServiceCreds = fromMaybe False $ iniOnOff "SUBSCRIBER" "use_service_credentials" ini, periodicNtfsInterval = 5 * 60, -- 5 minutes logStatsInterval = logStats $> 86400, -- seconds logStatsStartTime = 0, -- seconds from 00:00 UTC @@ -276,7 +279,8 @@ ntfServerCLI cfgPath logPath = transportConfig = mkTransportServerConfig (fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini) - (Just alpnSupportedNTFHandshakes), + (Just alpnSupportedNTFHandshakes) + False, startOptions } iniDeletedTTL ini = readIniDefault (86400 * defaultDeletedTTL) "STORE_LOG" "db_deleted_ttl" ini diff --git a/src/Simplex/Messaging/Notifications/Server/Prometheus.hs b/src/Simplex/Messaging/Notifications/Server/Prometheus.hs index 78d5b4d38..faaa56951 100644 --- a/src/Simplex/Messaging/Notifications/Server/Prometheus.hs +++ b/src/Simplex/Messaging/Notifications/Server/Prometheus.hs @@ -17,6 +17,7 @@ import Numeric.Natural (Natural) import Simplex.Messaging.Notifications.Server.Stats import Simplex.Messaging.Server.Stats (PeriodStatCounts (..)) import Simplex.Messaging.Transport (simplexMQVersion) +import Simplex.Messaging.Util (tshow) data NtfServerMetrics = NtfServerMetrics { statsData :: NtfServerStatsData, @@ -36,8 +37,10 @@ data NtfRealTimeMetrics = NtfRealTimeMetrics srvSubscribers :: NtfSMPWorkerMetrics, srvClients :: NtfSMPWorkerMetrics, srvSubWorkers :: NtfSMPWorkerMetrics, - ntfActiveSubs :: NtfSMPSubMetrics, - ntfPendingSubs :: NtfSMPSubMetrics, + ntfActiveServiceSubs :: NtfSMPSubMetrics, + ntfActiveQueueSubs :: NtfSMPSubMetrics, + ntfPendingServiceSubs :: NtfSMPSubMetrics, + ntfPendingQueueSubs :: NtfSMPSubMetrics, smpSessionCount :: Int, apnsPushQLength :: Natural } @@ -57,8 +60,10 @@ ntfPrometheusMetrics sm rtm ts = srvSubscribers, srvClients, srvSubWorkers, - ntfActiveSubs, - ntfPendingSubs, + ntfActiveServiceSubs, + ntfActiveQueueSubs, + ntfPendingServiceSubs, + ntfPendingQueueSubs, smpSessionCount, apnsPushQLength } = rtm @@ -148,8 +153,10 @@ ntfPrometheusMetrics sm rtm ts = \# TYPE simplex_ntf_subscriptions_approx_total gauge\n\ \simplex_ntf_subscriptions_approx_total " <> mshow approxSubCount <> "\n# approxSubCount\n\ \\n" - <> showSubMetric ntfActiveSubs "simplex_ntf_smp_subscription_active_" "Active" - <> showSubMetric ntfPendingSubs "simplex_ntf_smp_subscription_pending_" "Pending" + <> showSubMetric ntfActiveServiceSubs "simplex_ntf_smp_service_subscription_active_" "Active" + <> showSubMetric ntfActiveQueueSubs "simplex_ntf_smp_subscription_active_" "Active" + <> showSubMetric ntfPendingServiceSubs "simplex_ntf_smp_service_subscription_pending_" "Pending" + <> showSubMetric ntfPendingQueueSubs "simplex_ntf_smp_subscription_pending_" "Pending" notifications = "# Notifications\n\ \# -------------\n\ @@ -244,9 +251,9 @@ ntfPrometheusMetrics sm rtm ts = \" <> name <> param <> " " <> mshow value <> "\n# " <> codeRef <> "\n\ \\n" metricHost host = "{server=\"" <> host <> "\"}" - mstr a = T.pack a <> " " <> tsEpoch + mstr a = a <> " " <> tsEpoch mshow :: Show a => a -> Text - mshow = mstr . show - tsEpoch = T.pack $ show @Int64 $ floor @Double $ realToFrac (ts `diffUTCTime` epoch) * 1000 + mshow = mstr . tshow + tsEpoch = tshow @Int64 $ floor @Double $ realToFrac (ts `diffUTCTime` epoch) * 1000 epoch = UTCTime systemEpochDay 0 {-# FOURMOLU_ENABLE\n#-} diff --git a/src/Simplex/Messaging/Notifications/Server/Store.hs b/src/Simplex/Messaging/Notifications/Server/Store.hs index 201e477d6..63a81ac0f 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store.hs @@ -24,7 +24,7 @@ import Data.Word (Word16) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol -import Simplex.Messaging.Protocol (NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer) +import Simplex.Messaging.Protocol (NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId) import Simplex.Messaging.Server.QueueStore (RoundedSystemTime) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -37,7 +37,8 @@ data NtfSTMStore = NtfSTMStore subscriptions :: TMap NtfSubscriptionId NtfSubData, tokenSubscriptions :: TMap NtfTokenId (TVar (Set NtfSubscriptionId)), subscriptionLookup :: TMap SMPQueueNtf NtfSubscriptionId, - tokenLastNtfs :: TMap NtfTokenId (TVar (NonEmpty PNMessageData)) + tokenLastNtfs :: TMap NtfTokenId (TVar (NonEmpty PNMessageData)), + ntfServices :: TMap SMPServer ServiceId } newNtfSTMStore :: IO NtfSTMStore @@ -48,7 +49,8 @@ newNtfSTMStore = do tokenSubscriptions <- TM.emptyIO subscriptionLookup <- TM.emptyIO tokenLastNtfs <- TM.emptyIO - pure NtfSTMStore {tokens, tokenRegistrations, subscriptions, tokenSubscriptions, subscriptionLookup, tokenLastNtfs} + ntfServices <- TM.emptyIO + pure NtfSTMStore {tokens, tokenRegistrations, subscriptions, tokenSubscriptions, subscriptionLookup, tokenLastNtfs, ntfServices} data NtfTknData = NtfTknData { ntfTknId :: NtfTokenId, @@ -74,7 +76,8 @@ data NtfSubData = NtfSubData smpQueue :: SMPQueueNtf, notifierKey :: NtfPrivateAuthKey, tokenId :: NtfTokenId, - subStatus :: TVar NtfSubStatus + subStatus :: TVar NtfSubStatus, + ntfServiceAssoc :: TVar Bool } ntfSubServer :: NtfSubData -> SMPServer @@ -183,6 +186,10 @@ stmStoreTokenLastNtf (NtfSTMStore {tokens, tokenLastNtfs}) tknId ntf = do whenM (TM.member tknId tokens) $ TM.insertM tknId (newTVar [ntf]) tokenLastNtfs +stmSetNtfService :: NtfSTMStore -> SMPServer -> Maybe ServiceId -> STM () +stmSetNtfService (NtfSTMStore {ntfServices}) srv serviceId = + maybe (TM.delete srv) (TM.insert srv) serviceId ntfServices + data TokenNtfMessageRecord = TNMRv1 NtfTokenId PNMessageData instance StrEncoding TokenNtfMessageRecord where diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs index 700be059f..226a02dc6 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs @@ -11,7 +11,8 @@ import Text.RawString.QQ (r) ntfServerSchemaMigrations :: [(String, Text, Maybe Text)] ntfServerSchemaMigrations = - [ ("20250417_initial", m20250417_initial, Nothing) + [ ("20250417_initial", m20250417_initial, Nothing), + ("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert) ] -- | The list of migrations in ascending order by date @@ -79,3 +80,27 @@ CREATE INDEX idx_last_notifications_subscription_id ON last_notifications(subscr CREATE UNIQUE INDEX idx_last_notifications_token_subscription ON last_notifications(token_id, subscription_id); |] + +m20250517_service_cert :: Text +m20250517_service_cert = + T.pack + [r| +ALTER TABLE smp_servers ADD COLUMN ntf_service_id BYTEA; + +ALTER TABLE subscriptions ADD COLUMN ntf_service_assoc BOOLEAN NOT NULL DEFAULT FALSE; + +DROP INDEX idx_subscriptions_smp_server_id_status; +CREATE INDEX idx_subscriptions_smp_server_id_ntf_service_status ON subscriptions(smp_server_id, ntf_service_assoc, status); + |] + +down_m20250517_service_cert :: Text +down_m20250517_service_cert = + T.pack + [r| +DROP INDEX idx_subscriptions_smp_server_id_ntf_service_status; +CREATE INDEX idx_subscriptions_smp_server_id_status ON subscriptions(smp_server_id, status); + +ALTER TABLE smp_servers DROP COLUMN ntf_service_id; + +ALTER TABLE subscriptions DROP COLUMN ntf_service_assoc; + |] diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs index 9a201ff2a..aaa4e5932 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -37,7 +37,7 @@ import Data.List (findIndex, foldl') import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (fromMaybe, mapMaybe) +import Data.Maybe (fromMaybe, isJust, mapMaybe) import qualified Data.Set as S import Data.Text (Text) import qualified Data.Text as T @@ -63,7 +63,7 @@ import Simplex.Messaging.Notifications.Server.Store.Migrations import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Notifications.Server.StoreLog import Simplex.Messaging.Parsers (parseAll) -import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, pattern SMPServer) +import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, pattern SMPServer) import Simplex.Messaging.Server.QueueStore (RoundedSystemTime, getSystemDate) import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_) import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) @@ -237,32 +237,43 @@ updateTknCronInterval st tknId cronInt = -- Reads servers that have subscriptions that need subscribing. -- It is executed on server start, and it is supposed to crash on database error -getUsedSMPServers :: NtfPostgresStore -> IO [SMPServer] -getUsedSMPServers st = +getUsedSMPServers :: NtfPostgresStore -> IO [(SMPServer, Int64, Maybe (ServiceId, Int64))] +getUsedSMPServers st = withTransaction (dbStore st) $ \db -> - map rowToSrv <$> + map rowToSrvSubs <$> DB.query db [sql| - SELECT p.smp_host, p.smp_port, p.smp_keyhash + SELECT + p.smp_host, p.smp_port, p.smp_keyhash, p.smp_server_id, p.ntf_service_id, + SUM(CASE WHEN s.ntf_service_assoc THEN s.subs_count ELSE 0 END) :: BIGINT as service_subs_count FROM smp_servers p - WHERE EXISTS ( - SELECT 1 FROM subscriptions s - WHERE s.smp_server_id = p.smp_server_id - AND s.status IN ? - ) + JOIN ( + SELECT + smp_server_id, + ntf_service_assoc, + COUNT(1) as subs_count + FROM subscriptions + WHERE status IN ? + GROUP BY smp_server_id, ntf_service_assoc + ) s ON s.smp_server_id = p.smp_server_id + GROUP BY p.smp_host, p.smp_port, p.smp_keyhash, p.smp_server_id, p.ntf_service_id |] (Only (In [NSNew, NSPending, NSActive, NSInactive])) + where + rowToSrvSubs :: SMPServerRow :. (Int64, Maybe ServiceId, Int64) -> (SMPServer, Int64, Maybe (ServiceId, Int64)) + rowToSrvSubs ((host, port, kh) :. (srvId, serviceId_, subsCount)) = + (SMPServer host port kh, srvId, (,subsCount) <$> serviceId_) -getServerNtfSubscriptions :: NtfPostgresStore -> SMPServer -> Maybe NtfSubscriptionId -> Int -> IO (Either ErrorType [ServerNtfSub]) -getServerNtfSubscriptions st srv afterSubId_ count = +getServerNtfSubscriptions :: NtfPostgresStore -> Int64 -> Maybe NtfSubscriptionId -> Int -> IO (Either ErrorType [ServerNtfSub]) +getServerNtfSubscriptions st srvId afterSubId_ count = withDB' "getServerNtfSubscriptions" st $ \db -> do subs <- map toServerNtfSub <$> case afterSubId_ of Nothing -> - DB.query db (query <> orderLimit) (srvToRow srv :. (statusIn, count)) + DB.query db (query <> orderLimit) (srvId, statusIn, count) Just afterSubId -> - DB.query db (query <> " AND s.subscription_id > ?" <> orderLimit) (srvToRow srv :. (statusIn, afterSubId, count)) + DB.query db (query <> " AND subscription_id > ?" <> orderLimit) (srvId, statusIn, afterSubId, count) void $ DB.executeMany db @@ -278,13 +289,11 @@ getServerNtfSubscriptions st srv afterSubId_ count = where query = [sql| - SELECT s.subscription_id, s.smp_notifier_id, s.smp_notifier_key - FROM subscriptions s - JOIN smp_servers p ON p.smp_server_id = s.smp_server_id - WHERE p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? - AND s.status IN ? + SELECT subscription_id, smp_notifier_id, smp_notifier_key + FROM subscriptions + WHERE smp_server_id = ? AND NOT ntf_service_assoc AND status IN ? |] - orderLimit = " ORDER BY s.subscription_id LIMIT ?" + orderLimit = " ORDER BY subscription_id LIMIT ?" statusIn = In [NSNew, NSPending, NSActive, NSInactive] toServerNtfSub (ntfSubId, notifierId, notifierKey) = (ntfSubId, (notifierId, notifierKey)) @@ -301,7 +310,7 @@ findNtfSubscription st tknId q = DB.query db [sql| - SELECT s.token_id, s.subscription_id, s.smp_notifier_key, s.status + SELECT s.token_id, s.subscription_id, s.smp_notifier_key, s.status, s.ntf_service_assoc FROM subscriptions s JOIN smp_servers p ON p.smp_server_id = s.smp_server_id WHERE p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? @@ -320,33 +329,33 @@ getNtfSubscription st subId = db [sql| SELECT t.token_id, t.push_provider, t.push_provider_token, t.status, t.verify_key, t.dh_priv_key, t.dh_secret, t.reg_code, t.cron_interval, t.updated_at, - s.subscription_id, s.smp_notifier_key, s.status, + s.subscription_id, s.smp_notifier_key, s.status, s.ntf_service_assoc, p.smp_host, p.smp_port, p.smp_keyhash, s.smp_notifier_id FROM subscriptions s JOIN tokens t ON t.token_id = s.token_id JOIN smp_servers p ON p.smp_server_id = s.smp_server_id WHERE s.subscription_id = ? - |] + |] (Only subId) liftIO $ updateTokenDate st db tkn unless (allowNtfSubCommands tknStatus) $ throwE AUTH pure r -type NtfSubRow = (NtfSubscriptionId, NtfPrivateAuthKey, NtfSubStatus) +type NtfSubRow = (NtfSubscriptionId, NtfPrivateAuthKey, NtfSubStatus, NtfAssociatedService) rowToNtfTknSub :: NtfTknRow :. NtfSubRow :. SMPQueueNtfRow -> (NtfTknRec, NtfSubRec) -rowToNtfTknSub (tknRow :. (ntfSubId, notifierKey, subStatus) :. qRow) = +rowToNtfTknSub (tknRow :. (ntfSubId, notifierKey, subStatus, ntfServiceAssoc) :. qRow) = let tkn@NtfTknRec {ntfTknId = tokenId} = rowToNtfTkn tknRow smpQueue = rowToSMPQueue qRow - in (tkn, NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus}) + in (tkn, NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus, ntfServiceAssoc}) rowToNtfSub :: SMPQueueNtf -> Only NtfTokenId :. NtfSubRow -> NtfSubRec -rowToNtfSub smpQueue (Only tokenId :. (ntfSubId, notifierKey, subStatus)) = - NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus} +rowToNtfSub smpQueue (Only tokenId :. (ntfSubId, notifierKey, subStatus, ntfServiceAssoc)) = + NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus, ntfServiceAssoc} mkNtfSubRec :: NtfSubscriptionId -> NewNtfEntity 'Subscription -> NtfSubRec mkNtfSubRec ntfSubId (NewNtfSub tokenId smpQueue notifierKey) = - NtfSubRec {ntfSubId, tokenId, smpQueue, subStatus = NSNew, notifierKey} + NtfSubRec {ntfSubId, tokenId, smpQueue, subStatus = NSNew, notifierKey, ntfServiceAssoc = False} updateTknStatus :: NtfPostgresStore -> NtfTknRec -> NtfTknStatus -> IO (Either ErrorType ()) updateTknStatus st tkn status = @@ -408,14 +417,14 @@ addNtfSubscription st sub = getServer = maybeFirstRow fromOnly $ DB.query - db + db [sql| SELECT smp_server_id FROM smp_servers WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? |] (srvToRow srv) - insertServer = + insertServer = firstRow fromOnly (STORE "error inserting SMP server when adding subscription") $ DB.query db @@ -430,13 +439,13 @@ addNtfSubscription st sub = insertNtfSubQuery :: Query insertNtfSubQuery = [sql| - INSERT INTO subscriptions (token_id, smp_server_id, smp_notifier_id, subscription_id, smp_notifier_key, status) - VALUES (?,?,?,?,?,?) + INSERT INTO subscriptions (token_id, smp_server_id, smp_notifier_id, subscription_id, smp_notifier_key, status, ntf_service_assoc) + VALUES (?,?,?,?,?,?,?) |] ntfSubToRow :: Int64 -> NtfSubRec -> (NtfTokenId, Int64, NotifierId) :. NtfSubRow -ntfSubToRow srvId NtfSubRec {ntfSubId, tokenId, smpQueue = SMPQueueNtf _ nId, notifierKey, subStatus} = - (tokenId, srvId, nId) :. (ntfSubId, notifierKey, subStatus) +ntfSubToRow srvId NtfSubRec {ntfSubId, tokenId, smpQueue = SMPQueueNtf _ nId, notifierKey, subStatus, ntfServiceAssoc} = + (tokenId, srvId, nId) :. (ntfSubId, notifierKey, subStatus, ntfServiceAssoc) deleteNtfSubscription :: NtfPostgresStore -> NtfSubscriptionId -> IO (Either ErrorType ()) deleteNtfSubscription st subId = @@ -445,11 +454,27 @@ deleteNtfSubscription st subId = DB.execute db "DELETE FROM subscriptions WHERE subscription_id = ?" (Only subId) withLog "deleteNtfSubscription" st (`logDeleteSubscription` subId) +updateSubStatus :: NtfPostgresStore -> NotifierId -> NtfSubStatus -> IO (Either ErrorType ()) +updateSubStatus st nId status = + withFastDB' "updateSubStatus" st $ \db -> do + sub_ :: Maybe (NtfSubscriptionId, NtfAssociatedService) <- + maybeFirstRow id $ + DB.query + db + [sql| + UPDATE subscriptions SET status = ? + WHERE smp_notifier_id = ? AND status != ? + RETURNING subscription_id, ntf_service_assoc + |] + (status, nId, status) + forM_ sub_ $ \(subId, serviceAssoc) -> + withLog "updateSubStatus" st $ \sl -> logSubscriptionStatus sl (subId, status, serviceAssoc) + updateSrvSubStatus :: NtfPostgresStore -> SMPQueueNtf -> NtfSubStatus -> IO (Either ErrorType ()) updateSrvSubStatus st q status = withFastDB' "updateSrvSubStatus" st $ \db -> do - subId_ :: Maybe NtfSubscriptionId <- - maybeFirstRow fromOnly $ + sub_ :: Maybe (NtfSubscriptionId, NtfAssociatedService) <- + maybeFirstRow id $ DB.query db [sql| @@ -459,57 +484,39 @@ updateSrvSubStatus st q status = WHERE p.smp_server_id = s.smp_server_id AND p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? AND s.smp_notifier_id = ? AND s.status != ? - RETURNING s.subscription_id + RETURNING s.subscription_id, s.ntf_service_assoc |] (Only status :. smpQueueToRow q :. Only status) - forM_ subId_ $ \subId -> - withLog "updateSrvSubStatus" st $ \sl -> logSubscriptionStatus sl subId status + forM_ sub_ $ \(subId, serviceAssoc) -> + withLog "updateSrvSubStatus" st $ \sl -> logSubscriptionStatus sl (subId, status, serviceAssoc) -batchUpdateSrvSubStatus :: NtfPostgresStore -> SMPServer -> NonEmpty NotifierId -> NtfSubStatus -> IO Int64 -batchUpdateSrvSubStatus st srv nIds status = - batchUpdateStatus_ st srv $ \srvId -> - -- without executeMany - -- L.toList $ L.map (status,srvId,,status) nIds - L.toList $ L.map (status,srvId,) nIds - -batchUpdateSrvSubStatuses :: NtfPostgresStore -> SMPServer -> NonEmpty (NotifierId, NtfSubStatus) -> IO Int64 -batchUpdateSrvSubStatuses st srv subs = - batchUpdateStatus_ st srv $ \srvId -> - -- without executeMany - -- L.toList $ L.map (\(nId, status) -> (status, srvId, nId, status)) subs - L.toList $ L.map (\(nId, status) -> (status, srvId, nId)) subs - --- without executeMany --- batchUpdateStatus_ :: NtfPostgresStore -> SMPServer -> (Int64 -> [(NtfSubStatus, Int64, NotifierId, NtfSubStatus)]) -> IO Int64 -batchUpdateStatus_ :: NtfPostgresStore -> SMPServer -> (Int64 -> [(NtfSubStatus, Int64, NotifierId)]) -> IO Int64 -batchUpdateStatus_ st srv mkParams = - fmap (fromRight (-1)) $ withDB "batchUpdateStatus_" st $ \db -> runExceptT $ do - srvId <- ExceptT $ getSMPServerId db - let params = mkParams srvId - subs <- - liftIO $ - DB.returning +batchUpdateSrvSubStatus :: NtfPostgresStore -> SMPServer -> Maybe ServiceId -> NonEmpty NotifierId -> NtfSubStatus -> IO Int +batchUpdateSrvSubStatus st srv newServiceId nIds status = + fmap (fromRight (-1)) $ withDB "batchUpdateSrvSubStatus" st $ \db -> runExceptT $ do + (srvId, currServiceId) <- ExceptT $ getSMPServerService db + unless (currServiceId == newServiceId) $ liftIO $ void $ + DB.execute db "UPDATE smp_servers SET ntf_service_id = ? WHERE smp_server_id = ?" (newServiceId, srvId) + let params = L.toList $ L.map (srvId,isJust newServiceId,status,) nIds + batchUpdateStatus_ st db params + where + getSMPServerService db = + firstRow id AUTH $ + DB.query db [sql| - UPDATE subscriptions s - SET status = upd.status - FROM (VALUES(?, ?, ?)) AS upd(status, smp_server_id, smp_notifier_id) - WHERE s.smp_server_id = upd.smp_server_id - AND s.smp_notifier_id = (upd.smp_notifier_id :: BYTEA) - AND s.status != upd.status - RETURNING s.subscription_id, s.status + SELECT smp_server_id, ntf_service_id + FROM smp_servers + WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? + FOR UPDATE |] - params - -- TODO [ntfdb] below is equivalent without using executeMany. - -- executeMany "works", and logs updates. - -- We do not have tests that validate correct subscription status, - -- and the potential problem is BYTEA conversation - VALUES are inserted as TEXT in this case for some reason. - -- subs <- - -- liftIO $ fmap catMaybes $ forM params $ - -- maybeFirstRow id . DB.query db "UPDATE subscriptions SET status = ? WHERE smp_server_id = ? AND smp_notifier_id = ? AND status != ? RETURNING subscription_id, status" - -- logWarn $ "batchUpdateStatus_: " <> tshow (length subs) - withLog "batchUpdateStatus_" st $ forM_ subs . uncurry . logSubscriptionStatus - pure $ fromIntegral $ length subs + (srvToRow srv) + +batchUpdateSrvSubErrors :: NtfPostgresStore -> SMPServer -> NonEmpty (NotifierId, NtfSubStatus) -> IO Int +batchUpdateSrvSubErrors st srv subs = + fmap (fromRight (-1)) $ withDB "batchUpdateSrvSubErrors" st $ \db -> runExceptT $ do + srvId <- ExceptT $ getSMPServerId db + let params = L.toList $ L.map (\(nId, status) -> (srvId, False, status, nId)) subs + batchUpdateStatus_ st db params where getSMPServerId db = firstRow fromOnly AUTH $ @@ -522,31 +529,55 @@ batchUpdateStatus_ st srv mkParams = |] (srvToRow srv) -batchUpdateSubStatus :: NtfPostgresStore -> NonEmpty ServerNtfSub -> NtfSubStatus -> IO Int64 -batchUpdateSubStatus st subs status = - fmap (fromRight (-1)) $ withFastDB' "batchUpdateSubStatus" st $ \db -> do - let params = L.toList $ L.map (\(subId, _) -> (status, subId)) subs - subIds <- +batchUpdateStatus_ :: NtfPostgresStore -> DB.Connection -> [(Int64, NtfAssociatedService, NtfSubStatus, NotifierId)] -> ExceptT ErrorType IO Int +batchUpdateStatus_ st db params = do + subs <- + liftIO $ DB.returning db [sql| UPDATE subscriptions s - SET status = upd.status - FROM (VALUES(?, ?)) AS upd(status, subscription_id) - WHERE s.subscription_id = (upd.subscription_id :: BYTEA) - AND s.status != upd.status - RETURNING s.subscription_id + SET status = upd.status, ntf_service_assoc = upd.ntf_service_assoc + FROM (VALUES(?, ?, ?, ?)) AS upd(smp_server_id, ntf_service_assoc, status, smp_notifier_id) + WHERE s.smp_server_id = upd.smp_server_id + AND s.smp_notifier_id = (upd.smp_notifier_id :: BYTEA) + AND (s.status != upd.status OR s.ntf_service_assoc != upd.ntf_service_assoc) + RETURNING s.subscription_id, s.status, s.ntf_service_assoc |] params - -- TODO [ntfdb] below is equivalent without using executeMany - see comment above. - -- let params = L.toList $ L.map (\NtfSubRec {ntfSubId} -> (status, ntfSubId, status)) subs - -- subIds <- - -- fmap catMaybes $ forM params $ - -- maybeFirstRow id . DB.query db "UPDATE subscriptions SET status = ? WHERE subscription_id = ? AND status != ? RETURNING subscription_id" - -- logWarn $ "batchUpdateSubStatus: " <> tshow (length subIds) - withLog "batchUpdateSubStatus" st $ \sl -> - forM_ subIds $ \(Only subId) -> logSubscriptionStatus sl subId status - pure $ fromIntegral $ length subIds + withLog "batchUpdateStatus_" st $ forM_ subs . logSubscriptionStatus + pure $ length subs + +removeServiceAssociation :: NtfPostgresStore -> SMPServer -> IO (Either ErrorType (Int64, Int)) +removeServiceAssociation st srv = do + withDB "removeServiceAssociation" st $ \db -> runExceptT $ do + srvId <- ExceptT $ removeServerService db + subs <- + liftIO $ + DB.query + db + [sql| + UPDATE subscriptions s + SET status = ?, ntf_service_assoc = FALSE + WHERE smp_server_id = ? + AND (s.status != ? OR s.ntf_service_assoc != FALSE) + RETURNING s.subscription_id, s.status, s.ntf_service_assoc + |] + (NSInactive, srvId, NSInactive) + withLog "removeServiceAssociation" st $ forM_ subs . logSubscriptionStatus + pure (srvId, length subs) + where + removeServerService db = + firstRow fromOnly AUTH $ + DB.query + db + [sql| + UPDATE smp_servers + SET ntf_service_id = NULL + WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? + RETURNING smp_server_id + |] + (srvToRow srv) addTokenLastNtf :: NtfPostgresStore -> PNMessageData -> IO (Either ErrorType (NtfTknRec, NonEmpty PNMessageData)) addTokenLastNtf st newNtf = @@ -626,15 +657,16 @@ getEntityCounts st = pure (tCnt, sCnt, nCnt) where count (Only n : _) = n - count [] = 0 + count [] = 0 -importNtfSTMStore :: NtfPostgresStore -> NtfSTMStore -> S.Set NtfTokenId -> IO (Int64, Int64, Int64) +importNtfSTMStore :: NtfPostgresStore -> NtfSTMStore -> S.Set NtfTokenId -> IO (Int64, Int64, Int64, Int64) importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore skipTokens = do (tIds, tCnt) <- importTokens subLookup <- readTVarIO $ subscriptionLookup stmStore sCnt <- importSubscriptions tIds subLookup nCnt <- importLastNtfs tIds subLookup - pure (tCnt, sCnt, nCnt) + serviceCnt <- importNtfServiceIds + pure (tCnt, sCnt, nCnt, serviceCnt) where importTokens = do allTokens <- M.elems <$> readTVarIO (tokens stmStore) @@ -697,7 +729,7 @@ importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore skipTokens = do filterSubs allSubs = do let subs = filter (\NtfSubData {tokenId} -> S.member tokenId tIds) allSubs skipped = length allSubs - length subs - when (skipped /= 0) $ putStrLn $ "Skipped " <> show skipped <> " subscriptions of missing tokens" + when (skipped /= 0) $ putStrLn $ "Skipped " <> show skipped <> " subscriptions of missing tokens" let (removedSubTokens, removeSubs, dupQueues) = foldl' addSubToken (S.empty, S.empty, S.empty) subs unless (null removeSubs) $ putStrLn $ "Skipped " <> show (S.size removeSubs) <> " duplicate subscriptions of " <> show (S.size removedSubTokens) <> " tokens for " <> show (S.size dupQueues) <> " queues" pure $ filter (\NtfSubData {ntfSubId} -> S.notMember ntfSubId removeSubs) subs @@ -761,10 +793,22 @@ importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore skipTokens = do else (S.insert tId stIds, cnt', acc) where ntfRow (!qs, !rows) PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} = case M.lookup smpQueue subLookup of - Just ntfSubId -> + Just ntfSubId -> let row = (tId, ntfSubId, systemToUTCTime ntfTs, nmsgNonce, Binary encNMsgMeta) in (qs, row : rows) Nothing -> (S.insert smpQueue qs, rows) + importNtfServiceIds = do + ss <- M.assocs <$> readTVarIO (ntfServices stmStore) + withConnection s $ \db -> DB.executeMany db serviceQuery $ map serviceToRow ss + where + serviceQuery = + [sql| + INSERT INTO smp_servers (smp_host, smp_port, smp_keyhash, ntf_service_id) + VALUES (?, ?, ?, ?) + ON CONFLICT (smp_host, smp_port, smp_keyhash) + DO UPDATE SET ntf_service_id = EXCLUDED.ntf_service_id + |] + serviceToRow (srv, serviceId) = srvToRow srv :. Only serviceId checkCount name expected inserted | fromIntegral expected == inserted = do putStrLn $ "Imported " <> show inserted <> " " <> name <> "s." @@ -799,15 +843,15 @@ exportNtfDbStore NtfPostgresStore {dbStore = s, dbStoreLog = Just sl} lastNtfsFi where ntfSubQuery = [sql| - SELECT s.token_id, s.subscription_id, s.smp_notifier_key, s.status, + SELECT s.token_id, s.subscription_id, s.smp_notifier_key, s.status, s.ntf_service_assoc, p.smp_host, p.smp_port, p.smp_keyhash, s.smp_notifier_id FROM subscriptions s JOIN smp_servers p ON p.smp_server_id = s.smp_server_id |] toNtfSub :: Only NtfTokenId :. NtfSubRow :. SMPQueueNtfRow -> NtfSubRec - toNtfSub (Only tokenId :. (ntfSubId, notifierKey, subStatus) :. qRow) = + toNtfSub (Only tokenId :. (ntfSubId, notifierKey, subStatus, ntfServiceAssoc) :. qRow) = let smpQueue = rowToSMPQueue qRow - in NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus} + in NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus, ntfServiceAssoc} exportLastNtfs = withFile lastNtfsFile WriteMode $ \h -> withConnection s $ \db -> DB.fold_ db lastNtfsQuery 0 $ \ !i (Only tknId :. ntfRow) -> @@ -825,32 +869,32 @@ exportNtfDbStore NtfPostgresStore {dbStore = s, dbStoreLog = Just sl} lastNtfsFi |] encodeLastNtf tknId ntf = strEncode (TNMRv1 tknId ntf) `B.snoc` '\n' -withFastDB' :: String -> NtfPostgresStore -> (DB.Connection -> IO a) -> IO (Either ErrorType a) +withFastDB' :: Text -> NtfPostgresStore -> (DB.Connection -> IO a) -> IO (Either ErrorType a) withFastDB' op st action = withFastDB op st $ fmap Right . action {-# INLINE withFastDB' #-} -withDB' :: String -> NtfPostgresStore -> (DB.Connection -> IO a) -> IO (Either ErrorType a) +withDB' :: Text -> NtfPostgresStore -> (DB.Connection -> IO a) -> IO (Either ErrorType a) withDB' op st action = withDB op st $ fmap Right . action {-# INLINE withDB' #-} -withFastDB :: forall a. String -> NtfPostgresStore -> (DB.Connection -> IO (Either ErrorType a)) -> IO (Either ErrorType a) +withFastDB :: forall a. Text -> NtfPostgresStore -> (DB.Connection -> IO (Either ErrorType a)) -> IO (Either ErrorType a) withFastDB op st = withDB_ op st True {-# INLINE withFastDB #-} -withDB :: forall a. String -> NtfPostgresStore -> (DB.Connection -> IO (Either ErrorType a)) -> IO (Either ErrorType a) +withDB :: forall a. Text -> NtfPostgresStore -> (DB.Connection -> IO (Either ErrorType a)) -> IO (Either ErrorType a) withDB op st = withDB_ op st False {-# INLINE withDB #-} -withDB_ :: forall a. String -> NtfPostgresStore -> Bool -> (DB.Connection -> IO (Either ErrorType a)) -> IO (Either ErrorType a) +withDB_ :: forall a. Text -> NtfPostgresStore -> Bool -> (DB.Connection -> IO (Either ErrorType a)) -> IO (Either ErrorType a) withDB_ op st priority action = E.uninterruptibleMask_ $ E.try (withTransactionPriority (dbStore st) priority action) >>= either logErr pure where logErr :: E.SomeException -> IO (Either ErrorType a) - logErr e = logError ("STORE: " <> T.pack err) $> Left (STORE err) + logErr e = logError ("STORE: " <> err) $> Left (STORE err) where - err = op <> ", withDB, " <> show e + err = op <> ", withDB, " <> tshow e -withLog :: MonadIO m => String -> NtfPostgresStore -> (StoreLog 'WriteMode -> IO ()) -> m () +withLog :: MonadIO m => Text -> NtfPostgresStore -> (StoreLog 'WriteMode -> IO ()) -> m () withLog op NtfPostgresStore {dbStoreLog} = withLog_ op dbStoreLog {-# INLINE withLog #-} diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Types.hs b/src/Simplex/Messaging/Notifications/Server/Store/Types.hs index 76233290b..39e303340 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Types.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Types.hs @@ -9,6 +9,7 @@ module Simplex.Messaging.Notifications.Server.Store.Types where import Control.Applicative (optional) import Control.Concurrent.STM import qualified Data.ByteString.Char8 as B +import Data.Maybe (fromMaybe) import Data.Word (Word16) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String @@ -77,30 +78,36 @@ data NtfSubRec = NtfSubRec smpQueue :: SMPQueueNtf, notifierKey :: NtfPrivateAuthKey, tokenId :: NtfTokenId, - subStatus :: NtfSubStatus + subStatus :: NtfSubStatus, + ntfServiceAssoc :: NtfAssociatedService -- Bool } deriving (Show) type ServerNtfSub = (NtfSubscriptionId, (NotifierId, NtfPrivateAuthKey)) +type NtfAssociatedService = Bool + mkSubData :: NtfSubRec -> IO NtfSubData -mkSubData NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus = status} = do +mkSubData NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus = status, ntfServiceAssoc = serviceAssoc} = do subStatus <- newTVarIO status - pure NtfSubData {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} + ntfServiceAssoc <- newTVarIO serviceAssoc + pure NtfSubData {ntfSubId, smpQueue, notifierKey, tokenId, subStatus, ntfServiceAssoc} mkSubRec :: NtfSubData -> IO NtfSubRec -mkSubRec NtfSubData {ntfSubId, smpQueue, notifierKey, tokenId, subStatus = status} = do +mkSubRec NtfSubData {ntfSubId, smpQueue, notifierKey, tokenId, subStatus = status, ntfServiceAssoc = serviceAssoc} = do subStatus <- readTVarIO status - pure NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} + ntfServiceAssoc <- readTVarIO serviceAssoc + pure NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus, ntfServiceAssoc} instance StrEncoding NtfSubRec where - strEncode NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} = + strEncode NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus, ntfServiceAssoc} = B.unwords [ "subId=" <> strEncode ntfSubId, "smpQueue=" <> strEncode smpQueue, "notifierKey=" <> strEncode notifierKey, "tknId=" <> strEncode tokenId, - "subStatus=" <> strEncode subStatus + "subStatus=" <> strEncode subStatus, + "serviceAssoc=" <> strEncode ntfServiceAssoc ] strP = do ntfSubId <- "subId=" *> strP_ @@ -108,4 +115,5 @@ instance StrEncoding NtfSubRec where notifierKey <- "notifierKey=" *> strP_ tokenId <- "tknId=" *> strP_ subStatus <- "subStatus=" *> strP - pure NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} + ntfServiceAssoc <- fromMaybe False <$> optional (" serviceAssoc=" *> strP) + pure NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus, ntfServiceAssoc} diff --git a/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql b/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql index 4c98a1161..3b155fa1a 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql +++ b/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql @@ -52,7 +52,8 @@ CREATE TABLE ntf_server.smp_servers ( smp_server_id bigint NOT NULL, smp_host text NOT NULL, smp_port text NOT NULL, - smp_keyhash bytea NOT NULL + smp_keyhash bytea NOT NULL, + ntf_service_id bytea ); @@ -74,7 +75,8 @@ CREATE TABLE ntf_server.subscriptions ( smp_server_id bigint, smp_notifier_id bytea NOT NULL, smp_notifier_key bytea NOT NULL, - status text NOT NULL + status text NOT NULL, + ntf_service_assoc boolean DEFAULT false NOT NULL ); @@ -140,7 +142,7 @@ CREATE UNIQUE INDEX idx_subscriptions_smp_server_id_notifier_id ON ntf_server.su -CREATE INDEX idx_subscriptions_smp_server_id_status ON ntf_server.subscriptions USING btree (smp_server_id, status); +CREATE INDEX idx_subscriptions_smp_server_id_ntf_service_status ON ntf_server.subscriptions USING btree (smp_server_id, ntf_service_assoc, status); diff --git a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs b/src/Simplex/Messaging/Notifications/Server/StoreLog.hs index 87c09826e..e71ebaf57 100644 --- a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs +++ b/src/Simplex/Messaging/Notifications/Server/StoreLog.hs @@ -24,17 +24,21 @@ module Simplex.Messaging.Notifications.Server.StoreLog ) where +import Control.Applicative (optional, (<|>)) import Control.Concurrent.STM import Control.Monad import qualified Data.Attoparsec.ByteString.Char8 as A import qualified Data.ByteString.Base64.URL as B64 import qualified Data.ByteString.Char8 as B +import Data.Functor (($>)) +import qualified Data.Map.Strict as M +import Data.Maybe (fromMaybe) import Data.Word (Word16) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Store import Simplex.Messaging.Notifications.Server.Store.Types -import Simplex.Messaging.Protocol (EntityId (..)) +import Simplex.Messaging.Protocol (EntityId (..), SMPServer, ServiceId) import Simplex.Messaging.Server.QueueStore (RoundedSystemTime) import Simplex.Messaging.Server.StoreLog import System.IO @@ -47,8 +51,9 @@ data NtfStoreLogRecord | DeleteToken NtfTokenId | UpdateTokenTime NtfTokenId RoundedSystemTime | CreateSubscription NtfSubRec - | SubscriptionStatus NtfSubscriptionId NtfSubStatus + | SubscriptionStatus NtfSubscriptionId NtfSubStatus NtfAssociatedService | DeleteSubscription NtfSubscriptionId + | SetNtfService SMPServer (Maybe ServiceId) deriving (Show) instance StrEncoding NtfStoreLogRecord where @@ -60,8 +65,11 @@ instance StrEncoding NtfStoreLogRecord where DeleteToken tknId -> strEncode (Str "TDELETE", tknId) UpdateTokenTime tknId ts -> strEncode (Str "TTIME", tknId, ts) CreateSubscription subRec -> strEncode (Str "SCREATE", subRec) - SubscriptionStatus subId subStatus -> strEncode (Str "SSTATUS", subId, subStatus) + SubscriptionStatus subId subStatus serviceAssoc -> strEncode (Str "SSTATUS", subId, subStatus) <> serviceStr + where + serviceStr = if serviceAssoc then " service=" <> strEncode True else "" DeleteSubscription subId -> strEncode (Str "SDELETE", subId) + SetNtfService srv serviceId -> strEncode (Str "SERVICE", srv) <> " service=" <> maybe "off" strEncode serviceId strP = A.choice [ "TCREATE " *> (CreateToken <$> strP), @@ -71,8 +79,9 @@ instance StrEncoding NtfStoreLogRecord where "TDELETE " *> (DeleteToken <$> strP), "TTIME " *> (UpdateTokenTime <$> strP_ <*> strP), "SCREATE " *> (CreateSubscription <$> strP), - "SSTATUS " *> (SubscriptionStatus <$> strP_ <*> strP), - "SDELETE " *> (DeleteSubscription <$> strP) + "SSTATUS " *> (SubscriptionStatus <$> strP_ <*> strP <*> (fromMaybe False <$> optional (" service=" *> strP))), + "SDELETE " *> (DeleteSubscription <$> strP), + "SERVICE " *> (SetNtfService <$> strP <* " service=" <*> ("off" $> Nothing <|> strP)) ] logNtfStoreRecord :: StoreLog 'WriteMode -> NtfStoreLogRecord -> IO () @@ -100,12 +109,15 @@ logUpdateTokenTime s tknId t = logNtfStoreRecord s $ UpdateTokenTime tknId t logCreateSubscription :: StoreLog 'WriteMode -> NtfSubRec -> IO () logCreateSubscription s = logNtfStoreRecord s . CreateSubscription -logSubscriptionStatus :: StoreLog 'WriteMode -> NtfSubscriptionId -> NtfSubStatus -> IO () -logSubscriptionStatus s subId subStatus = logNtfStoreRecord s $ SubscriptionStatus subId subStatus +logSubscriptionStatus :: StoreLog 'WriteMode -> (NtfSubscriptionId, NtfSubStatus, NtfAssociatedService) -> IO () +logSubscriptionStatus s (subId, subStatus, serviceAssoc) = logNtfStoreRecord s $ SubscriptionStatus subId subStatus serviceAssoc logDeleteSubscription :: StoreLog 'WriteMode -> NtfSubscriptionId -> IO () logDeleteSubscription s subId = logNtfStoreRecord s $ DeleteSubscription subId +logSetNtfService :: StoreLog 'WriteMode -> SMPServer -> Maybe ServiceId -> IO () +logSetNtfService s srv serviceId = logNtfStoreRecord s $ SetNtfService srv serviceId + readWriteNtfSTMStore :: Bool -> FilePath -> NtfSTMStore -> IO (StoreLog 'WriteMode) readWriteNtfSTMStore tty = readWriteStoreLog (readNtfStore tty) writeNtfStore @@ -147,13 +159,19 @@ readNtfStore tty f st = readLogLines tty f $ \_ -> processLine Nothing -> B.putStrLn $ "Warning: no token " <> enc tokenId <> ", subscription " <> enc ntfSubId where enc = B64.encode . unEntityId - SubscriptionStatus subId status -> do - stmGetNtfSubscriptionIO st subId - >>= mapM_ (\NtfSubData {subStatus} -> atomically $ writeTVar subStatus status) + SubscriptionStatus subId status serviceAssoc -> do + stmGetNtfSubscriptionIO st subId >>= mapM_ update + where + update NtfSubData {subStatus, ntfServiceAssoc} = atomically $ do + writeTVar subStatus status + writeTVar ntfServiceAssoc serviceAssoc DeleteSubscription subId -> atomically $ stmDeleteNtfSubscription st subId + SetNtfService srv serviceId -> + atomically $ stmSetNtfService st srv serviceId writeNtfStore :: StoreLog 'WriteMode -> NtfSTMStore -> IO () -writeNtfStore s NtfSTMStore {tokens, subscriptions} = do +writeNtfStore s NtfSTMStore {tokens, subscriptions, ntfServices} = do mapM_ (logCreateToken s <=< mkTknRec) =<< readTVarIO tokens mapM_ (logCreateSubscription s <=< mkSubRec) =<< readTVarIO subscriptions + mapM_ (\(srv, serviceId) -> logSetNtfService s srv $ Just serviceId) . M.assocs =<< readTVarIO ntfServices diff --git a/src/Simplex/Messaging/Notifications/Transport.hs b/src/Simplex/Messaging/Notifications/Transport.hs index fb5258933..15f923102 100644 --- a/src/Simplex/Messaging/Notifications/Transport.hs +++ b/src/Simplex/Messaging/Notifications/Transport.hs @@ -126,8 +126,8 @@ ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do Nothing -> throwE TEVersion -- | Notifcations server client transport handshake. -ntfClientHandshake :: forall c. Transport c => c 'TClient -> C.KeyHash -> VersionRangeNTF -> Bool -> ExceptT TransportError IO (THandleNTF c 'TClient) -ntfClientHandshake c keyHash ntfVRange _proxyServer = do +ntfClientHandshake :: forall c. Transport c => c 'TClient -> C.KeyHash -> VersionRangeNTF -> Bool -> Maybe (ServiceCredentials, C.KeyPairEd25519) -> ExceptT TransportError IO (THandleNTF c 'TClient) +ntfClientHandshake c keyHash ntfVRange _proxyServer _serviceKeys = do let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th if sessionId /= sessId @@ -145,12 +145,13 @@ ntfClientHandshake c keyHash ntfVRange _proxyServer = do ntfThHandleServer :: forall c. THandleNTF c 'TServer -> VersionNTF -> VersionRangeNTF -> C.PrivateKeyX25519 -> THandleNTF c 'TServer ntfThHandleServer th v vr pk = - let thAuth = THAuthServer {serverPrivKey = pk, sessSecret' = Nothing} + let thAuth = THAuthServer {serverPrivKey = pk, peerClientService = Nothing, sessSecret' = Nothing} in ntfThHandle_ th v vr (Just thAuth) ntfThHandleClient :: forall c. THandleNTF c 'TClient -> VersionNTF -> VersionRangeNTF -> Maybe (C.PublicKeyX25519, CertChainPubKey) -> THandleNTF c 'TClient ntfThHandleClient th v vr ck_ = - let thAuth = (\(k, ck) -> THAuthClient {serverPeerPubKey = k, serverCertKey = ck, sessSecret = Nothing}) <$> ck_ + let thAuth = clientTHParams <$> ck_ + clientTHParams (k, ck) = THAuthClient {peerServerPubKey = k, peerServerCertKey = ck, clientService = Nothing, sessSecret = Nothing} in ntfThHandle_ th v vr thAuth ntfThHandle_ :: forall c p. THandleNTF c p -> VersionNTF -> VersionRangeNTF -> Maybe (THandleAuth p) -> THandleNTF c p @@ -173,5 +174,6 @@ ntfTHandle c = THandle {connection = c, params} thAuth = Nothing, implySessId = False, encryptBlock = Nothing, - batch = False + batch = False, + serviceAuth = False } diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 41fbaf37f..943a49822 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -68,6 +68,7 @@ module Simplex.Messaging.Protocol Cmd (..), DirectParty, SubscriberParty, + ASubscriberParty (..), BrokerMsg (..), SParty (..), PartyI (..), @@ -80,6 +81,7 @@ module Simplex.Messaging.Protocol BlockingInfo (..), BlockingReason (..), Transmission, + TAuthorizations, TransmissionAuth (..), SignedTransmission, SentRawTransmission, @@ -117,6 +119,7 @@ module Simplex.Messaging.Protocol SenderId, LinkId, NotifierId, + ServiceId, RcvPrivateAuthKey, RcvPublicAuthKey, RcvPublicDhKey, @@ -150,6 +153,8 @@ module Simplex.Messaging.Protocol currentSMPClientVersion, senderCanSecure, queueReqMode, + subscriberParty, + subscriberServiceRole, userProtocol, rcvMessageMeta, noMsgFlags, @@ -198,7 +203,6 @@ where import Control.Applicative (optional, (<|>)) import Control.Exception (Exception) -import Control.Monad import Control.Monad.Except import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson.TH as J @@ -211,11 +215,13 @@ import qualified Data.ByteString.Char8 as B import Data.Char (isPrint, isSpace) import Data.Constraint (Dict (..)) import Data.Functor (($>)) +import Data.Int (Int64) import Data.Kind import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Maybe (isJust, isNothing) import Data.String +import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time.Clock.System (SystemTime (..), systemToUTCTime) @@ -297,7 +303,7 @@ e2eEncMessageLength :: Int e2eEncMessageLength = 16000 -- 15988 .. 16005 -- | SMP protocol clients -data Party = Recipient | Sender | Notifier | LinkClient | ProxiedClient +data Party = Recipient | Sender | Notifier | LinkClient | ProxiedClient | ProxyService deriving (Show) -- | Singleton types for SMP protocol clients @@ -307,6 +313,7 @@ data SParty :: Party -> Type where SNotifier :: SParty Notifier SSenderLink :: SParty LinkClient SProxiedClient :: SParty ProxiedClient + SProxyService :: SParty ProxyService instance TestEquality SParty where testEquality SRecipient SRecipient = Just Refl @@ -314,6 +321,7 @@ instance TestEquality SParty where testEquality SNotifier SNotifier = Just Refl testEquality SSenderLink SSenderLink = Just Refl testEquality SProxiedClient SProxiedClient = Just Refl + testEquality SProxyService SProxyService = Just Refl testEquality _ _ = Nothing deriving instance Show (SParty p) @@ -330,11 +338,14 @@ instance PartyI LinkClient where sParty = SSenderLink instance PartyI ProxiedClient where sParty = SProxiedClient +instance PartyI ProxyService where sParty = SProxyService + type family DirectParty (p :: Party) :: Constraint where DirectParty Recipient = () DirectParty Sender = () DirectParty Notifier = () DirectParty LinkClient = () + DirectParty ProxyService = () DirectParty p = (Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not direct")) @@ -344,6 +355,40 @@ type family SubscriberParty (p :: Party) :: Constraint where SubscriberParty p = (Int ~ Bool, TypeError (Type.Text "Party " :<>: ShowType p :<>: Type.Text " is not subscriber")) +data ASubscriberParty = forall p. (PartyI p, SubscriberParty p) => ASP (SParty p) + +deriving instance Show ASubscriberParty + +instance Eq ASubscriberParty where + ASP p == ASP p' = isJust $ testEquality p p' + +instance Encoding ASubscriberParty where + smpEncode = \case + ASP SRecipient -> "R" + ASP SNotifier -> "N" + smpP = + A.anyChar >>= \case + 'R' -> pure $ ASP SRecipient + 'N' -> pure $ ASP SNotifier + _ -> fail "bad ASubscriberParty" + +instance StrEncoding ASubscriberParty where + strEncode = smpEncode + strP = smpP + +subscriberParty :: SParty p -> Maybe (Dict (PartyI p, SubscriberParty p)) +subscriberParty = \case + SRecipient -> Just Dict + SNotifier -> Just Dict + _ -> Nothing +{-# INLINE subscriberParty #-} + +subscriberServiceRole :: SubscriberParty p => SParty p -> SMPServiceRole +subscriberServiceRole = \case + SRecipient -> SRMessaging + SNotifier -> SRNotifier +{-# INLINE subscriberServiceRole #-} + -- | Type for client command of any participant. data Cmd = forall p. PartyI p => Cmd (SParty p) (Command p) @@ -353,13 +398,14 @@ deriving instance Show Cmd type Transmission c = (CorrId, EntityId, c) -- | signed parsed transmission, with original raw bytes and parsing error. -type SignedTransmission e c = (Maybe TransmissionAuth, Signed, Transmission (Either e c)) +type SignedTransmission e c = (Maybe TAuthorizations, Signed, Transmission (Either e c)) type Signed = ByteString -- | unparsed SMP transmission with signature. data RawTransmission = RawTransmission { authenticator :: ByteString, -- signature or encrypted transmission hash + serviceSig :: Maybe (C.Signature 'C.Ed25519), -- optional second signature with the key of the client service authorized :: ByteString, -- authorized transmission sessId :: SessionId, corrId :: CorrId, @@ -368,32 +414,36 @@ data RawTransmission = RawTransmission } deriving (Show) +type TAuthorizations = (TransmissionAuth, Maybe (C.Signature 'C.Ed25519)) + data TransmissionAuth = TASignature C.ASignature | TAAuthenticator C.CbAuthenticator deriving (Show) --- this encoding is backwards compatible with v6 that used Maybe C.ASignature instead of TAuthorization -tAuthBytes :: Maybe TransmissionAuth -> ByteString -tAuthBytes = \case - Nothing -> "" - Just (TASignature s) -> C.signatureBytes s - Just (TAAuthenticator (C.CbAuthenticator s)) -> s +-- this encoding is backwards compatible with v6 that used Maybe C.ASignature instead of TransmissionAuth +tEncodeAuth :: Bool -> Maybe TAuthorizations -> ByteString +tEncodeAuth serviceAuth = \case + Nothing -> smpEncode B.empty + Just (auth, sig) + | serviceAuth -> smpEncode (authBytes auth, sig) + | otherwise -> smpEncode (authBytes auth) + where + authBytes = \case + TASignature s -> C.signatureBytes s + TAAuthenticator (C.CbAuthenticator s) -> s -decodeTAuthBytes :: ByteString -> Either String (Maybe TransmissionAuth) -decodeTAuthBytes s +decodeTAuthBytes :: ByteString -> Maybe (C.Signature 'C.Ed25519) -> Either String (Maybe TAuthorizations) +decodeTAuthBytes s serviceSig | B.null s = Right Nothing - | B.length s == C.cbAuthenticatorSize = Right . Just . TAAuthenticator $ C.CbAuthenticator s - | otherwise = Just . TASignature <$> C.decodeSignature s - -instance IsString (Maybe TransmissionAuth) where - fromString = parseString $ B64.decode >=> C.decodeSignature >=> pure . fmap TASignature + | B.length s == C.cbAuthenticatorSize = Right $ Just (TAAuthenticator (C.CbAuthenticator s), serviceSig) + | otherwise = (\sig -> Just (TASignature sig, serviceSig)) <$> C.decodeSignature s -- | unparsed sent SMP transmission with signature, without session ID. -type SignedRawTransmission = (Maybe TransmissionAuth, CorrId, EntityId, ByteString) +type SignedRawTransmission = (Maybe TAuthorizations, CorrId, EntityId, ByteString) -- | unparsed sent SMP transmission with signature. -type SentRawTransmission = (Maybe TransmissionAuth, ByteString) +type SentRawTransmission = (Maybe TAuthorizations, ByteString) -- | SMP queue ID for the recipient. type RecipientId = QueueId @@ -409,14 +459,6 @@ type LinkId = QueueId -- | SMP queue ID on the server. type QueueId = EntityId --- this type is used for server entities only -newtype EntityId = EntityId {unEntityId :: ByteString} - deriving (Eq, Ord, Show) - deriving newtype (Encoding, StrEncoding) - -pattern NoEntity :: EntityId -pattern NoEntity = EntityId "" - -- | Parameterized type for SMP protocol commands from all clients. data Command (p :: Party) where -- SMP recipient commands @@ -426,6 +468,8 @@ data Command (p :: Party) where -- RcvPublicAuthKey is defined as C.APublicKey - it can be either signature or DH public keys. NEW :: NewQueueReq -> Command Recipient SUB :: Command Recipient + -- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command. + SUBS :: Command Recipient KEY :: SndPublicAuthKey -> Command Recipient RKEY :: NonEmpty RcvPublicAuthKey -> Command Recipient LSET :: LinkId -> QueueLinkData -> Command Recipient @@ -448,6 +492,8 @@ data Command (p :: Party) where LGET :: Command LinkClient -- SMP notification subscriber commands NSUB :: Command Notifier + -- | subscribe all associated queues. Service ID must be used as entity ID, and service session key must sign the command. + NSUBS :: Command Notifier PRXY :: SMPServer -> Maybe BasicAuth -> Command ProxiedClient -- request a relay server connection by URI -- Transmission to proxy: -- - entity ID: ID of the session with relay returned in PKEY (response to PRXY) @@ -458,7 +504,7 @@ data Command (p :: Party) where -- Transmission forwarded to relay: -- - entity ID: empty -- - corrId: unique correlation ID between proxy and relay, also used as a nonce to encrypt forwarded transmission - RFWD :: EncFwdTransmission -> Command Sender -- use CorrId as CbNonce, proxy to relay + RFWD :: EncFwdTransmission -> Command ProxyService -- use CorrId as CbNonce, proxy to relay deriving instance Show (Command p) @@ -574,6 +620,10 @@ data BrokerMsg where -- SMP broker messages (responses, client messages, notifications) IDS :: QueueIdsKeys -> BrokerMsg LNK :: SenderId -> QueueLinkData -> BrokerMsg + -- | Service subscription success - confirms when queue was associated with the service + SOK :: Maybe ServiceId -> BrokerMsg + -- | The number of queues subscribed with SUBS command + SOKS :: Int64 -> BrokerMsg -- MSG v1/2 has to be supported for encoding/decoding -- v1: MSG :: MsgId -> SystemTime -> MsgBody -> BrokerMsg -- v2: MsgId -> SystemTime -> MsgFlags -> MsgBody -> BrokerMsg @@ -585,6 +635,7 @@ data BrokerMsg where RRES :: EncFwdResponse -> BrokerMsg -- relay to proxy PRES :: EncResponse -> BrokerMsg -- proxy to client END :: BrokerMsg + ENDS :: Int64 -> BrokerMsg DELD :: BrokerMsg INFO :: QueueInfo -> BrokerMsg OK :: BrokerMsg @@ -778,6 +829,7 @@ noMsgFlags = MsgFlags {notification = False} data CommandTag (p :: Party) where NEW_ :: CommandTag Recipient SUB_ :: CommandTag Recipient + SUBS_ :: CommandTag Recipient KEY_ :: CommandTag Recipient RKEY_ :: CommandTag Recipient LSET_ :: CommandTag Recipient @@ -796,8 +848,9 @@ data CommandTag (p :: Party) where LGET_ :: CommandTag LinkClient PRXY_ :: CommandTag ProxiedClient PFWD_ :: CommandTag ProxiedClient - RFWD_ :: CommandTag Sender + RFWD_ :: CommandTag ProxyService NSUB_ :: CommandTag Notifier + NSUBS_ :: CommandTag Notifier data CmdTag = forall p. PartyI p => CT (SParty p) (CommandTag p) @@ -808,6 +861,8 @@ deriving instance Show CmdTag data BrokerMsgTag = IDS_ | LNK_ + | SOK_ + | SOKS_ | MSG_ | NID_ | NMSG_ @@ -815,6 +870,7 @@ data BrokerMsgTag | RRES_ | PRES_ | END_ + | ENDS_ | DELD_ | INFO_ | OK_ @@ -834,6 +890,7 @@ instance PartyI p => Encoding (CommandTag p) where smpEncode = \case NEW_ -> "NEW" SUB_ -> "SUB" + SUBS_ -> "SUBS" KEY_ -> "KEY" RKEY_ -> "RKEY" LSET_ -> "LSET" @@ -854,12 +911,14 @@ instance PartyI p => Encoding (CommandTag p) where PFWD_ -> "PFWD" RFWD_ -> "RFWD" NSUB_ -> "NSUB" + NSUBS_ -> "NSUBS" smpP = messageTagP instance ProtocolMsgTag CmdTag where decodeTag = \case "NEW" -> Just $ CT SRecipient NEW_ "SUB" -> Just $ CT SRecipient SUB_ + "SUBS" -> Just $ CT SRecipient SUBS_ "KEY" -> Just $ CT SRecipient KEY_ "RKEY" -> Just $ CT SRecipient RKEY_ "LSET" -> Just $ CT SRecipient LSET_ @@ -878,8 +937,9 @@ instance ProtocolMsgTag CmdTag where "LGET" -> Just $ CT SSenderLink LGET_ "PRXY" -> Just $ CT SProxiedClient PRXY_ "PFWD" -> Just $ CT SProxiedClient PFWD_ - "RFWD" -> Just $ CT SSender RFWD_ + "RFWD" -> Just $ CT SProxyService RFWD_ "NSUB" -> Just $ CT SNotifier NSUB_ + "NSUBS" -> Just $ CT SNotifier NSUBS_ _ -> Nothing instance Encoding CmdTag where @@ -893,6 +953,8 @@ instance Encoding BrokerMsgTag where smpEncode = \case IDS_ -> "IDS" LNK_ -> "LNK" + SOK_ -> "SOK" + SOKS_ -> "SOKS" MSG_ -> "MSG" NID_ -> "NID" NMSG_ -> "NMSG" @@ -900,6 +962,7 @@ instance Encoding BrokerMsgTag where RRES_ -> "RRES" PRES_ -> "PRES" END_ -> "END" + ENDS_ -> "ENDS" DELD_ -> "DELD" INFO_ -> "INFO" OK_ -> "OK" @@ -911,6 +974,8 @@ instance ProtocolMsgTag BrokerMsgTag where decodeTag = \case "IDS" -> Just IDS_ "LNK" -> Just LNK_ + "SOK" -> Just SOK_ + "SOKS" -> Just SOKS_ "MSG" -> Just MSG_ "NID" -> Just NID_ "NMSG" -> Just NMSG_ @@ -918,6 +983,7 @@ instance ProtocolMsgTag BrokerMsgTag where "RRES" -> Just RRES_ "PRES" -> Just PRES_ "END" -> Just END_ + "ENDS" -> Just ENDS_ "DELD" -> Just DELD_ "INFO" -> Just INFO_ "OK" -> Just OK_ @@ -1257,7 +1323,8 @@ data QueueIdsKeys = QIK sndId :: SenderId, rcvPublicDhKey :: RcvPublicDhKey, queueMode :: Maybe QueueMode, -- TODO remove Maybe when min version is 9 (sndAuthKeySMPVersion) - linkId :: Maybe LinkId + linkId :: Maybe LinkId, + serviceId :: Maybe ServiceId -- TODO [notifications] -- serverNtfCreds :: Maybe ServerNtfCreds } @@ -1327,12 +1394,14 @@ data ErrorType AUTH | -- | command with the entity that was blocked BLOCKED {blockInfo :: BlockingInfo} + | -- | service signature is not allowed for command or session; service command is sent not in service session + SERVICE | -- | encryption/decryption error in proxy protocol CRYPTO | -- | SMP queue capacity is exceeded on the server QUOTA | -- | SMP server storage error - STORE {storeErr :: String} + STORE {storeErr :: Text} | -- | ACK command is sent without message to be acknowledged NO_MSG | -- | sent message is too large (> maxMessageLength = 16088 bytes) @@ -1353,9 +1422,10 @@ instance StrEncoding ErrorType where PROXY e -> "PROXY " <> strEncode e AUTH -> "AUTH" BLOCKED info -> "BLOCKED " <> strEncode info + SERVICE -> "SERVICE" CRYPTO -> "CRYPTO" QUOTA -> "QUOTA" - STORE e -> "STORE " <> encodeUtf8 (T.pack e) + STORE e -> "STORE " <> encodeUtf8 e NO_MSG -> "NO_MSG" LARGE_MSG -> "LARGE_MSG" EXPIRED -> "EXPIRED" @@ -1369,9 +1439,10 @@ instance StrEncoding ErrorType where "PROXY " *> (PROXY <$> strP), "AUTH" $> AUTH, "BLOCKED " *> strP, + "SERVICE" $> SERVICE, "CRYPTO" $> CRYPTO, "QUOTA" $> QUOTA, - "STORE " *> (STORE . T.unpack . safeDecodeUtf8 <$> A.takeByteString), + "STORE " *> (STORE . safeDecodeUtf8 <$> A.takeByteString), "NO_MSG" $> NO_MSG, "LARGE_MSG" $> LARGE_MSG, "EXPIRED" $> EXPIRED, @@ -1385,7 +1456,7 @@ data CommandError UNKNOWN | -- | error parsing command SYNTAX - | -- | command is not allowed (SUB/GET cannot be used with the same queue in the same TCP connection) + | -- | command is not allowed (bad service role, or SUB/GET used with the same queue in the same TCP session) PROHIBITED | -- | transmission has no required credentials (signature or queue ID) NO_AUTH @@ -1417,6 +1488,8 @@ data BrokerErrorType NETWORK | -- | no compatible server host (e.g. onion when public is required, or vice versa) HOST + | -- | service unavailable client-side - used in agent errors + NO_SERVICE | -- | handshake or other transport error TRANSPORT {transportErr :: TransportError} | -- | command response timeout @@ -1456,23 +1529,25 @@ instance FromJSON BlockingReason where -- | SMP transmission parser. transmissionP :: THandleParams v p -> Parser RawTransmission -transmissionP THandleParams {sessionId, implySessId} = do +transmissionP THandleParams {sessionId, implySessId, serviceAuth} = do authenticator <- smpP + serviceSig <- if serviceAuth && not (B.null authenticator) then smpP else pure Nothing authorized <- A.takeByteString - either fail pure $ parseAll (trn authenticator authorized) authorized + either fail pure $ parseAll (trn authenticator serviceSig authorized) authorized where - trn authenticator authorized = do + trn authenticator serviceSig authorized = do sessId <- if implySessId then pure "" else smpP let authorized' = if implySessId then smpEncode sessionId <> authorized else authorized corrId <- smpP entityId <- smpP command <- A.takeByteString - pure RawTransmission {authenticator, authorized = authorized', sessId, corrId, entityId, command} + pure RawTransmission {authenticator, serviceSig, authorized = authorized', sessId, corrId, entityId, command} class (ProtocolTypeI (ProtoType msg), ProtocolEncoding v err msg, ProtocolEncoding v err (ProtoCommand msg), Show err, Show msg) => Protocol v err msg | msg -> v, msg -> err where type ProtoCommand msg = cmd | cmd -> msg type ProtoType msg = (sch :: ProtocolType) | sch -> msg - protocolClientHandshake :: forall c. Transport c => c 'TClient -> Maybe C.KeyPairX25519 -> C.KeyHash -> VersionRange v -> Bool -> ExceptT TransportError IO (THandle v c 'TClient) + protocolClientHandshake :: Transport c => c 'TClient -> Maybe C.KeyPairX25519 -> C.KeyHash -> VersionRange v -> Bool -> Maybe (ServiceCredentials, C.KeyPairEd25519) -> ExceptT TransportError IO (THandle v c 'TClient) + useServiceAuth :: ProtoCommand msg -> Bool protocolPing :: ProtoCommand msg protocolError :: msg -> Maybe err @@ -1482,10 +1557,19 @@ instance Protocol SMPVersion ErrorType BrokerMsg where type ProtoCommand BrokerMsg = Cmd type ProtoType BrokerMsg = 'PSMP protocolClientHandshake = smpClientHandshake + {-# INLINE protocolClientHandshake #-} + useServiceAuth = \case + Cmd _ (NEW _) -> True + Cmd _ SUB -> True + Cmd _ NSUB -> True + _ -> False + {-# INLINE useServiceAuth #-} protocolPing = Cmd SSender PING + {-# INLINE protocolPing #-} protocolError = \case ERR e -> Just e _ -> Nothing + {-# INLINE protocolError #-} class ProtocolMsgTag (Tag msg) => ProtocolEncoding v err msg | msg -> err, msg -> v where type Tag msg @@ -1505,6 +1589,7 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where new = e (NEW_, ' ', rKey, dhKey) auth = maybe "" (e . ('A',)) auth_ SUB -> e SUB_ + SUBS -> e SUBS_ KEY k -> e (KEY_, ' ', k) RKEY ks -> e (RKEY_, ' ', ks) LSET lnkId d -> e (LSET_, ' ', lnkId, d) @@ -1520,6 +1605,7 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where SEND flags msg -> e (SEND_, ' ', flags, ' ', Tail msg) PING -> e PING_ NSUB -> e NSUB_ + NSUBS -> e NSUBS_ LKEY k -> e (LKEY_, ' ', k) LGET -> e LGET_ PRXY host auth_ -> e (PRXY_, ' ', host, auth_) @@ -1549,6 +1635,8 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where PRXY {} -> noAuthCmd PFWD {} -> entityCmd RFWD _ -> noAuthCmd + SUB -> serviceCmd + NSUB -> serviceCmd -- other client commands must have both signature and queue ID _ | isNothing auth || B.null entId -> Left $ CMD NO_AUTH @@ -1564,10 +1652,15 @@ instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where | B.null entId = Left $ CMD NO_ENTITY | isNothing auth = Right cmd | otherwise = Left $ CMD HAS_AUTH + serviceCmd :: Either ErrorType (Command p) + serviceCmd + | isNothing auth || B.null entId = Left $ CMD NO_AUTH + | otherwise = Right cmd instance ProtocolEncoding SMPVersion ErrorType Cmd where type Tag Cmd = CmdTag encodeProtocol v (Cmd _ c) = encodeProtocol v c + {-# INLINE encodeProtocol #-} protocolP v = \case CT SRecipient tag -> @@ -1589,6 +1682,7 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where auth = optional (A.char 'A' *> smpP) qReq sndSecure = Just $ if sndSecure then QRMessaging Nothing else QRContact Nothing SUB_ -> pure SUB + SUBS_ -> pure SUBS KEY_ -> KEY <$> _smpP RKEY_ -> RKEY <$> _smpP LSET_ -> LSET <$> _smpP <*> smpP @@ -1605,7 +1699,8 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where SKEY_ -> SKEY <$> _smpP SEND_ -> SEND <$> _smpP <*> (unTail <$> _smpP) PING_ -> pure PING - RFWD_ -> RFWD <$> (EncFwdTransmission . unTail <$> _smpP) + CT SProxyService RFWD_ -> + Cmd SProxyService . RFWD . EncFwdTransmission . unTail <$> _smpP CT SSenderLink tag -> Cmd SSenderLink <$> case tag of LKEY_ -> LKEY <$> _smpP @@ -1614,23 +1709,32 @@ instance ProtocolEncoding SMPVersion ErrorType Cmd where Cmd SProxiedClient <$> case tag of PFWD_ -> PFWD <$> _smpP <*> smpP <*> (EncTransmission . unTail <$> smpP) PRXY_ -> PRXY <$> _smpP <*> smpP - CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB + CT SNotifier tag -> + pure $ Cmd SNotifier $ case tag of + NSUB_ -> NSUB + NSUBS_ -> NSUBS fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg {-# INLINE fromProtocolError #-} checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c + {-# INLINE checkCredentials #-} instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where type Tag BrokerMsg = BrokerMsgTag encodeProtocol v = \case - IDS QIK {rcvId, sndId, rcvPublicDhKey = srvDh, queueMode, linkId} + IDS QIK {rcvId, sndId, rcvPublicDhKey = srvDh, queueMode, linkId, serviceId} + | v >= serviceCertsSMPVersion -> ids <> e queueMode <> e linkId <> e serviceId | v >= shortLinksSMPVersion -> ids <> e queueMode <> e linkId | v >= sndAuthKeySMPVersion -> ids <> e (senderCanSecure queueMode) | otherwise -> ids where ids = e (IDS_, ' ', rcvId, sndId, srvDh) LNK sId d -> e (LNK_, ' ', sId, d) + SOK serviceId_ + | v >= serviceCertsSMPVersion -> e (SOK_, ' ', serviceId_) + | otherwise -> e OK_ -- won't happen, the association with the service requires v >= serviceCertsSMPVersion + SOKS n -> e (SOKS_, ' ', n) MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -> e (MSG_, ' ', msgId, Tail body) NID nId srvNtfDh -> e (NID_, ' ', nId, srvNtfDh) @@ -1639,6 +1743,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where RRES (EncFwdResponse encBlock) -> e (RRES_, ' ', Tail encBlock) PRES (EncResponse encBlock) -> e (PRES_, ' ', Tail encBlock) END -> e END_ + ENDS n -> e (ENDS_, ' ', n) DELD | v >= deletedEventSMPVersion -> e DELD_ | otherwise -> e END_ @@ -1659,28 +1764,33 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where where bodyP = EncRcvMsgBody . unTail <$> smpP IDS_ - | v >= shortLinksSMPVersion -> ids smpP smpP - | v >= sndAuthKeySMPVersion -> ids (qm <$> smpP) nothing - | otherwise -> ids nothing nothing + | v >= serviceCertsSMPVersion -> ids smpP smpP smpP + | v >= shortLinksSMPVersion -> ids smpP smpP nothing + | v >= sndAuthKeySMPVersion -> ids (qm <$> smpP) nothing nothing + | otherwise -> ids nothing nothing nothing where qm sndSecure = Just $ if sndSecure then QMMessaging else QMContact nothing = pure Nothing - ids p1 p2 = do + ids p1 p2 p3 = do rcvId <- _smpP sndId <- smpP rcvPublicDhKey <- smpP queueMode <- p1 linkId <- p2 + serviceId <- p3 -- TODO [notifications] -- serverNtfCreds <- p3 - pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId} + pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId, serviceId} LNK_ -> LNK <$> _smpP <*> smpP + SOK_ -> SOK <$> _smpP + SOKS_ -> SOKS <$> _smpP NID_ -> NID <$> _smpP <*> smpP NMSG_ -> NMSG <$> _smpP <*> smpP PKEY_ -> PKEY <$> _smpP <*> smpP <*> smpP RRES_ -> RRES <$> (EncFwdResponse . unTail <$> _smpP) PRES_ -> PRES <$> (EncResponse . unTail <$> _smpP) END_ -> pure END + ENDS_ -> ENDS <$> _smpP DELD_ -> pure DELD INFO_ -> INFO <$> _smpP OK_ -> pure OK @@ -1737,9 +1847,10 @@ instance Encoding ErrorType where PROXY err -> "PROXY " <> smpEncode err AUTH -> "AUTH" BLOCKED info -> "BLOCKED " <> smpEncode info + SERVICE -> "SERVICE" CRYPTO -> "CRYPTO" QUOTA -> "QUOTA" - STORE err -> "STORE " <> smpEncode err + STORE err -> "STORE " <> encodeUtf8 err EXPIRED -> "EXPIRED" NO_MSG -> "NO_MSG" LARGE_MSG -> "LARGE_MSG" @@ -1754,9 +1865,10 @@ instance Encoding ErrorType where "PROXY" -> PROXY <$> _smpP "AUTH" -> pure AUTH "BLOCKED" -> BLOCKED <$> _smpP + "SERVICE" -> pure SERVICE "CRYPTO" -> pure CRYPTO "QUOTA" -> pure QUOTA - "STORE" -> STORE <$> _smpP + "STORE" -> STORE . safeDecodeUtf8 <$> (A.space *> A.takeByteString) "EXPIRED" -> pure EXPIRED "NO_MSG" -> pure NO_MSG "LARGE_MSG" -> pure LARGE_MSG @@ -1819,6 +1931,7 @@ instance Encoding BrokerErrorType where NETWORK -> "NETWORK" TIMEOUT -> "TIMEOUT" HOST -> "HOST" + NO_SERVICE -> "NO_SERVICE" smpP = A.takeTill (== ' ') >>= \case "RESPONSE" -> RESPONSE <$> _smpP @@ -1827,6 +1940,7 @@ instance Encoding BrokerErrorType where "NETWORK" -> pure NETWORK "TIMEOUT" -> pure TIMEOUT "HOST" -> pure HOST + "NO_SERVICE" -> pure NO_SERVICE _ -> fail "bad BrokerErrorType" instance StrEncoding BrokerErrorType where @@ -1837,6 +1951,7 @@ instance StrEncoding BrokerErrorType where NETWORK -> "NETWORK" TIMEOUT -> "TIMEOUT" HOST -> "HOST" + NO_SERVICE -> "NO_SERVICE" strP = A.takeTill (== ' ') >>= \case "RESPONSE" -> RESPONSE <$> _textP @@ -1845,13 +1960,14 @@ instance StrEncoding BrokerErrorType where "NETWORK" -> pure NETWORK "TIMEOUT" -> pure TIMEOUT "HOST" -> pure HOST + "NO_SERVICE" -> pure NO_SERVICE _ -> fail "bad BrokerErrorType" where _textP = A.space *> (T.unpack . safeDecodeUtf8 <$> A.takeByteString) -- | Send signed SMP transmission to TCP transport. tPut :: Transport c => THandle v c p -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()] -tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (batch params) (blockSize params) +tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions params where tPutBatch :: TransportBatch () -> IO [Either TransportError ()] tPutBatch = \case @@ -1870,13 +1986,13 @@ tPutLog th s = do -- ByteString in TBTransmissions includes byte with transmissions count data TransportBatch r = TBTransmissions ByteString Int [r] | TBTransmission ByteString r | TBError TransportError r -batchTransmissions :: Bool -> Int -> NonEmpty (Either TransportError SentRawTransmission) -> [TransportBatch ()] -batchTransmissions batch bSize = batchTransmissions' batch bSize . L.map (,()) +batchTransmissions :: THandleParams v p -> NonEmpty (Either TransportError SentRawTransmission) -> [TransportBatch ()] +batchTransmissions params = batchTransmissions' params . L.map (,()) -- | encodes and batches transmissions into blocks -batchTransmissions' :: forall r. Bool -> Int -> NonEmpty (Either TransportError SentRawTransmission, r) -> [TransportBatch r] -batchTransmissions' batch bSize ts - | batch = batchTransmissions_ bSize $ L.map (first $ fmap tEncodeForBatch) ts +batchTransmissions' :: forall v p r. THandleParams v p -> NonEmpty (Either TransportError SentRawTransmission, r) -> [TransportBatch r] +batchTransmissions' THandleParams {batch, blockSize = bSize, serviceAuth} ts + | batch = batchTransmissions_ bSize $ L.map (first $ fmap $ tEncodeForBatch serviceAuth) ts | otherwise = map mkBatch1 $ L.toList ts where mkBatch1 :: (Either TransportError SentRawTransmission, r) -> TransportBatch r @@ -1887,7 +2003,7 @@ batchTransmissions' batch bSize ts | B.length s <= bSize - 2 -> TBTransmission s r | otherwise -> TBError TELargeMsg r where - s = tEncode t + s = tEncode serviceAuth t -- | Pack encoded transmissions into batches batchTransmissions_ :: Int -> NonEmpty (Either TransportError ByteString, r) -> [TransportBatch r] @@ -1910,16 +2026,16 @@ batchTransmissions_ bSize = addBatch . foldr addTransmission ([], 0, 0, [], []) where b = B.concat $ B.singleton (lenEncode n) : ss -tEncode :: SentRawTransmission -> ByteString -tEncode (auth, t) = smpEncode (tAuthBytes auth) <> t +tEncode :: Bool -> SentRawTransmission -> ByteString +tEncode serviceAuth (auth, t) = tEncodeAuth serviceAuth auth <> t {-# INLINE tEncode #-} -tEncodeForBatch :: SentRawTransmission -> ByteString -tEncodeForBatch = smpEncode . Large . tEncode +tEncodeForBatch :: Bool -> SentRawTransmission -> ByteString +tEncodeForBatch serviceAuth = smpEncode . Large . tEncode serviceAuth {-# INLINE tEncodeForBatch #-} -tEncodeBatch1 :: SentRawTransmission -> ByteString -tEncodeBatch1 t = lenEncode 1 `B.cons` tEncodeForBatch t +tEncodeBatch1 :: Bool -> SentRawTransmission -> ByteString +tEncodeBatch1 serviceAuth t = lenEncode 1 `B.cons` tEncodeForBatch serviceAuth t {-# INLINE tEncodeBatch1 #-} -- tForAuth is lazy to avoid computing it when there is no key to sign @@ -1967,9 +2083,9 @@ tGet th@THandle {params} = L.map (tDecodeParseValidate params) <$> tGetParse th tDecodeParseValidate :: forall v p err cmd. ProtocolEncoding v err cmd => THandleParams v p -> Either TransportError RawTransmission -> SignedTransmission err cmd tDecodeParseValidate THandleParams {sessionId, thVersion = v, implySessId} = \case - Right RawTransmission {authenticator, authorized, sessId, corrId, entityId, command} + Right RawTransmission {authenticator, serviceSig, authorized, sessId, corrId, entityId, command} | implySessId || sessId == sessionId -> - let decodedTransmission = (,corrId,entityId,command) <$> decodeTAuthBytes authenticator + let decodedTransmission = (,corrId,entityId,command) <$> decodeTAuthBytes authenticator serviceSig in either (const $ tError corrId) (tParseValidate authorized) decodedTransmission | otherwise -> (Nothing, "", (corrId, NoEntity, Left $ fromProtocolError @v @err @cmd PESession)) Left _ -> tError "" diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 735643969..5c05e9984 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -51,6 +51,7 @@ import Control.Monad.IO.Unlift import Control.Monad.Reader import Control.Monad.Trans.Except import Control.Monad.STM (retry) +import Crypto.Random (ChaChaDRG) import Data.Bifunctor (first) import Data.ByteString.Base64 (encode) import qualified Data.ByteString.Builder as BLD @@ -67,9 +68,11 @@ import qualified Data.IntSet as IS import Data.List (foldl', intercalate, mapAccumR) import Data.List.NonEmpty (NonEmpty (..), (<|)) import qualified Data.List.NonEmpty as L +import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing) import Data.Semigroup (Sum (..)) +import qualified Data.Set as S import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1) @@ -80,6 +83,7 @@ import Data.Time.Format.ISO8601 (iso8601Show) import Data.Type.Equality import Data.Typeable (cast) import qualified Data.X509 as X +import qualified Data.X509.Validation as XV import GHC.Conc.Signal import GHC.IORef (atomicSwapIORef) import GHC.Stats (getRTSStats) @@ -88,7 +92,7 @@ import Network.Socket (ServiceName, Socket, socketToHandle) import qualified Network.TLS as TLS import Numeric.Natural (Natural) import Simplex.Messaging.Agent.Lock -import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, forwardSMPTransmission, nonBlockingWriteTBQueue, smpProxyError, temporaryClientError) +import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, forwardSMPTransmission, smpProxyError, temporaryClientError) import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient, getConnectedSMPServerClient) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding @@ -149,6 +153,14 @@ runSMPServerBlocking started cfg attachHTTP_ = newEnv cfg >>= runReaderT (smpSer type M s a = ReaderT (Env s) IO a type AttachHTTP = Socket -> TLS.Context -> IO () +-- actions used in serverThread to reduce STM transaction scope +data ClientSubAction + = CSAEndSub QueueId -- end single direct queue subscription + | CSAEndServiceSub -- end service subscription to one queue + | CSADecreaseSubs Int64 -- reduce service subscriptions when cancelling. Fixed number is used to correctly handle race conditions when service resubscribes + +type PrevClientSub s = (Client s, ClientSubAction, (EntityId, BrokerMsg)) + smpServer :: forall s. MsgStoreClass s => TMVar Bool -> ServerConfig s -> Maybe AttachHTTP -> M s () smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOptions} attachHTTP_ = do s <- asks server @@ -163,8 +175,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt stopServer s liftIO $ exitSuccess raceAny_ - ( serverThread "server subscribers" s subscribers subscriptions cancelSub - : serverThread "server ntfSubscribers" s ntfSubscribers ntfSubscriptions (\_ -> pure ()) + ( serverThread "server subscribers" s subscribers subscriptions serviceSubsCount (Just cancelSub) + : serverThread "server ntfSubscribers" s ntfSubscribers ntfSubscriptions ntfServiceSubsCount Nothing : deliverNtfsThread s : sendPendingEvtsThread s : receiveFromProxyAgent pa @@ -228,84 +240,151 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt closeServer = asks (smpAgent . proxyAgent) >>= liftIO . closeSMPClientAgent serverThread :: - forall sub. - String -> + forall sub. String -> Server s -> (Server s -> ServerSubscribers s) -> (Client s -> TMap QueueId sub) -> - (sub -> IO ()) -> + (Client s -> TVar Int64) -> + Maybe (sub -> IO ()) -> M s () - serverThread label srv srvSubscribers clientSubs unsub = do + serverThread label srv srvSubscribers clientSubs clientServiceSubs unsub_ = do labelMyThread label liftIO . forever $ do -- Reading clients outside of `updateSubscribers` transaction to avoid transaction re-evaluation on each new connected client. -- In case client disconnects during the transaction (its `connected` property is read), -- the transaction will still be re-evaluated, and the client won't be stored as subscribed. - sub@(_, clntId, _) <- atomically $ readTQueue subQ + sub@(_, clntId) <- atomically $ readTQueue subQ c_ <- getServerClient clntId srv atomically (updateSubscribers c_ sub) - $>>= endPreviousSubscriptions - >>= mapM_ unsub + >>= endPreviousSubscriptions where - ServerSubscribers {subQ, queueSubscribers, subClients, pendingEvents} = srvSubscribers srv - updateSubscribers :: Maybe (Client s) -> (QueueId, ClientId, Subscribed) -> STM (Maybe ((QueueId, BrokerMsg), Client s)) - updateSubscribers c_ (qId, clntId, subscribed) = updateSub $>>= clientToBeNotified + ServerSubscribers {subQ, queueSubscribers, serviceSubscribers, totalServiceSubs, subClients, pendingEvents} = srvSubscribers srv + updateSubscribers :: Maybe (Client s) -> (ClientSub, ClientId) -> STM [PrevClientSub s] + updateSubscribers c_ (clntSub, clntId) = case c_ of + Just c@Client {connected} -> ifM (readTVar connected) (updateSubConnected c) updateSubDisconnected + Nothing -> updateSubDisconnected where - updateSub = case c_ of - Just c@Client {connected} -> ifM (readTVar connected) (updateSubConnected c) updateSubDisconnected - Nothing -> updateSubDisconnected - updateSubConnected c - | subscribed = do - modifyTVar' subClients $ IS.insert clntId -- add client to server's subscribed cients - upsertSubscribedClient qId c queueSubscribers - | otherwise = do - removeWhenNoSubs c - lookupDeleteSubscribedClient qId queueSubscribers - -- do not insert client if it is already disconnected, but send END to any other client - updateSubDisconnected = lookupDeleteSubscribedClient qId queueSubscribers - clientToBeNotified c@Client {clientId, connected} - | clntId == clientId = pure Nothing - | otherwise = (\yes -> if yes then Just ((qId, subEvt), c) else Nothing) <$> readTVar connected - where - subEvt = if subscribed then END else DELD - endPreviousSubscriptions :: ((QueueId, BrokerMsg), Client s) -> IO (Maybe sub) - endPreviousSubscriptions (evt@(qId, _), c) = do + updateSubConnected c = case clntSub of + CSClient qId prevServiceId serviceId_ -> do + modifyTVar' subClients $ IS.insert clntId -- add ID to server's subscribed cients + as'' <- if prevServiceId == serviceId_ then pure [] else endServiceSub prevServiceId qId END + case serviceId_ of + Just serviceId -> do + as <- endQueueSub qId END + as' <- cancelServiceSubs serviceId =<< upsertSubscribedClient serviceId c serviceSubscribers + pure $ as ++ as' ++ as'' + Nothing -> do + as <- prevSub qId END (CSAEndSub qId) =<< upsertSubscribedClient qId c queueSubscribers + pure $ as ++ as'' + CSDeleted qId serviceId -> do + removeWhenNoSubs c + as <- endQueueSub qId DELD + as' <- endServiceSub serviceId qId DELD + pure $ as ++ as' + CSService serviceId -> do + modifyTVar' subClients $ IS.insert clntId -- add ID to server's subscribed cients + cancelServiceSubs serviceId =<< upsertSubscribedClient serviceId c serviceSubscribers + updateSubDisconnected = case clntSub of + -- do not insert client if it is already disconnected, but send END/DELD to any other client subscribed to this queue or service + CSClient qId prevServiceId serviceId -> do + as <- endQueueSub qId END + as' <- endServiceSub serviceId qId END + as'' <- if prevServiceId == serviceId then pure [] else endServiceSub prevServiceId qId END + pure $ as ++ as' ++ as'' + CSDeleted qId serviceId -> do + as <- endQueueSub qId DELD + as' <- endServiceSub serviceId qId DELD + pure $ as ++ as' + CSService serviceId -> cancelServiceSubs serviceId =<< lookupSubscribedClient serviceId serviceSubscribers + endQueueSub :: QueueId -> BrokerMsg -> STM [PrevClientSub s] + endQueueSub qId msg = prevSub qId msg (CSAEndSub qId) =<< lookupDeleteSubscribedClient qId queueSubscribers + endServiceSub :: Maybe ServiceId -> QueueId -> BrokerMsg -> STM [PrevClientSub s] + endServiceSub Nothing _ _ = pure [] + endServiceSub (Just serviceId) qId msg = prevSub qId msg CSAEndServiceSub =<< lookupSubscribedClient serviceId serviceSubscribers + prevSub :: QueueId -> BrokerMsg -> ClientSubAction -> Maybe (Client s) -> STM [PrevClientSub s] + prevSub qId msg action = + checkAnotherClient $ \c -> pure [(c, action, (qId, msg))] + cancelServiceSubs :: ServiceId -> Maybe (Client s) -> STM [PrevClientSub s] + cancelServiceSubs serviceId = + checkAnotherClient $ \c -> do + n <- swapTVar (clientServiceSubs c) 0 + pure [(c, CSADecreaseSubs n, (serviceId, ENDS n))] + checkAnotherClient :: (Client s -> STM [PrevClientSub s]) -> Maybe (Client s) -> STM [PrevClientSub s] + checkAnotherClient mkSub = \case + Just c@Client {clientId, connected} | clntId /= clientId -> + ifM (readTVar connected) (mkSub c) (pure []) + _ -> pure [] + + endPreviousSubscriptions :: [PrevClientSub s] -> IO () + endPreviousSubscriptions = mapM_ $ \(c, subAction, evt) -> do atomically $ modifyTVar' pendingEvents $ IM.alter (Just . maybe [evt] (evt <|)) (clientId c) - atomically $ do - sub <- TM.lookupDelete qId (clientSubs c) - removeWhenNoSubs c $> sub + case subAction of + CSAEndSub qId -> atomically (endSub c qId) >>= a unsub_ + where + a (Just unsub) (Just s) = unsub s + a _ _ = pure () + CSAEndServiceSub -> atomically $ do + modifyTVar' (clientServiceSubs c) decrease + modifyTVar' totalServiceSubs decrease + where + decrease n = max 0 (n - 1) + -- TODO [certs rcv] for SMP subscriptions CSADecreaseSubs should also remove all delivery threads of the passed client + CSADecreaseSubs n' -> atomically $ modifyTVar' totalServiceSubs $ \n -> max 0 (n - n') + where + endSub :: Client s -> QueueId -> STM (Maybe sub) + endSub c qId = TM.lookupDelete qId (clientSubs c) >>= (removeWhenNoSubs c $>) -- remove client from server's subscribed cients - removeWhenNoSubs c = whenM (null <$> readTVar (clientSubs c)) $ modifyTVar' subClients $ IS.delete (clientId c) + removeWhenNoSubs c = do + noClientSubs <- null <$> readTVar (clientSubs c) + noServiceSubs <- (0 ==) <$> readTVar (clientServiceSubs c) + when (noClientSubs && noServiceSubs) $ modifyTVar' subClients $ IS.delete (clientId c) deliverNtfsThread :: Server s -> M s () - deliverNtfsThread srv@Server {ntfSubscribers} = do + deliverNtfsThread srv@Server {ntfSubscribers = ServerSubscribers {subClients, serviceSubscribers}} = do ntfInt <- asks $ ntfDeliveryInterval . config - NtfStore ns <- asks ntfStore + ms <- asks msgStore + ns' <- asks ntfStore stats <- asks serverStats liftIO $ forever $ do threadDelay ntfInt - cIds <- IS.toList <$> readTVarIO (subClients ntfSubscribers) - forM_ cIds $ \cId -> getServerClient cId srv >>= mapM_ (deliverNtfs ns stats) + runDeliverNtfs ms ns' stats where - deliverNtfs ns stats Client {clientId, ntfSubscriptions, sndQ, connected} = - whenM (currentClient readTVarIO) $ do - subs <- readTVarIO ntfSubscriptions - ntfQs <- M.assocs . M.filterWithKey (\nId _ -> M.member nId subs) <$> readTVarIO ns - tryAny (atomically $ flushSubscribedNtfs ntfQs) >>= \case - Right len -> updateNtfStats len - Left e -> logDebug $ "NOTIFICATIONS: cancelled for client #" <> tshow clientId <> ", reason: " <> tshow e + runDeliverNtfs :: s -> NtfStore -> ServerStats -> IO () + runDeliverNtfs ms (NtfStore ns) stats = do + ntfs <- M.assocs <$> readTVarIO ns + unless (null ntfs) $ + getQueueNtfServices @(StoreQueue s) (queueStore ms) ntfs >>= \case + Left e -> logError $ "NOTIFICATIONS: getQueueNtfServices error " <> tshow e + Right (sNtfs, deleted) -> do + forM_ sNtfs $ \(serviceId_, ntfs') -> case serviceId_ of + Just sId -> getSubscribedClient sId serviceSubscribers >>= mapM_ (deliverServiceNtfs ntfs') + Nothing -> do -- legacy code that does almost the same as before for non-service subscribers + cIds <- IS.toList <$> readTVarIO subClients + forM_ cIds $ \cId -> getServerClient cId srv >>= mapM_ (deliverQueueNtfs ntfs') + atomically $ modifyTVar' ns (`M.withoutKeys` S.fromList (map fst deleted)) where - flushSubscribedNtfs :: [(NotifierId, TVar [MsgNtf])] -> STM Int - flushSubscribedNtfs ntfQs = do - ts_ <- foldM addNtfs [] ntfQs + deliverQueueNtfs ntfs' c@Client {ntfSubscriptions} = + whenM (currentClient readTVarIO c) $ do + subs <- readTVarIO ntfSubscriptions + unless (M.null subs) $ do + let ntfs'' = filter (\(nId, _) -> M.member nId subs) ntfs' + tryAny (atomically $ flushSubscribedNtfs ntfs'' c) >>= updateNtfStats c + deliverServiceNtfs ntfs' cv = readTVarIO cv >>= mapM_ deliver + where + deliver c = tryAny (atomically $ withSubscribed $ flushSubscribedNtfs ntfs') >>= updateNtfStats c + withSubscribed :: (Client s -> STM Int) -> STM Int + withSubscribed a = readTVar cv >>= maybe (throwSTM $ userError "service unsubscribed") a + flushSubscribedNtfs :: [(NotifierId, TVar [MsgNtf])] -> Client s' -> STM Int + flushSubscribedNtfs ntfs c@Client {sndQ} = do + ts_ <- foldM addNtfs [] ntfs forM_ (L.nonEmpty ts_) $ \ts -> do let cancelNtfs s = throwSTM $ userError $ s <> ", " <> show (length ts_) <> " ntfs kept" - unlessM (currentClient readTVar) $ cancelNtfs "not current client" + unlessM (currentClient readTVar c) $ cancelNtfs "not current client" whenM (isFullTBQueue sndQ) $ cancelNtfs "sending queue full" writeTBQueue sndQ ts pure $ length ts_ - currentClient :: Monad m => (forall a. TVar a -> m a) -> m Bool - currentClient rd = (&&) <$> rd connected <*> (IS.member clientId <$> rd (subClients ntfSubscribers)) + currentClient :: Monad m => (forall a. TVar a -> m a) -> Client s' -> m Bool + currentClient rd Client {clientId, connected} = (&&) <$> rd connected <*> (IS.member clientId <$> rd subClients) addNtfs :: [Transmission BrokerMsg] -> (NotifierId, TVar [MsgNtf]) -> STM [Transmission BrokerMsg] addNtfs acc (nId, v) = readTVar v >>= \case @@ -314,11 +393,14 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt writeTVar v [] pure $ foldl' (\acc' ntf -> nmsg nId ntf : acc') acc ntfs -- reverses, to order by time nmsg nId MsgNtf {ntfNonce, ntfEncMeta} = (CorrId "", nId, NMSG ntfNonce ntfEncMeta) - updateNtfStats 0 = pure () - updateNtfStats len = liftIO $ do - atomicModifyIORef'_ (ntfCount stats) (subtract len) - atomicModifyIORef'_ (msgNtfs stats) (+ len) - atomicModifyIORef'_ (msgNtfsB stats) (+ (len `div` 80 + 1)) -- up to 80 NMSG in the batch + updateNtfStats :: Client s' -> Either SomeException Int -> IO () + updateNtfStats Client {clientId} = \case + Right 0 -> pure () + Right len -> do + atomicModifyIORef'_ (ntfCount stats) (subtract len) + atomicModifyIORef'_ (msgNtfs stats) (+ len) + atomicModifyIORef'_ (msgNtfsB stats) (+ (len `div` 80 + 1)) -- up to 80 NMSG in the batch + Left e -> logNote $ "NOTIFICATIONS: cancelled for client #" <> tshow clientId <> ", reason: " <> tshow e sendPendingEvtsThread :: Server s -> M s () sendPendingEvtsThread srv@Server {subscribers, ntfSubscribers} = do @@ -334,11 +416,16 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt unless (null pending) $ forM_ (IM.assocs pending) $ \(cId, evts) -> getServerClient cId srv >>= mapM_ (enqueueEvts evts) where - enqueueEvts evts Client {connected, sndQ} = - whenM (readTVarIO connected) $ - nonBlockingWriteTBQueue sndQ ts >> updateEndStats + enqueueEvts evts c@Client {connected, sndQ} = + whenM (readTVarIO connected) $ do + sent <- atomically $ tryWriteTBQueue sndQ ts + if sent + then updateEndStats + else -- if queue is full it can block + forkClient c ("sendPendingEvtsThread.queueEvts") $ + atomically (writeTBQueue sndQ ts) >> updateEndStats where - ts = L.map (\(qId, evt) -> (CorrId "", qId, evt)) evts + ts = L.map (\(entId, evt) -> (CorrId "", entId, evt)) evts -- this accounts for both END and DELD events updateEndStats = do let len = L.length evts @@ -350,10 +437,17 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt receiveFromProxyAgent ProxyAgent {smpAgent = SMPClientAgent {agentQ}} = forever $ atomically (readTBQueue agentQ) >>= \case - CAConnected srv -> logInfo $ "SMP server connected " <> showServer' srv - CADisconnected srv qIds -> logError $ "SMP server disconnected " <> showServer' srv <> " / subscriptions: " <> tshow (L.length qIds) - CASubscribed srv qIds -> logError $ "SMP server subscribed " <> showServer' srv <> " / subscriptions: " <> tshow (L.length qIds) - CASubError srv errs -> logError $ "SMP server subscription errors " <> showServer' srv <> " / errors: " <> tshow (L.length errs) + CAConnected srv _service_ -> logInfo $ "SMP server connected " <> showServer' srv + CADisconnected srv qIds -> logError $ "SMP server disconnected " <> showServer' srv <> " / subscriptions: " <> tshow (length qIds) + -- the errors below should never happen - messaging proxy does not make any subscriptions + CASubscribed srv serviceId qIds -> logError $ "SMP server subscribed queues " <> asService <> showServer' srv <> " / subscriptions: " <> tshow (length qIds) + where + asService = if isJust serviceId then "as service " else "" + CASubError srv errs -> logError $ "SMP server subscription errors " <> showServer' srv <> " / errors: " <> tshow (length errs) + CAServiceDisconnected {} -> logError "CAServiceDisconnected" + CAServiceSubscribed {} -> logError "CAServiceSubscribed" + CAServiceSubError {} -> logError "CAServiceSubError" + CAServiceUnavailable {} -> logError "CAServiceUnavailable" where showServer' = decodeLatin1 . strEncode . host @@ -408,10 +502,11 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath liftIO $ threadDelay' $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0) - ss@ServerStats {fromTime, qCreated, qSecured, qDeletedAll, qDeletedAllB, qDeletedNew, qDeletedSecured, qSub, qSubAllB, qSubAuth, qSubDuplicate, qSubProhibited, qSubEnd, qSubEndB, ntfCreated, ntfDeleted, ntfDeletedB, ntfSub, ntfSubB, ntfSubAuth, ntfSubDuplicate, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, msgRecvGet, msgGet, msgGetNoMsg, msgGetAuth, msgGetDuplicate, msgGetProhibited, msgExpired, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount, ntfCount, pRelays, pRelaysOwn, pMsgFwds, pMsgFwdsOwn, pMsgFwdsRecv} + ss@ServerStats {fromTime, qCreated, qSecured, qDeletedAll, qDeletedAllB, qDeletedNew, qDeletedSecured, qSub, qSubAllB, qSubAuth, qSubDuplicate, qSubProhibited, qSubEnd, qSubEndB, ntfCreated, ntfDeleted, ntfDeletedB, ntfSub, ntfSubB, ntfSubAuth, ntfSubDuplicate, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, msgRecvGet, msgGet, msgGetNoMsg, msgGetAuth, msgGetDuplicate, msgGetProhibited, msgExpired, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount, ntfCount, pRelays, pRelaysOwn, pMsgFwds, pMsgFwdsOwn, pMsgFwdsRecv, rcvServices, ntfServices} <- asks serverStats st <- asks msgStore - QueueCounts {queueCount, notifierCount} <- liftIO $ queueCounts @(StoreQueue s) $ queueStore st + EntityCounts {queueCount, notifierCount, rcvServiceCount, ntfServiceCount, rcvServiceQueuesCount, ntfServiceQueuesCount} <- + liftIO $ getEntityCounts @(StoreQueue s) $ queueStore st let interval = 1000000 * logInterval forever $ do withFile statsFilePath AppendMode $ \h -> liftIO $ do @@ -464,81 +559,91 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt pMsgFwds' <- getResetProxyStatsData pMsgFwds pMsgFwdsOwn' <- getResetProxyStatsData pMsgFwdsOwn pMsgFwdsRecv' <- atomicSwapIORef pMsgFwdsRecv 0 + rcvServices' <- getServiceStatsData rcvServices + ntfServices' <- getServiceStatsData ntfServices qCount' <- readIORef qCount msgCount' <- readIORef msgCount ntfCount' <- readIORef ntfCount - hPutStrLn h $ - intercalate + T.hPutStrLn h $ + T.intercalate "," - ( [ iso8601Show $ utctDay fromTime', - show qCreated', - show qSecured', - show qDeletedAll', - show msgSent', - show msgRecv', + ( [ T.pack $ iso8601Show $ utctDay fromTime', + tshow qCreated', + tshow qSecured', + tshow qDeletedAll', + tshow msgSent', + tshow msgRecv', dayCount ps, weekCount ps, monthCount ps, - show msgSentNtf', - show msgRecvNtf', + tshow msgSentNtf', + tshow msgRecvNtf', dayCount psNtf, weekCount psNtf, monthCount psNtf, - show qCount', - show msgCount', - show msgExpired', - show qDeletedNew', - show qDeletedSecured' + tshow qCount', + tshow msgCount', + tshow msgExpired', + tshow qDeletedNew', + tshow qDeletedSecured' ] <> showProxyStats pRelays' <> showProxyStats pRelaysOwn' <> showProxyStats pMsgFwds' <> showProxyStats pMsgFwdsOwn' - <> [ show pMsgFwdsRecv', - show qSub', - show qSubAuth', - show qSubDuplicate', - show qSubProhibited', - show msgSentAuth', - show msgSentQuota', - show msgSentLarge', - show msgNtfs', - show msgNtfNoSub', - show msgNtfLost', + <> [ tshow pMsgFwdsRecv', + tshow qSub', + tshow qSubAuth', + tshow qSubDuplicate', + tshow qSubProhibited', + tshow msgSentAuth', + tshow msgSentQuota', + tshow msgSentLarge', + tshow msgNtfs', + tshow msgNtfNoSub', + tshow msgNtfLost', "0", -- qSubNoMsg' is removed for performance. -- Use qSubAllB for the approximate number of all subscriptions. -- Average observed batch size is 25-30 subscriptions. - show msgRecvGet', - show msgGet', - show msgGetNoMsg', - show msgGetAuth', - show msgGetDuplicate', - show msgGetProhibited', + tshow msgRecvGet', + tshow msgGet', + tshow msgGetNoMsg', + tshow msgGetAuth', + tshow msgGetDuplicate', + tshow msgGetProhibited', "0", -- dayCount psSub; psSub is removed to reduce memory usage "0", -- weekCount psSub "0", -- monthCount psSub - show queueCount, - show ntfCreated', - show ntfDeleted', - show ntfSub', - show ntfSubAuth', - show ntfSubDuplicate', - show notifierCount, - show qDeletedAllB', - show qSubAllB', - show qSubEnd', - show qSubEndB', - show ntfDeletedB', - show ntfSubB', - show msgNtfsB', - show msgNtfExpired', - show ntfCount' + tshow queueCount, + tshow ntfCreated', + tshow ntfDeleted', + tshow ntfSub', + tshow ntfSubAuth', + tshow ntfSubDuplicate', + tshow notifierCount, + tshow qDeletedAllB', + tshow qSubAllB', + tshow qSubEnd', + tshow qSubEndB', + tshow ntfDeletedB', + tshow ntfSubB', + tshow msgNtfsB', + tshow msgNtfExpired', + tshow ntfCount', + tshow rcvServiceCount, + tshow ntfServiceCount, + tshow rcvServiceQueuesCount, + tshow ntfServiceQueuesCount ] + <> showServiceStats rcvServices' + <> showServiceStats ntfServices' ) liftIO $ threadDelay' interval where showProxyStats ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} = - [show _pRequests, show _pSuccesses, show _pErrorsConnect, show _pErrorsCompat, show _pErrorsOther] + map tshow [_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther] + showServiceStats ServiceStatsData {_srvAssocNew, _srvAssocDuplicate, _srvAssocUpdated, _srvAssocRemoved, _srvSubCount, _srvSubDuplicate, _srvSubQueues, _srvSubEnd} = + map tshow [_srvAssocNew, _srvAssocDuplicate, _srvAssocUpdated, _srvAssocRemoved, _srvSubCount, _srvSubDuplicate, _srvSubQueues, _srvSubEnd] prometheusMetricsThread_ :: ServerConfig s -> [M s ()] prometheusMetricsThread_ ServerConfig {prometheusInterval = Just interval, prometheusMetricsFile} = @@ -566,8 +671,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt d <- getServerStatsData ss let ps = periodStatDataCounts $ _activeQueues d psNtf = periodStatDataCounts $ _activeQueuesNtf d - QueueCounts {queueCount, notifierCount} <- queueCounts @(StoreQueue s) $ queueStore st - pure ServerMetrics {statsData = d, activeQueueCounts = ps, activeNtfCounts = psNtf, queueCount, notifierCount, rtsOptions} + entityCounts <- getEntityCounts @(StoreQueue s) $ queueStore st + pure ServerMetrics {statsData = d, activeQueueCounts = ps, activeNtfCounts = psNtf, entityCounts, rtsOptions} getRealTimeMetrics :: Env s -> IO RealTimeMetrics getRealTimeMetrics Env {sockets, msgStore_ = ms, server = srv@Server {subscribers, ntfSubscribers}} = do @@ -583,21 +688,33 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt loadedCounts <- loadedQueueCounts $ fromMsgStore ms pure RealTimeMetrics {socketStats, threadsCount, clientsCount, smpSubs, ntfSubs, loadedCounts} where - getSubscribersMetrics ServerSubscribers {queueSubscribers, subClients} = do + getSubscribersMetrics ServerSubscribers {queueSubscribers, serviceSubscribers, subClients} = do subsCount <- M.size <$> getSubscribedClients queueSubscribers subClientsCount <- IS.size <$> readTVarIO subClients - pure RTSubscriberMetrics {subsCount, subClientsCount} + subServicesCount <- M.size <$> getSubscribedClients serviceSubscribers + pure RTSubscriberMetrics {subsCount, subClientsCount, subServicesCount} runClient :: Transport c => X.CertificateChain -> C.APrivateSignKey -> TProxy c 'TServer -> c 'TServer -> M s () runClient srvCert srvSignKey tp h = do + ms <- asks msgStore + g <- asks random + idSize <- asks $ queueIdBytes . config kh <- asks serverIdentity ks <- atomically . C.generateKeyPair =<< asks random ServerConfig {smpServerVRange, smpHandshakeTimeout} <- asks config labelMyThread $ "smp handshake for " <> transportName tp - liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake srvCert srvSignKey h ks kh smpServerVRange) >>= \case + liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake srvCert srvSignKey h ks kh smpServerVRange $ getClientService ms g idSize) >>= \case Just (Right th) -> runClientTransport th _ -> pure () + getClientService :: s -> TVar ChaChaDRG -> Int -> SMPServiceRole -> X.CertificateChain -> XV.Fingerprint -> ExceptT TransportError IO ServiceId + getClientService ms g idSize role cert fp = do + newServiceId <- EntityId <$> atomically (C.randomBytes idSize g) + ts <- liftIO getSystemDate + let sr = ServiceRec {serviceId = newServiceId, serviceRole = role, serviceCert = cert, serviceCertHash = fp, serviceCreatedAt = ts} + withExceptT (const $ TEHandshake BAD_SERVICE) $ ExceptT $ + getCreateService @(StoreQueue s) (queueStore ms) sr + controlPortThread_ :: ServerConfig s -> [M s ()] controlPortThread_ ServerConfig {controlPort = Just port} = [runCPServer port] controlPortThread_ _ = [] @@ -647,7 +764,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt CPClients -> withAdminRole $ do cls <- getServerClients srv hPutStrLn h "clientId,sessionId,connected,createdAt,rcvActiveAt,sndActiveAt,age,subscriptions" - forM_ (IM.toList cls) $ \(cid, Client {sessionId, connected, createdAt, rcvActiveAt, sndActiveAt, subscriptions}) -> do + forM_ (IM.toList cls) $ \(cid, Client {clientTHParams = THandleParams {sessionId}, connected, createdAt, rcvActiveAt, sndActiveAt, subscriptions}) -> do connected' <- bshow <$> readTVarIO connected rcvActiveAt' <- strEncode <$> readTVarIO rcvActiveAt sndActiveAt' <- strEncode <$> readTVarIO sndActiveAt @@ -658,7 +775,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt CPStats -> withUserRole $ do ss <- unliftIO u $ asks serverStats st <- unliftIO u $ asks msgStore - QueueCounts {queueCount, notifierCount} <- queueCounts @(StoreQueue s) $ queueStore st + EntityCounts {queueCount, notifierCount, rcvServiceCount, ntfServiceCount, rcvServiceQueuesCount, ntfServiceQueuesCount} <- + getEntityCounts @(StoreQueue s) $ queueStore st let getStat :: (ServerStats -> IORef a) -> IO a getStat var = readIORef (var ss) putStat :: Show a => String -> (ServerStats -> IORef a) -> IO () @@ -711,6 +829,10 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt putProxyStat "pMsgFwds" pMsgFwds putProxyStat "pMsgFwdsOwn" pMsgFwdsOwn putStat "pMsgFwdsRecv" pMsgFwdsRecv + hPutStrLn h $ "rcvServiceCount: " <> show rcvServiceCount + hPutStrLn h $ "ntfServiceCount: " <> show ntfServiceCount + hPutStrLn h $ "rcvServiceQueuesCount: " <> show rcvServiceQueuesCount + hPutStrLn h $ "ntfServiceQueuesCount: " <> show ntfServiceQueuesCount CPStatsRTS -> getRTSStats >>= hPrint h CPThreads -> withAdminRole $ do #if MIN_VERSION_base(4,18,0) @@ -781,14 +903,15 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt putSubscribersInfo protoName ServerSubscribers {queueSubscribers, subClients} showIds = do activeSubs <- getSubscribedClients queueSubscribers hPutStrLn h $ protoName <> " subscriptions: " <> show (M.size activeSubs) + -- TODO [certs] service subscriptions clnts <- countSubClients activeSubs hPutStrLn h $ protoName <> " subscribed clients: " <> show (IS.size clnts) <> (if showIds then " " <> show (IS.toList clnts) else "") clnts' <- readTVarIO subClients hPutStrLn h $ protoName <> " subscribed clients count 2: " <> show (IS.size clnts') <> (if showIds then " " <> show clnts' else "") where - countSubClients :: M.Map QueueId (TVar (Maybe (Client s))) -> IO IS.IntSet + countSubClients :: Map QueueId (TVar (Maybe (Client s))) -> IO IS.IntSet countSubClients = foldM (\ !s c -> maybe s ((`IS.insert` s) . clientId) <$> readTVarIO c) IS.empty - countClientSubs :: (Client s -> TMap QueueId a) -> Maybe (M.Map QueueId a -> IO (Int, Int, Int, Int)) -> IM.IntMap (Client s) -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) + countClientSubs :: (Client s -> TMap QueueId a) -> Maybe (Map QueueId a -> IO (Int, Int, Int, Int)) -> IM.IntMap (Client s) -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) countClientSubs subSel countSubs_ = foldM addSubs (0, (0, 0, 0, 0), 0, (0, 0, 0)) where addSubs :: (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) -> Client s -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) @@ -813,7 +936,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt sl <- atomically $ lengthTBQueue sndQ ml <- atomically $ lengthTBQueue msgQ pure (rl, sl, ml) - countSMPSubs :: M.Map QueueId Sub -> IO (Int, Int, Int, Int) + countSMPSubs :: Map QueueId Sub -> IO (Int, Int, Int, Int) countSMPSubs = foldM countSubs (0, 0, 0, 0) where countSubs (c1, c2, c3, c4) Sub {subThread} = case subThread of @@ -880,12 +1003,12 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt hPutStrLn h "AUTH" runClientTransport :: forall c s. (Transport c, MsgStoreClass s) => THandleSMP c 'TServer -> M s () -runClientTransport h@THandle {params = thParams@THandleParams {thVersion, sessionId}} = do +runClientTransport h@THandle {params = thParams@THandleParams {sessionId}} = do q <- asks $ tbqSize . config ts <- liftIO getSystemTime nextClientId <- asks clientSeq clientId <- atomically $ stateTVar nextClientId $ \next -> (next, next + 1) - c <- liftIO $ newClient clientId q thVersion sessionId ts + c <- liftIO $ newClient clientId q thParams ts runClientThreads c `finally` clientDisconnected c where runClientThreads :: Client s -> M s () @@ -896,18 +1019,17 @@ runClientTransport h@THandle {params = thParams@THandleParams {thVersion, sessio expCfg <- asks $ inactiveClientExpiration . config th <- newMVar h -- put TH under a fair lock to interleave messages and command responses labelMyThread . B.unpack $ "client $" <> encode sessionId - raceAny_ $ [liftIO $ send th c, liftIO $ sendMsg th c, client thParams s ms c, receive h ms c] <> disconnectThread_ c s expCfg + raceAny_ $ [liftIO $ send th c, liftIO $ sendMsg th c, client s ms c, receive h ms c] <> disconnectThread_ c s expCfg disconnectThread_ :: Client s -> Server s -> Maybe ExpirationConfig -> [M s ()] disconnectThread_ c s (Just expCfg) = [liftIO $ disconnectTransport h (rcvActiveAt c) (sndActiveAt c) expCfg (noSubscriptions c s)] disconnectThread_ _ _ _ = [] - noSubscriptions Client {clientId} Server {subscribers, ntfSubscribers} = do - hasSubs <- IS.member clientId <$> readTVarIO (subClients subscribers) - if hasSubs - then pure False - else not . IS.member clientId <$> readTVarIO (subClients ntfSubscribers) + noSubscriptions Client {clientId} s = + not <$> anyM [hasSubs (subscribers s), hasSubs (ntfSubscribers s)] + where + hasSubs ServerSubscribers {subClients} = IS.member clientId <$> readTVarIO subClients clientDisconnected :: forall s. Client s -> M s () -clientDisconnected c@Client {clientId, subscriptions, ntfSubscriptions, connected, sessionId, endThreads} = do +clientDisconnected c@Client {clientId, subscriptions, ntfSubscriptions, serviceSubsCount, ntfServiceSubsCount, connected, clientTHParams = THandleParams {sessionId, thAuth}, endThreads} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disc" -- these can be in separate transactions, -- because the client already disconnected and they won't change @@ -917,16 +1039,26 @@ clientDisconnected c@Client {clientId, subscriptions, ntfSubscriptions, connecte liftIO $ mapM_ cancelSub subs whenM (asks serverActive >>= readTVarIO) $ do srv@Server {subscribers, ntfSubscribers} <- asks server - liftIO $ updateSubscribers subs subscribers - liftIO $ updateSubscribers ntfSubs ntfSubscribers - liftIO $ deleteServerClient clientId srv + liftIO $ do + deleteServerClient clientId srv + updateSubscribers subs subscribers + updateSubscribers ntfSubs ntfSubscribers + case peerClientService =<< thAuth of + Just THClientService {serviceId, serviceRole} + | serviceRole == SRMessaging -> updateServiceSubs serviceId serviceSubsCount subscribers + | serviceRole == SRNotifier -> updateServiceSubs serviceId ntfServiceSubsCount ntfSubscribers + _ -> pure () tIds <- atomically $ swapTVar endThreads IM.empty liftIO $ mapM_ (mapM_ killThread <=< deRefWeak) tIds where - updateSubscribers :: M.Map QueueId a -> ServerSubscribers s -> IO () + updateSubscribers :: Map QueueId a -> ServerSubscribers s -> IO () updateSubscribers subs ServerSubscribers {queueSubscribers, subClients} = do mapM_ (\qId -> deleteSubcribedClient qId c queueSubscribers) (M.keys subs) atomically $ modifyTVar' subClients $ IS.delete clientId + updateServiceSubs :: ServiceId -> TVar Int64 -> ServerSubscribers s -> IO () + updateServiceSubs serviceId subsCount ServerSubscribers {totalServiceSubs, serviceSubscribers} = do + deleteSubcribedClient serviceId c serviceSubscribers + atomically . modifyTVar' totalServiceSubs . subtract =<< readTVarIO subsCount cancelSub :: Sub -> IO () cancelSub s = case subThread s of @@ -937,20 +1069,21 @@ cancelSub s = case subThread s of ProhibitSub -> pure () receive :: forall c s. (Transport c, MsgStoreClass s) => THandleSMP c 'TServer -> s -> Client s -> M s () -receive h@THandle {params = THandleParams {thAuth}} ms Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do +receive h@THandle {params = THandleParams {thAuth, sessionId}} ms Client {rcvQ, sndQ, rcvActiveAt} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive" sa <- asks serverActive - forever $ do - ts <- L.toList <$> liftIO (tGet h) + stats <- asks serverStats + liftIO $ forever $ do + ts <- tGet h unlessM (readTVarIO sa) $ throwIO $ userError "server stopped" - atomically . (writeTVar rcvActiveAt $!) =<< liftIO getSystemTime - stats <- asks serverStats - (errs, cmds) <- partitionEithers <$> mapM (cmdAction stats) ts + atomically . (writeTVar rcvActiveAt $!) =<< getSystemTime + let service = peerClientService =<< thAuth + (errs, cmds) <- partitionEithers <$> mapM (cmdAction stats service) (L.toList ts) updateBatchStats stats cmds write sndQ errs write rcvQ cmds where - updateBatchStats :: ServerStats -> [(Maybe (StoreQueue s, QueueRec), Transmission Cmd)] -> M s () + updateBatchStats :: ServerStats -> [(Maybe (StoreQueue s, QueueRec), Transmission Cmd)] -> IO () updateBatchStats stats = \case (_, (_, _, (Cmd _ cmd))) : _ -> do let sel_ = case cmd of @@ -961,26 +1094,26 @@ receive h@THandle {params = THandleParams {thAuth}} ms Client {rcvQ, sndQ, rcvAc _ -> Nothing mapM_ (\sel -> incStat $ sel stats) sel_ [] -> pure () - cmdAction :: ServerStats -> SignedTransmission ErrorType Cmd -> M s (Either (Transmission BrokerMsg) (Maybe (StoreQueue s, QueueRec), Transmission Cmd)) - cmdAction stats (tAuth, authorized, (corrId, entId, cmdOrError)) = + cmdAction :: ServerStats -> Maybe THPeerClientService -> SignedTransmission ErrorType Cmd -> IO (Either (Transmission BrokerMsg) (Maybe (StoreQueue s, QueueRec), Transmission Cmd)) + cmdAction stats service (tAuth, authorized, (corrId, entId, cmdOrError)) = case cmdOrError of Left e -> pure $ Left (corrId, entId, ERR e) - Right cmd -> verified =<< verifyTransmission ms ((,C.cbNonce (bs corrId)) <$> thAuth) tAuth authorized entId cmd + Right cmd -> verified =<< verifyTransmission ms service ((,C.cbNonce (bs corrId)) <$> thAuth) tAuth authorized entId cmd where verified = \case VRVerified q -> pure $ Right (q, (corrId, entId, cmd)) - VRFailed -> do + VRFailed e -> do case cmd of Cmd _ SEND {} -> incStat $ msgSentAuth stats Cmd _ SUB -> incStat $ qSubAuth stats Cmd _ NSUB -> incStat $ ntfSubAuth stats Cmd _ GET -> incStat $ msgGetAuth stats _ -> pure () - pure $ Left (corrId, entId, ERR AUTH) + pure $ Left (corrId, entId, ERR e) write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty send :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> IO () -send th c@Client {sndQ, msgQ, sessionId} = do +send th c@Client {sndQ, msgQ, clientTHParams = THandleParams {sessionId}} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send" forever $ atomically (readTBQueue sndQ) >>= sendTransmissions where @@ -1005,7 +1138,7 @@ send th c@Client {sndQ, msgQ, sessionId} = do _ -> (msgs, t) sendMsg :: Transport c => MVar (THandleSMP c 'TServer) -> Client s -> IO () -sendMsg th c@Client {msgQ, sessionId} = do +sendMsg th c@Client {msgQ, clientTHParams = THandleParams {sessionId}} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " sendMsg" forever $ atomically (readTBQueue msgQ) >>= mapM_ (\t -> tSend th c [t]) @@ -1028,7 +1161,7 @@ disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcv ts <- max <$> readTVarIO rcvActiveAt <*> readTVarIO sndActiveAt if systemSeconds ts < old then closeConnection connection else loop -data VerificationResult s = VRVerified (Maybe (StoreQueue s, QueueRec)) | VRFailed +data VerificationResult s = VRVerified (Maybe (StoreQueue s, QueueRec)) | VRFailed ErrorType -- This function verifies queue command authorization, with the objective to have constant time between the three AUTH error scenarios: -- - the queue and party key exist, and the provided authorization has type matching queue key, but it is made with the different key. @@ -1036,37 +1169,68 @@ data VerificationResult s = VRVerified (Maybe (StoreQueue s, QueueRec)) | VRFail -- - the queue or party key do not exist. -- In all cases, the time of the verification should depend only on the provided authorization type, -- a dummy key is used to run verification in the last two cases, and failure is returned irrespective of the result. -verifyTransmission :: forall s. MsgStoreClass s => s -> Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TransmissionAuth -> ByteString -> QueueId -> Cmd -> M s (VerificationResult s) -verifyTransmission ms auth_ tAuth authorized queueId cmd = - case cmd of - Cmd SRecipient (NEW NewQueueReq {rcvAuthKey = k}) -> pure $ Nothing `verifiedWith` k - Cmd SRecipient _ -> verifyQueue (\q -> Just q `verifiedWithKeys` recipientKeys (snd q)) <$> get SRecipient - Cmd SSender (SKEY k) -> verifySecure SSender k - -- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command - Cmd SSender SEND {} -> verifyQueue (\q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified (Just q) else VRFailed) <$> get SSender - Cmd SSender PING -> pure $ VRVerified Nothing - Cmd SSender RFWD {} -> pure $ VRVerified Nothing - Cmd SSenderLink (LKEY k) -> verifySecure SSenderLink k - Cmd SSenderLink LGET -> verifyQueue (\q -> if isContactQueue (snd q) then VRVerified (Just q) else VRFailed) <$> get SSenderLink - -- NSUB will not be accepted without authorization - Cmd SNotifier NSUB -> verifyQueue (\q -> maybe dummyVerify (\n -> Just q `verifiedWith` notifierKey n) (notifier $ snd q)) <$> get SNotifier - Cmd SProxiedClient _ -> pure $ VRVerified Nothing +verifyTransmission :: forall s. MsgStoreClass s => s -> Maybe THPeerClientService -> Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TAuthorizations -> ByteString -> QueueId -> Cmd -> IO (VerificationResult s) +verifyTransmission ms service auth_ tAuth authorized queueId command@(Cmd party cmd) + | verifyServiceSig = case party of + SRecipient | hasRole SRMessaging -> case cmd of + NEW NewQueueReq {rcvAuthKey = k} -> pure $ Nothing `verifiedWith` k + SUB -> verifyQueue SRecipient $ \q -> Just q `verifiedWithKeys` recipientKeys (snd q) + SUBS -> pure verifyServiceCmd + _ -> verifyQueue SRecipient $ \q -> Just q `verifiedWithKeys` recipientKeys (snd q) + SSender | hasRole SRMessaging -> case cmd of + SKEY k -> verifySecure SSender k + -- SEND will be accepted without authorization before the queue is secured with KEY, SKEY or LSKEY command + SEND {} -> verifyQueue SSender $ \q -> if maybe (isNothing tAuth) verify (senderKey $ snd q) then VRVerified (Just q) else VRFailed AUTH + PING -> pure $ VRVerified Nothing + SSenderLink | hasRole SRMessaging -> case cmd of + LKEY k -> verifySecure SSenderLink k + LGET -> verifyQueue SSenderLink $ \q -> if isContactQueue (snd q) then VRVerified (Just q) else VRFailed AUTH + SNotifier | hasRole SRNotifier -> case cmd of + NSUB -> verifyQueue SNotifier $ \q -> maybe dummyVerify (\n -> Just q `verifiedWith` notifierKey n) (notifier $ snd q) + NSUBS -> pure verifyServiceCmd + SProxiedClient | hasRole SRMessaging -> pure $ VRVerified Nothing + SProxyService | hasRole SRProxy -> pure $ VRVerified Nothing + _ -> pure $ VRFailed $ CMD PROHIBITED + | otherwise = pure $ VRFailed SERVICE where - verify = verifyCmdAuthorization auth_ tAuth authorized - dummyVerify = verify (dummyAuthKey tAuth) `seq` VRFailed - verifyQueue :: ((StoreQueue s, QueueRec) -> VerificationResult s) -> Either ErrorType (StoreQueue s, QueueRec) -> VerificationResult s - verifyQueue = either (const dummyVerify) - verifySecure :: DirectParty p => SParty p -> SndPublicAuthKey -> M s (VerificationResult s) - verifySecure p k = verifyQueue (\q -> if k `allowedKey` snd q then Just q `verifiedWith` k else dummyVerify) <$> get p + hasRole role = case service of + Just THClientService {serviceRole} -> serviceRole == role + Nothing -> True + verify = verifyCmdAuthorization auth_ tAuth authorized' + verifyServiceCmd :: VerificationResult s + verifyServiceCmd = case (service, tAuth) of + (Just THClientService {serviceKey = k}, Just (TASignature (C.ASignature C.SEd25519 s), Nothing)) + | C.verify' k s authorized -> VRVerified Nothing + _ -> VRFailed SERVICE + -- this function verify service signature for commands that use it in service sessions + verifyServiceSig + | useServiceAuth command = case (service, serviceSig) of + (Just THClientService {serviceKey = k}, Just s) -> C.verify' k s authorized + (Nothing, Nothing) -> True + _ -> False + | otherwise = isNothing serviceSig + serviceSig = snd =<< tAuth + authorized' = case (service, serviceSig) of + (Just THClientService {serviceCertHash = XV.Fingerprint fp}, Just _) -> fp <> authorized + _ -> authorized + dummyVerify :: VerificationResult s + dummyVerify = verify (dummyAuthKey tAuth) `seq` VRFailed AUTH + verifyQueue :: DirectParty p => SParty p -> ((StoreQueue s, QueueRec) -> VerificationResult s) -> IO (VerificationResult s) + verifyQueue p v = either err v <$> getQueueRec ms p queueId + where + -- this prevents reporting any STORE errors as AUTH errors + err = \case + AUTH -> dummyVerify + e -> VRFailed e + verifySecure :: DirectParty p => SParty p -> SndPublicAuthKey -> IO (VerificationResult s) + verifySecure p k = verifyQueue p $ \q -> if k `allowedKey` snd q then Just q `verifiedWith` k else dummyVerify verifiedWith :: Maybe (StoreQueue s, QueueRec) -> C.APublicAuthKey -> VerificationResult s - verifiedWith q_ k = if verify k then VRVerified q_ else VRFailed + verifiedWith q_ k = if verify k then VRVerified q_ else VRFailed AUTH verifiedWithKeys :: Maybe (StoreQueue s, QueueRec) -> NonEmpty C.APublicAuthKey -> VerificationResult s - verifiedWithKeys q_ ks = if any verify ks then VRVerified q_ else VRFailed + verifiedWithKeys q_ ks = if any verify ks then VRVerified q_ else VRFailed AUTH allowedKey k = \case QueueRec {queueMode = Just QMMessaging, senderKey} -> maybe True (k ==) senderKey _ -> False - get :: DirectParty p => SParty p -> M s (Either ErrorType (StoreQueue s, QueueRec)) - get party = liftIO $ getQueueRec ms party queueId isContactQueue :: QueueRec -> Bool isContactQueue QueueRec {queueMode, senderKey} = case queueMode of @@ -1079,15 +1243,15 @@ isSecuredMsgQueue QueueRec {queueMode, senderKey} = case queueMode of Just QMContact -> False _ -> isJust senderKey -verifyCmdAuthorization :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TransmissionAuth -> ByteString -> C.APublicAuthKey -> Bool +verifyCmdAuthorization :: Maybe (THandleAuth 'TServer, C.CbNonce) -> Maybe TAuthorizations -> ByteString -> C.APublicAuthKey -> Bool verifyCmdAuthorization auth_ tAuth authorized key = maybe False (verify key) tAuth where - verify :: C.APublicAuthKey -> TransmissionAuth -> Bool + verify :: C.APublicAuthKey -> TAuthorizations -> Bool verify (C.APublicAuthKey a k) = \case - TASignature (C.ASignature a' s) -> case testEquality a a' of + (TASignature (C.ASignature a' s), _) -> case testEquality a a' of Just Refl -> C.verify' k s authorized _ -> C.verify' (dummySignKey a') s authorized `seq` False - TAAuthenticator s -> case a of + (TAAuthenticator s, _) -> case a of C.SX25519 -> verifyCmdAuth auth_ k s authorized _ -> verifyCmdAuth auth_ dummyKeyX25519 s authorized `seq` False @@ -1096,10 +1260,10 @@ verifyCmdAuth auth_ k authenticator authorized = case auth_ of Just (THAuthServer {serverPrivKey = pk}, nonce) -> C.cbVerify k pk nonce authenticator authorized Nothing -> False -dummyVerifyCmd :: Maybe (THandleAuth 'TServer, C.CbNonce) -> ByteString -> TransmissionAuth -> Bool +dummyVerifyCmd :: Maybe (THandleAuth 'TServer, C.CbNonce) -> ByteString -> TAuthorizations -> Bool dummyVerifyCmd auth_ authorized = \case - TASignature (C.ASignature a s) -> C.verify' (dummySignKey a) s authorized - TAAuthenticator s -> verifyCmdAuth auth_ dummyKeyX25519 s authorized + (TASignature (C.ASignature a s), _) -> C.verify' (dummySignKey a) s authorized + (TAAuthenticator s, _) -> verifyCmdAuth auth_ dummyKeyX25519 s authorized -- These dummy keys are used with `dummyVerify` function to mitigate timing attacks -- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes @@ -1108,9 +1272,9 @@ dummySignKey = \case C.SEd25519 -> dummyKeyEd25519 C.SEd448 -> dummyKeyEd448 -dummyAuthKey :: Maybe TransmissionAuth -> C.APublicAuthKey +dummyAuthKey :: Maybe TAuthorizations -> C.APublicAuthKey dummyAuthKey = \case - Just (TASignature (C.ASignature a _)) -> case a of + Just (TASignature (C.ASignature a _), _) -> case a of C.SEd25519 -> C.APublicAuthKey C.SEd25519 dummyKeyEd25519 C.SEd448 -> C.APublicAuthKey C.SEd448 dummyKeyEd448 _ -> C.APublicAuthKey C.SX25519 dummyKeyX25519 @@ -1124,7 +1288,7 @@ dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/Xo dummyKeyX25519 :: C.PublicKey 'C.X25519 dummyKeyX25519 = "MCowBQYDK2VuAyEA4JGSMYht18H4mas/jHeBwfcM7jLwNYJNOAhi2/g4RXg=" -forkClient :: Client s -> String -> M s () -> M s () +forkClient :: MonadUnliftIO m => Client s -> String -> m () -> m () forkClient Client {endThreads, endThreadSeq} label action = do tId <- atomically $ stateTVar endThreadSeq $ \next -> (next, next + 1) t <- forkIO $ do @@ -1132,17 +1296,18 @@ forkClient Client {endThreads, endThreadSeq} label action = do action `finally` atomically (modifyTVar' endThreads $ IM.delete tId) mkWeakThreadId t >>= atomically . modifyTVar' endThreads . IM.insert tId -client :: forall s. MsgStoreClass s => THandleParams SMPVersion 'TServer -> Server s -> s -> Client s -> M s () +client :: forall s. MsgStoreClass s => Server s -> s -> Client s -> M s () client - thParams' + -- TODO [certs rcv] rcv subscriptions Server {subscribers, ntfSubscribers} ms - clnt@Client {clientId, subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, procThreads} = do + clnt@Client {clientId, subscriptions, ntfSubscriptions, serviceSubsCount = _todo', ntfServiceSubsCount, rcvQ, sndQ, clientTHParams = thParams'@THandleParams {sessionId}, procThreads} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands" let THandleParams {thVersion} = thParams' + service = peerClientService =<< thAuth thParams' forever $ atomically (readTBQueue rcvQ) - >>= mapM (processCommand thVersion) + >>= mapM (processCommand service thVersion) >>= mapM_ reply . L.nonEmpty . catMaybes . L.toList where reply :: MonadIO m => NonEmpty (Transmission BrokerMsg) -> m () @@ -1185,7 +1350,7 @@ client -- Cap the destination relay version range to prevent client version fingerprinting. -- See comment for proxiedSMPRelayVersion. Just (Compatible vr) | thVersion >= sendingProxySMPVersion -> case thAuth of - Just THAuthClient {serverCertKey} -> PKEY srvSessId vr serverCertKey + Just THAuthClient {peerServerCertKey} -> PKEY srvSessId vr peerServerCertKey Nothing -> ERR $ transportErr TENoServerAuth _ -> ERR $ transportErr TEVersion PFWD fwdV pubKey encBlock -> do @@ -1227,18 +1392,24 @@ client mkIncProxyStats ps psOwn own sel = do incStat $ sel ps when own $ incStat $ sel psOwn - processCommand :: VersionSMP -> (Maybe (StoreQueue s, QueueRec), Transmission Cmd) -> M s (Maybe (Transmission BrokerMsg)) - processCommand clntVersion (q_, (corrId, entId, cmd)) = case cmd of + processCommand :: Maybe THPeerClientService -> VersionSMP -> (Maybe (StoreQueue s, QueueRec), Transmission Cmd) -> M s (Maybe (Transmission BrokerMsg)) + processCommand service clntVersion (q_, (corrId, entId, cmd)) = case cmd of Cmd SProxiedClient command -> processProxiedCmd (corrId, entId, command) Cmd SSender command -> Just <$> case command of SKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k SEND flags msgBody -> withQueue_ False $ sendMessage flags msgBody PING -> pure (corrId, NoEntity, PONG) - RFWD encBlock -> (corrId, NoEntity,) <$> processForwardedCommand encBlock + Cmd SProxyService (RFWD encBlock) -> Just . (corrId, NoEntity,) <$> processForwardedCommand encBlock Cmd SSenderLink command -> Just <$> case command of LKEY k -> withQueue $ \q qr -> checkMode QMMessaging qr $ secureQueue_ q k $>> getQueueLink_ q qr LGET -> withQueue $ \q qr -> checkContact qr $ getQueueLink_ q qr - Cmd SNotifier NSUB -> Just <$> subscribeNotifications + Cmd SNotifier command -> Just . (corrId,entId,) <$> case command of + NSUB -> case q_ of + Just (q, QueueRec {notifier = Just ntfCreds}) -> subscribeNotifications q ntfCreds + _ -> pure $ ERR INTERNAL + NSUBS -> case service of + Just s -> subscribeServiceNotifications s + Nothing -> pure $ ERR INTERNAL Cmd SRecipient command -> Just <$> case command of NEW nqr@NewQueueReq {auth_} -> @@ -1248,6 +1419,7 @@ client ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config pure $ allowNewQueues && maybe True ((== auth_) . Just) newQueueBasicAuth SUB -> withQueue subscribeQueue + SUBS -> pure $ err (CMD PROHIBITED) -- "TODO [certs rcv]" GET -> withQueue getMessage ACK msgId -> withQueue $ acknowledgeMsg msgId KEY sKey -> withQueue $ \q _ -> either err (corrId,entId,) <$> secureQueue_ q sKey @@ -1268,65 +1440,69 @@ client QUE -> withQueue $ \q qr -> (corrId,entId,) <$> getQueueInfo q qr where createQueue :: NewQueueReq -> M s (Transmission BrokerMsg) - createQueue NewQueueReq {rcvAuthKey, rcvDhKey, subMode, queueReqData} = time "NEW" $ do - g <- asks random - idSize <- asks $ queueIdBytes . config - updatedAt <- Just <$> liftIO getSystemDate - (rcvPublicDhKey, privDhKey) <- atomically $ C.generateKeyPair g - -- TODO [notifications] - -- ntfKeys_ <- forM ntfCreds $ \(NewNtfCreds notifierKey dhKey) -> do - -- (ntfPubDhKey, ntfPrivDhKey) <- atomically $ C.generateKeyPair g - -- pure (notifierKey, C.dh' dhKey ntfPrivDhKey, ntfPubDhKey) - let randId = EntityId <$> atomically (C.randomBytes idSize g) - -- TODO [notifications] the remaining 24 bytes are reserver for notifier ID - sndId' = B.take 24 $ C.sha3_384 (bs corrId) - tryCreate 0 = pure $ ERR INTERNAL - tryCreate n = do - (sndId, clntIds, queueData) <- case queueReqData of - Just (QRMessaging (Just (sId, d))) -> (\linkId -> (sId, True, Just (linkId, d))) <$> randId - Just (QRContact (Just (linkId, (sId, d)))) -> pure (sId, True, Just (linkId, d)) - _ -> (,False,Nothing) <$> randId - -- The condition that client-provided sender ID must match hash of correlation ID - -- prevents "ID oracle" attack, when creating queue with supplied ID can be used to check - -- if queue with this ID still exists. - if clntIds && unEntityId sndId /= sndId' - then pure $ ERR $ CMD PROHIBITED - else do - rcvId <- randId - -- TODO [notifications] - -- ntf <- forM ntfKeys_ $ \(notifierKey, rcvNtfDhSecret, rcvPubDhKey) -> do - -- notifierId <- randId - -- pure (NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}, ServerNtfCreds notifierId rcvPubDhKey) - let queueMode = queueReqMode <$> queueReqData - qr = - QueueRec - { senderId = sndId, - recipientKeys = [rcvAuthKey], - rcvDhSecret = C.dh' rcvDhKey privDhKey, - senderKey = Nothing, - queueMode, - queueData, - -- TODO [notifications] - notifier = Nothing, -- fst <$> ntf, - status = EntityActive, - updatedAt - } - liftIO (addQueue ms rcvId qr) >>= \case - Left DUPLICATE_ -- TODO [short links] possibly, we somehow need to understand which IDs caused collision to retry if it's not client-supplied? - | clntIds -> pure $ ERR AUTH -- no retry on collision if sender ID is client-supplied - | otherwise -> tryCreate (n - 1) - Left e -> pure $ ERR e - Right q -> do - stats <- asks serverStats - incStat $ qCreated stats - incStat $ qCount stats + createQueue NewQueueReq {rcvAuthKey, rcvDhKey, subMode, queueReqData} + | isJust service && subMode == SMOnlyCreate = pure (corrId, entId, ERR $ CMD PROHIBITED) + | otherwise = time "NEW" $ do + g <- asks random + idSize <- asks $ queueIdBytes . config + updatedAt <- Just <$> liftIO getSystemDate + (rcvPublicDhKey, privDhKey) <- atomically $ C.generateKeyPair g + -- TODO [notifications] + -- ntfKeys_ <- forM ntfCreds $ \(NewNtfCreds notifierKey dhKey) -> do + -- (ntfPubDhKey, ntfPrivDhKey) <- atomically $ C.generateKeyPair g + -- pure (notifierKey, C.dh' dhKey ntfPrivDhKey, ntfPubDhKey) + let randId = EntityId <$> atomically (C.randomBytes idSize g) + -- TODO [notifications] the remaining 24 bytes are reserver for notifier ID + sndId' = B.take 24 $ C.sha3_384 (bs corrId) + tryCreate 0 = pure $ ERR INTERNAL + tryCreate n = do + (sndId, clntIds, queueData) <- case queueReqData of + Just (QRMessaging (Just (sId, d))) -> (\linkId -> (sId, True, Just (linkId, d))) <$> randId + Just (QRContact (Just (linkId, (sId, d)))) -> pure (sId, True, Just (linkId, d)) + _ -> (,False,Nothing) <$> randId + -- The condition that client-provided sender ID must match hash of correlation ID + -- prevents "ID oracle" attack, when creating queue with supplied ID can be used to check + -- if queue with this ID still exists. + if clntIds && unEntityId sndId /= sndId' + then pure $ ERR $ CMD PROHIBITED + else do + rcvId <- randId -- TODO [notifications] - -- when (isJust ntf) $ incStat $ ntfCreated stats - case subMode of - SMOnlyCreate -> pure () - SMSubscribe -> void $ subscribeQueue q qr - pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId = fst <$> queueData} -- , serverNtfCreds = snd <$> ntf - (corrId,entId,) <$> tryCreate (3 :: Int) + -- ntf <- forM ntfKeys_ $ \(notifierKey, rcvNtfDhSecret, rcvPubDhKey) -> do + -- notifierId <- randId + -- pure (NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}, ServerNtfCreds notifierId rcvPubDhKey) + let queueMode = queueReqMode <$> queueReqData + rcvServiceId = (\THClientService {serviceId} -> serviceId) <$> service + qr = + QueueRec + { senderId = sndId, + recipientKeys = [rcvAuthKey], + rcvDhSecret = C.dh' rcvDhKey privDhKey, + senderKey = Nothing, + queueMode, + queueData, + -- TODO [notifications] + notifier = Nothing, -- fst <$> ntf, + status = EntityActive, + updatedAt, + rcvServiceId + } + liftIO (addQueue ms rcvId qr) >>= \case + Left DUPLICATE_ -- TODO [short links] possibly, we somehow need to understand which IDs caused collision to retry if it's not client-supplied? + | clntIds -> pure $ ERR AUTH -- no retry on collision if sender ID is client-supplied + | otherwise -> tryCreate (n - 1) + Left e -> pure $ ERR e + Right q -> do + stats <- asks serverStats + incStat $ qCreated stats + incStat $ qCount stats + -- TODO [notifications] + -- when (isJust ntf) $ incStat $ ntfCreated stats + case subMode of + SMOnlyCreate -> pure () + SMSubscribe -> void $ subscribeQueue q qr + pure $ IDS QIK {rcvId, sndId, rcvPublicDhKey, queueMode, linkId = fst <$> queueData, serviceId = rcvServiceId} -- , serverNtfCreds = snd <$> ntf + (corrId,entId,) <$> tryCreate (3 :: Int) -- this check allows to support contact queues created prior to SKEY, -- using `queueMode == Just QMContact` would prevent it, as they have queueMode `Nothing`. @@ -1358,24 +1534,25 @@ client addNotifierRetry 0 _ _ = pure $ ERR INTERNAL addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do notifierId <- randomId =<< asks (queueIdBytes . config) - let ntfCreds = NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} + let ntfCreds = NtfCreds {notifierId, notifierKey, rcvNtfDhSecret, ntfServiceId = Nothing} liftIO (addQueueNotifier (queueStore ms) q ntfCreds) >>= \case Left DUPLICATE_ -> addNotifierRetry (n - 1) rcvPublicDhKey rcvNtfDhSecret Left e -> pure $ ERR e - Right nId_ -> do + Right nc_ -> do incStat . ntfCreated =<< asks serverStats - forM_ nId_ $ \nId -> atomically $ writeTQueue (subQ ntfSubscribers) (nId, clientId, False) + forM_ nc_ $ \NtfCreds {notifierId = nId, ntfServiceId} -> + atomically $ writeTQueue (subQ ntfSubscribers) (CSDeleted nId ntfServiceId, clientId) pure $ NID notifierId rcvPublicDhKey deleteQueueNotifier_ :: StoreQueue s -> M s (Transmission BrokerMsg) deleteQueueNotifier_ q = liftIO (deleteQueueNotifier (queueStore ms) q) >>= \case - Right (Just nId) -> do + Right (Just NtfCreds {notifierId = nId, ntfServiceId}) -> do -- Possibly, the same should be done if the queue is suspended, but currently we do not use it stats <- asks serverStats deleted <- asks ntfStore >>= liftIO . (`deleteNtfs` nId) when (deleted > 0) $ liftIO $ atomicModifyIORef'_ (ntfCount stats) (subtract deleted) - atomically $ writeTQueue (subQ ntfSubscribers) (nId, clientId, False) + atomically $ writeTQueue (subQ ntfSubscribers) (CSDeleted nId ntfServiceId, clientId) incStat $ ntfDeleted stats pure ok Right Nothing -> pure ok @@ -1384,8 +1561,9 @@ client suspendQueue_ :: (StoreQueue s, QueueRec) -> M s (Transmission BrokerMsg) suspendQueue_ (q, _) = liftIO $ either err (const ok) <$> suspendQueue (queueStore ms) q + -- TODO [certs rcv] if serviceId is passed, associate with the service and respond with SOK subscribeQueue :: StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg) - subscribeQueue q qr = + subscribeQueue q qr@QueueRec {rcvServiceId} = liftIO (TM.lookupIO rId subscriptions) >>= \case Nothing -> newSub >>= deliver True Just s@Sub {subThread} -> do @@ -1402,7 +1580,7 @@ client rId = recipientId q newSub :: M s Sub newSub = time "SUB newSub" . atomically $ do - writeTQueue (subQ subscribers) (rId, clientId, True) + writeTQueue (subQ subscribers) (CSClient rId rcvServiceId Nothing, clientId) sub <- newSubscription NoSub TM.insert rId sub subscriptions pure sub @@ -1466,20 +1644,74 @@ client then action q qr else liftIO (updateQueueTime (queueStore ms) q t) >>= either (pure . err) (action q) - subscribeNotifications :: M s (Transmission BrokerMsg) - subscribeNotifications = do - statCount <- - time "NSUB" . atomically $ do - ifM - (TM.member entId ntfSubscriptions) - (pure ntfSubDuplicate) - (newSub $> ntfSub) - incStat . statCount =<< asks serverStats - pure ok - where - newSub = do - writeTQueue (subQ ntfSubscribers) (entId, clientId, True) - TM.insert entId () ntfSubscriptions + subscribeNotifications :: StoreQueue s -> NtfCreds -> M s BrokerMsg + subscribeNotifications q NtfCreds {ntfServiceId} = do + stats <- asks serverStats + let incNtfSrvStat sel = incStat $ sel $ ntfServices stats + case service of + Just THClientService {serviceId} + | ntfServiceId == Just serviceId -> do + -- duplicate queue-service association - can only happen in case of response error/timeout + hasSub <- atomically $ ifM hasServiceSub (pure True) (False <$ newServiceQueueSub) + unless hasSub $ do + incNtfSrvStat srvSubCount + incNtfSrvStat srvSubQueues + incNtfSrvStat srvAssocDuplicate + pure $ SOK $ Just serviceId + | otherwise -> + -- new or updated queue-service association + liftIO (setQueueService (queueStore ms) q SNotifier (Just serviceId)) >>= \case + Left e -> pure $ ERR e + Right () -> do + hasSub <- atomically $ (<$ newServiceQueueSub) =<< hasServiceSub + unless hasSub $ incNtfSrvStat srvSubCount + incNtfSrvStat srvSubQueues + incNtfSrvStat $ maybe srvAssocNew (const srvAssocUpdated) ntfServiceId + pure $ SOK $ Just serviceId + where + hasServiceSub = (0 /=) <$> readTVar ntfServiceSubsCount + -- This function is used when queue is associated with the service. + newServiceQueueSub = do + writeTQueue (subQ ntfSubscribers) (CSClient entId ntfServiceId (Just serviceId), clientId) + modifyTVar' ntfServiceSubsCount (+ 1) -- service count + modifyTVar' (totalServiceSubs ntfSubscribers) (+ 1) -- server count for all services + Nothing -> case ntfServiceId of + Just _ -> + liftIO (setQueueService (queueStore ms) q SNotifier Nothing) >>= \case + Left e -> pure $ ERR e + Right () -> do + -- hasSubscription should never be True in this branch, because queue was associated with service. + -- So unless storage and session states diverge, this check is redundant. + hasSub <- atomically $ hasSubscription >>= newSub + incNtfSrvStat srvAssocRemoved + sok hasSub + Nothing -> do + hasSub <- atomically $ ifM hasSubscription (pure True) (newSub False) + sok hasSub + where + hasSubscription = TM.member entId ntfSubscriptions + newSub hasSub = do + writeTQueue (subQ ntfSubscribers) (CSClient entId ntfServiceId Nothing, clientId) + unless (hasSub) $ TM.insert entId () ntfSubscriptions + pure hasSub + sok hasSub = do + incStat $ if hasSub then ntfSubDuplicate stats else ntfSub stats + pure $ SOK Nothing + + subscribeServiceNotifications :: THPeerClientService -> M s BrokerMsg + subscribeServiceNotifications THClientService {serviceId} = do + srvSubs <- readTVarIO ntfServiceSubsCount + if srvSubs == 0 + then + liftIO (getNtfServiceQueueCount @(StoreQueue s) (queueStore ms) serviceId) >>= \case + Left e -> pure $ ERR e + Right count -> do + atomically $ do + modifyTVar' ntfServiceSubsCount (+ count) -- service count + modifyTVar' (totalServiceSubs ntfSubscribers) (+ count) -- server count for all services + atomically $ writeTQueue (subQ ntfSubscribers) (CSService serviceId, clientId) + pure $ SOKS count + else pure $ SOKS srvSubs acknowledgeMsg :: MsgId -> StoreQueue s -> QueueRec -> M s (Transmission BrokerMsg) acknowledgeMsg msgId q qr = time "ACK" $ do @@ -1643,7 +1875,6 @@ client enqueueNotification :: NtfCreds -> Message -> M s () enqueueNotification _ MessageQuota {} = pure () enqueueNotification NtfCreds {notifierId = nId, rcvNtfDhSecret} Message {msgId, msgTs} = do - -- stats <- asks serverStats ns <- asks ntfStore ntf <- mkMessageNotification msgId msgTs rcvNtfDhSecret liftIO $ storeNtf ns nId ntf @@ -1671,16 +1902,16 @@ client t' <- case tParse clntTHParams b of t :| [] -> pure $ tDecodeParseValidate clntTHParams t _ -> throwE BLOCK - let clntThAuth = Just $ THAuthServer {serverPrivKey, sessSecret' = Just clientSecret} + let clntThAuth = Just $ THAuthServer {serverPrivKey, peerClientService = Nothing, sessSecret' = Just clientSecret} -- process forwarded command r <- lift (rejectOrVerify clntThAuth t') >>= \case Left r -> pure r -- rejectOrVerify filters allowed commands, no need to repeat it here. -- INTERNAL is used because processCommand never returns Nothing for sender commands (could be extracted for better types). - Right t''@(_, (corrId', entId', _)) -> fromMaybe (corrId', entId', ERR INTERNAL) <$> lift (processCommand fwdVersion t'') + Right t''@(_, (corrId', entId', _)) -> fromMaybe (corrId', entId', ERR INTERNAL) <$> lift (processCommand Nothing fwdVersion t'') -- encode response - r' <- case batchTransmissions (batch clntTHParams) (blockSize clntTHParams) [Right (Nothing, encodeTransmission clntTHParams r)] of + r' <- case batchTransmissions clntTHParams [Right (Nothing, encodeTransmission clntTHParams r)] of [] -> throwE INTERNAL -- at least 1 item is guaranteed from NonEmpty/Right TBError _ _ : _ -> throwE BLOCK TBTransmission b' _ : _ -> pure b' @@ -1699,7 +1930,7 @@ client case cmdOrError of Left e -> pure $ Left (corrId', entId', ERR e) Right cmd' - | allowed -> verified <$> verifyTransmission ms ((,C.cbNonce (bs corrId')) <$> clntThAuth) tAuth authorized entId' cmd' + | allowed -> liftIO $ verified <$> verifyTransmission ms Nothing ((,C.cbNonce (bs corrId')) <$> clntThAuth) tAuth authorized entId' cmd' | otherwise -> pure $ Left (corrId', entId', ERR $ CMD PROHIBITED) where allowed = case cmd' of @@ -1710,7 +1941,8 @@ client _ -> False verified = \case VRVerified q -> Right (q, (corrId', entId', cmd')) - VRFailed -> Left (corrId', entId', ERR AUTH) + VRFailed e -> Left (corrId', entId', ERR e) + deliverMessage :: T.Text -> QueueRec -> RecipientId -> Sub -> Maybe Message -> IO (Transmission BrokerMsg) deliverMessage name qr rId s@Sub {subThread} msg_ = time (name <> " deliver") . atomically $ case subThread of @@ -1740,22 +1972,22 @@ client setDelivered s msg = tryPutTMVar (delivered s) $! messageId msg delQueueAndMsgs :: (StoreQueue s, QueueRec) -> M s (Transmission BrokerMsg) - delQueueAndMsgs (q, _) = do + delQueueAndMsgs (q, QueueRec {rcvServiceId}) = do liftIO (deleteQueue ms q) >>= \case Right qr -> do -- Possibly, the same should be done if the queue is suspended, but currently we do not use it atomically $ do - writeTQueue (subQ subscribers) (entId, clientId, False) + writeTQueue (subQ subscribers) (CSDeleted entId rcvServiceId, clientId) -- queue is usually deleted by the same client that is currently subscribed, -- we delete subscription here, so the client with no subscriptions can be disconnected. TM.delete entId subscriptions - forM_ (notifierId <$> notifier qr) $ \nId -> do + forM_ (notifier qr) $ \NtfCreds {notifierId = nId, ntfServiceId} -> do -- queue is deleted by a different client from the one subscribed to notifications, -- so we don't need to remove subscription from the current client. stats <- asks serverStats deleted <- asks ntfStore >>= liftIO . (`deleteNtfs` nId) when (deleted > 0) $ liftIO $ atomicModifyIORef'_ (ntfCount stats) (subtract deleted) - atomically $ writeTQueue (subQ ntfSubscribers) (nId, clientId, False) + atomically $ writeTQueue (subQ ntfSubscribers) (CSDeleted nId ntfServiceId, clientId) updateDeletedStats qr pure ok Left e -> pure $ err e @@ -1886,10 +2118,10 @@ importMessages tty ms f old_ skipWarnings = do renameFile f $ f <> ".bak" mapM_ setOverQuota_ overQuota logQueueStates ms - QueueCounts {queueCount} <- liftIO $ queueCounts @(StoreQueue s) $ queueStore ms + EntityCounts {queueCount} <- liftIO $ getEntityCounts @(StoreQueue s) $ queueStore ms pure MessageStats {storedMsgsCount, expiredMsgsCount, storedQueues = queueCount} where - restoreMsg :: (Maybe (RecipientId, StoreQueue s), (Int, Int, M.Map RecipientId (StoreQueue s))) -> Bool -> ByteString -> IO (Maybe (RecipientId, StoreQueue s), (Int, Int, M.Map RecipientId (StoreQueue s))) + restoreMsg :: (Maybe (RecipientId, StoreQueue s), (Int, Int, Map RecipientId (StoreQueue s))) -> Bool -> ByteString -> IO (Maybe (RecipientId, StoreQueue s), (Int, Int, Map RecipientId (StoreQueue s))) restoreMsg (q_, counts@(!stored, !expired, !overQuota)) eof s = case strDecode s of Right (MLRv3 rId msg) -> runExceptT (addToMsgQueue rId msg) >>= either (exitErr . tshow) pure Left e @@ -2014,7 +2246,7 @@ restoreServerStats msgStats_ ntfStats = asks (serverStatsBackupFile . config) >> Right d@ServerStatsData {_qCount = statsQCount, _msgCount = statsMsgCount, _ntfCount = statsNtfCount} -> do s <- asks serverStats st <- asks msgStore - QueueCounts {queueCount = _qCount} <- liftIO $ queueCounts @(StoreQueue s) $ queueStore st + EntityCounts {queueCount = _qCount} <- liftIO $ getEntityCounts @(StoreQueue s) $ queueStore st let _msgCount = maybe statsMsgCount storedMsgsCount msgStats_ _ntfCount = storedMsgsCount ntfStats _msgExpired' = _msgExpired d + maybe 0 expiredMsgsCount msgStats_ diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 361d0e6f4..7819c297e 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -32,7 +32,7 @@ module Simplex.Messaging.Server.Env.STM ProxyAgent (..), Client (..), ClientId, - Subscribed, + ClientSub (..), Sub (..), ServerSub (..), SubscriptionThread (..), @@ -51,6 +51,7 @@ module Simplex.Messaging.Server.Env.STM getSubscribedClients, getSubscribedClient, upsertSubscribedClient, + lookupSubscribedClient, lookupDeleteSubscribedClient, deleteSubcribedClient, sameClientId, @@ -78,7 +79,6 @@ import Control.Logger.Simple import Control.Monad import qualified Crypto.PubKey.RSA as RSA import Crypto.Random -import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) import Data.IntMap.Strict (IntMap) import qualified Data.IntMap.Strict as IM @@ -120,7 +120,7 @@ import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.Server.StoreLog.ReadWrite import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (ASrvTransport, VersionRangeSMP, VersionSMP) +import Simplex.Messaging.Transport (ASrvTransport, SMPVersion, THandleParams, TransportPeer (..), VersionRangeSMP) import Simplex.Messaging.Transport.Server import Simplex.Messaging.Util (ifM, whenM, ($>>=)) import System.Directory (doesFileExist) @@ -299,8 +299,6 @@ data MsgStore s where StoreMemory :: STMMsgStore -> MsgStore STMMsgStore StoreJournal :: JournalMsgStore qs -> MsgStore (JournalMsgStore qs) -type Subscribed = Bool - data Server s = Server { clients :: ServerClients s, subscribers :: ServerSubscribers s, @@ -312,9 +310,11 @@ data Server s = Server newtype ServerClients s = ServerClients {serverClients :: TVar (IntMap (Client s))} data ServerSubscribers s = ServerSubscribers - { subQ :: TQueue (QueueId, ClientId, Subscribed), + { subQ :: TQueue (ClientSub, ClientId), queueSubscribers :: SubscribedClients s, - subClients :: TVar IntSet, + serviceSubscribers :: SubscribedClients s, -- service clients with long-term certificates that have subscriptions + totalServiceSubs :: TVar Int64, + subClients :: TVar IntSet, -- clients with individual or service subscriptions pendingEvents :: TVar (IntMap (NonEmpty (EntityId, BrokerMsg))) } @@ -344,10 +344,15 @@ upsertSubscribedClient entId c (SubscribedClients cs) = Just c' | sameClientId c c' -> pure Nothing c_ -> c_ <$ writeTVar cv (Just c) +lookupSubscribedClient :: EntityId -> SubscribedClients s -> STM (Maybe (Client s)) +lookupSubscribedClient entId (SubscribedClients cs) = TM.lookup entId cs $>>= readTVar +{-# INLINE lookupSubscribedClient #-} + -- lookup and delete currently subscribed client lookupDeleteSubscribedClient :: EntityId -> SubscribedClients s -> STM (Maybe (Client s)) lookupDeleteSubscribedClient entId (SubscribedClients cs) = TM.lookupDelete entId cs $>>= (`swapTVar` Nothing) +{-# INLINE lookupDeleteSubscribedClient #-} deleteSubcribedClient :: EntityId -> Client s -> SubscribedClients s -> IO () deleteSubcribedClient entId c (SubscribedClients cs) = @@ -368,6 +373,11 @@ sameClient :: Client s -> TVar (Maybe (Client s)) -> STM Bool sameClient c cv = maybe False (sameClientId c) <$> readTVar cv {-# INLINE sameClient #-} +data ClientSub + = CSClient QueueId (Maybe ServiceId) (Maybe ServiceId) -- includes previous and new associated service IDs + | CSDeleted QueueId (Maybe ServiceId) -- includes previously associated service IDs + | CSService ServiceId -- only send END to idividual client subs on message delivery, not of SSUB/NSSUB + newtype ProxyAgent = ProxyAgent { smpAgent :: SMPClientAgent 'Sender } @@ -378,14 +388,15 @@ data Client s = Client { clientId :: ClientId, subscriptions :: TMap RecipientId Sub, ntfSubscriptions :: TMap NotifierId (), + serviceSubsCount :: TVar Int64, -- only one service can be subscribed, based on its certificate, this is subscription count + ntfServiceSubsCount :: TVar Int64, -- only one service can be subscribed, based on its certificate, this is subscription count rcvQ :: TBQueue (NonEmpty (Maybe (StoreQueue s, QueueRec), Transmission Cmd)), sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), msgQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), procThreads :: TVar Int, endThreads :: TVar (IntMap (Weak ThreadId)), endThreadSeq :: TVar Int, - thVersion :: VersionSMP, - sessionId :: ByteString, + clientTHParams :: THandleParams SMPVersion 'TServer, connected :: TVar Bool, createdAt :: SystemTime, rcvActiveAt :: TVar SystemTime, @@ -434,14 +445,18 @@ newServerSubscribers :: IO (ServerSubscribers s) newServerSubscribers = do subQ <- newTQueueIO queueSubscribers <- SubscribedClients <$> TM.emptyIO + serviceSubscribers <- SubscribedClients <$> TM.emptyIO + totalServiceSubs <- newTVarIO 0 subClients <- newTVarIO IS.empty pendingEvents <- newTVarIO IM.empty - pure ServerSubscribers {subQ, queueSubscribers, subClients, pendingEvents} + pure ServerSubscribers {subQ, queueSubscribers, serviceSubscribers, totalServiceSubs, subClients, pendingEvents} -newClient :: ClientId -> Natural -> VersionSMP -> ByteString -> SystemTime -> IO (Client s) -newClient clientId qSize thVersion sessionId createdAt = do +newClient :: ClientId -> Natural -> THandleParams SMPVersion 'TServer -> SystemTime -> IO (Client s) +newClient clientId qSize clientTHParams createdAt = do subscriptions <- TM.emptyIO ntfSubscriptions <- TM.emptyIO + serviceSubsCount <- newTVarIO 0 + ntfServiceSubsCount <- newTVarIO 0 rcvQ <- newTBQueueIO qSize sndQ <- newTBQueueIO qSize msgQ <- newTBQueueIO qSize @@ -456,14 +471,15 @@ newClient clientId qSize thVersion sessionId createdAt = do { clientId, subscriptions, ntfSubscriptions, + serviceSubsCount, + ntfServiceSubsCount, rcvQ, sndQ, msgQ, procThreads, endThreads, endThreadSeq, - thVersion, - sessionId, + clientTHParams, connected, createdAt, rcvActiveAt, @@ -623,5 +639,5 @@ newSMPProxyAgent smpAgentCfg random = do smpAgent <- newSMPClientAgent SSender smpAgentCfg random pure ProxyAgent {smpAgent} -readWriteQueueStore :: forall q s. QueueStoreClass q s => Bool -> (RecipientId -> QueueRec -> IO q) -> FilePath -> s -> IO (StoreLog 'WriteMode) +readWriteQueueStore :: forall q. StoreQueueClass q => Bool -> (RecipientId -> QueueRec -> IO q) -> FilePath -> STMQueueStore q -> IO (StoreLog 'WriteMode) readWriteQueueStore tty mkQ = readWriteStoreLog (readQueueStore tty mkQ) (writeQueueStore @q) diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 25a9123bd..8ffc7c9e2 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -30,6 +30,7 @@ import Data.Ini (Ini, lookupValue, readIniFile) import Data.Int (Int64) import Data.List (find, isPrefixOf) import qualified Data.List.NonEmpty as L +import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe, isJust, isNothing) import Data.Text (Text) import qualified Data.Text as T @@ -71,9 +72,10 @@ import Simplex.Messaging.Agent.Store.Postgres (checkSchemaExists) import Simplex.Messaging.Server.MsgStore.Journal (JournalQueue) import Simplex.Messaging.Server.MsgStore.Types (QSType (..)) import Simplex.Messaging.Server.MsgStore.Journal (postgresQueueStore) -import Simplex.Messaging.Server.QueueStore.Postgres (batchInsertQueues, foldQueueRecs) +import Simplex.Messaging.Server.QueueStore.Postgres (batchInsertQueues, batchInsertServices, foldQueueRecs, foldServiceRecs) +import Simplex.Messaging.Server.QueueStore.STM (STMQueueStore (..)) import Simplex.Messaging.Server.QueueStore.Types -import Simplex.Messaging.Server.StoreLog (closeStoreLog, logCreateQueue, openWriteStoreLog) +import Simplex.Messaging.Server.StoreLog (closeStoreLog, logNewService, logCreateQueue, openWriteStoreLog) import System.Directory (renameFile) #endif @@ -180,8 +182,8 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = confirmOrExit ("WARNING: store log file " <> storeLogFile <> " will be compacted and imported to PostrgreSQL database: " <> B.unpack connstr <> ", schema: " <> B.unpack schema) "Queue records not imported" - qCnt <- importStoreLogToDatabase logPath storeLogFile dbOpts - putStrLn $ "Import completed: " <> show qCnt <> " queues" + (sCnt, qCnt) <- importStoreLogToDatabase logPath storeLogFile dbOpts + putStrLn $ "Import completed: " <> show sCnt <> " services, " <> show qCnt <> " queues" putStrLn $ case readStoreType ini of Right (ASType SQSMemory SMSMemory) -> setToDbStr <> "\nstore_messages set to `memory`, import messages to journal to use PostgreSQL database for queues (`smp-server journal import`)" Right (ASType SQSMemory SMSJournal) -> setToDbStr @@ -202,8 +204,8 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = confirmOrExit ("WARNING: PostrgreSQL database schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to store log file " <> storeLogFilePath) "Queue records not exported" - qCnt <- exportDatabaseToStoreLog logPath dbOpts storeLogFilePath - putStrLn $ "Export completed: " <> show qCnt <> " queues" + (sCnt, qCnt) <- exportDatabaseToStoreLog logPath dbOpts storeLogFilePath + putStrLn $ "Export completed: " <> show sCnt <> " services, " <> show qCnt <> " queues" putStrLn $ case readStoreType ini of Right (ASType SQSPostgres SMSJournal) -> "store_queues set to `database`, update it to `memory` in INI file." Right (ASType SQSMemory _) -> "store_queues set to `memory`, start the server" @@ -442,12 +444,13 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = prometheusInterval = eitherToMaybe $ read . T.unpack <$> lookupValue "STORE_LOG" "prometheus_interval" ini, prometheusMetricsFile = combine logPath "smp-server-metrics.txt", pendingENDInterval = 15000000, -- 15 seconds - ntfDeliveryInterval = 3000000, -- 3 seconds + ntfDeliveryInterval = 1500000, -- 1.5 second smpServerVRange = supportedServerSMPRelayVRange, transportConfig = mkTransportServerConfig (fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini) - (Just alpnSupportedSMPHandshakes), + (Just alpnSupportedSMPHandshakes) + (fromMaybe True $ iniOnOff "TRANSPORT" "accept_service_credentials" ini), -- TODO [certs] remove this option controlPort = eitherToMaybe $ T.unpack <$> lookupValue "TRANSPORT" "control_port" ini, smpAgentCfg = defaultSMPClientAgentConfig @@ -554,26 +557,30 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = putStrLn "Configure queue storage." exitFailure -importStoreLogToDatabase :: FilePath -> FilePath -> DBOpts -> IO Int64 +importStoreLogToDatabase :: FilePath -> FilePath -> DBOpts -> IO (Int64, Int64) importStoreLogToDatabase logPath storeLogFile dbOpts = do ms <- newJournalMsgStore logPath MQStoreCfg - sl <- readWriteQueueStore True (mkQueue ms False) storeLogFile (queueStore ms) + let st = stmQueueStore ms + sl <- readWriteQueueStore True (mkQueue ms False) storeLogFile st closeStoreLog sl - queues <- readTVarIO $ loadedQueues $ stmQueueStore ms + queues <- readTVarIO $ loadedQueues st + services' <- M.elems <$> readTVarIO (services st) let storeCfg = PostgresStoreCfg {dbOpts = dbOpts {createSchema = True}, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = 86400 * defaultDeletedTTL} ps <- newJournalMsgStore logPath $ PQStoreCfg storeCfg + sCnt <- batchInsertServices services' $ postgresQueueStore ps qCnt <- batchInsertQueues @(JournalQueue 'QSMemory) True queues $ postgresQueueStore ps renameFile storeLogFile $ storeLogFile <> ".bak" - pure qCnt + pure (sCnt, qCnt) -exportDatabaseToStoreLog :: FilePath -> DBOpts -> FilePath -> IO Int +exportDatabaseToStoreLog :: FilePath -> DBOpts -> FilePath -> IO (Int, Int) exportDatabaseToStoreLog logPath dbOpts storeLogFilePath = do let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = 86400 * defaultDeletedTTL} ps <- newJournalMsgStore logPath $ PQStoreCfg storeCfg sl <- openWriteStoreLog False storeLogFilePath + Sum sCnt <- foldServiceRecs (postgresQueueStore ps) $ \sr -> logNewService sl sr $> Sum (1 :: Int) Sum qCnt <- foldQueueRecs True True (postgresQueueStore ps) Nothing $ \(rId, qr) -> logCreateQueue sl rId qr $> Sum (1 :: Int) closeStoreLog sl - pure qCnt + pure (sCnt, qCnt) #endif newJournalMsgStore :: FilePath -> QStoreCfg s -> IO (JournalMsgStore s) diff --git a/src/Simplex/Messaging/Server/MsgStore/Journal.hs b/src/Simplex/Messaging/Server/MsgStore/Journal.hs index d28300a75..c1fc94c08 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Journal.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Journal.hs @@ -61,11 +61,12 @@ import qualified Data.ByteString.Char8 as B import Data.Either (fromRight) import Data.Functor (($>)) import Data.Int (Int64) -import Data.List (intercalate, sort) +import Data.List (sort) import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe, isJust, isNothing, mapMaybe) import Data.Text (Text) import qualified Data.Text as T +import Data.Text.Encoding (decodeLatin1) import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Data.Time.Format.ISO8601 (iso8601Show, iso8601ParseM) @@ -296,7 +297,7 @@ instance StoreQueueClass (JournalQueue s) where {-# INLINE queueRec #-} msgQueue = msgQueue' {-# INLINE msgQueue #-} - withQueueLock :: JournalQueue s -> String -> IO a -> IO a + withQueueLock :: JournalQueue s -> Text -> IO a -> IO a withQueueLock JournalQueue {recipientId', queueLock, sharedLock} = withLockWaitShared recipientId' queueLock sharedLock {-# INLINE withQueueLock #-} @@ -317,8 +318,8 @@ instance QueueStoreClass (JournalQueue s) (QStore s) where {-# INLINE loadedQueues #-} compactQueues = withQS (compactQueues @(JournalQueue s)) {-# INLINE compactQueues #-} - queueCounts = withQS (queueCounts @(JournalQueue s)) - {-# INLINE queueCounts #-} + getEntityCounts = withQS (getEntityCounts @(JournalQueue s)) + {-# INLINE getEntityCounts #-} addQueue_ = withQS addQueue_ {-# INLINE addQueue_ #-} getQueue_ = withQS getQueue_ @@ -347,6 +348,14 @@ instance QueueStoreClass (JournalQueue s) (QStore s) where {-# INLINE updateQueueTime #-} deleteStoreQueue = withQS deleteStoreQueue {-# INLINE deleteStoreQueue #-} + getCreateService = withQS (getCreateService @(JournalQueue s)) + {-# INLINE getCreateService #-} + setQueueService = withQS setQueueService + {-# INLINE setQueueService #-} + getQueueNtfServices = withQS (getQueueNtfServices @(JournalQueue s)) + {-# INLINE getQueueNtfServices #-} + getNtfServiceQueueCount = withQS (getNtfServiceQueueCount @(JournalQueue s)) + {-# INLINE getNtfServiceQueueCount #-} makeQueue_ :: JournalMsgStore s -> RecipientId -> QueueRec -> Lock -> IO (JournalQueue s) makeQueue_ JournalMsgStore {sharedLock} rId qr queueLock = do @@ -377,7 +386,7 @@ instance MsgStoreClass (JournalMsgStore s) where queueLocks <- TM.emptyIO sharedLock <- newEmptyTMVarIO queueStore_ <- newQueueStore @(JournalQueue s) queueStoreCfg - openedQueueCount <- newTVarIO 0 + openedQueueCount <- newTVarIO 0 expireBackupsBefore <- addUTCTime (- expireBackupsAfter config) <$> getCurrentTime pure JournalMsgStore {config, random, queueLocks, sharedLock, queueStore_, openedQueueCount, expireBackupsBefore} @@ -396,7 +405,7 @@ instance MsgStoreClass (JournalMsgStore s) where -- It does not cache queues and is NOT concurrency safe. unsafeWithAllMsgQueues :: Monoid a => Bool -> Bool -> JournalMsgStore s -> (JournalQueue s -> IO a) -> IO a unsafeWithAllMsgQueues tty withData ms action = case queueStore_ ms of - MQStore st -> withLoadedQueues st run + MQStore st -> withLoadedQueues st run #if defined(dbServerPostgres) PQStore st -> foldQueueRecs tty withData st Nothing $ uncurry (mkQueue ms False) >=> run #endif @@ -638,28 +647,28 @@ instance MsgStoreClass (JournalMsgStore s) where $>>= \len -> readTVarIO handles $>>= \hs -> updateReadPos q mq logState len hs $> Just () - isolateQueue :: JournalQueue s -> String -> StoreIO s a -> ExceptT ErrorType IO a + isolateQueue :: JournalQueue s -> Text -> StoreIO s a -> ExceptT ErrorType IO a isolateQueue sq op = tryStore' op (recipientId' sq) . withQueueLock sq op . unStoreIO - unsafeRunStore :: JournalQueue s -> String -> StoreIO s a -> IO a + unsafeRunStore :: JournalQueue s -> Text -> StoreIO s a -> IO a unsafeRunStore sq op a = unStoreIO a `E.catch` \e -> storeError op (recipientId' sq) e >> E.throwIO e updateActiveAt :: JournalQueue s -> IO () updateActiveAt q = atomically . writeTVar (activeAt q) . systemSeconds =<< getSystemTime -tryStore' :: String -> RecipientId -> IO a -> ExceptT ErrorType IO a +tryStore' :: Text -> RecipientId -> IO a -> ExceptT ErrorType IO a tryStore' op rId = tryStore op rId . fmap Right -tryStore :: forall a. String -> RecipientId -> IO (Either ErrorType a) -> ExceptT ErrorType IO a +tryStore :: forall a. Text -> RecipientId -> IO (Either ErrorType a) -> ExceptT ErrorType IO a tryStore op rId a = ExceptT $ E.mask_ $ a `E.catch` storeError op rId -storeError :: String -> RecipientId -> E.SomeException -> IO (Either ErrorType a) +storeError :: Text -> RecipientId -> E.SomeException -> IO (Either ErrorType a) storeError op rId e = - let e' = intercalate ", " [op, B.unpack $ strEncode rId, show e] - in logError ("STORE: " <> T.pack e') $> Left (STORE e') + let e' = T.intercalate ", " [op, decodeLatin1 $ strEncode rId, tshow e] + in logError ("STORE: " <> e') $> Left (STORE e') -isolateQueueId :: String -> JournalMsgStore s -> RecipientId -> IO (Either ErrorType a) -> ExceptT ErrorType IO a +isolateQueueId :: Text -> JournalMsgStore s -> RecipientId -> IO (Either ErrorType a) -> ExceptT ErrorType IO a isolateQueueId op JournalMsgStore {queueLocks, sharedLock} rId = tryStore op rId . withLockMapWaitShared rId queueLocks sharedLock op diff --git a/src/Simplex/Messaging/Server/MsgStore/Journal/SharedLock.hs b/src/Simplex/Messaging/Server/MsgStore/Journal/SharedLock.hs index 4e09f3895..87b7294f1 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Journal/SharedLock.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Journal/SharedLock.hs @@ -8,6 +8,7 @@ where import Control.Concurrent.STM import qualified Control.Exception as E import Control.Monad +import Data.Text (Text) import Simplex.Messaging.Agent.Lock import Simplex.Messaging.Agent.Client (getMapLock) import Simplex.Messaging.Protocol (RecipientId) @@ -16,14 +17,14 @@ import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (($>>), ($>>=)) -- wait until shared lock with passed ID is released and take lock -withLockWaitShared :: RecipientId -> Lock -> TMVar RecipientId -> String -> IO a -> IO a +withLockWaitShared :: RecipientId -> Lock -> TMVar RecipientId -> Text -> IO a -> IO a withLockWaitShared rId lock shared name = E.bracket_ (atomically $ waitShared rId shared >> putTMVar lock name) (void $ atomically $ takeTMVar lock) -- wait until shared lock with passed ID is released and take lock from Map for this ID -withLockMapWaitShared :: RecipientId -> TMap RecipientId Lock -> TMVar RecipientId -> String -> IO a -> IO a +withLockMapWaitShared :: RecipientId -> TMap RecipientId Lock -> TMVar RecipientId -> Text -> IO a -> IO a withLockMapWaitShared rId locks shared name a = E.bracket (atomically $ waitShared rId shared >> getPutLock (getMapLock locks) rId name) diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index afde3ff82..ed24e85a4 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -24,6 +24,7 @@ import Control.Monad.Trans.Except import Data.Functor (($>)) import Data.Int (Int64) import qualified Data.Map.Strict as M +import Data.Text (Text) import Simplex.Messaging.Protocol import Simplex.Messaging.Server.MsgStore.Types import Simplex.Messaging.Server.QueueStore @@ -178,10 +179,10 @@ instance MsgStoreClass STMMsgStore where Just _ -> modifyTVar' size (subtract 1) _ -> pure () - isolateQueue :: STMQueue -> String -> STM a -> ExceptT ErrorType IO a + isolateQueue :: STMQueue -> Text -> STM a -> ExceptT ErrorType IO a isolateQueue _ _ = liftIO . atomically {-# INLINE isolateQueue #-} - unsafeRunStore :: STMQueue -> String -> STM a -> IO a + unsafeRunStore :: STMQueue -> Text -> STM a -> IO a unsafeRunStore _ _ = atomically {-# INLINE unsafeRunStore #-} diff --git a/src/Simplex/Messaging/Server/MsgStore/Types.hs b/src/Simplex/Messaging/Server/MsgStore/Types.hs index 82778b5a4..e0d32482d 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Types.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Types.hs @@ -6,6 +6,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} @@ -22,6 +23,7 @@ import Data.Functor (($>)) import Data.Int (Int64) import Data.Kind import Data.Maybe (fromMaybe) +import Data.Text (Text) import Data.Time.Clock.System (SystemTime (systemSeconds)) import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore @@ -61,8 +63,8 @@ class (Monad (StoreMonad s), QueueStoreClass (StoreQueue s) (QueueStore s)) => M getQueueSize_ :: MsgQueue (StoreQueue s) -> StoreMonad s Int tryPeekMsg_ :: StoreQueue s -> MsgQueue (StoreQueue s) -> StoreMonad s (Maybe Message) tryDeleteMsg_ :: StoreQueue s -> MsgQueue (StoreQueue s) -> Bool -> StoreMonad s () - isolateQueue :: StoreQueue s -> String -> StoreMonad s a -> ExceptT ErrorType IO a - unsafeRunStore :: StoreQueue s -> String -> StoreMonad s a -> IO a + isolateQueue :: StoreQueue s -> Text -> StoreMonad s a -> ExceptT ErrorType IO a + unsafeRunStore :: StoreQueue s -> Text -> StoreMonad s a -> IO a data MSType = MSMemory | MSJournal @@ -141,7 +143,7 @@ tryDelPeekMsg st q msgId' = | otherwise -> pure (Nothing, Just msg) -- The action is called with Nothing when it is known that the queue is empty -withPeekMsgQueue :: MsgStoreClass s => s -> StoreQueue s -> String -> (Maybe (MsgQueue (StoreQueue s), Message) -> StoreMonad s a) -> ExceptT ErrorType IO a +withPeekMsgQueue :: MsgStoreClass s => s -> StoreQueue s -> Text -> (Maybe (MsgQueue (StoreQueue s), Message) -> StoreMonad s a) -> ExceptT ErrorType IO a withPeekMsgQueue st q op a = isolateQueue q op $ getPeekMsgQueue st q >>= a {-# INLINE withPeekMsgQueue #-} diff --git a/src/Simplex/Messaging/Server/Prometheus.hs b/src/Simplex/Messaging/Server/Prometheus.hs index 2aea7ac6a..de16873ee 100644 --- a/src/Simplex/Messaging/Server/Prometheus.hs +++ b/src/Simplex/Messaging/Server/Prometheus.hs @@ -13,16 +13,17 @@ import Data.Time.Clock.System (systemEpochDay) import Data.Time.Format.ISO8601 (iso8601Show) import Network.Socket (ServiceName) import Simplex.Messaging.Server.MsgStore.Types (LoadedQueueCounts (..)) +import Simplex.Messaging.Server.QueueStore.Types (EntityCounts (..)) import Simplex.Messaging.Server.Stats import Simplex.Messaging.Transport (simplexMQVersion) import Simplex.Messaging.Transport.Server (SocketStats (..)) +import Simplex.Messaging.Util (tshow) data ServerMetrics = ServerMetrics { statsData :: ServerStatsData, activeQueueCounts :: PeriodStatCounts, activeNtfCounts :: PeriodStatCounts, - queueCount :: Int, - notifierCount :: Int, + entityCounts :: EntityCounts, rtsOptions :: Text } @@ -40,15 +41,16 @@ data RealTimeMetrics = RealTimeMetrics data RTSubscriberMetrics = RTSubscriberMetrics { subsCount :: Int, - subClientsCount :: Int + subClientsCount :: Int, + subServicesCount :: Int } {-# FOURMOLU_DISABLE\n#-} prometheusMetrics :: ServerMetrics -> RealTimeMetrics -> UTCTime -> Text prometheusMetrics sm rtm ts = - time <> queues <> subscriptions <> messages <> ntfMessages <> ntfs <> relays <> info + time <> queues <> subscriptions <> messages <> ntfMessages <> ntfs <> relays <> services <> info where - ServerMetrics {statsData, activeQueueCounts = ps, activeNtfCounts = psNtf, queueCount, notifierCount, rtsOptions} = sm + ServerMetrics {statsData, activeQueueCounts = ps, activeNtfCounts = psNtf, entityCounts, rtsOptions} = sm RealTimeMetrics { socketStats, threadsCount, @@ -105,6 +107,8 @@ prometheusMetrics sm rtm ts = _pMsgFwds, _pMsgFwdsOwn, _pMsgFwdsRecv, + _rcvServices, + _ntfServices, _qCount, _msgCount, _ntfCount @@ -145,7 +149,7 @@ prometheusMetrics sm rtm ts = \\n\ \# HELP simplex_smp_queues_total2 Total number of stored queues (second type of count).\n\ \# TYPE simplex_smp_queues_total2 gauge\n\ - \simplex_smp_queues_total2 " <> mshow queueCount <> "\n# qCount2\n\ + \simplex_smp_queues_total2 " <> mshow (queueCount entityCounts) <> "\n# qCount2\n\ \\n\ \# HELP simplex_smp_queues_daily Daily active queues.\n\ \# TYPE simplex_smp_queues_daily gauge\n\ @@ -269,7 +273,7 @@ prometheusMetrics sm rtm ts = \\n\ \# HELP simplex_smp_queues_notify_total2 Total number of stored queues with notification flag (second type of count).\n\ \# TYPE simplex_smp_queues_notify_total2 gauge\n\ - \simplex_smp_queues_notify_total2 " <> mshow notifierCount <> "\n# ntfCount2\n\ + \simplex_smp_queues_notify_total2 " <> mshow (notifierCount entityCounts) <> "\n# ntfCount2\n\ \\n" ntfs = "# Notifications (server)\n\ @@ -348,6 +352,60 @@ prometheusMetrics sm rtm ts = \# TYPE simplex_smp_relay_messages_received counter\n\ \simplex_smp_relay_messages_received " <> mshow _pMsgFwdsRecv <> "\n# pMsgFwdsRecv\n\ \\n" + services = + "# Services\n\ + \# --------\n\ + \# HELP simplex_smp_rcv_services_count The count of receiving services.\n\ + \# TYPE simplex_smp_rcv_services_count gauge\n\ + \simplex_smp_rcv_services_count " <> mshow (rcvServiceCount entityCounts) <> "\n# rcvServiceCount\n\ + \\n\ + \# HELP simplex_smp_rcv_services_queues_count The count of queues associated with receiving services.\n\ + \# TYPE simplex_smp_rcv_services_queues_count gauge\n\ + \simplex_smp_rcv_services_queues_count " <> mshow (rcvServiceQueuesCount entityCounts) <> "\n# rcv.rcvServiceQueuesCount\n\ + \\n\ + \# HELP simplex_smp_ntf_services_count The count of notification services.\n\ + \# TYPE simplex_smp_ntf_services_count gauge\n\ + \simplex_smp_ntf_services_count " <> mshow (ntfServiceCount entityCounts) <> "\n# ntfServiceCount\n\ + \\n\ + \# HELP simplex_smp_ntf_services_queues_count The count of queues associated with notification services.\n\ + \# TYPE simplex_smp_ntf_services_queues_count gauge\n\ + \simplex_smp_ntf_services_queues_count " <> mshow (ntfServiceQueuesCount entityCounts) <> "\n# ntfServiceQueuesCount\n\ + \\n" + <> showServices _rcvServices "rcv" "receiving" + <> showServices _ntfServices "ntf" "notification" + showServices ss pfx name = + "# HELP simplex_smp_" <> pfx <> "_services_assoc_new New queue associations with " <> name <> " services.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_assoc_new counter\n\ + \simplex_smp_" <> pfx <> "_services_assoc_new " <> mshow (_srvAssocNew ss) <> "\n# " <> pfx <> ".srvAssocNew\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_assoc_duplicate Duplicate queue associations with " <> name <> " services.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_assoc_duplicate counter\n\ + \simplex_smp_" <> pfx <> "_services_assoc_duplicate " <> mshow (_srvAssocDuplicate ss) <> "\n# " <> pfx <> ".srvAssocDuplicate\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_assoc_updated Updated queue associations with " <> name <> " services.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_assoc_updated counter\n\ + \simplex_smp_" <> pfx <> "_services_assoc_updated " <> mshow (_srvAssocUpdated ss) <> "\n# " <> pfx <> ".srvAssocUpdated\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_assoc_removed Removed queue associations with " <> name <> " services.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_assoc_removed counter\n\ + \simplex_smp_" <> pfx <> "_services_assoc_removed " <> mshow (_srvAssocRemoved ss) <> "\n# " <> pfx <> ".srvAssocRemoved\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_sub_count Service subscriptions by " <> name <> " services.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_sub_count counter\n\ + \simplex_smp_" <> pfx <> "_services_sub_count " <> mshow (_srvSubCount ss) <> "\n# " <> pfx <> ".srvSubCount\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_sub_duplicate Duplicate service subscriptions by " <> name <> " services.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_sub_duplicate counter\n\ + \simplex_smp_" <> pfx <> "_services_sub_duplicate " <> mshow (_srvSubDuplicate ss) <> "\n# " <> pfx <> ".srvSubDuplicate\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_sub_queues Queues subscribed by " <> name <> " services.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_sub_queues gauge\n\ + \simplex_smp_" <> pfx <> "_services_sub_queues " <> mshow (_srvSubQueues ss) <> "\n# " <> pfx <> ".srvSubQueues\n\ + \\n\ + \# HELP simplex_smp_" <> pfx <> "_services_sub_end Ended subscriptions with " <> name <> " services.\n\ + \# TYPE simplex_smp_" <> pfx <> "_services_sub_end gauge\n\ + \simplex_smp_" <> pfx <> "_services_sub_end " <> mshow (_srvSubEnd ss) <> "\n# " <> pfx <> ".srvSubEnd\n\ + \\n" info = "# Info\n\ \# ----\n\ @@ -376,6 +434,10 @@ prometheusMetrics sm rtm ts = \# TYPE simplex_smp_subscribtion_clients_total gauge\n\ \simplex_smp_subscribtion_clients_total " <> mshow (subClientsCount smpSubs) <> "\n# smp.subClientsCount\n\ \\n\ + \# HELP simplex_smp_subscribtion_services_total Subscribed services, first counting method\n\ + \# TYPE simplex_smp_subscribtion_services_total gauge\n\ + \simplex_smp_subscribtion_services_total " <> mshow (subServicesCount smpSubs) <> "\n# smp.subServicesCount\n\ + \\n\ \# HELP simplex_smp_subscription_ntf_total Total notification subscripbtions (from ntf server)\n\ \# TYPE simplex_smp_subscription_ntf_total gauge\n\ \simplex_smp_subscription_ntf_total " <> mshow (subsCount ntfSubs) <> "\n# ntf.subsCount\n\ @@ -384,6 +446,10 @@ prometheusMetrics sm rtm ts = \# TYPE simplex_smp_subscription_ntf_clients_total gauge\n\ \simplex_smp_subscription_ntf_clients_total " <> mshow (subClientsCount ntfSubs) <> "\n# ntf.subClientsCount\n\ \\n\ + \# HELP simplex_smp_subscribtion_nts_services_total Subscribed NTF services, first counting method\n\ + \# TYPE simplex_smp_subscribtion_nts_services_total gauge\n\ + \simplex_smp_subscribtion_nts_services_total " <> mshow (subServicesCount ntfSubs) <> "\n# ntf.subServicesCount\n\ + \\n\ \# HELP simplex_smp_loaded_queues_queue_count Total loaded queues count (all queues for memory/journal storage)\n\ \# TYPE simplex_smp_loaded_queues_queue_count gauge\n\ \simplex_smp_loaded_queues_queue_count " <> mshow (loadedQueueCount loadedCounts) <> "\n# loadedCounts.loadedQueueCount\n\ @@ -410,9 +476,9 @@ prometheusMetrics sm rtm ts = <> "# TYPE " <> metric <> " gauge\n" <> T.concat (map (\(port, ss) -> metric <> "{port=\"" <> T.pack port <> "\"} " <> mshow (sel ss) <> "\n") socketStats) <> "\n" - mstr a = T.pack a <> " " <> tsEpoch + mstr a = a <> " " <> tsEpoch mshow :: Show a => a -> Text - mshow = mstr . show - tsEpoch = T.pack $ show @Int64 $ floor @Double $ realToFrac (ts `diffUTCTime` epoch) * 1000 + mshow = mstr . tshow + tsEpoch = tshow @Int64 $ floor @Double $ realToFrac (ts `diffUTCTime` epoch) * 1000 epoch = UTCTime systemEpochDay 0 {-# FOURMOLU_ENABLE\n#-} diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index e90359d1d..cbac2bf08 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -1,6 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} @@ -10,13 +11,18 @@ module Simplex.Messaging.Server.QueueStore where -import Control.Applicative ((<|>)) +import Control.Applicative (optional, (<|>)) +import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) import Data.Time.Clock.System (SystemTime (..), getSystemTime) +import qualified Data.X509 as X +import qualified Data.X509.Validation as XV +import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol +import Simplex.Messaging.Transport (SMPServiceRole) #if defined(dbServerPostgres) import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Database.PostgreSQL.Simple.FromField (FromField (..)) @@ -34,22 +40,55 @@ data QueueRec = QueueRec queueData :: Maybe (LinkId, QueueLinkData), notifier :: Maybe NtfCreds, status :: ServerEntityStatus, - updatedAt :: Maybe RoundedSystemTime + updatedAt :: Maybe RoundedSystemTime, + rcvServiceId :: Maybe ServiceId } deriving (Show) data NtfCreds = NtfCreds - { notifierId :: !NotifierId, - notifierKey :: !NtfPublicAuthKey, - rcvNtfDhSecret :: !RcvNtfDhSecret + { notifierId :: NotifierId, + notifierKey :: NtfPublicAuthKey, + rcvNtfDhSecret :: RcvNtfDhSecret, + ntfServiceId :: Maybe ServiceId } deriving (Show) instance StrEncoding NtfCreds where - strEncode NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} = strEncode (notifierId, notifierKey, rcvNtfDhSecret) + strEncode NtfCreds {notifierId, notifierKey, rcvNtfDhSecret, ntfServiceId} = + strEncode (notifierId, notifierKey, rcvNtfDhSecret) + <> maybe "" ((" nsrv=" <>) . strEncode) ntfServiceId strP = do (notifierId, notifierKey, rcvNtfDhSecret) <- strP - pure NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} + ntfServiceId <- optional $ " nsrv=" *> strP + pure NtfCreds {notifierId, notifierKey, rcvNtfDhSecret, ntfServiceId} + +data ServiceRec = ServiceRec + { serviceId :: ServiceId, + serviceRole :: SMPServiceRole, + serviceCert :: X.CertificateChain, + serviceCertHash :: XV.Fingerprint, -- SHA512 hash of long-term service client certificate. See comment for ClientHandshake. + serviceCreatedAt :: RoundedSystemTime + } + deriving (Show) + +type CertFingerprint = B.ByteString + +instance StrEncoding ServiceRec where + strEncode ServiceRec {serviceId, serviceRole, serviceCert, serviceCertHash, serviceCreatedAt} = + B.unwords + [ "service_id=" <> strEncode serviceId, + "role=" <> smpEncode serviceRole, + "cert=" <> strEncode serviceCert, + "cert_hash=" <> strEncode serviceCertHash, + "created_at=" <> strEncode serviceCreatedAt + ] + strP = do + serviceId <- "service_id=" *> strP + serviceRole <- " role=" *> smpP + serviceCert <- " cert=" *> strP + serviceCertHash <- " cert_hash=" *> strP + serviceCreatedAt <- " created_at=" *> strP + pure ServiceRec {serviceId, serviceRole, serviceCert, serviceCertHash, serviceCreatedAt} data ServerEntityStatus = EntityActive diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs index 7d2107f5a..20307ac9d 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs @@ -21,7 +21,9 @@ module Simplex.Messaging.Server.QueueStore.Postgres ( PostgresQueueStore (..), PostgresStoreCfg (..), + batchInsertServices, batchInsertQueues, + foldServiceRecs, foldQueueRecs, handleDuplicate, withLog_, @@ -43,13 +45,16 @@ import Data.Bitraversable (bimapM) import Data.Either (fromRight) import Data.Functor (($>)) import Data.Int (Int64) -import Data.List (intersperse) +import Data.List (foldl', intersperse, partition) import Data.List.NonEmpty (NonEmpty) import qualified Data.Map.Strict as M import Data.Maybe (catMaybes, fromMaybe) -import qualified Data.Text as T +import qualified Data.Set as S +import Data.Text (Text) import Data.Time.Clock.System (SystemTime (..), getSystemTime) -import Database.PostgreSQL.Simple (Binary (..), Only (..), Query, SqlError, (:.) (..)) +import qualified Data.X509 as X +import qualified Data.X509.Validation as XV +import Database.PostgreSQL.Simple (Binary (..), In (..), Only (..), Query, SqlError, (:.) (..)) import qualified Database.PostgreSQL.Simple as DB import qualified Database.PostgreSQL.Simple.Copy as DB import Database.PostgreSQL.Simple.FromField (FromField (..)) @@ -65,16 +70,18 @@ import Simplex.Messaging.Agent.Store.Postgres.Common import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding +import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.Postgres.Config import Simplex.Messaging.Server.QueueStore.Postgres.Migrations (serverMigrations) -import Simplex.Messaging.Server.QueueStore.STM (readQueueRecIO) +import Simplex.Messaging.Server.QueueStore.STM (STMService (..), readQueueRecIO) import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, tshow, (<$$>)) +import Simplex.Messaging.Transport (SMPServiceRole (..)) +import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>)) import System.Exit (exitFailure) import System.IO (IOMode (..), hFlush, stdout) import UnliftIO.STM @@ -96,6 +103,7 @@ data PostgresQueueStore q = PostgresQueueStore -- this map only cashes the queues that were attempted to be subscribed to, notifiers :: TMap NotifierId RecipientId, notifierLocks :: TMap NotifierId Lock, + serviceLocks :: TMap CertFingerprint Lock, deletedTTL :: Int64 } @@ -111,7 +119,8 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where links <- TM.emptyIO notifiers <- TM.emptyIO notifierLocks <- TM.emptyIO - pure PostgresQueueStore {dbStore, dbStoreLog, queues, senders, links, notifiers, notifierLocks, deletedTTL} + serviceLocks <- TM.emptyIO + pure PostgresQueueStore {dbStore, dbStoreLog, queues, senders, links, notifiers, notifierLocks, serviceLocks, deletedTTL} where err e = do logError $ "STORE: newQueueStore, error opening PostgreSQL database, " <> tshow e @@ -131,18 +140,23 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where fmap (fromRight 0) $ runExceptT $ withDB' "removeDeletedQueues" st $ \db -> DB.execute db "DELETE FROM msg_queues WHERE deleted_at < ?" (Only old) - queueCounts :: PostgresQueueStore q -> IO QueueCounts - queueCounts st = + getEntityCounts :: PostgresQueueStore q -> IO EntityCounts + getEntityCounts st = withConnection (dbStore st) $ \db -> do - (queueCount, notifierCount) : _ <- - DB.query_ + (queueCount, notifierCount, rcvServiceCount, ntfServiceCount, rcvServiceQueuesCount, ntfServiceQueuesCount) : _ <- + DB.query db [sql| SELECT (SELECT COUNT(1) FROM msg_queues WHERE deleted_at IS NULL) AS queue_count, - (SELECT COUNT(1) FROM msg_queues WHERE deleted_at IS NULL AND notifier_id IS NOT NULL) AS notifier_count + (SELECT COUNT(1) FROM msg_queues WHERE deleted_at IS NULL AND notifier_id IS NOT NULL) AS notifier_count, + (SELECT COUNT(1) FROM services WHERE service_role = ?) AS rcv_service_count, + (SELECT COUNT(1) FROM services WHERE service_role = ?) AS ntf_service_count, + (SELECT COUNT(1) FROM msg_queues WHERE rcv_service_id IS NOT NULL AND deleted_at IS NULL) AS rcv_service_queues_count, + (SELECT COUNT(1) FROM msg_queues WHERE ntf_service_id IS NOT NULL AND deleted_at IS NULL) AS ntf_service_queues_count |] - pure QueueCounts {queueCount, notifierCount} + (SRMessaging, SRNotifier) + pure EntityCounts {queueCount, notifierCount, rcvServiceCount, ntfServiceCount, rcvServiceQueuesCount, ntfServiceQueuesCount} -- this implementation assumes that the lock is already taken by addQueue -- and relies on unique constraints in the database to prevent duplicate IDs. @@ -169,13 +183,15 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where getQueue_ :: DirectParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q) getQueue_ st mkQ party qId = case party of SRecipient -> getRcvQueue qId - SSender -> TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue + SSender -> getSndQueue + SProxyService -> getSndQueue SSenderLink -> TM.lookupIO qId links >>= maybe (mask loadLinkQueue) getRcvQueue -- loaded queue is deleted from notifiers map to reduce cache size after queue was subscribed to by ntf server SNotifier -> TM.lookupIO qId notifiers >>= maybe (mask loadNtfQueue) (getRcvQueue >=> (atomically (TM.delete qId notifiers) $>)) where PostgresQueueStore {queues, senders, links, notifiers} = st getRcvQueue rId = TM.lookupIO rId queues >>= maybe (mask loadRcvQueue) (pure . Right) + getSndQueue = TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue loadRcvQueue = do (rId, qRec) <- loadQueue " WHERE recipient_id = ?" liftIO $ cacheQueue rId qRec $ \_ -> pure () -- recipient map already checked, not caching sender ref @@ -273,20 +289,20 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where where rId = recipientId sq - addQueueNotifier :: PostgresQueueStore q -> q -> NtfCreds -> IO (Either ErrorType (Maybe NotifierId)) + addQueueNotifier :: PostgresQueueStore q -> q -> NtfCreds -> IO (Either ErrorType (Maybe NtfCreds)) addQueueNotifier st sq ntfCreds@NtfCreds {notifierId = nId, notifierKey, rcvNtfDhSecret} = withQueueRec sq "addQueueNotifier" $ \q -> ExceptT $ withLockMap (notifierLocks st) nId "addQueueNotifier" $ ifM (TM.memberIO nId notifiers) (pure $ Left DUPLICATE_) $ runExceptT $ do assertUpdated $ withDB "addQueueNotifier" st $ \db -> E.try (update db) >>= bimapM handleDuplicate pure - nId_ <- forM (notifier q) $ \NtfCreds {notifierId} -> atomically (TM.delete notifierId notifiers) $> notifierId + nc_ <- forM (notifier q) $ \nc@NtfCreds {notifierId} -> atomically (TM.delete notifierId notifiers) $> nc let !q' = q {notifier = Just ntfCreds} atomically $ writeTVar (queueRec sq) $ Just q' -- cache queue notifier ID – after notifier is added ntf server will likely subscribe atomically $ TM.insert nId rId notifiers withLog "addQueueNotifier" st $ \s -> logAddNotifier s rId ntfCreds - pure nId_ + pure nc_ where PostgresQueueStore {notifiers} = st rId = recipientId sq @@ -300,16 +316,16 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where |] (nId, notifierKey, rcvNtfDhSecret, rId) - deleteQueueNotifier :: PostgresQueueStore q -> q -> IO (Either ErrorType (Maybe NotifierId)) + deleteQueueNotifier :: PostgresQueueStore q -> q -> IO (Either ErrorType (Maybe NtfCreds)) deleteQueueNotifier st sq = withQueueRec sq "deleteQueueNotifier" $ \q -> - ExceptT $ fmap sequence $ forM (notifier q) $ \NtfCreds {notifierId = nId} -> + ExceptT $ fmap sequence $ forM (notifier q) $ \nc@NtfCreds {notifierId = nId} -> withLockMap (notifierLocks st) nId "deleteQueueNotifier" $ runExceptT $ do assertUpdated $ withDB' "deleteQueueNotifier" st update atomically $ TM.delete nId $ notifiers st atomically $ writeTVar (queueRec sq) $ Just q {notifier = Nothing} withLog "deleteQueueNotifier" st (`logDeleteNotifier` rId) - pure nId + pure nc where rId = recipientId sq update db = @@ -371,6 +387,75 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where rId = recipientId sq qr = queueRec sq + getCreateService :: PostgresQueueStore q -> ServiceRec -> IO (Either ErrorType ServiceId) + getCreateService st sr@ServiceRec {serviceId = newSrvId, serviceRole, serviceCertHash = XV.Fingerprint fp} = + withLockMap (serviceLocks st) fp "getCreateService" $ E.uninterruptibleMask_ $ runExceptT $ do + (serviceId, new) <- + withDB "getCreateService" st $ \db -> + maybeFirstRow id (DB.query db "SELECT service_id, service_role FROM services WHERE service_cert_hash = ?" (Only (Binary fp))) >>= \case + Just (serviceId, role) + | role == serviceRole -> pure $ Right (serviceId, False) + | otherwise -> pure $ Left SERVICE + Nothing -> + E.try (DB.execute db insertServiceQuery (serviceRecToRow sr)) + >>= bimapM handleDuplicate (\_ -> pure (newSrvId, True)) + when new $ withLog "getCreateService" st (`logNewService` sr) + pure serviceId + + setQueueService :: (PartyI p, SubscriberParty p) => PostgresQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) + setQueueService st sq party serviceId = withQueueRec sq "setQueueService" $ \q -> case party of + SRecipient + | rcvServiceId q == serviceId -> pure () + | otherwise -> do + assertUpdated $ withDB' "setQueueService" st $ \db -> + DB.execute db "UPDATE msg_queues SET rcv_service_id = ? WHERE recipient_id = ? AND deleted_at IS NULL" (serviceId, rId) + updateQueueRec q {rcvServiceId = serviceId} + SNotifier -> case notifier q of + Nothing -> throwE AUTH + Just nc@NtfCreds {ntfServiceId = prevSrvId} + | prevSrvId == serviceId -> pure () + | otherwise -> do + assertUpdated $ withDB' "setQueueService" st $ \db -> + DB.execute db "UPDATE msg_queues SET ntf_service_id = ? WHERE recipient_id = ? AND notifier_id IS NOT NULL AND deleted_at IS NULL" (serviceId, rId) + updateQueueRec q {notifier = Just nc {ntfServiceId = serviceId}} + where + rId = recipientId sq + updateQueueRec :: QueueRec -> ExceptT ErrorType IO () + updateQueueRec q' = do + atomically $ writeTVar (queueRec sq) $ Just q' + withLog "setQueueService" st $ \sl -> logQueueService sl rId party serviceId + + getQueueNtfServices :: PostgresQueueStore q -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)])) + getQueueNtfServices st ntfs = E.uninterruptibleMask_ $ runExceptT $ do + snIds <- + withDB' "getQueueNtfServices" st $ \db -> + DB.query db "SELECT ntf_service_id, notifier_id FROM msg_queues WHERE notifier_id IN ? AND deleted_at IS NULL" (Only (In (map fst ntfs))) + pure $ + if null snIds + then ([], ntfs) + else + let snIds' = foldl' (\m (sId, nId) -> M.alter (Just . maybe (S.singleton nId) (S.insert nId)) sId m) M.empty snIds + in foldr addService ([], ntfs) (M.assocs snIds') + where + addService :: + (Maybe ServiceId, S.Set NotifierId) -> + ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)]) -> + ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)]) + addService (serviceId, snIds) (ssNtfs, ntfs') = + let (sNtfs, restNtfs) = partition (\(nId, _) -> S.member nId snIds) ntfs' + in ((serviceId, sNtfs) : ssNtfs, restNtfs) + + getNtfServiceQueueCount :: PostgresQueueStore q -> ServiceId -> IO (Either ErrorType Int64) + getNtfServiceQueueCount st serviceId = + E.uninterruptibleMask_ $ runExceptT $ withDB' "getNtfServiceQueueCount" st $ \db -> + fmap (fromMaybe 0) $ maybeFirstRow fromOnly $ + DB.query db "SELECT count(1) FROM msg_queues WHERE ntf_service_id = ? AND deleted_at IS NULL" (Only serviceId) + +batchInsertServices :: [STMService] -> PostgresQueueStore q -> IO Int64 +batchInsertServices services' toStore = + withConnection (dbStore toStore) $ \db -> + DB.executeMany db insertServiceQuery $ map (serviceRecToRow . serviceRec) services' + batchInsertQueues :: StoreQueueClass q => Bool -> M.Map RecipientId q -> PostgresQueueStore q' -> IO Int64 batchInsertQueues tty queues toStore = do qs <- catMaybes <$> mapM (\(rId, q) -> (rId,) <$$> readTVarIO (queueRec q)) (M.assocs queues) @@ -381,7 +466,7 @@ batchInsertQueues tty queues toStore = do DB.copy_ db [sql| - COPY msg_queues (recipient_id, recipient_keys, rcv_dh_secret, sender_id, sender_key, queue_mode, notifier_id, notifier_key, rcv_ntf_dh_secret, status, updated_at, link_id, fixed_data, user_data) + COPY msg_queues (recipient_id, recipient_keys, rcv_dh_secret, sender_id, sender_key, queue_mode, notifier_id, notifier_key, rcv_ntf_dh_secret, ntf_service_id, status, updated_at, link_id, rcv_service_id, fixed_data, user_data) FROM STDIN WITH (FORMAT CSV) |] mapM_ (putQueue db) (zip [1..] qs) @@ -399,10 +484,24 @@ insertQueueQuery :: Query insertQueueQuery = [sql| INSERT INTO msg_queues - (recipient_id, recipient_keys, rcv_dh_secret, sender_id, sender_key, queue_mode, notifier_id, notifier_key, rcv_ntf_dh_secret, status, updated_at, link_id, fixed_data, user_data) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?) + (recipient_id, recipient_keys, rcv_dh_secret, sender_id, sender_key, queue_mode, notifier_id, notifier_key, rcv_ntf_dh_secret, ntf_service_id, status, updated_at, link_id, rcv_service_id, fixed_data, user_data) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) |] +insertServiceQuery :: Query +insertServiceQuery = + [sql| + INSERT INTO services + (service_id, service_role, service_cert, service_cert_hash, created_at) + VALUES (?,?,?,?,?) + |] + +foldServiceRecs :: forall a q. Monoid a => PostgresQueueStore q -> (ServiceRec -> IO a) -> IO a +foldServiceRecs st f = + withConnection (dbStore st) $ \db -> + DB.fold_ db "SELECT service_id, service_role, service_cert, service_cert_hash, created_at FROM services" mempty $ + \ !acc -> fmap (acc <>) . f . rowToServiceRec + foldQueueRecs :: forall a q. Monoid a => Bool -> Bool -> PostgresQueueStore q -> Maybe Int64 -> ((RecipientId, QueueRec) -> IO a) -> IO a foldQueueRecs tty withData st skipOld_ f = do (n, r) <- withConnection (dbStore st) $ \db -> @@ -417,12 +516,11 @@ foldQueueRecs tty withData st skipOld_ f = do where foldRecs db acc f' = case skipOld_ of Nothing - | withData -> DB.fold_ db (query <> " WHERE deleted_at IS NULL") acc $ \acc' -> f' acc' . rowToQueueRecWithData - | otherwise -> DB.fold_ db (query <> " WHERE deleted_at IS NULL") acc $ \acc' -> f' acc' . rowToQueueRec + | withData -> DB.fold_ db (queueRecQueryWithData <> " WHERE deleted_at IS NULL") acc $ \acc' -> f' acc' . rowToQueueRecWithData + | otherwise -> DB.fold_ db (queueRecQuery <> " WHERE deleted_at IS NULL") acc $ \acc' -> f' acc' . rowToQueueRec Just old - | withData -> DB.fold db (query <> " WHERE deleted_at IS NULL AND updated_at > ?") (Only old) acc $ \acc' -> f' acc' . rowToQueueRecWithData - | otherwise -> DB.fold db (query <> " WHERE deleted_at IS NULL AND updated_at > ?") (Only old) acc $ \acc' -> f' acc' . rowToQueueRec - query = if withData then queueRecQueryWithData else queueRecQuery + | withData -> DB.fold db (queueRecQueryWithData <> " WHERE deleted_at IS NULL AND updated_at > ?") (Only old) acc $ \acc' -> f' acc' . rowToQueueRecWithData + | otherwise -> DB.fold db (queueRecQuery <> " WHERE deleted_at IS NULL AND updated_at > ?") (Only old) acc $ \acc' -> f' acc' . rowToQueueRec progress i = "Processed: " <> show i <> " records" queueRecQuery :: Query @@ -430,9 +528,8 @@ queueRecQuery = [sql| SELECT recipient_id, recipient_keys, rcv_dh_secret, sender_id, sender_key, queue_mode, - notifier_id, notifier_key, rcv_ntf_dh_secret, - status, updated_at, - link_id + notifier_id, notifier_key, rcv_ntf_dh_secret, ntf_service_id, + status, updated_at, link_id, rcv_service_id FROM msg_queues |] @@ -441,23 +538,28 @@ queueRecQueryWithData = [sql| SELECT recipient_id, recipient_keys, rcv_dh_secret, sender_id, sender_key, queue_mode, - notifier_id, notifier_key, rcv_ntf_dh_secret, - status, updated_at, - link_id, fixed_data, user_data + notifier_id, notifier_key, rcv_ntf_dh_secret, ntf_service_id, + status, updated_at, link_id, rcv_service_id, + fixed_data, user_data FROM msg_queues |] -type QueueRecRow = (RecipientId, NonEmpty RcvPublicAuthKey, RcvDhSecret, SenderId, Maybe SndPublicAuthKey, Maybe QueueMode, Maybe NotifierId, Maybe NtfPublicAuthKey, Maybe RcvNtfDhSecret, ServerEntityStatus, Maybe RoundedSystemTime, Maybe LinkId) +type QueueRecRow = + ( RecipientId, NonEmpty RcvPublicAuthKey, RcvDhSecret, + SenderId, Maybe SndPublicAuthKey, Maybe QueueMode, + Maybe NotifierId, Maybe NtfPublicAuthKey, Maybe RcvNtfDhSecret, Maybe ServiceId, + ServerEntityStatus, Maybe RoundedSystemTime, Maybe LinkId, Maybe ServiceId + ) queueRecToRow :: (RecipientId, QueueRec) -> QueueRecRow :. (Maybe EncDataBytes, Maybe EncDataBytes) -queueRecToRow (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, queueData, notifier = n, status, updatedAt}) = - (rId, recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, notifierId <$> n, notifierKey <$> n, rcvNtfDhSecret <$> n, status, updatedAt, linkId_) +queueRecToRow (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, queueData, notifier = n, status, updatedAt, rcvServiceId}) = + (rId, recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, notifierId <$> n, notifierKey <$> n, rcvNtfDhSecret <$> n, ntfServiceId =<< n, status, updatedAt, linkId_, rcvServiceId) :. (fst <$> queueData_, snd <$> queueData_) where (linkId_, queueData_) = queueDataColumns queueData queueRecToText :: (RecipientId, QueueRec) -> ByteString -queueRecToText (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, queueData, notifier = n, status, updatedAt}) = +queueRecToText (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, queueData, notifier = n, status, updatedAt, rcvServiceId}) = LB.toStrict $ BB.toLazyByteString $ mconcat tabFields <> BB.char7 '\n' where tabFields = BB.char7 ',' `intersperse` fields @@ -471,9 +573,11 @@ queueRecToText (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, nullable (notifierId <$> n), nullable (notifierKey <$> n), nullable (rcvNtfDhSecret <$> n), + nullable (ntfServiceId =<< n), BB.char7 '"' <> renderField (toField status) <> BB.char7 '"', nullable updatedAt, nullable linkId_, + nullable rcvServiceId, nullable (fst <$> queueData_), nullable (snd <$> queueData_) ] @@ -494,19 +598,32 @@ queueDataColumns = \case Nothing -> (Nothing, Nothing) rowToQueueRec :: QueueRecRow -> (RecipientId, QueueRec) -rowToQueueRec (rId, recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, notifierId_, notifierKey_, rcvNtfDhSecret_, status, updatedAt, linkId_) = - let notifier = NtfCreds <$> notifierId_ <*> notifierKey_ <*> rcvNtfDhSecret_ +rowToQueueRec (rId, recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, notifierId_, notifierKey_, rcvNtfDhSecret_, ntfServiceId, status, updatedAt, linkId_, rcvServiceId) = + let notifier = mkNotifier (notifierId_, notifierKey_, rcvNtfDhSecret_) ntfServiceId queueData = (,(EncDataBytes "", EncDataBytes "")) <$> linkId_ - in (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, queueData, notifier, status, updatedAt}) + in (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, queueData, notifier, status, updatedAt, rcvServiceId}) rowToQueueRecWithData :: QueueRecRow :. (Maybe EncDataBytes, Maybe EncDataBytes) -> (RecipientId, QueueRec) -rowToQueueRecWithData ((rId, recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, notifierId_, notifierKey_, rcvNtfDhSecret_, status, updatedAt, linkId_) :. (immutableData_, userData_)) = - let notifier = NtfCreds <$> notifierId_ <*> notifierKey_ <*> rcvNtfDhSecret_ +rowToQueueRecWithData ((rId, recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, notifierId_, notifierKey_, rcvNtfDhSecret_, ntfServiceId, status, updatedAt, linkId_, rcvServiceId) :. (immutableData_, userData_)) = + let notifier = mkNotifier (notifierId_, notifierKey_, rcvNtfDhSecret_) ntfServiceId encData = fromMaybe (EncDataBytes "") queueData = (,(encData immutableData_, encData userData_)) <$> linkId_ - in (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, queueData, notifier, status, updatedAt}) + in (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, queueData, notifier, status, updatedAt, rcvServiceId}) -setStatusDB :: StoreQueueClass q => String -> PostgresQueueStore q -> q -> ServerEntityStatus -> ExceptT ErrorType IO () -> IO (Either ErrorType ()) +mkNotifier :: (Maybe NotifierId, Maybe NtfPublicAuthKey, Maybe RcvNtfDhSecret) -> Maybe ServiceId -> Maybe NtfCreds +mkNotifier (Just notifierId, Just notifierKey, Just rcvNtfDhSecret) ntfServiceId = + Just NtfCreds {notifierId, notifierKey, rcvNtfDhSecret, ntfServiceId} +mkNotifier _ _ = Nothing + +serviceRecToRow :: ServiceRec -> (ServiceId, SMPServiceRole, X.CertificateChain, Binary ByteString, RoundedSystemTime) +serviceRecToRow ServiceRec {serviceId, serviceRole, serviceCert, serviceCertHash = XV.Fingerprint fp, serviceCreatedAt} = + (serviceId, serviceRole, serviceCert, Binary fp, serviceCreatedAt) + +rowToServiceRec :: (ServiceId, SMPServiceRole, X.CertificateChain, Binary ByteString, RoundedSystemTime) -> ServiceRec +rowToServiceRec (serviceId, serviceRole, serviceCert, Binary fp, serviceCreatedAt) = + ServiceRec {serviceId, serviceRole, serviceCert, serviceCertHash = XV.Fingerprint fp, serviceCreatedAt} + +setStatusDB :: StoreQueueClass q => Text -> PostgresQueueStore q -> q -> ServerEntityStatus -> ExceptT ErrorType IO () -> IO (Either ErrorType ()) setStatusDB op st sq status writeLog = withQueueRec sq op $ \q -> do assertUpdated $ withDB' op st $ \db -> @@ -514,33 +631,33 @@ setStatusDB op st sq status writeLog = atomically $ writeTVar (queueRec sq) $ Just q {status} writeLog -withQueueRec :: StoreQueueClass q => q -> String -> (QueueRec -> ExceptT ErrorType IO a) -> IO (Either ErrorType a) +withQueueRec :: StoreQueueClass q => q -> Text -> (QueueRec -> ExceptT ErrorType IO a) -> IO (Either ErrorType a) withQueueRec sq op action = withQueueLock sq op $ E.uninterruptibleMask_ $ runExceptT $ ExceptT (readQueueRecIO $ queueRec sq) >>= action assertUpdated :: ExceptT ErrorType IO Int64 -> ExceptT ErrorType IO () assertUpdated = (>>= \n -> when (n == 0) (throwE AUTH)) -withDB' :: String -> PostgresQueueStore q -> (DB.Connection -> IO a) -> ExceptT ErrorType IO a +withDB' :: Text -> PostgresQueueStore q -> (DB.Connection -> IO a) -> ExceptT ErrorType IO a withDB' op st action = withDB op st $ fmap Right . action -withDB :: forall a q. String -> PostgresQueueStore q -> (DB.Connection -> IO (Either ErrorType a)) -> ExceptT ErrorType IO a +withDB :: forall a q. Text -> PostgresQueueStore q -> (DB.Connection -> IO (Either ErrorType a)) -> ExceptT ErrorType IO a withDB op st action = ExceptT $ E.try (withConnection (dbStore st) action) >>= either logErr pure where logErr :: E.SomeException -> IO (Either ErrorType a) - logErr e = logError ("STORE: " <> T.pack err) $> Left (STORE err) + logErr e = logError ("STORE: " <> err) $> Left (STORE err) where - err = op <> ", withDB, " <> show e + err = op <> ", withDB, " <> tshow e -withLog :: MonadIO m => String -> PostgresQueueStore q -> (StoreLog 'WriteMode -> IO ()) -> m () +withLog :: MonadIO m => Text -> PostgresQueueStore q -> (StoreLog 'WriteMode -> IO ()) -> m () withLog op PostgresQueueStore {dbStoreLog} = withLog_ op dbStoreLog {-# INLINE withLog #-} -withLog_ :: MonadIO m => String -> Maybe (StoreLog 'WriteMode) -> (StoreLog 'WriteMode -> IO ()) -> m () +withLog_ :: MonadIO m => Text -> Maybe (StoreLog 'WriteMode) -> (StoreLog 'WriteMode -> IO ()) -> m () withLog_ op sl_ action = forM_ sl_ $ \sl -> liftIO $ action sl `catchAny` \e -> - logWarn $ "STORE: " <> T.pack (op <> ", withLog, " <> show e) + logWarn $ "STORE: " <> op <> ", withLog, " <> tshow e handleDuplicate :: SqlError -> IO ErrorType handleDuplicate e = case constraintViolation e of @@ -553,6 +670,14 @@ instance ToField (NonEmpty C.APublicAuthKey) where toField = toField . Binary . instance FromField (NonEmpty C.APublicAuthKey) where fromField = blobFieldDecoder smpDecode +instance ToField SMPServiceRole where toField = toField . decodeLatin1 . smpEncode + +instance FromField SMPServiceRole where fromField = fromTextField_ $ eitherToMaybe . smpDecode . encodeUtf8 + +instance ToField X.CertificateChain where toField = toField . Binary . smpEncode . C.encodeCertChain + +instance FromField X.CertificateChain where fromField = blobFieldDecoder (parseAll C.certChainP) + #if !defined(dbPostgres) instance ToField EntityId where toField (EntityId s) = toField $ Binary s diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs index b1c5501f6..e8469d1cc 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs @@ -13,7 +13,8 @@ serverSchemaMigrations :: [(String, Text, Maybe Text)] serverSchemaMigrations = [ ("20250207_initial", m20250207_initial, Nothing), ("20250319_updated_index", m20250319_updated_index, Just down_m20250319_updated_index), - ("20250320_short_links", m20250320_short_links, Just down_m20250320_short_links) + ("20250320_short_links", m20250320_short_links, Just down_m20250320_short_links), + ("20250514_service_certs", m20250514_service_certs, Just down_m20250514_service_certs) ] -- | The list of migrations in ascending order by date @@ -48,7 +49,7 @@ CREATE INDEX idx_msg_queues_deleted_at ON msg_queues (deleted_at); |] m20250319_updated_index :: Text -m20250319_updated_index = +m20250319_updated_index = T.pack [r| DROP INDEX idx_msg_queues_deleted_at; @@ -119,3 +120,42 @@ UPDATE msg_queues SET recipient_keys = substring(recipient_keys from 3); ALTER TABLE msg_queues RENAME COLUMN recipient_keys TO recipient_key; |] + +m20250514_service_certs :: Text +m20250514_service_certs = + T.pack + [r| +CREATE TABLE services( + service_id BYTEA NOT NULL, + service_role TEXT NOT NULL, + service_cert BYTEA NOT NULL, + service_cert_hash BYTEA NOT NULL UNIQUE, + created_at BIGINT NOT NULL, + PRIMARY KEY (service_id) +); + +CREATE INDEX idx_services_service_role ON services(service_role); + +ALTER TABLE msg_queues + ADD COLUMN rcv_service_id BYTEA REFERENCES services(service_id) ON DELETE SET NULL ON UPDATE RESTRICT, + ADD COLUMN ntf_service_id BYTEA REFERENCES services(service_id) ON DELETE SET NULL ON UPDATE RESTRICT; + +CREATE INDEX idx_msg_queues_rcv_service_id ON msg_queues(rcv_service_id, deleted_at); +CREATE INDEX idx_msg_queues_ntf_service_id ON msg_queues(ntf_service_id, deleted_at); + |] + +down_m20250514_service_certs :: Text +down_m20250514_service_certs = + T.pack + [r| +DROP INDEX idx_msg_queues_rcv_service_id; +DROP INDEX idx_msg_queues_ntf_service_id; + +ALTER TABLE msg_queues + DROP COLUMN rcv_service_id, + DROP COLUMN ntf_service_id; + +DROP INDEX idx_services_service_role; + +DROP TABLE services; + |] diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql b/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql index 2910b6959..6c0501d8b 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql @@ -41,7 +41,19 @@ CREATE TABLE smp_server.msg_queues ( queue_mode text, link_id bytea, fixed_data bytea, - user_data bytea + user_data bytea, + rcv_service_id bytea, + ntf_service_id bytea +); + + + +CREATE TABLE smp_server.services ( + service_id bytea NOT NULL, + service_role text NOT NULL, + service_cert bytea NOT NULL, + service_cert_hash bytea NOT NULL, + created_at bigint NOT NULL ); @@ -56,6 +68,16 @@ ALTER TABLE ONLY smp_server.msg_queues +ALTER TABLE ONLY smp_server.services + ADD CONSTRAINT services_pkey PRIMARY KEY (service_id); + + + +ALTER TABLE ONLY smp_server.services + ADD CONSTRAINT services_service_cert_hash_key UNIQUE (service_cert_hash); + + + CREATE UNIQUE INDEX idx_msg_queues_link_id ON smp_server.msg_queues USING btree (link_id); @@ -64,6 +86,14 @@ CREATE UNIQUE INDEX idx_msg_queues_notifier_id ON smp_server.msg_queues USING bt +CREATE INDEX idx_msg_queues_ntf_service_id ON smp_server.msg_queues USING btree (ntf_service_id, deleted_at); + + + +CREATE INDEX idx_msg_queues_rcv_service_id ON smp_server.msg_queues USING btree (rcv_service_id, deleted_at); + + + CREATE UNIQUE INDEX idx_msg_queues_sender_id ON smp_server.msg_queues USING btree (sender_id); @@ -72,3 +102,17 @@ CREATE INDEX idx_msg_queues_updated_at ON smp_server.msg_queues USING btree (del +CREATE INDEX idx_services_service_role ON smp_server.services USING btree (service_role); + + + +ALTER TABLE ONLY smp_server.msg_queues + ADD CONSTRAINT msg_queues_ntf_service_id_fkey FOREIGN KEY (ntf_service_id) REFERENCES smp_server.services(service_id) ON UPDATE RESTRICT ON DELETE SET NULL; + + + +ALTER TABLE ONLY smp_server.msg_queues + ADD CONSTRAINT msg_queues_rcv_service_id_fkey FOREIGN KEY (rcv_service_id) REFERENCES smp_server.services(service_id) ON UPDATE RESTRICT ON DELETE SET NULL; + + + diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index 61fa3af45..522f2f28e 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -16,6 +16,7 @@ module Simplex.Messaging.Server.QueueStore.STM ( STMQueueStore (..), + STMService (..), setStoreLog, withLog', readQueueRecIO, @@ -28,16 +29,22 @@ import Control.Logger.Simple import Control.Monad import Data.Bitraversable (bimapM) import Data.Functor (($>)) +import Data.Int (Int64) +import Data.List (partition) import Data.List.NonEmpty (NonEmpty) import qualified Data.Map.Strict as M -import qualified Data.Text as T +import Data.Set (Set) +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.X509.Validation as XV import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (anyM, ifM, ($>>), ($>>=), (<$$)) +import Simplex.Messaging.Transport (SMPServiceRole (..)) +import Simplex.Messaging.Util (anyM, ifM, tshow, ($>>), ($>>=), (<$$)) import System.IO import UnliftIO.STM @@ -45,10 +52,18 @@ data STMQueueStore q = STMQueueStore { queues :: TMap RecipientId q, senders :: TMap SenderId RecipientId, notifiers :: TMap NotifierId RecipientId, + services :: TMap ServiceId STMService, + serviceCerts :: TMap CertFingerprint ServiceId, links :: TMap LinkId RecipientId, storeLog :: TVar (Maybe (StoreLog 'WriteMode)) } +data STMService = STMService + { serviceRec :: ServiceRec, + serviceRcvQueues :: TVar (Set RecipientId), + serviceNtfQueues :: TVar (Set NotifierId) + } + setStoreLog :: STMQueueStore q -> StoreLog 'WriteMode -> IO () setStoreLog st sl = atomically $ writeTVar (storeLog st) (Just sl) @@ -60,9 +75,11 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where queues <- TM.emptyIO senders <- TM.emptyIO notifiers <- TM.emptyIO + services <- TM.emptyIO + serviceCerts <- TM.emptyIO links <- TM.emptyIO storeLog <- newTVarIO Nothing - pure STMQueueStore {queues, senders, notifiers, links, storeLog} + pure STMQueueStore {queues, senders, notifiers, links, services, serviceCerts, storeLog} closeQueueStore :: STMQueueStore q -> IO () closeQueueStore STMQueueStore {queues, senders, notifiers, storeLog} = do @@ -76,11 +93,25 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where compactQueues _ = pure 0 {-# INLINE compactQueues #-} - queueCounts :: STMQueueStore q -> IO QueueCounts - queueCounts st = do + getEntityCounts :: STMQueueStore q -> IO EntityCounts + getEntityCounts st = do queueCount <- M.size <$> readTVarIO (queues st) notifierCount <- M.size <$> readTVarIO (notifiers st) - pure QueueCounts {queueCount, notifierCount} + ss <- readTVarIO (services st) + rcvServiceQueuesCount <- serviceQueuesCount serviceRcvQueues ss + ntfServiceQueuesCount <- serviceQueuesCount serviceNtfQueues ss + pure + EntityCounts + { queueCount, + notifierCount, + rcvServiceCount = serviceCount SRMessaging ss, + ntfServiceCount = serviceCount SRNotifier ss, + rcvServiceQueuesCount, + ntfServiceQueuesCount + } + where + serviceCount role = M.foldl' (\ !n s -> if serviceRole (serviceRec s) == role then n + 1 else n) 0 + serviceQueuesCount serviceSel = foldM (\n s -> (n +) . S.size <$> readTVarIO (serviceSel s)) 0 addQueue_ :: STMQueueStore q -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q) addQueue_ st mkQ rId qr@QueueRec {senderId = sId, notifier, queueData} = do @@ -101,11 +132,13 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where getQueue_ st _ party qId = maybe (Left AUTH) Right <$> case party of SRecipient -> TM.lookupIO qId queues - SSender -> TM.lookupIO qId senders $>>= (`TM.lookupIO` queues) + SSender -> getSndQueue + SProxyService -> getSndQueue SNotifier -> TM.lookupIO qId notifiers $>>= (`TM.lookupIO` queues) SSenderLink -> TM.lookupIO qId links $>>= (`TM.lookupIO` queues) where STMQueueStore {queues, senders, notifiers, links} = st + getSndQueue = TM.lookupIO qId senders $>>= (`TM.lookupIO` queues) getQueueLinkData :: STMQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData) getQueueLinkData _ q lnkId = atomically $ readQueueRec (queueRec q) $>>= pure . getData @@ -162,31 +195,31 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where writeTVar qr $ Just q {senderKey = Just sKey} pure $ Right () - addQueueNotifier :: STMQueueStore q -> q -> NtfCreds -> IO (Either ErrorType (Maybe NotifierId)) + addQueueNotifier :: STMQueueStore q -> q -> NtfCreds -> IO (Either ErrorType (Maybe NtfCreds)) addQueueNotifier st sq ntfCreds@NtfCreds {notifierId = nId} = atomically (readQueueRec qr $>>= add) - $>>= \nId_ -> nId_ <$$ withLog "addQueueNotifier" st (\s -> logAddNotifier s rId ntfCreds) + $>>= \nc_ -> nc_ <$$ withLog "addQueueNotifier" st (\s -> logAddNotifier s rId ntfCreds) where rId = recipientId sq qr = queueRec sq STMQueueStore {notifiers} = st add q = ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $ do - nId_ <- forM (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers $> notifierId + nc_ <- forM (notifier q) $ \nc -> nc <$ removeNotifier st nc let !q' = q {notifier = Just ntfCreds} writeTVar qr $ Just q' TM.insert nId rId notifiers - pure $ Right nId_ + pure $ Right nc_ - deleteQueueNotifier :: STMQueueStore q -> q -> IO (Either ErrorType (Maybe NotifierId)) + deleteQueueNotifier :: STMQueueStore q -> q -> IO (Either ErrorType (Maybe NtfCreds)) deleteQueueNotifier st sq = withQueueRec qr delete - $>>= \nId_ -> nId_ <$$ withLog "deleteQueueNotifier" st (`logDeleteNotifier` recipientId sq) + $>>= \nc_ -> nc_ <$$ withLog "deleteQueueNotifier" st (`logDeleteNotifier` recipientId sq) where qr = queueRec sq - delete q = forM (notifier q) $ \NtfCreds {notifierId} -> do - TM.delete notifierId $ notifiers st + delete q = forM (notifier q) $ \nc -> do + removeNotifier st nc writeTVar qr $ Just q {notifier = Nothing} - pure notifierId + pure nc suspendQueue :: STMQueueStore q -> q -> IO (Either ErrorType ()) suspendQueue st sq = @@ -219,16 +252,93 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where deleteStoreQueue :: STMQueueStore q -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q))) deleteStoreQueue st sq = withQueueRec qr delete - $>>= \q -> withLog "deleteStoreQueue" st (`logDeleteQueue` recipientId sq) + $>>= \q -> withLog "deleteStoreQueue" st (`logDeleteQueue` rId) >>= mapM (\_ -> (q,) <$> atomically (swapTVar (msgQueue sq) Nothing)) where + rId = recipientId sq qr = queueRec sq - delete q = do + delete q@QueueRec {senderId, rcvServiceId} = do writeTVar qr Nothing - TM.delete (senderId q) $ senders st - forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId $ notifiers st + TM.delete senderId $ senders st + mapM_ (removeServiceQueue st serviceRcvQueues rId) rcvServiceId + mapM_ (removeNotifier st) $ notifier q pure q + getCreateService :: STMQueueStore q -> ServiceRec -> IO (Either ErrorType ServiceId) + getCreateService st sr@ServiceRec {serviceId = newSrvId, serviceRole, serviceCertHash = XV.Fingerprint fp} = + TM.lookupIO fp serviceCerts + >>= maybe + (atomically $ TM.lookup fp serviceCerts >>= maybe newService checkService) + (atomically . checkService) + $>>= \(serviceId, new) -> + if new + then serviceId <$$ withLog "getCreateService" st (`logNewService` sr) + else pure $ Right serviceId + where + STMQueueStore {services, serviceCerts} = st + checkService sId = + TM.lookup sId services >>= \case + Just STMService {serviceRec = ServiceRec {serviceId, serviceRole = role}} + | role == serviceRole -> pure $ Right (serviceId, False) + | otherwise -> pure $ Left $ SERVICE + Nothing -> newService_ + newService = ifM (TM.member newSrvId services) (pure $ Left DUPLICATE_) newService_ + newService_ = do + TM.insertM newSrvId newSTMService services + TM.insert fp newSrvId serviceCerts + pure $ Right (newSrvId, True) + newSTMService = do + serviceRcvQueues <- newTVar S.empty + serviceNtfQueues <- newTVar S.empty + pure STMService {serviceRec = sr, serviceRcvQueues, serviceNtfQueues} + + setQueueService :: (PartyI p, SubscriberParty p) => STMQueueStore q -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) + setQueueService st sq party serviceId = + atomically (readQueueRec qr $>>= setService) + $>> withLog "setQueueService" st (\sl -> logQueueService sl rId party serviceId) + where + qr = queueRec sq + rId = recipientId sq + setService :: QueueRec -> STM (Either ErrorType ()) + setService q@QueueRec {rcvServiceId = prevSrvId} = case party of + SRecipient + | prevSrvId == serviceId -> pure $ Right () + | otherwise -> do + updateServiceQueues serviceRcvQueues rId prevSrvId + let !q' = Just q {rcvServiceId = serviceId} + writeTVar qr q' $> Right () + SNotifier -> case notifier q of + Nothing -> pure $ Left AUTH + Just nc@NtfCreds {notifierId = nId, ntfServiceId = prevNtfSrvId} + | prevNtfSrvId == serviceId -> pure $ Right () + | otherwise -> do + let !q' = Just q {notifier = Just nc {ntfServiceId = serviceId}} + updateServiceQueues serviceNtfQueues nId prevNtfSrvId + writeTVar qr q' $> Right () + updateServiceQueues :: (STMService -> TVar (Set QueueId)) -> QueueId -> Maybe ServiceId -> STM () + updateServiceQueues serviceSel qId prevSrvId = do + mapM_ (removeServiceQueue st serviceSel qId) prevSrvId + mapM_ (addServiceQueue st serviceSel qId) serviceId + + getQueueNtfServices :: STMQueueStore q -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)])) + getQueueNtfServices st ntfs = do + ss <- readTVarIO (services st) + (ssNtfs, noServiceNtfs) <- if M.null ss then pure ([], ntfs) else foldM addService ([], ntfs) (M.assocs ss) + ns <- readTVarIO (notifiers st) + let (ntfs', deleteNtfs) = partition (\(nId, _) -> M.member nId ns) noServiceNtfs + ssNtfs' = (Nothing, ntfs') : ssNtfs + pure $ Right (ssNtfs', deleteNtfs) + where + addService (ssNtfs, ntfs') (serviceId, s) = do + snIds <- readTVarIO $ serviceNtfQueues s + let (sNtfs, restNtfs) = partition (\(nId, _) -> S.member nId snIds) ntfs' + pure ((Just serviceId, sNtfs) : ssNtfs, restNtfs) + + getNtfServiceQueueCount :: STMQueueStore q -> ServiceId -> IO (Either ErrorType Int64) + getNtfServiceQueueCount st serviceId = + TM.lookupIO serviceId (services st) >>= + maybe (pure $ Left AUTH) (fmap (Right . fromIntegral . S.size) . readTVarIO . serviceNtfQueues) + withQueueRec :: TVar (Maybe QueueRec) -> (QueueRec -> STM a) -> IO (Either ErrorType a) withQueueRec qr a = atomically $ readQueueRec qr >>= mapM a @@ -238,6 +348,21 @@ setStatus qr status = Just q -> (Right (), Just q {status}) Nothing -> (Left AUTH, Nothing) +addServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId)) -> QueueId -> ServiceId -> STM () +addServiceQueue st serviceSel qId serviceId = + TM.lookup serviceId (services st) >>= mapM_ (\s -> modifyTVar' (serviceSel s) (S.insert qId)) +{-# INLINE addServiceQueue #-} + +removeServiceQueue :: STMQueueStore q -> (STMService -> TVar (Set QueueId)) -> QueueId -> ServiceId -> STM () +removeServiceQueue st serviceSel qId serviceId = + TM.lookup serviceId (services st) >>= mapM_ (\s -> modifyTVar' (serviceSel s) (S.delete qId)) +{-# INLINE removeServiceQueue #-} + +removeNotifier :: STMQueueStore q -> NtfCreds -> STM () +removeNotifier st NtfCreds {notifierId = nId, ntfServiceId} = do + TM.delete nId $ notifiers st + mapM_ (removeServiceQueue st serviceNtfQueues nId) ntfServiceId + readQueueRec :: TVar (Maybe QueueRec) -> STM (Either ErrorType QueueRec) readQueueRec qr = maybe (Left AUTH) Right <$> readTVar qr {-# INLINE readQueueRec #-} @@ -246,16 +371,16 @@ readQueueRecIO :: TVar (Maybe QueueRec) -> IO (Either ErrorType QueueRec) readQueueRecIO qr = maybe (Left AUTH) Right <$> readTVarIO qr {-# INLINE readQueueRecIO #-} -withLog' :: String -> TVar (Maybe (StoreLog 'WriteMode)) -> (StoreLog 'WriteMode -> IO ()) -> IO (Either ErrorType ()) +withLog' :: Text -> TVar (Maybe (StoreLog 'WriteMode)) -> (StoreLog 'WriteMode -> IO ()) -> IO (Either ErrorType ()) withLog' name sl action = readTVarIO sl >>= maybe (pure $ Right ()) (E.try . E.uninterruptibleMask_ . action >=> bimapM logErr pure) where logErr :: E.SomeException -> IO ErrorType - logErr e = logError ("STORE: " <> T.pack err) $> STORE err + logErr e = logError ("STORE: " <> err) $> STORE err where - err = name <> ", withLog, " <> show e + err = name <> ", withLog, " <> tshow e -withLog :: String -> STMQueueStore q -> (StoreLog 'WriteMode -> IO ()) -> IO (Either ErrorType ()) +withLog :: Text -> STMQueueStore q -> (StoreLog 'WriteMode -> IO ()) -> IO (Either ErrorType ()) withLog name = withLog' name . storeLog {-# INLINE withLog #-} diff --git a/src/Simplex/Messaging/Server/QueueStore/Types.hs b/src/Simplex/Messaging/Server/QueueStore/Types.hs index 9a52a9df8..e8af996cb 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Types.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Types.hs @@ -11,6 +11,7 @@ import Control.Concurrent.STM import Control.Monad import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) +import Data.Text (Text) import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.TMap (TMap) @@ -20,13 +21,13 @@ class StoreQueueClass q where recipientId :: q -> RecipientId queueRec :: q -> TVar (Maybe QueueRec) msgQueue :: q -> TVar (Maybe (MsgQueue q)) - withQueueLock :: q -> String -> IO a -> IO a + withQueueLock :: q -> Text -> IO a -> IO a class StoreQueueClass q => QueueStoreClass q s where type QueueStoreCfg s newQueueStore :: QueueStoreCfg s -> IO s closeQueueStore :: s -> IO () - queueCounts :: s -> IO QueueCounts + getEntityCounts :: s -> IO EntityCounts loadedQueues :: s -> TMap RecipientId q compactQueues :: s -> IO Int64 addQueue_ :: s -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q) @@ -36,17 +37,25 @@ class StoreQueueClass q => QueueStoreClass q s where deleteQueueLinkData :: s -> q -> IO (Either ErrorType ()) secureQueue :: s -> q -> SndPublicAuthKey -> IO (Either ErrorType ()) updateKeys :: s -> q -> NonEmpty RcvPublicAuthKey -> IO (Either ErrorType ()) - addQueueNotifier :: s -> q -> NtfCreds -> IO (Either ErrorType (Maybe NotifierId)) - deleteQueueNotifier :: s -> q -> IO (Either ErrorType (Maybe NotifierId)) + addQueueNotifier :: s -> q -> NtfCreds -> IO (Either ErrorType (Maybe NtfCreds)) + deleteQueueNotifier :: s -> q -> IO (Either ErrorType (Maybe NtfCreds)) suspendQueue :: s -> q -> IO (Either ErrorType ()) blockQueue :: s -> q -> BlockingInfo -> IO (Either ErrorType ()) unblockQueue :: s -> q -> IO (Either ErrorType ()) updateQueueTime :: s -> q -> RoundedSystemTime -> IO (Either ErrorType QueueRec) deleteStoreQueue :: s -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q))) + getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId) + setQueueService :: (PartyI p, SubscriberParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) + getQueueNtfServices :: s -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)])) + getNtfServiceQueueCount :: s -> ServiceId -> IO (Either ErrorType Int64) -data QueueCounts = QueueCounts +data EntityCounts = EntityCounts { queueCount :: Int, - notifierCount :: Int + notifierCount :: Int, + rcvServiceCount :: Int, + ntfServiceCount :: Int, + rcvServiceQueuesCount :: Int, + ntfServiceQueuesCount :: Int } withLoadedQueues :: (Monoid a, QueueStoreClass q s) => s -> (q -> IO a) -> IO a diff --git a/src/Simplex/Messaging/Server/Stats.hs b/src/Simplex/Messaging/Server/Stats.hs index bbab6d8d2..da90b7216 100644 --- a/src/Simplex/Messaging/Server/Stats.hs +++ b/src/Simplex/Messaging/Server/Stats.hs @@ -18,13 +18,14 @@ import Data.IntSet (IntSet) import qualified Data.IntSet as IS import Data.Set (Set) import qualified Data.Set as S +import Data.Text (Text) import Data.Time.Calendar.Month (pattern MonthDay) import Data.Time.Calendar.OrdinalDate (mondayStartWeek) import Data.Time.Clock (UTCTime (..)) import GHC.IORef (atomicSwapIORef) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (EntityId (..)) -import Simplex.Messaging.Util (atomicModifyIORef'_, unlessM) +import Simplex.Messaging.Util (atomicModifyIORef'_, tshow, unlessM) data ServerStats = ServerStats { fromTime :: IORef UTCTime, @@ -78,6 +79,8 @@ data ServerStats = ServerStats pMsgFwds :: ProxyStats, pMsgFwdsOwn :: ProxyStats, pMsgFwdsRecv :: IORef Int, + rcvServices :: ServiceStats, + ntfServices :: ServiceStats, qCount :: IORef Int, msgCount :: IORef Int, ntfCount :: IORef Int @@ -133,6 +136,8 @@ data ServerStatsData = ServerStatsData _pMsgFwds :: ProxyStatsData, _pMsgFwdsOwn :: ProxyStatsData, _pMsgFwdsRecv :: Int, + _ntfServices :: ServiceStatsData, + _rcvServices :: ServiceStatsData, _qCount :: Int, _msgCount :: Int, _ntfCount :: Int @@ -190,6 +195,8 @@ newServerStats ts = do pMsgFwds <- newProxyStats pMsgFwdsOwn <- newProxyStats pMsgFwdsRecv <- newIORef 0 + rcvServices <- newServiceStats + ntfServices <- newServiceStats qCount <- newIORef 0 msgCount <- newIORef 0 ntfCount <- newIORef 0 @@ -244,6 +251,8 @@ newServerStats ts = do pMsgFwds, pMsgFwdsOwn, pMsgFwdsRecv, + rcvServices, + ntfServices, qCount, msgCount, ntfCount @@ -300,6 +309,8 @@ getServerStatsData s = do _pMsgFwds <- getProxyStatsData $ pMsgFwds s _pMsgFwdsOwn <- getProxyStatsData $ pMsgFwdsOwn s _pMsgFwdsRecv <- readIORef $ pMsgFwdsRecv s + _rcvServices <- getServiceStatsData $ rcvServices s + _ntfServices <- getServiceStatsData $ ntfServices s _qCount <- readIORef $ qCount s _msgCount <- readIORef $ msgCount s _ntfCount <- readIORef $ ntfCount s @@ -354,6 +365,8 @@ getServerStatsData s = do _pMsgFwds, _pMsgFwdsOwn, _pMsgFwdsRecv, + _rcvServices, + _ntfServices, _qCount, _msgCount, _ntfCount @@ -411,6 +424,8 @@ setServerStats s d = do setProxyStats (pMsgFwds s) $! _pMsgFwds d setProxyStats (pMsgFwdsOwn s) $! _pMsgFwdsOwn d writeIORef (pMsgFwdsRecv s) $! _pMsgFwdsRecv d + setServiceStats (rcvServices s) $! _rcvServices d + setServiceStats (ntfServices s) $! _ntfServices d writeIORef (qCount s) $! _qCount d writeIORef (msgCount s) $! _msgCount d writeIORef (ntfCount s) $! _ntfCount d @@ -473,7 +488,11 @@ instance StrEncoding ServerStatsData where strEncode (_pMsgFwds d), "pMsgFwdsOwn:", strEncode (_pMsgFwdsOwn d), - "pMsgFwdsRecv=" <> strEncode (_pMsgFwdsRecv d) + "pMsgFwdsRecv=" <> strEncode (_pMsgFwdsRecv d), + "rcvServices:", + strEncode (_rcvServices d), + "ntfServices:", + strEncode (_ntfServices d) ] strP = do _fromTime <- "fromTime=" *> strP <* A.endOfLine @@ -541,6 +560,8 @@ instance StrEncoding ServerStatsData where _pMsgFwds <- proxyStatsP "pMsgFwds:" _pMsgFwdsOwn <- proxyStatsP "pMsgFwdsOwn:" _pMsgFwdsRecv <- opt "pMsgFwdsRecv=" + _rcvServices <- serviceStatsP "rcvServices:" + _ntfServices <- serviceStatsP "ntfServices:" pure ServerStatsData { _fromTime, @@ -592,6 +613,8 @@ instance StrEncoding ServerStatsData where _pMsgFwds, _pMsgFwdsOwn, _pMsgFwdsRecv, + _rcvServices, + _ntfServices, _qCount, _msgCount = 0, _ntfCount = 0 @@ -603,6 +626,10 @@ instance StrEncoding ServerStatsData where optional (A.string key >> A.endOfLine) >>= \case Just _ -> strP <* optional A.endOfLine _ -> pure newProxyStatsData + serviceStatsP key = + optional (A.string key >> A.endOfLine) >>= \case + Just _ -> strP <* optional A.endOfLine + _ -> pure newServiceStatsData data PeriodStats = PeriodStats { day :: IORef IntSet, @@ -653,17 +680,17 @@ instance StrEncoding PeriodStatsData where bsSetP = S.foldl' (\s -> (`IS.insert` s) . hash) IS.empty <$> strP @(Set ByteString) data PeriodStatCounts = PeriodStatCounts - { dayCount :: String, - weekCount :: String, - monthCount :: String + { dayCount :: Text, + weekCount :: Text, + monthCount :: Text } periodStatDataCounts :: PeriodStatsData -> PeriodStatCounts periodStatDataCounts PeriodStatsData {_day, _week, _month} = PeriodStatCounts - { dayCount = show $ IS.size _day, - weekCount = show $ IS.size _week, - monthCount = show $ IS.size _month + { dayCount = tshow $ IS.size _day, + weekCount = tshow $ IS.size _week, + monthCount = tshow $ IS.size _month } periodStatCounts :: PeriodStats -> UTCTime -> IO PeriodStatCounts @@ -676,8 +703,8 @@ periodStatCounts ps ts = do monthCount <- periodCount mDay $ month ps pure PeriodStatCounts {dayCount, weekCount, monthCount} where - periodCount :: Int -> IORef IntSet -> IO String - periodCount 1 ref = show . IS.size <$> atomicSwapIORef ref IS.empty + periodCount :: Int -> IORef IntSet -> IO Text + periodCount 1 ref = tshow . IS.size <$> atomicSwapIORef ref IS.empty periodCount _ _ = pure "" updatePeriodStats :: PeriodStats -> EntityId -> IO () @@ -764,3 +791,156 @@ instance StrEncoding ProxyStatsData where _pErrorsCompat <- "errorsCompat=" *> strP <* A.endOfLine _pErrorsOther <- "errorsOther=" *> strP pure ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} + +data ServiceStats = ServiceStats + { srvAssocNew :: IORef Int, + srvAssocDuplicate :: IORef Int, + srvAssocUpdated :: IORef Int, + srvAssocRemoved :: IORef Int, + srvSubCount :: IORef Int, + srvSubDuplicate :: IORef Int, + srvSubQueues :: IORef Int, + srvSubEnd :: IORef Int + } + +data ServiceStatsData = ServiceStatsData + { _srvAssocNew :: Int, + _srvAssocDuplicate :: Int, + _srvAssocUpdated :: Int, + _srvAssocRemoved :: Int, + _srvSubCount :: Int, + _srvSubDuplicate :: Int, + _srvSubQueues :: Int, + _srvSubEnd :: Int + } + deriving (Show) + +newServiceStatsData :: ServiceStatsData +newServiceStatsData = + ServiceStatsData + { _srvAssocNew = 0, + _srvAssocDuplicate = 0, + _srvAssocUpdated = 0, + _srvAssocRemoved = 0, + _srvSubCount = 0, + _srvSubDuplicate = 0, + _srvSubQueues = 0, + _srvSubEnd = 0 + } + +newServiceStats :: IO ServiceStats +newServiceStats = do + srvAssocNew <- newIORef 0 + srvAssocDuplicate <- newIORef 0 + srvAssocUpdated <- newIORef 0 + srvAssocRemoved <- newIORef 0 + srvSubCount <- newIORef 0 + srvSubDuplicate <- newIORef 0 + srvSubQueues <- newIORef 0 + srvSubEnd <- newIORef 0 + pure + ServiceStats + { srvAssocNew, + srvAssocDuplicate, + srvAssocUpdated, + srvAssocRemoved, + srvSubCount, + srvSubDuplicate, + srvSubQueues, + srvSubEnd + } + +getServiceStatsData :: ServiceStats -> IO ServiceStatsData +getServiceStatsData s = do + _srvAssocNew <- readIORef $ srvAssocNew s + _srvAssocDuplicate <- readIORef $ srvAssocDuplicate s + _srvAssocUpdated <- readIORef $ srvAssocUpdated s + _srvAssocRemoved <- readIORef $ srvAssocRemoved s + _srvSubCount <- readIORef $ srvSubCount s + _srvSubDuplicate <- readIORef $ srvSubDuplicate s + _srvSubQueues <- readIORef $ srvSubQueues s + _srvSubEnd <- readIORef $ srvSubEnd s + pure + ServiceStatsData + { _srvAssocNew, + _srvAssocDuplicate, + _srvAssocUpdated, + _srvAssocRemoved, + _srvSubCount, + _srvSubDuplicate, + _srvSubQueues, + _srvSubEnd + } + +getResetServiceStatsData :: ServiceStats -> IO ServiceStatsData +getResetServiceStatsData s = do + _srvAssocNew <- atomicSwapIORef (srvAssocNew s) 0 + _srvAssocDuplicate <- atomicSwapIORef (srvAssocDuplicate s) 0 + _srvAssocUpdated <- atomicSwapIORef (srvAssocUpdated s) 0 + _srvAssocRemoved <- atomicSwapIORef (srvAssocRemoved s) 0 + _srvSubCount <- atomicSwapIORef (srvSubCount s) 0 + _srvSubDuplicate <- atomicSwapIORef (srvSubDuplicate s) 0 + _srvSubQueues <- atomicSwapIORef (srvSubQueues s) 0 + _srvSubEnd <- atomicSwapIORef (srvSubEnd s) 0 + pure + ServiceStatsData + { _srvAssocNew, + _srvAssocDuplicate, + _srvAssocUpdated, + _srvAssocRemoved, + _srvSubCount, + _srvSubDuplicate, + _srvSubQueues, + _srvSubEnd + } + +-- this function is not thread safe, it is used on server start only +setServiceStats :: ServiceStats -> ServiceStatsData -> IO () +setServiceStats s d = do + writeIORef (srvAssocNew s) $! _srvAssocNew d + writeIORef (srvAssocDuplicate s) $! _srvAssocDuplicate d + writeIORef (srvAssocUpdated s) $! _srvAssocUpdated d + writeIORef (srvAssocRemoved s) $! _srvAssocRemoved d + writeIORef (srvSubCount s) $! _srvSubCount d + writeIORef (srvSubDuplicate s) $! _srvSubDuplicate d + writeIORef (srvSubQueues s) $! _srvSubQueues d + writeIORef (srvSubEnd s) $! _srvSubEnd d + +instance StrEncoding ServiceStatsData where + strEncode ServiceStatsData {_srvAssocNew, _srvAssocDuplicate, _srvAssocUpdated, _srvAssocRemoved, _srvSubCount, _srvSubDuplicate, _srvSubQueues, _srvSubEnd} = + "assocNew=" + <> strEncode _srvAssocNew + <> "\nassocDuplicate=" + <> strEncode _srvAssocDuplicate + <> "\nassocUpdatedt=" + <> strEncode _srvAssocUpdated + <> "\nassocRemoved=" + <> strEncode _srvAssocRemoved + <> "\nsubCount=" + <> strEncode _srvSubCount + <> "\nsubDuplicate=" + <> strEncode _srvSubDuplicate + <> "\nsubQueues=" + <> strEncode _srvSubQueues + <> "\nsubEnd=" + <> strEncode _srvSubEnd + strP = do + _srvAssocNew <- "assocNew=" *> strP <* A.endOfLine + _srvAssocDuplicate <- "assocDuplicate=" *> strP <* A.endOfLine + _srvAssocUpdated <- "assocUpdatedt=" *> strP <* A.endOfLine + _srvAssocRemoved <- "assocRemoved=" *> strP <* A.endOfLine + _srvSubCount <- "subCount=" *> strP <* A.endOfLine + _srvSubDuplicate <- "subDuplicate=" *> strP <* A.endOfLine + _srvSubQueues <- "subQueues=" *> strP <* A.endOfLine + _srvSubEnd <- "subEnd=" *> strP + pure + ServiceStatsData + { _srvAssocNew, + _srvAssocDuplicate, + _srvAssocUpdated, + _srvAssocRemoved, + _srvSubCount, + _srvSubDuplicate, + _srvSubQueues, + _srvSubEnd + } diff --git a/src/Simplex/Messaging/Server/StoreLog.hs b/src/Simplex/Messaging/Server/StoreLog.hs index dffc818e3..0baad8a11 100644 --- a/src/Simplex/Messaging/Server/StoreLog.hs +++ b/src/Simplex/Messaging/Server/StoreLog.hs @@ -29,6 +29,8 @@ module Simplex.Messaging.Server.StoreLog logDeleteQueue, logDeleteNotifier, logUpdateQueueTime, + logNewService, + logQueueService, readWriteStoreLog, readLogLines, foldLogLines, @@ -74,6 +76,8 @@ data StoreLogRecord | DeleteQueue QueueId | DeleteNotifier QueueId | UpdateTime QueueId RoundedSystemTime + | NewService ServiceRec + | QueueService RecipientId ASubscriberParty (Maybe ServiceId) deriving (Show) data SLRTag @@ -89,24 +93,29 @@ data SLRTag | DeleteQueue_ | DeleteNotifier_ | UpdateTime_ + | NewService_ + | QueueService_ instance StrEncoding QueueRec where - strEncode QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, queueData, notifier, status, updatedAt} = - B.unwords - [ "rk=" <> strEncode recipientKeys, - "rdh=" <> strEncode rcvDhSecret, - "sid=" <> strEncode senderId, - "sk=" <> strEncode senderKey + strEncode QueueRec {recipientKeys, rcvDhSecret, rcvServiceId, senderId, senderKey, queueMode, queueData, notifier, status, updatedAt} = + B.concat + [ p "rk=" recipientKeys, + p " rdh=" rcvDhSecret, + p " sid=" senderId, + p " sk=" senderKey, + maybe "" ((" queue_mode=" <>) . smpEncode) queueMode, + opt " link_id=" (fst <$> queueData), + opt " queue_data=" (snd <$> queueData), + opt " notifier=" notifier, + opt " updated_at=" updatedAt, + statusStr, + opt " rsrv=" rcvServiceId ] - <> maybe "" ((" queue_mode=" <>) . smpEncode) queueMode - <> opt " link_id=" (fst <$> queueData) - <> opt " queue_data=" (snd <$> queueData) - <> opt " notifier=" notifier - <> opt " updated_at=" updatedAt - <> statusStr where + p :: StrEncoding a => ByteString -> a -> ByteString + p param = (param <>) . strEncode opt :: StrEncoding a => ByteString -> Maybe a -> ByteString - opt param = maybe "" ((param <>) . strEncode) + opt = maybe "" . p statusStr = case status of EntityActive -> "" _ -> " status=" <> strEncode status @@ -124,7 +133,20 @@ instance StrEncoding QueueRec where notifier <- optional $ " notifier=" *> strP updatedAt <- optional $ " updated_at=" *> strP status <- (" status=" *> strP) <|> pure EntityActive - pure QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, queueMode, queueData, notifier, status, updatedAt} + rcvServiceId <- optional $ " rsrv=" *> strP + pure + QueueRec + { recipientKeys, + rcvDhSecret, + senderId, + senderKey, + queueMode, + queueData, + notifier, + status, + updatedAt, + rcvServiceId + } where toQueueMode sndSecure = Just $ if sndSecure then QMMessaging else QMContact @@ -142,6 +164,8 @@ instance StrEncoding SLRTag where DeleteQueue_ -> "DELETE" DeleteNotifier_ -> "NDELETE" UpdateTime_ -> "TIME" + NewService_ -> "NEW_SERVICE" + QueueService_ -> "QUEUE_SERVICE" strP = A.choice @@ -156,7 +180,9 @@ instance StrEncoding SLRTag where "UNBLOCK" $> UnblockQueue_, "DELETE" $> DeleteQueue_, "NDELETE" $> DeleteNotifier_, - "TIME" $> UpdateTime_ + "TIME" $> UpdateTime_, + "NEW_SERVICE" $> NewService_, + "QUEUE_SERVICE" $> QueueService_ ] instance StrEncoding StoreLogRecord where @@ -173,6 +199,8 @@ instance StrEncoding StoreLogRecord where DeleteQueue rId -> strEncode (DeleteQueue_, rId) DeleteNotifier rId -> strEncode (DeleteNotifier_, rId) UpdateTime rId t -> strEncode (UpdateTime_, rId, t) + NewService sr -> strEncode (NewService_, sr) + QueueService rId party serviceId -> strEncode (QueueService_, rId, party, serviceId) strP = strP_ >>= \case @@ -188,6 +216,8 @@ instance StrEncoding StoreLogRecord where DeleteQueue_ -> DeleteQueue <$> strP DeleteNotifier_ -> DeleteNotifier <$> strP UpdateTime_ -> UpdateTime <$> strP_ <*> strP + NewService_ -> NewService <$> strP + QueueService_ -> QueueService <$> strP_ <*> strP_ <*> strP openWriteStoreLog :: Bool -> FilePath -> IO (StoreLog 'WriteMode) openWriteStoreLog append f = do @@ -253,6 +283,12 @@ logDeleteNotifier s = writeStoreLogRecord s . DeleteNotifier logUpdateQueueTime :: StoreLog 'WriteMode -> QueueId -> RoundedSystemTime -> IO () logUpdateQueueTime s qId t = writeStoreLogRecord s $ UpdateTime qId t +logNewService :: StoreLog 'WriteMode -> ServiceRec -> IO () +logNewService s = writeStoreLogRecord s . NewService + +logQueueService :: (PartyI p, SubscriberParty p) => StoreLog 'WriteMode -> RecipientId -> SParty p -> Maybe ServiceId -> IO () +logQueueService s rId party = writeStoreLogRecord s . QueueService rId (ASP party) + readWriteStoreLog :: (FilePath -> s -> IO ()) -> (StoreLog 'WriteMode -> s -> IO ()) -> FilePath -> s -> IO (StoreLog 'WriteMode) readWriteStoreLog readStore writeStore f st = ifM diff --git a/src/Simplex/Messaging/Server/StoreLog/ReadWrite.hs b/src/Simplex/Messaging/Server/StoreLog/ReadWrite.hs index bc576001c..ea6c9ed4a 100644 --- a/src/Simplex/Messaging/Server/StoreLog/ReadWrite.hs +++ b/src/Simplex/Messaging/Server/StoreLog/ReadWrite.hs @@ -2,8 +2,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module Simplex.Messaging.Server.StoreLog.ReadWrite where @@ -16,24 +18,23 @@ import qualified Data.ByteString.Char8 as B import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1) import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (ErrorType, RecipientId, SParty (..)) -import Simplex.Messaging.Server.QueueStore (QueueRec) +import Simplex.Messaging.Protocol (ASubscriberParty (..), ErrorType, RecipientId, SParty (..)) +import Simplex.Messaging.Server.QueueStore (QueueRec, ServiceRec (..)) +import Simplex.Messaging.Server.QueueStore.STM (STMQueueStore (..), STMService (..)) import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.Util (tshow) import System.IO -writeQueueStore :: forall q s. QueueStoreClass q s => StoreLog 'WriteMode -> s -> IO () -writeQueueStore s st = withLoadedQueues st $ writeQueue +writeQueueStore :: forall q. StoreQueueClass q => StoreLog 'WriteMode -> STMQueueStore q -> IO () +writeQueueStore s st = do + readTVarIO (services st) >>= mapM_ (logNewService s . serviceRec) + withLoadedQueues st $ writeQueue where writeQueue :: q -> IO () - writeQueue q = do - let rId = recipientId q - readTVarIO (queueRec q) >>= \case - Just q' -> logCreateQueue s rId q' - Nothing -> pure () + writeQueue q = readTVarIO (queueRec q) >>= mapM_ (logCreateQueue s $ recipientId q) -readQueueStore :: forall q s. QueueStoreClass q s => Bool -> (RecipientId -> QueueRec -> IO q) -> FilePath -> s -> IO () +readQueueStore :: forall q. StoreQueueClass q => Bool -> (RecipientId -> QueueRec -> IO q) -> FilePath -> STMQueueStore q -> IO () readQueueStore tty mkQ f st = readLogLines tty f $ \_ -> processLine where processLine :: B.ByteString -> IO () @@ -53,6 +54,14 @@ readQueueStore tty mkQ f st = readLogLines tty f $ \_ -> processLine DeleteQueue qId -> withQueue qId "DeleteQueue" $ deleteStoreQueue st DeleteNotifier qId -> withQueue qId "DeleteNotifier" $ deleteQueueNotifier st UpdateTime qId t -> withQueue qId "UpdateTime" $ \q -> updateQueueTime st q t + NewService sr@ServiceRec {serviceId} -> getCreateService @q st sr >>= \case + Right serviceId' + | serviceId == serviceId' -> pure () + | otherwise -> logError $ errPfx <> "created with the wrong ID " <> decodeLatin1 (strEncode serviceId') + Left e -> logError $ errPfx <> tshow e + where + errPfx = "STORE: getCreateService, stored service " <> decodeLatin1 (strEncode serviceId) <> ", " + QueueService rId (ASP party) serviceId -> withQueue rId "QueueService" $ \q -> setQueueService st q party serviceId printError :: String -> IO () printError e = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> s withQueue :: forall a. RecipientId -> T.Text -> (q -> IO (Either ErrorType a)) -> IO () diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 13384ce64..0b0c440e4 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -1,10 +1,12 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} @@ -53,6 +55,7 @@ module Simplex.Messaging.Transport encryptedBlockSMPVersion, blockedEntitySMPVersion, shortLinksSMPVersion, + serviceCertsSMPVersion, simplexMQVersion, smpBlockSize, TransportConfig (..), @@ -70,6 +73,9 @@ module Simplex.Messaging.Transport -- * TLS Transport TLS (..), SessionId, + ServiceId, + EntityId (..), + pattern NoEntity, ALPN, connectTLS, closeTLS, @@ -82,6 +88,11 @@ module Simplex.Messaging.Transport THandleParams (..), THandleAuth (..), CertChainPubKey (..), + ServiceCredentials (..), + THClientService' (..), + THClientService, + THPeerClientService, + SMPServiceRole (..), TSbChainKeys (..), TransportError (..), HandshakeError (..), @@ -97,7 +108,7 @@ where import Control.Applicative (optional) import Control.Concurrent.STM -import Control.Monad (forM, when, (<$!>)) +import Control.Monad import Control.Monad.Except import Control.Monad.IO.Class import Control.Monad.Trans.Except (throwE) @@ -125,8 +136,10 @@ import qualified Network.TLS.Extra as TE import qualified Paths_simplexmq as SMQ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (dropPrefix, parseRead1, sumTypeJSON) import Simplex.Messaging.Transport.Buffer +import Simplex.Messaging.Transport.Shared import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith) import Simplex.Messaging.Version import Simplex.Messaging.Version.Internal @@ -154,6 +167,7 @@ smpBlockSize = 16384 -- 12 - BLOCKED error for blocked queues (1/11/2025) -- 14 - proxyServer handshake property to disable transport encryption between server and proxy (1/19/2025) -- 15 - short links, with associated data passed in NEW of LSET command (3/30/2025) +-- 16 - service certificates (5/31/2025) data SMPVersion @@ -193,6 +207,9 @@ proxyServerHandshakeSMPVersion = VersionSMP 14 shortLinksSMPVersion :: VersionSMP shortLinksSMPVersion = VersionSMP 15 +serviceCertsSMPVersion :: VersionSMP +serviceCertsSMPVersion = VersionSMP 16 + minClientSMPRelayVersion :: VersionSMP minClientSMPRelayVersion = VersionSMP 6 @@ -200,13 +217,13 @@ minServerSMPRelayVersion :: VersionSMP minServerSMPRelayVersion = VersionSMP 6 currentClientSMPRelayVersion :: VersionSMP -currentClientSMPRelayVersion = VersionSMP 15 +currentClientSMPRelayVersion = VersionSMP 16 legacyServerSMPRelayVersion :: VersionSMP legacyServerSMPRelayVersion = VersionSMP 6 currentServerSMPRelayVersion :: VersionSMP -currentServerSMPRelayVersion = VersionSMP 15 +currentServerSMPRelayVersion = VersionSMP 16 -- Max SMP protocol version to be used in e2e encrypted -- connection between client and server, as defined by SMP proxy. @@ -255,9 +272,14 @@ class Typeable c => Transport (c :: TransportPeer -> Type) where transportConfig :: c p -> TransportConfig -- | Upgrade TLS context to connection - getTransportConnection :: TransportPeerI p => TransportConfig -> X.CertificateChain -> T.Context -> IO (c p) + getTransportConnection :: TransportPeerI p => TransportConfig -> Bool -> X.CertificateChain -> T.Context -> IO (c p) - -- | TLS certificate chain, server's in the client, client's in the server (empty chain) + -- | Whether TLS certificate chain was provided to peer + -- It is always True for the server. + -- It is True for the client when server requested it AND non-empty chain is sent. + certificateSent :: c p -> Bool + + -- | TLS certificate chain, server's in the client, client's in the server (empty chain for non-service clients) getPeerCertChain :: c p -> X.CertificateChain -- | tls-unique channel binding per RFC5929 @@ -317,6 +339,7 @@ data TLS (p :: TransportPeer) = TLS tlsUniq :: ByteString, tlsBuffer :: TBuffer, tlsALPN :: Maybe ALPN, + tlsCertSent :: Bool, -- see comment for certificateSent tlsPeerCert :: X.CertificateChain, tlsTransportConfig :: TransportConfig } @@ -332,13 +355,13 @@ connectTLS host_ TransportConfig {logTLSErrors} params sock = logThrow e = putStrLn ("TLS error" <> host <> ": " <> show e) >> E.throwIO e host = maybe "" (\h -> " (" <> h <> ")") host_ -getTLS :: forall p. TransportPeerI p => TransportConfig -> X.CertificateChain -> T.Context -> IO (TLS p) -getTLS cfg tlsPeerCert cxt = withTlsUnique @TLS @p cxt newTLS +getTLS :: forall p. TransportPeerI p => TransportConfig -> Bool -> X.CertificateChain -> T.Context -> IO (TLS p) +getTLS cfg tlsCertSent tlsPeerCert cxt = withTlsUnique @TLS @p cxt newTLS where newTLS tlsUniq = do tlsBuffer <- newTBuffer tlsALPN <- T.getNegotiatedProtocol cxt - pure TLS {tlsContext = cxt, tlsALPN, tlsTransportConfig = cfg, tlsPeerCert, tlsUniq, tlsBuffer} + pure TLS {tlsContext = cxt, tlsALPN, tlsTransportConfig = cfg, tlsCertSent, tlsPeerCert, tlsUniq, tlsBuffer} withTlsUnique :: forall c p. TransportPeerI p => T.Context -> (ByteString -> IO (c p)) -> IO (c p) withTlsUnique cxt f = @@ -396,6 +419,8 @@ instance Transport TLS where {-# INLINE transportConfig #-} getTransportConnection = getTLS {-# INLINE getTransportConnection #-} + certificateSent = tlsCertSent + {-# INLINE certificateSent #-} getPeerCertChain = tlsPeerCert {-# INLINE getPeerCertChain #-} getSessionALPN = tlsALPN @@ -450,22 +475,37 @@ data THandleParams v p = THandleParams encryptBlock :: Maybe TSbChainKeys, -- | send multiple transmissions in a single block -- based on protocol version - batch :: Bool + batch :: Bool, + -- | include service signature (or '0' if it is absent), based on protocol version + serviceAuth :: Bool } data THandleAuth (p :: TransportPeer) where THAuthClient :: - { serverPeerPubKey :: C.PublicKeyX25519, -- used by the client to combine with client's private per-queue key - serverCertKey :: CertChainPubKey, -- the key here is serverPeerPubKey signed with server certificate + { peerServerPubKey :: C.PublicKeyX25519, -- used by the client to combine with client's private per-queue key + peerServerCertKey :: CertChainPubKey, -- the key here is peerServerCertKey signed with server certificate + clientService :: Maybe THClientService, sessSecret :: Maybe C.DhSecretX25519 -- session secret (will be used in SMP proxy only) } -> THandleAuth 'TClient THAuthServer :: { serverPrivKey :: C.PrivateKeyX25519, -- used by the server to combine with client's public per-queue key + peerClientService :: Maybe THPeerClientService, sessSecret' :: Maybe C.DhSecretX25519 -- session secret (will be used in SMP proxy only) } -> THandleAuth 'TServer +type THClientService = THClientService' C.PrivateKeyEd25519 + +type THPeerClientService = THClientService' C.PublicKeyEd25519 + +data THClientService' k = THClientService + { serviceId :: ServiceId, + serviceRole :: SMPServiceRole, + serviceCertHash :: XV.Fingerprint, + serviceKey :: k + } + data TSbChainKeys = TSbChainKeys { sndKey :: TVar C.SbChainKey, rcvKey :: TVar C.SbChainKey @@ -474,6 +514,16 @@ data TSbChainKeys = TSbChainKeys -- | TLS-unique channel binding type SessionId = ByteString +type ServiceId = EntityId + +-- this type is used for server entities only +newtype EntityId = EntityId {unEntityId :: ByteString} + deriving (Eq, Ord, Show) + deriving newtype (Encoding, StrEncoding) + +pattern NoEntity :: EntityId +pattern NoEntity = EntityId "" + data SMPServerHandshake = SMPServerHandshake { smpVersionRange :: VersionRangeSMP, sessionId :: SessionId, @@ -482,6 +532,14 @@ data SMPServerHandshake = SMPServerHandshake authPubKey :: Maybe CertChainPubKey } +-- This is the third handshake message that SMP server sends to services +-- in response to them sending `clientService` field. +-- The client would wait for this message in case `clientService` was sent +-- (and it can only be sent once client knows that service supports it.) +data SMPServerHandshakeResponse + = SMPServerHandshakeResponse {serviceId :: ServiceId} + | SMPServerHandshakeError {handshakeError :: TransportError} + data SMPClientHandshake = SMPClientHandshake { -- | agreed SMP server protocol version smpVersion :: VersionSMP, @@ -489,27 +547,79 @@ data SMPClientHandshake = SMPClientHandshake keyHash :: C.KeyHash, -- | pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys. authPubKey :: Maybe C.PublicKeyX25519, + -- TODO [certs] remove proxyServer, as serviceInfo includes it as clientRole -- | Whether connecting client is a proxy server (send from SMP v12). -- This property, if True, disables additional transport encrytion inside TLS. -- (Proxy server connection already has additional encryption, so this layer is not needed there). - proxyServer :: Bool + proxyServer :: Bool, + -- | optional long-term service client certificate of a high-volume service using SMP server. + -- This certificate MUST be used both in TLS and in protocol handshake. + -- It signs the key that is used to authorize: + -- - queue creation commands (in addition to authorization by queue key) - it creates association of the queue with this certificate, + -- - "handover" subscription command (in addition to queue key) - it also creates association, + -- - bulk subscription command CSUB. + -- SHA512 hash of this certificate is stored to associate queues with this client. + -- These certificates are used by the servers and services connecting to SMP servers: + -- - chat relays, + -- - notification servers, + -- - high traffic chat bots, + -- - high traffic business support clients. + clientService :: Maybe SMPClientHandshakeService } +data SMPClientHandshakeService = SMPClientHandshakeService + { serviceRole :: SMPServiceRole, + serviceCertKey :: CertChainPubKey + } + +data ServiceCredentials = ServiceCredentials + { serviceRole :: SMPServiceRole, + serviceCreds :: T.Credential, + serviceCertHash :: XV.Fingerprint, + serviceSignKey :: C.APrivateSignKey + } + +data SMPServiceRole = SRMessaging | SRNotifier | SRProxy deriving (Eq, Show) + instance Encoding SMPClientHandshake where - smpEncode SMPClientHandshake {smpVersion = v, keyHash, authPubKey, proxyServer} = + smpEncode SMPClientHandshake {smpVersion = v, keyHash, authPubKey, proxyServer, clientService} = smpEncode (v, keyHash) <> encodeAuthEncryptCmds v authPubKey <> ifHasProxy v (smpEncode proxyServer) "" + <> ifHasService v (smpEncode clientService) "" smpP = do (v, keyHash) <- smpP -- TODO drop SMP v6: remove special parser and make key non-optional authPubKey <- authEncryptCmdsP v smpP proxyServer <- ifHasProxy v smpP (pure False) - pure SMPClientHandshake {smpVersion = v, keyHash, authPubKey, proxyServer} + clientService <- ifHasService v smpP (pure Nothing) + pure SMPClientHandshake {smpVersion = v, keyHash, authPubKey, proxyServer, clientService} + +instance Encoding SMPClientHandshakeService where + smpEncode SMPClientHandshakeService {serviceRole, serviceCertKey} = + smpEncode (serviceRole, serviceCertKey) + smpP = do + (serviceRole, serviceCertKey) <- smpP + pure SMPClientHandshakeService {serviceRole, serviceCertKey} + +instance Encoding SMPServiceRole where + smpEncode = \case + SRMessaging -> "M" + SRNotifier -> "N" + SRProxy -> "P" + smpP = + A.anyChar >>= \case + 'M' -> pure SRMessaging + 'N' -> pure SRNotifier + 'P' -> pure SRProxy + _ -> fail "bad SMPServiceRole" ifHasProxy :: VersionSMP -> a -> a -> a ifHasProxy v a b = if v >= proxyServerHandshakeSMPVersion then a else b +ifHasService :: VersionSMP -> a -> a -> a +ifHasService v a b = if v >= serviceCertsSMPVersion then a else b + instance Encoding SMPServerHandshake where smpEncode SMPServerHandshake {smpVersionRange, sessionId, authPubKey} = smpEncode (smpVersionRange, sessionId) <> auth @@ -543,6 +653,16 @@ encodeAuthEncryptCmds v k authEncryptCmdsP :: VersionSMP -> Parser a -> Parser (Maybe a) authEncryptCmdsP v p = if v >= authCmdsSMPVersion then optional p else pure Nothing +instance Encoding SMPServerHandshakeResponse where + smpEncode = \case + SMPServerHandshakeResponse serviceId -> smpEncode ('R', serviceId) + SMPServerHandshakeError handshakeError -> smpEncode ('E', handshakeError) + smpP = + A.anyChar >>= \case + 'R' -> SMPServerHandshakeResponse <$> smpP + 'E' -> SMPServerHandshakeError <$> smpP + _ -> fail "bad SMPServerHandshakeResponse" + -- | Error of SMP encrypted transport over TCP. data TransportError = -- | error parsing transport block @@ -568,6 +688,8 @@ data HandshakeError IDENTITY | -- | v7 authentication failed BAD_AUTH + | -- | error reading/creating service record + BAD_SERVICE deriving (Eq, Read, Show, Exception) instance Encoding TransportError where @@ -615,27 +737,52 @@ tGetBlock THandle {connection = c, params = THandleParams {blockSize, encryptBlo -- | Server SMP transport handshake. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -smpServerHandshake :: forall c. Transport c => X.CertificateChain -> C.APrivateSignKey -> c 'TServer -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c 'TServer) -smpServerHandshake srvCert srvSignKey c (k, pk) kh smpVRange = do - let th@THandle {params = THandleParams {sessionId}} = smpTHandle c - sk = C.signX509 srvSignKey $ C.publicToX509 k +smpServerHandshake :: + forall c. Transport c => + X.CertificateChain -> + C.APrivateSignKey -> + c 'TServer -> + C.KeyPairX25519 -> + C.KeyHash -> + VersionRangeSMP -> + (SMPServiceRole -> X.CertificateChain -> XV.Fingerprint -> ExceptT TransportError IO ServiceId) -> + ExceptT TransportError IO (THandleSMP c 'TServer) +smpServerHandshake srvCert srvSignKey c (k, pk) kh smpVRange getService = do + let sk = C.signX509 srvSignKey $ C.publicToX509 k smpVersionRange = maybe legacyServerSMPRelayVRange (const smpVRange) $ getSessionALPN c sendHandshake th $ SMPServerHandshake {sessionId, smpVersionRange, authPubKey = Just (CertChainPubKey srvCert sk)} - getHandshake th >>= \case - SMPClientHandshake {smpVersion = v, keyHash, authPubKey = k', proxyServer} - | keyHash /= kh -> - throwE $ TEHandshake IDENTITY - | otherwise -> - case compatibleVRange' smpVersionRange v of - Just (Compatible vr) -> liftIO $ smpTHandleServer th v vr pk k' proxyServer - Nothing -> throwE TEVersion + SMPClientHandshake {smpVersion = v, keyHash, authPubKey = k', proxyServer, clientService} <- getHandshake th + when (keyHash /= kh) $ throwE $ TEHandshake IDENTITY + case compatibleVRange' smpVersionRange v of + Just (Compatible vr) -> do + service <- mapM getClientService clientService + liftIO $ smpTHandleServer th v vr pk k' proxyServer service + Nothing -> throwE TEVersion + where + th@THandle {params = THandleParams {sessionId}} = smpTHandle c + getClientService :: SMPClientHandshakeService -> ExceptT TransportError IO THPeerClientService + getClientService SMPClientHandshakeService {serviceRole, serviceCertKey = CertChainPubKey cc exact} = handleError sendErr $ do + unless (getPeerCertChain c == cc) $ throwE $ TEHandshake BAD_AUTH + (idCert, serviceKey) <- liftEitherWith (const $ TEHandshake BAD_AUTH) $ do + (leafCert, idCert) <- case chainIdCaCerts cc of + CCSelf cert -> pure (cert, cert) + CCValid {leafCert, idCert} -> pure (leafCert, idCert) + _ -> throwError "bad certificate" + serviceCertKey <- getCertVerifyKey leafCert + (idCert,) <$> (C.x509ToPublic' =<< C.verifyX509 serviceCertKey exact) + let fp = XV.getFingerprint idCert X.HashSHA256 + serviceId <- getService serviceRole cc fp + sendHandshake th $ SMPServerHandshakeResponse {serviceId} + pure THClientService {serviceId, serviceRole, serviceCertHash = fp, serviceKey} + sendErr err = do + sendHandshake th $ SMPServerHandshakeError {handshakeError = err} + throwError err -- | Client SMP transport handshake. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -smpClientHandshake :: forall c. Transport c => c 'TClient -> Maybe C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> Bool -> ExceptT TransportError IO (THandleSMP c 'TClient) -smpClientHandshake c ks_ keyHash@(C.KeyHash kh) vRange proxyServer = do - let th@THandle {params = THandleParams {sessionId}} = smpTHandle c +smpClientHandshake :: forall c. Transport c => c 'TClient -> Maybe C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> Bool -> Maybe (ServiceCredentials, C.KeyPairEd25519) -> ExceptT TransportError IO (THandleSMP c 'TClient) +smpClientHandshake c ks_ keyHash@(C.KeyHash kh) vRange proxyServer serviceKeys_ = do SMPServerHandshake {sessionId = sessId, smpVersionRange, authPubKey} <- getHandshake th when (sessionId /= sessId) $ throwE TEBadSession -- Below logic downgrades version range in case the "client" is SMP proxy server and it is @@ -657,30 +804,55 @@ smpClientHandshake c ks_ keyHash@(C.KeyHash kh) vRange proxyServer = do else vRange case smpVersionRange `compatibleVRange` smpVRange of Just (Compatible vr) -> do - ck_ <- forM authPubKey $ \certKey@(CertChainPubKey (X.CertificateChain cert) exact) -> + ck_ <- forM authPubKey $ \certKey@(CertChainPubKey chain exact) -> liftEitherWith (const $ TEHandshake BAD_AUTH) $ do - case cert of - [_leaf, ca] | XV.Fingerprint kh == XV.getFingerprint ca X.HashSHA256 -> pure () + case chainIdCaCerts chain of + CCValid {idCert} | XV.Fingerprint kh == XV.getFingerprint idCert X.HashSHA256 -> pure () _ -> throwError "bad certificate" serverKey <- getServerVerifyKey c (,certKey) <$> (C.x509ToPublic' =<< C.verifyX509 serverKey exact) let v = maxVersion vr - sendHandshake th $ SMPClientHandshake {smpVersion = v, keyHash, authPubKey = fst <$> ks_, proxyServer} - liftIO $ smpTHandleClient th v vr (snd <$> ks_) ck_ proxyServer + serviceKeys = case serviceKeys_ of + Just sks | v >= serviceCertsSMPVersion && certificateSent c -> Just sks + _ -> Nothing + clientService = mkClientService <$> serviceKeys + hs = SMPClientHandshake {smpVersion = v, keyHash, authPubKey = fst <$> ks_, proxyServer, clientService} + sendHandshake th hs + service <- mapM getClientService serviceKeys + liftIO $ smpTHandleClient th v vr (snd <$> ks_) ck_ proxyServer service Nothing -> throwE TEVersion + where + th@THandle {params = THandleParams {sessionId}} = smpTHandle c + mkClientService :: (ServiceCredentials, C.KeyPairEd25519) -> SMPClientHandshakeService + mkClientService (ServiceCredentials {serviceRole, serviceCreds, serviceSignKey}, (k, _)) = + let sk = C.signX509 serviceSignKey $ C.publicToX509 k + in SMPClientHandshakeService {serviceRole, serviceCertKey = CertChainPubKey (fst serviceCreds) sk} + getClientService :: (ServiceCredentials, C.KeyPairEd25519) -> ExceptT TransportError IO THClientService + getClientService (ServiceCredentials {serviceRole, serviceCertHash}, (_, pk)) = + getHandshake th >>= \case + SMPServerHandshakeResponse {serviceId} -> pure THClientService {serviceId, serviceRole, serviceCertHash, serviceKey = pk} + SMPServerHandshakeError {handshakeError} -> throwE handshakeError -smpTHandleServer :: forall c. THandleSMP c 'TServer -> VersionSMP -> VersionRangeSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> Bool -> IO (THandleSMP c 'TServer) -smpTHandleServer th v vr pk k_ proxyServer = do - let thAuth = Just THAuthServer {serverPrivKey = pk, sessSecret' = (`C.dh'` pk) <$!> k_} +smpTHandleServer :: forall c. THandleSMP c 'TServer -> VersionSMP -> VersionRangeSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> Bool -> Maybe THPeerClientService -> IO (THandleSMP c 'TServer) +smpTHandleServer th v vr pk k_ proxyServer peerClientService = do + let thAuth = Just THAuthServer {serverPrivKey = pk, peerClientService, sessSecret' = (`C.dh'` pk) <$!> k_} be <- blockEncryption th v proxyServer thAuth pure $ smpTHandle_ th v vr thAuth $ uncurry TSbChainKeys <$> be -smpTHandleClient :: forall c. THandleSMP c 'TClient -> VersionSMP -> VersionRangeSMP -> Maybe C.PrivateKeyX25519 -> Maybe (C.PublicKeyX25519, CertChainPubKey) -> Bool -> IO (THandleSMP c 'TClient) -smpTHandleClient th v vr pk_ ck_ proxyServer = do - let thAuth = (\(k, ck) -> THAuthClient {serverPeerPubKey = k, serverCertKey = forceCertChain ck, sessSecret = C.dh' k <$!> pk_}) <$!> ck_ +smpTHandleClient :: forall c. THandleSMP c 'TClient -> VersionSMP -> VersionRangeSMP -> Maybe C.PrivateKeyX25519 -> Maybe (C.PublicKeyX25519, CertChainPubKey) -> Bool -> Maybe THClientService -> IO (THandleSMP c 'TClient) +smpTHandleClient th v vr pk_ ck_ proxyServer clientService = do + let thAuth = clientTHParams <$!> ck_ be <- blockEncryption th v proxyServer thAuth -- swap is needed to use client's sndKey as server's rcvKey and vice versa pure $ smpTHandle_ th v vr thAuth $ uncurry TSbChainKeys . swap <$> be + where + clientTHParams (k, ck) = + THAuthClient + { peerServerPubKey = k, + peerServerCertKey = forceCertChain ck, + clientService, + sessSecret = C.dh' k <$!> pk_ + } blockEncryption :: THandleSMP c p -> VersionSMP -> Bool -> Maybe (THandleAuth p) -> IO (Maybe (TVar C.SbChainKey, TVar C.SbChainKey)) blockEncryption THandle {params = THandleParams {sessionId}} v proxyServer = \case @@ -695,17 +867,30 @@ blockEncryption THandle {params = THandleParams {sessionId}} v proxyServer = \ca smpTHandle_ :: forall c p. THandleSMP c p -> VersionSMP -> VersionRangeSMP -> Maybe (THandleAuth p) -> Maybe TSbChainKeys -> THandleSMP c p smpTHandle_ th@THandle {params} v vr thAuth encryptBlock = -- TODO drop SMP v6: make thAuth non-optional - let params' = params {thVersion = v, thServerVRange = vr, thAuth, implySessId = v >= authCmdsSMPVersion, encryptBlock} + -- * Note: update version-based parameters in smpTHParamsSetVersion as well. + let params' = + params + { thVersion = v, + thServerVRange = vr, + thAuth, + implySessId = v >= authCmdsSMPVersion, + encryptBlock, + serviceAuth = v >= serviceCertsSMPVersion -- optional service signature will be encoded for all commands and responses + } in (th :: THandleSMP c p) {params = params'} -{-# INLINE forceCertChain #-} forceCertChain :: CertChainPubKey -> CertChainPubKey forceCertChain cert@(CertChainPubKey (X.CertificateChain cc) signedKey) = length (show cc) `seq` show signedKey `seq` cert +{-# INLINE forceCertChain #-} -- This function is only used with v >= 8, so currently it's a simple record update. --- It may require some parameters update in the future, to be consistent with smpTHandle_. +-- * Note: it requires updating version-based parameters, to be consistent with smpTHandle_. smpTHParamsSetVersion :: VersionSMP -> THandleParams SMPVersion p -> THandleParams SMPVersion p -smpTHParamsSetVersion v params = params {thVersion = v} +smpTHParamsSetVersion v params = + params + { thVersion = v, + serviceAuth = v >= serviceCertsSMPVersion + } {-# INLINE smpTHParamsSetVersion #-} sendHandshake :: (Transport c, Encoding smp) => THandle v c p -> smp -> ExceptT TransportError IO () @@ -728,7 +913,8 @@ smpTHandle c = THandle {connection = c, params} thAuth = Nothing, implySessId = False, encryptBlock = Nothing, - batch = True + batch = True, + serviceAuth = False } $(J.deriveJSON (sumTypeJSON id) ''HandshakeError) diff --git a/src/Simplex/Messaging/Transport/Client.hs b/src/Simplex/Messaging/Transport/Client.hs index fa8975cb4..1dc2f56e6 100644 --- a/src/Simplex/Messaging/Transport/Client.hs +++ b/src/Simplex/Messaging/Transport/Client.hs @@ -30,13 +30,13 @@ where import Control.Applicative (optional, (<|>)) import Control.Logger.Simple (logError) -import Control.Monad (when) import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Char (isAsciiLower, isDigit, isHexDigit) import Data.Default (def) +import Data.IORef import Data.IP import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L @@ -57,6 +57,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll, parseString) import Simplex.Messaging.Transport import Simplex.Messaging.Transport.KeepAlive +import Simplex.Messaging.Transport.Shared import Simplex.Messaging.Util (bshow, catchAll, tshow, (<$?>)) import System.IO.Error import Text.Read (readMaybe) @@ -136,7 +137,16 @@ defaultTcpConnectTimeout :: Int defaultTcpConnectTimeout = 25_000_000 defaultTransportClientConfig :: TransportClientConfig -defaultTransportClientConfig = TransportClientConfig Nothing defaultTcpConnectTimeout (Just defaultKeepAliveOpts) True Nothing Nothing True +defaultTransportClientConfig = + TransportClientConfig + { socksProxy = Nothing, + tcpConnectTimeout = defaultTcpConnectTimeout, + tcpKeepAlive = Just defaultKeepAliveOpts, + logTLSErrors = True, + clientCredentials = Nothing, + clientALPN = Nothing, + useSNI = True + } clientTransportConfig :: TransportClientConfig -> TransportConfig clientTransportConfig TransportClientConfig {logTLSErrors} = @@ -149,8 +159,9 @@ runTransportClient = runTLSTransportClient defaultSupportedParams Nothing runTLSTransportClient :: Transport c => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe SocksCredentials -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c 'TClient -> IO a) -> IO a runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials, clientALPN, useSNI} socksCreds host port keyHash client = do serverCert <- newEmptyTMVarIO + clientCredsSent <- newIORef False let hostName = B.unpack $ strEncode host - clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials clientALPN useSNI serverCert + clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials clientCredsSent clientALPN useSNI serverCert connectTCP = case socksProxy of Just proxy -> connectSocksClient proxy socksCreds (hostAddr host) _ -> connectTCPClient hostName @@ -160,13 +171,9 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, let tCfg = clientTransportConfig cfg -- No TLS timeout to avoid failing connections via SOCKS tls <- connectTLS (Just hostName) tCfg clientParams sock - chain <- - atomically (tryTakeTMVar serverCert) >>= \case - Nothing -> do - logError "onServerCertificate didn't fire or failed to get cert chain" - closeTLS tls >> error "onServerCertificate failed" - Just c -> pure c - getTransportConnection tCfg chain tls + chain <- takePeerCertChain serverCert `E.onException` closeTLS tls + sent <- readIORef clientCredsSent + getTransportConnection tCfg sent chain tls client c `E.finally` closeConnection c where hostAddr = \case @@ -265,41 +272,36 @@ instance StrEncoding SocksAuth where password <- A.takeTill (== '@') <* A.char '@' pure SocksAuthUsername {username, password} -mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe T.Credential -> Maybe [ALPN] -> Bool -> TMVar X.CertificateChain -> T.ClientParams -mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ alpn_ sni serverCerts = +mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe T.Credential -> IORef Bool -> Maybe [ALPN] -> Bool -> TMVar (Maybe X.CertificateChain) -> T.ClientParams +mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ clientCredsSent alpn_ sni serverCerts = (T.defaultParamsClient host p) { T.clientUseServerNameIndication = sni, T.clientShared = def {T.sharedCAStore = fromMaybe (T.sharedCAStore def) caStore_}, T.clientHooks = def { T.onServerCertificate = onServerCert, - T.onCertificateRequest = maybe def (const . pure . Just) clientCreds_, + T.onCertificateRequest = onCertRequest, T.onSuggestALPN = pure alpn_ }, T.clientSupported = supported } where p = B.pack port - onServerCert _ _ _ c = do - errs <- maybe def (\ca -> validateCertificateChain ca host p c) cafp_ - when (null errs) $ - atomically (putTMVar serverCerts c) + onServerCert _ _ _ cc = do + errs <- maybe def (\ca -> validateCertificateChain ca host p cc) cafp_ + atomically $ putTMVar serverCerts $ if null errs then Just cc else Nothing pure errs + onCertRequest = case clientCreds_ of + Just _ -> \_ -> clientCreds_ <$ writeIORef clientCredsSent True + Nothing -> \_ -> pure Nothing validateCertificateChain :: C.KeyHash -> HostName -> ByteString -> X.CertificateChain -> IO [XV.FailedReason] -validateCertificateChain _ _ _ (X.CertificateChain []) = pure [XV.EmptyChain] -validateCertificateChain _ _ _ (X.CertificateChain [_]) = pure [XV.EmptyChain] -validateCertificateChain (C.KeyHash kh) host port cc@(X.CertificateChain [_, caCert]) = - if Fingerprint kh == XV.getFingerprint caCert X.HashSHA256 - then x509validate - else pure [XV.UnknownCA] +validateCertificateChain (C.KeyHash kh) host port cc = case chainIdCaCerts cc of + CCEmpty -> pure [XV.EmptyChain] + CCSelf _ -> pure [XV.EmptyChain] + CCValid {idCert, caCert} -> validate idCert caCert + CCLong -> pure [XV.AuthorityTooDeep] where - x509validate :: IO [XV.FailedReason] - x509validate = XV.validate X.HashSHA256 hooks checks certStore cache serviceID cc - where - hooks = XV.defaultHooks - checks = XV.defaultChecks {XV.checkFQHN = False} - certStore = XS.makeCertificateStore [caCert] - cache = XV.exceptionValidationCache [] -- we manually check fingerprint only of the identity certificate (ca.crt) - serviceID = (host, port) -validateCertificateChain _ _ _ _ = pure [XV.AuthorityTooDeep] + validate idCert caCert + | Fingerprint kh == XV.getFingerprint idCert X.HashSHA256 = x509validate caCert (host, port) cc + | otherwise = pure [XV.UnknownCA] diff --git a/src/Simplex/Messaging/Transport/Credentials.hs b/src/Simplex/Messaging/Transport/Credentials.hs index 3c82b2f78..3d6155da0 100644 --- a/src/Simplex/Messaging/Transport/Credentials.hs +++ b/src/Simplex/Messaging/Transport/Credentials.hs @@ -23,6 +23,7 @@ import Data.X509.Validation (Fingerprint (..), getFingerprint) import qualified Network.TLS as TLS import qualified Simplex.Messaging.Crypto as C import qualified Time.System as Hourglass +import qualified Time.Types as HT -- | Generate a certificate chain to be used with TLS fingerprint-pinning -- @@ -54,7 +55,9 @@ genCredentials g parent (before, after) subjectName = do Nothing -> (subjectKeys, subject) -- self-signed Just (keys, cert) -> (keys, X509.certSubjectDN . X509.signedObject $ X509.getSigned cert) today <- Hourglass.dateCurrent - let signed = + -- remove nanoseconds from time - certificate encoding/decoding removes them. + let today' = today {HT.dtTime = (HT.dtTime today) {HT.todNSec = 0}} + signed = C.signCertificate (snd issuerKeys) X509.Certificate @@ -62,7 +65,7 @@ genCredentials g parent (before, after) subjectName = do certSerial = 1, certSignatureAlg = C.signatureAlgorithmX509 issuerKeys, certIssuerDN = issuer, - certValidity = (timeAdd today (-before), timeAdd today after), + certValidity = (timeAdd today' (-before), timeAdd today' after), certSubjectDN = subject, certPubKey = C.toPubKey C.publicToX509 $ fst subjectKeys, certExtensions = X509.Extensions Nothing diff --git a/src/Simplex/Messaging/Transport/Server.hs b/src/Simplex/Messaging/Transport/Server.hs index 4e57dac5b..edb599803 100644 --- a/src/Simplex/Messaging/Transport/Server.hs +++ b/src/Simplex/Messaging/Transport/Server.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -19,11 +20,8 @@ module Simplex.Messaging.Transport.Server runTransportServer, runTransportServerSocket, runLocalTCPServer, - runTCPServerSocket, startTCPServer, loadServerCredential, - supportedTLSServerParams, - supportedTLSServerParams_, loadFingerprint, loadFileFingerprint, smpServerHandshake, @@ -34,6 +32,7 @@ import Control.Applicative ((<|>)) import Control.Logger.Simple import Control.Monad import qualified Crypto.Store.X509 as SX +import qualified Data.ByteString as B import Data.Default (def) import Data.IntMap.Strict (IntMap) import qualified Data.IntMap.Strict as IM @@ -47,6 +46,7 @@ import GHC.IO.Exception (ioe_errno) import Network.Socket import qualified Network.TLS as T import Simplex.Messaging.Transport +import Simplex.Messaging.Transport.Shared import Simplex.Messaging.Util (catchAll_, labelMyThread, tshow) import System.Exit (exitFailure) import System.IO.Error (tryIOError) @@ -59,6 +59,7 @@ import UnliftIO.STM data TransportServerConfig = TransportServerConfig { logTLSErrors :: Bool, serverALPN :: Maybe [ALPN], + askClientCert :: Bool, tlsSetupTimeout :: Int, transportTimeout :: Int } @@ -73,11 +74,12 @@ data ServerCredentials = ServerCredentials type AddHTTP = Bool -mkTransportServerConfig :: Bool -> Maybe [ALPN] ->TransportServerConfig -mkTransportServerConfig logTLSErrors serverALPN = +mkTransportServerConfig :: Bool -> Maybe [ALPN] -> Bool -> TransportServerConfig +mkTransportServerConfig logTLSErrors serverALPN askClientCert = TransportServerConfig { logTLSErrors, serverALPN, + askClientCert, tlsSetupTimeout = 60000000, transportTimeout = 40000000 } @@ -90,41 +92,54 @@ serverTransportConfig TransportServerConfig {logTLSErrors} = -- | Run transport server (plain TCP or WebSockets) on passed TCP port and signal when server started and stopped via passed TMVar. -- -- All accepted connections are passed to the passed function. -runTransportServer :: forall c. Transport c => TMVar Bool -> ServiceName -> T.Supported -> T.Credential -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO () +runTransportServer :: Transport c => TMVar Bool -> ServiceName -> T.Supported -> T.Credential -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO () runTransportServer started port srvSupported srvCreds cfg server = do ss <- newSocketState runTransportServerState ss started port srvSupported srvCreds cfg server -runTransportServerState :: forall c . Transport c => SocketState -> TMVar Bool -> ServiceName -> T.Supported -> T.Credential -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO () +runTransportServerState :: Transport c => SocketState -> TMVar Bool -> ServiceName -> T.Supported -> T.Credential -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO () runTransportServerState ss started port srvSupported srvCreds cfg server = runTransportServerState_ ss started port srvSupported (const srvCreds) cfg (const server) -runTransportServerState_ :: forall c . Transport c => SocketState -> TMVar Bool -> ServiceName -> T.Supported -> (Maybe HostName -> T.Credential) -> TransportServerConfig -> (Socket -> c 'TServer -> IO ()) -> IO () +runTransportServerState_ :: forall c. Transport c => SocketState -> TMVar Bool -> ServiceName -> T.Supported -> (Maybe HostName -> T.Credential) -> TransportServerConfig -> (Socket -> c 'TServer -> IO ()) -> IO () runTransportServerState_ ss started port = runTransportServerSocketState ss started (startTCPServer started Nothing port) (transportName (TProxy :: TProxy c 'TServer)) -- | Run a transport server with provided connection setup and handler. -runTransportServerSocket :: Transport c => TMVar Bool -> IO Socket -> String -> T.Credential -> T.ServerParams -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO () -runTransportServerSocket started getSocket threadLabel srvCreds srvParams cfg server = do +runTransportServerSocket :: Transport c => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO () +runTransportServerSocket started getSocket threadLabel srvParams cfg server = do ss <- newSocketState - runTransportServerSocketState_ ss started getSocket threadLabel (const srvCreds) srvParams cfg (const server) - -runTransportServerSocketState :: Transport c => SocketState -> TMVar Bool -> IO Socket -> String -> T.Supported -> (Maybe HostName -> T.Credential) -> TransportServerConfig -> (Socket -> c 'TServer -> IO ()) -> IO () -runTransportServerSocketState ss started getSocket threadLabel srvSupported srvCreds cfg = - runTransportServerSocketState_ ss started getSocket threadLabel srvCreds srvParams cfg - where - srvParams = supportedTLSServerParams_ srvSupported srvCreds $ serverALPN cfg - --- | Run a transport server with provided connection setup and handler. -runTransportServerSocketState_ :: Transport c => SocketState -> TMVar Bool -> IO Socket -> String -> (Maybe HostName -> T.Credential) -> T.ServerParams -> TransportServerConfig -> (Socket -> c 'TServer -> IO ()) -> IO () -runTransportServerSocketState_ ss started getSocket threadLabel srvCreds srvParams cfg server = do - labelMyThread $ "transport server for " <> threadLabel - runTCPServerSocket ss started getSocket $ \conn -> - E.bracket (setup conn >>= maybe (fail "tls setup timeout") pure) closeConnection (server conn) + runTransportServerSocketState_ ss started getSocket threadLabel (tlsSetupTimeout cfg) setupTLS (const server) where tCfg = serverTransportConfig cfg - setup conn = timeout (tlsSetupTimeout cfg) $ do - labelMyThread $ threadLabel <> "/setup" + setupTLS conn = do tls <- connectTLS Nothing tCfg srvParams conn - getTransportConnection tCfg (fst $ srvCreds Nothing) tls + getTransportConnection tCfg True (X.CertificateChain []) tls + +runTransportServerSocketState :: Transport c => SocketState -> TMVar Bool -> IO Socket -> String -> T.Supported -> (Maybe HostName -> T.Credential) -> TransportServerConfig -> (Socket -> c 'TServer -> IO ()) -> IO () +runTransportServerSocketState ss started getSocket threadLabel srvSupported srvCreds cfg server = + runTransportServerSocketState_ ss started getSocket threadLabel (tlsSetupTimeout cfg) setupTLS server + where + tCfg = serverTransportConfig cfg + srvParams = supportedTLSServerParams srvSupported srvCreds $ serverALPN cfg + setupTLS conn + | askClientCert cfg = do + clientCert <- newEmptyTMVarIO + tls <- connectTLS Nothing tCfg (paramsAskClientCert clientCert srvParams) conn + chain <- takePeerCertChain clientCert `E.onException` closeTLS tls + getTransportConnection tCfg True chain tls + | otherwise = do + tls <- connectTLS Nothing tCfg srvParams conn + getTransportConnection tCfg True (X.CertificateChain []) tls + +-- | Run a transport server with provided connection setup and handler. +runTransportServerSocketState_ :: Transport c => SocketState -> TMVar Bool -> IO Socket -> String -> Int -> (Socket -> IO (c 'TServer)) -> (Socket -> c 'TServer -> IO ()) -> IO () +runTransportServerSocketState_ ss started getSocket threadLabel tlsSetupTimeout setupTLS server = do + labelMyThread $ "transport server for " <> threadLabel + runTCPServerSocket ss started getSocket $ \conn -> do + labelMyThread $ threadLabel <> "/setup" + E.bracket + (timeout tlsSetupTimeout (setupTLS conn) >>= maybe (fail "tls setup timeout") pure) + closeConnection + (server conn) -- | Run TCP server without TLS runLocalTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO () @@ -217,11 +232,8 @@ loadServerCredential ServerCredentials {caCertificateFile, certificateFile, priv Right credential -> pure credential Left _ -> putStrLn "invalid credential" >> exitFailure -supportedTLSServerParams :: T.Credential -> Maybe [ALPN] -> T.ServerParams -supportedTLSServerParams = supportedTLSServerParams_ defaultSupportedParams . const - -supportedTLSServerParams_ :: T.Supported -> (Maybe HostName -> T.Credential) -> Maybe [ALPN] -> T.ServerParams -supportedTLSServerParams_ serverSupported creds alpn_ = +supportedTLSServerParams :: T.Supported -> (Maybe HostName -> T.Credential) -> Maybe [ALPN] -> T.ServerParams +supportedTLSServerParams serverSupported creds alpn_ = def { T.serverWantClientCert = False, T.serverHooks = @@ -232,6 +244,34 @@ supportedTLSServerParams_ serverSupported creds alpn_ = T.serverSupported = serverSupported } +paramsAskClientCert :: TMVar (Maybe X.CertificateChain) -> T.ServerParams -> T.ServerParams +paramsAskClientCert clientCert params = + params + { T.serverWantClientCert = True, + T.serverHooks = + (T.serverHooks params) + { T.onClientCertificate = \cc -> validateClientCertificate cc >>= \case + Just reason -> T.CertificateUsageReject reason <$ atomically (tryPutTMVar clientCert Nothing) + Nothing -> T.CertificateUsageAccept <$ atomically (tryPutTMVar clientCert $ Just cc) + } + } + +validateClientCertificate :: X.CertificateChain -> IO (Maybe T.CertificateRejectReason) +validateClientCertificate cc = case chainIdCaCerts cc of + CCEmpty -> pure Nothing -- client certificates are only used for services + CCSelf cert -> validate cert + CCValid {caCert} -> validate caCert + CCLong -> pure $ Just $ T.CertificateRejectOther "chain too long" + where + validate caCert = usage <$> x509validate caCert ("", B.empty) cc + usage [] = Nothing + usage r = + Just $ + if + | XV.Expired `elem` r || XV.InFuture `elem` r -> T.CertificateRejectExpired + | XV.UnknownCA `elem` r -> T.CertificateRejectUnknownCA + | otherwise -> T.CertificateRejectOther (show r) + loadFingerprint :: ServerCredentials -> IO Fingerprint loadFingerprint ServerCredentials {caCertificateFile} = case caCertificateFile of Just certificateFile -> loadFileFingerprint certificateFile diff --git a/src/Simplex/Messaging/Transport/Shared.hs b/src/Simplex/Messaging/Transport/Shared.hs new file mode 100644 index 000000000..86a5d53e1 --- /dev/null +++ b/src/Simplex/Messaging/Transport/Shared.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} + +module Simplex.Messaging.Transport.Shared where + +import Control.Concurrent.STM +import qualified Control.Exception as E +import Control.Logger.Simple (logError) +import Data.ByteString (ByteString) +import qualified Data.X509 as X +import qualified Data.X509.CertificateStore as XS +import qualified Data.X509.Validation as XV +import Network.Socket (HostName) + +data ChainCertificates + = CCEmpty + | CCSelf X.SignedCertificate + | CCValid {leafCert :: X.SignedCertificate, idCert :: X.SignedCertificate, caCert :: X.SignedCertificate} + | CCLong + +chainIdCaCerts :: X.CertificateChain -> ChainCertificates +chainIdCaCerts (X.CertificateChain chain) = case chain of + [] -> CCEmpty + [cert] -> CCSelf cert + [leafCert, cert] -> CCValid {leafCert, idCert = cert, caCert = cert} -- current long-term online/offline certificates chain + [leafCert, idCert, caCert] -> CCValid {leafCert, idCert, caCert} -- with additional operator certificate (preset in the client) + [leafCert, idCert, _, caCert] -> CCValid {leafCert, idCert, caCert} -- with network certificate + _ -> CCLong + +x509validate :: X.SignedCertificate -> (HostName, ByteString) -> X.CertificateChain -> IO [XV.FailedReason] +x509validate caCert serviceID = XV.validate X.HashSHA256 XV.defaultHooks checks certStore noCache serviceID + where + checks = XV.defaultChecks {XV.checkFQHN = False} + certStore = XS.makeCertificateStore [caCert] + noCache = XV.ValidationCache (\_ _ _ -> pure XV.ValidationCacheUnknown) (\_ _ _ -> pure ()) + +takePeerCertChain :: TMVar (Maybe X.CertificateChain) -> IO (X.CertificateChain) +takePeerCertChain peerCert = + atomically (tryTakeTMVar peerCert) >>= \case + Just (Just cc) -> pure cc + Just Nothing -> logError "peer certificate invalid" >> E.throwIO (userError "peer certificate invalid") + Nothing -> logError "certificate hook not called" >> E.throwIO (userError "certificate hook not called") -- onServerCertificate / onClientCertificate diff --git a/src/Simplex/Messaging/Transport/WebSockets.hs b/src/Simplex/Messaging/Transport/WebSockets.hs index 34c27bedd..3ab213dcd 100644 --- a/src/Simplex/Messaging/Transport/WebSockets.hs +++ b/src/Simplex/Messaging/Transport/WebSockets.hs @@ -39,6 +39,7 @@ data WS (p :: TransportPeer) = WS wsStream :: Stream, wsConnection :: Connection, wsTransportConfig :: TransportConfig, + wsCertSent :: Bool, wsPeerCert :: X.CertificateChain } @@ -57,6 +58,8 @@ instance Transport WS where {-# INLINE transportConfig #-} getTransportConnection = getWS {-# INLINE getTransportConnection #-} + certificateSent = wsCertSent + {-# INLINE certificateSent #-} getPeerCertChain = wsPeerCert {-# INLINE getPeerCertChain #-} getSessionALPN = wsALPN @@ -83,14 +86,14 @@ instance Transport WS where then E.throwIO TEBadBlock else pure $ B.init s -getWS :: forall p. TransportPeerI p => TransportConfig -> X.CertificateChain -> T.Context -> IO (WS p) -getWS cfg wsPeerCert cxt = withTlsUnique @WS @p cxt connectWS +getWS :: forall p. TransportPeerI p => TransportConfig -> Bool -> X.CertificateChain -> T.Context -> IO (WS p) +getWS cfg wsCertSent wsPeerCert cxt = withTlsUnique @WS @p cxt connectWS where connectWS tlsUniq = do s <- makeTLSContextStream cxt wsConnection <- connectPeer s wsALPN <- T.getNegotiatedProtocol cxt - pure $ WS {tlsUniq, wsALPN, wsStream = s, wsConnection, wsTransportConfig = cfg, wsPeerCert} + pure $ WS {tlsUniq, wsALPN, wsStream = s, wsConnection, wsTransportConfig = cfg, wsCertSent, wsPeerCert} connectPeer :: Stream -> IO Connection connectPeer = case sTransportPeer @p of STServer -> acceptClientRequest diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 3d00257a2..416a61f8d 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -192,7 +192,7 @@ catchThrow action err = catchAllErrors err action throwE {-# INLINE catchThrow #-} allFinally :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m b -> ExceptT e m a -allFinally err action final = tryAllErrors err action >>= \r -> final >> either throwE pure r +allFinally err action final = tryAllErrors err action >>= \r -> final >> except r {-# INLINE allFinally #-} eitherToMaybe :: Either a b -> Maybe b diff --git a/src/Simplex/RemoteControl/Discovery.hs b/src/Simplex/RemoteControl/Discovery.hs index cd61b118b..9eb714029 100644 --- a/src/Simplex/RemoteControl/Discovery.hs +++ b/src/Simplex/RemoteControl/Discovery.hs @@ -81,7 +81,7 @@ startTLSServer port_ startedOnPort credentials hooks server = async . liftIO $ d port <- N.socketPort socket logInfo $ "System-assigned port: " <> tshow port setPort $ Just port - runTransportServerSocket started (pure socket) "RCP TLS" credentials serverParams (mkTransportServerConfig True Nothing) server + runTransportServerSocket started (pure socket) "RCP TLS" serverParams (mkTransportServerConfig True Nothing True) server setPort = void . atomically . tryPutTMVar startedOnPort serverParams = def diff --git a/tests/AgentTests/EqInstances.hs b/tests/AgentTests/EqInstances.hs index 583808c41..d59874921 100644 --- a/tests/AgentTests/EqInstances.hs +++ b/tests/AgentTests/EqInstances.hs @@ -18,7 +18,7 @@ deriving instance Eq (Connection d) deriving instance Eq (SConnType d) -deriving instance Eq (StoredRcvQueue q) +deriving instance Eq (StoredRcvQueue s) deriving instance Eq (StoredSndQueue q) diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 56a89e4c1..005c333f8 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -81,11 +81,11 @@ import Data.Word (Word16) import GHC.Stack (withFrozenCallStack) import SMPAgentClient import SMPClient -import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) +import Simplex.Messaging.Agent hiding (createConnection, joinConnection, subscribeConnection, sendMessage) import qualified Simplex.Messaging.Agent as A import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), ServerQueueInfo (..), UserNetworkInfo (..), UserNetworkType (..), waitForUserNetwork) import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), Env (..), InitialAgentServers (..), createAgentStore) -import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, REQ, SENT) +import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, REQ, SENT, INV, JOINED) import qualified Simplex.Messaging.Agent.Protocol as A import Simplex.Messaging.Agent.Store.Common (DBStore (..), withTransaction) import Simplex.Messaging.Agent.Store.Interface @@ -213,6 +213,12 @@ pattern SENT msgId = A.SENT msgId Nothing pattern Rcvd :: AgentMsgId -> AEvent 'AEConn pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}] +pattern INV :: AConnectionRequestUri -> AEvent 'AEConn +pattern INV cReq = A.INV cReq Nothing + +pattern JOINED :: SndQueueSecured -> AEvent 'AEConn +pattern JOINED sndSecure = A.JOINED sndSecure Nothing + smpCfgVPrev :: ProtocolClientConfig SMPVersion smpCfgVPrev = (smpCfg agentCfg) {serverVRange = prevRange $ serverVRange $ smpCfg agentCfg} @@ -264,13 +270,17 @@ inAnyOrder g rs = withFrozenCallStack $ do createConnection :: ConnectionModeI c => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> AE (ConnId, ConnectionRequestUri c) createConnection c userId enableNtfs cMode clientData subMode = do - (connId, CCLink cReq _) <- A.createConnection c userId enableNtfs cMode Nothing clientData (IKNoPQ PQSupportOn) subMode + (connId, (CCLink cReq _, Nothing)) <- A.createConnection c userId enableNtfs cMode Nothing clientData (IKNoPQ PQSupportOn) subMode pure (connId, cReq) joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE (ConnId, SndQueueSecured) joinConnection c userId enableNtfs cReq connInfo subMode = do connId <- A.prepareConnectionToJoin c userId enableNtfs cReq PQSupportOn - (connId,) <$> A.joinConnection c userId connId enableNtfs cReq connInfo PQSupportOn subMode + (sndSecure, Nothing) <- A.joinConnection c userId connId enableNtfs cReq connInfo PQSupportOn subMode + pure (connId, sndSecure) + +subscribeConnection :: AgentClient -> ConnId -> AE () +subscribeConnection c = void . A.subscribeConnection c sendMessage :: AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> AE AgentMsgId sendMessage c connId msgFlags msgBody = do @@ -326,7 +336,7 @@ functionalAPITests ps = do describe "should connect via 1-time short link with async join" $ testProxyMatrix ps testInviationShortLinkAsync describe "should connect via contact short link" $ testProxyMatrix ps testContactShortLink describe "should add short link to existing contact and connect" $ testProxyMatrix ps testAddContactShortLink - describe "try to create 1-time short link with prev versions" $ testProxyMatrixWithPrev ps testInviationShortLinkPrev + xdescribe "try to create 1-time short link with prev versions" $ testProxyMatrixWithPrev ps testInviationShortLinkPrev describe "server restart" $ do it "should get 1-time link data after restart" $ testInviationShortLinkRestart ps it "should connect via contact short link after restart" $ testContactShortLinkRestart ps @@ -657,9 +667,9 @@ runAgentClientTest pqSupport sqSecured viaProxy alice bob baseId = runAgentClientTestPQ :: HasCallStack => SndQueueSecured -> Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () runAgentClientTestPQ sqSecured viaProxy (alice, aPQ) (bob, bPQ) baseId = runRight_ $ do - (bobId, CCLink qInfo Nothing) <- A.createConnection alice 1 True SCMInvitation Nothing Nothing aPQ SMSubscribe + (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice 1 True SCMInvitation Nothing Nothing aPQ SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo bPQ - sqSecured' <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" bPQ SMSubscribe + (sqSecured', Nothing) <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" bPQ SMSubscribe liftIO $ sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` CR.connPQEncryption aPQ @@ -859,14 +869,14 @@ runAgentClientContactTest pqSupport sqSecured viaProxy alice bob baseId = runAgentClientContactTestPQ :: HasCallStack => SndQueueSecured -> Bool -> PQSupport -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () runAgentClientContactTestPQ sqSecured viaProxy reqPQSupport (alice, aPQ) (bob, bPQ) baseId = runRight_ $ do - (_, CCLink qInfo Nothing) <- A.createConnection alice 1 True SCMContact Nothing Nothing aPQ SMSubscribe + (_, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice 1 True SCMContact Nothing Nothing aPQ SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo bPQ - sqSecuredJoin <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" bPQ SMSubscribe + (sqSecuredJoin, Nothing) <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" bPQ SMSubscribe liftIO $ sqSecuredJoin `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` reqPQSupport bobId <- A.prepareConnectionToAccept alice True invId (CR.connPQEncryption aPQ) - sqSecured' <- acceptContact alice bobId True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + (sqSecured', Nothing) <- acceptContact alice bobId True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe liftIO $ sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get bob liftIO $ pqSup'' `shouldBe` bPQ @@ -903,7 +913,7 @@ runAgentClientContactTestPQ sqSecured viaProxy reqPQSupport (alice, aPQ) (bob, b runAgentClientContactTestPQ3 :: HasCallStack => Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () runAgentClientContactTestPQ3 viaProxy (alice, aPQ) (bob, bPQ) (tom, tPQ) baseId = runRight_ $ do - (_, CCLink qInfo Nothing) <- A.createConnection alice 1 True SCMContact Nothing Nothing aPQ SMSubscribe + (_, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice 1 True SCMContact Nothing Nothing aPQ SMSubscribe (bAliceId, bobId, abPQEnc) <- connectViaContact bob bPQ qInfo sentMessages abPQEnc alice bobId bob bAliceId (tAliceId, tomId, atPQEnc) <- connectViaContact tom tPQ qInfo @@ -912,12 +922,12 @@ runAgentClientContactTestPQ3 viaProxy (alice, aPQ) (bob, bPQ) (tom, tPQ) baseId msgId = subtract baseId . fst connectViaContact b pq qInfo = do aId <- A.prepareConnectionToJoin b 1 True qInfo pq - sqSecuredJoin <- A.joinConnection b 1 aId True qInfo "bob's connInfo" pq SMSubscribe + (sqSecuredJoin, Nothing) <- A.joinConnection b 1 aId True qInfo "bob's connInfo" pq SMSubscribe liftIO $ sqSecuredJoin `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn bId <- A.prepareConnectionToAccept alice True invId (CR.connPQEncryption aPQ) - sqSecuredAccept <- acceptContact alice bId True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + (sqSecuredAccept, Nothing) <- acceptContact alice bId True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe liftIO $ sqSecuredAccept `shouldBe` False -- agent cfg is v8 ("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get b liftIO $ pqSup'' `shouldBe` pq @@ -956,9 +966,9 @@ noMessages_ ingoreQCONT c err = tryGet `shouldReturn` () testRejectContactRequest :: HasCallStack => IO () testRejectContactRequest = withAgentClients2 $ \alice bob -> runRight_ $ do - (addrConnId, CCLink qInfo Nothing) <- A.createConnection alice 1 True SCMContact Nothing Nothing IKPQOn SMSubscribe + (addrConnId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice 1 True SCMContact Nothing Nothing IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - sqSecured <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (sqSecured, Nothing) <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId PQSupportOn _ "bob's connInfo") <- get alice liftIO $ runExceptT (rejectContact alice "abcd" invId) `shouldReturn` Left (CONN NOT_FOUND) @@ -972,7 +982,7 @@ testUpdateConnectionUserId = newUserId <- createUser alice [noAuthSrvCfg testSMPServer] [noAuthSrvCfg testXFTPServer] _ <- changeConnectionUser alice 1 connId newUserId aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - sqSecured' <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (sqSecured', Nothing) <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured' `shouldBe` True ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn @@ -1128,7 +1138,7 @@ testInviationShortLink :: HasCallStack => Bool -> AgentClient -> AgentClient -> testInviationShortLink viaProxy a b = withAgent 3 agentCfg initAgentServers testDB3 $ \c -> do let userData = "some user data" - (bId, CCLink connReq (Just shortLink)) <- runRight $ A.createConnection a 1 True SCMInvitation (Just userData) Nothing CR.IKUsePQ SMSubscribe + (bId, (CCLink connReq (Just shortLink), Nothing)) <- runRight $ A.createConnection a 1 True SCMInvitation (Just userData) Nothing CR.IKUsePQ SMSubscribe (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq @@ -1150,7 +1160,7 @@ testInviationShortLink viaProxy a b = testJoinConn_ :: Bool -> Bool -> AgentClient -> ConnId -> AgentClient -> ConnectionRequestUri c -> ExceptT AgentErrorType IO () testJoinConn_ viaProxy sndSecure a bId b connReq = do aId <- A.prepareConnectionToJoin b 1 True connReq PQSupportOn - sndSecure' <- A.joinConnection b 1 aId True connReq "bob's connInfo" PQSupportOn SMSubscribe + (sndSecure', Nothing) <- A.joinConnection b 1 aId True connReq "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sndSecure' `shouldBe` sndSecure ("", _, CONF confId _ "bob's connInfo") <- get a allowConnection a bId confId "alice's connInfo" @@ -1163,13 +1173,13 @@ testInviationShortLinkPrev :: HasCallStack => Bool -> Bool -> AgentClient -> Age testInviationShortLinkPrev viaProxy sndSecure a b = runRight_ $ do let userData = "some user data" -- can't create short link with previous version - (bId, CCLink connReq Nothing) <- A.createConnection a 1 True SCMInvitation (Just userData) Nothing CR.IKPQOn SMSubscribe + (bId, (CCLink connReq Nothing, Nothing)) <- A.createConnection a 1 True SCMInvitation (Just userData) Nothing CR.IKPQOn SMSubscribe testJoinConn_ viaProxy sndSecure a bId b connReq testInviationShortLinkAsync :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO () testInviationShortLinkAsync viaProxy a b = do let userData = "some user data" - (bId, CCLink connReq (Just shortLink)) <- runRight $ A.createConnection a 1 True SCMInvitation (Just userData) Nothing CR.IKUsePQ SMSubscribe + (bId, (CCLink connReq (Just shortLink), Nothing)) <- runRight $ A.createConnection a 1 True SCMInvitation (Just userData) Nothing CR.IKUsePQ SMSubscribe (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq @@ -1188,7 +1198,7 @@ testContactShortLink :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO testContactShortLink viaProxy a b = withAgent 3 agentCfg initAgentServers testDB3 $ \c -> do let userData = "some user data" - (contactId, CCLink connReq0 (Just shortLink)) <- runRight $ A.createConnection a 1 True SCMContact (Just userData) Nothing CR.IKPQOn SMSubscribe + (contactId, (CCLink connReq0 (Just shortLink), Nothing)) <- runRight $ A.createConnection a 1 True SCMContact (Just userData) Nothing CR.IKPQOn SMSubscribe Right connReq <- pure $ smpDecode (smpEncode connReq0) (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink @@ -1207,7 +1217,7 @@ testContactShortLink viaProxy a b = liftIO $ sndSecure `shouldBe` False ("", _, REQ invId _ "bob's connInfo") <- get a bId <- A.prepareConnectionToAccept a True invId PQSupportOn - sndSecure' <- acceptContact a bId True invId "alice's connInfo" PQSupportOn SMSubscribe + (sndSecure', Nothing) <- acceptContact a bId True invId "alice's connInfo" PQSupportOn SMSubscribe liftIO $ sndSecure' `shouldBe` True ("", _, CONF confId _ "alice's connInfo") <- get b allowConnection b aId confId "bob's connInfo" @@ -1233,7 +1243,7 @@ testContactShortLink viaProxy a b = testAddContactShortLink :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO () testAddContactShortLink viaProxy a b = withAgent 3 agentCfg initAgentServers testDB3 $ \c -> do - (contactId, CCLink connReq0 Nothing) <- runRight $ A.createConnection a 1 True SCMContact Nothing Nothing CR.IKPQOn SMSubscribe + (contactId, (CCLink connReq0 Nothing, Nothing)) <- runRight $ A.createConnection a 1 True SCMContact Nothing Nothing CR.IKPQOn SMSubscribe Right connReq <- pure $ smpDecode (smpEncode connReq0) -- let userData = "some user data" shortLink <- runRight $ setConnShortLink a contactId SCMContact userData Nothing @@ -1254,7 +1264,7 @@ testAddContactShortLink viaProxy a b = liftIO $ sndSecure `shouldBe` False ("", _, REQ invId _ "bob's connInfo") <- get a bId <- A.prepareConnectionToAccept a True invId PQSupportOn - sndSecure' <- acceptContact a bId True invId "alice's connInfo" PQSupportOn SMSubscribe + (sndSecure', Nothing) <- acceptContact a bId True invId "alice's connInfo" PQSupportOn SMSubscribe liftIO $ sndSecure' `shouldBe` True ("", _, CONF confId _ "alice's connInfo") <- get b allowConnection b aId confId "bob's connInfo" @@ -1273,7 +1283,7 @@ testAddContactShortLink viaProxy a b = testInviationShortLinkRestart :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testInviationShortLinkRestart ps = withAgentClients2 $ \a b -> do let userData = "some user data" - (bId, CCLink connReq (Just shortLink)) <- withSmpServer ps $ + (bId, (CCLink connReq (Just shortLink), Nothing)) <- withSmpServer ps $ runRight $ A.createConnection a 1 True SCMInvitation (Just userData) Nothing CR.IKUsePQ SMOnlyCreate withSmpServer ps $ do runRight_ $ subscribeConnection a bId @@ -1285,7 +1295,7 @@ testInviationShortLinkRestart ps = withAgentClients2 $ \a b -> do testContactShortLinkRestart :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testContactShortLinkRestart ps = withAgentClients2 $ \a b -> do let userData = "some user data" - (contactId, CCLink connReq0 (Just shortLink)) <- withSmpServer ps $ + (contactId, (CCLink connReq0 (Just shortLink), Nothing)) <- withSmpServer ps $ runRight $ A.createConnection a 1 True SCMContact (Just userData) Nothing CR.IKPQOn SMOnlyCreate Right connReq <- pure $ smpDecode (smpEncode connReq0) let updatedData = "updated user data" @@ -1305,7 +1315,7 @@ testContactShortLinkRestart ps = withAgentClients2 $ \a b -> do testAddContactShortLinkRestart :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testAddContactShortLinkRestart ps = withAgentClients2 $ \a b -> do let userData = "some user data" - ((contactId, CCLink connReq0 Nothing), shortLink) <- withSmpServer ps $ runRight $ do + ((contactId, (CCLink connReq0 Nothing, Nothing)), shortLink) <- withSmpServer ps $ runRight $ do r@(contactId, _) <- A.createConnection a 1 True SCMContact Nothing Nothing CR.IKPQOn SMOnlyCreate (r,) <$> setConnShortLink a contactId SCMContact userData Nothing Right connReq <- pure $ smpDecode (smpEncode connReq0) @@ -1325,7 +1335,7 @@ testAddContactShortLinkRestart ps = withAgentClients2 $ \a b -> do testOldContactQueueShortLink :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testOldContactQueueShortLink ps@(_, msType) = withAgentClients2 $ \a b -> do - (contactId, CCLink connReq Nothing) <- withSmpServer ps $ runRight $ + (contactId, (CCLink connReq Nothing, Nothing)) <- withSmpServer ps $ runRight $ A.createConnection a 1 True SCMContact Nothing Nothing CR.IKPQOn SMOnlyCreate -- make it an "old" queue let updateStoreLog f = replaceSubstringInFile f " queue_mode=C" "" @@ -2057,9 +2067,9 @@ makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn True makeConnectionForUsers_ :: HasCallStack => PQSupport -> SndQueueSecured -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnectionForUsers_ pqSupport sqSecured alice aliceUserId bob bobUserId = do - (bobId, CCLink qInfo Nothing) <- A.createConnection alice aliceUserId True SCMInvitation Nothing Nothing (CR.IKNoPQ pqSupport) SMSubscribe + (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice aliceUserId True SCMInvitation Nothing Nothing (CR.IKNoPQ pqSupport) SMSubscribe aliceId <- A.prepareConnectionToJoin bob bobUserId True qInfo pqSupport - sqSecured' <- A.joinConnection bob bobUserId aliceId True qInfo "bob's connInfo" pqSupport SMSubscribe + (sqSecured', Nothing) <- A.joinConnection bob bobUserId aliceId True qInfo "bob's connInfo" pqSupport SMSubscribe liftIO $ sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` pqSupport diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 98f759054..463977331 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -8,6 +8,8 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} @@ -58,7 +60,7 @@ import Data.Time.Clock.System (systemToUTCTime) import qualified Database.PostgreSQL.Simple as PSQL import NtfClient import SMPAgentClient (agentCfg, initAgentServers, initAgentServers2, testDB, testDB2, testNtfServer, testNtfServer2) -import SMPClient (AServerConfig (..), AServerStoreCfg (..), cfgJ2QS, cfgMS, cfgVPrev, ntfTestPort, ntfTestPort2, testServerStoreConfig, testPort, testPort2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn, withServerCfg) +import SMPClient import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore') import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers) @@ -79,6 +81,7 @@ import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgFlags (MsgFlags), NMsgMe import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (AStoreType (..), ServerConfig (..)) import Simplex.Messaging.Transport (ASrvTransport) +import Simplex.Messaging.Transport.Server (TransportServerConfig (..)) import System.Process (callCommand) import Test.Hspec hiding (fit, it) import UnliftIO @@ -172,6 +175,7 @@ notificationTests ps@(t, _) = do withAPNSMockServer $ \apns -> withNtfServerOn t ntfTestPort2 ntfTestDBCfg2 . withNtfServerThreadOn t ntfTestPort ntfTestDBCfg $ \ntf -> testNotificationsNewToken apns ntf + it "should migrate to service subscriptions" $ testMigrateToServiceSubscriptions ps testNtfMatrix :: HasCallStack => (ASrvTransport, AStoreType) -> (APNSMockServer -> AgentMsgId -> AgentClient -> AgentClient -> IO ()) -> Spec testNtfMatrix ps@(_, msType) runTest = do @@ -727,7 +731,7 @@ testChangeToken apns = withAgent 1 agentCfg initAgentServers testDB2 $ \bob -> d pure (aliceId, bobId) withAgent 3 agentCfg initAgentServers testDB $ \alice1 -> runRight_ $ do - subscribeConnection alice1 bobId + void $ subscribeConnection alice1 bobId -- change notification token void $ registerTestToken alice1 "bcde" NMInstant apns -- send message, receive notification @@ -912,6 +916,94 @@ testNotificationsNewToken apns oldNtf = let testMessageAC = testMessage_ apns a acId c caId testMessageAC "greetings" +testMigrateToServiceSubscriptions :: HasCallStack => (ASrvTransport, AStoreType) -> IO () +testMigrateToServiceSubscriptions ps@(t, msType) = withAgentClients2 $ \a b -> do + (c1, c2, c3) <- withSmpServerConfigOn t cfgNoService testPort $ \_ -> do + (c1, c2) <- withAPNSMockServer $ \apns -> do + withNtfServerCfg ntfCfgNoService $ \_ -> runRight $ do + _tkn <- registerTestToken a "abcd" NMInstant apns + -- create 2 connections with ntfs, test delivery + c1 <- testConnectMsg apns a b "hello" + c2 <- testConnectMsg apns a b "hello too" + pure (c1, c2) + liftIO $ threadDelay 250000 + fmap (c1,c2,) $ withAPNSMockServer $ \apns -> + withNtfServer t $ runRight $ do + liftIO $ threadDelay 250000 + testSendMsg apns a b c1 "hello 1" + testSendMsg apns a b c2 "hello 2" + testConnectMsg apns a b "hello 3" + serverDOWN a b 3 + + -- this session creates association of subscriptions with service + c4 <- withAPNSMockServer $ \apns -> withSmpServer ps $ withNtfServer t $ do + serverUP a b 3 + runRight $ do + liftIO $ threadDelay 250000 + testSendMsg apns a b c1 "hey 1" + testSendMsg apns a b c2 "hey 2" + testSendMsg apns a b c3 "hey 3" + testConnectMsg apns a b "hey 4" + serverDOWN a b 4 + + -- this session uses service to subscribe + c5 <- withAPNSMockServer $ \apns -> withSmpServer ps $ withNtfServer t $ do + serverUP a b 4 + runRight $ do + liftIO $ threadDelay 250000 + testSendMsg apns a b c1 "hi 1" + testSendMsg apns a b c2 "hi 2" + testSendMsg apns a b c3 "hi 3" + testSendMsg apns a b c4 "hi 4" + testConnectMsg apns a b "hi 5" + serverDOWN a b 5 + + -- Ntf server does not use server, subscriptions downgrade + c6 <- withAPNSMockServer $ \apns -> withSmpServer ps $ withNtfServerCfg ntfCfgNoService $ \_ -> do + serverUP a b 5 + runRight $ do + testSendMsg apns a b c1 "msg 1" + testSendMsg apns a b c2 "msg 2" + testSendMsg apns a b c3 "msg 3" + testSendMsg apns a b c4 "msg 4" + testSendMsg apns a b c5 "msg 5" + testConnectMsg apns a b "msg 6" + serverDOWN a b 6 + + withAPNSMockServer $ \apns -> withSmpServerConfigOn t cfgNoService testPort $ \_ -> withNtfServerCfg ntfCfgNoService $ \_ -> do + serverUP a b 6 + runRight_ $ do + testSendMsg apns a b c1 "1" + testSendMsg apns a b c2 "2" + testSendMsg apns a b c3 "3" + testSendMsg apns a b c4 "4" + testSendMsg apns a b c5 "5" + testSendMsg apns a b c6 "6" + void $ testConnectMsg apns a b "7" + serverDOWN a b 7 + where + testConnectMsg apns a b msg = do + conn <- makeConnection a b + liftIO $ threadDelay 250000 + testSendMsg apns a b conn msg + pure conn + testSendMsg :: HasCallStack => APNSMockServer -> AgentClient -> AgentClient -> (ConnId, ConnId) -> SMP.MsgBody -> ExceptT AgentErrorType IO () + testSendMsg apns a b (abId, baId) = testMessage_ apns a abId b baId + serverDOWN a b n = do + ("", "", DOWN _ cs) <- nGet a + ("", "", DOWN _ cs') <- nGet b + length cs `shouldBe` n + length cs' `shouldBe` n + serverUP a b n = do + ("", "", UP _ cs) <- nGet a + ("", "", UP _ cs') <- nGet b + length cs `shouldBe` n + length cs' `shouldBe` n + cfgNoService = updateCfg (cfgMS msType) $ \(cfg' :: ServerConfig s) -> + let ServerConfig {transportConfig} = cfg' + in cfg' {transportConfig = transportConfig {askClientCert = False}} :: ServerConfig s + ntfCfgNoService = ntfServerCfg {useServiceCreds = False, transports = [(ntfTestPort, t, False)]} + testMessage_ :: HasCallStack => APNSMockServer -> AgentClient -> ConnId -> AgentClient -> ConnId -> SMP.MsgBody -> ExceptT AgentErrorType IO () testMessage_ apns a aId b bId msg = do msgId <- sendMessage b aId (SMP.MsgFlags True) msg diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index aa4c863d5..fc96ef896 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -230,6 +230,7 @@ rcvQueue1 = sndId = EntityId "2345", queueMode = Just QMMessaging, shortLink = Nothing, + clientService = Nothing, status = New, dbQueueId = DBNewEntity, primary = True, @@ -443,6 +444,7 @@ testUpgradeSndConnToDuplex = sndId = EntityId "4567", queueMode = Just QMMessaging, shortLink = Nothing, + clientService = Nothing, status = New, dbQueueId = DBNewEntity, rcvSwchStatus = Nothing, diff --git a/tests/CLITests.hs b/tests/CLITests.hs index 46a184df0..5dff5ce00 100644 --- a/tests/CLITests.hs +++ b/tests/CLITests.hs @@ -190,7 +190,7 @@ smpServerTestStatic = do X.Certificate {X.certPubKey = X.PubKeyEd25519 _k} : _ca -> print _ca -- pure () leaf : _ -> error $ "Unexpected leaf cert: " <> show leaf [] -> error "Empty chain" - runRight_ . void $ smpClientHandshake tls Nothing caSMP supportedClientSMPRelayVRange False + runRight_ . void $ smpClientHandshake tls Nothing caSMP supportedClientSMPRelayVRange False Nothing logDebug "Combined SMP works" where getCerts :: TLS 'TClient -> [X.Certificate] diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index 40c4ab98d..9069cfc89 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -35,7 +35,7 @@ batchingTests = do it "should break on message that does not fit" testBatchWithMessageV6 it "should break on large message" testBatchWithLargeMessageV6 describe "SMP current" $ do - it "should batch with 136 subscriptions per batch" testBatchSubscriptions + it "should batch with 135 subscriptions per batch" testBatchSubscriptions it "should break on message that does not fit" testBatchWithMessage it "should break on large message" testBatchWithLargeMessage describe "batchTransmissions'" $ do @@ -44,7 +44,7 @@ batchingTests = do it "should break on message that does not fit" testClientBatchWithMessageV6 it "should break on large message" testClientBatchWithLargeMessageV6 describe "SMP current" $ do - it "should batch with 136 subscriptions per batch" testClientBatchSubscriptions + it "should batch with 135 subscriptions per batch" testClientBatchSubscriptions it "should batch with 255 ENDs per batch" testClientBatchENDs it "should batch with 80 NMSGs per batch" testClientBatchNMSGs it "should break on message that does not fit" testClientBatchWithMessage @@ -54,10 +54,11 @@ testBatchSubscriptionsV6 :: IO () testBatchSubscriptionsV6 = do sessId <- atomically . C.randomBytes 32 =<< C.newRandom subs <- replicateM 250 $ randomSUBv6 sessId - let batches1 = batchTransmissions False smpBlockSize $ L.fromList subs + let thParams = testTHandleParams minServerSMPRelayVersion sessId + batches1 = batchTransmissions thParams {batch = False} $ L.fromList subs all lenOk1 batches1 `shouldBe` True length batches1 `shouldBe` 250 - let batches = batchTransmissions True smpBlockSize $ L.fromList subs + let batches = batchTransmissions thParams $ L.fromList subs length batches `shouldBe` 3 [TBTransmissions s1 n1 _, TBTransmissions s2 n2 _, TBTransmissions s3 n3 _] <- pure batches (n1, n2, n3) `shouldBe` (38, 106, 106) @@ -67,13 +68,14 @@ testBatchSubscriptions :: IO () testBatchSubscriptions = do sessId <- atomically . C.randomBytes 32 =<< C.newRandom subs <- replicateM 300 $ randomSUB sessId - let batches1 = batchTransmissions False smpBlockSize $ L.fromList subs + let thParams = testTHandleParams currentClientSMPRelayVersion sessId + batches1 = batchTransmissions thParams {batch = False} $ L.fromList subs all lenOk1 batches1 `shouldBe` True length batches1 `shouldBe` 300 - let batches = batchTransmissions True smpBlockSize $ L.fromList subs + let batches = batchTransmissions thParams $ L.fromList subs length batches `shouldBe` 3 [TBTransmissions s1 n1 _, TBTransmissions s2 n2 _, TBTransmissions s3 n3 _] <- pure batches - (n1, n2, n3) `shouldBe` (28, 136, 136) + (n1, n2, n3) `shouldBe` (30, 135, 135) all lenOk [s1, s2, s3] `shouldBe` True testBatchWithMessageV6 :: IO () @@ -82,11 +84,12 @@ testBatchWithMessageV6 = do subs1 <- replicateM 60 $ randomSUBv6 sessId send <- randomSENDv6 sessId 8000 subs2 <- replicateM 40 $ randomSUBv6 sessId - let cmds = subs1 <> [send] <> subs2 - batches1 = batchTransmissions False smpBlockSize $ L.fromList cmds + let thParams = testTHandleParams minServerSMPRelayVersion sessId + cmds = subs1 <> [send] <> subs2 + batches1 = batchTransmissions thParams {batch = False} $ L.fromList cmds all lenOk1 batches1 `shouldBe` True length batches1 `shouldBe` 101 - let batches = batchTransmissions True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions thParams $ L.fromList cmds length batches `shouldBe` 2 [TBTransmissions s1 n1 _, TBTransmissions s2 n2 _] <- pure batches (n1, n2) `shouldBe` (47, 54) @@ -98,14 +101,15 @@ testBatchWithMessage = do subs1 <- replicateM 60 $ randomSUB sessId send <- randomSEND sessId 8000 subs2 <- replicateM 40 $ randomSUB sessId - let cmds = subs1 <> [send] <> subs2 - batches1 = batchTransmissions False smpBlockSize $ L.fromList cmds + let thParams = testTHandleParams currentClientSMPRelayVersion sessId + cmds = subs1 <> [send] <> subs2 + batches1 = batchTransmissions thParams {batch = False} $ L.fromList cmds all lenOk1 batches1 `shouldBe` True length batches1 `shouldBe` 101 - let batches = batchTransmissions True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions thParams $ L.fromList cmds length batches `shouldBe` 2 [TBTransmissions s1 n1 _, TBTransmissions s2 n2 _] <- pure batches - (n1, n2) `shouldBe` (32, 69) + (n1, n2) `shouldBe` (33, 68) all lenOk [s1, s2] `shouldBe` True testBatchWithLargeMessageV6 :: IO () @@ -114,14 +118,15 @@ testBatchWithLargeMessageV6 = do subs1 <- replicateM 50 $ randomSUBv6 sessId send <- randomSENDv6 sessId 17000 subs2 <- replicateM 150 $ randomSUBv6 sessId - let cmds = subs1 <> [send] <> subs2 - batches1 = batchTransmissions False smpBlockSize $ L.fromList cmds + let thParams = testTHandleParams minServerSMPRelayVersion sessId + cmds = subs1 <> [send] <> subs2 + batches1 = batchTransmissions thParams {batch = False} $ L.fromList cmds all lenOk1 batches1 `shouldBe` False length batches1 `shouldBe` 201 let batches1' = take 50 batches1 <> drop 51 batches1 all lenOk1 batches1' `shouldBe` True length batches1' `shouldBe` 200 - let batches = batchTransmissions True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions thParams $ L.fromList cmds length batches `shouldBe` 4 [TBTransmissions s1 n1 _, TBError TELargeMsg _, TBTransmissions s2 n2 _, TBTransmissions s3 n3 _] <- pure batches (n1, n2, n3) `shouldBe` (50, 44, 106) @@ -133,26 +138,27 @@ testBatchWithLargeMessage = do subs1 <- replicateM 60 $ randomSUB sessId send <- randomSEND sessId 17000 subs2 <- replicateM 150 $ randomSUB sessId - let cmds = subs1 <> [send] <> subs2 - batches1 = batchTransmissions False smpBlockSize $ L.fromList cmds + let thParams = testTHandleParams currentClientSMPRelayVersion sessId + cmds = subs1 <> [send] <> subs2 + batches1 = batchTransmissions thParams {batch = False} $ L.fromList cmds all lenOk1 batches1 `shouldBe` False length batches1 `shouldBe` 211 let batches1' = take 60 batches1 <> drop 61 batches1 all lenOk1 batches1' `shouldBe` True length batches1' `shouldBe` 210 - let batches = batchTransmissions True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions thParams $ L.fromList cmds length batches `shouldBe` 4 [TBTransmissions s1 n1 _, TBError TELargeMsg _, TBTransmissions s2 n2 _, TBTransmissions s3 n3 _] <- pure batches - (n1, n2, n3) `shouldBe` (60, 14, 136) + (n1, n2, n3) `shouldBe` (60, 15, 135) all lenOk [s1, s2, s3] `shouldBe` True testClientBatchSubscriptionsV6 :: IO () testClientBatchSubscriptionsV6 = do client <- testClientStubV6 subs <- replicateM 250 $ randomSUBCmdV6 client - let batches1 = batchTransmissions' False smpBlockSize $ L.fromList subs + let batches1 = batchTransmissions' (thParams client) {batch = False} $ L.fromList subs all lenOk1 batches1 `shouldBe` True - let batches = batchTransmissions' True smpBlockSize $ L.fromList subs + let batches = batchTransmissions' (thParams client) $ L.fromList subs length batches `shouldBe` 3 [TBTransmissions s1 n1 rs1, TBTransmissions s2 n2 rs2, TBTransmissions s3 n3 rs3] <- pure batches (n1, n2, n3) `shouldBe` (38, 106, 106) @@ -163,13 +169,13 @@ testClientBatchSubscriptions :: IO () testClientBatchSubscriptions = do client <- testClientStub subs <- replicateM 300 $ randomSUBCmd client - let batches1 = batchTransmissions' False smpBlockSize $ L.fromList subs + let batches1 = batchTransmissions' (thParams client) {batch = False} $ L.fromList subs all lenOk1 batches1 `shouldBe` True - let batches = batchTransmissions' True smpBlockSize $ L.fromList subs + let batches = batchTransmissions' (thParams client) $ L.fromList subs length batches `shouldBe` 3 [TBTransmissions s1 n1 rs1, TBTransmissions s2 n2 rs2, TBTransmissions s3 n3 rs3] <- pure batches - (n1, n2, n3) `shouldBe` (28, 136, 136) - (length rs1, length rs2, length rs3) `shouldBe` (28, 136, 136) + (n1, n2, n3) `shouldBe` (30, 135, 135) + (length rs1, length rs2, length rs3) `shouldBe` (30, 135, 135) all lenOk [s1, s2, s3] `shouldBe` True testClientBatchENDs :: IO () @@ -177,9 +183,9 @@ testClientBatchENDs = do client <- testClientStub ends <- replicateM 300 randomENDCmd let ends' = map (\t -> Right (Nothing, encodeTransmission (thParams client) t)) ends - batches1 = batchTransmissions False smpBlockSize $ L.fromList ends' + batches1 = batchTransmissions (thParams client) {batch = False} $ L.fromList ends' all lenOk1 batches1 `shouldBe` True - let batches = batchTransmissions True smpBlockSize $ L.fromList ends' + let batches = batchTransmissions (thParams client) $ L.fromList ends' length batches `shouldBe` 2 [TBTransmissions s1 n1 rs1, TBTransmissions s2 n2 rs2] <- pure batches (n1, n2) `shouldBe` (45, 255) @@ -192,9 +198,9 @@ testClientBatchNMSGs = do ts <- getSystemTime ntfs <- replicateM 200 $ randomNMSGCmd ts let ntfs' = map (\t -> Right (Nothing, encodeTransmission (thParams client) t)) ntfs - batches1 = batchTransmissions False smpBlockSize $ L.fromList ntfs' + batches1 = batchTransmissions (thParams client) {batch = False} $ L.fromList ntfs' all lenOk1 batches1 `shouldBe` True - let batches = batchTransmissions True smpBlockSize $ L.fromList ntfs' + let batches = batchTransmissions (thParams client) $ L.fromList ntfs' length batches `shouldBe` 3 [TBTransmissions s1 n1 rs1, TBTransmissions s2 n2 rs2, TBTransmissions s3 n3 rs3] <- pure batches (n1, n2, n3) `shouldBe` (40, 80, 80) @@ -208,10 +214,10 @@ testClientBatchWithMessageV6 = do send <- randomSENDCmdV6 client 8000 subs2 <- replicateM 40 $ randomSUBCmdV6 client let cmds = subs1 <> [send] <> subs2 - batches1 = batchTransmissions' False smpBlockSize $ L.fromList cmds + batches1 = batchTransmissions' (thParams client) {batch = False} $ L.fromList cmds all lenOk1 batches1 `shouldBe` True length batches1 `shouldBe` 101 - let batches = batchTransmissions' True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions' (thParams client) $ L.fromList cmds length batches `shouldBe` 2 [TBTransmissions s1 n1 rs1, TBTransmissions s2 n2 rs2] <- pure batches (n1, n2) `shouldBe` (47, 54) @@ -225,14 +231,14 @@ testClientBatchWithMessage = do send <- randomSENDCmd client 8000 subs2 <- replicateM 40 $ randomSUBCmd client let cmds = subs1 <> [send] <> subs2 - batches1 = batchTransmissions' False smpBlockSize $ L.fromList cmds + batches1 = batchTransmissions' (thParams client) {batch = False} $ L.fromList cmds all lenOk1 batches1 `shouldBe` True length batches1 `shouldBe` 101 - let batches = batchTransmissions' True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions' (thParams client) $ L.fromList cmds length batches `shouldBe` 2 [TBTransmissions s1 n1 rs1, TBTransmissions s2 n2 rs2] <- pure batches - (n1, n2) `shouldBe` (32, 69) - (length rs1, length rs2) `shouldBe` (32, 69) + (n1, n2) `shouldBe` (33, 68) + (length rs1, length rs2) `shouldBe` (33, 68) all lenOk [s1, s2] `shouldBe` True testClientBatchWithLargeMessageV6 :: IO () @@ -242,14 +248,14 @@ testClientBatchWithLargeMessageV6 = do send <- randomSENDCmdV6 client 17000 subs2 <- replicateM 150 $ randomSUBCmdV6 client let cmds = subs1 <> [send] <> subs2 - batches1 = batchTransmissions' False smpBlockSize $ L.fromList cmds + batches1 = batchTransmissions' (thParams client) {batch = False} $ L.fromList cmds all lenOk1 batches1 `shouldBe` False length batches1 `shouldBe` 201 let batches1' = take 50 batches1 <> drop 51 batches1 all lenOk1 batches1' `shouldBe` True length batches1' `shouldBe` 200 -- - let batches = batchTransmissions' True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions' (thParams client) $ L.fromList cmds length batches `shouldBe` 4 [TBTransmissions s1 n1 rs1, TBError TELargeMsg _, TBTransmissions s2 n2 rs2, TBTransmissions s3 n3 rs3] <- pure batches (n1, n2, n3) `shouldBe` (50, 44, 106) @@ -257,7 +263,7 @@ testClientBatchWithLargeMessageV6 = do all lenOk [s1, s2, s3] `shouldBe` True -- let cmds' = [send] <> subs1 <> subs2 - let batches' = batchTransmissions' True smpBlockSize $ L.fromList cmds' + let batches' = batchTransmissions' (thParams client) $ L.fromList cmds' length batches' `shouldBe` 3 [TBError TELargeMsg _, TBTransmissions s1' n1' rs1', TBTransmissions s2' n2' rs2'] <- pure batches' (n1', n2') `shouldBe` (94, 106) @@ -271,26 +277,26 @@ testClientBatchWithLargeMessage = do send <- randomSENDCmd client 17000 subs2 <- replicateM 150 $ randomSUBCmd client let cmds = subs1 <> [send] <> subs2 - batches1 = batchTransmissions' False smpBlockSize $ L.fromList cmds + batches1 = batchTransmissions' (thParams client) {batch = False} $ L.fromList cmds all lenOk1 batches1 `shouldBe` False length batches1 `shouldBe` 211 let batches1' = take 60 batches1 <> drop 61 batches1 all lenOk1 batches1' `shouldBe` True length batches1' `shouldBe` 210 -- - let batches = batchTransmissions' True smpBlockSize $ L.fromList cmds + let batches = batchTransmissions' (thParams client) $ L.fromList cmds length batches `shouldBe` 4 [TBTransmissions s1 n1 rs1, TBError TELargeMsg _, TBTransmissions s2 n2 rs2, TBTransmissions s3 n3 rs3] <- pure batches - (n1, n2, n3) `shouldBe` (60, 14, 136) - (length rs1, length rs2, length rs3) `shouldBe` (60, 14, 136) + (n1, n2, n3) `shouldBe` (60, 15, 135) + (length rs1, length rs2, length rs3) `shouldBe` (60, 15, 135) all lenOk [s1, s2, s3] `shouldBe` True -- let cmds' = [send] <> subs1 <> subs2 - let batches' = batchTransmissions' True smpBlockSize $ L.fromList cmds' + let batches' = batchTransmissions' (thParams client) $ L.fromList cmds' length batches' `shouldBe` 3 [TBError TELargeMsg _, TBTransmissions s1' n1' rs1', TBTransmissions s2' n2' rs2'] <- pure batches' - (n1', n2') `shouldBe` (74, 136) - (length rs1', length rs2') `shouldBe` (74, 136) + (n1', n2') `shouldBe` (75, 135) + (length rs1', length rs2') `shouldBe` (75, 135) all lenOk [s1', s2'] `shouldBe` True testClientStubV6 :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) @@ -307,13 +313,14 @@ testClientStub = do thAuth_ <- testTHandleAuth currentClientSMPRelayVersion g rKey smpClientStub g sessId currentClientSMPRelayVersion thAuth_ -randomSUBv6 :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSUBv6 :: ByteString -> IO (Either TransportError (Maybe TAuthorizations, ByteString)) randomSUBv6 = randomSUB_ C.SEd25519 minServerSMPRelayVersion -randomSUB :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSUB :: ByteString -> IO (Either TransportError (Maybe TAuthorizations, ByteString)) randomSUB = randomSUB_ C.SEd25519 currentClientSMPRelayVersion -randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +-- TODO [certs] test with the additional certificate signature +randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> IO (Either TransportError (Maybe TAuthorizations, ByteString)) randomSUB_ a v sessId = do g <- C.newRandom rId <- atomically $ C.randomBytes 24 g @@ -322,7 +329,7 @@ randomSUB_ a v sessId = do thAuth_ <- testTHandleAuth v g rKey let thParams = testTHandleParams v sessId TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (CorrId corrId, EntityId rId, Cmd SRecipient SUB) - pure $ (,tToSend) <$> authTransmission thAuth_ (Just rpKey) nonce tForAuth + pure $ (,tToSend) <$> authTransmission thAuth_ True (Just rpKey) nonce tForAuth randomSUBCmdV6 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) randomSUBCmdV6 = randomSUBCmd_ C.SEd25519 @@ -354,13 +361,13 @@ randomNMSGCmd ts = do Right encNMsgMeta <- pure $ C.cbEncrypt (C.dh' k pk) nonce (smpEncode msgMeta) 128 pure (CorrId "", EntityId nId, NMSG nonce encNMsgMeta) -randomSENDv6 :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSENDv6 :: ByteString -> Int -> IO (Either TransportError (Maybe TAuthorizations, ByteString)) randomSENDv6 = randomSEND_ C.SEd25519 minServerSMPRelayVersion -randomSEND :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSEND :: ByteString -> Int -> IO (Either TransportError (Maybe TAuthorizations, ByteString)) randomSEND = randomSEND_ C.SX25519 currentClientSMPRelayVersion -randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> Int -> IO (Either TransportError (Maybe TAuthorizations, ByteString)) randomSEND_ a v sessId len = do g <- C.newRandom sId <- atomically $ C.randomBytes 24 g @@ -370,7 +377,7 @@ randomSEND_ a v sessId len = do msg <- atomically $ C.randomBytes len g let thParams = testTHandleParams v sessId TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (CorrId corrId, EntityId sId, Cmd SSender $ SEND noMsgFlags msg) - pure $ (,tToSend) <$> authTransmission thAuth_ (Just spKey) nonce tForAuth + pure $ (,tToSend) <$> authTransmission thAuth_ False (Just spKey) nonce tForAuth testTHandleParams :: VersionSMP -> ByteString -> THandleParams SMPVersion 'TClient testTHandleParams v sessionId = @@ -382,19 +389,20 @@ testTHandleParams v sessionId = thAuth = Nothing, implySessId = v >= authCmdsSMPVersion, encryptBlock = Nothing, - batch = True + batch = True, + serviceAuth = v >= serviceCertsSMPVersion } testTHandleAuth :: VersionSMP -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe (THandleAuth 'TClient)) -testTHandleAuth v g (C.APublicAuthKey a serverPeerPubKey) = case a of +testTHandleAuth v g (C.APublicAuthKey a peerServerPubKey) = case a of C.SX25519 | v >= authCmdsSMPVersion -> do ca <- head <$> XS.readCertificates "tests/fixtures/ca.crt" serverCert <- head <$> XS.readCertificates "tests/fixtures/server.crt" serverKey <- head <$> XF.readKeyFile "tests/fixtures/server.key" signKey <- either error pure $ C.x509ToPrivate (serverKey, []) >>= C.privKey @C.APrivateSignKey (serverAuthPub, _) <- atomically $ C.generateKeyPair @'C.X25519 g - let serverCertKey = CertChainPubKey (X.CertificateChain [serverCert, ca]) (C.signX509 signKey $ C.toPubKey C.publicToX509 serverAuthPub) - pure $ Just THAuthClient {serverPeerPubKey, serverCertKey, sessSecret = Nothing} + let peerServerCertKey = CertChainPubKey (X.CertificateChain [serverCert, ca]) (C.signX509 signKey $ C.toPubKey C.publicToX509 serverAuthPub) + pure $ Just THAuthClient {peerServerPubKey, peerServerCertKey, clientService = Nothing, sessSecret = Nothing} _ -> pure Nothing randomSENDCmdV6 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) diff --git a/tests/CoreTests/MsgStoreTests.hs b/tests/CoreTests/MsgStoreTests.hs index 2d121e8ef..d25b00c7c 100644 --- a/tests/CoreTests/MsgStoreTests.hs +++ b/tests/CoreTests/MsgStoreTests.hs @@ -131,7 +131,8 @@ testNewQueueRecData g qm queueData = do queueData, notifier = Nothing, status = EntityActive, - updatedAt = Nothing + updatedAt = Nothing, + rcvServiceId = Nothing } pure (rId, qr) where @@ -205,7 +206,7 @@ testExportImportStore ms = do g <- C.newRandom (rId1, qr1) <- testNewQueueRec g QMMessaging (rId2, qr2) <- testNewQueueRec g QMMessaging - sl <- readWriteQueueStore True (mkQueue ms True) testStoreLogFile $ queueStore ms + sl <- readWriteQueueStore True (mkQueue ms True) testStoreLogFile $ stmQueueStore ms runRight_ $ do let write q s = writeMsg ms q True =<< mkMessage s q1 <- ExceptT $ addQueue ms rId1 qr1 @@ -230,7 +231,7 @@ testExportImportStore ms = do closeStoreLog sl let cfg = (testJournalStoreCfg MQStoreCfg :: JournalStoreConfig 'QSMemory) {storePath = testStoreMsgsDir2} ms' <- newMsgStore cfg - readWriteQueueStore True (mkQueue ms' True) testStoreLogFile (queueStore ms') >>= closeStoreLog + readWriteQueueStore True (mkQueue ms' True) testStoreLogFile (stmQueueStore ms') >>= closeStoreLog stats@MessageStats {storedMsgsCount = 5, expiredMsgsCount = 0, storedQueues = 2} <- importMessages False ms' testStoreMsgsFile Nothing False printMessageStats "Messages" stats diff --git a/tests/CoreTests/StoreLogTests.hs b/tests/CoreTests/StoreLogTests.hs index cb5861d7a..f03f1d2ee 100644 --- a/tests/CoreTests/StoreLogTests.hs +++ b/tests/CoreTests/StoreLogTests.hs @@ -17,6 +17,8 @@ import qualified Data.ByteString.Char8 as B import Data.Either (partitionEithers) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M +import qualified Data.X509 as X +import qualified Data.X509.Validation as XV import SMPClient import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String @@ -29,6 +31,8 @@ import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.STM (STMQueueStore (..)) import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.StoreLog +import Simplex.Messaging.Transport (SMPServiceRole (..)) +import Simplex.Messaging.Transport.Credentials (genCredentials) import Test.Hspec hiding (fit, it) import Util @@ -43,7 +47,8 @@ testNtfCreds g = do NtfCreds { notifierId = EntityId "ijkl", notifierKey, - rcvNtfDhSecret = C.dh' k pk + rcvNtfDhSecret = C.dh' k pk, + ntfServiceId = Nothing } data StoreLogTestCase r s = SLTC {name :: String, saved :: [r], state :: s, compacted :: [r]} @@ -52,6 +57,8 @@ type SMPStoreLogTestCase = StoreLogTestCase StoreLogRecord (M.Map RecipientId Qu deriving instance Eq QueueRec +deriving instance Eq ServiceRec + deriving instance Eq StoreLogRecord deriving instance Eq NtfCreds @@ -60,8 +67,8 @@ storeLogTests :: Spec storeLogTests = forM_ [QMMessaging, QMContact] $ \qm -> do g <- runIO C.newRandom - ((rId, qr), ntfCreds, date) <- runIO $ - (,,) <$> testNewQueueRec g qm <*> testNtfCreds g <*> getSystemDate + ((rId, qr), ntfCreds, date, sr@ServiceRec {serviceId}) <- runIO $ + (,,,) <$> testNewQueueRec g qm <*> testNtfCreds g <*> getSystemDate <*> newTestServiceRec g ((rId', qr'), lnkId, qd) <- runIO $ do lnkId <- atomically $ EntityId <$> C.randomBytes 24 g let qd = (EncDataBytes "fixed data", EncDataBytes "user data") @@ -113,6 +120,12 @@ storeLogTests = compacted = [CreateQueue rId qr {notifier = Just ntfCreds}], state = M.fromList [(rId, qr {notifier = Just ntfCreds})] }, + SLTC + { name = "create queue, add notifier, register and associate notification service", + saved = [CreateQueue rId qr, AddNotifier rId ntfCreds, NewService sr, QueueService rId (ASP SNotifier) (Just serviceId)], + compacted = [NewService sr, CreateQueue rId qr {notifier = Just ntfCreds {ntfServiceId = Just serviceId}}], + state = M.fromList [(rId, qr {notifier = Just ntfCreds {ntfServiceId = Just serviceId}})] + }, SLTC { name = "delete notifier", saved = [CreateQueue rId qr, AddNotifier rId ntfCreds, DeleteNotifier rId], @@ -133,6 +146,20 @@ storeLogTests = } ] +newTestServiceRec :: TVar ChaChaDRG -> IO ServiceRec +newTestServiceRec g = do + serviceId <- atomically $ EntityId <$> C.randomBytes 24 g + (_, cert) <- genCredentials g Nothing (0, 2400) "ntf.example.com" + serviceCreatedAt <- getSystemDate + pure + ServiceRec + { serviceId, + serviceRole = SRNotifier, + serviceCert = X.CertificateChain [cert], + serviceCertHash = XV.getFingerprint cert X.HashSHA256, + serviceCreatedAt + } + testSMPStoreLog :: String -> [SMPStoreLogTestCase] -> Spec testSMPStoreLog testSuite tests = describe testSuite $ forM_ tests $ \t@SLTC {name, saved} -> it name $ do @@ -141,18 +168,19 @@ testSMPStoreLog testSuite tests = closeStoreLog l replicateM_ 3 $ testReadWrite t #if defined(dbServerPostgres) - qCnt <- fromIntegral <$> importStoreLogToDatabase "tests/tmp/" testStoreLogFile testStoreDBOpts - qCnt `shouldBe` length (compacted t) + (sCnt, qCnt) <- importStoreLogToDatabase "tests/tmp/" testStoreLogFile testStoreDBOpts + fromIntegral (sCnt + qCnt) `shouldBe` length (compacted t) imported <- B.readFile $ testStoreLogFile <> ".bak" - qCnt' <- exportDatabaseToStoreLog "tests/tmp/" testStoreDBOpts testStoreLogFile - qCnt' `shouldBe` qCnt + (sCnt', qCnt') <- exportDatabaseToStoreLog "tests/tmp/" testStoreDBOpts testStoreLogFile + sCnt' `shouldBe` fromIntegral sCnt + qCnt' `shouldBe` fromIntegral qCnt exported <- B.readFile testStoreLogFile imported `shouldBe` exported #endif where testReadWrite SLTC {compacted, state} = do st <- newMsgStore $ testJournalStoreCfg MQStoreCfg - l <- readWriteQueueStore True (mkQueue st True) testStoreLogFile $ queueStore st + l <- readWriteQueueStore True (mkQueue st True) testStoreLogFile $ stmQueueStore st storeState st `shouldReturn` state closeStoreLog l ([], compacted') <- partitionEithers . map strDecode . B.lines <$> B.readFile testStoreLogFile diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs index 6f4f79909..1bc0b9f3f 100644 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -201,6 +201,7 @@ dummyRQ userId server connId rcvId = sndId = NoEntity, queueMode = Just QMMessaging, shortLink = Nothing, + clientService = Nothing, status = New, dbQueueId = DBEntityId 0, primary = True, diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index bb5af9722..60a160746 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -34,7 +34,7 @@ import Network.HTTP.Types (Status) import qualified Network.HTTP.Types as N import qualified Network.HTTP2.Server as H import Network.Socket -import SMPClient (defaultStartOptions, ntfTestPort, prevRange, serverBracket) +import SMPClient (defaultStartOptions, ntfTestPort, ntfTestServerCredentials, prevRange, serverBracket) import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultNetworkConfig) @@ -120,7 +120,7 @@ testNtfClient :: Transport c => (THandleNTF c 'TClient -> IO a) -> IO a testNtfClient client = do Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h -> - runExceptT (ntfClientHandshake h testKeyHash supportedClientNTFVRange False) >>= \case + runExceptT (ntfClientHandshake h testKeyHash supportedClientNTFVRange False Nothing) >>= \case Right th -> client th Left e -> error $ show e @@ -144,12 +144,8 @@ ntfServerCfg = subsBatchSize = 900, inactiveClientExpiration = Just defaultInactiveClientExpiration, dbStoreConfig = ntfTestDBCfg, - ntfCredentials = - ServerCredentials - { caCertificateFile = Just "tests/fixtures/ca.crt", - privateKeyFile = "tests/fixtures/server.key", - certificateFile = "tests/fixtures/server.crt" - }, + ntfCredentials = ntfTestServerCredentials, + useServiceCreds = True, periodicNtfsInterval = 1, -- stats config logStatsInterval = Nothing, @@ -159,7 +155,7 @@ ntfServerCfg = prometheusInterval = Nothing, prometheusMetricsFile = ntfTestPrometheusMetricsFile, ntfServerVRange = supportedServerNTFVRange, - transportConfig = mkTransportServerConfig True $ Just alpnSupportedNTFHandshakes, + transportConfig = mkTransportServerConfig True (Just alpnSupportedNTFHandshakes) False, startOptions = defaultStartOptions } @@ -200,11 +196,11 @@ ntfServerTest :: forall c smp. (Transport c, Encoding smp) => TProxy c 'TServer -> - (Maybe TransmissionAuth, ByteString, ByteString, smp) -> - IO (Maybe TransmissionAuth, ByteString, ByteString, NtfResponse) + (Maybe TAuthorizations, ByteString, ByteString, smp) -> + IO (Maybe TAuthorizations, ByteString, ByteString, NtfResponse) ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h where - tPut' :: THandleNTF c 'TClient -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () + tPut' :: THandleNTF c 'TClient -> (Maybe TAuthorizations, ByteString, ByteString, smp) -> IO () tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp) [Right ()] <- tPut h [Right (sig, t')] @@ -242,10 +238,10 @@ apnsMockServerConfig = privateKeyFile = "tests/fixtures/server.key", certificateFile = "tests/fixtures/server.crt" }, - transportConfig = mkTransportServerConfig True Nothing + transportConfig = mkTransportServerConfig True Nothing False } -withAPNSMockServer :: (APNSMockServer -> IO ()) -> IO () +withAPNSMockServer :: (APNSMockServer -> IO a) -> IO a withAPNSMockServer = E.bracket (getAPNSMockServer apnsMockServerConfig) closeAPNSMockServer deriving instance Generic APNSAlertBody diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index b8709957d..30e37b080 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -7,6 +7,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} {-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} @@ -42,17 +43,18 @@ import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Transport (THandleNTF) import Simplex.Messaging.Parsers (parse, parseAll) import Simplex.Messaging.Protocol hiding (notification) +import Simplex.Messaging.Server.Env.STM (AStoreType) import Simplex.Messaging.Transport import Test.Hspec hiding (fit, it) import UnliftIO.STM import Util -ntfServerTests :: ASrvTransport -> Spec -ntfServerTests t = do +ntfServerTests :: (ASrvTransport, AStoreType) -> Spec +ntfServerTests ps@(t, _) = do describe "Notifications server protocol syntax" $ ntfSyntaxTests t - describe "Notification subscriptions (NKEY)" $ testNotificationSubscription t createNtfQueueNKEY - -- describe "Notification subscriptions (NEW with ntf creds)" $ testNotificationSubscription t createNtfQueueNEW - describe "Retried notification subscription" $ testRetriedNtfSubscription t + describe "Notification subscriptions (NKEY)" $ testNotificationSubscription ps createNtfQueueNKEY + -- describe "Notification subscriptions (NEW with ntf creds)" $ testNotificationSubscription ps createNtfQueueNEW + describe "Retried notification subscription" $ testRetriedNtfSubscription ps ntfSyntaxTests :: ASrvTransport -> Spec ntfSyntaxTests (ATransport t) = do @@ -65,8 +67,8 @@ ntfSyntaxTests (ATransport t) = do where (>#>) :: Encoding smp => - (Maybe TransmissionAuth, ByteString, ByteString, smp) -> - (Maybe TransmissionAuth, ByteString, ByteString, NtfResponse) -> + (Maybe TAuthorizations, ByteString, ByteString, smp) -> + (Maybe TAuthorizations, ByteString, ByteString, NtfResponse) -> Expectation command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response @@ -75,7 +77,7 @@ pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command deriving instance Eq NtfResponse -sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> (Maybe TransmissionAuth, ByteString, NtfEntityId, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) +sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c 'TClient -> (Maybe TAuthorizations, ByteString, NtfEntityId, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) sendRecvNtf h@THandle {params} (sgn, corrId, qId, cmd) = do let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, tToSend) @@ -87,7 +89,7 @@ signSendRecvNtf h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = Right () <- tPut1 h (authorize tForAuth, tToSend) tGet1 h where - authorize t = case a of + authorize t = (,Nothing) <$> case a of C.SEd25519 -> Just . TASignature . C.ASignature C.SEd25519 $ C.sign' pk t C.SEd448 -> Just . TASignature . C.ASignature C.SEd448 $ C.sign' pk t _ -> Nothing @@ -97,8 +99,8 @@ v .-> key = let J.Object o = v in U.decodeLenient . encodeUtf8 <$> JT.parseEither (J..: key) o -testNotificationSubscription :: ASrvTransport -> CreateQueueFunc -> Spec -testNotificationSubscription (ATransport t) createQueue = +testNotificationSubscription :: (ASrvTransport, AStoreType) -> CreateQueueFunc -> Spec +testNotificationSubscription (ATransport t, msType) createQueue = it "should create notification subscription and notify when message is received" $ do g <- C.newRandom (sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g @@ -107,7 +109,7 @@ testNotificationSubscription (ATransport t) createQueue = (dhPub, dhPriv :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g let tkn = DeviceToken PPApnsTest "abcd" withAPNSMockServer $ \apns -> - smpTest2' t $ \rh sh -> + smpTest2 t msType $ \rh sh -> ntfTest t $ \nh -> do ((sId, rId, rKey, rcvDhSecret), nId, rcvNtfDhSecret) <- createQueue rh sPub nPub -- register and verify token @@ -180,14 +182,14 @@ testNotificationSubscription (ATransport t) createQueue = smpServer3 `shouldBe` srv notifierId3 `shouldBe` nId -testRetriedNtfSubscription :: ASrvTransport -> Spec -testRetriedNtfSubscription (ATransport t) = +testRetriedNtfSubscription :: (ASrvTransport, AStoreType) -> Spec +testRetriedNtfSubscription (ATransport t, msType) = it "should allow retrying to create notification subscription with the same token and key" $ do g <- C.newRandom (sPub, _sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g (nPub, nKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g withAPNSMockServer $ \apns -> - smpTest' t $ \h -> + smpTest t msType $ \h -> ntfTest t $ \nh -> do ((_sId, _rId, _rKey, _rcvDhSecret), nId, _rcvNtfDhSecret) <- createNtfQueueNKEY h sPub nPub (tknKey, _dhSecret, tId, regCode) <- registerToken nh apns "abcd" diff --git a/tests/PostgresSchemaDump.hs b/tests/PostgresSchemaDump.hs index dbacce3f3..234ac8a30 100644 --- a/tests/PostgresSchemaDump.hs +++ b/tests/PostgresSchemaDump.hs @@ -15,7 +15,6 @@ import qualified Simplex.Messaging.Agent.Store.Postgres.Migrations as Migrations import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationsToRun (..), toDownMigration) import Simplex.Messaging.Util (ifM, whenM) import System.Directory (doesFileExist, removeFile) -import System.Environment (lookupEnv) import System.Process (readCreateProcess, shell) import Test.Hspec hiding (fit, it) import Util @@ -58,7 +57,7 @@ postgresSchemaDumpTest migrations skipComparisonForDownMigrations testDBOpts@DBO getSchema :: FilePath -> IO String getSchema schemaPath = do - ci <- (Just "true" ==) <$> lookupEnv "CI" + ci <- envCI let cmd = ("pg_dump " <> B.unpack connstr <> " --schema " <> B.unpack testDBSchema) <> " --schema-only --no-owner --no-privileges --no-acl --no-subscriptions --no-tablespaces > " diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 276fe0388..b8b10422e 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -33,9 +33,9 @@ import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (.. import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client import Simplex.Messaging.Transport.Server +import Simplex.Messaging.Util (ifM) import Simplex.Messaging.Version import Simplex.Messaging.Version.Internal -import System.Environment (lookupEnv) import System.Info (os) import Test.Hspec hiding (fit, it) import UnliftIO.Concurrent @@ -143,10 +143,7 @@ xit'' :: (HasCallStack, Example a) => String -> a -> SpecWith (Arg a) xit'' d = skipOnCI . it d skipOnCI :: SpecWith a -> SpecWith a -skipOnCI t = - runIO (lookupEnv "CI") >>= \case - Just "true" -> skip "skipped on CI" t - _ -> t +skipOnCI t = ifM (runIO envCI) (skip "skipped on CI" t) t testSMPClient :: Transport c => (THandleSMP c 'TClient -> IO a) -> IO a testSMPClient = testSMPClientVR supportedClientSMPRelayVRange @@ -160,7 +157,7 @@ testSMPClient_ :: Transport c => TransportHost -> ServiceName -> VersionRangeSMP testSMPClient_ host port vr client = do let tcConfig = defaultTransportClientConfig {clientALPN} :: TransportClientConfig runTransportClient tcConfig Nothing host port (Just testKeyHash) $ \h -> - runExceptT (smpClientHandshake h Nothing testKeyHash vr False) >>= \case + runExceptT (smpClientHandshake h Nothing testKeyHash vr False Nothing) >>= \case Right th -> client th Left e -> error $ show e where @@ -168,6 +165,30 @@ testSMPClient_ host port vr client = do | authCmdsSMPVersion `isCompatible` vr = Just alpnSupportedSMPHandshakes | otherwise = Nothing +testNtfServiceClient :: Transport c => TProxy c 'TServer -> C.KeyPairEd25519 -> (THandleSMP c 'TClient -> IO a) -> IO a +testNtfServiceClient _ keys client = do + tlsNtfServerCreds <- loadServerCredential ntfTestServerCredentials + serviceCertHash <- loadFingerprint ntfTestServerCredentials + Right serviceSignKey <- pure $ C.x509ToPrivate' $ snd tlsNtfServerCreds + let service = ServiceCredentials {serviceRole = SRNotifier, serviceCreds = tlsNtfServerCreds, serviceCertHash, serviceSignKey} + tcConfig = + defaultTransportClientConfig + { clientCredentials = Just tlsNtfServerCreds, + clientALPN = Just alpnSupportedSMPHandshakes + } + runTransportClient tcConfig Nothing "localhost" testPort (Just testKeyHash) $ \h -> + runExceptT (smpClientHandshake h Nothing testKeyHash supportedClientSMPRelayVRange False $ Just (service, keys)) >>= \case + Right th -> client th + Left e -> error $ show e + +ntfTestServerCredentials :: ServerCredentials +ntfTestServerCredentials = + ServerCredentials + { caCertificateFile = Just "tests/fixtures/ca.crt", + privateKeyFile = "tests/fixtures/server.key", + certificateFile = "tests/fixtures/server.crt" + } + cfg :: AServerConfig cfg = cfgMS (ASType SQSMemory SMSJournal) @@ -226,7 +247,7 @@ cfgMS msType = withStoreCfg (testServerStoreConfig msType) $ \serverStoreCfg -> }, httpCredentials = Nothing, smpServerVRange = supportedServerSMPRelayVRange, - transportConfig = mkTransportServerConfig True $ Just alpnSupportedSMPHandshakes, + transportConfig = mkTransportServerConfig True (Just alpnSupportedSMPHandshakes) True, controlPort = Nothing, smpAgentCfg = defaultSMPClientAgentConfig {persistErrorInterval = 1}, -- seconds allowSMPProxy = False, @@ -258,9 +279,6 @@ serverStoreConfig_ useDbStoreLog = \case cfgV7 :: AServerConfig cfgV7 = updateCfg cfg $ \cfg' -> cfg' {smpServerVRange = mkVersionRange minServerSMPRelayVersion authCmdsSMPVersion} -cfgV8 :: AStoreType -> AServerConfig -cfgV8 msType = updateCfg (cfgMS msType) $ \cfg' -> cfg' {smpServerVRange = mkVersionRange minServerSMPRelayVersion sendingProxySMPVersion} - cfgVPrev :: AStoreType -> AServerConfig cfgVPrev msType = updateCfg (cfgMS msType) $ \cfg' -> cfg' {smpServerVRange = prevRange $ smpServerVRange cfg'} @@ -366,11 +384,11 @@ smpServerTest :: forall c smp. (Transport c, Encoding smp) => TProxy c 'TServer -> - (Maybe TransmissionAuth, ByteString, ByteString, smp) -> - IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) + (Maybe TAuthorizations, ByteString, ByteString, smp) -> + IO (Maybe TAuthorizations, ByteString, ByteString, BrokerMsg) smpServerTest _ t = runSmpTest (ASType SQSMemory SMSJournal) $ \h -> tPut' h t >> tGet' h where - tPut' :: THandleSMP c 'TClient -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () + tPut' :: THandleSMP c 'TClient -> (Maybe TAuthorizations, ByteString, ByteString, smp) -> IO () tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp) [Right ()] <- tPut h [Right (sig, t')] @@ -382,15 +400,9 @@ smpServerTest _ t = runSmpTest (ASType SQSMemory SMSJournal) $ \h -> tPut' h t > smpTest :: (HasCallStack, Transport c) => TProxy c 'TServer -> AStoreType -> (HasCallStack => THandleSMP c 'TClient -> IO ()) -> Expectation smpTest _ msType test' = runSmpTest msType test' `shouldReturn` () -smpTest' :: forall c. (HasCallStack, Transport c) => TProxy c 'TServer -> (HasCallStack => THandleSMP c 'TClient -> IO ()) -> Expectation -smpTest' = (`smpTest` ASType SQSMemory SMSJournal) - smpTestN :: (HasCallStack, Transport c) => AStoreType -> Int -> (HasCallStack => [THandleSMP c 'TClient] -> IO ()) -> Expectation smpTestN msType n test' = runSmpTestN msType n test' `shouldReturn` () -smpTest2' :: forall c. (HasCallStack, Transport c) => TProxy c 'TServer -> (HasCallStack => THandleSMP c 'TClient -> THandleSMP c 'TClient -> IO ()) -> Expectation -smpTest2' = (`smpTest2` ASType SQSMemory SMSJournal) - smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c 'TServer -> AStoreType -> (HasCallStack => THandleSMP c 'TClient -> THandleSMP c 'TClient -> IO ()) -> Expectation smpTest2 t msType = smpTest2Cfg (cfgMS msType) supportedClientSMPRelayVRange t diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index b47641239..207550a32 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -36,7 +36,7 @@ import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (pattern PQSupportOn) import qualified Simplex.Messaging.Crypto.Ratchet as CR -import Simplex.Messaging.Protocol (EncRcvMsgBody (..), MsgBody, QueueReqData (..), RcvMessage (..), SubscriptionMode (..), maxMessageLength, noMsgFlags, pattern NoEntity) +import Simplex.Messaging.Protocol (EncRcvMsgBody (..), MsgBody, QueueReqData (..), RcvMessage (..), SubscriptionMode (..), maxMessageLength, noMsgFlags) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (AStoreType (..), ServerConfig (..)) import Simplex.Messaging.Server.MsgStore.Types (SQSType (..)) @@ -148,9 +148,9 @@ smpProxyTests = do where oneServer test msType = withSmpServerConfigOn (transport @TLS) (updateCfg (proxyCfgMS msType) $ \cfg_ -> cfg_ {msgQueueQuota = 128, maxJournalMsgCount = 256}) testPort $ const test twoServers test msType = twoServers_ (proxyCfgMS msType) (proxyCfgMS msType) test msType - twoServersFirstProxy test msType = twoServers_ (proxyCfgMS msType) (updateCfg (cfgV8 msType) $ \cfg_ -> cfg_ {msgQueueQuota = 128, maxJournalMsgCount = 256}) test msType - twoServersMoreConc test msType = twoServers_ (updateCfg (proxyCfgMS msType) $ \cfg_ -> cfg_ {serverClientConcurrency = 128}) (updateCfg (cfgV8 msType) $ \cfg_ -> cfg_ {msgQueueQuota = 128, maxJournalMsgCount = 256}) test msType - twoServersNoConc test msType = twoServers_ (updateCfg (proxyCfgMS msType) $ \cfg_ -> cfg_ {serverClientConcurrency = 1}) (updateCfg (cfgV8 msType) $ \cfg_ -> cfg_ {msgQueueQuota = 128, maxJournalMsgCount = 256}) test msType + twoServersFirstProxy test msType = twoServers_ (proxyCfgMS msType) (updateCfg (cfgMS msType) $ \cfg_ -> cfg_ {msgQueueQuota = 128, maxJournalMsgCount = 256}) test msType + twoServersMoreConc test msType = twoServers_ (updateCfg (proxyCfgMS msType) $ \cfg_ -> cfg_ {serverClientConcurrency = 128}) (updateCfg (cfgMS msType) $ \cfg_ -> cfg_ {msgQueueQuota = 128, maxJournalMsgCount = 256}) test msType + twoServersNoConc test msType = twoServers_ (updateCfg (proxyCfgMS msType) $ \cfg_ -> cfg_ {serverClientConcurrency = 1}) (updateCfg (cfgMS msType) $ \cfg_ -> cfg_ {msgQueueQuota = 128, maxJournalMsgCount = 256}) test msType twoServers_ :: AServerConfig -> AServerConfig -> IO () -> AStoreType -> IO () twoServers_ cfg1 cfg2 runTest (ASType qsType _) = withSmpServerConfigOn (transport @TLS) cfg1 testPort $ \_ -> @@ -172,7 +172,7 @@ deliverMessagesViaProxy proxyServ relayServ alg unsecuredMsgs securedMsgs = do THAuthClient {} <- maybe (fail "getProtocolClient returned no thAuth") pure $ thAuth $ thParams pc -- set up relay msgQ <- newTBQueueIO 1024 - rc' <- getProtocolClient g (2, relayServ, Nothing) defaultSMPClientConfig {serverVRange = mkVersionRange minServerSMPRelayVersion authCmdsSMPVersion} [] (Just msgQ) ts (\_ -> pure ()) + rc' <- getProtocolClient g (2, relayServ, Nothing) defaultSMPClientConfig {serverVRange = mkVersionRange minServerSMPRelayVersion currentClientSMPRelayVersion} [] (Just msgQ) ts (\_ -> pure ()) rc <- either (fail . show) pure rc' -- prepare receiving queue (rPub, rPriv) <- atomically $ C.generateAuthKeyPair alg g @@ -224,9 +224,9 @@ agentDeliverMessageViaProxy :: (C.AlgorithmI a, C.AuthAlgorithm a) => (NonEmpty agentDeliverMessageViaProxy aTestCfg@(aSrvs, _, aViaProxy) bTestCfg@(bSrvs, _, bViaProxy) alg msg1 msg2 baseId = withAgent 1 aCfg (servers aTestCfg) testDB $ \alice -> withAgent 2 aCfg (servers bTestCfg) testDB2 $ \bob -> runRight_ $ do - (bobId, CCLink qInfo Nothing) <- A.createConnection alice 1 True SCMInvitation Nothing Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe + (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice 1 True SCMInvitation Nothing Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - sqSecured <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (sqSecured, Nothing) <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` True ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn @@ -280,9 +280,9 @@ agentDeliverMessagesViaProxyConc agentServers msgs = -- agent connections have to be set up in advance -- otherwise the CONF messages would get mixed with MSG prePair alice bob = do - (bobId, CCLink qInfo Nothing) <- runExceptT' $ A.createConnection alice 1 True SCMInvitation Nothing Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe + (bobId, (CCLink qInfo Nothing, Nothing)) <- runExceptT' $ A.createConnection alice 1 True SCMInvitation Nothing Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe aliceId <- runExceptT' $ A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - sqSecured <- runExceptT' $ A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (sqSecured, Nothing) <- runExceptT' $ A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` True confId <- get alice >>= \case @@ -331,7 +331,7 @@ agentViaProxyVersionError = withAgent 1 agentCfg (servers [SMPServer testHost testPort testKeyHash]) testDB $ \alice -> do Left (A.BROKER _ (TRANSPORT TEVersion)) <- withAgent 2 agentCfg (servers [SMPServer testHost2 testPort2 testKeyHash]) testDB2 $ \bob -> runExceptT $ do - (_bobId, CCLink qInfo Nothing) <- A.createConnection alice 1 True SCMInvitation Nothing Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe + (_bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice 1 True SCMInvitation Nothing Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe pure () @@ -351,9 +351,9 @@ agentViaProxyRetryOffline = do let pqEnc = CR.PQEncOn withServer $ \_ -> do (aliceId, bobId) <- withServer2 $ \_ -> runRight $ do - (bobId, CCLink qInfo Nothing) <- A.createConnection alice 1 True SCMInvitation Nothing Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe + (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice 1 True SCMInvitation Nothing Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - sqSecured <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (sqSecured, Nothing) <- A.joinConnection bob 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` True ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 082045990..7398ad022 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -2,6 +2,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -10,6 +11,7 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -Wno-orphans #-} @@ -22,23 +24,25 @@ import Control.Monad import Control.Monad.IO.Class import CoreTests.MsgStoreTests (testJournalStoreCfg) import Data.Bifunctor (first) -import Data.ByteString.Base64 +import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Hashable (hash) import qualified Data.IntSet as IS +import Data.String (IsString (..)) import Data.Type.Equality +import qualified Data.X509.Validation as XV import GHC.Stack (withFrozenCallStack) import SMPClient import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Parsers (parseAll) +import Simplex.Messaging.Parsers (parseAll, parseString) import Simplex.Messaging.Protocol import Simplex.Messaging.Server (exportMessages) import Simplex.Messaging.Server.Env.STM (AStoreType (..), ServerConfig (..), ServerStoreCfg (..), readWriteQueueStore) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Server.MsgStore.Journal (JournalStoreConfig (..), QStoreCfg (..)) +import Simplex.Messaging.Server.MsgStore.Journal (JournalStoreConfig (..), QStoreCfg (..), stmQueueStore) import Simplex.Messaging.Server.MsgStore.Types (MsgStoreClass (..), SMSType (..), SQSType (..), newMsgStore) import Simplex.Messaging.Server.Stats (PeriodStatsData (..), ServerStatsData (..)) import Simplex.Messaging.Server.StoreLog (StoreLogRecord (..), closeStoreLog) @@ -76,7 +80,9 @@ serverTests = do describe "Restore messages (old / v2)" testRestoreExpireMessages describe "Save prometheus metrics" testPrometheusMetrics describe "Timing of AUTH error" testTiming - describe "Message notifications" testMessageNotifications + describe "Message notifications" $ do + testMessageNotifications + testMessageServiceNotifications describe "Message expiration" $ do testMsgExpireOnSend testMsgExpireOnInterval @@ -93,30 +99,40 @@ pattern New :: RcvPublicAuthKey -> RcvPublicDhKey -> Command 'Recipient pattern New rPub dhPub = NEW (NewQueueReq rPub dhPub Nothing SMSubscribe (Just (QRMessaging Nothing))) pattern Ids :: RecipientId -> SenderId -> RcvPublicDhKey -> BrokerMsg -pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh _sndSecure _linkId) +pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh _sndSecure _linkId Nothing) pattern Msg :: MsgId -> MsgBody -> BrokerMsg pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> (Maybe TransmissionAuth, ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) +sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> (Maybe TAuthorizations, ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) sendRecv h@THandle {params} (sgn, corrId, qId, cmd) = do let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, tToSend) tGet1 h signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> (ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) -signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do +signSendRecv h pk = signSendRecv_ h pk Nothing + +serviceSignSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) +serviceSignSendRecv h pk = signSendRecv_ h pk . Just + +signSendRecv_ :: forall c p. (Transport c, PartyI p) => THandleSMP c 'TClient -> C.APrivateAuthKey -> Maybe C.PrivateKeyEd25519 -> (ByteString, EntityId, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) +signSendRecv_ h@THandle {params} (C.APrivateAuthKey a pk) serviceKey_ (corrId, qId, cmd) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (authorize tForAuth, tToSend) tGet1 h where - authorize t = case a of - C.SEd25519 -> Just . TASignature . C.ASignature C.SEd25519 $ C.sign' pk t - C.SEd448 -> Just . TASignature . C.ASignature C.SEd448 $ C.sign' pk t - C.SX25519 -> (\THAuthClient {serverPeerPubKey = k} -> TAAuthenticator $ C.cbAuthenticate k pk (C.cbNonce corrId) t) <$> thAuth params + authorize t = (,(`C.sign'` t) <$> serviceKey_) <$> case a of + C.SEd25519 -> Just . TASignature . C.ASignature C.SEd25519 $ C.sign' pk t' + C.SEd448 -> Just . TASignature . C.ASignature C.SEd448 $ C.sign' pk t' + C.SX25519 -> (\THAuthClient {peerServerPubKey = k} -> TAAuthenticator $ C.cbAuthenticate k pk (C.cbNonce corrId) t') <$> thAuth params #if !MIN_VERSION_base(4,18,0) _sx448 -> undefined -- ghc8107 fails to the branch excluded by types #endif + where + t' = case (serviceKey_, thAuth params >>= clientService) of + (Just _, Just THClientService {serviceCertHash = XV.Fingerprint fp}) -> fp <> t + _ -> t tPut1 :: Transport c => THandle v c 'TClient -> SentRawTransmission -> IO (Either TransportError ()) tPut1 h t = do @@ -432,13 +448,13 @@ testDuplex = (bDhPub, bDhPriv :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g Resp "abcd" _ (Ids bRcv bSnd bSrvDh) <- signSendRecv bob brKey ("abcd", NoEntity, New brPub bDhPub) let bDec = decryptMsgV3 $ C.dh' bSrvDh bDhPriv - Resp "bcda" _ OK <- signSendRecv bob bsKey ("bcda", aSnd, _SEND $ "reply_id " <> encode (unEntityId bSnd)) + Resp "bcda" _ OK <- signSendRecv bob bsKey ("bcda", aSnd, _SEND $ "reply_id " <> B64.encode (unEntityId bSnd)) -- "reply_id ..." is ad-hoc, not a part of SMP protocol Resp "" _ (Msg mId2 msg2) <- tGet1 alice Resp "cdab" _ OK <- signSendRecv alice arKey ("cdab", aRcv, ACK mId2) Right ["reply_id", bId] <- pure $ B.words <$> aDec mId2 msg2 - (bId, encode (unEntityId bSnd)) #== "reply queue ID received from Bob" + (bId, B64.encode (unEntityId bSnd)) #== "reply queue ID received from Bob" (asPub, asKey) <- atomically $ C.generateAuthKeyPair C.SEd448 g Resp "dabc" _ OK <- sendRecv alice ("", "dabc", bSnd, _SEND $ "key " <> strEncode asPub) @@ -629,7 +645,7 @@ testWithStoreLog = writeTVar dhShared1 $ Just dhShared writeTVar senderId1 sId1 writeTVar notifierId nId - Resp "dabc" _ OK <- signSendRecv h1 nKey ("dabc", nId, NSUB) + Resp "dabc" _ (SOK Nothing) <- signSendRecv h1 nKey ("dabc", nId, NSUB) (mId1, msg1) <- signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello") >>= \case Resp "" _ (Msg mId1 msg1) -> pure (mId1, msg1) @@ -666,7 +682,7 @@ testWithStoreLog = Just dh1 <- readTVarIO dhShared1 sId1 <- readTVarIO senderId1 nId <- readTVarIO notifierId - Resp "dabc" _ OK <- signSendRecv h1 nKey ("dabc", nId, NSUB) + Resp "dabc" _ (SOK Nothing) <- signSendRecv h1 nKey ("dabc", nId, NSUB) Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello") Resp "cdab" _ (Msg mId3 msg3) <- signSendRecv h rKey1 ("cdab", rId1, SUB) (decryptMsgV3 dh1 mId3 msg3, Right "hello") #== "delivered from restored queue" @@ -746,7 +762,7 @@ testRestoreMessages = pure () rId <- readTVarIO recipientId logSize testStoreLogFile `shouldReturn` 2 - logSize testServerStatsBackupFile `shouldReturn` 76 + logSize testServerStatsBackupFile `shouldReturn` 94 Right stats1 <- strDecode <$> B.readFile testServerStatsBackupFile checkStats stats1 [rId] 5 1 withSmpServerConfigOn at cfg' testPort . runTest t $ \h -> do @@ -762,7 +778,7 @@ testRestoreMessages = logSize testStoreLogFile `shouldReturn` (if compacting then 1 else 2) -- the last message is not removed because it was not ACK'd -- logSize testStoreMsgsFile `shouldReturn` 3 - logSize testServerStatsBackupFile `shouldReturn` 76 + logSize testServerStatsBackupFile `shouldReturn` 94 Right stats2 <- strDecode <$> B.readFile testServerStatsBackupFile checkStats stats2 [rId] 5 3 @@ -780,7 +796,7 @@ testRestoreMessages = pure () logSize testStoreLogFile `shouldReturn` (if compacting then 1 else 2) removeFile testStoreLogFile - logSize testServerStatsBackupFile `shouldReturn` 76 + logSize testServerStatsBackupFile `shouldReturn` 94 Right stats3 <- strDecode <$> B.readFile testServerStatsBackupFile checkStats stats3 [rId] 5 5 removeFileIfExists testStoreMsgsFile @@ -869,7 +885,7 @@ testRestoreExpireMessages = where export = do ms <- newMsgStore (testJournalStoreCfg MQStoreCfg) {quota = 4} - readWriteQueueStore True (mkQueue ms True) testStoreLogFile (queueStore ms) >>= closeStoreLog + readWriteQueueStore True (mkQueue ms True) testStoreLogFile (stmQueueStore ms) >>= closeStoreLog removeFileIfExists testStoreMsgsFile exportMessages False ms testStoreMsgsFile False closeMsgStore ms @@ -976,21 +992,24 @@ testMessageNotifications = it "should create simplex connection, subscribe notifier and deliver notifications" $ \(ATransport t, msType) -> do g <- C.newRandom (sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g - (nPub, nKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g smpTest4 t msType $ \rh sh nh1 nh2 -> do (sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub let dec = decryptMsgV3 dhShared + (nPub, nKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g (rcvNtfPubDhKey, _) <- atomically $ C.generateKeyPair g Resp "1" _ (NID nId' _) <- signSendRecv rh rKey ("1", rId, NKEY nPub rcvNtfPubDhKey) Resp "1a" _ (NID nId _) <- signSendRecv rh rKey ("1a", rId, NKEY nPub rcvNtfPubDhKey) nId' `shouldNotBe` nId - Resp "2" _ OK <- signSendRecv nh1 nKey ("2", nId, NSUB) + -- can't subscribe with service signature without service connection + (_, servicePK) <- atomically $ C.generateKeyPair g + Resp "2'" _ (ERR SERVICE) <- serviceSignSendRecv nh1 nKey servicePK ("2'", nId, NSUB) + Resp "2" _ (SOK Nothing) <- signSendRecv nh1 nKey ("2", nId, NSUB) Resp "3" _ OK <- signSendRecv sh sKey ("3", sId, _SEND' "hello") Resp "" _ (Msg mId1 msg1) <- tGet1 rh (dec mId1 msg1, Right "hello") #== "delivered from queue" Resp "3a" _ OK <- signSendRecv rh rKey ("3a", rId, ACK mId1) Resp "" _ (NMSG _ _) <- tGet1 nh1 - Resp "4" _ OK <- signSendRecv nh2 nKey ("4", nId, NSUB) + Resp "4" _ (SOK Nothing) <- signSendRecv nh2 nKey ("4", nId, NSUB) Resp "" nId2 END <- tGet1 nh1 nId2 `shouldBe` nId Resp "5" _ OK <- signSendRecv sh sKey ("5", sId, _SEND' "hello again") @@ -1007,9 +1026,96 @@ testMessageNotifications = Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there") Resp "" _ (Msg mId3 msg3) <- tGet1 rh (dec mId3 msg3, Right "hello there") #== "delivered from queue again" + Resp "7a" _ OK <- signSendRecv rh rKey ("7a", rId, ACK mId3) 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection" + (nPub'', nKey'') <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (rcvNtfPubDhKey'', _) <- atomically $ C.generateKeyPair g + Resp "8" _ (NID nId'' _) <- signSendRecv rh rKey ("8", rId, NKEY nPub'' rcvNtfPubDhKey'') + Resp "9" _ (SOK Nothing) <- signSendRecv nh1 nKey'' ("9", nId'', NSUB) + Resp "10" _ OK <- signSendRecv sh sKey ("10", sId, _SEND' "one more") + Resp "" _ (Msg mId4 msg4) <- tGet1 rh + (dec mId4 msg4, Right "one more") #== "delivered from queue" + Resp "10a" _ OK <- signSendRecv rh rKey ("10a", rId, ACK mId4) + Resp "" _ (NMSG _ _) <- tGet1 nh1 + pure () + +testMessageServiceNotifications :: SpecWith (ASrvTransport, AStoreType) +testMessageServiceNotifications = + it "should create simplex connection, subscribe notifier as service and deliver notifications" $ \(ATransport t, msType) -> do + g <- C.newRandom + smpTest2 t msType $ \rh sh -> do + (sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub + let dec = decryptMsgV3 dhShared + (nPub, nKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (rcvNtfPubDhKey, _) <- atomically $ C.generateKeyPair g + Resp "1" _ (NID nId _) <- signSendRecv rh rKey ("1", rId, NKEY nPub rcvNtfPubDhKey) + serviceKeys@(_, servicePK) <- atomically $ C.generateKeyPair g + -- TODO [certs] we need to get certificate fingerprint and include it into signed over for NSUB commands + testNtfServiceClient t serviceKeys $ \nh1 -> do + -- can't subscribe without service signature in service connection + Resp "2a" _ (ERR SERVICE) <- signSendRecv nh1 nKey ("2a", nId, NSUB) + Resp "2b" _ (SOK (Just serviceId)) <- serviceSignSendRecv nh1 nKey servicePK ("2b", nId, NSUB) + -- repeat subscription works, to support retries + Resp "2c" _ (SOK (Just serviceId'')) <- serviceSignSendRecv nh1 nKey servicePK ("2c", nId, NSUB) + serviceId'' `shouldBe` serviceId + deliverMessage rh rId rKey sh sId sKey nh1 "hello" dec + testNtfServiceClient t serviceKeys $ \nh2 -> do + Resp "4" _ (SOK (Just serviceId')) <- serviceSignSendRecv nh2 nKey servicePK ("4", nId, NSUB) + serviceId' `shouldBe` serviceId + -- service subscription is terminated + Resp "" serviceId2 (ENDS 1) <- tGet1 nh1 + serviceId2 `shouldBe` serviceId + deliverMessage rh rId rKey sh sId sKey nh2 "hello again" dec + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case + Nothing -> pure () + Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection" + Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL) + Resp "" nId3 DELD <- tGet1 nh2 + nId3 `shouldBe` nId + Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there") + Resp "" _ (Msg mId3 msg3) <- tGet1 rh + (dec mId3 msg3, Right "hello there") #== "delivered from queue again" + Resp "7a" _ OK <- signSendRecv rh rKey ("7a", rId, ACK mId3) + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case + Nothing -> pure () + Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection" + -- new notification credentials + (nPub', nKey') <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (rcvNtfPubDhKey', _) <- atomically $ C.generateKeyPair g + Resp "8" _ (NID nId' _) <- signSendRecv rh rKey ("8", rId, NKEY nPub' rcvNtfPubDhKey') + nId' == nId `shouldBe` False + Resp "9" _ (SOK (Just serviceId3)) <- serviceSignSendRecv nh2 nKey' servicePK ("9", nId', NSUB) + serviceId3 `shouldBe` serviceId + -- another queue + (sPub'', sKey'') <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (sId'', rId'', rKey'', dhShared'') <- createAndSecureQueue rh sPub'' + let dec'' = decryptMsgV3 dhShared'' + (nPub'', nKey'') <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (rcvNtfPubDhKey'', _) <- atomically $ C.generateKeyPair g + Resp "10" _ (NID nId'' _) <- signSendRecv rh rKey'' ("10", rId'', NKEY nPub'' rcvNtfPubDhKey'') + nId'' == nId `shouldBe` False + Resp "11" _ (SOK (Just serviceId4)) <- serviceSignSendRecv nh2 nKey'' servicePK ("11", nId'', NSUB) + serviceId4 `shouldBe` serviceId + deliverMessage rh rId rKey sh sId sKey nh2 "connection 1" dec + deliverMessage rh rId'' rKey'' sh sId'' sKey'' nh2 "connection 2" dec'' + -- -- another client makes service subscription + Resp "12" serviceId5 (SOKS 2) <- signSendRecv nh1 (C.APrivateAuthKey C.SEd25519 servicePK) ("12", serviceId, NSUBS) + serviceId5 `shouldBe` serviceId + Resp "" serviceId6 (ENDS 2) <- tGet1 nh2 + serviceId6 `shouldBe` serviceId + deliverMessage rh rId rKey sh sId sKey nh1 "connection 1 one more" dec + deliverMessage rh rId'' rKey'' sh sId'' sKey'' nh1 "connection 2 one more" dec'' + where + deliverMessage rh rId rKey sh sId sKey nh msgText dec = do + Resp "msg-1" _ OK <- signSendRecv sh sKey ("msg-1", sId, _SEND' msgText) + Resp "" _ (Msg mId msg) <- tGet1 rh + Resp "msg-2" _ OK <- signSendRecv rh rKey ("msg-2", rId, ACK mId) + (dec mId msg, Right msgText) #== "delivered from queue" + Resp "" _ (NMSG _ _) <- tGet1 nh + pure () testMsgExpireOnSend :: SpecWith (ASrvTransport, AStoreType) testMsgExpireOnSend = @@ -1110,7 +1216,7 @@ testInvQueueLinkData = -- sender ID must be derived from corrId Resp "1" NoEntity (ERR (CMD PROHIBITED)) <- signSendRecv r rKey ("1", NoEntity, NEW (NewQueueReq rPub dhPub Nothing SMSubscribe (Just qrd))) - Resp corrId' NoEntity (IDS (QIK rId sId' _srvDh (Just QMMessaging) (Just lnkId))) <- + Resp corrId' NoEntity (IDS (QIK rId sId' _srvDh (Just QMMessaging) (Just lnkId) Nothing)) <- signSendRecv r rKey (corrId, NoEntity, NEW (NewQueueReq rPub dhPub Nothing SMSubscribe (Just qrd))) (sId', sId) #== "should return the same sender ID" corrId' `shouldBe` CorrId corrId @@ -1167,7 +1273,7 @@ testContactQueueLinkData = -- sender ID must be derived from corrId Resp "1" NoEntity (ERR (CMD PROHIBITED)) <- signSendRecv r rKey ("1", NoEntity, NEW (NewQueueReq rPub dhPub Nothing SMSubscribe (Just qrd))) - Resp corrId' NoEntity (IDS (QIK rId sId' _srvDh (Just QMContact) (Just lnkId'))) <- + Resp corrId' NoEntity (IDS (QIK rId sId' _srvDh (Just QMContact) (Just lnkId') Nothing)) <- signSendRecv r rKey (corrId, NoEntity, NEW (NewQueueReq rPub dhPub Nothing SMSubscribe (Just qrd))) (lnkId', lnkId) #== "should return the same link ID" (sId', sId) #== "should return the same sender ID" @@ -1218,8 +1324,8 @@ samplePubKey = C.APublicVerifyKey C.SEd25519 "MCowBQYDK2VwAyEAfAOflyvbJv1fszgzkQ sampleDhPubKey :: C.PublicKey 'C.X25519 sampleDhPubKey = "MCowBQYDK2VuAyEAriy+HcARIhqsgSjVnjKqoft+y6pxrxdY68zn4+LjYhQ=" -sampleSig :: Maybe TransmissionAuth -sampleSig = Just $ TASignature "e8JK+8V3fq6kOLqco/SaKlpNaQ7i1gfOrXoqekEl42u4mF8Bgu14T5j0189CGcUhJHw2RwCMvON+qbvQ9ecJAA==" +sampleSig :: Maybe TAuthorizations +sampleSig = Just (TASignature "e8JK+8V3fq6kOLqco/SaKlpNaQ7i1gfOrXoqekEl42u4mF8Bgu14T5j0189CGcUhJHw2RwCMvON+qbvQ9ecJAA==", Nothing) noAuth :: (Char, Maybe BasicAuth) noAuth = ('A', Nothing) @@ -1231,6 +1337,9 @@ instance Eq C.ASignature where Just Refl -> s == s' _ -> False +instance IsString (Maybe TAuthorizations) where + fromString = parseString $ B64.decode >=> C.decodeSignature >=> pure . fmap ((,Nothing) . TASignature) + serverSyntaxTests :: ASrvTransport -> Spec serverSyntaxTests (ATransport t) = do it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN) @@ -1270,7 +1379,7 @@ serverSyntaxTests (ATransport t) = do it "no queue ID" $ (sampleSig, "dabc", "", cmd) >#> ("", "dabc", "", ERR $ CMD NO_AUTH) (>#>) :: Encoding smp => - (Maybe TransmissionAuth, ByteString, ByteString, smp) -> - (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) -> + (Maybe TAuthorizations, ByteString, ByteString, smp) -> + (Maybe TAuthorizations, ByteString, ByteString, BrokerMsg) -> Expectation command >#> response = withFrozenCallStack $ smpServerTest t command `shouldReturn` response diff --git a/tests/Test.hs b/tests/Test.hs index 06c627514..4598bb8e4 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -123,7 +123,11 @@ main = do ntfTestStoreDBOpts "src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql" aroundAll_ (postgressBracket ntfTestServerDBConnectInfo) $ do - describe "Notifications server" $ ntfServerTests (transport @TLS) + describe "Notifications server (SMP server: jornal store)" $ + ntfServerTests (transport @TLS, ASType SQSMemory SMSJournal) + aroundAll_ (postgressBracket testServerDBConnectInfo) $ + describe "Notifications server (SMP server: postgres+jornal store)" $ + ntfServerTests (transport @TLS, ASType SQSPostgres SMSJournal) aroundAll_ (postgressBracket testServerDBConnectInfo) $ do describe "SMP client agent, postgres+jornal message store" $ agentTests (transport @TLS, ASType SQSPostgres SMSJournal) describe "SMP proxy, postgres+jornal message store" $ diff --git a/tests/Util.hs b/tests/Util.hs index 7ca759781..9a4049d68 100644 --- a/tests/Util.hs +++ b/tests/Util.hs @@ -12,6 +12,7 @@ import Data.Either (partitionEithers) import Data.List (tails) import GHC.Conc (getNumCapabilities, getNumProcessors, setNumCapabilities) import System.Directory (doesFileExist, removeFile) +import System.Environment (lookupEnv) import System.Process (callCommand) import System.Timeout (timeout) import Test.Hspec hiding (fit, it) @@ -51,22 +52,24 @@ testLogLevel = LogError instance Example a => Example (TestWrapper a) where type Arg (TestWrapper a) = Arg a - evaluateExample (TestWrapper action) params hooks state = - runTest `E.catches` [E.Handler onTestFailure, E.Handler onTestException] + evaluateExample (TestWrapper action) params hooks state = do + ci <- envCI + runTest `E.catches` [E.Handler (onTestFailure ci), E.Handler (onTestException ci)] where tt = 120 runTest = timeout (tt * 1000000) (evaluateExample action params hooks state) `finally` callCommand "sync" >>= \case Just r -> pure r Nothing -> throwIO $ userError $ "test timed out after " <> show tt <> " seconds" - onTestFailure :: ResultStatus -> IO Result - onTestFailure = \case - Failure loc_ reason -> do + onTestFailure :: Bool -> ResultStatus -> IO Result + onTestFailure ci = \case + Failure loc_ reason | ci -> do putStrLn $ "Test failed: location " ++ show loc_ ++ ", reason: " ++ show reason retryTest r -> E.throwIO r - onTestException :: SomeException -> IO Result - onTestException e = do + onTestException :: Bool -> SomeException -> IO Result + onTestException False e = E.throwIO e + onTestException True e = do putStrLn $ "Test exception: " ++ show e retryTest retryTest = do @@ -74,6 +77,9 @@ instance Example a => Example (TestWrapper a) where setLogLevel LogDebug runTest `finally` setLogLevel testLogLevel -- change this to match log level in Test.hs +envCI :: IO Bool +envCI = (Just "true" ==) <$> lookupEnv "CI" + it :: (HasCallStack, Example a) => String -> a -> SpecWith (Arg a) it label action = Hspec.it label (TestWrapper action) diff --git a/tests/XFTPClient.hs b/tests/XFTPClient.hs index 8533d5a69..d3215ea95 100644 --- a/tests/XFTPClient.hs +++ b/tests/XFTPClient.hs @@ -127,7 +127,7 @@ testXFTPServerConfig = logStatsStartTime = 0, serverStatsLogFile = "tests/tmp/xftp-server-stats.daily.log", serverStatsBackupFile = Nothing, - transportConfig = mkTransportServerConfig True $ Just alpnSupportedXFTPhandshakes, + transportConfig = mkTransportServerConfig True (Just alpnSupportedXFTPhandshakes) False, responseDelay = 0 }