diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 2ffe6f0f7..3fe7311e4 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -362,7 +362,7 @@ reconnectAllServers c = liftIO $ do closeProtocolServerClients c ntfClients -- | Register device notifications token -registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus +registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m (NtfTknStatus, Maybe NtfServer) registerNtfToken c = withAgentEnv c .: registerNtfToken' c -- | Verify device notifications token @@ -1573,22 +1573,23 @@ connectionStats = \case setProtocolServers' :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> NonEmpty (ProtoServerWithAuth p) -> m () setProtocolServers' c userId srvs = atomically $ TM.insert userId srvs (userServers c) -registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus +registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m (NtfTknStatus, Maybe NtfServer) registerNtfToken' c suppliedDeviceToken suppliedNtfMode = withStore' c getSavedNtfToken >>= \case - Just tkn@NtfToken {deviceToken = savedDeviceToken, ntfTokenId, ntfTknStatus, ntfTknAction, ntfMode = savedNtfMode} -> do - status <- case (ntfTokenId, ntfTknAction) of + Just tkn@NtfToken {deviceToken = savedDeviceToken, ntfTokenId, ntfTknStatus, ntfTknAction, ntfMode = savedNtfMode, ntfServer} -> do + (status, srv) <- case (ntfTokenId, ntfTknAction) of (Nothing, Just NTARegister) -> do when (savedDeviceToken /= suppliedDeviceToken) $ withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken - registerToken tkn $> NTRegistered + registerToken tkn $> (NTRegistered, Just ntfServer) -- possible improvement: add minimal time before repeat registration (Just tknId, Nothing) | savedDeviceToken == suppliedDeviceToken -> - when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered + when (ntfTknStatus == NTRegistered) (registerToken tkn) $> (NTRegistered, Just ntfServer) | otherwise -> replaceToken tknId (Just tknId, Just (NTAVerify code)) - | savedDeviceToken == suppliedDeviceToken -> - t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code + | savedDeviceToken == suppliedDeviceToken -> do + status' <- t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code + pure (status', Just ntfServer) | otherwise -> replaceToken tknId (Just tknId, Just NTACheck) | savedDeviceToken == suppliedDeviceToken -> do @@ -1600,19 +1601,19 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = when (suppliedNtfMode == NMInstant) $ initializeNtfSubs c when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete -- possible improvement: get updated token status from the server, or maybe TCRON could return the current status - pure ntfTknStatus + pure (ntfTknStatus, Just ntfServer) | otherwise -> replaceToken tknId (Just tknId, Just NTADelete) -> do agentNtfDeleteToken c tknId tkn withStore' c (`removeNtfToken` tkn) ns <- asks ntfSupervisor atomically $ nsRemoveNtfToken ns - pure NTExpired - _ -> pure ntfTknStatus + pure (NTExpired, Nothing) + _ -> pure (ntfTknStatus, Just ntfServer) withStore' c $ \db -> updateNtfMode db tkn suppliedNtfMode - pure status + pure (status, srv) where - replaceToken :: NtfTokenId -> m NtfTknStatus + replaceToken :: NtfTokenId -> m (NtfTknStatus, Maybe NtfServer) replaceToken tknId = do ns <- asks ntfSupervisor tryReplace ns `catchAgentError` \e -> @@ -1627,11 +1628,11 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = agentNtfReplaceToken c tknId tkn suppliedDeviceToken withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken atomically $ nsUpdateToken ns tkn {deviceToken = suppliedDeviceToken, ntfTknStatus = NTRegistered, ntfMode = suppliedNtfMode} - pure NTRegistered + pure (NTRegistered, Just ntfServer) _ -> createToken where t tkn = withToken c tkn Nothing - createToken :: m NtfTknStatus + createToken :: m (NtfTknStatus, Maybe NtfServer) createToken = getNtfServer c >>= \case Just ntfServer -> @@ -1643,7 +1644,7 @@ registerNtfToken' c suppliedDeviceToken suppliedNtfMode = let tkn = newNtfToken suppliedDeviceToken ntfServer tknKeys dhKeys suppliedNtfMode withStore' c (`createNtfToken` tkn) registerToken tkn - pure NTRegistered + pure (NTRegistered, Just ntfServer) _ -> throwError $ CMD PROHIBITED registerToken :: NtfToken -> m () registerToken tkn@NtfToken {ntfPubKey, ntfDhKeys = (pubDhKey, privDhKey)} = do diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 15ba1993e..1d3090819 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -121,7 +121,7 @@ testNotificationToken APNSMockServer {apnsQ} = do a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB runRight_ $ do let tkn = DeviceToken PPApnsTest "abcd" - NTRegistered <- registerNtfToken a tkn NMPeriodic + (NTRegistered, _) <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ verification <- ntfData .-> "verification" @@ -149,13 +149,13 @@ testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB runRight_ $ do let tkn = DeviceToken PPApnsTest "abcd" - NTRegistered <- registerNtfToken a tkn NMPeriodic + (NTRegistered, _) <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ verification <- ntfData .-> "verification" nonce <- C.cbNonce <$> ntfData .-> "nonce" liftIO $ sendApnsResponse APNSRespOk - NTRegistered <- registerNtfToken a tkn NMPeriodic + (NTRegistered, _) <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <- atomically $ readTBQueue apnsQ _ <- ntfData' .-> "verification" @@ -174,7 +174,7 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do a' <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 runRight_ $ do let tkn = DeviceToken PPApnsTest "abcd" - NTRegistered <- registerNtfToken a tkn NMPeriodic + (NTRegistered, _) <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ verification <- ntfData .-> "verification" @@ -182,7 +182,7 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do liftIO $ sendApnsResponse APNSRespOk verifyNtfToken a tkn nonce verification - NTRegistered <- registerNtfToken a' tkn NMPeriodic + (NTRegistered, _) <- registerNtfToken a' tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <- atomically $ readTBQueue apnsQ verification' <- ntfData' .-> "verification" @@ -208,7 +208,7 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB let tkn = DeviceToken PPApnsTest "abcd" ntfData <- withNtfServer t . runRight $ do - NTRegistered <- registerNtfToken a tkn NMPeriodic + (NTRegistered, _) <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ liftIO $ sendApnsResponse APNSRespOk @@ -245,7 +245,7 @@ testNtfTokenMultipleServers t APNSMockServer {apnsQ} = do withNtfServerThreadOn t ntfTestPort $ \ntf -> withNtfServerThreadOn t ntfTestPort2 $ \ntf2 -> runRight_ $ do -- register a new token, the agent picks a server and stores its choice - NTRegistered <- registerNtfToken a tkn NMPeriodic + (NTRegistered, _) <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ verification <- ntfData .-> "verification" @@ -309,7 +309,7 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do get bob ##> ("", aliceId, CON) -- register notification token let tkn = DeviceToken PPApnsTest "abcd" - NTRegistered <- registerNtfToken alice tkn NMInstant + (NTRegistered, _) <- registerNtfToken alice tkn NMInstant APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ verification <- ntfData .-> "verification" @@ -400,7 +400,7 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do registerTestToken :: AgentClient -> ByteString -> NotificationsMode -> TBQueue APNSMockRequest -> ExceptT AgentErrorType IO DeviceToken registerTestToken a token mode apnsQ = do let tkn = DeviceToken PPApnsTest token - NTRegistered <- registerNtfToken a tkn mode + (NTRegistered, _) <- registerNtfToken a tkn mode Just APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <- timeout 1000000 . atomically $ readTBQueue apnsQ verification' <- ntfData' .-> "verification" @@ -433,7 +433,7 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = do get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False ackMessage alice bobId (baseId + 1) Nothing -- set mode to NMPeriodic - NTActive <- registerNtfToken alice tkn NMPeriodic + (NTActive, _) <- registerNtfToken alice tkn NMPeriodic -- send message, no notification liftIO $ threadDelay 750000 2 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello again" @@ -442,7 +442,7 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = do get alice =##> \case ("", c, Msg "hello again") -> c == bobId; _ -> False ackMessage alice bobId (baseId + 2) Nothing -- set mode to NMInstant - NTActive <- registerNtfToken alice tkn NMInstant + (NTActive, _) <- registerNtfToken alice tkn NMInstant -- send message, receive notification liftIO $ threadDelay 500000 3 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello there"