diff --git a/charts/openab/templates/configmap.yaml b/charts/openab/templates/configmap.yaml index 78316b59..2cff6218 100644 --- a/charts/openab/templates/configmap.yaml +++ b/charts/openab/templates/configmap.yaml @@ -58,6 +58,19 @@ data: {{- if hasKey ($cfg.discord | default dict) "maxBotTurns" }} max_bot_turns = {{ ($cfg.discord).maxBotTurns | int }} {{- end }} + {{- /* messageProcessingMode: per-message (default) | per-thread | per-lane (turn-boundary batching) */ -}} + {{- if hasKey ($cfg.discord | default dict) "messageProcessingMode" }} + {{- if not (has ($cfg.discord).messageProcessingMode (list "per-message" "per-thread" "per-lane")) }} + {{- fail (printf "agents.%s.discord.messageProcessingMode must be one of: per-message, per-thread, per-lane — got: %s" $name ($cfg.discord).messageProcessingMode) }} + {{- end }} + message_processing_mode = {{ ($cfg.discord).messageProcessingMode | toJson }} + {{- end }} + {{- if hasKey ($cfg.discord | default dict) "maxBufferedMessages" }} + max_buffered_messages = {{ ($cfg.discord).maxBufferedMessages | int }} + {{- end }} + {{- if hasKey ($cfg.discord | default dict) "maxBatchTokens" }} + max_batch_tokens = {{ ($cfg.discord).maxBatchTokens | int }} + {{- end }} {{- end }} {{- if and ($cfg.slack).enabled }} @@ -97,6 +110,19 @@ data: {{- if hasKey ($cfg.slack | default dict) "maxBotTurns" }} max_bot_turns = {{ ($cfg.slack).maxBotTurns | int }} {{- end }} + {{- /* messageProcessingMode: per-message (default) | per-thread | per-lane (turn-boundary batching) */ -}} + {{- if hasKey ($cfg.slack | default dict) "messageProcessingMode" }} + {{- if not (has ($cfg.slack).messageProcessingMode (list "per-message" "per-thread" "per-lane")) }} + {{- fail (printf "agents.%s.slack.messageProcessingMode must be one of: per-message, per-thread, per-lane — got: %s" $name ($cfg.slack).messageProcessingMode) }} + {{- end }} + message_processing_mode = {{ ($cfg.slack).messageProcessingMode | toJson }} + {{- end }} + {{- if hasKey ($cfg.slack | default dict) "maxBufferedMessages" }} + max_buffered_messages = {{ ($cfg.slack).maxBufferedMessages | int }} + {{- end }} + {{- if hasKey ($cfg.slack | default dict) "maxBatchTokens" }} + max_batch_tokens = {{ ($cfg.slack).maxBatchTokens | int }} + {{- end }} {{- end }} [agent] @@ -162,6 +188,19 @@ data: {{- end }} {{- end }} allowed_users = {{ ($cfg.gateway).allowedUsers | default list | toJson }} + {{- /* messageProcessingMode: per-message (default) | per-thread | per-lane (turn-boundary batching) */ -}} + {{- if hasKey ($cfg.gateway | default dict) "messageProcessingMode" }} + {{- if not (has ($cfg.gateway).messageProcessingMode (list "per-message" "per-thread" "per-lane")) }} + {{- fail (printf "agents.%s.gateway.messageProcessingMode must be one of: per-message, per-thread, per-lane — got: %s" $name ($cfg.gateway).messageProcessingMode) }} + {{- end }} + message_processing_mode = {{ ($cfg.gateway).messageProcessingMode | toJson }} + {{- end }} + {{- if hasKey ($cfg.gateway | default dict) "maxBufferedMessages" }} + max_buffered_messages = {{ ($cfg.gateway).maxBufferedMessages | int }} + {{- end }} + {{- if hasKey ($cfg.gateway | default dict) "maxBatchTokens" }} + max_batch_tokens = {{ ($cfg.gateway).maxBatchTokens | int }} + {{- end }} {{- end }} {{- if or ($cfg.cronjobs) (($cfg.cron).usercronEnabled) (($cfg.cron).usercronPath) }} diff --git a/charts/openab/tests/message-processing-mode_test.yaml b/charts/openab/tests/message-processing-mode_test.yaml new file mode 100644 index 00000000..f70fde5b --- /dev/null +++ b/charts/openab/tests/message-processing-mode_test.yaml @@ -0,0 +1,124 @@ +suite: messageProcessingMode & batching params +templates: + - templates/configmap.yaml +tests: + - it: discord renders message_processing_mode = "per-lane" + set: + agents.kiro.discord.messageProcessingMode: per-lane + asserts: + - matchRegex: + path: data["config.toml"] + pattern: 'message_processing_mode = "per-lane"' + + - it: discord renders message_processing_mode = "per-thread" + set: + agents.kiro.discord.messageProcessingMode: per-thread + asserts: + - matchRegex: + path: data["config.toml"] + pattern: 'message_processing_mode = "per-thread"' + + - it: discord renders message_processing_mode = "per-message" + set: + agents.kiro.discord.messageProcessingMode: per-message + asserts: + - matchRegex: + path: data["config.toml"] + pattern: 'message_processing_mode = "per-message"' + + - it: discord rejects invalid messageProcessingMode value + set: + agents.kiro.discord.messageProcessingMode: batched + asserts: + - failedTemplate: + errorPattern: "must be one of: per-message, per-thread, per-lane" + + - it: discord omits message_processing_mode when not set + asserts: + - notMatchRegex: + path: data["config.toml"] + pattern: 'message_processing_mode' + + - it: discord renders maxBufferedMessages and maxBatchTokens + set: + agents.kiro.discord.messageProcessingMode: per-lane + agents.kiro.discord.maxBufferedMessages: 25 + agents.kiro.discord.maxBatchTokens: 32000 + asserts: + - matchRegex: + path: data["config.toml"] + pattern: 'max_buffered_messages = 25' + - matchRegex: + path: data["config.toml"] + pattern: 'max_batch_tokens = 32000' + + - it: slack renders message_processing_mode = "per-thread" + set: + agents.kiro.slack.enabled: true + agents.kiro.slack.botToken: xoxb-x + agents.kiro.slack.appToken: xapp-x + agents.kiro.slack.messageProcessingMode: per-thread + asserts: + - matchRegex: + path: data["config.toml"] + pattern: 'message_processing_mode = "per-thread"' + + - it: slack rejects invalid messageProcessingMode value + set: + agents.kiro.slack.enabled: true + agents.kiro.slack.botToken: xoxb-x + agents.kiro.slack.appToken: xapp-x + agents.kiro.slack.messageProcessingMode: batched + asserts: + - failedTemplate: + errorPattern: "must be one of: per-message, per-thread, per-lane" + + - it: slack renders maxBufferedMessages and maxBatchTokens + set: + agents.kiro.slack.enabled: true + agents.kiro.slack.botToken: xoxb-x + agents.kiro.slack.appToken: xapp-x + agents.kiro.slack.messageProcessingMode: per-lane + agents.kiro.slack.maxBufferedMessages: 15 + agents.kiro.slack.maxBatchTokens: 18000 + asserts: + - matchRegex: + path: data["config.toml"] + pattern: 'max_buffered_messages = 15' + - matchRegex: + path: data["config.toml"] + pattern: 'max_batch_tokens = 18000' + + - it: gateway renders message_processing_mode = "per-lane" + set: + agents.kiro.gateway.enabled: true + agents.kiro.gateway.url: ws://openab-gateway:8080/ws + agents.kiro.gateway.messageProcessingMode: per-lane + asserts: + - matchRegex: + path: data["config.toml"] + pattern: 'message_processing_mode = "per-lane"' + + - it: gateway rejects invalid messageProcessingMode value + set: + agents.kiro.gateway.enabled: true + agents.kiro.gateway.url: ws://openab-gateway:8080/ws + agents.kiro.gateway.messageProcessingMode: batched + asserts: + - failedTemplate: + errorPattern: "must be one of: per-message, per-thread, per-lane" + + - it: gateway renders maxBufferedMessages and maxBatchTokens + set: + agents.kiro.gateway.enabled: true + agents.kiro.gateway.url: ws://openab-gateway:8080/ws + agents.kiro.gateway.messageProcessingMode: per-thread + agents.kiro.gateway.maxBufferedMessages: 50 + agents.kiro.gateway.maxBatchTokens: 12000 + asserts: + - matchRegex: + path: data["config.toml"] + pattern: 'max_buffered_messages = 50' + - matchRegex: + path: data["config.toml"] + pattern: 'max_batch_tokens = 12000' diff --git a/charts/openab/values.yaml b/charts/openab/values.yaml index 4f12e315..e81faf9c 100644 --- a/charts/openab/values.yaml +++ b/charts/openab/values.yaml @@ -151,6 +151,14 @@ agents: # multi-agent collaborations; lower to throttle runaway loops more # aggressively. Hard cap remains 100 regardless (compiled-in). # maxBotTurns: 20 + # messageProcessingMode: "per-message" (default) | "per-thread" | "per-lane" + # per-thread: all senders in a thread share one batch → one ACP turn per turn boundary + # per-lane: each (thread, sender) batches independently → no silent-drop risk + # messageProcessingMode: "per-lane" + # maxBufferedMessages: per-thread mpsc capacity for batching modes (default 10) + # maxBufferedMessages: 10 + # maxBatchTokens: soft token cap per ACP turn for batching modes (default 24000) + # maxBatchTokens: 24000 slack: enabled: false botToken: "" # Bot User OAuth Token (xoxb-...) @@ -175,6 +183,14 @@ agents: # multi-agent collaborations; lower to throttle runaway loops more # aggressively. Hard cap remains 100 regardless (compiled-in). # maxBotTurns: 20 + # messageProcessingMode: "per-message" (default) | "per-thread" | "per-lane" + # per-thread: all senders in a thread share one batch → one ACP turn per turn boundary + # per-lane: each (thread, sender) batches independently → no silent-drop risk + # messageProcessingMode: "per-lane" + # maxBufferedMessages: per-thread mpsc capacity for batching modes (default 10) + # maxBufferedMessages: 10 + # maxBatchTokens: soft token cap per ACP turn for batching modes (default 24000) + # maxBatchTokens: 24000 workingDir: /home/agent env: {} envFrom: [] @@ -200,6 +216,14 @@ agents: platform: "telegram" # default platform when gateway is enabled token: "" # optional shared secret (injected via GATEWAY_WS_TOKEN env var) botUsername: "" # optional, for @mention gating + # messageProcessingMode: "per-message" (default) | "per-thread" | "per-lane" + # per-thread: all senders in a thread share one batch → one ACP turn per turn boundary + # per-lane: each (thread, sender) batches independently → no silent-drop risk + # messageProcessingMode: "per-lane" + # maxBufferedMessages: per-thread mpsc capacity for batching modes (default 10) + # maxBufferedMessages: 10 + # maxBatchTokens: soft token cap per ACP turn for batching modes (default 24000) + # maxBatchTokens: 24000 image: "ghcr.io/openabdev/openab-gateway" # gateway container image tag: "" # defaults to Chart.AppVersion strategy: "Recreate" # Recreate (default, prevents concurrent WS conflicts) or RollingUpdate diff --git a/src/adapter.rs b/src/adapter.rs index f558fd07..51cf01cd 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -87,6 +87,10 @@ pub struct SenderContext { #[serde(skip_serializing_if = "Option::is_none")] pub thread_id: Option, pub is_bot: bool, + /// Platform message creation time (ISO 8601 UTC). + /// Discord/Slack: platform timestamp. Gateway: broker receive time (best-effort). + /// Additive field — schema stays openab.sender.v1. + pub timestamp: String, } // --- ChatAdapter trait --- @@ -160,6 +164,32 @@ impl AdapterRouter { &self.pool } + /// Access the reactions config (used by dispatch.rs). + pub fn reactions_config(&self) -> &ReactionsConfig { + &self.reactions_config + } + + /// Pack one arrival event into ContentBlocks using the uniform per-arrival template: + /// Text { "\n{json}\n\n\n{prompt}" } + /// [extra_blocks in arrival order] + /// + /// This is the single packing code path for both per-message and batched dispatch + /// (ADR §3.5). For a batch of N messages, call this N times and concatenate. + pub fn pack_arrival_event( + sender_json: &str, + prompt: &str, + extra_blocks: Vec, + ) -> Vec { + let header = format!( + "\n{}\n\n\n{}", + sender_json, prompt + ); + let mut blocks = Vec::with_capacity(1 + extra_blocks.len()); + blocks.push(ContentBlock::Text { text: header }); + blocks.extend(extra_blocks); + blocks + } + /// Handle an incoming user message. The adapter is responsible for /// filtering, resolving the thread, and building the SenderContext. /// This method handles sender context injection, session management, and streaming. @@ -176,28 +206,7 @@ impl AdapterRouter { ) -> Result<()> { tracing::debug!(platform = adapter.platform(), "processing message"); - // Build content blocks: sender context + prompt text, then extra (images, transcripts) - let prompt_with_sender = format!( - "\n{}\n\n\n{}", - sender_json, prompt - ); - - let mut content_blocks = Vec::with_capacity(1 + extra_blocks.len()); - // Prepend any transcript blocks (they go before the text block) - for block in &extra_blocks { - if matches!(block, ContentBlock::Text { .. }) { - content_blocks.push(block.clone()); - } - } - content_blocks.push(ContentBlock::Text { - text: prompt_with_sender, - }); - // Append non-text blocks (images) - for block in extra_blocks { - if !matches!(block, ContentBlock::Text { .. }) { - content_blocks.push(block); - } - } + let content_blocks = Self::pack_arrival_event(sender_json, prompt, extra_blocks); let thread_key = format!( "{}:{}", @@ -272,6 +281,21 @@ impl AdapterRouter { thread_channel: &ChannelRef, reactions: Arc, other_bot_present: bool, + ) -> Result<()> { + self.stream_prompt_blocks(adapter, thread_key, content_blocks, thread_channel, reactions, other_bot_present).await + } + + /// Drive one ACP turn with the given pre-packed ContentBlocks. + /// Called by both `handle_message` (per-message mode) and `dispatch::dispatch_batch` + /// (batched mode). + pub async fn stream_prompt_blocks( + &self, + adapter: &Arc, + thread_key: &str, + content_blocks: Vec, + thread_channel: &ChannelRef, + reactions: Arc, + other_bot_present: bool, ) -> Result<()> { let adapter = adapter.clone(); let thread_channel = thread_channel.clone(); diff --git a/src/config.rs b/src/config.rs index 3bd99504..fff52736 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,6 +4,36 @@ use serde::Deserialize; use std::collections::HashMap; use std::path::Path; +/// Controls how incoming messages are dispatched to ACP turns. +/// +/// - `Message` (default): each message becomes its own ACP turn (v0.8.2-beta.1 behaviour). +/// - `Thread`: one buffer per thread; all senders in a thread share a single batch and +/// produce one ACP turn per turn boundary. +/// - `Lane`: one buffer per (thread, sender); each sender batches independently and gets +/// its own ACP turn — no silent-drop risk when multiple senders address the same thread. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum MessageProcessingMode { + #[default] + Message, + Thread, + Lane, +} + +impl<'de> Deserialize<'de> for MessageProcessingMode { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + match s.to_lowercase().replace('-', "_").as_str() { + "per_message" => Ok(Self::Message), + "per_thread" => Ok(Self::Thread), + "per_lane" => Ok(Self::Lane), + other => Err(serde::de::Error::unknown_variant( + other, + &["per-message", "per-thread", "per-lane"], + )), + } + } +} + /// Controls whether the bot processes messages from other Discord bots. /// /// Inspired by Hermes Agent's `DISCORD_ALLOW_BOTS` 3-value design: @@ -120,9 +150,20 @@ pub struct DiscordConfig { /// Default: false (opt-in). `allowed_users` still applies in DMs. #[serde(default)] pub allow_dm: bool, + /// Message dispatch mode. Default: per-message (v0.8.2-beta.1 behaviour). + #[serde(default)] + pub message_processing_mode: MessageProcessingMode, + /// Batched mode only: per-thread channel capacity. Default: 10. + #[serde(default = "default_max_buffered_messages")] + pub max_buffered_messages: usize, + /// Batched mode only: soft token cap for greedy drain. Default: 24000. + #[serde(default = "default_max_batch_tokens")] + pub max_batch_tokens: usize, } fn default_max_bot_turns() -> u32 { 20 } +fn default_max_buffered_messages() -> usize { 10 } +fn default_max_batch_tokens() -> usize { 24_000 } /// Controls whether the bot responds to user messages in threads without @mention. /// @@ -179,6 +220,15 @@ pub struct SlackConfig { /// Human message resets the counter. Default: 20. #[serde(default = "default_max_bot_turns")] pub max_bot_turns: u32, + /// Message dispatch mode. Default: per-message. + #[serde(default)] + pub message_processing_mode: MessageProcessingMode, + /// Batched mode only: per-thread channel capacity. Default: 10. + #[serde(default = "default_max_buffered_messages")] + pub max_buffered_messages: usize, + /// Batched mode only: soft token cap for greedy drain. Default: 24000. + #[serde(default = "default_max_batch_tokens")] + pub max_batch_tokens: usize, } #[derive(Debug, Deserialize)] @@ -205,6 +255,15 @@ pub struct GatewayConfig { /// Enable streaming (typewriter) mode — requires gateway platform to support message editing. #[serde(default)] pub streaming: bool, + /// Message dispatch mode. Default: per-message. + #[serde(default)] + pub message_processing_mode: MessageProcessingMode, + /// Batched mode only: per-thread channel capacity. Default: 10. + #[serde(default = "default_max_buffered_messages")] + pub max_buffered_messages: usize, + /// Batched mode only: soft token cap for greedy drain. Default: 24000. + #[serde(default = "default_max_batch_tokens")] + pub max_batch_tokens: usize, } fn default_gateway_platform() -> String { @@ -285,7 +344,7 @@ impl<'de> Deserialize<'de> for ToolDisplay { } } -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct ReactionsConfig { #[serde(default = "default_true")] pub enabled: bool, @@ -452,6 +511,23 @@ fn parse_config(raw: &str, source: &str) -> anyhow::Result { let expanded = expand_env_vars(raw); let config: Config = toml::from_str(&expanded) .map_err(|e| anyhow::anyhow!("failed to parse config from {source}: {e}"))?; + + // Validate max_buffered_messages > 0 (tokio::sync::mpsc::channel panics on 0) + // and max_batch_tokens > 0 (otherwise the consumer's token-cap check forces every + // batch to size 1 — functionally per-message via a confusing path). + if let Some(ref d) = config.discord { + anyhow::ensure!(d.max_buffered_messages > 0, "discord.max_buffered_messages must be > 0"); + anyhow::ensure!(d.max_batch_tokens > 0, "discord.max_batch_tokens must be > 0"); + } + if let Some(ref s) = config.slack { + anyhow::ensure!(s.max_buffered_messages > 0, "slack.max_buffered_messages must be > 0"); + anyhow::ensure!(s.max_batch_tokens > 0, "slack.max_batch_tokens must be > 0"); + } + if let Some(ref g) = config.gateway { + anyhow::ensure!(g.max_buffered_messages > 0, "gateway.max_buffered_messages must be > 0"); + anyhow::ensure!(g.max_batch_tokens > 0, "gateway.max_batch_tokens must be > 0"); + } + Ok(config) } @@ -584,6 +660,94 @@ command = "echo" assert_eq!(ToolDisplay::default(), ToolDisplay::Full); } + #[test] + fn message_processing_mode_parses_per_message() { + let toml = r#" +[discord] +bot_token = "t" +message_processing_mode = "per-message" + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert_eq!( + cfg.discord.unwrap().message_processing_mode, + MessageProcessingMode::Message + ); + } + + #[test] + fn message_processing_mode_parses_per_thread() { + let toml = r#" +[discord] +bot_token = "t" +message_processing_mode = "per-thread" + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert_eq!( + cfg.discord.unwrap().message_processing_mode, + MessageProcessingMode::Thread + ); + } + + #[test] + fn message_processing_mode_parses_per_lane() { + let toml = r#" +[discord] +bot_token = "t" +message_processing_mode = "per-lane" + +[agent] +command = "echo" +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert_eq!( + cfg.discord.unwrap().message_processing_mode, + MessageProcessingMode::Lane + ); + } + + // The legacy alias "batched" was removed: only per-message / per-thread / per-lane + // are accepted. Configs still using "batched" must migrate to an explicit value. + #[test] + fn message_processing_mode_batched_is_rejected() { + let toml = r#" +[discord] +bot_token = "t" +message_processing_mode = "batched" + +[agent] +command = "echo" +"#; + assert!(parse_config(toml, "test").is_err()); + } + + #[test] + fn message_processing_mode_default_is_per_message() { + let cfg = parse_config(MINIMAL_TOML, "test").unwrap(); + assert_eq!( + cfg.discord.unwrap().message_processing_mode, + MessageProcessingMode::Message + ); + } + + #[test] + fn message_processing_mode_unknown_value_errors() { + let toml = r#" +[discord] +bot_token = "t" +message_processing_mode = "bogus" + +[agent] +command = "echo" +"#; + assert!(parse_config(toml, "test").is_err()); + } + #[test] fn parse_gateway_config_explicit_allow_all_overrides_list() { let toml = r#" diff --git a/src/cron.rs b/src/cron.rs index ce5245ec..ffeda136 100644 --- a/src/cron.rs +++ b/src/cron.rs @@ -350,6 +350,7 @@ async fn fire_cronjob( channel_id: reply_channel.parent_id.as_deref().unwrap_or(&reply_channel.channel_id).to_string(), thread_id: reply_channel.thread_id.clone().or(Some(reply_channel.channel_id.clone())), is_bot: true, + timestamp: Utc::now().to_rfc3339(), }; let sender_json = match serde_json::to_string(&sender) { Ok(j) => j, diff --git a/src/discord.rs b/src/discord.rs index 67f3c7c8..6bc224bd 100644 --- a/src/discord.rs +++ b/src/discord.rs @@ -160,6 +160,8 @@ pub struct Handler { pub bot_turns: tokio::sync::Mutex, /// Allow the bot to respond to Discord DMs. pub allow_dm: bool, + /// Per-thread dispatcher (Message mode uses cap=1 for FIFO; Thread/Lane use configured cap). + pub dispatcher: Arc, } impl Handler { @@ -527,6 +529,7 @@ impl EventHandler for Handler { &msg.channel_id.to_string(), thread_parent_id.as_deref(), msg.author.bot, + &msg.timestamp.to_rfc3339().unwrap_or_default(), ); // Build extra content blocks from attachments (audio → STT, text → inline, image → encode) @@ -622,7 +625,7 @@ impl EventHandler for Handler { let trigger_msg = discord_msg_ref(&msg); // Per-thread streaming: check if another bot is present in this thread - let other_bot_present = { + let other_bot_present_flag = { let cache = self.multibot_threads.lock().await; cache.contains_key(&msg.channel_id.to_string()) }; @@ -635,14 +638,31 @@ impl EventHandler for Handler { sender.thread_id = Some(thread_channel.channel_id.clone()); } - let router = self.router.clone(); + let dispatcher = self.dispatcher.clone(); + tokio::spawn(async move { + let sender_id = sender.sender_id.clone(); + let sender_name = sender.sender_name.clone(); let sender_json = serde_json::to_string(&sender).unwrap(); - if let Err(e) = router - .handle_message(&adapter, &thread_channel, &sender_json, &prompt, extra_blocks, &trigger_msg, other_bot_present) + let thread_key = + dispatcher.key("discord", &thread_channel.channel_id, &sender_id); + let estimated_tokens = + crate::dispatch::estimate_tokens(&prompt, &extra_blocks); + let buf_msg = crate::dispatch::BufferedMessage { + sender_json, + sender_name, + prompt, + extra_blocks, + trigger_msg, + arrived_at: std::time::Instant::now(), + estimated_tokens, + other_bot_present: other_bot_present_flag, + }; + if let Err(e) = dispatcher + .submit(thread_key, thread_channel, adapter, buf_msg) .await { - error!("handle_message error: {e}"); + error!("dispatcher submit error: {e}"); } }); } @@ -658,6 +678,8 @@ impl EventHandler for Handler { .description("Select the agent mode for this session"), CreateCommand::new("cancel") .description("Cancel the current operation"), + CreateCommand::new("cancel-all") + .description("Cancel current operation and drop all buffered messages"), CreateCommand::new("reset") .description("Reset the conversation session"), ]; @@ -694,6 +716,9 @@ impl EventHandler for Handler { Interaction::Command(cmd) if cmd.data.name == "cancel" => { self.handle_cancel_command(&ctx, &cmd).await; } + Interaction::Command(cmd) if cmd.data.name == "cancel-all" => { + self.handle_cancel_all_command(&ctx, &cmd).await; + } Interaction::Command(cmd) if cmd.data.name == "reset" => { self.handle_reset_command(&ctx, &cmd).await; } @@ -856,16 +881,57 @@ impl Handler { } } + async fn handle_cancel_all_command( + &self, + ctx: &Context, + cmd: &serenity::model::application::CommandInteraction, + ) { + // /cancel-all is the nuclear escape hatch: stop the in-flight turn AND clear + // every lane's buffer in this thread, so a human can intervene from a clean slate. + let session_key = format!("discord:{}", cmd.channel_id.get()); + let dropped = self + .dispatcher + .cancel_buffered_thread("discord", &cmd.channel_id.get().to_string()); + + let cancel_result = self.router.pool().cancel_session(&session_key).await; + + let msg = match (cancel_result, dropped) { + (Ok(()), 0) => "🛑 Cancel signal sent.".to_string(), + (Ok(()), n) => format!("🛑 Cancel signal sent. Dropped {n} buffered message(s)."), + (Err(_), 0) => "⚠️ Nothing to cancel — no active session and no buffered messages.".to_string(), + (Err(_), n) => format!("🛑 Dropped {n} buffered message(s). No active session to cancel."), + }; + + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new().content(msg).ephemeral(true), + ); + if let Err(e) = cmd.create_response(&ctx.http, response).await { + tracing::error!(error = %e, "failed to respond to /cancel-all command"); + } + } + async fn handle_reset_command( &self, ctx: &Context, cmd: &serenity::model::application::CommandInteraction, ) { - let thread_key = format!("discord:{}", cmd.channel_id.get()); - let result = self.router.pool().reset_session(&thread_key).await; + // /reset clears every lane's buffer in this thread and tears down the shared + // ACP session — the next message in the thread starts a fresh conversation. + let session_key = format!("discord:{}", cmd.channel_id.get()); + let dropped = self + .dispatcher + .cancel_buffered_thread("discord", &cmd.channel_id.get().to_string()); + + let result = self.router.pool().reset_session(&session_key).await; let msg = match result { + Ok(()) if dropped > 0 => { + format!("🔄 Session reset. Dropped {dropped} buffered message(s). Start a new conversation!") + } Ok(()) => "🔄 Session reset. Start a new conversation!".to_string(), + Err(_) if dropped > 0 => { + format!("🔄 Dropped {dropped} buffered message(s). No active session to reset.") + } Err(_) => "⚠️ No active session to reset. Start a conversation first by @mentioning the bot.".to_string(), }; @@ -1104,6 +1170,7 @@ fn build_sender_context( msg_channel_id: &str, thread_parent_id: Option<&str>, is_bot: bool, + timestamp: &str, ) -> SenderContext { SenderContext { schema: "openab.sender.v1".into(), @@ -1114,6 +1181,7 @@ fn build_sender_context( channel_id: thread_parent_id.unwrap_or(msg_channel_id).to_string(), thread_id: thread_parent_id.map(|_| msg_channel_id.to_string()), is_bot, + timestamp: timestamp.to_string(), } } @@ -1438,7 +1506,7 @@ mod tests { /// In-thread message: channel_id = parent, thread_id = thread channel ID. #[test] fn build_sender_context_in_thread() { - let ctx = build_sender_context("user1", "alice", "Alice", "thread_ch", Some("parent_ch"), false); + let ctx = build_sender_context("user1", "alice", "Alice", "thread_ch", Some("parent_ch"), false, "2026-05-01T00:00:00Z"); assert_eq!(ctx.channel_id, "parent_ch"); assert_eq!(ctx.thread_id, Some("thread_ch".to_string())); assert_eq!(ctx.channel, "discord"); @@ -1449,7 +1517,7 @@ mod tests { /// Non-thread message: channel_id = message channel, thread_id = None. #[test] fn build_sender_context_not_in_thread() { - let ctx = build_sender_context("user1", "alice", "Alice", "main_ch", None, false); + let ctx = build_sender_context("user1", "alice", "Alice", "main_ch", None, false, "2026-05-01T00:00:00Z"); assert_eq!(ctx.channel_id, "main_ch"); assert_eq!(ctx.thread_id, None); } @@ -1457,7 +1525,7 @@ mod tests { /// Bot sender: is_bot flag propagated correctly. #[test] fn build_sender_context_bot_sender() { - let ctx = build_sender_context("bot1", "mybot", "MyBot", "ch", Some("parent"), true); + let ctx = build_sender_context("bot1", "mybot", "MyBot", "ch", Some("parent"), true, "2026-05-01T00:00:00Z"); assert!(ctx.is_bot); assert_eq!(ctx.channel_id, "parent"); assert_eq!(ctx.thread_id, Some("ch".to_string())); diff --git a/src/dispatch.rs b/src/dispatch.rs new file mode 100644 index 00000000..0aa824aa --- /dev/null +++ b/src/dispatch.rs @@ -0,0 +1,1381 @@ +//! Turn-boundary message batching dispatcher. +//! +//! See ADR: turn-boundary-batching-adr.md for full design rationale. +//! +//! # Invariants +//! - I1: First message after idle has zero added latency. +//! - I2: At most one in-flight ACP turn per thread. +//! - I3: Broker structural fidelity — no merging, splitting, reordering, or +//! semantic transformation of arrival events. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use async_trait::async_trait; +use tracing::{debug, error, info, info_span, warn}; + +use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef}; +use crate::acp::ContentBlock; +use crate::config::ReactionsConfig; +use crate::error_display::format_user_error; +use crate::reactions::StatusReactionController; + +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- + +/// One arrival event buffered for a future ACP turn. +pub struct BufferedMessage { + /// Serialised SenderContext JSON (already built by the platform adapter). + pub sender_json: String, + /// Author display name — denormalised from `sender_json` so observability + /// fields (per-event tracing in `dispatch_batch`) don't pay a JSON parse. + /// Per ADR §2.3 each arrival event carries its sender name. + pub sender_name: String, + /// User-visible prompt text (verbatim, never transformed). + pub prompt: String, + /// Attachment blocks (images, STT transcripts) in arrival order. + pub extra_blocks: Vec, + /// Anchor for reactions (👀 / ❌). + pub trigger_msg: MessageRef, + /// Broker receive time — used for `buffer_wait_ms` observability. + pub arrived_at: Instant, + /// Rough token estimate for `max_batch_tokens` cap. + pub estimated_tokens: usize, + /// Snapshot at submit time. Captured per-message so a batch reflects the + /// freshest known state; `dispatch_batch` reads `batch.last()`. + pub other_bot_present: bool, +} + +/// How `thread_key` is built for the dispatcher's per-thread map. +/// +/// - `Thread`: one mpsc per thread → all senders in a thread share one batch → one +/// ACP turn per batch (cheaper, but risks silent drop when the agent's single reply +/// forgets to address some senders). +/// - `Lane`: one mpsc per (thread, sender) → each sender batches independently and +/// gets a dedicated ACP turn. Sessions are still shared per-thread; turns serialise +/// through the shared session. +/// +/// Derived from `config::MessageProcessingMode` in `main.rs`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BatchGrouping { + Thread, + Lane, +} + +/// Error returned by `Dispatcher::submit`. +#[derive(Debug)] +pub enum DispatchError { + /// The per-thread consumer task has exited unexpectedly. + ConsumerDead, +} + +impl std::fmt::Display for DispatchError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConsumerDead => write!(f, "dispatch consumer exited unexpectedly"), + } + } +} + +impl std::error::Error for DispatchError {} + +// --------------------------------------------------------------------------- +// Internal types +// --------------------------------------------------------------------------- + +struct ThreadHandle { + tx: tokio::sync::mpsc::Sender, + consumer: tokio::task::JoinHandle<()>, + /// Race-safe eviction counter (§2.5). Plain u64 — all reads/writes under per_thread lock. + generation: u64, + channel_id: String, + adapter_kind: String, +} + +impl ThreadHandle { + /// Approximate number of messages still buffered in the mpsc — used for + /// shutdown / cancel logging. Not exact: tokio's mpsc has no sync `.len()`. + fn pending_count(&self) -> usize { + self.tx.max_capacity() - self.tx.capacity() + } +} + +// --------------------------------------------------------------------------- +// DispatchTarget — trait seam between Dispatcher and AdapterRouter +// --------------------------------------------------------------------------- + +/// Surface that `consumer_loop` / `dispatch_batch` need from the underlying +/// router. Extracted as a trait so the dispatcher can be unit-tested without +/// spinning up a real `SessionPool` (which forks ACP CLI subprocesses). +/// `AdapterRouter` is the production implementor; tests use a mock that +/// records calls. +#[async_trait] +pub trait DispatchTarget: Send + Sync + 'static { + fn reactions_config(&self) -> &ReactionsConfig; + + /// Ensure the ACP session for `session_key` exists (idempotent). + async fn ensure_session(&self, session_key: &str) -> Result<()>; + + /// Drive one ACP turn with the pre-packed `content_blocks`. + #[allow(clippy::too_many_arguments)] + async fn stream_prompt_blocks( + &self, + adapter: &Arc, + session_key: &str, + content_blocks: Vec, + thread_channel: &ChannelRef, + reactions: Arc, + other_bot_present: bool, + ) -> Result<()>; +} + +#[async_trait] +impl DispatchTarget for AdapterRouter { + fn reactions_config(&self) -> &ReactionsConfig { + AdapterRouter::reactions_config(self) + } + + async fn ensure_session(&self, session_key: &str) -> Result<()> { + self.pool().get_or_create(session_key).await + } + + async fn stream_prompt_blocks( + &self, + adapter: &Arc, + session_key: &str, + content_blocks: Vec, + thread_channel: &ChannelRef, + reactions: Arc, + other_bot_present: bool, + ) -> Result<()> { + AdapterRouter::stream_prompt_blocks( + self, + adapter, + session_key, + content_blocks, + thread_channel, + reactions, + other_bot_present, + ) + .await + } +} + +// --------------------------------------------------------------------------- +// Dispatcher +// --------------------------------------------------------------------------- + +/// Default idle timeout for per-thread consumer tasks. When no message arrives +/// within this window the consumer exits, allowing `per_thread` map cleanup on +/// the next `submit` (via `SendError` → `try_evict_locked`). Prevents unbounded +/// task/memory growth from one-shot thread keys (e.g. Slack non-thread messages). +pub const DEFAULT_CONSUMER_IDLE_TIMEOUT: Duration = Duration::from_secs(300); + +/// Per-thread message dispatcher for batched mode. +/// +/// Constructed once in `main.rs` and shared via `Arc`. Platform adapters call +/// `submit()` from their per-message `tokio::spawn`'d tasks. +pub struct Dispatcher { + /// std::sync::Mutex — critical section has no .await; tokio::Mutex buys nothing here. + per_thread: Mutex>, + /// Monotonic counter for `ThreadHandle.generation` (§2.5). Pre-fetched on + /// every `submit` and consumed only when a fresh handle is inserted; wasted + /// values are fine because generations need only be monotonic, not contiguous. + next_generation: AtomicU64, + target: Arc, + max_buffered_messages: usize, + max_batch_tokens: usize, + grouping: BatchGrouping, + idle_timeout: Duration, +} + +impl Dispatcher { + pub fn new( + target: Arc, + max_buffered_messages: usize, + max_batch_tokens: usize, + grouping: BatchGrouping, + ) -> Self { + Self::with_idle_timeout( + target, + max_buffered_messages, + max_batch_tokens, + grouping, + DEFAULT_CONSUMER_IDLE_TIMEOUT, + ) + } + + /// Like `new`, but with a custom consumer idle timeout. Test-only knob — + /// production code should use `new` (which applies `DEFAULT_CONSUMER_IDLE_TIMEOUT`). + pub fn with_idle_timeout( + target: Arc, + max_buffered_messages: usize, + max_batch_tokens: usize, + grouping: BatchGrouping, + idle_timeout: Duration, + ) -> Self { + Self { + per_thread: Mutex::new(HashMap::new()), + next_generation: AtomicU64::new(0), + target, + max_buffered_messages, + max_batch_tokens, + grouping, + idle_timeout, + } + } + + /// Build the dispatcher key for a (platform, thread, sender) tuple. + /// + /// In `Thread` mode the sender is ignored; in `Lane` mode the sender is appended + /// so each (thread, sender) pair gets its own mpsc and consumer. + /// + /// Note: this is the *dispatcher* key, not the *session pool* key. Session pool keys + /// are always `:` regardless of grouping (the ACP session is + /// shared per-thread by design). + pub fn key(&self, platform: &str, thread_id: &str, sender_id: &str) -> String { + match self.grouping { + BatchGrouping::Thread => format!("{platform}:{thread_id}"), + BatchGrouping::Lane => format!("{platform}:{thread_id}:{sender_id}"), + } + } + + /// Build the shared session pool key for a routed channel. + /// + /// Unlike dispatcher keys, session keys never include sender identity. + /// They track the logical conversation thread across all grouping modes. + fn session_key(thread_channel: &ChannelRef) -> String { + let logical_thread_id = thread_channel + .thread_id + .as_deref() + .unwrap_or(&thread_channel.channel_id); + format!("{}:{}", thread_channel.platform, logical_thread_id) + } + + /// Submit one arrival event for the given thread. + /// + /// - If the thread has no active consumer, one is spawned lazily. + /// - If the channel is full, this future parks until space is available + /// (backpressure — no data loss, no error). + /// - If the consumer has died (`SendError`), surfaces ❌ + ⚠️ and returns + /// `Err(DispatchError::ConsumerDead)` (§2.5). + /// + /// `adapter` is passed per-call (not stored on `Dispatcher`) because the + /// Discord adapter is constructed inside serenity's `ready` callback via + /// `OnceLock` — after the Dispatcher is built in `main.rs`. + pub async fn submit( + &self, + thread_key: String, + thread_channel: ChannelRef, + adapter: Arc, + msg: BufferedMessage, + ) -> Result<(), DispatchError> { + let cap = self.max_buffered_messages; + let target = Arc::clone(&self.target); + let max_tokens = self.max_batch_tokens; + let idle_timeout = self.idle_timeout; + + // Pre-fetch a generation in case we end up inserting a fresh handle. + // Wasted if the entry already exists; generations need only be monotonic. + let next_g = self.next_generation.fetch_add(1, Ordering::Relaxed); + + let (tx, my_generation) = { + let mut map = self.per_thread.lock().unwrap(); + + // Proactive stale-entry cleanup: if the consumer has exited (idle + // timeout or unexpected), remove the entry so `or_insert_with` + // creates a fresh one. Prevents map leak from one-shot thread keys + // and avoids the first-message-after-idle being treated as an error. + if let Some(handle) = map.get(&thread_key) { + if handle.consumer.is_finished() { + map.remove(&thread_key); + } + } + + let entry = map.entry(thread_key.clone()).or_insert_with(|| { + let (tx, rx) = tokio::sync::mpsc::channel(cap); + let consumer = tokio::spawn(consumer_loop( + thread_key.clone(), + thread_channel.clone(), + rx, + Arc::clone(&target), + Arc::clone(&adapter), + cap, + max_tokens, + idle_timeout, + )); + ThreadHandle { + tx, + consumer, + generation: next_g, + channel_id: thread_channel.channel_id.clone(), + adapter_kind: adapter.platform().to_string(), + } + }); + (entry.tx.clone(), entry.generation) + }; + + if let Err(e) = tx.send(msg).await { + // Consumer has exited between our check and the send — race-safe + // eviction under lock (§2.5), then transparent retry once. + { + let mut map = self.per_thread.lock().unwrap(); + Self::try_evict_locked(&mut map, &thread_key, my_generation); + } + let failed_msg = e.0; + + // Retry: spawn a fresh consumer and re-send. If this also fails, + // surface the error to the user. + let retry_g = self.next_generation.fetch_add(1, Ordering::Relaxed); + let (retry_tx, retry_gen) = { + let mut map = self.per_thread.lock().unwrap(); + let entry = map.entry(thread_key.clone()).or_insert_with(|| { + let (tx, rx) = tokio::sync::mpsc::channel(cap); + let consumer = tokio::spawn(consumer_loop( + thread_key.clone(), + thread_channel.clone(), + rx, + Arc::clone(&target), + Arc::clone(&adapter), + cap, + max_tokens, + idle_timeout, + )); + ThreadHandle { + tx, + consumer, + generation: retry_g, + channel_id: thread_channel.channel_id.clone(), + adapter_kind: adapter.platform().to_string(), + } + }); + (entry.tx.clone(), entry.generation) + }; + + if let Err(e2) = retry_tx.send(failed_msg).await { + // Retry also failed — truly unexpected. Surface error. + { + let mut map = self.per_thread.lock().unwrap(); + Self::try_evict_locked(&mut map, &thread_key, retry_gen); + } + let failed_msg = e2.0; + let _ = adapter + .add_reaction( + &failed_msg.trigger_msg, + &self.target.reactions_config().emojis.error, + ) + .await; + let _ = adapter + .send_message( + &thread_channel, + &format!("⚠️ {}", format_user_error("dispatch consumer exited unexpectedly")), + ) + .await; + return Err(DispatchError::ConsumerDead); + } + } + Ok(()) + } + + /// Drop all per-thread handles whose key belongs to `(platform, thread_id)`, + /// regardless of grouping, and abort each consumer (§2.5 / §4.4). Returns + /// the total number of buffered messages discarded across all lanes. + /// + /// Matches both Thread keys (`:`) and Lane keys + /// (`::`). Used by `/reset` and + /// `/cancel-all` to clear the entire thread, not just one lane. + /// + /// Disjoint from SendError recovery: removal happens *before* abort, so any + /// fresh `submit` after this returns lands on a lazily-constructed new handle + /// instead of observing `SendError`. + pub fn cancel_buffered_thread(&self, platform: &str, thread_id: &str) -> usize { + let prefix = format!("{platform}:{thread_id}"); + let lane_prefix = format!("{prefix}:"); + let mut map = self.per_thread.lock().unwrap(); + let keys: Vec = map + .keys() + .filter(|k| k.as_str() == prefix || k.starts_with(&lane_prefix)) + .cloned() + .collect(); + let mut dropped = 0; + for k in keys { + if let Some(handle) = map.remove(&k) { + dropped += handle.pending_count(); + handle.consumer.abort(); + } + } + dropped + } + + /// §2.5 race-safe eviction. Caller must hold the `per_thread` mutex. + /// Removes the entry only if its generation matches `my_generation` — + /// protects against evicting a fresh handle that another `submit` lazily + /// inserted between this caller's failed `tx.send` and this call. + /// Returns true if the entry was removed. + fn try_evict_locked( + map: &mut HashMap, + thread_key: &str, + my_generation: u64, + ) -> bool { + if let Some(handle) = map.get(thread_key) { + if handle.generation == my_generation { + map.remove(thread_key); + return true; + } + } + false + } + + /// Remove map entries whose consumer task has finished (idle timeout or + /// unexpected exit). Called periodically from the cleanup task in main.rs + /// to prevent unbounded map growth from one-shot thread keys that never + /// receive a second `submit()`. Returns the number of entries swept. + pub fn sweep_stale(&self) -> usize { + let mut map = self.per_thread.lock().unwrap(); + let before = map.len(); + map.retain(|_, handle| !handle.consumer.is_finished()); + before - map.len() + } + + /// Log buffered-message counts and drop all handles (called on SIGTERM). + pub fn shutdown(&self) { + let mut map = self.per_thread.lock().unwrap(); + for (thread_id, handle) in map.iter() { + let pending = handle.pending_count(); + if pending > 0 { + warn!( + thread_id = %thread_id, + channel = %handle.channel_id, + adapter = %handle.adapter_kind, + buffered_lost = pending, + "shutdown dropped pending messages without dispatch", + ); + } + handle.consumer.abort(); + } + map.clear(); + } +} + +// --------------------------------------------------------------------------- +// consumer_loop +// --------------------------------------------------------------------------- + +#[allow(clippy::too_many_arguments)] +async fn consumer_loop( + thread_key: String, + thread_channel: ChannelRef, + mut rx: tokio::sync::mpsc::Receiver, + target: Arc, + adapter: Arc, + max_batch: usize, + max_tokens: usize, + idle_timeout: Duration, +) { + // `pending` holds a message that exceeded the token cap for the current batch; + // it becomes the first message of the next batch, preserving FIFO. + let mut pending: Option = None; + + loop { + // I1: block until at least one message arrives (zero latency for first message). + // Idle timeout: if no message arrives within `idle_timeout` the consumer + // exits, freeing the task and mpsc. The next `submit` for this thread_key + // will observe `SendError`, evict the stale entry, and lazily spawn a + // fresh consumer (§2.5 generation check prevents mis-eviction). + let first = match pending.take() { + Some(msg) => msg, + None => match tokio::time::timeout(idle_timeout, rx.recv()).await { + Ok(Some(msg)) => msg, + Ok(None) => { + // All senders dropped → shutdown() or cancel_buffered_thread(). + break; + } + Err(_elapsed) => { + debug!( + thread_key = %thread_key, + channel = %thread_channel.channel_id, + "consumer idle timeout, exiting" + ); + break; + } + }, + }; + + // Greedy drain up to max_batch messages or max_tokens. + let mut batch = vec![first]; + let mut cumulative_tokens = batch[0].estimated_tokens; + + while batch.len() < max_batch { + match rx.try_recv() { + Ok(more) => { + if cumulative_tokens + more.estimated_tokens > max_tokens { + // Token cap — save for next turn (FIFO preserved). + pending = Some(more); + break; + } + cumulative_tokens += more.estimated_tokens; + batch.push(more); + } + Err(_) => break, + } + } + + // §2.6: read the freshest snapshot in the batch (batch is non-empty). + let bot_present = batch.last().unwrap().other_bot_present; + + dispatch_batch( + &thread_key, + &thread_channel, + &target, + &adapter, + batch, + bot_present, + ) + .await; + } +} + +// --------------------------------------------------------------------------- +// dispatch_batch +// --------------------------------------------------------------------------- + +async fn dispatch_batch( + thread_key: &str, + thread_channel: &ChannelRef, + target: &Arc, + adapter: &Arc, + batch: Vec, + other_bot_present: bool, +) { + let dispatch_start = Instant::now(); + let batch_size = batch.len(); + let session_key = Dispatcher::session_key(thread_channel); + + // Apply 👀 reaction to every message in the batch before dispatch (§6.7). + // Parallelized so first-token latency isn't paid for N serial reaction RPCs. + let queued_emoji = &target.reactions_config().emojis.queued; + futures_util::future::join_all( + batch + .iter() + .map(|msg| adapter.add_reaction(&msg.trigger_msg, queued_emoji)), + ) + .await; + + // Collect per-event observability data (before consuming the batch). + let tokens_per_event: Vec = batch.iter().map(|m| m.estimated_tokens).collect(); + let wait_ms: Vec = batch + .iter() + .map(|m| m.arrived_at.elapsed().as_millis()) + .collect(); + let senders: Vec = batch.iter().map(|m| m.sender_name.clone()).collect(); + + // Anchor reactions on the last message in the batch (before consuming). + let trigger_msg = batch.last().unwrap().trigger_msg.clone(); + + // Pack all arrival events into one Vec (§3.3). + // Uses into_iter() to avoid deep-copying extra_blocks (may contain base64 image data). + let mut content_blocks: Vec = Vec::new(); + for msg in batch { + let mut event_blocks = + AdapterRouter::pack_arrival_event(&msg.sender_json, &msg.prompt, msg.extra_blocks); + content_blocks.append(&mut event_blocks); + } + let packed_block_count = content_blocks.len(); + + // Ensure session exists. + if let Err(e) = target.ensure_session(&session_key).await { + let user_msg = format_user_error(&e.to_string()); + let _ = adapter + .send_message(thread_channel, &format!("⚠️ {user_msg}")) + .await; + error!("pool error in dispatch_batch: {e}"); + return; + } + + let reactions_config = target.reactions_config().clone(); + let reactions = Arc::new(StatusReactionController::new( + reactions_config.enabled, + adapter.clone(), + trigger_msg, + reactions_config.emojis.clone(), + reactions_config.timing.clone(), + )); + // 👀 already applied above; skip set_queued() to avoid double-reaction. + + let result = target + .stream_prompt_blocks( + adapter, + &session_key, + content_blocks, + thread_channel, + reactions.clone(), + other_bot_present, + ) + .await; + + match &result { + Ok(()) => reactions.set_done().await, + Err(_) => reactions.set_error().await, + } + + let hold_ms = if result.is_ok() { + reactions_config.timing.done_hold_ms + } else { + reactions_config.timing.error_hold_ms + }; + if reactions_config.remove_after_reply { + let reactions = reactions; + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(hold_ms)).await; + reactions.clear().await; + }); + } + + if let Err(ref e) = result { + let _ = adapter + .send_message(thread_channel, &format!("⚠️ {e}")) + .await; + } + + let agent_dispatch_ms = dispatch_start.elapsed().as_millis(); + let span = info_span!( + "dispatch", + channel = %thread_channel.channel_id, + adapter = adapter.platform(), + ); + let _enter = span.enter(); + info!( + thread_key = %thread_key, + events_per_dispatch = batch_size, + packed_block_count = packed_block_count, + agent_dispatch_ms = agent_dispatch_ms, + tokens_per_event = ?tokens_per_event, + wait_ms = ?wait_ms, + senders = ?senders, + "batch dispatched", + ); +} + +// --------------------------------------------------------------------------- +// Token estimation +// --------------------------------------------------------------------------- + +/// Rough token estimate for a buffered message (used for `max_batch_tokens` cap). +/// Intentionally coarse — the goal is a guard rail, not an exact pre-flight. +pub fn estimate_tokens(prompt: &str, extra_blocks: &[ContentBlock]) -> usize { + // ~4 chars per token for text; fixed 512 per image block (conservative). + let text_tokens = prompt.len() / 4 + 1; + let block_tokens: usize = extra_blocks + .iter() + .map(|b| match b { + ContentBlock::Text { text } => text.len() / 4 + 1, + ContentBlock::Image { .. } => 512, + }) + .sum(); + text_tokens + block_tokens +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn estimate_tokens_empty() { + assert!(estimate_tokens("", &[]) >= 1); + } + + #[test] + fn estimate_tokens_text() { + // 400 chars ≈ 100 tokens + let s = "a".repeat(400); + assert_eq!(estimate_tokens(&s, &[]), 101); + } + + #[test] + fn estimate_tokens_image_block() { + let blocks = vec![ContentBlock::Image { + media_type: "image/png".into(), + data: "base64data".into(), + }]; + assert_eq!(estimate_tokens("", &blocks), 1 + 512); + } + + #[test] + fn pack_arrival_event_single() { + let blocks = AdapterRouter::pack_arrival_event( + r#"{"schema":"openab.sender.v1"}"#, + "hello", + vec![], + ); + assert_eq!(blocks.len(), 1); + if let ContentBlock::Text { text } = &blocks[0] { + assert!(text.contains("")); + assert!(text.contains("hello")); + } else { + panic!("expected Text block"); + } + } + + #[test] + fn pack_arrival_event_with_extra_blocks() { + let extra = vec![ + ContentBlock::Text { text: "[Voice transcript]: hi".into() }, + ContentBlock::Image { media_type: "image/png".into(), data: "abc".into() }, + ]; + let blocks = AdapterRouter::pack_arrival_event("{}", "prompt", extra); + // header + 2 extra = 3 blocks + assert_eq!(blocks.len(), 3); + // extra blocks follow the header in arrival order + assert!(matches!(&blocks[1], ContentBlock::Text { text } if text.contains("Voice transcript"))); + assert!(matches!(&blocks[2], ContentBlock::Image { .. })); + } + + #[test] + fn pack_arrival_event_batch_n2() { + // Two arrival events concatenated → 2 header blocks + let mut all: Vec = Vec::new(); + all.extend(AdapterRouter::pack_arrival_event(r#"{"ts":"T1"}"#, "msg1", vec![])); + all.extend(AdapterRouter::pack_arrival_event(r#"{"ts":"T2"}"#, "msg2", vec![])); + assert_eq!(all.len(), 2); + if let ContentBlock::Text { text } = &all[0] { + assert!(text.contains("msg1")); + } + if let ContentBlock::Text { text } = &all[1] { + assert!(text.contains("msg2")); + } + } + + // ADR §3.6 Scenario B — text in one message, image in the next, same author. + // Broker preserves structural truth: image stays in M2 alone, both messages + // carry the same sender_id so the agent can semantically link them. + #[test] + fn pack_arrival_event_scenario_b_image_in_separate_message() { + let mut all: Vec = Vec::new(); + // M1 (alice): "see this image" + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T1"}"#, + "see this image", + vec![], + )); + // M2 (alice): image, no text + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T2"}"#, + "", + vec![ContentBlock::Image { + media_type: "image/png".into(), + data: "imgB".into(), + }], + )); + // header(M1) + header(M2) + image(M2) = 3 blocks + assert_eq!(all.len(), 3); + // M1's header carries text only + if let ContentBlock::Text { text } = &all[0] { + assert!(text.contains(r#""sender_id":"A""#)); + assert!(text.contains(r#""ts":"T1""#)); + assert!(text.contains("see this image")); + } else { + panic!("expected Text header for M1"); + } + // M2's header carries empty prompt (line after is blank) + if let ContentBlock::Text { text } = &all[1] { + assert!(text.contains(r#""ts":"T2""#)); + assert!(text.ends_with("\n\n"), "M2 prompt must be empty: {text:?}"); + } else { + panic!("expected Text header for M2"); + } + // M2's image follows immediately after its header (structural attribution) + assert!(matches!(&all[2], ContentBlock::Image { .. })); + } + + // ADR §3.6 Scenario C — fragmented multi-author batch. + // Repeated sender_id is preserved across non-adjacent messages; bob's interjection + // is kept as-is (no silent drop, no reorder). + #[test] + fn pack_arrival_event_scenario_c_multi_author_interleaved() { + let mut all: Vec = Vec::new(); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T1"}"#, + "see this image", + vec![], + )); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"B","ts":"T2"}"#, + "what?", + vec![], + )); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T3"}"#, + "", + vec![ContentBlock::Image { + media_type: "image/png".into(), + data: "imgC".into(), + }], + )); + // 3 headers + 1 image = 4 blocks + assert_eq!(all.len(), 4); + // Order is preserved (no reorder). + let h1 = match &all[0] { + ContentBlock::Text { text } => text, + _ => panic!("expected Text"), + }; + let h2 = match &all[1] { + ContentBlock::Text { text } => text, + _ => panic!("expected Text"), + }; + let h3 = match &all[2] { + ContentBlock::Text { text } => text, + _ => panic!("expected Text"), + }; + assert!(h1.contains(r#""sender_id":"A""#) && h1.contains("see this image")); + assert!(h2.contains(r#""sender_id":"B""#) && h2.contains("what?")); + assert!(h3.contains(r#""sender_id":"A""#)); + // M3's image attached to M3 only. + assert!(matches!(&all[3], ContentBlock::Image { .. })); + } + + // ADR §3.6 Scenario D — voice-only message in a batch. + // M2 has empty prompt + transcript text block. Per ADR, transcript moves AFTER + // (vs. v0.8.2-beta.1's prepended position). + #[test] + fn pack_arrival_event_scenario_d_voice_only() { + let mut all: Vec = Vec::new(); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T1"}"#, + "look at this", + vec![ContentBlock::Image { + media_type: "image/png".into(), + data: "scr".into(), + }], + )); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"A","ts":"T2"}"#, + "", + vec![ContentBlock::Text { + text: "[Voice message transcript]: hey can we sync about the deploy".into(), + }], + )); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"sender_id":"B","ts":"T3"}"#, + "what?", + vec![], + )); + // header(M1) + image(M1) + header(M2) + transcript(M2) + header(M3) = 5 + assert_eq!(all.len(), 5); + if let ContentBlock::Text { text } = &all[0] { + assert!(text.contains(r#""ts":"T1""#)); + assert!(text.contains("look at this")); + } + assert!(matches!(&all[1], ContentBlock::Image { .. })); + // M2 header has empty prompt; transcript follows AFTER the header (not before). + if let ContentBlock::Text { text } = &all[2] { + assert!(text.contains(r#""ts":"T2""#)); + assert!(text.ends_with("\n\n")); + } + if let ContentBlock::Text { text } = &all[3] { + assert!(text.contains("Voice message transcript")); + assert!(text.contains("sync about the deploy")); + } else { + panic!("expected transcript Text block as M2 attachment"); + } + if let ContentBlock::Text { text } = &all[4] { + assert!(text.contains(r#""sender_id":"B""#)); + assert!(text.contains("what?")); + } + } + + // Token-cap math: a single message that already exceeds max_batch_tokens still + // dispatches alone (the consumer_loop logic admits the first message before + // checking the cap). Verifies estimate_tokens scales with input length. + #[test] + fn estimate_tokens_oversized_single_message() { + // ~24k token text (96000 chars / 4 chars-per-token). + let big = "x".repeat(96_000); + let est = estimate_tokens(&big, &[]); + assert!(est > 24_000, "expected >24k tokens, got {est}"); + } + + // Cumulative token math: two messages whose sum exceeds max_batch_tokens. + // The consumer_loop reads first, then peeks at the next; if cumulative tokens + // > cap, the second is held over to the next batch (FIFO preserved). + #[test] + fn estimate_tokens_cumulative_exceeds_cap() { + let max_tokens = 24_000_usize; + let m1 = estimate_tokens(&"a".repeat(80_000), &[]); + let m2 = estimate_tokens(&"b".repeat(50_000), &[]); + assert!(m1 < max_tokens); + assert!(m1 + m2 > max_tokens, "{m1} + {m2} should exceed cap"); + } + + // ADR §2.5 race-safe eviction. The full SendError path requires a real + // AdapterRouter (concrete struct, not a trait — no easy mock seam), so we + // unit-test the eviction predicate in isolation. End-to-end consumer-death + // recovery is exercised by the manual staging smoke documented in the ADR. + fn dummy_handle(generation: u64) -> ThreadHandle { + let (tx, _rx) = tokio::sync::mpsc::channel::(1); + let consumer = tokio::spawn(async {}); + ThreadHandle { + tx, + consumer, + generation, + channel_id: "C".into(), + adapter_kind: "discord".into(), + } + } + + #[tokio::test] + async fn try_evict_locked_removes_when_generation_matches() { + let mut map: HashMap = HashMap::new(); + map.insert("t".into(), dummy_handle(7)); + assert!(Dispatcher::try_evict_locked(&mut map, "t", 7)); + assert!(map.is_empty()); + } + + // The bug §2.5 prevents: a stale producer (my_gen=7) observing SendError + // must not remove a freshly inserted handle (gen=8) created by another + // submit between the failed send and the eviction attempt. + #[tokio::test] + async fn try_evict_locked_keeps_when_generation_differs() { + let mut map: HashMap = HashMap::new(); + map.insert("t".into(), dummy_handle(8)); + assert!(!Dispatcher::try_evict_locked(&mut map, "t", 7)); + assert_eq!(map.len(), 1); + assert_eq!(map.get("t").unwrap().generation, 8); + } + + #[tokio::test] + async fn try_evict_locked_returns_false_when_absent() { + let mut map: HashMap = HashMap::new(); + assert!(!Dispatcher::try_evict_locked(&mut map, "missing", 0)); + } + + // BatchGrouping → thread_key shape. + fn make_dispatcher(grouping: BatchGrouping) -> Dispatcher { + // The router is wrapped in Arc but never used by `key()` itself; we use + // a dummy AdapterRouter built via the same path main.rs would use. + // For a pure-keying test we'd ideally not need it, but the constructor demands one. + // Construct a minimal router via the public test helpers in adapter.rs if available; + // otherwise we fall back to building one with a dummy SessionPool. + use crate::acp::SessionPool; + let agent_cfg = crate::config::AgentConfig { + command: "/bin/true".into(), + args: vec![], + working_dir: "/tmp".into(), + env: std::collections::HashMap::new(), + inherit_env: vec![], + }; + let pool = Arc::new(SessionPool::new(agent_cfg, 1)); + let router = Arc::new(AdapterRouter::new( + pool, + crate::config::ReactionsConfig::default(), + crate::markdown::TableMode::Off, + )); + Dispatcher::new(router, 10, 24_000, grouping) + } + + #[tokio::test] + async fn key_per_thread_ignores_sender() { + let d = make_dispatcher(BatchGrouping::Thread); + assert_eq!(d.key("discord", "T1", "userA"), "discord:T1"); + assert_eq!(d.key("discord", "T1", "userB"), "discord:T1"); + } + + #[tokio::test] + async fn key_per_lane_includes_sender() { + let d = make_dispatcher(BatchGrouping::Lane); + assert_eq!(d.key("discord", "T1", "userA"), "discord:T1:userA"); + assert_eq!(d.key("discord", "T1", "userB"), "discord:T1:userB"); + // Different threads remain distinct. + assert_eq!(d.key("slack", "T2", "userA"), "slack:T2:userA"); + } + + fn insert_dummy_handle(d: &Dispatcher, key: &str) { + let (tx, _rx) = tokio::sync::mpsc::channel::(10); + let consumer = tokio::spawn(async {}); + let handle = ThreadHandle { + tx, + consumer, + generation: 0, + channel_id: "c".into(), + adapter_kind: "discord".into(), + }; + d.per_thread.lock().unwrap().insert(key.to_string(), handle); + } + + #[tokio::test] + async fn cancel_buffered_thread_drops_per_thread_key() { + let d = make_dispatcher(BatchGrouping::Thread); + insert_dummy_handle(&d, "discord:T1"); + insert_dummy_handle(&d, "discord:T2"); // different thread, must survive + assert_eq!(d.cancel_buffered_thread("discord", "T1"), 0); // no buffered msgs + let map = d.per_thread.lock().unwrap(); + assert!(!map.contains_key("discord:T1")); + assert!(map.contains_key("discord:T2")); + } + + #[tokio::test] + async fn cancel_buffered_thread_drops_all_lanes() { + let d = make_dispatcher(BatchGrouping::Lane); + insert_dummy_handle(&d, "discord:T1:userA"); + insert_dummy_handle(&d, "discord:T1:userB"); + insert_dummy_handle(&d, "discord:T2:userA"); // different thread + insert_dummy_handle(&d, "slack:T1:userA"); // different platform + d.cancel_buffered_thread("discord", "T1"); + let map = d.per_thread.lock().unwrap(); + assert!(!map.contains_key("discord:T1:userA")); + assert!(!map.contains_key("discord:T1:userB")); + assert!(map.contains_key("discord:T2:userA")); + assert!(map.contains_key("slack:T1:userA")); + } + + #[tokio::test] + async fn cancel_buffered_thread_does_not_match_thread_id_prefix() { + // T1 must not match T10 / T11 (substring trap). + let d = make_dispatcher(BatchGrouping::Lane); + insert_dummy_handle(&d, "discord:T1:userA"); + insert_dummy_handle(&d, "discord:T10:userA"); + d.cancel_buffered_thread("discord", "T1"); + let map = d.per_thread.lock().unwrap(); + assert!(!map.contains_key("discord:T1:userA")); + assert!(map.contains_key("discord:T10:userA")); + } + + // Long-running consumer that parks until aborted — used by sweep_stale / + // shutdown tests to exercise the "still alive" path. + fn alive_consumer_handle() -> ThreadHandle { + let (tx, _rx) = tokio::sync::mpsc::channel::(10); + let consumer = tokio::spawn(async { + std::future::pending::<()>().await; + }); + ThreadHandle { + tx, + consumer, + generation: 0, + channel_id: "c".into(), + adapter_kind: "discord".into(), + } + } + + #[tokio::test] + async fn sweep_stale_removes_finished_consumers() { + let d = make_dispatcher(BatchGrouping::Thread); + insert_dummy_handle(&d, "discord:T1"); + insert_dummy_handle(&d, "discord:T2"); + // Yield so the empty-body spawned tasks actually run to completion + // before is_finished() is checked. + tokio::time::sleep(Duration::from_millis(10)).await; + let swept = d.sweep_stale(); + assert_eq!(swept, 2); + assert!(d.per_thread.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn sweep_stale_keeps_running_consumers() { + let d = make_dispatcher(BatchGrouping::Thread); + let abort = { + let h = alive_consumer_handle(); + let a = h.consumer.abort_handle(); + d.per_thread.lock().unwrap().insert("alive".into(), h); + a + }; + let swept = d.sweep_stale(); + assert_eq!(swept, 0); + assert!(d.per_thread.lock().unwrap().contains_key("alive")); + // Cleanup so the parked task doesn't linger across tests. + abort.abort(); + } + + #[tokio::test] + async fn shutdown_clears_all_handles() { + let d = make_dispatcher(BatchGrouping::Thread); + insert_dummy_handle(&d, "k1"); + insert_dummy_handle(&d, "k2"); + insert_dummy_handle(&d, "k3"); + d.shutdown(); + assert!(d.per_thread.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn shutdown_aborts_running_consumers() { + let d = make_dispatcher(BatchGrouping::Thread); + let abort = { + let h = alive_consumer_handle(); + let a = h.consumer.abort_handle(); + d.per_thread.lock().unwrap().insert("k".into(), h); + a + }; + d.shutdown(); + // Give the runtime a tick to process abort + map drop. + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(abort.is_finished()); + } + + // ----------------------------------------------------------------------- + // consumer_loop / dispatch_batch integration tests (NIT 2) + // + // These drive `consumer_loop` directly with a pre-populated mpsc, using + // `MockDispatchTarget` to record the calls that would otherwise hit a + // real `AdapterRouter` (and through it, ACP CLI subprocesses). This + // gives deterministic coverage of the orchestration paths the existing + // unit tests don't reach: greedy drain, token-cap overflow, idle timeout. + // ----------------------------------------------------------------------- + + /// One recorded `stream_prompt_blocks` invocation. + #[derive(Clone)] + struct RecordedDispatch { + block_count: usize, + other_bot_present: bool, + } + + /// Mock `DispatchTarget` — records calls; never touches a real session pool. + struct MockDispatchTarget { + reactions: ReactionsConfig, + calls: Mutex>, + /// If set, `ensure_session` returns this error once. + ensure_err: Mutex>, + /// If set, `stream_prompt_blocks` returns this error once. + stream_err: Mutex>, + } + + impl MockDispatchTarget { + fn new() -> Self { + Self { + reactions: ReactionsConfig::default(), + calls: Mutex::new(Vec::new()), + ensure_err: Mutex::new(None), + stream_err: Mutex::new(None), + } + } + + fn calls(&self) -> Vec { + self.calls.lock().unwrap().clone() + } + } + + #[async_trait] + impl DispatchTarget for MockDispatchTarget { + fn reactions_config(&self) -> &ReactionsConfig { + &self.reactions + } + + async fn ensure_session(&self, _session_key: &str) -> Result<()> { + if let Some(msg) = self.ensure_err.lock().unwrap().take() { + return Err(anyhow::anyhow!(msg)); + } + Ok(()) + } + + async fn stream_prompt_blocks( + &self, + _adapter: &Arc, + _session_key: &str, + content_blocks: Vec, + _thread_channel: &ChannelRef, + _reactions: Arc, + other_bot_present: bool, + ) -> Result<()> { + self.calls.lock().unwrap().push(RecordedDispatch { + block_count: content_blocks.len(), + other_bot_present, + }); + if let Some(msg) = self.stream_err.lock().unwrap().take() { + return Err(anyhow::anyhow!(msg)); + } + Ok(()) + } + } + + /// Mock `ChatAdapter` — every method is a no-op success. The dispatch loop + /// invokes `add_reaction` (queued 👀), `platform`, and on the error path + /// `send_message`; nothing else needs real behavior here. + struct MockChatAdapter; + + #[async_trait] + impl ChatAdapter for MockChatAdapter { + fn platform(&self) -> &'static str { "mock" } + fn message_limit(&self) -> usize { 2000 } + + async fn send_message(&self, channel: &ChannelRef, _content: &str) -> Result { + Ok(MessageRef { channel: channel.clone(), message_id: "mock-msg".into() }) + } + + async fn create_thread( + &self, + channel: &ChannelRef, + _trigger_msg: &MessageRef, + _title: &str, + ) -> Result { + Ok(channel.clone()) + } + + async fn add_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { Ok(()) } + async fn remove_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { Ok(()) } + fn use_streaming(&self, _other_bot_present: bool) -> bool { false } + } + + fn make_channel(thread: &str) -> ChannelRef { + ChannelRef { + platform: "mock".into(), + channel_id: thread.into(), + thread_id: Some(thread.into()), + parent_id: None, + origin_event_id: None, + } + } + + fn make_msg(prompt: &str, tokens: usize) -> BufferedMessage { + BufferedMessage { + sender_json: r#"{"schema":"openab.sender.v1","sender_id":"u","sender_name":"u"}"#.into(), + sender_name: "u".into(), + prompt: prompt.into(), + extra_blocks: vec![], + trigger_msg: MessageRef { + channel: make_channel("T"), + message_id: format!("m-{prompt}"), + }, + arrived_at: Instant::now(), + estimated_tokens: tokens, + other_bot_present: false, + } + } + + /// Pre-load `msgs` into a fresh mpsc, drop the sender, and run + /// `consumer_loop` to completion. Returns the recorded dispatches. + async fn run_consumer_with_messages( + msgs: Vec, + max_batch: usize, + max_tokens: usize, + ) -> Vec { + let mock = Arc::new(MockDispatchTarget::new()); + let target: Arc = mock.clone(); + let adapter: Arc = Arc::new(MockChatAdapter); + let (tx, rx) = tokio::sync::mpsc::channel::(msgs.len().max(1)); + for m in msgs { + tx.send(m).await.unwrap(); + } + drop(tx); + + consumer_loop( + "mock:T".into(), + make_channel("T"), + rx, + target, + adapter, + max_batch, + max_tokens, + Duration::from_secs(60), + ) + .await; + + mock.calls() + } + + #[tokio::test] + async fn consumer_dispatches_single_message_as_one_batch() { + let calls = run_consumer_with_messages(vec![make_msg("hi", 10)], 10, 24_000).await; + assert_eq!(calls.len(), 1); + // pack_arrival_event with no extra_blocks → 1 Text block per message. + assert_eq!(calls[0].block_count, 1); + assert!(!calls[0].other_bot_present); + } + + #[tokio::test] + async fn consumer_greedy_drain_combines_queued_messages_into_one_batch() { + // 3 messages already in the queue when the consumer wakes → greedy + // drain pulls all 3, packs them into one batch, dispatches once. + let calls = run_consumer_with_messages( + vec![make_msg("a", 50), make_msg("b", 50), make_msg("c", 50)], + 10, + 24_000, + ) + .await; + assert_eq!(calls.len(), 1, "expected a single batched dispatch"); + assert_eq!(calls[0].block_count, 3, "one Text block per arrival event"); + } + + #[tokio::test] + async fn consumer_token_cap_splits_batch_preserving_fifo() { + // max_tokens=100, two 80-token messages → cumulative 160 > 100, so + // msg2 becomes `pending` and is dispatched in the next batch. + let calls = + run_consumer_with_messages(vec![make_msg("a", 80), make_msg("b", 80)], 10, 100).await; + assert_eq!(calls.len(), 2, "token cap should split into two batches"); + assert_eq!(calls[0].block_count, 1); + assert_eq!(calls[1].block_count, 1); + } + + #[tokio::test] + async fn consumer_exits_after_idle_timeout_with_no_messages() { + // No messages ever arrive; consumer should exit once `idle_timeout` + // elapses. Keep `tx` alive so the exit path is the timeout, not the + // "all senders dropped" branch. + let mock = Arc::new(MockDispatchTarget::new()); + let target: Arc = mock.clone(); + let adapter: Arc = Arc::new(MockChatAdapter); + let (tx, rx) = tokio::sync::mpsc::channel::(1); + let consumer = tokio::spawn(consumer_loop( + "mock:T".into(), + make_channel("T"), + rx, + target, + adapter, + 10, + 24_000, + Duration::from_millis(50), + )); + // Wait enough for the timeout branch + a tick for the task to finish. + tokio::time::sleep(Duration::from_millis(150)).await; + assert!(consumer.is_finished(), "consumer should exit after idle timeout"); + // No dispatches should have been recorded. + assert!(mock.calls().is_empty()); + drop(tx); + } + + #[tokio::test] + async fn submit_evicts_dead_handle_and_retries_with_fresh_consumer() { + // §2.5: if `tx.send()` returns `SendError` (consumer's rx dropped + // mid-flight), `submit` evicts the stale entry under lock and spawns + // a fresh consumer. Manufacture this state by inserting a handle + // whose consumer is still parked but whose rx has been dropped. + let mock = Arc::new(MockDispatchTarget::new()); + let target: Arc = mock.clone(); + let d = Dispatcher::new(target, 10, 24_000, BatchGrouping::Thread); + let adapter: Arc = Arc::new(MockChatAdapter); + + let key = "mock:T".to_string(); + let parked = { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + drop(rx); // closes the channel → next tx.send() yields SendError + let consumer = tokio::spawn(std::future::pending::<()>()); + let abort = consumer.abort_handle(); + let handle = ThreadHandle { + tx, + consumer, + generation: 999, + channel_id: "T".into(), + adapter_kind: "mock".into(), + }; + d.per_thread.lock().unwrap().insert(key.clone(), handle); + abort + }; + + d.submit(key, make_channel("T"), adapter, make_msg("hello", 10)) + .await + .expect("retry should spawn a fresh consumer"); + // Give the freshly spawned consumer time to drain + dispatch. + tokio::time::sleep(Duration::from_millis(50)).await; + + let calls = mock.calls(); + assert_eq!(calls.len(), 1, "fresh consumer should have dispatched the retry"); + assert_eq!(calls[0].block_count, 1); + + parked.abort(); + } +} diff --git a/src/gateway.rs b/src/gateway.rs index 816afb45..28799f1b 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1,4 +1,4 @@ -use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef, SenderContext}; +use crate::adapter::{ChannelRef, ChatAdapter, MessageRef, SenderContext}; use anyhow::Result; use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; @@ -348,8 +348,9 @@ pub struct GatewayParams { pub async fn run_gateway_adapter( params: GatewayParams, - router: Arc, mut shutdown_rx: tokio::sync::watch::Receiver, + dispatcher: Arc, + router: Arc, ) -> Result<()> { let platform: &'static str = Box::leak(params.platform.into_boxed_str()); @@ -487,6 +488,12 @@ pub async fn run_gateway_adapter( channel_id: event.channel.id.clone(), thread_id: event.channel.thread_id.clone(), is_bot: event.sender.is_bot, + // Gateway: use event timestamp if available, else broker receive time + timestamp: if event.timestamp.is_empty() { + crate::timestamp::now_iso8601() + } else { + event.timestamp.clone() + }, }; let sender_json = serde_json::to_string(&sender_ctx) .unwrap_or_default(); @@ -497,8 +504,10 @@ pub async fn run_gateway_adapter( }; let adapter = adapter.clone(); - let router = router.clone(); let prompt = event.content.text.clone(); + let sender_name = event.sender.name.clone(); + let sender_id = event.sender.id.clone(); + let dispatcher = dispatcher.clone(); // Slash command interception for gateway platforms // (Feishu/LINE/Telegram don't have native slash commands) @@ -506,12 +515,16 @@ pub async fn run_gateway_adapter( // need message_id for streaming edits. let trimmed = prompt.trim(); if trimmed == "/reset" { - let thread_key = format!("{}:{}", event.platform, event.channel.thread_id.as_deref().unwrap_or(&event.channel.id)); - let msg = match router.pool().reset_session(&thread_key).await { - Ok(()) => "🔄 Session reset. Start a new conversation!", - Err(_) => "⚠️ No active session to reset.", + let thread_id_str = event.channel.thread_id.as_deref().unwrap_or(&event.channel.id); + let thread_key = format!("{}:{}", event.platform, thread_id_str); + let dropped = dispatcher.cancel_buffered_thread(event.platform.as_str(), thread_id_str); + let msg = match (router.pool().reset_session(&thread_key).await, dropped) { + (Ok(()), 0) => "🔄 Session reset. Start a new conversation!".to_string(), + (Ok(()), n) => format!("🔄 Session reset. Dropped {n} buffered message(s). Start a new conversation!"), + (Err(_), 0) => "⚠️ No active session to reset.".to_string(), + (Err(_), n) => format!("🔄 Dropped {n} buffered message(s). No active session to reset."), }; - let _ = send_fire_and_forget(&slash_ws_tx, &channel, msg).await; + let _ = send_fire_and_forget(&slash_ws_tx, &channel, &msg).await; continue; } if trimmed == "/cancel" { @@ -541,19 +554,33 @@ pub async fn run_gateway_adapter( channel.clone() }; - if let Err(e) = router - .handle_message( - &adapter, - &thread_channel, - &sender_json, - &prompt, - vec![], - &trigger_msg, - false, - ) + let thread_id = thread_channel + .thread_id + .as_deref() + .unwrap_or(&thread_channel.channel_id); + let thread_key = dispatcher.key( + &thread_channel.platform, + thread_id, + &sender_id, + ); + let estimated_tokens = + crate::dispatch::estimate_tokens(&prompt, &[]); + let buf_msg = crate::dispatch::BufferedMessage { + sender_json, + sender_name, + prompt, + extra_blocks: vec![], + trigger_msg, + arrived_at: std::time::Instant::now(), + estimated_tokens, + // TODO: implement gateway multibot detection + other_bot_present: false, + }; + if let Err(e) = dispatcher + .submit(thread_key, thread_channel, adapter, buf_msg) .await { - error!("gateway message handling error: {e}"); + error!("gateway dispatcher submit error: {e}"); } }); } @@ -592,3 +619,4 @@ pub async fn run_gateway_adapter( backoff_secs = (backoff_secs * 2).min(MAX_BACKOFF); } // outer reconnect loop } + diff --git a/src/main.rs b/src/main.rs index ffe6985c..6372b680 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ mod bot_turns; mod config; mod cron; mod discord; +mod dispatch; mod error_display; mod format; mod markdown; @@ -13,6 +14,7 @@ mod setup; mod slack; mod stt; mod gateway; +mod timestamp; use adapter::AdapterRouter; use clap::Parser; @@ -20,7 +22,7 @@ use serenity::gateway::GatewayError; use serenity::prelude::*; use std::collections::HashSet; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use tracing::{error, info, warn}; /// Wait for SIGINT (ctrl_c) or, on unix, SIGTERM. SIGTERM is what Kubernetes @@ -142,12 +144,21 @@ async fn main() -> anyhow::Result<()> { // Shutdown signal for Slack adapter let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); + // Dispatcher handles tracked here so SIGTERM cleanup can call shutdown() on each (ADR §6.8). + // Also shared with the cleanup task for periodic stale-entry sweeping. + let dispatchers: Arc>>> = Arc::new(Mutex::new(Vec::new())); + // Spawn cleanup task let cleanup_pool = pool.clone(); + let cleanup_dispatchers = dispatchers.clone(); let cleanup_handle = tokio::spawn(async move { loop { tokio::time::sleep(std::time::Duration::from_secs(60)).await; cleanup_pool.cleanup_idle(ttl_secs).await; + // Sweep stale per-thread dispatcher entries (idle-exited consumers). + for d in cleanup_dispatchers.lock().unwrap().iter() { + d.sweep_stale(); + } } }); @@ -188,6 +199,24 @@ async fn main() -> anyhow::Result<()> { let max_bot_turns = slack_cfg.max_bot_turns; let slack_shutdown_rx = shutdown_rx.clone(); let adapter = shared_slack_adapter.clone().expect("shared_slack_adapter must exist when slack config is present"); + // Dispatcher is the sole serialization path for all modes. Message = cap 1 + // (each message dispatches alone, FIFO). Thread / Lane = configured cap; + // grouping decides whether senders share a buffer or get their own lane. + let (slack_cap, slack_grouping) = match slack_cfg.message_processing_mode { + config::MessageProcessingMode::Message => + (1, dispatch::BatchGrouping::Thread), + config::MessageProcessingMode::Thread => + (slack_cfg.max_buffered_messages, dispatch::BatchGrouping::Thread), + config::MessageProcessingMode::Lane => + (slack_cfg.max_buffered_messages, dispatch::BatchGrouping::Lane), + }; + let slack_dispatcher = Arc::new(dispatch::Dispatcher::new( + router.clone(), + slack_cap, + slack_cfg.max_batch_tokens, + slack_grouping, + )); + dispatchers.lock().unwrap().push(slack_dispatcher.clone()); Some(tokio::spawn(async move { if let Err(e) = slack::run_slack_adapter( adapter, @@ -201,8 +230,8 @@ async fn main() -> anyhow::Result<()> { slack_cfg.allow_user_messages, max_bot_turns, stt, - router, slack_shutdown_rx, + slack_dispatcher, ) .await { @@ -218,6 +247,21 @@ async fn main() -> anyhow::Result<()> { let router = router.clone(); let shutdown_rx = shutdown_rx.clone(); info!(url = %gw_cfg.url, "starting gateway adapter"); + let (gw_cap, gw_grouping) = match gw_cfg.message_processing_mode { + config::MessageProcessingMode::Message => + (1, dispatch::BatchGrouping::Thread), + config::MessageProcessingMode::Thread => + (gw_cfg.max_buffered_messages, dispatch::BatchGrouping::Thread), + config::MessageProcessingMode::Lane => + (gw_cfg.max_buffered_messages, dispatch::BatchGrouping::Lane), + }; + let gw_dispatcher = Arc::new(dispatch::Dispatcher::new( + router.clone(), + gw_cap, + gw_cfg.max_batch_tokens, + gw_grouping, + )); + dispatchers.lock().unwrap().push(gw_dispatcher.clone()); let params = gateway::GatewayParams { url: gw_cfg.url, platform: gw_cfg.platform, @@ -229,8 +273,9 @@ async fn main() -> anyhow::Result<()> { allowed_users: gw_cfg.allowed_users, streaming: gw_cfg.streaming, }; + let gw_router = router.clone(); Some(tokio::spawn(async move { - if let Err(e) = gateway::run_gateway_adapter(params, router, shutdown_rx).await { + if let Err(e) = gateway::run_gateway_adapter(params, shutdown_rx, gw_dispatcher, gw_router).await { error!("gateway adapter error: {e}"); } })) @@ -301,6 +346,22 @@ async fn main() -> anyhow::Result<()> { "starting discord adapter" ); + let (discord_cap, discord_grouping) = match discord_cfg.message_processing_mode { + config::MessageProcessingMode::Message => + (1, dispatch::BatchGrouping::Thread), + config::MessageProcessingMode::Thread => + (discord_cfg.max_buffered_messages, dispatch::BatchGrouping::Thread), + config::MessageProcessingMode::Lane => + (discord_cfg.max_buffered_messages, dispatch::BatchGrouping::Lane), + }; + let discord_dispatcher = Arc::new(dispatch::Dispatcher::new( + router.clone(), + discord_cap, + discord_cfg.max_batch_tokens, + discord_grouping, + )); + dispatchers.lock().unwrap().push(discord_dispatcher.clone()); + let handler = discord::Handler { router, allow_all_channels, @@ -318,6 +379,7 @@ async fn main() -> anyhow::Result<()> { max_bot_turns: discord_cfg.max_bot_turns, bot_turns: tokio::sync::Mutex::new(bot_turns::BotTurnTracker::new(discord_cfg.max_bot_turns)), allow_dm: discord_cfg.allow_dm, + dispatcher: discord_dispatcher, }; let intents = GatewayIntents::GUILD_MESSAGES @@ -378,6 +440,10 @@ async fn main() -> anyhow::Result<()> { // cron.rs drains in-flight tasks for up to 30s, so wait slightly longer let _ = tokio::time::timeout(std::time::Duration::from_secs(35), handle).await; } + // Drain per-thread dispatchers and log buffered_lost counts before pool shutdown (ADR §6.8). + for d in dispatchers.lock().unwrap().iter() { + d.shutdown(); + } let shutdown_pool = pool; shutdown_pool.shutdown().await; info!("openab shut down"); diff --git a/src/slack.rs b/src/slack.rs index 7619255f..015bf978 100644 --- a/src/slack.rs +++ b/src/slack.rs @@ -1,5 +1,5 @@ use crate::acp::ContentBlock; -use crate::adapter::{AdapterRouter, ChatAdapter, ChannelRef, MessageRef, SenderContext}; +use crate::adapter::{ChatAdapter, ChannelRef, MessageRef, SenderContext}; use crate::bot_turns::{BotTurnTracker, TurnAction, TurnSeverity}; use crate::config::{AllowBots, AllowUsers, SttConfig}; use crate::media; @@ -428,50 +428,6 @@ impl ChatAdapter for SlackAdapter { } } -// --- Per-thread async queue (inspired by OpenClaw's KeyedAsyncQueue) --- - -/// Serialize async work per key while allowing unrelated keys to run concurrently. -/// Same-key tasks execute in FIFO order; different keys run in parallel. -/// Idle keys are cleaned up automatically after the last task settles. -struct KeyedAsyncQueue { - tails: tokio::sync::Mutex>>, -} - -impl KeyedAsyncQueue { - fn new() -> Self { - Self { - tails: tokio::sync::Mutex::new(HashMap::new()), - } - } - - /// Acquire a per-key permit. The returned guard must be held for the - /// duration of the async work. Dropping it allows the next queued task - /// for the same key to proceed. - /// - /// Performs lazy cleanup of idle semaphores to prevent unbounded growth - /// in long-running deployments. - async fn acquire(&self, key: &str) -> Option { - let sem = { - let mut tails = self.tails.lock().await; - // Lazy cleanup: evict idle entries (available_permits == 1 means no one is holding or waiting) - if tails.len() > 100 { - tails.retain(|_, sem| Arc::strong_count(sem) > 1 || sem.available_permits() < 1); - } - tails - .entry(key.to_string()) - .or_insert_with(|| Arc::new(tokio::sync::Semaphore::new(1))) - .clone() - }; - match sem.acquire_owned().await { - Ok(permit) => Some(permit), - Err(e) => { - warn!(key, error = %e, "semaphore closed, skipping message"); - None - } - } - } -} - // --- Socket Mode event loop --- /// Hard cap on consecutive bot messages in a thread. Prevents runaway loops. @@ -492,10 +448,9 @@ pub async fn run_slack_adapter( allow_user_messages: AllowUsers, max_bot_turns: u32, stt_config: SttConfig, - router: Arc, mut shutdown_rx: watch::Receiver, + dispatcher: Arc, ) -> Result<()> { - let queue = Arc::new(KeyedAsyncQueue::new()); let bot_token = adapter.bot_token().to_string(); let bot_turns = Arc::new(tokio::sync::Mutex::new(BotTurnTracker::new(max_bot_turns))); @@ -589,19 +544,8 @@ pub async fn run_slack_adapter( let allowed_channels = allowed_channels.clone(); let allowed_users = allowed_users.clone(); let stt_config = stt_config.clone(); - let router = router.clone(); - let queue = queue.clone(); - // Queue key: thread_ts if already in a thread, otherwise ts. - // app_mention always has a channel context, so ts alone - // is unique enough (unlike message events in DMs where - // we prefix with channel_id to avoid ts collisions). - let queue_key = event["thread_ts"] - .as_str() - .or_else(|| event["ts"].as_str()) - .unwrap_or("") - .to_string(); + let dispatcher = dispatcher.clone(); tokio::spawn(async move { - let Some(_permit) = queue.acquire(&queue_key).await else { return }; handle_message( &event, &adapter, @@ -611,7 +555,7 @@ pub async fn run_slack_adapter( &allowed_channels, &allowed_users, &stt_config, - &router, + &dispatcher, ) .await; }); @@ -665,8 +609,7 @@ pub async fn run_slack_adapter( // --- Bot turn tracking --- // Runs before self-check so ALL bot messages (including own) // count toward the per-thread limit. Matches Discord #483. - // Keyed on thread_ts when in a thread, else channel:ts (the - // same key shape used for per-thread queueing below). + // Keyed on thread_ts when in a thread, else channel:ts. // Non-thread messages get a unique key per message, so the // counter never accumulates — intentional, because bot-to-bot // loops only happen inside threads. @@ -821,27 +764,17 @@ pub async fn run_slack_adapter( } } - // Dispatch to handle_message (serialized per thread) + // Dispatch to handle_message (per-thread serialization comes + // from Dispatcher consumer task in batched mode and from + // pool.with_connection in per-message mode). let event = event.clone(); let adapter = adapter.clone(); let bot_token = bot_token.clone(); let allowed_channels = allowed_channels.clone(); let allowed_users = allowed_users.clone(); let stt_config = stt_config.clone(); - let router = router.clone(); - let queue = queue.clone(); - // Queue key: thread_ts if in a thread, otherwise channel:ts. - // Prefixed with channel_id for non-thread messages because - // DMs and channels can have overlapping ts values — the - // prefix ensures keys are globally unique. - let queue_key = event["thread_ts"] - .as_str() - .map(|s| s.to_string()) - .unwrap_or_else(|| { - format!("{}:{}", channel_id, event["ts"].as_str().unwrap_or("")) - }); + let dispatcher = dispatcher.clone(); tokio::spawn(async move { - let Some(_permit) = queue.acquire(&queue_key).await else { return }; handle_message( &event, &adapter, @@ -851,7 +784,7 @@ pub async fn run_slack_adapter( &allowed_channels, &allowed_users, &stt_config, - &router, + &dispatcher, ) .await; }); @@ -922,7 +855,7 @@ async fn handle_message( allowed_channels: &HashSet, allowed_users: &HashSet, stt_config: &SttConfig, - router: &Arc, + dispatcher: &Arc, ) { let channel_id = match event["channel"].as_str() { Some(ch) => ch.to_string(), @@ -1093,6 +1026,7 @@ async fn handle_message( channel_id: channel_id.clone(), thread_id: thread_ts.clone(), is_bot: is_bot_msg, + timestamp: crate::timestamp::slack_ts_to_iso8601(&ts), }; let trigger_msg = MessageRef { @@ -1133,11 +1067,27 @@ async fn handle_message( thread_channel.thread_id.as_deref() .is_some_and(|ts| cache.get(ts).is_some_and(|inst| inst.elapsed() < adapter.session_ttl)) }; - if let Err(e) = router - .handle_message(&adapter_dyn, &thread_channel, &sender_json, &prompt, extra_blocks, &trigger_msg, other_bot_present) + let thread_id = thread_channel + .thread_id + .as_deref() + .unwrap_or(&thread_channel.channel_id); + let thread_key = dispatcher.key("slack", thread_id, &sender.sender_id); + let estimated_tokens = crate::dispatch::estimate_tokens(&prompt, &extra_blocks); + let buf_msg = crate::dispatch::BufferedMessage { + sender_json, + sender_name: sender.sender_name.clone(), + prompt, + extra_blocks, + trigger_msg, + arrived_at: std::time::Instant::now(), + estimated_tokens, + other_bot_present, + }; + if let Err(e) = dispatcher + .submit(thread_key, thread_channel, adapter_dyn, buf_msg) .await { - error!("Slack handle_message error: {e}"); + error!("Slack dispatcher submit error: {e}"); } } diff --git a/src/timestamp.rs b/src/timestamp.rs new file mode 100644 index 00000000..485a7864 --- /dev/null +++ b/src/timestamp.rs @@ -0,0 +1,88 @@ +//! ISO 8601 UTC timestamp helpers — no external crate dependency. +//! +//! Centralizes the Gregorian date math used by Slack (`.` ts strings) +//! and Gateway (`SystemTime::now()`) so both adapters share one implementation. + +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Convert days since the Unix epoch (1970-01-01) to a Gregorian (year, month, day). +/// Algorithm from . +fn days_to_ymd(days: u64) -> (u64, u64, u64) { + let z = days + 719468; + let era = z / 146097; + let doe = z % 146097; + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + let y = yoe + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let d = doy - (153 * mp + 2) / 5 + 1; + let m = if mp < 10 { mp + 3 } else { mp - 9 }; + let y = if m <= 2 { y + 1 } else { y }; + (y, m, d) +} + +/// Format a Unix timestamp (seconds + millis) as ISO 8601 UTC with millisecond precision. +fn unix_to_iso8601(secs: u64, ms: u64) -> String { + let days = secs / 86400; + let time_secs = secs % 86400; + let h = time_secs / 3600; + let m = (time_secs % 3600) / 60; + let s = time_secs % 60; + let (year, month, day) = days_to_ymd(days); + format!("{year:04}-{month:02}-{day:02}T{h:02}:{m:02}:{s:02}.{ms:03}Z") +} + +/// Convert a Slack `ts` string (".") to ISO 8601 UTC. +/// Best-effort; falls back to epoch on parse failure. +pub fn slack_ts_to_iso8601(ts: &str) -> String { + let mut parts = ts.splitn(2, '.'); + let secs = parts.next().unwrap_or("0").parse::().unwrap_or(0); + let frac = parts.next().unwrap_or("000"); + let ms: u64 = frac.chars().take(3).collect::().parse().unwrap_or(0); + unix_to_iso8601(secs, ms) +} + +/// Current wall-clock instant as ISO 8601 UTC with millisecond precision. +pub fn now_iso8601() -> String { + let dur = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default(); + unix_to_iso8601(dur.as_secs(), (dur.subsec_millis()) as u64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn slack_ts_epoch_zero() { + assert_eq!(slack_ts_to_iso8601("0.000000"), "1970-01-01T00:00:00.000Z"); + } + + #[test] + fn slack_ts_keeps_milliseconds() { + // 1714204397 = 2024-04-27T07:53:17 UTC; .123456 → .123 ms + assert_eq!(slack_ts_to_iso8601("1714204397.123456"), "2024-04-27T07:53:17.123Z"); + } + + #[test] + fn slack_ts_missing_fraction_uses_zero() { + assert_eq!(slack_ts_to_iso8601("1714204397"), "2024-04-27T07:53:17.000Z"); + } + + #[test] + fn slack_ts_unparseable_falls_back_to_epoch() { + assert_eq!(slack_ts_to_iso8601("not-a-ts"), "1970-01-01T00:00:00.000Z"); + } + + #[test] + fn now_iso8601_has_expected_shape() { + let s = now_iso8601(); + // YYYY-MM-DDTHH:MM:SS.mmmZ = 24 chars + assert_eq!(s.len(), 24); + assert!(s.ends_with('Z')); + assert_eq!(&s[4..5], "-"); + assert_eq!(&s[10..11], "T"); + assert_eq!(&s[19..20], "."); + } +}