diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py index 68eeadfe055..65bd215830a 100644 --- a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -1,6 +1,6 @@ # coding=utf-8 -import ast import io +import requests import uuid_utils.compat as uuid from django.db.models import QuerySet @@ -8,8 +8,11 @@ from application.flow.common import WorkflowMode from application.flow.i_step_node import NodeResult from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode +from common.utils.common import get_file_name_from_content_disposition, get_file_name_from_url +from common.utils.logger import maxkb_logger from knowledge.models import File, FileSourceType from knowledge.serializers.document import split_handles, parse_table_handle_list, FileBufferHandle +from oss.serializers.file import validate_url, SafeHTTPAdapter splitter = '\n`-----------------------------------`\n' @@ -23,7 +26,6 @@ def execute(self, document, chat_id=None, **kwargs): get_buffer = FileBufferHandle().get_buffer self.context['document_list'] = document - content = [] if document is None or not isinstance(document, list): return NodeResult({'content': '', 'document_list': []}, {}) @@ -62,9 +64,106 @@ def save_image(image_list): if not QuerySet(File).filter(id=new_file.id).exists(): new_file.save(file_bytes) + # 从URL下载文件并保存为File对象 + def download_and_save_file(url, file_name=None): + try: + # 验证URL安全性 + validated_url = validate_url(url) + + # 创建安全的HTTP会话 + session = requests.Session() + safe_adapter = SafeHTTPAdapter() + session.mount('http://', safe_adapter) + session.mount('https://', safe_adapter) + + try: + # 发送GET请求下载文件 + response = session.get( + validated_url, + timeout=30, + allow_redirects=True + ) + response.raise_for_status() + + # 获取文件名(如果未提供) + if not file_name: + # 如果Content-Disposition头中有文件名,优先使用 + file_name = get_file_name_from_content_disposition(response.headers.get('Content-Disposition', '')) + if file_name is None: + # 从URL路径中提取文件名 + file_name = get_file_name_from_url(validated_url, 'downloaded_document') + + # 获取文件内容 + file_bytes = response.content + + # 生成文件ID + file_id = uuid.uuid7() + + # 确定source_type和source_id + source_type = FileSourceType.APPLICATION.value if application_id else FileSourceType.KNOWLEDGE.value if knowledge_id else FileSourceType.TOOL.value + source_id = application_id or knowledge_id or tool_id + + # 创建File对象 + meta = { + 'debug': False if (application_id or knowledge_id or tool_id) else True, + 'chat_id': chat_id, + 'application_id': str(application_id) if application_id else None, + 'knowledge_id': str(knowledge_id) if knowledge_id else None, + 'tool_id': str(tool_id) if tool_id else None, + 'file_id': str(file_id), + 'source_url': url + } + + new_file = File( + id=file_id, + file_name=file_name, + file_size=len(file_bytes), + source_type=source_type, + source_id=source_id, + meta=meta + ) + + # 保存文件到数据库 + new_file.save(file_bytes) + + maxkb_logger.info(f'Successfully downloaded and saved file from URL: {url}, file_id: {file_id}') + + return new_file + + finally: + session.close() + + except Exception as e: + maxkb_logger.error(f'Failed to download document file from URL: {url}, error: {str(e)}') + raise Exception(f'Failed to download document file: {str(e)}') + + content = [] document_list = [] for doc in document: - file = QuerySet(File).filter(id=doc['file_id']).first() + # 考虑API调用时,用户传错了格式,抛出异常提示 + if isinstance(doc, str): + raise ValueError('The "document_list" parameters must be in the format of `[{ "url": "http......" }, ......]`') + + # 如果是文档的 HTTP(s) URL地址,则先下载并保存到file表中 + if not doc.get("file_id") and doc.get("url") and (doc.get("url").startswith("http:") or doc.get("url").startswith("https:")): + try: + # 下载并保存文件 + file = download_and_save_file(doc["url"], doc.get('name', None)) + + # 更新doc字典,添加file_id + doc['file_id'] = str(file.id) + if not doc.get('name'): + doc['name'] = file.file_name + + maxkb_logger.info(f'Downloaded file from URL and assigned file_id: {doc["file_id"]}') + except Exception as e: + maxkb_logger.error(f'Error processing document URL: {doc.get("url")}, error: {str(e)}') + raise e + elif doc.get("file_id"): + file = QuerySet(File).filter(id=doc['file_id']).first() + else: + raise ValueError('Please provide a valid document file ID or URL') + buffer = io.BytesIO(file.get_bytes()) buffer.name = doc['name'] # this is the important line diff --git a/apps/common/utils/common.py b/apps/common/utils/common.py index 5c6e17b35bb..e4498e1bf2a 100644 --- a/apps/common/utils/common.py +++ b/apps/common/utils/common.py @@ -23,6 +23,7 @@ from django.db.models import QuerySet from django.utils.translation import gettext as _ from pydub import AudioSegment +from urllib.parse import urlparse from ..database_model_manage.database_model_manage import DatabaseModelManage from ..exception.app_exception import AppApiException @@ -409,6 +410,7 @@ def is_valid_uuid(uuid_string): except ValueError: return False + def common_convert_value(_type, value): if value is None: return None @@ -436,3 +438,38 @@ def common_convert_value(_type, value): return v raise Exception(_('type error')) return value + + +def get_file_name_from_content_disposition(content_disposition, default = None): + """ + 尝试从响应头 `Content-Disposition` 中获取文件名 + + :param content_disposition: 响应头 `Content-Disposition` + :param default: 默认文件名 + :return: 文件名 + """ + if not content_disposition: + return default + + file_name = default + if 'filename=' in content_disposition: + filename_part = content_disposition.split('filename=')[1].split(';')[0].strip('"\'') + if filename_part: + file_name = filename_part + + return file_name + + +def get_file_name_from_url(url, default = None): + """ + 尝试从url中获取文件名 + :param url: 文件URL地址 + :param default: 默认文件名 + :return: 文件名 + """ + if not url: + return default + + parsed_url = urlparse(url) + path_parts = parsed_url.path.split('/') + return path_parts[-1] if path_parts and path_parts[-1] else default diff --git a/ui/src/utils/common.ts b/ui/src/utils/common.ts index b2148247ff1..ac10083a520 100644 --- a/ui/src/utils/common.ts +++ b/ui/src/utils/common.ts @@ -64,9 +64,8 @@ const typeList: any = { export function getImgUrl(name: string) { const list = Object.values(typeList).flat() - const type = list.includes(fileType(name).toLowerCase()) - ? fileType(name).toLowerCase() - : 'unknown' + const typeStr = fileType(name).toLowerCase() + const type = list.includes(typeStr) ? typeStr : 'unknown' return new URL(`../assets/fileType/${type}-icon.svg`, import.meta.url).href }