Move generic push functions to Push.hs

This commit is contained in:
sim
2025-06-26 16:59:25 +02:00
parent a2d777bda0
commit e90c15bb90
2 changed files with 82 additions and 58 deletions

View File

@@ -0,0 +1,81 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use newtype instead of data" #-}
module Simplex.Messaging.Notifications.Server.Push where
import Crypto.Hash.Algorithms (SHA256 (..))
import qualified Crypto.PubKey.ECC.ECDSA as EC
import qualified Crypto.PubKey.ECC.Types as ECT
import qualified Crypto.Store.PKCS8 as PK
import Data.ASN1.BinaryEncoding (DER (..))
import Data.ASN1.Encoding
import Data.ASN1.Types
import Data.Aeson (ToJSON)
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as JQ
import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Int (Int64)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Text (Text)
import Data.Time.Clock.System
import qualified Data.X509 as X
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Parsers (defaultJSON)
data JWTHeader = JWTHeader
{ alg :: Text, -- key algorithm, ES256 for APNS
kid :: Text -- key ID
}
deriving (Show)
data JWTClaims = JWTClaims
{ iss :: Text, -- issuer, team ID for APNS
iat :: Int64 -- issue time, seconds from epoch
}
deriving (Show)
data JWTToken = JWTToken JWTHeader JWTClaims
deriving (Show)
mkJWTToken :: JWTHeader -> Text -> IO JWTToken
mkJWTToken hdr iss = do
iat <- systemSeconds <$> getSystemTime
pure $ JWTToken hdr JWTClaims {iss, iat}
type SignedJWTToken = ByteString
$(JQ.deriveToJSON defaultJSON ''JWTHeader)
$(JQ.deriveToJSON defaultJSON ''JWTClaims)
signedJWTToken :: EC.PrivateKey -> JWTToken -> IO SignedJWTToken
signedJWTToken pk (JWTToken hdr claims) = do
let hc = jwtEncode hdr <> "." <> jwtEncode claims
sig <- EC.sign pk SHA256 hc
pure $ hc <> "." <> serialize sig
where
jwtEncode :: ToJSON a => a -> ByteString
jwtEncode = U.encodeUnpadded . LB.toStrict . J.encode
serialize sig = U.encodeUnpadded $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
readECPrivateKey :: FilePath -> IO EC.PrivateKey
readECPrivateKey f = do
-- this pattern match is specific to APNS key type, it may need to be extended for other push providers
[PK.Unprotected (X.PrivKeyEC X.PrivKeyEC_Named {privkeyEC_name, privkeyEC_priv})] <- PK.readKeyFile f
pure EC.PrivateKey {private_curve = ECT.getCurveByName privkeyEC_name, private_d = privkeyEC_priv}
data PushNotification
= PNVerification NtfRegCode
| PNMessage (NonEmpty PNMessageData)
| -- | PNAlert Text
PNCheckMessages
deriving (Show)

View File

@@ -16,14 +16,8 @@ import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Class
import Control.Monad.Trans.Except
import Crypto.Hash.Algorithms (SHA256 (..))
import qualified Crypto.PubKey.ECC.ECDSA as EC
import qualified Crypto.PubKey.ECC.Types as ECT
import Crypto.Random (ChaChaDRG)
import qualified Crypto.Store.PKCS8 as PK
import Data.ASN1.BinaryEncoding (DER (..))
import Data.ASN1.Encoding
import Data.ASN1.Types
import Data.Aeson (ToJSON, (.=))
import qualified Data.Aeson as J
import qualified Data.Aeson.Encoding as JE
@@ -32,18 +26,15 @@ import Data.Bifunctor (first)
import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Builder (lazyByteString)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Lazy.Char8 as LB
import qualified Data.CaseInsensitive as CI
import Data.Int (Int64)
import Data.List (find)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict (Map)
import Data.Maybe (isNothing)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Data.Time.Clock.System
import qualified Data.X509 as X
import qualified Data.X509.CertificateStore as XS
import Network.HPACK.Token as HT
import Network.HTTP.Types (Status)
@@ -53,6 +44,7 @@ import qualified Network.HTTP2.Client as H
import Network.Socket (HostName, ServiceName)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Server.Push
import Simplex.Messaging.Notifications.Server.Push.APNS.Internal
import Simplex.Messaging.Notifications.Server.Store.Types (NtfTknRec (..))
import Simplex.Messaging.Parsers (defaultJSON)
@@ -62,55 +54,6 @@ import Simplex.Messaging.Util (safeDecodeUtf8, tshow)
import System.Environment (getEnv)
import UnliftIO.STM
data JWTHeader = JWTHeader
{ alg :: Text, -- key algorithm, ES256 for APNS
kid :: Text -- key ID
}
deriving (Show)
data JWTClaims = JWTClaims
{ iss :: Text, -- issuer, team ID for APNS
iat :: Int64 -- issue time, seconds from epoch
}
deriving (Show)
data JWTToken = JWTToken JWTHeader JWTClaims
deriving (Show)
mkJWTToken :: JWTHeader -> Text -> IO JWTToken
mkJWTToken hdr iss = do
iat <- systemSeconds <$> getSystemTime
pure $ JWTToken hdr JWTClaims {iss, iat}
type SignedJWTToken = ByteString
$(JQ.deriveToJSON defaultJSON ''JWTHeader)
$(JQ.deriveToJSON defaultJSON ''JWTClaims)
signedJWTToken :: EC.PrivateKey -> JWTToken -> IO SignedJWTToken
signedJWTToken pk (JWTToken hdr claims) = do
let hc = jwtEncode hdr <> "." <> jwtEncode claims
sig <- EC.sign pk SHA256 hc
pure $ hc <> "." <> serialize sig
where
jwtEncode :: ToJSON a => a -> ByteString
jwtEncode = U.encodeUnpadded . LB.toStrict . J.encode
serialize sig = U.encodeUnpadded $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
readECPrivateKey :: FilePath -> IO EC.PrivateKey
readECPrivateKey f = do
-- this pattern match is specific to APNS key type, it may need to be extended for other push providers
[PK.Unprotected (X.PrivKeyEC X.PrivKeyEC_Named {privkeyEC_name, privkeyEC_priv})] <- PK.readKeyFile f
pure EC.PrivateKey {private_curve = ECT.getCurveByName privkeyEC_name, private_d = privkeyEC_priv}
data PushNotification
= PNVerification NtfRegCode
| PNMessage (NonEmpty PNMessageData)
| -- | PNAlert Text
PNCheckMessages
deriving (Show)
data APNSNotification = APNSNotification {aps :: APNSNotificationBody, notificationData :: Maybe J.Value}
deriving (Show)