diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 70a192f94..8cc8ec03a 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -12,6 +12,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} -- | -- Module : Simplex.Messaging.Agent @@ -100,7 +101,7 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..)) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, SMPMsgMeta) +import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta) import qualified Simplex.Messaging.Protocol as SMP import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (bshow, eitherToMaybe, liftE, liftError, tryError, unlessM, whenM, ($>>=)) @@ -637,8 +638,7 @@ deleteConnection' c connId = -- | Change servers to be used for creating new queues, in Reader monad setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m () -setSMPServers' c servers = do - atomically $ writeTVar (smpServers c) servers +setSMPServers' c = atomically . writeTVar (smpServers c) registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus registerNtfToken' c suppliedDeviceToken suppliedNtfMode = diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 506bc5293..4fb5f2822 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -98,7 +98,7 @@ import Simplex.Messaging.Notifications.Client import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, NotifierId, NtfPrivateSignKey, NtfPublicVerifyKey, ProtocolServer (..), QueueId, QueueIdsKeys (..), RcvMessage (..), RcvNtfPublicDhKey, SMPMsgMeta (..), SndPublicVerifyKey) +import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, NotifierId, NtfPrivateSignKey, NtfPublicVerifyKey, NtfServer, ProtoServer, ProtocolServer (..), QueueId, QueueIdsKeys (..), RcvMessage (..), RcvNtfPublicDhKey, SMPMsgMeta (..), SndPublicVerifyKey) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -194,7 +194,7 @@ agentDbPath :: AgentClient -> FilePath agentDbPath AgentClient {agentEnv = Env {store = SQLiteStore {dbFilePath}}} = dbFilePath class ProtocolServerClient msg where - getProtocolServerClient :: AgentMonad m => AgentClient -> ProtocolServer -> m (ProtocolClient msg) + getProtocolServerClient :: AgentMonad m => AgentClient -> ProtoServer msg -> m (ProtocolClient msg) clientProtocolError :: ErrorType -> AgentErrorType instance ProtocolServerClient BrokerMsg where @@ -311,7 +311,7 @@ getNtfServerClient c@AgentClient {active, ntfClients} srv = do atomically $ TM.delete srv ntfClients logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv -getClientVar :: forall a. ProtocolServer -> TMap ProtocolServer (TMVar a) -> STM (Either (TMVar a) (TMVar a)) +getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a)) getClientVar srv clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv clients where newClientVar :: STM (TMVar a) @@ -333,8 +333,8 @@ newProtocolClient :: forall msg m. AgentMonad m => AgentClient -> - ProtocolServer -> - TMap ProtocolServer (ClientVar msg) -> + ProtoServer msg -> + TMap (ProtoServer msg) (ClientVar msg) -> m (ProtocolClient msg) -> m () -> ClientVar msg -> @@ -383,7 +383,7 @@ closeAgentClient c = liftIO $ do clear :: (AgentClient -> TMap k a) -> IO () clear sel = atomically $ writeTVar (sel c) M.empty -closeProtocolServerClients :: Int -> TMap ProtocolServer (ClientVar msg) -> IO () +closeProtocolServerClients :: Int -> TMap (ProtoServer msg) (ClientVar msg) -> IO () closeProtocolServerClients tcpTimeout cs = readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty) where closeClient cVar = @@ -400,7 +400,7 @@ withAgentLock AgentClient {lock} = (void . atomically $ takeTMVar lock) (atomically $ putTMVar lock ()) -withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> (ProtocolClient msg -> m a) -> m a +withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> (ProtocolClient msg -> m a) -> m a withClient_ c srv action = (getProtocolServerClient c srv >>= action) `catchError` logServerError where logServerError :: AgentErrorType -> m a @@ -408,17 +408,17 @@ withClient_ c srv action = (getProtocolServerClient c srv >>= action) `catchErro logServer "<--" c srv "" $ bshow e throwError e -withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> QueueId -> ByteString -> (ProtocolClient msg -> m a) -> m a +withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> m a) -> m a withLogClient_ c srv qId cmdStr action = do logServer "-->" c srv qId cmdStr res <- withClient_ c srv action logServer "<--" c srv qId "OK" return res -withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a +withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a withClient c srv action = withClient_ c srv $ liftClient (clientProtocolError @msg) . action -withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a +withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ liftClient (clientProtocolError @msg) . action liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> ExceptT ProtocolClientError IO a -> m a @@ -516,11 +516,11 @@ getSubscriptions AgentClient {subscrConns} = do m <- readTVar subscrConns pure $ M.keysSet m -logServer :: MonadIO m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () +logServer :: MonadIO m => ByteString -> AgentClient -> ProtocolServer s -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr] -showServer :: SMPServer -> ByteString +showServer :: ProtocolServer s -> ByteString showServer ProtocolServer {host, port} = B.pack $ host <> if null port then "" else ':' : port diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 4af3c0411..6dd481ee9 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -39,6 +39,7 @@ import Simplex.Messaging.Client import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Types +import Simplex.Messaging.Protocol (NtfServer) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (TLS, Transport (..)) diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index 22ee8377f..45aee73e2 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -139,9 +139,9 @@ processNtfSub c (connId, cmd) = do addNtfNTFWorker = addWorker ntfWorkers runNtfWorker addNtfSMPWorker = addWorker ntfSMPWorkers runNtfSMPWorker addWorker :: - (NtfSupervisor -> TMap ProtocolServer (TMVar (), Async ())) -> - (AgentClient -> ProtocolServer -> TMVar () -> m ()) -> - ProtocolServer -> + (NtfSupervisor -> TMap (ProtocolServer s) (TMVar (), Async ())) -> + (AgentClient -> ProtocolServer s -> TMVar () -> m ()) -> + ProtocolServer s -> m () addWorker wsSel runWorker srv = do ws <- asks $ wsSel . ntfSupervisor @@ -340,7 +340,7 @@ closeNtfSupervisor ns = do cancelNtfWorkers_ $ ntfWorkers ns cancelNtfWorkers_ $ ntfSMPWorkers ns -cancelNtfWorkers_ :: TMap ProtocolServer (TMVar (), Async ()) -> IO () +cancelNtfWorkers_ :: TMap (ProtocolServer s) (TMVar (), Async ()) -> IO () cancelNtfWorkers_ wsVar = do ws <- atomically $ stateTVar wsVar (,M.empty) forM_ ws $ uninterruptibleCancel . snd diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 06066629f..8a75f9ddf 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -127,7 +127,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..)) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (blobFieldParser, fromTextField_) -import Simplex.Messaging.Protocol (MsgBody, MsgFlags, ProtocolServer (..), RcvNtfDhSecret) +import Simplex.Messaging.Protocol (MsgBody, MsgFlags, NtfServer, ProtocolServer (..), RcvNtfDhSecret, pattern NtfServer) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util (bshow, eitherToMaybe, ($>>=), (<$$>)) import Simplex.Messaging.Version @@ -656,7 +656,7 @@ getSavedNtfToken db = do |] where ntfToken ((host, port, keyHash) :. (provider, dt, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode_)) = - let ntfServer = ProtocolServer {host, port, keyHash} + let ntfServer = NtfServer host port keyHash ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey) ntfMode = fromMaybe NMPeriodic ntfMode_ in NtfToken {deviceToken = DeviceToken provider dt, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} @@ -736,7 +736,7 @@ getNtfSubscription db connId = where ntfSubscription (smpHost, smpPort, smpKeyHash, ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, ntfAction_, smpAction_, actionTs_) = let smpServer = SMPServer smpHost smpPort smpKeyHash - ntfServer = ProtocolServer ntfHost ntfPort ntfKeyHash + ntfServer = NtfServer ntfHost ntfPort ntfKeyHash action = case (ntfAction_, smpAction_, actionTs_) of (Just ntfAction, Nothing, Just actionTs) -> Just (NtfSubNTFAction ntfAction, actionTs) (Nothing, Just smpAction, Just actionTs) -> Just (NtfSubSMPAction smpAction, actionTs) @@ -745,7 +745,7 @@ getNtfSubscription db connId = createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> NtfActionTs -> IO () createNtfSubscription db ntfSubscription action actionTs = do - let NtfSubscription {connId, smpServer = (SMPServer host port _), ntfQueueId, ntfServer = (SMPServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription + let NtfSubscription {connId, smpServer = (SMPServer host port _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription DB.execute db [sql| @@ -761,7 +761,7 @@ createNtfSubscription db ntfSubscription action actionTs = do (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action supervisorUpdateNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> NtfActionTs -> IO () -supervisorUpdateNtfSubscription db NtfSubscription {connId, ntfQueueId, ntfServer = (ProtocolServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action actionTs = do +supervisorUpdateNtfSubscription db NtfSubscription {connId, ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action actionTs = do updatedAt <- getCurrentTime DB.execute db @@ -789,7 +789,7 @@ supervisorUpdateNtfSubAction db connId action actionTs = do (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action updateNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> NtfActionTs -> IO () -updateNtfSubscription db NtfSubscription {connId, ntfQueueId, ntfServer = (ProtocolServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action actionTs = do +updateNtfSubscription db NtfSubscription {connId, ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action actionTs = do r <- maybeFirstRow fromOnly $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) forM_ r $ \updatedBySupervisor -> do updatedAt <- getCurrentTime @@ -848,7 +848,7 @@ deleteNtfSubscription db connId = do else DB.execute db "DELETE FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) getNextNtfSubNTFAction :: DB.Connection -> NtfServer -> IO (Maybe (NtfSubscription, NtfSubNTFAction, NtfActionTs)) -getNextNtfSubNTFAction db ntfServer@(ProtocolServer ntfHost ntfPort _) = do +getNextNtfSubNTFAction db ntfServer@(NtfServer ntfHost ntfPort _) = do maybeFirstRow ntfSubAction getNtfSubAction_ $>>= \a@(NtfSubscription {connId}, _, _) -> do DB.execute db "UPDATE ntf_subscriptions SET updated_by_supervisor = ? WHERE conn_id = ?" (False, connId) pure $ Just a @@ -891,7 +891,7 @@ getNextNtfSubSMPAction db smpServer@(SMPServer smpHost smpPort _) = do |] (smpHost, smpPort) ntfSubAction (connId, ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = - let ntfServer = ProtocolServer ntfHost ntfPort ntfKeyHash + let ntfServer = NtfServer ntfHost ntfPort ntfKeyHash ntfSubscription = NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} in (ntfSubscription, action, actionTs) @@ -911,7 +911,7 @@ getActiveNtfToken db = (Only NTActive) where ntfToken ((host, port, keyHash) :. (provider, dt, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode_)) = - let ntfServer = ProtocolServer {host, port, keyHash} + let ntfServer = NtfServer host port keyHash ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey) ntfMode = fromMaybe NMPeriodic ntfMode_ in NtfToken {deviceToken = DeviceToken provider dt, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index c39f5b1e2..a04be51ba 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -83,7 +83,7 @@ data ProtocolClient msg = ProtocolClient connected :: TVar Bool, sessionId :: SessionId, thVersion :: Version, - protocolServer :: ProtocolServer, + protocolServer :: ProtoServer msg, tcpTimeout :: Int, clientCorrId :: TVar Natural, sentCommands :: TMap CorrId (Request msg), @@ -95,7 +95,7 @@ data ProtocolClient msg = ProtocolClient type SMPClient = ProtocolClient SMP.BrokerMsg -- | Type synonym for transmission from some SPM server queue. -type ServerTransmission msg = (ProtocolServer, Version, SessionId, QueueId, msg) +type ServerTransmission msg = (ProtoServer msg, Version, SessionId, QueueId, msg) -- | protocol client configuration. data ProtocolClientConfig = ProtocolClientConfig @@ -137,7 +137,7 @@ type Response msg = Either ProtocolClientError msg -- -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. -getProtocolClient :: forall msg. Protocol msg => ProtocolServer -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg)) +getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg)) getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tcpKeepAlive, smpPing, smpServerVRange} msgQ disconnected = (atomically mkProtocolClient >>= runClient useTransport) `catch` \(e :: IOException) -> pure . Left $ PCEIOError e @@ -378,7 +378,7 @@ sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd) -- | Send Protocol command -sendProtocolCommand :: forall msg. ProtocolEncoding (ProtocolCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtocolCommand msg -> ExceptT ProtocolClientError IO msg +sendProtocolCommand :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT ProtocolClientError IO msg sendProtocolCommand ProtocolClient {sndQ, sentCommands, clientCorrId, sessionId, thVersion, tcpTimeout} pKey qId cmd = do corrId <- lift_ getNextCorrId t <- signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd) diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 9b8e55184..28421de66 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -146,7 +146,8 @@ instance Encoding ANewNtfEntity where _ -> fail "bad ANewNtfEntity" instance Protocol NtfResponse where - type ProtocolCommand NtfResponse = NtfCmd + type ProtoCommand NtfResponse = NtfCmd + type ProtoType NtfResponse = 'PNTF protocolClientHandshake = ntfClientHandshake protocolPing = NtfCmd SSubscription PING protocolError = \case @@ -323,7 +324,7 @@ instance ProtocolEncoding NtfResponse where | otherwise = Left $ CMD HAS_AUTH data SMPQueueNtf = SMPQueueNtf - { smpServer :: ProtocolServer, + { smpServer :: SMPServer, notifierId :: NotifierId } deriving (Eq, Ord, Show) diff --git a/src/Simplex/Messaging/Notifications/Types.hs b/src/Simplex/Messaging/Notifications/Types.hs index d679d00f2..3b6f7782d 100644 --- a/src/Simplex/Messaging/Notifications/Types.hs +++ b/src/Simplex/Messaging/Notifications/Types.hs @@ -3,6 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} module Simplex.Messaging.Notifications.Types where @@ -16,9 +17,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Parsers (blobFieldDecoder, fromTextField_) -import Simplex.Messaging.Protocol (NotifierId, ProtocolServer, SMPServer) - -type NtfServer = ProtocolServer +import Simplex.Messaging.Protocol (NotifierId, NtfServer, SMPServer) data NtfTknAction = NTARegister diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 52dc7d0b3..42713ad2d 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -63,9 +63,13 @@ module Simplex.Messaging.Protocol ClientMessage (..), PrivHeader (..), Protocol (..), + ProtocolType (..), ProtocolServer (..), + ProtoServer, SMPServer, pattern SMPServer, + NtfServer, + pattern NtfServer, SrvLoc (..), CorrId (..), QueueId, @@ -546,41 +550,109 @@ instance Encoding ClientMessage where smpEncode (ClientMessage h msg) = smpEncode h <> msg smpP = ClientMessage <$> smpP <*> A.takeByteString -type SMPServer = ProtocolServer +type SMPServer = ProtocolServer 'PSMP -pattern SMPServer :: HostName -> ServiceName -> C.KeyHash -> ProtocolServer -pattern SMPServer host port keyHash = ProtocolServer host port keyHash +pattern SMPServer :: HostName -> ServiceName -> C.KeyHash -> ProtocolServer 'PSMP +pattern SMPServer host port keyHash = ProtocolServer SPSMP host port keyHash {-# COMPLETE SMPServer #-} --- | SMP server location and transport key digest (hash). -data ProtocolServer = ProtocolServer - { host :: HostName, +type NtfServer = ProtocolServer 'PNTF + +pattern NtfServer :: HostName -> ServiceName -> C.KeyHash -> ProtocolServer 'PNTF +pattern NtfServer host port keyHash = ProtocolServer SPNTF host port keyHash + +{-# COMPLETE NtfServer #-} + +data ProtocolType = PSMP | PNTF + deriving (Eq, Ord, Show) + +instance StrEncoding ProtocolType where + strEncode = \case + PSMP -> "smp" + PNTF -> "ntf" + strP = + A.takeTill (== ':') >>= \case + "smp" -> pure PSMP + "ntf" -> pure PNTF + _ -> fail "bad ProtocolType" + +data SProtocolType (p :: ProtocolType) where + SPSMP :: SProtocolType 'PSMP + SPNTF :: SProtocolType 'PNTF + +deriving instance Eq (SProtocolType p) + +deriving instance Ord (SProtocolType p) + +deriving instance Show (SProtocolType p) + +data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p) + +instance TestEquality SProtocolType where + testEquality SPSMP SPSMP = Just Refl + testEquality SPNTF SPNTF = Just Refl + testEquality _ _ = Nothing + +protocolType :: SProtocolType p -> ProtocolType +protocolType = \case + SPSMP -> PSMP + SPNTF -> PNTF + +aProtocolType :: ProtocolType -> AProtocolType +aProtocolType = \case + PSMP -> AProtocolType SPSMP + PNTF -> AProtocolType SPNTF + +instance ProtocolTypeI p => StrEncoding (SProtocolType p) where + strEncode = strEncode . protocolType + strP = (\(AProtocolType p) -> checkProtocolType p) <$?> strP + +instance StrEncoding AProtocolType where + strEncode (AProtocolType p) = strEncode p + strP = aProtocolType <$> strP + +checkProtocolType :: forall t p p'. (ProtocolTypeI p, ProtocolTypeI p') => t p' -> Either String (t p) +checkProtocolType p = case testEquality (protocolTypeI @p) (protocolTypeI @p') of + Just Refl -> Right p + Nothing -> Left "bad ProtocolType" + +class ProtocolTypeI (p :: ProtocolType) where + protocolTypeI :: SProtocolType p + +instance ProtocolTypeI 'PSMP where protocolTypeI = SPSMP + +instance ProtocolTypeI 'PNTF where protocolTypeI = SPNTF + +-- | server location and transport key digest (hash). +data ProtocolServer p = ProtocolServer + { scheme :: SProtocolType p, + host :: HostName, port :: ServiceName, keyHash :: C.KeyHash } deriving (Eq, Ord, Show) -instance IsString ProtocolServer where +instance ProtocolTypeI p => IsString (ProtocolServer p) where fromString = parseString strDecode -instance Encoding ProtocolServer where +instance ProtocolTypeI p => Encoding (ProtocolServer p) where smpEncode ProtocolServer {host, port, keyHash} = smpEncode (host, port, keyHash) smpP = do (host, port, keyHash) <- smpP - pure ProtocolServer {host, port, keyHash} + pure ProtocolServer {scheme = protocolTypeI @p, host, port, keyHash} -instance StrEncoding ProtocolServer where - strEncode ProtocolServer {host, port, keyHash} = - "smp://" <> strEncode keyHash <> "@" <> strEncode (SrvLoc host port) +instance ProtocolTypeI p => StrEncoding (ProtocolServer p) where + strEncode ProtocolServer {scheme, host, port, keyHash} = + strEncode scheme <> "://" <> strEncode keyHash <> "@" <> strEncode (SrvLoc host port) strP = do - _ <- "smp://" + scheme <- strP <* "://" keyHash <- strP <* A.char '@' SrvLoc host port <- strP - pure ProtocolServer {host, port, keyHash} + pure ProtocolServer {scheme, host, port, keyHash} -instance ToJSON ProtocolServer where +instance ProtocolTypeI p => ToJSON (ProtocolServer p) where toJSON = strToJSON toEncoding = strToJEncoding @@ -727,14 +799,18 @@ transmissionP = do command <- A.takeByteString pure RawTransmission {signature, signed, sessId, corrId, entityId, command} -class (ProtocolEncoding msg, ProtocolEncoding (ProtocolCommand msg), Show msg) => Protocol msg where - type ProtocolCommand msg = cmd | cmd -> msg +class (ProtocolEncoding msg, ProtocolEncoding (ProtoCommand msg), Show msg) => Protocol msg where + type ProtoCommand msg = cmd | cmd -> msg + type ProtoType msg = (sch :: ProtocolType) | sch -> msg protocolClientHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) - protocolPing :: ProtocolCommand msg + protocolPing :: ProtoCommand msg protocolError :: msg -> Maybe ErrorType +type ProtoServer msg = ProtocolServer (ProtoType msg) + instance Protocol BrokerMsg where - type ProtocolCommand BrokerMsg = Cmd + type ProtoCommand BrokerMsg = Cmd + type ProtoType BrokerMsg = 'PSMP protocolClientHandshake = smpClientHandshake protocolPing = Cmd SSender PING protocolError = \case diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index af4f749bc..cf5bb0a20 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -19,12 +19,7 @@ uri :: String uri = "smp.simplex.im" srv :: SMPServer -srv = - ProtocolServer - { host = "smp.simplex.im", - port = "5223", - keyHash = C.KeyHash "\215m\248\251" - } +srv = SMPServer "smp.simplex.im" "5223" (C.KeyHash "\215m\248\251") queue :: SMPQueueUri queue = diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 08cc3ea37..9a5bbeee7 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -162,7 +162,7 @@ initAgentServers :: InitialAgentServers initAgentServers = InitialAgentServers { smp = L.fromList [testSMPServer], - ntf = ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"] + ntf = ["ntf://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"] } agentCfg :: AgentConfig