-
Notifications
You must be signed in to change notification settings - Fork 148
Expand file tree
/
Copy pathagent_loop.py
More file actions
124 lines (113 loc) · 6.4 KB
/
agent_loop.py
File metadata and controls
124 lines (113 loc) · 6.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json, re, os
from dataclasses import dataclass
from typing import Any, Optional
@dataclass
class StepOutcome:
data: Any
next_prompt: Optional[str] = None
should_exit: bool = False
def try_call_generator(func, *args, **kwargs):
ret = func(*args, **kwargs)
if hasattr(ret, '__iter__') and not isinstance(ret, (str, bytes, dict, list)): ret = yield from ret
return ret
class BaseHandler:
def tool_before_callback(self, tool_name, args, response): pass
def tool_after_callback(self, tool_name, args, response, ret): pass
def next_prompt_patcher(self, next_prompt, outcome, turn): return next_prompt
def dispatch(self, tool_name, args, response, index=0):
method_name = f"do_{tool_name}"
if hasattr(self, method_name):
args['_index'] = index
prer = yield from try_call_generator(self.tool_before_callback, tool_name, args, response)
ret = yield from try_call_generator(getattr(self, method_name), args, response)
_ = yield from try_call_generator(self.tool_after_callback, tool_name, args, response, ret)
return ret
elif tool_name == 'bad_json':
return StepOutcome(None, next_prompt=args.get('msg', 'bad_json'), should_exit=False)
else:
yield f"未知工具: {tool_name}\n"
return StepOutcome(None, next_prompt=f"未知工具 {tool_name}", should_exit=False)
def json_default(o):
if isinstance(o, set): return list(o)
return str(o)
def exhaust(g):
try:
while True: next(g)
except StopIteration as e: return e.value
def get_pretty_json(data):
if isinstance(data, dict) and "script" in data:
data = data.copy(); data["script"] = data["script"].replace("; ", ";\n ")
return json.dumps(data, indent=2, ensure_ascii=False).replace('\\n', '\n')
_TOOL_ICONS = {'file_read': '📖', 'file_write': '✏️', 'file_patch': '✏️', 'code_run': '⚙️',
'web_scan': '🌐', 'web_execute_js': '🌐', 'update_working_checkpoint': '💾', 'ask_user': '❓', 'start_long_term_update': '💾'}
def agent_runner_loop(client, system_prompt, user_input, handler, tools_schema, max_turns=40, verbose=True, initial_user_content=None):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_user_content if initial_user_content is not None else user_input}
]
turn = 0; handler._done_hooks = []; handler.max_turns = max_turns
while turn < handler.max_turns:
turn += 1; md = '**' if verbose else ''
yield f"{md}LLM Running (Turn {turn}) ...{md}\n\n"
if turn%10 == 0: client.last_tools = '' # 每10轮重置一次工具描述,避免上下文过大导致的模型性能下降
response_gen = client.chat(messages=messages, tools=tools_schema)
if verbose:
response = yield from response_gen
yield '\n\n'
else:
response = exhaust(response_gen)
cleaned = _clean_content(response.content)
if cleaned: yield cleaned + '\n'
if not response.tool_calls: tool_calls = [{'tool_name': 'no_tool', 'args': {}}]
else: tool_calls = [{'tool_name': tc.function.name, 'args': json.loads(tc.function.arguments), 'id': tc.id}
for tc in response.tool_calls]
tool_results = []; next_prompts = set(); should_exit = None
for ii, tc in enumerate(tool_calls):
tool_name, args, tid = tc['tool_name'], tc['args'], tc.get('id', '')
icon = _TOOL_ICONS.get(tool_name, '🛠️')
if tool_name == 'no_tool': pass
else:
if verbose: yield f"{icon} 正在调用工具: `{tool_name}` 📥参数:\n````text\n{get_pretty_json(args)}\n````\n"
else: yield f"{icon} {tool_name}({_compact_tool_args(tool_name, args)})\n\n\n"
handler.current_turn = turn
gen = handler.dispatch(tool_name, args, response, index=ii)
try:
v = next(gen)
def proxy(): yield v; return (yield from gen)
if verbose: yield '`````\n'
outcome = (yield from proxy()) if verbose else exhaust(proxy())
if verbose: yield '`````\n'
except StopIteration as e: outcome = e.value
if outcome.should_exit: return {'result': 'EXITED', 'data': outcome.data} # should_exit is only used for immediate exit
if not outcome.next_prompt:
should_exit = {'result': 'CURRENT_TASK_DONE', 'data': outcome.data}; break
if outcome.next_prompt.startswith('未知工具'): client.last_tools = ''
if outcome.data is not None:
datastr = json.dumps(outcome.data, ensure_ascii=False, default=json_default) if type(outcome.data) in [dict, list] else str(outcome.data)
tool_results.append({'tool_use_id': tid, 'content': datastr})
next_prompts.add(outcome.next_prompt)
if len(next_prompts) == 0:
if len(handler._done_hooks) == 0: return should_exit
next_prompts.add(handler._done_hooks.pop(0))
next_prompt = handler.next_prompt_patcher("\n".join(next_prompts), None, turn)
messages = [{"role": "user", "content": next_prompt, "tool_results": tool_results}] # just new message, history is kept in *Session
return {'result': 'MAX_TURNS_EXCEEDED'}
def _clean_content(text):
if not text: return ''
def _shrink_code(m):
lines = m.group(0).split('\n')
lang = lines[0].replace('```','').strip()
body = [l for l in lines[1:-1] if l.strip()] # 去掉```行和空行
if len(body) <= 6: return m.group(0) # 短代码保留
preview = '\n'.join(body[:5])
return f'```{lang}\n{preview}\n ... ({len(body)} lines)\n```'
text = re.sub(r'```[\s\S]*?```', _shrink_code, text)
for p in [r'<file_content>[\s\S]*?</file_content>', r'<tool_(?:use|call)>[\s\S]*?</tool_(?:use|call)>', r'(\r?\n){3,}']:
text = re.sub(p, '\n\n' if '\\n' in p else '', text)
return text.strip()
def _compact_tool_args(name, args):
a = {k: v for k, v in args.items() if k != '_index'}
for k in ('path',): # 只缩短路径
if k in a: a[k] = os.path.basename(a[k])
if name == 'update_working_checkpoint': s = a.get('key_info', ''); return (s[:60]+'...') if len(s)>60 else s
s = json.dumps(a, ensure_ascii=False); return (s[:120]+'...') if len(s)>120 else s