Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# coding=utf-8
import ast
import io
import requests

import uuid_utils.compat as uuid
from django.db.models import QuerySet

from application.flow.common import WorkflowMode
from application.flow.i_step_node import NodeResult
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
from common.utils.common import get_file_name_from_content_disposition, get_file_name_from_url
from common.utils.logger import maxkb_logger
from knowledge.models import File, FileSourceType
from knowledge.serializers.document import split_handles, parse_table_handle_list, FileBufferHandle
from oss.serializers.file import validate_url, SafeHTTPAdapter

splitter = '\n`-----------------------------------`\n'

Expand All @@ -23,7 +26,6 @@ def execute(self, document, chat_id=None, **kwargs):
get_buffer = FileBufferHandle().get_buffer

self.context['document_list'] = document
content = []
if document is None or not isinstance(document, list):
return NodeResult({'content': '', 'document_list': []}, {})

Expand Down Expand Up @@ -62,9 +64,106 @@ def save_image(image_list):
if not QuerySet(File).filter(id=new_file.id).exists():
new_file.save(file_bytes)

# 从URL下载文件并保存为File对象
def download_and_save_file(url, file_name=None):
try:
# 验证URL安全性
validated_url = validate_url(url)

# 创建安全的HTTP会话
session = requests.Session()
safe_adapter = SafeHTTPAdapter()
session.mount('http://', safe_adapter)
session.mount('https://', safe_adapter)

try:
# 发送GET请求下载文件
response = session.get(
validated_url,
timeout=30,
allow_redirects=True
)
response.raise_for_status()

# 获取文件名(如果未提供)
if not file_name:
# 如果Content-Disposition头中有文件名,优先使用
file_name = get_file_name_from_content_disposition(response.headers.get('Content-Disposition', ''))
if file_name is None:
# 从URL路径中提取文件名
file_name = get_file_name_from_url(validated_url, 'downloaded_document')

# 获取文件内容
file_bytes = response.content

# 生成文件ID
file_id = uuid.uuid7()

# 确定source_type和source_id
source_type = FileSourceType.APPLICATION.value if application_id else FileSourceType.KNOWLEDGE.value if knowledge_id else FileSourceType.TOOL.value
source_id = application_id or knowledge_id or tool_id

# 创建File对象
meta = {
'debug': False if (application_id or knowledge_id or tool_id) else True,
'chat_id': chat_id,
'application_id': str(application_id) if application_id else None,
'knowledge_id': str(knowledge_id) if knowledge_id else None,
'tool_id': str(tool_id) if tool_id else None,
'file_id': str(file_id),
'source_url': url
}

new_file = File(
id=file_id,
file_name=file_name,
file_size=len(file_bytes),
source_type=source_type,
source_id=source_id,
meta=meta
)

# 保存文件到数据库
new_file.save(file_bytes)

maxkb_logger.info(f'Successfully downloaded and saved file from URL: {url}, file_id: {file_id}')

return new_file

finally:
session.close()

except Exception as e:
maxkb_logger.error(f'Failed to download document file from URL: {url}, error: {str(e)}')
raise Exception(f'Failed to download document file: {str(e)}')

content = []
document_list = []
for doc in document:
file = QuerySet(File).filter(id=doc['file_id']).first()
# 考虑API调用时,用户传错了格式,抛出异常提示
if isinstance(doc, str):
raise ValueError('The "document_list" parameters must be in the format of `[{ "url": "http......" }, ......]`')

# 如果是文档的 HTTP(s) URL地址,则先下载并保存到file表中
if not doc.get("file_id") and doc.get("url") and (doc.get("url").startswith("http:") or doc.get("url").startswith("https:")):
try:
# 下载并保存文件
file = download_and_save_file(doc["url"], doc.get('name', None))

# 更新doc字典,添加file_id
doc['file_id'] = str(file.id)
if not doc.get('name'):
doc['name'] = file.file_name

maxkb_logger.info(f'Downloaded file from URL and assigned file_id: {doc["file_id"]}')
except Exception as e:
maxkb_logger.error(f'Error processing document URL: {doc.get("url")}, error: {str(e)}')
raise e
elif doc.get("file_id"):
file = QuerySet(File).filter(id=doc['file_id']).first()
else:
raise ValueError('Please provide a valid document file ID or URL')

buffer = io.BytesIO(file.get_bytes())
buffer.name = doc['name'] # this is the important line

Expand Down
37 changes: 37 additions & 0 deletions apps/common/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from django.db.models import QuerySet
from django.utils.translation import gettext as _
from pydub import AudioSegment
from urllib.parse import urlparse

from ..database_model_manage.database_model_manage import DatabaseModelManage
from ..exception.app_exception import AppApiException
Expand Down Expand Up @@ -409,6 +410,7 @@ def is_valid_uuid(uuid_string):
except ValueError:
return False


def common_convert_value(_type, value):
if value is None:
return None
Expand Down Expand Up @@ -436,3 +438,38 @@ def common_convert_value(_type, value):
return v
raise Exception(_('type error'))
return value


def get_file_name_from_content_disposition(content_disposition, default = None):
"""
尝试从响应头 `Content-Disposition` 中获取文件名

:param content_disposition: 响应头 `Content-Disposition`
:param default: 默认文件名
:return: 文件名
"""
if not content_disposition:
return default

file_name = default
if 'filename=' in content_disposition:
filename_part = content_disposition.split('filename=')[1].split(';')[0].strip('"\'')
if filename_part:
file_name = filename_part

return file_name


def get_file_name_from_url(url, default = None):
"""
尝试从url中获取文件名
:param url: 文件URL地址
:param default: 默认文件名
:return: 文件名
"""
if not url:
return default

parsed_url = urlparse(url)
path_parts = parsed_url.path.split('/')
return path_parts[-1] if path_parts and path_parts[-1] else default
5 changes: 2 additions & 3 deletions ui/src/utils/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ const typeList: any = {
export function getImgUrl(name: string) {
const list = Object.values(typeList).flat()

const type = list.includes(fileType(name).toLowerCase())
? fileType(name).toLowerCase()
: 'unknown'
const typeStr = fileType(name).toLowerCase()
const type = list.includes(typeStr) ? typeStr : 'unknown'
return new URL(`../assets/fileType/${type}-icon.svg`, import.meta.url).href
}

Expand Down
Loading