add protocol check for getFileTransfer protocol

This commit is contained in:
IC Rainbow
2023-12-24 20:40:56 +02:00
parent 5202349f6c
commit fce04aa34d
2 changed files with 16 additions and 10 deletions

View File

@@ -853,19 +853,20 @@ getFileTransfer :: DB.Connection -> User -> Int64 -> ExceptT StoreError IO FileT
getFileTransfer db user@User {userId} fileId =
fileTransfer =<< liftIO (getFileTransferRow_ db userId fileId)
where
fileTransfer :: [(Maybe Int64, Maybe Int64)] -> ExceptT StoreError IO FileTransfer
fileTransfer [(Nothing, Nothing)] = FTLocal <$> getLocalFileMeta db user fileId
fileTransfer [(Nothing, Just _)] = FTRcv <$> getRcvFileTransfer db user fileId
fileTransfer _ = do
fileTransfer :: [(Maybe Int64, Maybe Int64, FileProtocol)] -> ExceptT StoreError IO FileTransfer
fileTransfer [(Nothing, Just _, _)] = FTRcv <$> getRcvFileTransfer db user fileId
fileTransfer [(Just _, Nothing, _)] = do
(ftm, fts) <- getSndFileTransfer db user fileId
pure $ FTSnd {fileTransferMeta = ftm, sndFileTransfers = fts}
fileTransfer [(Nothing, Nothing, FPLocal)] = FTLocal <$> getLocalFileMeta db userId fileId
fileTransfer _ = throwError $ SEBadFileTransfer fileId
getFileTransferRow_ :: DB.Connection -> UserId -> Int64 -> IO [(Maybe Int64, Maybe Int64)]
getFileTransferRow_ :: DB.Connection -> UserId -> Int64 -> IO [(Maybe Int64, Maybe Int64, FileProtocol)]
getFileTransferRow_ db userId fileId =
DB.query
db
[sql|
SELECT s.file_id, r.file_id
SELECT s.file_id, r.file_id, f.protocol
FROM files f
LEFT JOIN snd_files s ON s.file_id = f.file_id
LEFT JOIN rcv_files r ON r.file_id = f.file_id
@@ -926,8 +927,8 @@ getFileTransferMeta_ db userId fileId =
xftpSndFile = (\fId -> XFTPSndFile {agentSndFileId = fId, privateSndFileDescr, agentSndFileDeleted, cryptoArgs}) <$> aSndFileId_
in FileTransferMeta {fileId, xftpSndFile, fileName, fileSize, chunkSize, filePath, fileInline, cancelled = fromMaybe False cancelled_}
getLocalFileMeta :: DB.Connection -> User -> Int64 -> ExceptT StoreError IO LocalFileMeta
getLocalFileMeta db User {userId} fileId =
getLocalFileMeta :: DB.Connection -> UserId -> Int64 -> ExceptT StoreError IO LocalFileMeta
getLocalFileMeta db userId fileId =
ExceptT . firstRow localFileMeta (SEFileNotFound fileId) $
DB.query
db
@@ -956,16 +957,20 @@ getNoteFolderFileInfo db User {userId} NoteFolder {noteFolderId} =
getLocalCryptoFile :: DB.Connection -> UserId -> Int64 -> Bool -> ExceptT StoreError IO CryptoFile
getLocalCryptoFile db userId fileId sent =
liftIO (getFileTransferRow_ db userId fileId) >>= \case
[(Nothing, Just _)] -> do
[(Nothing, Just _, _)] -> do
when sent $ throwError $ SEFileNotFound fileId
RcvFileTransfer {fileStatus, cryptoArgs} <- getRcvFileTransfer_ db userId fileId
case fileStatus of
RFSComplete RcvFileInfo {filePath} -> pure $ CryptoFile filePath cryptoArgs
_ -> throwError $ SEFileNotFound fileId
_ -> do
[(Just _, Nothing, _)] -> do
unless sent $ throwError $ SEFileNotFound fileId
FileTransferMeta {filePath, xftpSndFile} <- getFileTransferMeta_ db userId fileId
pure $ CryptoFile filePath $ xftpSndFile >>= \f -> f.cryptoArgs
[(Nothing, Nothing, FPLocal)] -> do
LocalFileMeta {filePath, fileCryptoArgs} <- getLocalFileMeta db userId fileId
pure $ CryptoFile filePath fileCryptoArgs
_ -> throwError $ SEBadFileTransfer fileId
updateDirectCIFileStatus :: forall d. MsgDirectionI d => DB.Connection -> VersionRange -> User -> Int64 -> CIFileStatus d -> ExceptT StoreError IO AChatItem
updateDirectCIFileStatus db vr user fileId fileStatus = do

View File

@@ -78,6 +78,7 @@ data StoreError
| SEFileNotFound {fileId :: FileTransferId}
| SERcvFileInvalid {fileId :: FileTransferId}
| SERcvFileInvalidDescrPart
| SEBadFileTransfer {fileId :: FileTransferId}
| SESharedMsgIdNotFoundByFileId {fileId :: FileTransferId}
| SEFileIdNotFoundBySharedMsgId {sharedMsgId :: SharedMsgId}
| SESndFileNotFoundXFTP {agentSndFileId :: AgentSndFileId}