Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 68 additions & 28 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,18 +438,15 @@ async def _fallback_to_text_only_and_retry(
image_fallback_used,
)

def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
"""创建带代理的 HTTP 客户端"""
proxy = provider_config.get("proxy", "")
return create_proxy_client("OpenAI", proxy)

def __init__(self, provider_config, provider_settings) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = None
self.api_keys: list = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 120)
self.custom_headers = provider_config.get("custom_headers", {})
self.client: AsyncOpenAI | AsyncAzureOpenAI | None = None
self._client_alive = False
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)

Expand All @@ -459,34 +456,59 @@ def __init__(self, provider_config, provider_settings) -> None:
for key in self.custom_headers:
self.custom_headers[key] = str(self.custom_headers[key])

if "api_version" in provider_config:
self.client = self._create_openai_client()
self._client_alive = True

self.default_params = inspect.signature(
Comment thread
Tz-WIND marked this conversation as resolved.
self.client.chat.completions.create,
).parameters.keys()

model = provider_config.get("model", "unknown")
self.set_model(model)

self.reasoning_key = "reasoning_content"

def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
"""创建带代理的 HTTP 客户端"""
proxy = provider_config.get("proxy", "")

return create_proxy_client("OpenAI", proxy)

def _create_openai_client(
self,
api_key: str | None = None,
) -> AsyncOpenAI | AsyncAzureOpenAI:
"""创建 OpenAI/Azure 客户端实例,将初始化逻辑解耦以便复用。"""
api_key = api_key or self.chosen_api_key
if "api_version" in self.provider_config:
# Using Azure OpenAI API
self.client = AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
return AsyncAzureOpenAI(
api_key=api_key,
api_version=self.provider_config.get("api_version", None),
default_headers=self.custom_headers,
base_url=provider_config.get("api_base", ""),
base_url=self.provider_config.get("api_base", ""),
timeout=self.timeout,
http_client=self._create_http_client(provider_config),
http_client=self._create_http_client(self.provider_config),
)
else:
# Using OpenAI Official API
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
return AsyncOpenAI(
api_key=api_key,
base_url=self.provider_config.get("api_base", None),
default_headers=self.custom_headers,
timeout=self.timeout,
http_client=self._create_http_client(provider_config),
http_client=self._create_http_client(self.provider_config),
)

self.default_params = inspect.signature(
self.client.chat.completions.create,
).parameters.keys()

model = provider_config.get("model", "unknown")
self.set_model(model)

self.reasoning_key = "reasoning_content"
def _ensure_client(self) -> None:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
"""确保 client 可用,仅在真实 API 调用前按需重建。"""
if self.client is None or not self._client_alive:
logger.warning("检测到 OpenAI client 已关闭或未初始化,正在重新创建...")
self.client = self._create_openai_client()
self._client_alive = True
self.default_params = inspect.signature(
self.client.chat.completions.create,
).parameters.keys()

def _ollama_disable_thinking_enabled(self) -> bool:
value = self.provider_config.get("ollama_disable_thinking", False)
Expand All @@ -509,6 +531,7 @@ def _apply_provider_specific_extra_body_overrides(
extra_body["reasoning_effort"] = "none"

async def get_models(self):
self._ensure_client()
try:
models_str = []
models = await self.client.models.list()
Expand All @@ -520,6 +543,7 @@ async def get_models(self):
raise Exception(f"获取模型列表失败:{e}")

async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
self._ensure_client()
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
Expand Down Expand Up @@ -592,6 +616,7 @@ async def _query_stream(
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API,逐步返回结果"""
self._ensure_client()
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
Expand Down Expand Up @@ -1145,7 +1170,10 @@ async def text_chat(
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
self.chosen_api_key = chosen_key
self._ensure_client()
if self.client is not None:
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
Expand Down Expand Up @@ -1216,7 +1244,10 @@ async def text_chat_stream(
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
self.chosen_api_key = chosen_key
self._ensure_client()
if self.client is not None:
self.client.api_key = chosen_key
async for response in self._query_stream(payloads, func_tool):
yield response
break
Expand Down Expand Up @@ -1270,13 +1301,15 @@ async def _remove_image_from_context(self, contexts: list):
return new_contexts

def get_current_key(self) -> str:
return self.client.api_key
return self.chosen_api_key

def get_keys(self) -> list[str]:
return self.api_keys

def set_key(self, key) -> None:
self.client.api_key = key
self.chosen_api_key = key
if self.client is not None:
self.client.api_key = key

async def assemble_context(
self,
Expand Down Expand Up @@ -1355,5 +1388,12 @@ async def encode_image_bs64(self, image_url: str) -> str:
return image_data

async def terminate(self):
"""关闭 client 并将引用置为 None,确保后续仅在真实调用时重建。"""
if self.client:
await self.client.close()
try:
await self.client.close()
except Exception as e:
logger.warning(f"关闭 OpenAI client 时出错: {e}")
finally:
self.client = None
self._client_alive = False
Loading