diff --git a/package.yaml b/package.yaml index 7ba64f531..02a088e23 100644 --- a/package.yaml +++ b/package.yaml @@ -180,6 +180,7 @@ ghc-options: - -Wall - -Wcompat - -Werror=incomplete-patterns + - -Werror=missing-methods - -Wredundant-constraints - -Wincomplete-record-updates - -Wincomplete-uni-patterns diff --git a/simplexmq.cabal b/simplexmq.cabal index 19c99acc4..3366cb0b8 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -174,7 +174,7 @@ library src default-extensions: StrictData - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Werror=missing-methods -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 include-dirs: cbits c-sources: @@ -255,7 +255,7 @@ executable ntf-server apps/ntf-server default-extensions: StrictData - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Werror=missing-methods -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts build-depends: aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 @@ -330,7 +330,7 @@ executable smp-agent apps/smp-agent default-extensions: StrictData - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Werror=missing-methods -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts build-depends: aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 @@ -405,7 +405,7 @@ executable smp-server apps/smp-server default-extensions: StrictData - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Werror=missing-methods -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts build-depends: aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 @@ -480,7 +480,7 @@ executable xftp apps/xftp default-extensions: StrictData - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Werror=missing-methods -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts build-depends: aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 @@ -555,7 +555,7 @@ executable xftp-server apps/xftp-server default-extensions: StrictData - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Werror=missing-methods -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts build-depends: aeson ==2.2.* , ansi-terminal >=0.10 && <0.12 @@ -662,7 +662,7 @@ test-suite simplexmq-test tests default-extensions: StrictData - ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts -with-rtsopts=-A64M -with-rtsopts=-N1 + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Werror=missing-methods -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -O2 -threaded -rtsopts -with-rtsopts=-A64M -with-rtsopts=-N1 build-depends: HUnit ==1.6.* , QuickCheck ==2.14.* diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index d04117942..415ead6c0 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -262,9 +262,9 @@ runXFTPRcvLocalWorker c Worker {doWork} = do withStore' c $ \db -> updateRcvFileStatus db rcvFileId RFSDecrypting chunkPaths <- getChunkPaths chunks encSize <- liftIO $ foldM (\s path -> (s +) . fromIntegral <$> getFileSize path) 0 chunkPaths - when (FileSize encSize /= size) $ throwError $ XFTP XFTP.SIZE + when (FileSize encSize /= size) $ throwError $ XFTP "" XFTP.SIZE encDigest <- liftIO $ LC.sha512Hash <$> readChunks chunkPaths - when (FileDigest encDigest /= digest) $ throwError $ XFTP XFTP.DIGEST + when (FileDigest encDigest /= digest) $ throwError $ XFTP "" XFTP.DIGEST let destFile = CryptoFile fsSavePath cfArgs void $ liftError (INTERNAL . show) $ decryptChunks encSize chunkPaths key nonce $ \_ -> pure destFile case redirect of @@ -281,10 +281,11 @@ runXFTPRcvLocalWorker c Worker {doWork} = do -- proceed with redirect yaml <- liftError (INTERNAL . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath) next@FileDescription {chunks = nextChunks} <- case strDecode (LB.toStrict yaml) of - Left _ -> throwError . XFTP $ XFTP.REDIRECT "decode error" + -- TODO switch to another error constructor + Left _ -> throwError . XFTP "" $ XFTP.REDIRECT "decode error" Right (ValidFileDescription fd@FileDescription {size = dstSize, digest = dstDigest}) - | dstSize /= redirectSize -> throwError . XFTP $ XFTP.REDIRECT "size mismatch" - | dstDigest /= redirectDigest -> throwError . XFTP $ XFTP.REDIRECT "digest mismatch" + | dstSize /= redirectSize -> throwError . XFTP "" $ XFTP.REDIRECT "size mismatch" + | dstDigest /= redirectDigest -> throwError . XFTP "" $ XFTP.REDIRECT "digest mismatch" | otherwise -> pure fd -- register and download chunks from the actual file withStore c $ \db -> updateRcvFileRedirect db redirectDbId next diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index f93d03e35..f834927d8 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -158,7 +158,7 @@ import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations -import Simplex.Messaging.Client (ProtocolClient (..), ServerTransmission) +import Simplex.Messaging.Client (ProtocolClient (..), ServerTransmission, TransmissionType (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport (..), pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn) @@ -734,11 +734,9 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv (SCMContact, CR.IKUsePQ) -> throwError $ CMD PROHIBITED _ -> pure () AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config - (rq, qUri) <- newRcvQueue c userId connId srv smpClientVRange subMode `catchAgentError` \e -> liftIO (print e) >> throwError e + (rq, qUri, tSess, sessId) <- newRcvQueue c userId connId srv smpClientVRange subMode `catchAgentError` \e -> liftIO (print e) >> throwError e rq' <- withStore c $ \db -> updateNewConnRcv db connId rq - liftIO $ case subMode of - SMOnlyCreate -> pure () - SMSubscribe -> addSubscription c rq' + lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId when enableNtfs $ do ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (connId, NSCCreate) @@ -863,12 +861,10 @@ joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode createReplyQueue :: AgentClient -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> AM SMPQueueInfo createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} subMode srv = do - (rq, qUri) <- newRcvQueue c userId connId srv (versionToRange smpClientVersion) subMode + (rq, qUri, tSess, sessId) <- newRcvQueue c userId connId srv (versionToRange smpClientVersion) subMode let qInfo = toVersionT qUri smpClientVersion rq' <- withStore c $ \db -> upgradeSndConnToDuplex db connId rq - liftIO $ case subMode of - SMOnlyCreate -> pure () - SMSubscribe -> addSubscription c rq' + lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId when enableNtfs $ do ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (connId, NSCCreate) @@ -928,7 +924,7 @@ subscribeConnections' c connIds = do (subRs, rcvQs) = M.mapEither rcvQueueOrResult cs mapM_ (mapM_ (\(cData, sqs) -> mapM_ (lift . resumeMsgDelivery c cData) sqs) . sndQueue) cs mapM_ (resumeConnCmds c) $ M.keys cs - rcvRs <- lift $ connResults <$> subscribeQueues c (concat $ M.elems rcvQs) + rcvRs <- lift $ connResults . fst <$> subscribeQueues c (concat $ M.elems rcvQs) ns <- asks ntfSupervisor tkn <- readTVarIO (ntfTkn ns) when (instantNotifications tkn) . void . lift . forkIO . void . runExceptT $ sendNtfCreate ns rcvRs conns @@ -1326,13 +1322,13 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq (Worker {doWork Left e -> do let err = if msgType == AM_A_MSG_ then MERR mId e else ERR e case e of - SMP SMP.QUOTA -> case msgType of + SMP _ SMP.QUOTA -> case msgType of AM_CONN_INFO -> connError msgId NOT_AVAILABLE AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE _ -> do expireTs <- addUTCTime (-quotaExceededTimeout) <$> liftIO getCurrentTime if internalTs < expireTs then notifyDelMsgs msgId e expireTs else retrySndMsg RISlow - SMP SMP.AUTH -> case msgType of + SMP _ SMP.AUTH -> case msgType of AM_CONN_INFO -> connError msgId NOT_AVAILABLE AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE AM_RATCHET_INFO -> connError msgId NOT_AVAILABLE @@ -1508,10 +1504,10 @@ switchDuplexConnection c (DuplexConnection cData@ConnData {connId, userId} rqs s -- try to get the server that is different from all queues, or at least from the primary rcv queue srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId $ map qServer (L.toList rqs) <> map qServer (L.toList sqs) srv' <- if srv == server then getNextServer c userId [server] else pure srvAuth - (q, qUri) <- newRcvQueue c userId connId srv' clientVRange SMSubscribe + (q, qUri, tSess, sessId) <- newRcvQueue c userId connId srv' clientVRange SMSubscribe let rq' = (q :: NewRcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} rq'' <- withStore c $ \db -> addConnRcvQueue db connId rq' - liftIO $ addSubscription c rq'' + lift $ addNewQueueSubscription c rq'' tSess sessId void . enqueueMessages c cData sqs SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))] rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSendingQADD let rqs' = updatedQs rq1 rqs <> [rq''] @@ -1565,7 +1561,7 @@ synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchroni ackQueueMessage :: AgentClient -> RcvQueue -> SMP.MsgId -> AM () ackQueueMessage c rq srvMsgId = sendAck c rq srvMsgId `catchAgentError` \case - SMP SMP.NO_MSG -> pure () + SMP _ SMP.NO_MSG -> pure () e -> throwError e -- | Suspend SMP agent connection (OFF command) in Reader monad @@ -1895,7 +1891,7 @@ deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do withStore' c $ \db -> updateNtfToken db tkn ntfTknStatus ntfTknAction atomically $ nsUpdateToken ns tkn {ntfTknStatus, ntfTknAction} agentNtfDeleteToken c tknId tkn `catchAgentError` \case - NTF AUTH -> pure () + NTF _ AUTH -> pure () e -> throwError e withStore' c $ \db -> removeNtfToken db tkn atomically $ nsRemoveNtfToken ns @@ -1912,7 +1908,7 @@ withToken c tkn@NtfToken {deviceToken, ntfMode} from_ (toStatus, toAction_) f = let updatedToken = tkn {ntfTknStatus = toStatus, ntfTknAction = toAction_} atomically $ nsUpdateToken ns updatedToken pure toStatus - Left e@(NTF AUTH) -> do + Left e@(NTF _ AUTH) -> do withStore' c $ \db -> removeNtfToken db tkn atomically $ nsRemoveNtfToken ns void $ registerNtfToken' c deviceToken ntfMode @@ -1995,11 +1991,13 @@ getSMPServer c userId = withUserServers c userId pickServer {-# INLINE getSMPServer #-} subscriber :: AgentClient -> AM' () -subscriber c@AgentClient {msgQ} = forever $ do +subscriber c@AgentClient {subQ, msgQ} = forever $ do t <- atomically $ readTBQueue msgQ agentOperationBracket c AORcvNetwork waitUntilActive $ runExceptT (processSMPTransmission c t) >>= \case - Left e -> liftIO $ print e + Left e -> do + logError $ tshow e + atomically $ writeTBQueue subQ ("", "", APC SAEConn $ ERR e) Right _ -> return () cleanupManager :: AgentClient -> AM' () @@ -2076,8 +2074,8 @@ data ACKd = ACKd | ACKPending -- | make sure to ACK or throw in each message processing branch -- it cannot be finally, unfortunately, as sometimes it needs to be ACK+DEL -processSMPTransmission :: AgentClient -> ServerTransmission SMPVersion BrokerMsg -> AM () -processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, sessId, isResponse, rId, cmd) = do +processSMPTransmission :: AgentClient -> ServerTransmission SMPVersion ErrorType BrokerMsg -> AM () +processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, sessId, tType, rId, cmd) = do (rq, SomeConn _ conn) <- withStore c (\db -> getRcvConn db srv rId) processSMP rq conn $ toConnData conn where @@ -2087,14 +2085,15 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, conn cData@ConnData {userId, connId, connAgentVersion, ratchetSyncState = rss} = withConnLock c connId "processSMP" $ case cmd of - SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> + Right (SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId}) -> void . handleNotifyAck $ do + isGET <- atomically $ hasGetLock c rq + unless isGET checkExpiredResponse msg' <- decryptSMPMessage rq msg ack' <- handleNotifyAck $ case msg' of SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody SMP.ClientRcvMsgQuota {} -> queueDrained >> ack - whenM (atomically $ hasGetLock c rq) $ - notify (MSGNTF $ SMP.rcvMessageMeta srvMsgId msg') + when isGET $ notify (MSGNTF $ SMP.rcvMessageMeta srvMsgId msg') pure ack' where queueDrained = case conn of @@ -2237,7 +2236,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd handleNotifyAck :: AM ACKd -> AM ACKd handleNotifyAck m = m `catchAgentError` \e -> notify (ERR e) >> ack - SMP.END -> + Right SMP.END -> atomically (TM.lookup tSess smpClients $>>= (tryReadTMVar . sessionVar) >>= processEND) >>= logServer "<--" c srv rId where @@ -2250,9 +2249,10 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, | otherwise -> ignored _ -> ignored ignored = pure "END from disconnected client - ignored" - _ -> do - logServer "<--" c srv rId $ "unexpected: " <> bshow cmd - notify . ERR $ BROKER (B.unpack $ strEncode srv) $ if isResponse then TIMEOUT else UNEXPECTED + Right (SMP.ERR e) -> notify $ ERR $ SMP (B.unpack $ strEncode srv) e + Right SMP.OK -> checkExpiredResponse + Right _ -> unexpected + Left e -> notify $ ERR $ protocolClientError SMP (B.unpack $ strEncode srv) e where notify :: forall e m. MonadIO m => AEntityI e => ACommand 'Agent e -> m () notify = atomically . notify' @@ -2266,6 +2266,27 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, enqueueCmd :: InternalCommand -> AM () enqueueCmd = enqueueCommand c "" connId (Just srv) . AInternalCommand + unexpected :: AM () + unexpected = do + logServer "<--" c srv rId $ "unexpected: " <> bshow cmd + -- TODO add extended information about transmission type once UNEXPECTED has string + notify . ERR $ BROKER (B.unpack $ strEncode srv) UNEXPECTED + + checkExpiredResponse :: AM () + checkExpiredResponse = case tType of + TTEvent -> pure () + TTUncorrelatedResponse -> unexpected + TTExpiredResponse (SMP.Cmd _ cmd') -> case cmd' of + SMP.SUB -> do + added <- + atomically $ + ifM + ((&&) <$> hasPendingSubscription c connId <*> activeClientSession c tSess sessId) + (True <$ addSubscription c rq) + (pure False) + when added $ notify $ UP srv [connId] + _ -> pure () + decryptClientMessage :: C.DhSecretX25519 -> SMP.ClientMsgEnvelope -> AM (SMP.PrivHeader, AgentMsgEnvelope) decryptClientMessage e2eDh SMP.ClientMsgEnvelope {cmNonce, cmEncBody} = do clientMsg <- agentCbDecrypt e2eDh cmNonce cmEncBody diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index cfed6b932..d9b059d59 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -41,6 +41,7 @@ module Simplex.Messaging.Agent.Client getQueueMessage, decryptSMPMessage, addSubscription, + addNewQueueSubscription, getSubscriptions, sendConfirmation, sendInvitation, @@ -77,11 +78,14 @@ module Simplex.Messaging.Agent.Client logSecret, removeSubscription, hasActiveSubscription, + hasPendingSubscription, hasGetLock, + activeClientSession, agentClientStore, agentDRG, getAgentSubscriptions, slowNetworkConfig, + protocolClientError, Worker (..), SessionVar (..), SubscriptionsInfo (..), @@ -152,7 +156,7 @@ import Data.Bifunctor (bimap, first, second) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Either (lefts, partitionEithers) +import Data.Either (partitionEithers) import Data.Functor (($>)) import Data.Int (Int64) import Data.List (deleteFirstsBy, foldl', partition, (\\)) @@ -229,7 +233,7 @@ import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (SMPVersion) +import Simplex.Messaging.Transport (SMPVersion, SessionId, THandleParams (sessionId)) import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util import Simplex.Messaging.Version @@ -260,7 +264,7 @@ data AgentClient = AgentClient active :: TVar Bool, rcvQ :: TBQueue (ATransmission 'Client), subQ :: TBQueue (ATransmission 'Agent), - msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg), + msgQ :: TBQueue (ServerTransmission SMPVersion ErrorType BrokerMsg), smpServers :: TMap UserId (NonEmpty SMPServerWithAuth), smpClients :: TMap SMPTransportSession SMPClientVar, ntfServers :: TVar [NtfServer], @@ -511,7 +515,7 @@ agentDRG AgentClient {agentEnv = Env {random}} = random class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where type Client msg = c | c -> msg getProtocolServerClient :: AgentClient -> TransportSession msg -> AM (Client msg) - clientProtocolError :: err -> AgentErrorType + clientProtocolError :: HostName -> err -> AgentErrorType closeProtocolServerClient :: Client msg -> IO () clientServer :: Client msg -> String clientTransportHost :: Client msg -> TransportHost @@ -644,7 +648,7 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do resubscribe :: AM () resubscribe = do cs <- readTVarIO $ RQ.getConnections $ activeSubs c - rs <- lift . subscribeQueues c $ L.toList qs + (rs, sessId_) <- lift . subscribeQueues c $ L.toList qs let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs liftIO $ do let conns = filter (`M.notMember` cs) okConns @@ -653,7 +657,10 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs forM_ (listToMaybe tempErrs) $ \(_, err) -> do when (null okConns && M.null cs && null finalErrs) . liftIO $ - closeClient c smpClients tSess + forM_ sessId_ $ \sessId -> do + -- We only close the client session that was used to subscribe. + v_ <- atomically $ ifM (activeClientSession c tSess sessId) (TM.lookupDelete tSess $ smpClients c) (pure Nothing) + mapM_ (closeClient_ c) v_ throwError err notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO () notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd) @@ -938,13 +945,13 @@ withXFTPClient c (userId, srv, entityId) cmdStr action = do tSess <- liftIO $ mkTransportSession c userId srv entityId withLogClient c tSess entityId cmdStr action -liftClient :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> AM a +liftClient :: (Show err, Encoding err) => (HostName -> err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> AM a liftClient protocolError_ = liftError . protocolClientError protocolError_ {-# INLINE liftClient #-} -protocolClientError :: (Show err, Encoding err) => (err -> AgentErrorType) -> HostName -> ProtocolClientError err -> AgentErrorType +protocolClientError :: (Show err, Encoding err) => (HostName -> err -> AgentErrorType) -> HostName -> ProtocolClientError err -> AgentErrorType protocolClientError protocolError_ host = \case - PCEProtocolError e -> protocolError_ e + PCEProtocolError e -> protocolError_ host e PCEResponseError e -> BROKER host $ RESPONSE $ B.unpack $ smpEncode e PCEUnexpectedResponse _ -> BROKER host UNEXPECTED PCEResponseTimeout -> BROKER host TIMEOUT @@ -1023,7 +1030,7 @@ runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do liftError (testErr TSUploadFile) $ X.uploadXFTPChunk xftp spKey sId chunkSpec liftError (testErr TSDownloadFile) $ X.downloadXFTPChunk g xftp rpKey rId $ XFTPRcvChunkSpec rcvPath chSize digest rcvDigest <- liftIO $ C.sha256Hash <$> B.readFile rcvPath - unless (digest == rcvDigest) $ throwError $ ProtocolTestFailure TSCompareFile $ XFTP DIGEST + unless (digest == rcvDigest) $ throwError $ ProtocolTestFailure TSCompareFile $ XFTP (B.unpack $ strEncode srv) DIGEST liftError (testErr TSDeleteFile) $ X.deleteXFTPChunk xftp spKey sId ok <- tcpTimeout xftpNetworkConfig `timeout` X.closeXFTPClient xftp incClientStat c userId xftp "XFTP_TEST" "OK" @@ -1098,7 +1105,7 @@ getSessionMode :: AgentClient -> IO TransportSessionMode getSessionMode = atomically . fmap sessionMode . getNetworkConfig {-# INLINE getSessionMode #-} -newRcvQueue :: AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> AM (NewRcvQueue, SMPQueueUri) +newRcvQueue :: AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> AM (NewRcvQueue, SMPQueueUri, SMPTransportSession, SessionId) newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do C.AuthAlg a <- asks (rcvAuthAlg . config) g <- asks random @@ -1107,8 +1114,9 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do (e2eDhKey, e2ePrivKey) <- atomically $ C.generateKeyPair g logServer "-->" c srv "" "NEW" tSess <- liftIO $ mkTransportSession c userId srv connId - QIK {rcvId, sndId, rcvPublicDhKey} <- - withClient c tSess "NEW" $ \smp -> createSMPQueue smp rKeys dhKey auth subMode + (sessId, QIK {rcvId, sndId, rcvPublicDhKey}) <- + withClient c tSess "NEW" $ \smp -> + (sessionId $ thParams smp,) <$> createSMPQueue smp rKeys dhKey auth subMode liftIO . logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId] let rq = RcvQueue @@ -1130,17 +1138,18 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do clientNtfCreds = Nothing, deleteErrors = 0 } - pure (rq, SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey) + qUri = SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey + pure (rq, qUri, tSess, sessId) -processSubResult :: AgentClient -> RcvQueue -> Either SMPClientError () -> IO (Either SMPClientError ()) -processSubResult c rq r = do - case r of - Left e -> - unless (temporaryClientError e) . atomically $ do - RQ.deleteQueue rq (pendingSubs c) - TM.insert (RQ.qKey rq) e (removedSubs c) - _ -> addSubscription c rq - pure r +processSubResult :: AgentClient -> RcvQueue -> Either SMPClientError () -> STM () +processSubResult c rq@RcvQueue {connId} = \case + Left e -> + unless (temporaryClientError e) $ do + RQ.deleteQueue rq (pendingSubs c) + TM.insert (RQ.qKey rq) e (removedSubs c) + Right () -> + whenM (hasPendingSubscription c connId) $ + addSubscription c rq temporaryAgentError :: AgentErrorType -> Bool temporaryAgentError = \case @@ -1157,7 +1166,7 @@ temporaryOrHostError = \case {-# INLINE temporaryOrHostError #-} -- | Subscribe to queues. The list of results can have a different order. -subscribeQueues :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())] +subscribeQueues :: AgentClient -> [RcvQueue] -> AM' ([(RcvQueue, Either AgentErrorType ())], Maybe SessionId) subscribeQueues c qs = do (errs, qs') <- partitionEithers <$> mapM checkQueue qs atomically $ do @@ -1165,20 +1174,43 @@ subscribeQueues c qs = do RQ.batchAddQueues (pendingSubs c) qs' env <- ask -- only "checked" queues are subscribed - (errs <>) <$> sendTSessionBatches "SUB" 90 id (subscribeQueues_ env) c qs' + session <- newTVarIO Nothing + rs <- sendTSessionBatches "SUB" 90 id (subscribeQueues_ env session) c qs' + (errs <> rs,) <$> readTVarIO session where checkQueue rq = do prohibited <- atomically $ hasGetLock c rq pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED) else Right rq - subscribeQueues_ :: Env -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError ()) - subscribeQueues_ env smp qs' = do + subscribeQueues_ :: Env -> TVar (Maybe SessionId) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError ()) + subscribeQueues_ env session smp qs' = do rs <- sendBatch subscribeSMPQueues smp qs' - mapM_ (uncurry $ processSubResult c) rs - when (any temporaryClientError . lefts . map snd $ L.toList rs) $ - runReaderT (resubscribeSMPSession c $ transportSession' smp) env - pure rs + active <- + atomically $ + ifM + (activeClientSession c tSess sessId) + (writeTVar session (Just sessId) >> processSubResults rs $> True) + (pure False) + if active + then when (hasTempErrors rs) resubscribe $> rs + else do + logWarn "subcription batch result for replaced SMP client, resubscribing" + resubscribe $> L.map (second $ \_ -> Left PCENetworkError) rs + where + tSess = transportSession' smp + sessId = sessionId $ thParams smp + hasTempErrors = any (either temporaryClientError (const False) . snd) + processSubResults :: NonEmpty (RcvQueue, Either SMPClientError ()) -> STM () + processSubResults = mapM_ $ uncurry $ processSubResult c + resubscribe = resubscribeSMPSession c tSess `runReaderT` env -type BatchResponses e r = (NonEmpty (RcvQueue, Either e r)) +activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool +activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClients c) + where + sameSess = \case + Just (Right smp) -> sessId == sessionId (thParams smp) + _ -> False + +type BatchResponses e r = NonEmpty (RcvQueue, Either e r) -- statBatchSize is not used to batch the commands, only for traffic statistics sendTSessionBatches :: forall q r. ByteString -> Int -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses SMPClientError r)) -> AgentClient -> [q] -> AM' [(RcvQueue, Either AgentErrorType r)] @@ -1213,16 +1245,35 @@ sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs) where queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) -addSubscription :: AgentClient -> RcvQueue -> IO () -addSubscription c rq@RcvQueue {connId} = atomically $ do +addSubscription :: AgentClient -> RcvQueue -> STM () +addSubscription c rq@RcvQueue {connId} = do modifyTVar' (subscrConns c) $ S.insert connId RQ.addQueue rq $ activeSubs c RQ.deleteQueue rq $ pendingSubs c +addPendingSubscription :: AgentClient -> RcvQueue -> STM () +addPendingSubscription c rq@RcvQueue {connId} = do + modifyTVar' (subscrConns c) $ S.insert connId + RQ.addQueue rq $ pendingSubs c + +addNewQueueSubscription :: AgentClient -> RcvQueue -> SMPTransportSession -> SessionId -> AM' () +addNewQueueSubscription c rq tSess sessId = do + same <- + atomically $ + ifM + (activeClientSession c tSess sessId) + (True <$ addSubscription c rq) + (False <$ addPendingSubscription c rq) + unless same $ resubscribeSMPSession c tSess + hasActiveSubscription :: AgentClient -> ConnId -> STM Bool hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c {-# INLINE hasActiveSubscription #-} +hasPendingSubscription :: AgentClient -> ConnId -> STM Bool +hasPendingSubscription c connId = RQ.hasConn connId $ pendingSubs c +{-# INLINE hasPendingSubscription #-} + removeSubscription :: AgentClient -> ConnId -> STM () removeSubscription c connId = do modifyTVar' (subscrConns c) $ S.delete connId diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 98db26ab4..e136a8bbb 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -193,13 +193,13 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet ( InitialKeys (..), PQEncryption (..), - pattern PQEncOff, PQSupport, - pattern PQSupportOn, - pattern PQSupportOff, RcvE2ERatchetParams, RcvE2ERatchetParamsUri, - SndE2ERatchetParams + SndE2ERatchetParams, + pattern PQEncOff, + pattern PQSupportOff, + pattern PQSupportOn, ) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -213,14 +213,14 @@ import Simplex.Messaging.Protocol MsgId, NMsgMeta, ProtocolServer (..), + SMPClientVersion, SMPMsgMeta, SMPServer, SMPServerWithAuth, SndPublicAuthKey, SubscriptionMode, - SMPClientVersion, - VersionSMPC, VersionRangeSMPC, + VersionSMPC, initialSMPClientVersion, legacyEncodeServer, legacyServerP, @@ -908,7 +908,7 @@ instance Encoding AgentMsgEnvelope where -- AgentRatchetInfo is not encrypted with double ratchet, but with per-queue E2E encryption data AgentMessage = -- used by the initiating party when confirming reply queue - AgentConnInfo ConnInfo + AgentConnInfo ConnInfo | -- AgentConnInfoReply is used by accepting party in duplexHandshake mode (v2), allowing to include reply queue(s) in the initial confirmation. -- It made removed REPLY message unnecessary. AgentConnInfoReply (NonEmpty SMPQueueInfo) ConnInfo @@ -1382,9 +1382,9 @@ deriving instance Show (ConnectionRequestUri m) data AConnectionRequestUri = forall m. ConnectionModeI m => ACR (SConnectionMode m) (ConnectionRequestUri m) instance Eq AConnectionRequestUri where - ACR m cr == ACR m' cr' = case testEquality m m' of - Just Refl -> cr == cr' - _ -> False + ACR m cr == ACR m' cr' = case testEquality m m' of + Just Refl -> cr == cr' + _ -> False deriving instance Show AConnectionRequestUri @@ -1469,11 +1469,11 @@ data AgentErrorType | -- | connection errors CONN {connErr :: ConnectionErrorType} | -- | SMP protocol errors forwarded to agent clients - SMP {smpErr :: ErrorType} + SMP {serverAddress :: String, smpErr :: ErrorType} | -- | NTF protocol errors forwarded to agent clients - NTF {ntfErr :: ErrorType} + NTF {serverAddress :: String, ntfErr :: ErrorType} | -- | XFTP protocol errors forwarded to agent clients - XFTP {xftpErr :: XFTPErrorType} + XFTP {serverAddress :: String, xftpErr :: XFTPErrorType} | -- | XRCP protocol errors forwarded to agent clients RCP {rcpErr :: RCErrorType} | -- | SMP server errors @@ -1584,9 +1584,9 @@ instance StrEncoding AgentErrorType where strP = "CMD " *> (CMD <$> parseRead1) <|> "CONN " *> (CONN <$> parseRead1) - <|> "SMP " *> (SMP <$> strP) - <|> "NTF " *> (NTF <$> strP) - <|> "XFTP " *> (XFTP <$> strP) + <|> "SMP " *> (SMP <$> textP <*> _strP) + <|> "NTF " *> (NTF <$> textP <*> _strP) + <|> "XFTP " *> (XFTP <$> textP <*> _strP) <|> "RCP " *> (RCP <$> strP) <|> "BROKER " *> (BROKER <$> textP <* " RESPONSE " <*> (RESPONSE <$> textP)) <|> "BROKER " *> (BROKER <$> textP <* " TRANSPORT " <*> (TRANSPORT <$> transportErrorP)) @@ -1602,9 +1602,9 @@ instance StrEncoding AgentErrorType where strEncode = \case CMD e -> "CMD " <> bshow e CONN e -> "CONN " <> bshow e - SMP e -> "SMP " <> strEncode e - NTF e -> "NTF " <> strEncode e - XFTP e -> "XFTP " <> strEncode e + SMP srv e -> "SMP " <> text srv <> " " <> strEncode e + NTF srv e -> "NTF " <> text srv <> " " <> strEncode e + XFTP srv e -> "XFTP " <> text srv <> " " <> strEncode e RCP e -> "RCP " <> strEncode e BROKER srv (RESPONSE e) -> "BROKER " <> text srv <> " RESPONSE " <> text e BROKER srv (TRANSPORT e) -> "BROKER " <> text srv <> " TRANSPORT " <> serializeTransportError e diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 176602f4b..6cee90839 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -70,6 +70,7 @@ module Simplex.Messaging.Client proxyUsername, temporaryClientError, ServerTransmission, + TransmissionType (..), ClientCommand, -- * For testing @@ -80,9 +81,11 @@ module Simplex.Messaging.Client ) where +import Control.Applicative ((<|>)) import Control.Concurrent.Async import Control.Concurrent.STM import Control.Exception +import Control.Logger.Simple import Control.Monad import Control.Monad.Except import Control.Monad.IO.Class (liftIO) @@ -110,7 +113,7 @@ import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client (SocksProxy, TransportClientConfig (..), TransportHost (..), defaultTcpConnectTimeout, runTransportClient) import Simplex.Messaging.Transport.KeepAlive import Simplex.Messaging.Transport.WebSockets (WS) -import Simplex.Messaging.Util (bshow, diffToMicroseconds, raceAny_, threadDelay', whenM) +import Simplex.Messaging.Util (bshow, diffToMicroseconds, ifM, raceAny_, threadDelay', tshow, whenM) import Simplex.Messaging.Version import System.Timeout (timeout) @@ -129,15 +132,14 @@ data PClient v err msg = PClient transportSession :: TransportSession msg, transportHost :: TransportHost, tcpTimeout :: Int, - rcvConcurrency :: Int, sendPings :: TVar Bool, lastReceived :: TVar UTCTime, timeoutErrorCount :: TVar Int, clientCorrId :: TVar ChaChaDRG, sentCommands :: TMap CorrId (Request err msg), - sndQ :: TBQueue (TVar Bool, ByteString), + sndQ :: TBQueue ByteString, rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)), - msgQ :: Maybe (TBQueue (ServerTransmission v msg)) + msgQ :: Maybe (TBQueue (ServerTransmission v err msg)) } smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe (THandleAuth 'TClient) -> STM SMPClient @@ -170,7 +172,6 @@ smpClientStub g sessionId thVersion thAuth = do transportSession = (1, "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001", Nothing), transportHost = "localhost", tcpTimeout = 15_000_000, - rcvConcurrency = 8, sendPings, lastReceived, timeoutErrorCount, @@ -188,7 +189,9 @@ type SMPClient = ProtocolClient SMPVersion ErrorType BrokerMsg type ClientCommand msg = (Maybe C.APrivateAuthKey, EntityId, ProtoCommand msg) -- | Type synonym for transmission from some SPM server queue. -type ServerTransmission v msg = (TransportSession msg, Version v, SessionId, Bool, EntityId, msg) +type ServerTransmission v err msg = (TransportSession msg, Version v, SessionId, TransmissionType msg, EntityId, Either (ProtocolClientError err) msg) + +data TransmissionType msg = TTEvent | TTUncorrelatedResponse | TTExpiredResponse (ProtoCommand msg) data HostMode = -- | prefer (or require) onion hosts when connecting via SOCKS proxy @@ -287,6 +290,8 @@ defaultSMPClientConfig = defaultClientConfig (Just supportedSMPHandshakes) suppo data Request err msg = Request { corrId :: CorrId, entityId :: EntityId, + command :: ProtoCommand msg, + pending :: TVar Bool, responseVar :: TMVar (Either (ProtocolClientError err) msg) } @@ -333,7 +338,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId) -- -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. -getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig v -> Maybe (TBQueue (ServerTransmission v msg)) -> (ProtocolClient v err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg)) +getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig v -> Maybe (TBQueue (ServerTransmission v err msg)) -> (ProtocolClient v err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg)) getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, clientALPN, serverVRange, agreeSecret} msgQ disconnected = do case chooseTransportHost networkConfig (host srv) of Right useHost -> @@ -341,7 +346,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize `catch` \(e :: IOException) -> pure . Left $ PCEIOError e Left e -> pure $ Left e where - NetworkConfig {tcpConnectTimeout, tcpTimeout, rcvConcurrency, smpPingInterval} = networkConfig + NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig mkProtocolClient :: TransportHost -> UTCTime -> STM (PClient v err msg) mkProtocolClient transportHost ts = do connected <- newTVar False @@ -363,7 +368,6 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize timeoutErrorCount, clientCorrId, sentCommands, - rcvConcurrency, sndQ, rcvQ, msgQ @@ -402,11 +406,11 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize atomically $ do writeTVar (connected c) True putTMVar cVar $ Right c' - raceAny_ ([send c' th, process c', receive c' th] <> [ping c' | smpPingInterval > 0]) + raceAny_ ([send c' th, process c', receive c' th] <> [monitor c' | smpPingInterval > 0]) `finally` disconnected c' send :: Transport c => ProtocolClient v err msg -> THandle v c 'TClient -> IO () - send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= \(active, s) -> whenM (readTVarIO active) (void $ tPutLog h s) + send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= void . tPutLog h receive :: Transport c => ProtocolClient v err msg -> THandle v c 'TClient -> IO () receive ProtocolClient {client_ = PClient {rcvQ, lastReceived, timeoutErrorCount}} h = forever $ do @@ -414,8 +418,8 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize getCurrentTime >>= atomically . writeTVar lastReceived atomically $ writeTVar timeoutErrorCount 0 - ping :: ProtocolClient v err msg -> IO () - ping c@ProtocolClient {client_ = PClient {sendPings, lastReceived, timeoutErrorCount}} = loop smpPingInterval + monitor :: ProtocolClient v err msg -> IO () + monitor c@ProtocolClient {client_ = PClient {sendPings, lastReceived, timeoutErrorCount}} = loop smpPingInterval where loop :: Int64 -> IO () loop delay = do @@ -439,27 +443,34 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize processMsg :: ProtocolClient v err msg -> SignedTransmission err msg -> IO () processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr)) - | isResponse = + | not $ B.null $ bs corrId = atomically (TM.lookup corrId sentCommands) >>= \case - Nothing -> sendMsg respOrErr - Just Request {entityId, responseVar} -> atomically $ do - TM.delete corrId sentCommands - putTMVar responseVar $ response entityId - | otherwise = sendMsg respOrErr + Nothing -> sendMsg TTUncorrelatedResponse + Just Request {entityId, command, pending, responseVar} -> do + wasPending <- + atomically $ do + TM.delete corrId sentCommands + ifM + (swapTVar pending False) + (True <$ tryPutTMVar responseVar (response entityId)) + (pure False) + unless wasPending $ sendMsg $ if entityId == entId then TTExpiredResponse command else TTUncorrelatedResponse + | otherwise = sendMsg TTEvent where - isResponse = not $ B.null $ bs corrId response entityId - | entityId == entId = - case respOrErr of - Left e -> Left $ PCEResponseError e - Right r -> case protocolError r of - Just e -> Left $ PCEProtocolError e - _ -> Right r + | entityId == entId = clientResp | otherwise = Left . PCEUnexpectedResponse $ bshow respOrErr - sendMsg :: Either err msg -> IO () - sendMsg = \case - Right msg -> atomically $ mapM_ (`writeTBQueue` serverTransmission c isResponse entId msg) msgQ - Left e -> putStrLn $ "SMP client error: " <> show e + clientResp = case respOrErr of + Left e -> Left $ PCEResponseError e + Right r -> case protocolError r of + Just e -> Left $ PCEProtocolError e + _ -> Right r + sendMsg :: TransmissionType msg -> IO () + sendMsg tType = case msgQ of + Just q -> atomically $ writeTBQueue q $ serverTransmission c tType entId clientResp + Nothing -> case clientResp of + Left e -> logError $ "SMP client error: " <> tshow e + Right _ -> logWarn $ "SMP client unprocessed event" proxyUsername :: TransportSession msg -> ByteString proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_ @@ -558,11 +569,11 @@ processSUBResponse c (Response rId r) = case r of Left e -> pure $ Left e writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO () -writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c False rId msg) (msgQ $ client_ c) +writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c TTEvent rId (Right msg)) (msgQ $ client_ c) -serverTransmission :: ProtocolClient v err msg -> Bool -> RecipientId -> msg -> ServerTransmission v msg -serverTransmission ProtocolClient {thParams = THandleParams {thVersion, sessionId}, client_ = PClient {transportSession}} isResponse entityId message = - (transportSession, thVersion, sessionId, isResponse, entityId, message) +serverTransmission :: ProtocolClient v err msg -> TransmissionType msg -> RecipientId -> Either (ProtocolClientError err) msg -> ServerTransmission v err msg +serverTransmission ProtocolClient {thParams = THandleParams {thVersion, sessionId}, client_ = PClient {transportSession}} tType entityId msgOrErr = + (transportSession, thVersion, sessionId, tType, entityId, msgOrErr) -- | Get message from SMP queue. The server returns ERR PROHIBITED if a client uses SUB and GET via the same transport connection for the same queue -- @@ -687,7 +698,7 @@ sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd) type PCTransmission err msg = (Either TransportError SentRawTransmission, Request err msg) -- | Send multiple commands with batching and collect responses -sendProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) +sendProtocolCommands :: forall v err msg. Protocol v err msg => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs = do bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs validate . concat =<< mapM (sendBatch c) bs @@ -704,30 +715,28 @@ sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSiz where diff = L.length cs - length rs -streamProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () +streamProtocolCommands :: forall v err msg. Protocol v err msg => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () streamProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs cb = do bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs mapM_ (cb <=< sendBatch c) bs sendBatch :: ProtocolClient v err msg -> TransportBatch (Request err msg) -> IO [Response err msg] -sendBatch c@ProtocolClient {client_ = PClient {rcvConcurrency, sndQ}} b = do +sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do case b of TBError e Request {entityId} -> do putStrLn "send error: large message" pure [Response entityId $ Left $ PCETransportError e] TBTransmissions s n rs | n > 0 -> do - active <- newTVarIO True - atomically $ writeTBQueue sndQ (active, s) - mapConcurrently (getResponse c active) rs + atomically $ writeTBQueue sndQ s + mapConcurrently (getResponse c) rs | otherwise -> pure [] TBTransmission s r -> do - active <- newTVarIO True - atomically $ writeTBQueue sndQ (active, s) - (: []) <$> getResponse c active r + atomically $ writeTBQueue sndQ s + (: []) <$> getResponse c r -- | Send Protocol command -sendProtocolCommand :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg +sendProtocolCommand :: forall v err msg. Protocol v err msg => ProtocolClient v err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THandleParams {batch, blockSize}} pKey entId cmd = ExceptT $ uncurry sendRecv =<< mkTransmission c (pKey, entId, cmd) where @@ -738,30 +747,30 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THand Right t | B.length s > blockSize - 2 -> pure . Left $ PCETransportError TELargeMsg | otherwise -> do - active <- newTVarIO True - atomically (writeTBQueue sndQ (active, s)) - response <$> getResponse c active r + atomically $ writeTBQueue sndQ s + response <$> getResponse c r where s | batch = tEncodeBatch1 t | otherwise = tEncode t --- TODO switch to timeout or TimeManager that supports Int64 -getResponse :: ProtocolClient v err msg -> TVar Bool -> Request err msg -> IO (Response err msg) -getResponse ProtocolClient {client_ = PClient {tcpTimeout, timeoutErrorCount, sentCommands}} active Request {corrId, entityId, responseVar} = do - response <- - timeout tcpTimeout (atomically (takeTMVar responseVar)) >>= \case - Just r -> atomically (writeTVar timeoutErrorCount 0) $> r - Nothing -> do - atomically (writeTVar active False >> TM.delete corrId sentCommands) - atomically $ modifyTVar' timeoutErrorCount (+ 1) - pure $ Left PCEResponseTimeout +getResponse :: ProtocolClient v err msg -> Request err msg -> IO (Response err msg) +getResponse ProtocolClient {client_ = PClient {tcpTimeout, timeoutErrorCount}} Request {entityId, pending, responseVar} = do + r <- tcpTimeout `timeout` atomically (takeTMVar responseVar) + response <- atomically $ do + writeTVar pending False + -- Try to read response again in case it arrived after timeout expired + -- but before `pending` was set to False above. + -- See `processMsg`. + ((r <|>) <$> tryTakeTMVar responseVar) >>= \case + Just r' -> writeTVar timeoutErrorCount 0 $> r' + Nothing -> modifyTVar' timeoutErrorCount (+ 1) $> Left PCEResponseTimeout pure Response {entityId, response} -mkTransmission :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> ClientCommand msg -> IO (PCTransmission err msg) -mkTransmission ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCommands}} (pKey_, entId, cmd) = do +mkTransmission :: forall v err msg. Protocol v err msg => ProtocolClient v err msg -> ClientCommand msg -> IO (PCTransmission err msg) +mkTransmission ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCommands}} (pKey_, entityId, command) = do corrId <- atomically getNextCorrId - let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, entId, cmd) + let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, entityId, command) auth = authTransmission (thAuth thParams) pKey_ corrId tForAuth r <- atomically $ mkRequest corrId pure ((,tToSend) <$> auth, r) @@ -770,7 +779,16 @@ mkTransmission ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCo getNextCorrId = CorrId <$> C.randomBytes 24 clientCorrId -- also used as nonce mkRequest :: CorrId -> STM (Request err msg) mkRequest corrId = do - r <- Request corrId entId <$> newEmptyTMVar + pending <- newTVar True + responseVar <- newEmptyTMVar + let r = + Request + { corrId, + entityId, + command, + pending, + responseVar + } TM.insert corrId r sentCommands pure r diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 4b925c6f6..aed56f1bd 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -39,7 +39,7 @@ import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (BrokerMsg, NotifierId, NtfPrivateAuthKey, ProtocolServer (..), QueueId, RcvPrivateAuthKey, RecipientId, SMPServer) +import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, NotifierId, NtfPrivateAuthKey, ProtocolServer (..), QueueId, RcvPrivateAuthKey, RecipientId, SMPServer) import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -94,7 +94,7 @@ defaultSMPClientAgentConfig = data SMPClientAgent = SMPClientAgent { agentCfg :: SMPClientAgentConfig, - msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg), + msgQ :: TBQueue (ServerTransmission SMPVersion ErrorType BrokerMsg), agentQ :: TBQueue SMPClientAgentEvent, randomDrg :: TVar ChaChaDRG, smpClients :: TMap SMPServer SMPClientVar, diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index b79665c87..37eff1e94 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -218,10 +218,10 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge receiveSMP :: M () receiveSMP = forever $ do - ((_, srv, _), _, _, _, ntfId, msg) <- atomically $ readTBQueue msgQ + ((_, srv, _), _, _, _tType, ntfId, msgOrErr) <- atomically $ readTBQueue msgQ let smpQueue = SMPQueueNtf srv ntfId - case msg of - SMP.NMSG nmsgNonce encNMsgMeta -> do + case msgOrErr of + Right (SMP.NMSG nmsgNonce encNMsgMeta) -> do ntfTs <- liftIO getSystemTime st <- asks store NtfPushServer {pushQ} <- asks pushServer @@ -231,8 +231,10 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge findNtfSubscriptionToken st smpQueue >>= mapM_ (\tkn -> writeTBQueue pushQ (tkn, PNMessage PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta})) incNtfStat ntfReceived - SMP.END -> updateSubStatus smpQueue NSEnd - _ -> pure () + Right SMP.END -> updateSubStatus smpQueue NSEnd + Right (SMP.ERR e) -> logError $ "SMP server error: " <> tshow e + Right _ -> logError $ "SMP server unexpected response" + Left e -> logError $ "SMP client error: " <> tshow e receiveAgent = forever $ diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 85beccd6b..a09759814 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -894,9 +894,10 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Serv Just msg -> let encMsg = encryptMsg qr msg in atomically (setDelivered s msg) $> (corrId, rId, MSG encMsg) - _ -> forkSub $> ok - _ -> pure ok + _ -> forkSub $> resp + _ -> pure resp where + resp = (corrId, rId, OK) forkSub :: M () forkSub = do atomically . modifyTVar' sub $ \s -> s {subThread = SubPending} diff --git a/src/Simplex/Messaging/Session.hs b/src/Simplex/Messaging/Session.hs index 7a219e106..75543b481 100644 --- a/src/Simplex/Messaging/Session.hs +++ b/src/Simplex/Messaging/Session.hs @@ -10,6 +10,7 @@ import Data.Composition ((.:.)) import Data.Functor (($>)) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Util (($>>=)) data SessionVar a = SessionVar { sessionVar :: TMVar a, @@ -36,3 +37,6 @@ removeSessVar' v sessKey vs = TM.lookup sessKey vs >>= \case Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs $> True _ -> pure False + +tryReadSessVar :: Ord k => k -> TMap k (SessionVar a) -> STM (Maybe a) +tryReadSessVar sessKey vs = TM.lookup sessKey vs $>>= (tryReadTMVar . sessionVar) diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 8c06c0d82..e1d383b5a 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -127,7 +127,7 @@ smpBlockSize = 16384 -- 4 - support command batching (7/17/2022) -- 5 - basic auth for SMP servers (11/12/2022) -- 6 - allow creating queues without subscribing (9/10/2023) --- 7 - support authenticated encryption to verify senders' commands, imply but do NOT send session ID in signed part (2/3/2024) +-- 7 - support authenticated encryption to verify senders' commands, imply but do NOT send session ID in signed part (4/30/2024) data SMPVersion diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 32610b54e..df117c105 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -242,7 +242,7 @@ testDuplexConnection' (alice, aPQ) (bob, bPQ) = do alice #: ("4a", "bob", "ACK 7") #> ("4a", "bob", OK) alice #: ("5", "bob", "OFF") #> ("5", "bob", OK) bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", A.MID 8 pq) - bob <# ("", "alice", MERR 8 (SMP AUTH)) + bob <#= \case ("", "alice", MERR 8 (SMP _ AUTH)) -> True; _ -> False alice #: ("6", "bob", "DEL") #> ("6", "bob", OK) alice #:# "nothing else should be delivered to alice" @@ -280,7 +280,7 @@ testDuplexConnRandomIds' (alice, aPQ) (bob, bPQ) = do alice #: ("4a", bobConn, "ACK 7") #> ("4a", bobConn, OK) alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK) bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, A.MID 8 pq) - bob <# ("", aliceConn, MERR 8 (SMP AUTH)) + bob <#= \case ("", cId, MERR 8 (SMP _ AUTH)) -> cId == aliceConn; _ -> False alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK) alice #:# "nothing else should be delivered to alice" @@ -383,7 +383,7 @@ testSubscrNotification t (server, _) client = do killThread server client <#. ("", "", DOWN testSMPServer ["conn1"]) withSmpServer (ATransport t) $ - client <# ("", "conn1", ERR (SMP AUTH)) -- this new server does not have the queue + client <#= \case ("", "conn1", ERR (SMP _ AUTH)) -> True; _ -> False -- this new server does not have the queue testMsgDeliveryServerRestart :: forall c. Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO () testMsgDeliveryServerRestart (alice, aPQ) (bob, bPQ) = do @@ -630,7 +630,7 @@ syntaxTests t = do <> " subscribe " <> "14\nbob's connInfo" ) - >#> ("311", "a", "ERR SMP AUTH") + >#> ("311", "a", "ERR SMP smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001 AUTH") describe "invalid" $ do it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR CMD SYNTAX") where diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 59e433ea5..f1dcc058a 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -422,7 +422,7 @@ functionalAPITests t = do describe "server with password" $ do let auth = Just "abcd" srv = ProtoServerWithAuth testSMPServer2 - authErr = Just (ProtocolTestFailure TSCreateQueue $ SMP AUTH) + authErr = Just (ProtocolTestFailure TSCreateQueue $ SMP (B.unpack $ strEncode testSMPServer2) AUTH) it "should pass with correct password" $ testSMPServerConnectionTest t auth (srv auth) `shouldReturn` Nothing it "should fail without password" $ testSMPServerConnectionTest t auth (srv Nothing) `shouldReturn` authErr it "should fail with incorrect password" $ testSMPServerConnectionTest t auth (srv $ Just "wrong") `shouldReturn` authErr @@ -537,7 +537,7 @@ runAgentClientTest pqSupport alice@AgentClient {} bob baseId = ackMessage alice bobId (baseId + 4) Nothing suspendConnection alice bobId 5 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "message 2" - get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH)) + get bob =##> \case ("", cId, MERR mId (SMP _ AUTH)) -> cId == aliceId && mId == (baseId + 5); _ -> False deleteConnection alice bobId liftIO $ noMessages alice "nothing else should be delivered to alice" where @@ -669,7 +669,7 @@ runAgentClientContactTest pqSupport alice bob baseId = ackMessage alice bobId (baseId + 4) Nothing suspendConnection alice bobId 5 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "message 2" - get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH)) + get bob =##> \case ("", cId, MERR mId (SMP _ AUTH)) -> cId == aliceId && mId == (baseId + 5); _ -> False deleteConnection alice bobId liftIO $ noMessages alice "nothing else should be delivered to alice" where @@ -1115,7 +1115,7 @@ testExpireMessageQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1} testP 5 <- sendMessage a bId SMP.noMsgFlags "2" liftIO $ threadDelay 1000000 6 <- sendMessage a bId SMP.noMsgFlags "3" -- this won't expire - get a =##> \case ("", c, MERR 5 (SMP QUOTA)) -> bId == c; _ -> False + get a =##> \case ("", c, MERR 5 (SMP _ QUOTA)) -> bId == c; _ -> False pure (aId, bId) withAgent 3 agentCfg initAgentServers testDB2 $ \b' -> runRight_ $ do subscribeConnection b' aId @@ -1143,15 +1143,15 @@ testExpireManyMessagesQuota t = withSmpServerConfigOn t cfg {msgQueueQuota = 1} 7 <- sendMessage a bId SMP.noMsgFlags "4" liftIO $ threadDelay 1000000 8 <- sendMessage a bId SMP.noMsgFlags "5" -- this won't expire - get a =##> \case ("", c, MERR 5 (SMP QUOTA)) -> bId == c; _ -> False + get a =##> \case ("", c, MERR 5 (SMP _ QUOTA)) -> bId == c; _ -> False get a >>= \case - ("", c, MERR 6 (SMP QUOTA)) -> do + ("", c, MERR 6 (SMP _ QUOTA)) -> do liftIO $ bId `shouldBe` c - get a =##> \case ("", c', MERR 7 (SMP QUOTA)) -> bId == c'; ("", c', MERRS [7] (SMP QUOTA)) -> bId == c'; _ -> False - ("", c, MERRS [6] (SMP QUOTA)) -> do + get a =##> \case ("", c', MERR 7 (SMP _ QUOTA)) -> bId == c'; ("", c', MERRS [7] (SMP _ QUOTA)) -> bId == c'; _ -> False + ("", c, MERRS [6] (SMP _ QUOTA)) -> do liftIO $ bId `shouldBe` c - get a =##> \case ("", c', MERR 7 (SMP QUOTA)) -> bId == c'; _ -> False - ("", c, MERRS [6, 7] (SMP QUOTA)) -> liftIO $ bId `shouldBe` c + get a =##> \case ("", c', MERR 7 (SMP _ QUOTA)) -> bId == c'; _ -> False + ("", c, MERRS [6, 7] (SMP _ QUOTA)) -> liftIO $ bId `shouldBe` c r -> error $ show r pure (aId, bId) withAgent 3 agentCfg initAgentServers testDB2 $ \b' -> runRight_ $ do @@ -1402,10 +1402,10 @@ makeConnection = makeConnection_ PQSupportOn makeConnection_ :: PQSupport -> AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnection_ pqEnc alice bob = makeConnectionForUsers_ pqEnc alice 1 bob 1 -makeConnectionForUsers :: AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnectionForUsers :: HasCallStack => AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn -makeConnectionForUsers_ :: PQSupport -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnectionForUsers_ :: HasCallStack => PQSupport -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnectionForUsers_ pqSupport alice aliceUserId bob bobUserId = do (bobId, qInfo) <- A.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqSupport) SMSubscribe aliceId <- A.prepareConnectionToJoin bob bobUserId True qInfo pqSupport @@ -1709,7 +1709,7 @@ testAcceptContactAsync = ackMessage alice bobId (baseId + 4) Nothing suspendConnection alice bobId 5 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2" - get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH)) + get bob =##> \case ("", cId, MERR mId (SMP _ AUTH)) -> cId == aliceId && mId == (baseId + 5); _ -> False deleteConnection alice bobId liftIO $ noMessages alice "nothing else should be delivered to alice" where @@ -1755,7 +1755,7 @@ testWaitDeliveryNoPending t = withAgentClients2 $ \alice bob -> get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False 3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2" - get bob ##> ("", aliceId, MERR (baseId + 3) (SMP AUTH)) + get bob =##> \case ("", cId, MERR mId (SMP _ AUTH)) -> cId == aliceId && mId == (baseId + 3); _ -> False liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" @@ -1850,8 +1850,8 @@ testWaitDeliveryAUTHErr t = liftIO $ noMessages bob "nothing else should be delivered to bob" withSmpServerStoreLogOn t testPort $ \_ -> do - get alice ##> ("", bobId, MERR (baseId + 3) (SMP AUTH)) - get alice ##> ("", bobId, MERR (baseId + 4) (SMP AUTH)) + get alice =##> \case ("", cId, MERR mId (SMP _ AUTH)) -> cId == bobId && mId == (baseId + 3); _ -> False + get alice =##> \case ("", cId, MERR mId (SMP _ AUTH)) -> cId == bobId && mId == (baseId + 4); _ -> False get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False liftIO $ noMessages alice "nothing else should be delivered to alice" @@ -2422,11 +2422,11 @@ testCreateQueueAuth srvVersion clnt1 clnt2 = do b <- getClient 2 clnt2 testDB2 r <- runRight $ do tryError (createConnection a 1 True SCMInvitation Nothing SMSubscribe) >>= \case - Left (SMP AUTH) -> pure 0 + Left (SMP _ AUTH) -> pure 0 Left e -> throwError e Right (bId, qInfo) -> tryError (joinConnection b 1 True qInfo "bob's connInfo" SMSubscribe) >>= \case - Left (SMP AUTH) -> pure 1 + Left (SMP _ AUTH) -> pure 1 Left e -> throwError e Right aId -> do ("", _, CONF confId _ "bob's connInfo") <- get a diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 2c1045791..3b497c8d4 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -6,8 +6,8 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE TypeApplications #-} -{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} +{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} module AgentTests.NotificationTests where @@ -17,10 +17,6 @@ import AgentTests.FunctionalAPITests createConnection, exchangeGreetingsMsgId, get, - withAgent, - withAgentClients2, - withAgentClientsCfgServers2, - withAgentClients3, joinConnection, makeConnection, nGet, @@ -29,7 +25,11 @@ import AgentTests.FunctionalAPITests sendMessage, switchComplete, testServerMatrix2, + withAgent, + withAgentClients2, + withAgentClients3, withAgentClientsCfg2, + withAgentClientsCfgServers2, (##>), (=##>), pattern CON, @@ -59,8 +59,8 @@ import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO) import Simplex.Messaging.Agent.Store.SQLite (getSavedNtfToken) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Server.Env (NtfServerConfig (..)) import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Server.Env (NtfServerConfig (..)) import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Types (NtfToken (..)) import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgFlags (MsgFlags), NtfServer, ProtocolServer (..), SMPMsgMeta (..), SubscriptionMode (..)) @@ -151,7 +151,8 @@ testNtfMatrix t runTest = do it "next servers: SMP v7, NTF v2; curr clients: v6/v1" $ runNtfTestCfg t cfgV7 ntfServerCfgV2 agentCfg agentCfg runTest it "curr servers: SMP v6, NTF v1; curr clients: v6/v1" $ runNtfTestCfg t cfg ntfServerCfg agentCfg agentCfg runTest skip "this case cannot be supported - see RFC" $ - it "servers: SMP v6, NTF v1; clients: v7/v2 (not supported)" $ runNtfTestCfg t cfg ntfServerCfg agentCfgV7 agentCfgV7 runTest + it "servers: SMP v6, NTF v1; clients: v7/v2 (not supported)" $ + runNtfTestCfg t cfg ntfServerCfg agentCfgV7 agentCfgV7 runTest -- servers can be migrated in any order it "servers: next SMP v7, curr NTF v1; curr clients: v6/v1" $ runNtfTestCfg t cfgV7 ntfServerCfg agentCfg agentCfg runTest it "servers: curr SMP v6, next NTF v2; curr clients: v6/v1" $ runNtfTestCfg t cfg ntfServerCfgV2 agentCfg agentCfg runTest @@ -243,7 +244,7 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = -- now the second token registration is verified verifyNtfToken a' tkn nonce' verification' -- the first registration is removed - Left (NTF AUTH) <- tryE $ checkNtfToken a tkn + Left (NTF _ AUTH) <- tryE $ checkNtfToken a tkn -- and the second is active NTActive <- checkNtfToken a' tkn pure () @@ -258,7 +259,7 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do atomically $ readTBQueue apnsQ liftIO $ sendApnsResponse APNSRespOk pure ntfData - -- the new agent is created as otherwise when running the tests in CI the old agent was keeping the connection to the server + -- the new agent is created as otherwise when running the tests in CI the old agent was keeping the connection to the server threadDelay 1000000 withAgent 2 agentCfg initAgentServers testDB $ \a' -> -- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token, @@ -266,7 +267,7 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do withNtfServer t . runRight_ $ do verification <- ntfData .-> "verification" nonce <- C.cbNonce <$> ntfData .-> "nonce" - Left (NTF AUTH) <- tryE $ verifyNtfToken a' tkn nonce verification + Left (NTF _ AUTH) <- tryE $ verifyNtfToken a' tkn nonce verification APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <- atomically $ readTBQueue apnsQ verification' <- ntfData' .-> "verification" diff --git a/tests/CoreTests/ProtocolErrorTests.hs b/tests/CoreTests/ProtocolErrorTests.hs index 7b1a7b813..d574bfb4f 100644 --- a/tests/CoreTests/ProtocolErrorTests.hs +++ b/tests/CoreTests/ProtocolErrorTests.hs @@ -35,6 +35,9 @@ protocolErrorTests = modifyMaxSuccess (const 1000) $ do errHasSpaces = \case BROKER srv (RESPONSE e) -> hasSpaces srv || hasSpaces e BROKER srv _ -> hasSpaces srv + SMP srv _ -> hasSpaces srv + NTF srv _ -> hasSpaces srv + XFTP srv _ -> hasSpaces srv _ -> False hasSpaces s = ' ' `B.elem` encodeUtf8 (T.pack s) diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index 88786bb40..0610bf48d 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -69,7 +69,7 @@ xftpAgentTests = around_ testBracket . describe "agent XFTP API" $ do describe "server with password" $ do let auth = Just "abcd" srv = ProtoServerWithAuth testXFTPServer2 - authErr = Just (ProtocolTestFailure TSCreateFile $ XFTP AUTH) + authErr = Just (ProtocolTestFailure TSCreateFile $ XFTP (B.unpack $ strEncode testXFTPServer2) AUTH) it "should pass with correct password" $ testXFTPServerTest auth (srv auth) `shouldReturn` Nothing it "should fail without password" $ testXFTPServerTest auth (srv Nothing) `shouldReturn` authErr it "should fail with incorrect password" $ testXFTPServerTest auth (srv $ Just "wrong") `shouldReturn` authErr @@ -392,7 +392,8 @@ testXFTPAgentReceiveCleanup = withGlobalLogging logCfgNoLogs $ do -- receive file - should fail with AUTH error withAgent 3 agentCfg initAgentServers testDB2 $ \rcp' -> do runRight_ $ xftpStartWorkers rcp' (Just recipientFiles) - ("", rfId', RFERR (INTERNAL "XFTP {xftpErr = AUTH}")) <- rfGet rcp' + ("", rfId', RFERR (INTERNAL "XFTP {serverAddress = \"xftp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:7000\", xftpErr = AUTH}")) <- + rfGet rcp' rfId' `shouldBe` rfId -- tmp path should be removed after permanent error @@ -471,7 +472,8 @@ testXFTPAgentSendCleanup = withGlobalLogging logCfgNoLogs $ do -- send file - should fail with AUTH error withAgent 2 agentCfg initAgentServers testDB $ \sndr' -> do runRight_ $ xftpStartWorkers sndr' (Just senderFiles) - ("", sfId', SFERR (INTERNAL "XFTP {xftpErr = AUTH}")) <- sfGet sndr' + ("", sfId', SFERR (INTERNAL "XFTP {serverAddress = \"xftp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:7000\", xftpErr = AUTH}")) <- + sfGet sndr' sfId' `shouldBe` sfId -- prefix path should be removed after permanent error @@ -506,7 +508,8 @@ testXFTPAgentDelete = withGlobalLogging logCfgNoLogs $ withAgent 3 agentCfg initAgentServers testDB2 $ \rcp2 -> runRight $ do xftpStartWorkers rcp2 (Just recipientFiles) rfId <- xftpReceiveFile rcp2 1 rfd2 Nothing - ("", rfId', RFERR (INTERNAL "XFTP {xftpErr = AUTH}")) <- rfGet rcp2 + ("", rfId', RFERR (INTERNAL "XFTP {serverAddress = \"xftp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:7000\", xftpErr = AUTH}")) <- + rfGet rcp2 liftIO $ rfId' `shouldBe` rfId testXFTPAgentDeleteRestore :: HasCallStack => IO () @@ -543,7 +546,8 @@ testXFTPAgentDeleteRestore = withGlobalLogging logCfgNoLogs $ do withAgent 5 agentCfg initAgentServers testDB3 $ \rcp2 -> runRight $ do xftpStartWorkers rcp2 (Just recipientFiles) rfId <- xftpReceiveFile rcp2 1 rfd2 Nothing - ("", rfId', RFERR (INTERNAL "XFTP {xftpErr = AUTH}")) <- rfGet rcp2 + ("", rfId', RFERR (INTERNAL "XFTP {serverAddress = \"xftp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:7000\", xftpErr = AUTH}")) <- + rfGet rcp2 liftIO $ rfId' `shouldBe` rfId testXFTPAgentDeleteOnServer :: HasCallStack => IO () @@ -577,7 +581,8 @@ testXFTPAgentDeleteOnServer = withGlobalLogging logCfgNoLogs $ runRight_ . void $ do -- receive file 1 again rfId1 <- xftpReceiveFile rcp 1 rfd1_2 Nothing - ("", rfId1', RFERR (INTERNAL "XFTP {xftpErr = AUTH}")) <- rfGet rcp + ("", rfId1', RFERR (INTERNAL "XFTP {serverAddress = \"xftp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:7000\", xftpErr = AUTH}")) <- + rfGet rcp liftIO $ rfId1 `shouldBe` rfId1' -- receive file 2 @@ -609,7 +614,8 @@ testXFTPAgentExpiredOnServer = withGlobalLogging logCfgNoLogs $ do -- receive file 1 again - should fail with AUTH error runRight $ do rfId <- xftpReceiveFile rcp 1 rfd1_2 Nothing - ("", rfId', RFERR (INTERNAL "XFTP {xftpErr = AUTH}")) <- rfGet rcp + ("", rfId', RFERR (INTERNAL "XFTP {serverAddress = \"xftp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:7000\", xftpErr = AUTH}")) <- + rfGet rcp liftIO $ rfId' `shouldBe` rfId -- create and send file 2