diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 98eb84da1..1f3d7099e 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -1045,7 +1045,6 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh withStore' c $ \db -> do setSndQueueStatus db sq Confirmed when (isJust rq_) $ removeConfirmations db connId - -- TODO possibly notification flag should be ON for one of the parties, to result in contact connected notification unless (duplexHandshake == Just True) . void $ enqueueMessage c cData sq SMP.noMsgFlags HELLO AM_CONN_INFO_REPLY -> pure () AM_REPLY_ -> pure () diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index e47509be3..9a3642767 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -98,7 +98,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Either (isRight, partitionEithers) import Data.Functor (($>)) -import Data.List (partition, (\\)) +import Data.List (partition) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) @@ -303,12 +303,9 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do removeClientAndSubs :: IO ([RcvQueue], [ConnId]) removeClientAndSubs = atomically $ do TM.delete srv smpClients - qs <- RQ.getDelSrvQueues srv $ activeSubs c + (qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c mapM_ (`RQ.addQueue` pendingSubs c) qs - cs <- RQ.getConns (activeSubs c) - -- TODO deduplicate conns - let conns = map (connId :: RcvQueue -> ConnId) qs \\ S.toList cs - pure (qs, conns) + pure (qs, S.toList conns) serverDown :: ([RcvQueue], [ConnId]) -> IO () serverDown (qs, conns) = whenM (readTVarIO active) $ do @@ -345,8 +342,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do unless connected . forM_ client_ $ \cl -> do incClientStat c cl "CONNECT" "" notifySub "" $ hostEvent CONNECT cl - -- TODO deduplicate okConns - let conns = okConns \\ S.toList cs + let conns = S.toList $ S.fromList okConns `S.difference` cs unless (null conns) $ notifySub "" $ UP srv conns let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs @@ -647,8 +643,7 @@ temporaryOrHostError = \case subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (Maybe SMPClient, [(RcvQueue, Either AgentErrorType ())]) subscribeQueues c srv qs = do (errs, qs_) <- partitionEithers <$> mapM checkQueue qs - forM_ qs_ $ \rq@RcvQueue {connId, server = _server} -> atomically $ do - -- TODO check server is correct + forM_ qs_ $ \rq@RcvQueue {connId} -> atomically $ do modifyTVar (subscrConns c) $ S.insert connId RQ.addQueue rq $ pendingSubs c case L.nonEmpty qs_ of @@ -667,9 +662,11 @@ subscribeQueues c srv qs = do pure $ map (second . first $ protocolClientError SMP $ clientServer smp) rs _ -> pure (Nothing, errs) where - checkQueue rq@RcvQueue {rcvId, server} = do - prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c - pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq + checkQueue rq@RcvQueue {rcvId, server} + | server == srv = do + prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c + pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq + | otherwise = pure $ Left (rq, Left $ INTERNAL "queue server does not match parameter") queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m () diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs index bed4138a2..9ace32db7 100644 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -40,9 +40,9 @@ getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs where addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs' -getDelSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue] -getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty) +getDelSrvQueues :: SMPServer -> TRcvQueues -> STM ([RcvQueue], Set ConnId) +getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ (([], S.empty), M.empty) where - addQ (removed, qs') rq@RcvQueue {server, rcvId} - | srv == server = (rq : removed, qs') + addQ (removed@(remQs, remConns), qs') rq@RcvQueue {connId, server, rcvId} + | srv == server = ((rq : remQs, S.insert connId remConns), qs') | otherwise = (removed, M.insert (server, rcvId) rq qs') diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 645a26e3a..0263e374e 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1160,26 +1160,33 @@ instance Encoding CommandError where _ -> fail "bad command error type" -- | Send signed SMP transmission to TCP transport. -tPut :: Transport c => THandle c -> NonEmpty SentRawTransmission -> IO (NonEmpty (Either TransportError ())) +tPut :: Transport c => THandle c -> NonEmpty SentRawTransmission -> IO [Either TransportError ()] tPut th trs | batch th = tPutBatch [] $ L.map tEncode trs - | otherwise = forM trs $ tPutBlock th . tEncode + | otherwise = forM (L.toList trs) $ tPutLog . tEncode where - tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO (NonEmpty (Either TransportError ())) + tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO [Either TransportError ()] tPutBatch rs ts = do let (n, s, ts_) = encodeBatch 0 "" ts - r <- if n == 0 then pure [Left TELargeMsg] else replicate n <$> tPutBlock th (lenEncode n `B.cons` s) + r <- if n == 0 then largeMsg else replicate n <$> tPutLog (lenEncode n `B.cons` s) let rs' = rs <> r case ts_ of Just ts' -> tPutBatch rs' ts' - _ -> pure $ L.fromList rs' + _ -> pure rs' + largeMsg = putStrLn "tPut error: large message" >> pure [Left TELargeMsg] + tPutLog s = do + r <- tPutBlock th s + case r of + Left e -> putStrLn ("tPut error: " <> show e) + _ -> pure () + pure r encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString)) encodeBatch n s ts@(t :| ts_) | n == 255 = (n, s, Just ts) | otherwise = let s' = s <> smpEncode (Large t) n' = n + 1 - in if B.length s' > blockSize th - 1 + in if B.length s' > blockSize th - 1 -- one byte is reserved for the number of messages in the batch then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts else case L.nonEmpty ts_ of Just ts' -> encodeBatch n' s' ts' diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 1bf3d46df..52c3558ed 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -256,7 +256,6 @@ receive th Client {rcvQ, sndQ, activeAt} = forever $ do send :: Transport c => THandle c -> Client -> IO () send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ - -- TODO the line below can return Lefts, but we ignore it and do not disconnect the client void . liftIO . tPut h $ L.map ((Nothing,) . encodeTransmission v sessionId) ts atomically . writeTVar activeAt =<< liftIO getSystemTime where