mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-06-04 13:01:29 +00:00
refactor types for DB entity (#1548)
This commit is contained in:
@@ -135,6 +135,7 @@ library
|
||||
Simplex.Messaging.Server.QueueStore.QueueInfo
|
||||
Simplex.Messaging.ServiceScheme
|
||||
Simplex.Messaging.Session
|
||||
Simplex.Messaging.Agent.Store.Entity
|
||||
Simplex.Messaging.TMap
|
||||
Simplex.Messaging.Transport
|
||||
Simplex.Messaging.Transport.Buffer
|
||||
@@ -308,6 +309,7 @@ library
|
||||
, network-transport ==0.5.6
|
||||
, network-udp ==0.0.*
|
||||
, random >=1.1 && <1.3
|
||||
, scientific ==0.3.7.*
|
||||
, simple-logger ==0.1.*
|
||||
, socks ==0.6.*
|
||||
, stm ==2.5.*
|
||||
|
||||
@@ -216,6 +216,7 @@ import Simplex.Messaging.Protocol
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.ServiceScheme (ServiceScheme (..))
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (SMPVersion)
|
||||
import Simplex.Messaging.Util
|
||||
@@ -833,7 +834,7 @@ newConn c userId enableNtfs cMode userData_ clientData pqInitKeys subMode = do
|
||||
`catchE` \e -> withStore' c (`deleteConnRecord` connId) >> throwE e
|
||||
|
||||
setContactShortLink' :: AgentClient -> ConnId -> ConnInfo -> Maybe CRClientData -> AM (ConnShortLink 'CMContact)
|
||||
setContactShortLink' c connId userData clientData =
|
||||
setContactShortLink' c connId userData clientData =
|
||||
withConnLock c connId "setContactShortLink" $
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (ContactConnection _ rq) -> do
|
||||
@@ -934,7 +935,7 @@ newRcvConnSrv c userId connId enableNtfs cMode userData_ clientData pqInitKeys s
|
||||
createRcvQueue nonce_ qd e2eKeys = do
|
||||
AgentConfig {smpClientVRange = vr} <- asks config
|
||||
-- TODO [notifications] send correct NTF credentials here
|
||||
-- let ntfCreds_ = Nothing
|
||||
-- let ntfCreds_ = Nothing
|
||||
(rq, qUri, tSess, sessId) <- newRcvQueue_ c userId connId srvWithAuth vr qd subMode nonce_ e2eKeys `catchAgentError` \e -> liftIO (print e) >> throwE e
|
||||
atomically $ incSMPServerStat c userId srv connCreated
|
||||
rq' <- withStore c $ \db -> updateNewConnRcv db connId rq
|
||||
@@ -1122,7 +1123,7 @@ joinConnSrv c userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup subMo
|
||||
Nothing -> throwE $ AGENT A_VERSION
|
||||
|
||||
delInvSL :: AgentClient -> ConnId -> SMPServerWithAuth -> SMP.LinkId -> AM ()
|
||||
delInvSL c connId srv lnkId =
|
||||
delInvSL c connId srv lnkId =
|
||||
withStore' c (\db -> deleteInvShortLink db (protoServer srv) lnkId) `catchE` \e ->
|
||||
liftIO $ nonBlockingWriteTBQueue (subQ c) ("", connId, AEvt SAEConn (ERR $ INTERNAL $ "error deleting short link " <> show e))
|
||||
|
||||
@@ -1293,7 +1294,7 @@ getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage
|
||||
msg_ <- getQueueMessage c rq `catchAgentError` \e -> atomically (releaseGetLock c rq) >> throwError e
|
||||
when (isNothing msg_) $ do
|
||||
atomically $ releaseGetLock c rq
|
||||
forM_ msgTs_ $ \msgTs -> withStore' c $ \db -> setLastBrokerTs db connId (DBQueueId dbQueueId) msgTs
|
||||
forM_ msgTs_ $ \msgTs -> withStore' c $ \db -> setLastBrokerTs db connId (DBEntityId dbQueueId) msgTs
|
||||
pure msg_
|
||||
{-# INLINE getConnectionMessages' #-}
|
||||
|
||||
@@ -1910,7 +1911,7 @@ switchConnection' c connId =
|
||||
_ -> throwE $ CMD PROHIBITED "switchConnection: not duplex"
|
||||
|
||||
switchDuplexConnection :: AgentClient -> Connection 'CDuplex -> RcvQueue -> AM ConnectionStats
|
||||
switchDuplexConnection c (DuplexConnection cData@ConnData {connId, userId} rqs sqs) rq@RcvQueue {server, dbQueueId = DBQueueId dbQueueId, sndId} = do
|
||||
switchDuplexConnection c (DuplexConnection cData@ConnData {connId, userId} rqs sqs) rq@RcvQueue {server, dbQueueId = DBEntityId dbQueueId, sndId} = do
|
||||
checkRQSwchStatus rq RSSwitchStarted
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
-- try to get the server that is different from all queues, or at least from the primary rcv queue
|
||||
@@ -2940,7 +2941,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
Just qInfo@(Compatible sqInfo@SMPQueueInfo {queueAddress}) ->
|
||||
case (findQ (qAddress sqInfo) sqs, findQ addr sqs) of
|
||||
(Just _, _) -> qError "QADD: queue address is already used in connection"
|
||||
(_, Just sq@SndQueue {dbQueueId = DBQueueId dbQueueId}) -> do
|
||||
(_, Just sq@SndQueue {dbQueueId = DBEntityId dbQueueId}) -> do
|
||||
let (delSqs, keepSqs) = L.partition ((Just dbQueueId ==) . dbReplaceQId) sqs
|
||||
case L.nonEmpty keepSqs of
|
||||
Just sqs' -> do
|
||||
@@ -3278,7 +3279,7 @@ newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAdd
|
||||
e2ePubKey = Just e2ePubKey,
|
||||
-- setting status to Secured prevents SKEY when queue was already secured with LKEY
|
||||
status = if isJust sndKeys_ then Secured else New,
|
||||
dbQueueId = DBNewQueue,
|
||||
dbQueueId = DBNewEntity,
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
sndSwchStatus = Nothing,
|
||||
|
||||
@@ -278,6 +278,7 @@ import Simplex.Messaging.Protocol
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server.QueueStore.QueueInfo
|
||||
import Simplex.Messaging.Session
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (SMPVersion, SessionId, THandleParams (sessionId, thVersion), TransportError (..), TransportPeer (..), sndAuthKeySMPVersion, shortLinksSMPVersion)
|
||||
@@ -1083,7 +1084,7 @@ sendOrProxySMPCommand ::
|
||||
UserId ->
|
||||
SMPServer ->
|
||||
ConnId -> -- session entity ID, for short links LinkId is used
|
||||
ByteString ->
|
||||
ByteString ->
|
||||
SMP.EntityId -> -- sender or link ID
|
||||
(SMPClient -> ProxiedRelay -> ExceptT SMPClientError IO (Either ProxyClientError a)) ->
|
||||
(SMPClient -> ExceptT SMPClientError IO a) ->
|
||||
@@ -1395,7 +1396,7 @@ newRcvQueue_ c userId connId (ProtoServerWithAuth srv auth) vRange cqrd subMode
|
||||
queueMode,
|
||||
shortLink,
|
||||
status = New,
|
||||
dbQueueId = DBNewQueue,
|
||||
dbQueueId = DBNewEntity,
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
rcvSwchStatus = Nothing,
|
||||
@@ -1408,7 +1409,7 @@ newRcvQueue_ c userId connId (ProtoServerWithAuth srv auth) vRange cqrd subMode
|
||||
where
|
||||
mkShortLinkCreds :: (THandleParams SMPVersion 'TClient, QueueIdsKeys) -> AM (Maybe ShortLinkCreds)
|
||||
mkShortLinkCreds (thParams', QIK {sndId, queueMode, linkId}) = case (cqrd, queueMode) of
|
||||
(CQRMessaging ld, Just QMMessaging) ->
|
||||
(CQRMessaging ld, Just QMMessaging) ->
|
||||
withLinkData ld $ \lnkId CQRData {linkKey, privSigKey, srvReq = (sndId', d)} ->
|
||||
if sndId == sndId'
|
||||
then pure $ Just $ ShortLinkCreds lnkId linkKey privSigKey (fst d)
|
||||
|
||||
@@ -52,30 +52,19 @@ import Simplex.Messaging.Protocol
|
||||
VersionSMPC,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
|
||||
createStore :: DBOpts -> MigrationConfirmation -> IO (Either MigrationError DBStore)
|
||||
createStore dbOpts = createDBStore dbOpts appMigrations
|
||||
|
||||
-- * Queue types
|
||||
|
||||
data QueueStored = QSStored | QSNew
|
||||
type RcvQueue = StoredRcvQueue 'DBStored
|
||||
|
||||
data SQueueStored (q :: QueueStored) where
|
||||
SQSStored :: SQueueStored 'QSStored
|
||||
SQSNew :: SQueueStored 'QSNew
|
||||
|
||||
data DBQueueId (q :: QueueStored) where
|
||||
DBQueueId :: Int64 -> DBQueueId 'QSStored
|
||||
DBNewQueue :: DBQueueId 'QSNew
|
||||
|
||||
deriving instance Show (DBQueueId q)
|
||||
|
||||
type RcvQueue = StoredRcvQueue 'QSStored
|
||||
|
||||
type NewRcvQueue = StoredRcvQueue 'QSNew
|
||||
type NewRcvQueue = StoredRcvQueue 'DBNew
|
||||
|
||||
-- | A receive queue. SMP queue through which the agent receives messages from a sender.
|
||||
data StoredRcvQueue (q :: QueueStored) = RcvQueue
|
||||
data StoredRcvQueue (q :: DBStored) = RcvQueue
|
||||
{ userId :: UserId,
|
||||
connId :: ConnId,
|
||||
server :: SMPServer,
|
||||
@@ -98,7 +87,7 @@ data StoredRcvQueue (q :: QueueStored) = RcvQueue
|
||||
-- | queue status
|
||||
status :: QueueStatus,
|
||||
-- | database queue ID (within connection)
|
||||
dbQueueId :: DBQueueId q,
|
||||
dbQueueId :: DBEntityId' q,
|
||||
-- | True for a primary or a next primary queue of the connection (next if dbReplaceQueueId is set)
|
||||
primary :: Bool,
|
||||
-- | database queue ID to replace, Nothing if this queue is not replacing another, `Just Nothing` is used for replacing old queues
|
||||
@@ -160,12 +149,12 @@ data InvShortLink = InvShortLink
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
type SndQueue = StoredSndQueue 'QSStored
|
||||
type SndQueue = StoredSndQueue 'DBStored
|
||||
|
||||
type NewSndQueue = StoredSndQueue 'QSNew
|
||||
type NewSndQueue = StoredSndQueue 'DBNew
|
||||
|
||||
-- | A send queue. SMP queue through which the agent sends messages to a recipient.
|
||||
data StoredSndQueue (q :: QueueStored) = SndQueue
|
||||
data StoredSndQueue (q :: DBStored) = SndQueue
|
||||
{ userId :: UserId,
|
||||
connId :: ConnId,
|
||||
server :: SMPServer,
|
||||
@@ -184,7 +173,7 @@ data StoredSndQueue (q :: QueueStored) = SndQueue
|
||||
-- | queue status
|
||||
status :: QueueStatus,
|
||||
-- | database queue ID (within connection)
|
||||
dbQueueId :: DBQueueId q,
|
||||
dbQueueId :: DBEntityId' q,
|
||||
-- | True for a primary or a next primary queue of the connection (next if dbReplaceQueueId is set)
|
||||
primary :: Bool,
|
||||
-- | ID of the queue this one is replacing
|
||||
@@ -257,7 +246,7 @@ instance SMPQueueRec RcvQueue where
|
||||
{-# INLINE qUserId #-}
|
||||
qConnId RcvQueue {connId} = connId
|
||||
{-# INLINE qConnId #-}
|
||||
dbQId RcvQueue {dbQueueId = DBQueueId qId} = qId
|
||||
dbQId RcvQueue {dbQueueId = DBEntityId qId} = qId
|
||||
{-# INLINE dbQId #-}
|
||||
dbReplaceQId RcvQueue {dbReplaceQueueId} = dbReplaceQueueId
|
||||
{-# INLINE dbReplaceQId #-}
|
||||
@@ -267,7 +256,7 @@ instance SMPQueueRec SndQueue where
|
||||
{-# INLINE qUserId #-}
|
||||
qConnId SndQueue {connId} = connId
|
||||
{-# INLINE qConnId #-}
|
||||
dbQId SndQueue {dbQueueId = DBQueueId qId} = qId
|
||||
dbQId SndQueue {dbQueueId = DBEntityId qId} = qId
|
||||
{-# INLINE dbQId #-}
|
||||
dbReplaceQId SndQueue {dbReplaceQueueId} = dbReplaceQueueId
|
||||
{-# INLINE dbReplaceQId #-}
|
||||
|
||||
@@ -283,6 +283,7 @@ import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Protocol
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, firstRow, firstRow', ifM, maybeFirstRow, tshow, ($>>=), (<$$>))
|
||||
import Simplex.Messaging.Version.Internal
|
||||
@@ -858,7 +859,7 @@ createRcvMsg db connId rq@RcvQueue {dbQueueId} rcvMsgData@RcvMsgData {msgMeta =
|
||||
updateRcvMsgHash db connId sndMsgId internalRcvId internalHash
|
||||
setLastBrokerTs db connId dbQueueId brokerTs
|
||||
|
||||
setLastBrokerTs :: DB.Connection -> ConnId -> DBQueueId 'QSStored -> UTCTime -> IO ()
|
||||
setLastBrokerTs :: DB.Connection -> ConnId -> DBEntityId -> UTCTime -> IO ()
|
||||
setLastBrokerTs db connId dbQueueId brokerTs =
|
||||
DB.execute db "UPDATE rcv_queues SET last_broker_ts = ? WHERE conn_id = ? AND rcv_queue_id = ? AND (last_broker_ts IS NULL OR last_broker_ts < ?)" (brokerTs, connId, dbQueueId, brokerTs)
|
||||
|
||||
@@ -1212,7 +1213,7 @@ getSndRatchet db connId v =
|
||||
DB.query db "SELECT ratchet_state, x3dh_pub_key_1, x3dh_pub_key_2, pq_pub_kem FROM ratchets WHERE conn_id = ?" (Only connId)
|
||||
where
|
||||
result = \case
|
||||
(Just ratchetState, Just k1, Just k2, pKem_) ->
|
||||
(Just ratchetState, Just k1, Just k2, pKem_) ->
|
||||
let params = case pKem_ of
|
||||
Nothing -> CR.AE2ERatchetParams CR.SRKSProposed (CR.E2ERatchetParams v k1 k2 Nothing)
|
||||
Just (CR.ARKP s pKem) -> CR.AE2ERatchetParams s (CR.E2ERatchetParams v k1 k2 (Just pKem))
|
||||
@@ -1811,15 +1812,6 @@ instance ToField QueueStatus where toField = toField . serializeQueueStatus
|
||||
|
||||
instance FromField QueueStatus where fromField = fromTextField_ queueStatusT
|
||||
|
||||
instance ToField (DBQueueId 'QSStored) where toField (DBQueueId qId) = toField qId
|
||||
|
||||
instance FromField (DBQueueId 'QSStored) where
|
||||
#if defined(dbPostgres)
|
||||
fromField x dat = DBQueueId <$> fromField x dat
|
||||
#else
|
||||
fromField x = DBQueueId <$> fromField x
|
||||
#endif
|
||||
|
||||
instance ToField InternalRcvId where toField (InternalRcvId x) = toField x
|
||||
|
||||
deriving newtype instance FromField InternalRcvId
|
||||
@@ -2018,13 +2010,13 @@ insertSndQueue_ db connId' sq@SndQueue {..} serverKeyHash_ = do
|
||||
smp_client_version=EXCLUDED.smp_client_version,
|
||||
server_key_hash=EXCLUDED.server_key_hash
|
||||
|]
|
||||
((host server, port server, sndId, queueMode, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret)
|
||||
((host server, port server, sndId, queueMode, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret)
|
||||
:. (status, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_))
|
||||
pure (sq :: NewSndQueue) {connId = connId', dbQueueId = qId}
|
||||
|
||||
newQueueId_ :: [Only Int64] -> DBQueueId 'QSStored
|
||||
newQueueId_ [] = DBQueueId 1
|
||||
newQueueId_ (Only maxId : _) = DBQueueId (maxId + 1)
|
||||
newQueueId_ :: [Only Int64] -> DBEntityId
|
||||
newQueueId_ [] = DBEntityId 1
|
||||
newQueueId_ (Only maxId : _) = DBEntityId (maxId + 1)
|
||||
|
||||
-- * getConn helpers
|
||||
|
||||
@@ -2160,7 +2152,7 @@ rcvQueueQuery =
|
||||
|
||||
toRcvQueue ::
|
||||
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, Maybe QueueMode)
|
||||
:. (QueueStatus, DBQueueId 'QSStored, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int)
|
||||
:. (QueueStatus, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int)
|
||||
:. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret)
|
||||
:. (Maybe SMP.LinkId, Maybe LinkKey, Maybe C.PrivateKeyEd25519, Maybe EncDataBytes) ->
|
||||
RcvQueue
|
||||
@@ -2210,7 +2202,7 @@ sndQueueQuery =
|
||||
toSndQueue ::
|
||||
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId, Maybe QueueMode)
|
||||
:. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus)
|
||||
:. (DBQueueId 'QSStored, BoolInt, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) ->
|
||||
:. (DBEntityId, BoolInt, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) ->
|
||||
SndQueue
|
||||
toSndQueue
|
||||
( (userId, keyHash, connId, host, port, sndId, queueMode)
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.Entity where
|
||||
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
import Data.Int (Int64)
|
||||
import Data.Scientific (floatingOrInteger)
|
||||
import Simplex.Messaging.Agent.Store.DB (FromField (..), ToField (..))
|
||||
|
||||
data DBStored = DBStored | DBNew
|
||||
|
||||
data SDBStored (s :: DBStored) where
|
||||
SDBStored :: SDBStored 'DBStored
|
||||
SDBNew :: SDBStored 'DBNew
|
||||
|
||||
deriving instance Show (SDBStored s)
|
||||
|
||||
class DBStoredI s where sdbStored :: SDBStored s
|
||||
|
||||
instance DBStoredI 'DBStored where sdbStored = SDBStored
|
||||
|
||||
instance DBStoredI 'DBNew where sdbStored = SDBNew
|
||||
|
||||
data DBEntityId' (s :: DBStored) where
|
||||
DBEntityId :: Int64 -> DBEntityId' 'DBStored
|
||||
DBNewEntity :: DBEntityId' 'DBNew
|
||||
|
||||
deriving instance Show (DBEntityId' s)
|
||||
|
||||
deriving instance Eq (DBEntityId' s)
|
||||
|
||||
type DBEntityId = DBEntityId' 'DBStored
|
||||
|
||||
type DBNewEntity = DBEntityId' 'DBNew
|
||||
|
||||
instance ToJSON (DBEntityId' s) where
|
||||
toEncoding = \case
|
||||
DBEntityId i -> toEncoding i
|
||||
DBNewEntity -> JE.null_
|
||||
toJSON = \case
|
||||
DBEntityId i -> toJSON i
|
||||
DBNewEntity -> J.Null
|
||||
|
||||
instance DBStoredI s => FromJSON (DBEntityId' s) where
|
||||
parseJSON v = case (v, sdbStored @s) of
|
||||
(J.Null, SDBNew) -> pure DBNewEntity
|
||||
(J.Number n, SDBStored) -> case floatingOrInteger n of
|
||||
Left (_ :: Double) -> fail "bad DBEntityId"
|
||||
Right i -> pure $ DBEntityId (fromInteger i)
|
||||
_ -> fail "bad DBEntityId"
|
||||
omittedField = case sdbStored @s of
|
||||
SDBStored -> Nothing
|
||||
SDBNew -> Just DBNewEntity
|
||||
|
||||
instance FromField DBEntityId where
|
||||
#if defined(dbPostgres)
|
||||
fromField x dat = DBEntityId <$> fromField x dat
|
||||
#else
|
||||
fromField x = DBEntityId <$> fromField x
|
||||
#endif
|
||||
|
||||
instance ToField DBEntityId where toField (DBEntityId i) = toField i
|
||||
@@ -179,8 +179,6 @@ module Simplex.Messaging.Crypto
|
||||
unPad,
|
||||
|
||||
-- * X509 Certificates
|
||||
SignedCertificate,
|
||||
Certificate,
|
||||
signCertificate,
|
||||
signX509,
|
||||
verifyX509,
|
||||
@@ -240,7 +238,7 @@ import Data.String
|
||||
import Data.Type.Equality
|
||||
import Data.Typeable (Proxy (Proxy), Typeable)
|
||||
import Data.Word (Word32)
|
||||
import Data.X509
|
||||
import qualified Data.X509 as X
|
||||
import Data.X509.Validation (Fingerprint (..), getFingerprint)
|
||||
import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+))
|
||||
import Network.Transport.Internal (decodeWord16, encodeWord16)
|
||||
@@ -1160,12 +1158,12 @@ sign :: APrivateSignKey -> ByteString -> ASignature
|
||||
sign (APrivateSignKey a k) = ASignature a . sign' k
|
||||
{-# INLINE sign #-}
|
||||
|
||||
signCertificate :: APrivateSignKey -> Certificate -> SignedCertificate
|
||||
signCertificate :: APrivateSignKey -> X.Certificate -> X.SignedCertificate
|
||||
signCertificate = signX509
|
||||
{-# INLINE signCertificate #-}
|
||||
|
||||
signX509 :: (ASN1Object o, Eq o, Show o) => APrivateSignKey -> o -> SignedExact o
|
||||
signX509 key = fst . objectToSignedExact f
|
||||
signX509 :: (ASN1Object o, Eq o, Show o) => APrivateSignKey -> o -> X.SignedExact o
|
||||
signX509 key = fst . X.objectToSignedExact f
|
||||
where
|
||||
f bytes =
|
||||
( signatureBytes $ sign key bytes,
|
||||
@@ -1174,33 +1172,33 @@ signX509 key = fst . objectToSignedExact f
|
||||
)
|
||||
{-# INLINE signX509 #-}
|
||||
|
||||
verifyX509 :: (ASN1Object o, Eq o, Show o) => APublicVerifyKey -> SignedExact o -> Either String o
|
||||
verifyX509 :: (ASN1Object o, Eq o, Show o) => APublicVerifyKey -> X.SignedExact o -> Either String o
|
||||
verifyX509 key exact = do
|
||||
signature <- case signedAlg of
|
||||
SignatureALG_IntrinsicHash PubKeyALG_Ed25519 -> ASignature SEd25519 <$> decodeSignature signedSignature
|
||||
SignatureALG_IntrinsicHash PubKeyALG_Ed448 -> ASignature SEd448 <$> decodeSignature signedSignature
|
||||
X.SignatureALG_IntrinsicHash X.PubKeyALG_Ed25519 -> ASignature SEd25519 <$> decodeSignature signedSignature
|
||||
X.SignatureALG_IntrinsicHash X.PubKeyALG_Ed448 -> ASignature SEd448 <$> decodeSignature signedSignature
|
||||
_ -> Left "unknown x509 signature algorithm"
|
||||
if verify key signature $ getSignedData exact then Right signedObject else Left "bad signature"
|
||||
if verify key signature $ X.getSignedData exact then Right signedObject else Left "bad signature"
|
||||
where
|
||||
Signed {signedObject, signedAlg, signedSignature} = getSigned exact
|
||||
X.Signed {signedObject, signedAlg, signedSignature} = X.getSigned exact
|
||||
{-# INLINE verifyX509 #-}
|
||||
|
||||
certificateFingerprint :: SignedCertificate -> KeyHash
|
||||
certificateFingerprint :: X.SignedCertificate -> KeyHash
|
||||
certificateFingerprint = signedFingerprint
|
||||
{-# INLINE certificateFingerprint #-}
|
||||
|
||||
signedFingerprint :: (ASN1Object o, Eq o, Show o) => SignedExact o -> KeyHash
|
||||
signedFingerprint :: (ASN1Object o, Eq o, Show o) => X.SignedExact o -> KeyHash
|
||||
signedFingerprint o = KeyHash fp
|
||||
where
|
||||
Fingerprint fp = getFingerprint o HashSHA256
|
||||
Fingerprint fp = getFingerprint o X.HashSHA256
|
||||
|
||||
class SignatureAlgorithmX509 a where
|
||||
signatureAlgorithmX509 :: a -> SignatureALG
|
||||
signatureAlgorithmX509 :: a -> X.SignatureALG
|
||||
|
||||
instance SignatureAlgorithm a => SignatureAlgorithmX509 (SAlgorithm a) where
|
||||
signatureAlgorithmX509 = \case
|
||||
SEd25519 -> SignatureALG_IntrinsicHash PubKeyALG_Ed25519
|
||||
SEd448 -> SignatureALG_IntrinsicHash PubKeyALG_Ed448
|
||||
SEd25519 -> X.SignatureALG_IntrinsicHash X.PubKeyALG_Ed25519
|
||||
SEd448 -> X.SignatureALG_IntrinsicHash X.PubKeyALG_Ed448
|
||||
{-# INLINE signatureAlgorithmX509 #-}
|
||||
|
||||
instance SignatureAlgorithmX509 APrivateSignKey where
|
||||
@@ -1217,31 +1215,31 @@ instance SignatureAlgorithmX509 pk => SignatureAlgorithmX509 (a, pk) where
|
||||
{-# INLINE signatureAlgorithmX509 #-}
|
||||
|
||||
-- | A wrapper to marshall signed ASN1 objects, like certificates.
|
||||
newtype SignedObject a = SignedObject {getSignedExact :: SignedExact a}
|
||||
newtype SignedObject a = SignedObject {getSignedExact :: X.SignedExact a}
|
||||
|
||||
instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a) where
|
||||
#if defined(dbPostgres)
|
||||
fromField f dat = SignedObject <$> blobFieldDecoder decodeSignedObject f dat
|
||||
fromField f dat = SignedObject <$> blobFieldDecoder X.decodeSignedObject f dat
|
||||
#else
|
||||
fromField = fmap SignedObject . blobFieldDecoder decodeSignedObject
|
||||
fromField = fmap SignedObject . blobFieldDecoder X.decodeSignedObject
|
||||
#endif
|
||||
|
||||
instance (Eq a, Show a, ASN1Object a) => ToField (SignedObject a) where
|
||||
toField (SignedObject s) = toField . Binary $ encodeSignedObject s
|
||||
toField (SignedObject s) = toField . Binary $ X.encodeSignedObject s
|
||||
|
||||
instance (Eq a, Show a, ASN1Object a) => Encoding (SignedObject a) where
|
||||
smpEncode (SignedObject exact) = smpEncode . Large $ encodeSignedObject exact
|
||||
smpP = fmap SignedObject . decodeSignedObject . unLarge <$?> smpP
|
||||
smpEncode (SignedObject exact) = smpEncode . Large $ X.encodeSignedObject exact
|
||||
smpP = fmap SignedObject . X.decodeSignedObject . unLarge <$?> smpP
|
||||
|
||||
encodeCertChain :: CertificateChain -> L.NonEmpty Large
|
||||
encodeCertChain :: X.CertificateChain -> L.NonEmpty Large
|
||||
encodeCertChain cc = L.fromList $ map Large blobs
|
||||
where
|
||||
CertificateChainRaw blobs = encodeCertificateChain cc
|
||||
X.CertificateChainRaw blobs = X.encodeCertificateChain cc
|
||||
|
||||
certChainP :: A.Parser CertificateChain
|
||||
certChainP :: A.Parser X.CertificateChain
|
||||
certChainP = do
|
||||
rawChain <- CertificateChainRaw . map unLarge . L.toList <$> smpP
|
||||
either (fail . show) pure $ decodeCertificateChain rawChain
|
||||
rawChain <- X.CertificateChainRaw . map unLarge . L.toList <$> smpP
|
||||
either (fail . show) pure $ X.decodeCertificateChain rawChain
|
||||
|
||||
-- | Signature verification.
|
||||
--
|
||||
@@ -1453,19 +1451,19 @@ xSalsa20 secret nonce msg = (rs, msg')
|
||||
(rs, state2) = XSalsa.generate state1 32
|
||||
(msg', _) = XSalsa.combine state2 msg
|
||||
|
||||
publicToX509 :: PublicKey a -> PubKey
|
||||
publicToX509 :: PublicKey a -> X.PubKey
|
||||
publicToX509 = \case
|
||||
PublicKeyEd25519 k -> PubKeyEd25519 k
|
||||
PublicKeyEd448 k -> PubKeyEd448 k
|
||||
PublicKeyX25519 k -> PubKeyX25519 k
|
||||
PublicKeyX448 k -> PubKeyX448 k
|
||||
PublicKeyEd25519 k -> X.PubKeyEd25519 k
|
||||
PublicKeyEd448 k -> X.PubKeyEd448 k
|
||||
PublicKeyX25519 k -> X.PubKeyX25519 k
|
||||
PublicKeyX448 k -> X.PubKeyX448 k
|
||||
|
||||
privateToX509 :: PrivateKey a -> PrivKey
|
||||
privateToX509 :: PrivateKey a -> X.PrivKey
|
||||
privateToX509 = \case
|
||||
PrivateKeyEd25519 k _ -> PrivKeyEd25519 k
|
||||
PrivateKeyEd448 k _ -> PrivKeyEd448 k
|
||||
PrivateKeyX25519 k _ -> PrivKeyX25519 k
|
||||
PrivateKeyX448 k _ -> PrivKeyX448 k
|
||||
PrivateKeyEd25519 k _ -> X.PrivKeyEd25519 k
|
||||
PrivateKeyEd448 k _ -> X.PrivKeyEd448 k
|
||||
PrivateKeyX25519 k _ -> X.PrivKeyX25519 k
|
||||
PrivateKeyX448 k _ -> X.PrivKeyX448 k
|
||||
|
||||
encodeASNObj :: ASN1Object a => a -> ByteString
|
||||
encodeASNObj k = toStrict . encodeASN1 DER $ toASN1 k []
|
||||
@@ -1478,20 +1476,20 @@ decodePubKey = decodeKey >=> x509ToPublic >=> pubKey
|
||||
decodePrivKey :: CryptoPrivateKey k => ByteString -> Either String k
|
||||
decodePrivKey = decodeKey >=> x509ToPrivate >=> privKey
|
||||
|
||||
x509ToPublic :: (PubKey, [ASN1]) -> Either String APublicKey
|
||||
x509ToPublic :: (X.PubKey, [ASN1]) -> Either String APublicKey
|
||||
x509ToPublic = \case
|
||||
(PubKeyEd25519 k, []) -> Right . APublicKey SEd25519 $ PublicKeyEd25519 k
|
||||
(PubKeyEd448 k, []) -> Right . APublicKey SEd448 $ PublicKeyEd448 k
|
||||
(PubKeyX25519 k, []) -> Right . APublicKey SX25519 $ PublicKeyX25519 k
|
||||
(PubKeyX448 k, []) -> Right . APublicKey SX448 $ PublicKeyX448 k
|
||||
(X.PubKeyEd25519 k, []) -> Right . APublicKey SEd25519 $ PublicKeyEd25519 k
|
||||
(X.PubKeyEd448 k, []) -> Right . APublicKey SEd448 $ PublicKeyEd448 k
|
||||
(X.PubKeyX25519 k, []) -> Right . APublicKey SX25519 $ PublicKeyX25519 k
|
||||
(X.PubKeyX448 k, []) -> Right . APublicKey SX448 $ PublicKeyX448 k
|
||||
r -> keyError r
|
||||
|
||||
x509ToPrivate :: (PrivKey, [ASN1]) -> Either String APrivateKey
|
||||
x509ToPrivate :: (X.PrivKey, [ASN1]) -> Either String APrivateKey
|
||||
x509ToPrivate = \case
|
||||
(PrivKeyEd25519 k, []) -> Right . APrivateKey SEd25519 . PrivateKeyEd25519 k $ Ed25519.toPublic k
|
||||
(PrivKeyEd448 k, []) -> Right . APrivateKey SEd448 . PrivateKeyEd448 k $ Ed448.toPublic k
|
||||
(PrivKeyX25519 k, []) -> Right . APrivateKey SX25519 . PrivateKeyX25519 k $ X25519.toPublic k
|
||||
(PrivKeyX448 k, []) -> Right . APrivateKey SX448 . PrivateKeyX448 k $ X448.toPublic k
|
||||
(X.PrivKeyEd25519 k, []) -> Right . APrivateKey SEd25519 . PrivateKeyEd25519 k $ Ed25519.toPublic k
|
||||
(X.PrivKeyEd448 k, []) -> Right . APrivateKey SEd448 . PrivateKeyEd448 k $ Ed448.toPublic k
|
||||
(X.PrivKeyX25519 k, []) -> Right . APrivateKey SX25519 . PrivateKeyX25519 k $ X25519.toPublic k
|
||||
(X.PrivKeyX448 k, []) -> Right . APrivateKey SX448 . PrivateKeyX448 k $ X448.toPublic k
|
||||
r -> keyError r
|
||||
|
||||
decodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1])
|
||||
|
||||
@@ -49,7 +49,7 @@ import qualified Data.Text as T
|
||||
import Data.Time.Clock.System (getSystemTime)
|
||||
import Data.Tuple (swap)
|
||||
import Data.Word (Word16)
|
||||
import qualified Data.X509 as X509
|
||||
import qualified Data.X509 as X
|
||||
import Data.X509.Validation (Fingerprint (..), getFingerprint)
|
||||
import Network.Socket (PortNumber, SockAddr (..), hostAddressToTuple)
|
||||
import qualified Network.TLS as TLS
|
||||
@@ -157,7 +157,7 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
|
||||
tlsHooks r knownHost_ hostCAHash =
|
||||
def
|
||||
{ TLS.onNewHandshake = \_ -> atomically $ isNothing <$> tryReadTMVar r,
|
||||
TLS.onClientCertificate = \(X509.CertificateChain chain) ->
|
||||
TLS.onClientCertificate = \(X.CertificateChain chain) ->
|
||||
case chain of
|
||||
[_leaf, ca] -> do
|
||||
let kh = certFingerprint ca
|
||||
@@ -190,16 +190,16 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
|
||||
}
|
||||
pure $ signInvitation (snd sessKeys) idPrivKey inv
|
||||
|
||||
genTLSCredentials :: TVar ChaChaDRG -> C.APrivateSignKey -> C.SignedCertificate -> IO TLS.Credential
|
||||
genTLSCredentials :: TVar ChaChaDRG -> C.APrivateSignKey -> X.SignedCertificate -> IO TLS.Credential
|
||||
genTLSCredentials drg caKey caCert = do
|
||||
let caCreds = (C.signatureKeyPair caKey, caCert)
|
||||
leaf <- genCredentials drg (Just caCreds) (0, 24 * 999999) "localhost" -- session-signing cert
|
||||
pure . snd $ tlsCredentials (leaf :| [caCreds])
|
||||
|
||||
certFingerprint :: X509.SignedCertificate -> C.KeyHash
|
||||
certFingerprint :: X.SignedCertificate -> C.KeyHash
|
||||
certFingerprint caCert = C.KeyHash fp
|
||||
where
|
||||
Fingerprint fp = getFingerprint caCert X509.HashSHA256
|
||||
Fingerprint fp = getFingerprint caCert X.HashSHA256
|
||||
|
||||
cancelHostClient :: RCHostClient -> IO ()
|
||||
cancelHostClient RCHostClient {action, client_ = RCHClient_ {announcer, endSession}} = do
|
||||
|
||||
@@ -22,8 +22,6 @@ deriving instance Eq (StoredRcvQueue q)
|
||||
|
||||
deriving instance Eq (StoredSndQueue q)
|
||||
|
||||
deriving instance Eq (DBQueueId q)
|
||||
|
||||
deriving instance Eq ClientNtfCreds
|
||||
|
||||
deriving instance Eq ShortLinkCreds
|
||||
|
||||
@@ -54,6 +54,7 @@ import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding.String (StrEncoding (..))
|
||||
import Simplex.Messaging.Protocol (EntityId (..), QueueMode (..), SubscriptionMode (..), pattern VersionSMPC)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import System.Random
|
||||
import Test.Hspec hiding (fit, it)
|
||||
import UnliftIO.Directory (removeFile)
|
||||
@@ -230,7 +231,7 @@ rcvQueue1 =
|
||||
queueMode = Just QMMessaging,
|
||||
shortLink = Nothing,
|
||||
status = New,
|
||||
dbQueueId = DBNewQueue,
|
||||
dbQueueId = DBNewEntity,
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
rcvSwchStatus = Nothing,
|
||||
@@ -252,7 +253,7 @@ sndQueue1 =
|
||||
e2ePubKey = Nothing,
|
||||
e2eDhSecret = testDhSecret,
|
||||
status = New,
|
||||
dbQueueId = DBNewQueue,
|
||||
dbQueueId = DBNewEntity,
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
sndSwchStatus = Nothing,
|
||||
@@ -271,11 +272,11 @@ testCreateRcvConn =
|
||||
g <- C.newRandom
|
||||
Right (connId, rq@RcvQueue {dbQueueId}) <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
connId `shouldBe` "conn1"
|
||||
dbQueueId `shouldBe` DBQueueId 1
|
||||
dbQueueId `shouldBe` DBEntityId 1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rq))
|
||||
Right sq@SndQueue {dbQueueId = dbQueueId'} <- upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
dbQueueId' `shouldBe` DBQueueId 1
|
||||
dbQueueId' `shouldBe` DBEntityId 1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq] [sq]))
|
||||
|
||||
@@ -287,7 +288,7 @@ testCreateRcvConnRandomId =
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 {connId} rq))
|
||||
Right sq@SndQueue {dbQueueId = dbQueueId'} <- upgradeRcvConnToDuplex db connId sndQueue1
|
||||
dbQueueId' `shouldBe` DBQueueId 1
|
||||
dbQueueId' `shouldBe` DBEntityId 1
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq] [sq]))
|
||||
|
||||
@@ -305,11 +306,11 @@ testCreateSndConn =
|
||||
g <- C.newRandom
|
||||
Right (connId, sq@SndQueue {dbQueueId}) <- createSndConn db g cData1 sndQueue1
|
||||
connId `shouldBe` "conn1"
|
||||
dbQueueId `shouldBe` DBQueueId 1
|
||||
dbQueueId `shouldBe` DBEntityId 1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sq))
|
||||
Right rq@RcvQueue {dbQueueId = dbQueueId'} <- upgradeSndConnToDuplex db "conn1" rcvQueue1
|
||||
dbQueueId' `shouldBe` DBQueueId 1
|
||||
dbQueueId' `shouldBe` DBEntityId 1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq] [sq]))
|
||||
|
||||
@@ -321,7 +322,7 @@ testCreateSndConnRandomID =
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 {connId} sq))
|
||||
Right (rq@RcvQueue {dbQueueId = dbQueueId'}) <- upgradeSndConnToDuplex db connId rcvQueue1
|
||||
dbQueueId' `shouldBe` DBQueueId 1
|
||||
dbQueueId' `shouldBe` DBEntityId 1
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq] [sq]))
|
||||
|
||||
@@ -412,7 +413,7 @@ testUpgradeRcvConnToDuplex =
|
||||
e2ePubKey = Nothing,
|
||||
e2eDhSecret = testDhSecret,
|
||||
status = New,
|
||||
dbQueueId = DBNewQueue,
|
||||
dbQueueId = DBNewEntity,
|
||||
sndSwchStatus = Nothing,
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
@@ -443,7 +444,7 @@ testUpgradeSndConnToDuplex =
|
||||
queueMode = Just QMMessaging,
|
||||
shortLink = Nothing,
|
||||
status = New,
|
||||
dbQueueId = DBNewQueue,
|
||||
dbQueueId = DBNewEntity,
|
||||
rcvSwchStatus = Nothing,
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
|
||||
@@ -14,7 +14,8 @@ import qualified Data.Map as M
|
||||
import qualified Data.Set as S
|
||||
import Data.String (IsString (..))
|
||||
import Simplex.Messaging.Agent.Protocol (ConnId, QueueStatus (..), UserId)
|
||||
import Simplex.Messaging.Agent.Store (DBQueueId (..), RcvQueue, StoredRcvQueue (..))
|
||||
import Simplex.Messaging.Agent.Store (RcvQueue, StoredRcvQueue (..))
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import qualified Simplex.Messaging.Agent.TRcvQueues as RQ
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (EntityId (..), QueueMode (..), RecipientId, SMPServer, pattern NoEntity, pattern VersionSMPC)
|
||||
@@ -201,7 +202,7 @@ dummyRQ userId server connId rcvId =
|
||||
queueMode = Just QMMessaging,
|
||||
shortLink = Nothing,
|
||||
status = New,
|
||||
dbQueueId = DBQueueId 0,
|
||||
dbQueueId = DBEntityId 0,
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
rcvSwchStatus = Nothing,
|
||||
|
||||
Reference in New Issue
Block a user