diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 4ad44ef4a..1407c3862 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -765,11 +765,11 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = (Just tknId, Nothing) | savedDeviceToken == suppliedDeviceToken -> when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered - | otherwise -> replaceToken tknId $> NTRegistered + | otherwise -> replaceToken tknId (Just tknId, Just (NTAVerify code)) | savedDeviceToken == suppliedDeviceToken -> t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code - | otherwise -> replaceToken tknId $> NTRegistered + | otherwise -> replaceToken tknId (Just tknId, Just NTACheck) | savedDeviceToken == suppliedDeviceToken -> do ns <- asks ntfSupervisor @@ -781,7 +781,7 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete pure ntfTknStatus -- TODO -- agentNtfCheckToken c tknId tkn >>= \case - | otherwise -> replaceToken tknId $> NTRegistered + | otherwise -> replaceToken tknId (Just tknId, Just NTADelete) -> do agentNtfDeleteToken c tknId tkn withStore' c (`removeNtfToken` tkn) @@ -792,13 +792,27 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = withStore' c $ \db -> updateNtfMode db tkn suppliedNtfMode pure status where - replaceToken :: NtfTokenId -> m () + replaceToken :: NtfTokenId -> m NtfTknStatus replaceToken tknId = do - agentNtfReplaceToken c tknId tkn suppliedDeviceToken - withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken ns <- asks ntfSupervisor - atomically $ nsUpdateToken ns tkn {deviceToken = suppliedDeviceToken, ntfTknStatus = NTRegistered, ntfMode = suppliedNtfMode} - _ -> + tryReplace ns `catchError` \e -> + if temporaryAgentError e || e == BROKER HOST + then throwError e + else do + withStore' c $ \db -> removeNtfToken db tkn + atomically $ nsRemoveNtfToken ns + createToken + where + tryReplace ns = do + agentNtfReplaceToken c tknId tkn suppliedDeviceToken + withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken + atomically $ nsUpdateToken ns tkn {deviceToken = suppliedDeviceToken, ntfTknStatus = NTRegistered, ntfMode = suppliedNtfMode} + pure NTRegistered + _ -> createToken + where + t tkn = withToken c tkn Nothing + createToken :: m NtfTknStatus + createToken = getNtfServer c >>= \case Just ntfServer -> asks (cmdSignAlg . config) >>= \case @@ -810,8 +824,6 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = registerToken tkn pure NTRegistered _ -> throwError $ CMD PROHIBITED - where - t tkn = withToken c tkn Nothing registerToken :: NtfToken -> m () registerToken tkn@NtfToken {ntfPubKey, ntfDhKeys = (pubDhKey, privDhKey)} = do (tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey diff --git a/src/Simplex/Messaging/Notifications/Server/Store.hs b/src/Simplex/Messaging/Notifications/Server/Store.hs index 38ef6cc72..d9af0fb29 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store.hs @@ -106,6 +106,7 @@ removeInactiveTokenRegistrations st NtfTknData {ntfTknId = tId, token} = forM_ tIds $ \(regKey, tId') -> do TM.delete regKey tknRegs TM.delete tId' $ tokens st + -- TODO remove token subscriptions as in deleteNtfToken pure $ map snd tIds removeTokenRegistration :: NtfStore -> NtfTknData -> STM () @@ -130,6 +131,7 @@ deleteNtfToken st tknId = do ) ) + -- TODO refactor qs <- TM.lookupDelete tknId (tokenSubscriptions st) >>= mapM diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index f3a0cc111..a8dbd1a17 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -276,31 +276,25 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do liftIO $ threadDelay 1000000 aliceId <- joinConnection bob True qInfo "bob's connInfo" liftIO $ threadDelay 750000 - liftIO $ print 0 void $ messageNotification apnsQ ("", _, CONF confId _ "bob's connInfo") <- get alice liftIO $ threadDelay 500000 allowConnection alice bobId confId "alice's connInfo" - liftIO $ print 1 void $ messageNotification apnsQ get bob ##> ("", aliceId, INFO "alice's connInfo") - liftIO $ print 2 void $ messageNotification apnsQ get alice ##> ("", bobId, CON) - liftIO $ print 3 void $ messageNotification apnsQ get bob ##> ("", aliceId, CON) -- bob sends message 1 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello" get bob ##> ("", aliceId, SENT $ baseId + 1) - liftIO $ print 4 void $ messageNotification apnsQ get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False ackMessage alice bobId $ baseId + 1 -- alice sends message 2 <- msgId <$> sendMessage alice bobId (SMP.MsgFlags True) "hey there" get alice ##> ("", bobId, SENT $ baseId + 2) - liftIO $ print 5 void $ messageNotification apnsQ get bob =##> \case ("", c, Msg "hey there") -> c == aliceId; _ -> False ackMessage bob aliceId $ baseId + 2