diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index ae0436e87..8ccc83332 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -791,7 +791,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do notify OK LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK ACK msgId -> withServer' . tryCommand $ ackMessage' c connId msgId >> notify OK - SWCH -> noServer $ tryCommand $ switchConnection' c connId >>= notify . SWITCH SPStarted + SWCH -> noServer $ tryCommand $ switchConnection' c connId >>= notify . SWITCH QDRcv SPStarted DEL -> withServer' . tryCommand $ deleteConnection' c connId >> notify OK _ -> notify $ ERR $ INTERNAL $ "unsupported async command " <> show (aCommandTag cmd) AInternalCommand cmd -> case cmd of @@ -847,7 +847,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (connId, NSCCreate) let conn' = DuplexConnection cData (rq'' :| rqs') sqs - notify $ SWITCH SPCompleted $ connectionStats conn' + notify $ SWITCH QDRcv SPCompleted $ connectionStats conn' _ -> internalErr "ICQDelete: cannot delete the only queue in connection" where ack srv rId srvMsgId = do @@ -1076,7 +1076,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh deleteConnSndQueue db connId sq' let sqs'' = sq'' :| sqs' conn' = DuplexConnection cData' rqs sqs'' - notify . SWITCH SPCompleted $ connectionStats conn' + notify . SWITCH QDSnd SPCompleted $ connectionStats conn' _ -> internalErr msgId "sent QTEST: there is only one queue in connection" _ -> internalErr msgId "sent QTEST: queue not in connection or not replacing another queue" _ -> internalErr msgId "QTEST sent not in duplex connection" @@ -1694,7 +1694,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm let sqInfo' = (sqInfo :: SMPQueueInfo) {queueAddress = queueAddress {dhPublicKey}} void . enqueueMessages c cData sqs SMP.noMsgFlags $ QKEY [(sqInfo', sndPubKey)] let conn' = DuplexConnection cData rqs (sq <| sq' :| sqs_) - notify . SWITCH SPStarted $ connectionStats conn' + notify . SWITCH QDSnd SPStarted $ connectionStats conn' _ -> qError "absent sender keys" _ -> qError "QADD: replaced queue address is not found in connection" _ -> throwError $ AGENT A_VERSION @@ -1711,7 +1711,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm let dhSecret = C.dh' dhPublicKey dhPrivKey withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer' enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey - notify . SWITCH SPConfirmed $ connectionStats conn + notify . SWITCH QDRcv SPConfirmed $ connectionStats conn | otherwise -> qError "QKEY: queue already secured" _ -> qError "QKEY: queue address not found in connection" where @@ -1729,7 +1729,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm let sq'' = (sq' :: SndQueue) {status = Secured} -- sending QTEST to the new queue only, the old one will be removed if sent successfully void $ enqueueMessages c cData [sq''] SMP.noMsgFlags $ QTEST [addr] - notify . SWITCH SPConfirmed $ connectionStats conn + notify . SWITCH QDSnd SPConfirmed $ connectionStats conn _ -> qError "QUSE: queue address not found in connection" qError :: String -> m () diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 9b96ce573..d0907493c 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -51,6 +51,7 @@ module Simplex.Messaging.Agent.Protocol MsgMeta (..), ConnectionStats (..), SwitchPhase (..), + QueueDirection (..), SMPConfirmation (..), AgentMsgEnvelope (..), AgentMessage (..), @@ -252,7 +253,7 @@ data ACommand (p :: AParty) where DISCONNECT :: AProtocolType -> TransportHost -> ACommand Agent DOWN :: SMPServer -> [ConnId] -> ACommand Agent UP :: SMPServer -> [ConnId] -> ACommand Agent - SWITCH :: SwitchPhase -> ConnectionStats -> ACommand Agent + SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> ACommand Agent SEND :: MsgFlags -> MsgBody -> ACommand Client MID :: AgentMsgId -> ACommand Agent SENT :: AgentMsgId -> ACommand Agent @@ -345,6 +346,26 @@ aCommandTag = \case ERR _ -> ERR_ SUSPENDED -> SUSPENDED_ +data QueueDirection = QDRcv | QDSnd + deriving (Eq, Show) + +instance StrEncoding QueueDirection where + strEncode = \case + QDRcv -> "rcv" + QDSnd -> "snd" + strP = + A.takeTill (== ' ') >>= \case + "rcv" -> pure QDRcv + "snd" -> pure QDSnd + _ -> fail "bad QueueDirection" + +instance ToJSON QueueDirection where + toEncoding = strToJEncoding + toJSON = strToJSON + +instance FromJSON QueueDirection where + parseJSON = strParseJSON "QueueDirection" + data SwitchPhase = SPStarted | SPConfirmed | SPCompleted deriving (Eq, Show) @@ -1254,7 +1275,7 @@ commandP binaryP = DISCONNECT_ -> s (DISCONNECT <$> strP_ <*> strP) DOWN_ -> s (DOWN <$> strP_ <*> connections) UP_ -> s (UP <$> strP_ <*> connections) - SWITCH_ -> s (SWITCH <$> strP_ <*> strP) + SWITCH_ -> s (SWITCH <$> strP_ <*> strP_ <*> strP) MID_ -> s (MID <$> A.decimal) SENT_ -> s (SENT <$> A.decimal) MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP) @@ -1297,7 +1318,7 @@ serializeCommand = \case DISCONNECT p h -> s (DISCONNECT_, p, h) DOWN srv conns -> B.unwords [s DOWN_, s srv, connections conns] UP srv conns -> B.unwords [s UP_, s srv, connections conns] - SWITCH phase srvs -> s (SWITCH_, phase, srvs) + SWITCH dir phase srvs -> s (SWITCH_, dir, phase, srvs) SEND msgFlags msgBody -> B.unwords [s SEND_, smpEncode msgFlags, serializeBinary msgBody] MID mId -> s (MID_, Str $ bshow mId) SENT mId -> s (SENT_, Str $ bshow mId) diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 6fa506204..75f06869b 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -669,23 +669,25 @@ testSwitchConnection servers = do switchComplete :: AgentClient -> ByteString -> AgentClient -> ByteString -> ExceptT AgentErrorType IO () switchComplete a bId b aId = do - phase a bId SPStarted - phase b aId SPStarted - phase a bId SPConfirmed - phase b aId SPConfirmed - phase b aId SPCompleted - phase a bId SPCompleted + phase a bId QDRcv SPStarted + phase b aId QDSnd SPStarted + phase a bId QDRcv SPConfirmed + phase b aId QDSnd SPConfirmed + phase b aId QDSnd SPCompleted + phase a bId QDRcv SPCompleted -phase :: AgentClient -> ByteString -> SwitchPhase -> ExceptT AgentErrorType IO () -phase c connId p = +phase :: AgentClient -> ByteString -> QueueDirection -> SwitchPhase -> ExceptT AgentErrorType IO () +phase c connId d p = get c >>= \(_, connId', msg) -> do liftIO $ connId `shouldBe` connId' case msg of - SWITCH p' _ -> liftIO $ p `shouldBe` p' - ERR (AGENT A_DUPLICATE) -> phase c connId p + SWITCH d' p' _ -> liftIO $ do + d `shouldBe` d' + p `shouldBe` p' + ERR (AGENT A_DUPLICATE) -> phase c connId d p r -> do liftIO . putStrLn $ "expected: " <> show p <> ", received: " <> show r - SWITCH _ _ <- pure r + SWITCH _ _ _ <- pure r pure () testSwitchAsync :: InitialAgentServers -> IO () @@ -698,13 +700,13 @@ testSwitchAsync servers = do withB' = session withB aId withA' $ \a -> do switchConnectionAsync a "" bId - phase a bId SPStarted - withB' $ \b -> phase b aId SPStarted - withA' $ \a -> phase a bId SPConfirmed + phase a bId QDRcv SPStarted + withB' $ \b -> phase b aId QDSnd SPStarted + withA' $ \a -> phase a bId QDRcv SPConfirmed withB' $ \b -> do - phase b aId SPConfirmed - phase b aId SPCompleted - withA' $ \a -> phase a bId SPCompleted + phase b aId QDSnd SPConfirmed + phase b aId QDSnd SPCompleted + withA' $ \a -> phase a bId QDRcv SPCompleted Right () <- withA $ \a -> withB $ \b -> runExceptT $ do subscribeConnection a bId subscribeConnection b aId @@ -733,7 +735,7 @@ testSwitchDelete servers = do exchangeGreetingsMsgId 4 a bId b aId disconnectAgentClient b switchConnectionAsync a "" bId - phase a bId SPStarted + phase a bId QDRcv SPStarted deleteConnectionAsync a "1" bId ("1", bId', OK) <- get a liftIO $ bId `shouldBe` bId'