diff --git a/src/Simplex/Chat/Store/Files.hs b/src/Simplex/Chat/Store/Files.hs index 343167d8d3..6977fca74b 100644 --- a/src/Simplex/Chat/Store/Files.hs +++ b/src/Simplex/Chat/Store/Files.hs @@ -853,19 +853,20 @@ getFileTransfer :: DB.Connection -> User -> Int64 -> ExceptT StoreError IO FileT getFileTransfer db user@User {userId} fileId = fileTransfer =<< liftIO (getFileTransferRow_ db userId fileId) where - fileTransfer :: [(Maybe Int64, Maybe Int64)] -> ExceptT StoreError IO FileTransfer - fileTransfer [(Nothing, Nothing)] = FTLocal <$> getLocalFileMeta db user fileId - fileTransfer [(Nothing, Just _)] = FTRcv <$> getRcvFileTransfer db user fileId - fileTransfer _ = do + fileTransfer :: [(Maybe Int64, Maybe Int64, FileProtocol)] -> ExceptT StoreError IO FileTransfer + fileTransfer [(Nothing, Just _, _)] = FTRcv <$> getRcvFileTransfer db user fileId + fileTransfer [(Just _, Nothing, _)] = do (ftm, fts) <- getSndFileTransfer db user fileId pure $ FTSnd {fileTransferMeta = ftm, sndFileTransfers = fts} + fileTransfer [(Nothing, Nothing, FPLocal)] = FTLocal <$> getLocalFileMeta db userId fileId + fileTransfer _ = throwError $ SEBadFileTransfer fileId -getFileTransferRow_ :: DB.Connection -> UserId -> Int64 -> IO [(Maybe Int64, Maybe Int64)] +getFileTransferRow_ :: DB.Connection -> UserId -> Int64 -> IO [(Maybe Int64, Maybe Int64, FileProtocol)] getFileTransferRow_ db userId fileId = DB.query db [sql| - SELECT s.file_id, r.file_id + SELECT s.file_id, r.file_id, f.protocol FROM files f LEFT JOIN snd_files s ON s.file_id = f.file_id LEFT JOIN rcv_files r ON r.file_id = f.file_id @@ -926,8 +927,8 @@ getFileTransferMeta_ db userId fileId = xftpSndFile = (\fId -> XFTPSndFile {agentSndFileId = fId, privateSndFileDescr, agentSndFileDeleted, cryptoArgs}) <$> aSndFileId_ in FileTransferMeta {fileId, xftpSndFile, fileName, fileSize, chunkSize, filePath, fileInline, cancelled = fromMaybe False cancelled_} -getLocalFileMeta :: DB.Connection -> User -> Int64 -> ExceptT StoreError IO LocalFileMeta -getLocalFileMeta db User {userId} fileId = +getLocalFileMeta :: DB.Connection -> UserId -> Int64 -> ExceptT StoreError IO LocalFileMeta +getLocalFileMeta db userId fileId = ExceptT . firstRow localFileMeta (SEFileNotFound fileId) $ DB.query db @@ -956,16 +957,20 @@ getNoteFolderFileInfo db User {userId} NoteFolder {noteFolderId} = getLocalCryptoFile :: DB.Connection -> UserId -> Int64 -> Bool -> ExceptT StoreError IO CryptoFile getLocalCryptoFile db userId fileId sent = liftIO (getFileTransferRow_ db userId fileId) >>= \case - [(Nothing, Just _)] -> do + [(Nothing, Just _, _)] -> do when sent $ throwError $ SEFileNotFound fileId RcvFileTransfer {fileStatus, cryptoArgs} <- getRcvFileTransfer_ db userId fileId case fileStatus of RFSComplete RcvFileInfo {filePath} -> pure $ CryptoFile filePath cryptoArgs _ -> throwError $ SEFileNotFound fileId - _ -> do + [(Just _, Nothing, _)] -> do unless sent $ throwError $ SEFileNotFound fileId FileTransferMeta {filePath, xftpSndFile} <- getFileTransferMeta_ db userId fileId pure $ CryptoFile filePath $ xftpSndFile >>= \f -> f.cryptoArgs + [(Nothing, Nothing, FPLocal)] -> do + LocalFileMeta {filePath, fileCryptoArgs} <- getLocalFileMeta db userId fileId + pure $ CryptoFile filePath fileCryptoArgs + _ -> throwError $ SEBadFileTransfer fileId updateDirectCIFileStatus :: forall d. MsgDirectionI d => DB.Connection -> VersionRange -> User -> Int64 -> CIFileStatus d -> ExceptT StoreError IO AChatItem updateDirectCIFileStatus db vr user fileId fileStatus = do diff --git a/src/Simplex/Chat/Store/Shared.hs b/src/Simplex/Chat/Store/Shared.hs index 7a9a106e71..9240f7d0c9 100644 --- a/src/Simplex/Chat/Store/Shared.hs +++ b/src/Simplex/Chat/Store/Shared.hs @@ -78,6 +78,7 @@ data StoreError | SEFileNotFound {fileId :: FileTransferId} | SERcvFileInvalid {fileId :: FileTransferId} | SERcvFileInvalidDescrPart + | SEBadFileTransfer {fileId :: FileTransferId} | SESharedMsgIdNotFoundByFileId {fileId :: FileTransferId} | SEFileIdNotFoundBySharedMsgId {sharedMsgId :: SharedMsgId} | SESndFileNotFoundXFTP {agentSndFileId :: AgentSndFileId}