diff --git a/Cargo.lock b/Cargo.lock index 2d0bc6ce2..94cd99f24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3148,6 +3148,19 @@ dependencies = [ "z3", ] +[[package]] +name = "openshell-provider-auth" +version = "0.0.0" +dependencies = [ + "base64 0.22.1", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "url", +] + [[package]] name = "openshell-providers" version = "0.0.0" @@ -3193,6 +3206,7 @@ dependencies = [ "openshell-core", "openshell-ocsf", "openshell-policy", + "openshell-provider-auth", "openshell-router", "rand_core 0.6.4", "rcgen", @@ -3215,6 +3229,7 @@ dependencies = [ "tracing", "tracing-appender", "tracing-subscriber", + "url", "uuid", "webpki-roots 1.0.6", ] diff --git a/architecture/sandbox-providers.md b/architecture/sandbox-providers.md index fe5d48a97..e77fa9c56 100644 --- a/architecture/sandbox-providers.md +++ b/architecture/sandbox-providers.md @@ -94,6 +94,7 @@ pub trait ProviderPlugin: Send + Sync { | `nvidia.rs` | `NVIDIA_API_KEY` | *(none)* | | `gitlab.rs` | `GITLAB_TOKEN`, `GLAB_TOKEN`, `CI_JOB_TOKEN` | `~/.config/glab-cli/config.yml` | | `github.rs` | `GITHUB_TOKEN`, `GH_TOKEN` | `~/.config/gh/hosts.yml` | +| `microsoft_agent_s2s.rs` | `AZURE_TENANT_ID`, `A365_BLUEPRINT_CLIENT_ID`, `A365_BLUEPRINT_CLIENT_SECRET`, `A365_RUNTIME_AGENT_ID`, `A365_ALLOWED_AUDIENCES`, `A365_OBSERVABILITY_RESOURCE`, `A365_REQUIRED_ROLES` | *(none)* | | `outlook.rs` | *(none)* | *(none)* | `generic` and `outlook` are stubs — `discover_existing()` always returns `None`. @@ -241,16 +242,25 @@ variables (injected into the pod spec by the gateway's Kubernetes sandbox creati In `run_sandbox()` (`crates/openshell-sandbox/src/lib.rs`): -1. loads the sandbox policy via gRPC (`GetSandboxSettings`), +1. loads the sandbox policy via gRPC (`GetSandboxConfig`), 2. fetches provider credentials via gRPC (`GetSandboxProviderEnvironment`), -3. if the fetch fails, continues with an empty map (graceful degradation with a warning). +3. if the fetch fails, continues with an empty map (graceful degradation with a warning), +4. starts any provider-specific runtime resolvers, such as `microsoft-agent-s2s`. -The returned `provider_env` `HashMap` is immediately transformed into: +Most returned provider credentials are transformed into: - a child-visible env map with placeholder values such as `openshell:resolve:env:ANTHROPIC_API_KEY`, and - a supervisor-only in-memory registry mapping each placeholder back to its real secret. +`microsoft-agent-s2s` is handled differently. Its blueprint secret and broker inputs are +removed from the child env path, used only by the sandbox supervisor to start a local +token resolver, and replaced with non-secret resolver metadata: + +- `OPENSHELL_MICROSOFT_AGENT_S2S_TOKEN_URL` +- `OPENSHELL_MICROSOFT_AGENT_S2S_DEFAULT_AUDIENCE` when one default audience is known +- `A365_TOKEN_PROVIDER_URL` as a compatibility alias for runtimes that expect A365 naming + The placeholder env map is threaded to the entrypoint process spawner and SSH server. The registry is threaded to the proxy so it can rewrite outbound headers. diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 6239978d7..18c3a78ff 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -604,6 +604,7 @@ enum CliProviderType { Opencode, Codex, Copilot, + MicrosoftAgentS2s, Generic, Openai, Anthropic, @@ -635,6 +636,7 @@ impl CliProviderType { Self::Opencode => "opencode", Self::Codex => "codex", Self::Copilot => "copilot", + Self::MicrosoftAgentS2s => "microsoft-agent-s2s", Self::Generic => "generic", Self::Openai => "openai", Self::Anthropic => "anthropic", diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index a3dd6826f..0d06bd40b 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -600,6 +600,72 @@ async fn explicit_provider_name_auto_creates_when_valid_type() { ); } +/// When `--provider microsoft-agent-s2s` is passed, no provider named +/// "microsoft-agent-s2s" exists, and it is a valid provider type, the CLI +/// should auto-create a provider using discovered Agent ID S2S credentials. +#[tokio::test] +async fn explicit_microsoft_agent_s2s_provider_name_auto_creates_when_valid_type() { + let ts = run_server().await; + let _guard = EnvVarGuard::set(&[ + ("AZURE_TENANT_ID", "tenant-id"), + ("A365_BLUEPRINT_CLIENT_ID", "blueprint-client-id"), + ("A365_BLUEPRINT_CLIENT_SECRET", "blueprint-secret"), + ("A365_RUNTIME_AGENT_ID", "runtime-agent-id"), + ("A365_ALLOWED_AUDIENCES", "api://aud-a,api://aud-b"), + ("A365_OBSERVABILITY_RESOURCE", "observability-resource"), + ("A365_REQUIRED_ROLES", "Agent365.Observability.OtelWrite"), + ]); + + let mut client = openshell_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client"); + + let result = run::ensure_required_providers( + &mut client, + &["microsoft-agent-s2s".to_string()], + &[], + Some(true), + ) + .await + .expect("should auto-create the provider"); + + assert_eq!(result, vec!["microsoft-agent-s2s".to_string()]); + + let providers = ts.openshell.state.providers.lock().await; + let provider = providers + .get("microsoft-agent-s2s") + .expect("microsoft-agent-s2s provider should exist"); + assert_eq!(provider.r#type, "microsoft-agent-s2s"); + assert_eq!( + provider.credentials.get("AZURE_TENANT_ID"), + Some(&"tenant-id".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_BLUEPRINT_CLIENT_ID"), + Some(&"blueprint-client-id".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_BLUEPRINT_CLIENT_SECRET"), + Some(&"blueprint-secret".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_RUNTIME_AGENT_ID"), + Some(&"runtime-agent-id".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_ALLOWED_AUDIENCES"), + Some(&"api://aud-a,api://aud-b".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_OBSERVABILITY_RESOURCE"), + Some(&"observability-resource".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_REQUIRED_ROLES"), + Some(&"Agent365.Observability.OtelWrite".to_string()) + ); +} + /// When `--provider my-custom-thing` is passed and "my-custom-thing" is not a /// known provider type, the CLI should return an error. #[tokio::test] diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index dc6ec9d4c..7c27ab885 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -28,32 +28,48 @@ use tokio_stream::wrappers::TcpListenerStream; use tonic::transport::{Certificate as TlsCertificate, Identity, Server, ServerTlsConfig}; use tonic::{Response, Status}; -struct EnvVarGuard { +static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + +struct SavedVar { key: &'static str, original: Option, } +struct EnvVarGuard { + vars: Vec, + _lock: std::sync::MutexGuard<'static, ()>, +} + #[allow(unsafe_code)] impl EnvVarGuard { - fn set(key: &'static str, value: &str) -> Self { - let original = std::env::var(key).ok(); - unsafe { - std::env::set_var(key, value); + fn set(pairs: &[(&'static str, &str)]) -> Self { + let lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut vars = Vec::with_capacity(pairs.len()); + for &(key, value) in pairs { + let original = std::env::var(key).ok(); + unsafe { + std::env::set_var(key, value); + } + vars.push(SavedVar { key, original }); } - Self { key, original } + Self { vars, _lock: lock } } } #[allow(unsafe_code)] impl Drop for EnvVarGuard { fn drop(&mut self) { - if let Some(value) = &self.original { - unsafe { - std::env::set_var(self.key, value); - } - } else { - unsafe { - std::env::remove_var(self.key); + for var in &self.vars { + if let Some(value) = &var.original { + unsafe { + std::env::set_var(var.key, value); + } + } else { + unsafe { + std::env::remove_var(var.key); + } } } } @@ -545,7 +561,7 @@ async fn provider_create_rejects_key_only_credentials_without_local_env_value() #[tokio::test] async fn provider_create_supports_generic_type_and_env_lookup_credentials() { let ts = run_server().await; - let _guard = EnvVarGuard::set("NAV_GENERIC_TEST_KEY", "generic-value"); + let _guard = EnvVarGuard::set(&[("NAV_GENERIC_TEST_KEY", "generic-value")]); run::provider_create( &ts.endpoint, @@ -577,6 +593,73 @@ async fn provider_create_supports_generic_type_and_env_lookup_credentials() { ); } +#[tokio::test] +async fn provider_create_from_existing_supports_microsoft_agent_s2s_type() { + let ts = run_server().await; + let _guard = EnvVarGuard::set(&[ + ("AZURE_TENANT_ID", "tenant-id"), + ("A365_BLUEPRINT_CLIENT_ID", "blueprint-client-id"), + ("A365_BLUEPRINT_CLIENT_SECRET", "blueprint-secret"), + ("A365_RUNTIME_AGENT_ID", "runtime-agent-id"), + ("A365_ALLOWED_AUDIENCES", "api://aud-a,api://aud-b"), + ("A365_OBSERVABILITY_RESOURCE", "observability-resource"), + ("A365_REQUIRED_ROLES", "Agent365.Observability.OtelWrite"), + ]); + + run::provider_create( + &ts.endpoint, + "my-microsoft-agent-s2s", + "microsoft-agent-s2s", + true, + &[], + &[], + &ts.tls, + ) + .await + .expect("provider create"); + + let mut client = openshell_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client should connect"); + let response = client + .get_provider(GetProviderRequest { + name: "my-microsoft-agent-s2s".to_string(), + }) + .await + .expect("get provider should succeed") + .into_inner(); + let provider = response.provider.expect("provider should exist"); + assert_eq!(provider.r#type, "microsoft-agent-s2s"); + assert_eq!( + provider.credentials.get("AZURE_TENANT_ID"), + Some(&"tenant-id".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_BLUEPRINT_CLIENT_ID"), + Some(&"blueprint-client-id".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_BLUEPRINT_CLIENT_SECRET"), + Some(&"blueprint-secret".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_RUNTIME_AGENT_ID"), + Some(&"runtime-agent-id".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_ALLOWED_AUDIENCES"), + Some(&"api://aud-a,api://aud-b".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_OBSERVABILITY_RESOURCE"), + Some(&"observability-resource".to_string()) + ); + assert_eq!( + provider.credentials.get("A365_REQUIRED_ROLES"), + Some(&"Agent365.Observability.OtelWrite".to_string()) + ); +} + #[tokio::test] async fn provider_create_rejects_combined_from_existing_and_credentials() { let ts = run_server().await; @@ -603,7 +686,7 @@ async fn provider_create_rejects_combined_from_existing_and_credentials() { #[tokio::test] async fn provider_create_rejects_empty_env_var_for_key_only_credential() { let ts = run_server().await; - let _guard = EnvVarGuard::set("NAV_EMPTY_ENV_KEY", ""); + let _guard = EnvVarGuard::set(&[("NAV_EMPTY_ENV_KEY", "")]); let err = run::provider_create( &ts.endpoint, @@ -627,7 +710,7 @@ async fn provider_create_rejects_empty_env_var_for_key_only_credential() { #[tokio::test] async fn provider_create_supports_nvidia_type_with_nvidia_api_key() { let ts = run_server().await; - let _guard = EnvVarGuard::set("NVIDIA_API_KEY", "nvapi-live-test"); + let _guard = EnvVarGuard::set(&[("NVIDIA_API_KEY", "nvapi-live-test")]); run::provider_create( &ts.endpoint, diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 50a4fa651..7110eb86f 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -718,6 +718,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { let _env = test_env(&fake_ssh_dir, &xdg_dir); let tls = test_tls(&server); install_fake_ssh(&fake_ssh_dir); + let forward_port = unused_local_port(); run::sandbox_create( &server.endpoint, @@ -732,7 +733,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { None, &[], None, - Some(openshell_core::forward::ForwardSpec::new(8080)), + Some(openshell_core::forward::ForwardSpec::new(forward_port)), &["echo".to_string(), "OK".to_string()], Some(false), Some(false), @@ -744,3 +745,8 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { assert!(deleted_names(&server).await.is_empty()); } + +fn unused_local_port() -> u16 { + let listener = std::net::TcpListener::bind(("127.0.0.1", 0)).unwrap(); + listener.local_addr().unwrap().port() +} diff --git a/crates/openshell-provider-auth/Cargo.toml b/crates/openshell-provider-auth/Cargo.toml new file mode 100644 index 000000000..8cc09b5e7 --- /dev/null +++ b/crates/openshell-provider-auth/Cargo.toml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-provider-auth" +description = "Runtime provider authentication brokers for OpenShell" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +base64 = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } +url = { workspace = true } + +[lints] +workspace = true diff --git a/crates/openshell-provider-auth/src/lib.rs b/crates/openshell-provider-auth/src/lib.rs new file mode 100644 index 000000000..38df98014 --- /dev/null +++ b/crates/openshell-provider-auth/src/lib.rs @@ -0,0 +1,6 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Runtime authentication brokers for `OpenShell` providers. + +pub mod microsoft_s2s; diff --git a/crates/openshell-provider-auth/src/microsoft_s2s.rs b/crates/openshell-provider-auth/src/microsoft_s2s.rs new file mode 100644 index 000000000..20c81f381 --- /dev/null +++ b/crates/openshell-provider-auth/src/microsoft_s2s.rs @@ -0,0 +1,990 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use base64::Engine as _; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; +use std::collections::{BTreeSet, HashMap}; +use std::sync::Arc; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +use tokio::sync::Mutex; +use url::Url; + +const AZURE_TOKEN_EXCHANGE_SCOPE: &str = "api://AzureADTokenExchange/.default"; +const CLIENT_ASSERTION_TYPE_JWT_BEARER: &str = + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; +const DEFAULT_AUTHORITY_HOST: &str = "https://login.microsoftonline.com"; +const DEFAULT_REFRESH_SKEW: Duration = Duration::from_secs(300); + +#[derive(Debug, thiserror::Error)] +pub enum MicrosoftS2sError { + #[error("invalid Microsoft S2S provider config: {0}")] + InvalidConfig(String), + #[error("audience '{0}' is not allowed by provider config")] + AudienceDenied(String), + #[error("failed to build token endpoint URL: {0}")] + Url(String), + #[error("Microsoft token request failed with HTTP {status}: {body}")] + TokenHttp { status: StatusCode, body: String }, + #[error("Microsoft token request failed: {0}")] + TokenTransport(String), + #[error("Microsoft token response did not include an access token")] + MissingAccessToken, + #[error("Microsoft token claim validation failed: {0}")] + ClaimValidation(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MicrosoftS2sConfig { + pub tenant_id: String, + pub blueprint_client_id: String, + pub blueprint_client_secret: String, + pub runtime_agent_id: String, + pub allowed_audiences: Vec, + pub observability_resource: Option, + pub required_roles: Vec, +} + +impl MicrosoftS2sConfig { + pub fn from_provider_maps( + credentials: &HashMap, + config: &HashMap, + ) -> Result { + let provider_value = |key: &str| { + credentials + .get(key) + .or_else(|| config.get(key)) + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + }; + + let allowed_audiences = provider_value("A365_ALLOWED_AUDIENCES") + .map(|value| split_csv(&value)) + .unwrap_or_default(); + let required_roles = provider_value("A365_REQUIRED_ROLES") + .map(|value| split_csv(&value)) + .unwrap_or_default(); + + let cfg = Self { + tenant_id: provider_value("AZURE_TENANT_ID").unwrap_or_default(), + blueprint_client_id: provider_value("A365_BLUEPRINT_CLIENT_ID").unwrap_or_default(), + blueprint_client_secret: provider_value("A365_BLUEPRINT_CLIENT_SECRET") + .unwrap_or_default(), + runtime_agent_id: provider_value("A365_RUNTIME_AGENT_ID").unwrap_or_default(), + allowed_audiences, + observability_resource: provider_value("A365_OBSERVABILITY_RESOURCE"), + required_roles, + }; + cfg.validate()?; + Ok(cfg) + } + + pub fn validate(&self) -> Result<(), MicrosoftS2sError> { + require_non_empty("AZURE_TENANT_ID", &self.tenant_id)?; + require_non_empty("A365_BLUEPRINT_CLIENT_ID", &self.blueprint_client_id)?; + require_non_empty( + "A365_BLUEPRINT_CLIENT_SECRET", + &self.blueprint_client_secret, + )?; + require_non_empty("A365_RUNTIME_AGENT_ID", &self.runtime_agent_id)?; + + if self.allowed_audiences.is_empty() && self.observability_resource.is_none() { + return Err(MicrosoftS2sError::InvalidConfig( + "at least one allowed audience or observability resource is required".to_string(), + )); + } + + Ok(()) + } + + fn allowed_audience_set(&self) -> BTreeSet { + let mut allowed = self + .allowed_audiences + .iter() + .map(|audience| normalize_audience(audience)) + .filter(|audience| !audience.is_empty()) + .collect::>(); + if let Some(resource) = &self.observability_resource { + let normalized = normalize_audience(resource); + if !normalized.is_empty() { + allowed.insert(normalized); + } + } + allowed + } +} + +#[derive(Debug, Clone)] +pub struct MicrosoftS2sBrokerOptions { + pub authority_host: Url, + pub refresh_skew: Duration, +} + +impl Default for MicrosoftS2sBrokerOptions { + fn default() -> Self { + Self { + authority_host: Url::parse(DEFAULT_AUTHORITY_HOST) + .expect("default authority host should parse"), + refresh_skew: DEFAULT_REFRESH_SKEW, + } + } +} + +#[derive(Clone, Debug)] +pub struct MicrosoftS2sBroker { + config: Arc, + client: reqwest::Client, + authority_host: Url, + refresh_skew: Duration, + cache: Arc>>, +} + +impl MicrosoftS2sBroker { + pub fn new(config: MicrosoftS2sConfig) -> Result { + Self::with_options(config, MicrosoftS2sBrokerOptions::default()) + } + + pub fn with_options( + config: MicrosoftS2sConfig, + options: MicrosoftS2sBrokerOptions, + ) -> Result { + config.validate()?; + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_secs(30)) + .build() + .map_err(|e| MicrosoftS2sError::TokenTransport(e.to_string()))?; + Ok(Self { + config: Arc::new(config), + client, + authority_host: options.authority_host, + refresh_skew: options.refresh_skew, + cache: Arc::new(Mutex::new(HashMap::new())), + }) + } + + pub async fn authorization_header( + &self, + audience: &str, + ) -> Result { + let token = self.access_token(audience).await?; + Ok(AuthorizationHeader { + value: format!("Bearer {}", token.access_token), + expires_at_unix: token.expires_at_unix, + cache_hit: token.cache_hit, + }) + } + + pub async fn access_token( + &self, + audience: &str, + ) -> Result { + let audience = normalize_audience(audience); + self.ensure_allowed_audience(&audience)?; + + let cache_key = CacheKey { + tenant_id: self.config.tenant_id.clone(), + runtime_agent_id: self.config.runtime_agent_id.clone(), + audience: audience.clone(), + }; + + if let Some(cached) = self.cached_token(&cache_key).await { + return Ok(BrokeredAccessToken { + access_token: cached.access_token, + expires_at_unix: cached.expires_at_unix, + cache_hit: true, + }); + } + + let assertion = self.fetch_blueprint_assertion().await?; + let token = self + .fetch_runtime_agent_token(&audience, &assertion) + .await?; + self.validate_runtime_token_claims(&audience, &token.access_token)?; + + let expires_at = token.expires_at(self.refresh_skew); + let expires_at_unix = token.expires_at_unix(); + let cached = CachedToken { + access_token: token.access_token, + expires_at, + expires_at_unix, + }; + self.cache.lock().await.insert(cache_key, cached.clone()); + + Ok(BrokeredAccessToken { + access_token: cached.access_token, + expires_at_unix: cached.expires_at_unix, + cache_hit: false, + }) + } + + pub async fn evict(&self, audience: &str) { + let cache_key = CacheKey { + tenant_id: self.config.tenant_id.clone(), + runtime_agent_id: self.config.runtime_agent_id.clone(), + audience: normalize_audience(audience), + }; + self.cache.lock().await.remove(&cache_key); + } + + fn ensure_allowed_audience(&self, audience: &str) -> Result<(), MicrosoftS2sError> { + let allowed = self.config.allowed_audience_set(); + if allowed.contains(audience) { + Ok(()) + } else { + Err(MicrosoftS2sError::AudienceDenied(audience.to_string())) + } + } + + async fn cached_token(&self, cache_key: &CacheKey) -> Option { + let cached = self.cache.lock().await.get(cache_key).cloned()?; + if Instant::now() < cached.expires_at { + Some(cached) + } else { + None + } + } + + async fn fetch_blueprint_assertion(&self) -> Result { + let endpoint = self.token_endpoint()?; + let form = [ + ("grant_type", "client_credentials"), + ("client_id", self.config.blueprint_client_id.as_str()), + ( + "client_secret", + self.config.blueprint_client_secret.as_str(), + ), + ("scope", AZURE_TOKEN_EXCHANGE_SCOPE), + ("fmi_path", self.config.runtime_agent_id.as_str()), + ]; + self.post_token_form(endpoint, &form).await + } + + async fn fetch_runtime_agent_token( + &self, + audience: &str, + assertion: &TokenResponse, + ) -> Result { + let endpoint = self.token_endpoint()?; + let scope = default_scope_for_audience(audience); + let form = [ + ("grant_type", "client_credentials"), + ("client_id", self.config.runtime_agent_id.as_str()), + ("client_assertion", assertion.access_token.as_str()), + ("client_assertion_type", CLIENT_ASSERTION_TYPE_JWT_BEARER), + ("scope", scope.as_str()), + ]; + self.post_token_form(endpoint, &form).await + } + + async fn post_token_form( + &self, + endpoint: Url, + form: &[(&str, &str)], + ) -> Result { + let response = self + .client + .post(endpoint) + .form(form) + .send() + .await + .map_err(|e| MicrosoftS2sError::TokenTransport(e.to_string()))?; + let status = response.status(); + let body = response + .text() + .await + .map_err(|e| MicrosoftS2sError::TokenTransport(e.to_string()))?; + + if !status.is_success() { + return Err(MicrosoftS2sError::TokenHttp { + status, + body: sanitize_error_body(&body), + }); + } + + let parsed = serde_json::from_str::(&body).map_err(|e| { + MicrosoftS2sError::TokenTransport(format!("failed to parse token response: {e}")) + })?; + if parsed.access_token.trim().is_empty() { + return Err(MicrosoftS2sError::MissingAccessToken); + } + Ok(parsed) + } + + fn token_endpoint(&self) -> Result { + self.authority_host + .join(&format!( + "{}/oauth2/v2.0/token", + self.config.tenant_id.trim_matches('/') + )) + .map_err(|e| MicrosoftS2sError::Url(e.to_string())) + } + + fn validate_runtime_token_claims( + &self, + audience: &str, + token: &str, + ) -> Result<(), MicrosoftS2sError> { + let claims = JwtClaims::decode_unverified(token)?; + claims.expect_audience(audience)?; + claims.expect_tenant(&self.config.tenant_id)?; + claims.expect_runtime_agent(&self.config.runtime_agent_id)?; + claims.expect_app_token()?; + claims.expect_roles(&self.config.required_roles)?; + claims.expect_not_expired()?; + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AuthorizationHeader { + pub value: String, + pub expires_at_unix: Option, + pub cache_hit: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BrokeredAccessToken { + pub access_token: String, + pub expires_at_unix: Option, + pub cache_hit: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct CacheKey { + tenant_id: String, + runtime_agent_id: String, + audience: String, +} + +#[derive(Debug, Clone)] +struct CachedToken { + access_token: String, + expires_at: Instant, + expires_at_unix: Option, +} + +#[derive(Debug, Clone, Deserialize)] +struct TokenResponse { + access_token: String, + expires_in: Option, + #[serde(default)] + expires_on: Option, +} + +impl TokenResponse { + fn expires_at(&self, refresh_skew: Duration) -> Instant { + let ttl = self.expires_in.unwrap_or(3600); + let ttl = Duration::from_secs(ttl); + Instant::now() + ttl.saturating_sub(refresh_skew) + } + + fn expires_at_unix(&self) -> Option { + if let Some(expires_on) = &self.expires_on + && let Ok(value) = expires_on.parse::() + { + return Some(value); + } + let expires_in = self.expires_in?; + let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs(); + Some(now.saturating_add(expires_in)) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +struct JwtClaims { + #[serde(default)] + aud: AudienceClaim, + #[serde(default)] + tid: Option, + #[serde(default)] + azp: Option, + #[serde(default)] + appid: Option, + #[serde(default)] + oid: Option, + #[serde(default)] + sub: Option, + #[serde(default)] + idtyp: Option, + #[serde(default)] + roles: Vec, + #[serde(default)] + scp: Option, + #[serde(default)] + exp: Option, + #[serde(default)] + nbf: Option, +} + +impl JwtClaims { + fn decode_unverified(token: &str) -> Result { + let mut parts = token.split('.'); + let _header = parts.next(); + let payload = parts + .next() + .ok_or_else(|| MicrosoftS2sError::ClaimValidation("token is not a JWT".to_string()))?; + let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(payload) + .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(payload)) + .map_err(|e| { + MicrosoftS2sError::ClaimValidation(format!("token payload decode failed: {e}")) + })?; + serde_json::from_slice(&decoded).map_err(|e| { + MicrosoftS2sError::ClaimValidation(format!("token payload parse failed: {e}")) + }) + } + + fn expect_audience(&self, audience: &str) -> Result<(), MicrosoftS2sError> { + if self + .aud + .values() + .iter() + .any(|actual| normalize_audience(actual) == audience) + { + Ok(()) + } else { + Err(MicrosoftS2sError::ClaimValidation(format!( + "audience claim does not include '{audience}'" + ))) + } + } + + fn expect_tenant(&self, tenant_id: &str) -> Result<(), MicrosoftS2sError> { + match self.tid.as_deref() { + Some(actual) if actual.eq_ignore_ascii_case(tenant_id) => Ok(()), + Some(actual) => Err(MicrosoftS2sError::ClaimValidation(format!( + "tenant claim '{actual}' does not match expected tenant" + ))), + None => Err(MicrosoftS2sError::ClaimValidation( + "missing tenant claim".to_string(), + )), + } + } + + fn expect_runtime_agent(&self, runtime_agent_id: &str) -> Result<(), MicrosoftS2sError> { + let expected = runtime_agent_id.to_ascii_lowercase(); + let matches = [&self.azp, &self.appid, &self.oid, &self.sub] + .into_iter() + .flatten() + .any(|value| value.to_ascii_lowercase() == expected); + if matches { + Ok(()) + } else { + Err(MicrosoftS2sError::ClaimValidation( + "token does not represent the runtime agent identity".to_string(), + )) + } + } + + fn expect_app_token(&self) -> Result<(), MicrosoftS2sError> { + match self.idtyp.as_deref() { + Some("app") => Ok(()), + Some(actual) => Err(MicrosoftS2sError::ClaimValidation(format!( + "expected app token, got idtyp='{actual}'" + ))), + None => Err(MicrosoftS2sError::ClaimValidation( + "missing idtyp claim".to_string(), + )), + } + } + + fn expect_roles(&self, required_roles: &[String]) -> Result<(), MicrosoftS2sError> { + for required in required_roles { + if !self.roles.iter().any(|role| role == required) { + return Err(MicrosoftS2sError::ClaimValidation(format!( + "missing required role '{required}'" + ))); + } + } + Ok(()) + } + + fn expect_not_expired(&self) -> Result<(), MicrosoftS2sError> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| MicrosoftS2sError::ClaimValidation(e.to_string()))? + .as_secs(); + if let Some(nbf) = self.nbf + && now.saturating_add(60) < nbf + { + return Err(MicrosoftS2sError::ClaimValidation( + "token is not valid yet".to_string(), + )); + } + if let Some(exp) = self.exp + && exp <= now.saturating_sub(60) + { + return Err(MicrosoftS2sError::ClaimValidation( + "token is expired".to_string(), + )); + } + Ok(()) + } +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +#[serde(untagged)] +enum AudienceClaim { + One(String), + Many(Vec), + #[default] + Missing, +} + +impl AudienceClaim { + fn values(&self) -> Vec<&str> { + match self { + Self::One(value) => vec![value.as_str()], + Self::Many(values) => values.iter().map(String::as_str).collect(), + Self::Missing => Vec::new(), + } + } +} + +fn require_non_empty(name: &str, value: &str) -> Result<(), MicrosoftS2sError> { + if value.trim().is_empty() { + Err(MicrosoftS2sError::InvalidConfig(format!( + "{name} is required" + ))) + } else { + Ok(()) + } +} + +fn normalize_audience(input: &str) -> String { + input + .trim() + .trim_end_matches("/.default") + .trim_end_matches('/') + .to_string() +} + +fn split_csv(value: &str) -> Vec { + value + .split(',') + .map(str::trim) + .filter(|part| !part.is_empty()) + .map(ToString::to_string) + .collect() +} + +fn default_scope_for_audience(audience: &str) -> String { + format!("{}/.default", normalize_audience(audience)) +} + +fn sanitize_error_body(body: &str) -> String { + const MAX_ERROR_BODY: usize = 1024; + body.chars() + .filter(|ch| !ch.is_control() || *ch == '\n' || *ch == '\t') + .take(MAX_ERROR_BODY) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use base64::engine::general_purpose::URL_SAFE_NO_PAD; + use std::net::SocketAddr; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + const TENANT: &str = "11111111-1111-4111-8111-111111111111"; + const BLUEPRINT: &str = "22222222-2222-4222-8222-222222222222"; + const RUNTIME_AGENT: &str = "33333333-3333-4333-8333-333333333333"; + const RESOURCE: &str = "api://44444444-4444-4444-8444-444444444444"; + + fn config() -> MicrosoftS2sConfig { + MicrosoftS2sConfig { + tenant_id: TENANT.to_string(), + blueprint_client_id: BLUEPRINT.to_string(), + blueprint_client_secret: "secret".to_string(), + runtime_agent_id: RUNTIME_AGENT.to_string(), + allowed_audiences: vec![RESOURCE.to_string()], + observability_resource: None, + required_roles: vec!["Agent365.Observability.OtelWrite".to_string()], + } + } + + fn broker(server: &FakeTokenServer) -> MicrosoftS2sBroker { + MicrosoftS2sBroker::with_options( + config(), + MicrosoftS2sBrokerOptions { + authority_host: Url::parse(&server.uri()).expect("fake server URL"), + refresh_skew: Duration::from_secs(60), + }, + ) + .expect("broker") + } + + fn jwt(claims: serde_json::Value) -> String { + let header = serde_json::json!({"alg": "none", "typ": "JWT"}); + let header = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header).unwrap()); + let payload = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&claims).unwrap()); + format!("{header}.{payload}.signature") + } + + fn runtime_token() -> String { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + jwt(serde_json::json!({ + "aud": RESOURCE, + "tid": TENANT, + "azp": RUNTIME_AGENT, + "oid": RUNTIME_AGENT, + "sub": RUNTIME_AGENT, + "idtyp": "app", + "roles": ["Agent365.Observability.OtelWrite"], + "nbf": now.saturating_sub(30), + "exp": now + 3600 + })) + } + + #[test] + fn builds_config_from_provider_maps() { + let credentials = HashMap::from([ + ("AZURE_TENANT_ID".to_string(), TENANT.to_string()), + ( + "A365_BLUEPRINT_CLIENT_SECRET".to_string(), + "secret".to_string(), + ), + ]); + let config = HashMap::from([ + ( + "A365_BLUEPRINT_CLIENT_ID".to_string(), + BLUEPRINT.to_string(), + ), + ( + "A365_RUNTIME_AGENT_ID".to_string(), + RUNTIME_AGENT.to_string(), + ), + ( + "A365_ALLOWED_AUDIENCES".to_string(), + format!("{RESOURCE}, api://extra/.default"), + ), + ( + "A365_REQUIRED_ROLES".to_string(), + "Agent365.Observability.OtelWrite".to_string(), + ), + ]); + + let cfg = MicrosoftS2sConfig::from_provider_maps(&credentials, &config) + .expect("provider maps should build config"); + + assert_eq!(cfg.tenant_id, TENANT); + assert_eq!(cfg.blueprint_client_id, BLUEPRINT); + assert_eq!(cfg.blueprint_client_secret, "secret"); + assert_eq!(cfg.runtime_agent_id, RUNTIME_AGENT); + assert_eq!( + cfg.allowed_audiences, + vec![RESOURCE.to_string(), "api://extra/.default".to_string()] + ); + assert_eq!( + cfg.required_roles, + vec!["Agent365.Observability.OtelWrite".to_string()] + ); + } + + #[derive(Debug, Default)] + struct FakeTokenState { + runtime_token: Mutex, + blueprint_requests: Mutex, + runtime_requests: Mutex, + } + + #[derive(Clone)] + struct FakeTokenServer { + addr: SocketAddr, + state: Arc, + } + + impl FakeTokenServer { + async fn start(runtime_token: String) -> Self { + let state = Arc::new(FakeTokenState { + runtime_token: Mutex::new(runtime_token), + blueprint_requests: Mutex::new(0), + runtime_requests: Mutex::new(0), + }); + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind fake token server"); + let addr = listener.local_addr().expect("fake token server addr"); + let server_state = state.clone(); + tokio::spawn(async move { + loop { + let Ok((stream, _peer)) = listener.accept().await else { + break; + }; + let state = server_state.clone(); + tokio::spawn(async move { + handle_token_connection(stream, state).await; + }); + } + }); + Self { addr, state } + } + + fn uri(&self) -> String { + format!("http://{}", self.addr) + } + + async fn request_counts(&self) -> (usize, usize) { + ( + *self.state.blueprint_requests.lock().await, + *self.state.runtime_requests.lock().await, + ) + } + } + + async fn handle_token_connection( + mut stream: tokio::net::TcpStream, + state: Arc, + ) { + let mut buffer = Vec::new(); + let mut temp = [0_u8; 1024]; + let mut content_length = None; + let mut header_end = None; + + loop { + let read = stream.read(&mut temp).await.expect("read fake request"); + if read == 0 { + return; + } + buffer.extend_from_slice(&temp[..read]); + if header_end.is_none() + && let Some(pos) = find_header_end(&buffer) + { + header_end = Some(pos); + let headers = String::from_utf8_lossy(&buffer[..pos]); + content_length = parse_content_length(&headers); + } + if let (Some(end), Some(len)) = (header_end, content_length) + && buffer.len() >= end + 4 + len + { + break; + } + } + + let end = header_end.expect("headers should be present"); + let len = content_length.expect("content length should be present"); + let body = &buffer[end + 4..end + 4 + len]; + let form = url::form_urlencoded::parse(body) + .into_owned() + .collect::>(); + let response = token_response_for_form(&state, &form).await; + stream + .write_all(response.as_bytes()) + .await + .expect("write fake response"); + } + + async fn token_response_for_form( + state: &Arc, + form: &HashMap, + ) -> String { + if form + .get("client_id") + .is_some_and(|value| value == BLUEPRINT) + { + assert_eq!( + form.get("grant_type").map(String::as_str), + Some("client_credentials") + ); + assert_eq!( + form.get("scope").map(String::as_str), + Some(AZURE_TOKEN_EXCHANGE_SCOPE) + ); + assert_eq!( + form.get("fmi_path").map(String::as_str), + Some(RUNTIME_AGENT) + ); + *state.blueprint_requests.lock().await += 1; + return json_response( + 200, + serde_json::json!({ + "token_type": "Bearer", + "expires_in": 3600, + "access_token": "blueprint-assertion" + }), + ); + } + + if form + .get("client_id") + .is_some_and(|value| value == RUNTIME_AGENT) + { + assert_eq!( + form.get("grant_type").map(String::as_str), + Some("client_credentials") + ); + assert_eq!( + form.get("client_assertion").map(String::as_str), + Some("blueprint-assertion") + ); + assert_eq!( + form.get("client_assertion_type").map(String::as_str), + Some(CLIENT_ASSERTION_TYPE_JWT_BEARER) + ); + let expected_scope = format!("{RESOURCE}/.default"); + assert_eq!( + form.get("scope").map(String::as_str), + Some(expected_scope.as_str()) + ); + *state.runtime_requests.lock().await += 1; + let runtime_token = state.runtime_token.lock().await.clone(); + return json_response( + 200, + serde_json::json!({ + "token_type": "Bearer", + "expires_in": 3600, + "access_token": runtime_token + }), + ); + } + + json_response( + 400, + serde_json::json!({"error": "unexpected token request"}), + ) + } + + fn find_header_end(buffer: &[u8]) -> Option { + buffer.windows(4).position(|window| window == b"\r\n\r\n") + } + + fn parse_content_length(headers: &str) -> Option { + headers.lines().find_map(|line| { + let (name, value) = line.split_once(':')?; + if name.eq_ignore_ascii_case("content-length") { + value.trim().parse().ok() + } else { + None + } + }) + } + + fn json_response(status: u16, body: serde_json::Value) -> String { + let reason = if status == 200 { "OK" } else { "Bad Request" }; + let body = body.to_string(); + format!( + "HTTP/1.1 {status} {reason}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ) + } + + #[tokio::test] + async fn mints_runtime_agent_token_with_two_step_exchange() { + let runtime_token = runtime_token(); + let server = FakeTokenServer::start(runtime_token.clone()).await; + + let token = broker(&server) + .access_token(RESOURCE) + .await + .expect("token should mint"); + + assert_eq!(token.access_token, runtime_token); + assert!(!token.cache_hit); + assert!(token.expires_at_unix.is_some()); + assert_eq!(server.request_counts().await, (1, 1)); + } + + #[tokio::test] + async fn returns_cached_token_for_same_audience() { + let runtime_token = runtime_token(); + let server = FakeTokenServer::start(runtime_token.clone()).await; + let broker = broker(&server); + + let first = broker.access_token(RESOURCE).await.expect("first token"); + let second = broker.access_token(RESOURCE).await.expect("cached token"); + + assert_eq!(first.access_token, second.access_token); + assert!(!first.cache_hit); + assert!(second.cache_hit); + assert_eq!(server.request_counts().await, (1, 1)); + } + + #[tokio::test] + async fn rejects_unallowed_audience_before_network_call() { + let server = FakeTokenServer::start(runtime_token()).await; + let err = broker(&server) + .access_token("api://not-allowed") + .await + .expect_err("audience should be denied"); + + assert!(matches!(err, MicrosoftS2sError::AudienceDenied(_))); + assert_eq!(server.request_counts().await, (0, 0)); + } + + #[tokio::test] + async fn validates_runtime_agent_claims() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let wrong_agent_token = jwt(serde_json::json!({ + "aud": RESOURCE, + "tid": TENANT, + "azp": "a185cf21-03c8-4bf1-919a-ec8f0782118d", + "idtyp": "app", + "nbf": now.saturating_sub(30), + "exp": now + 3600 + })); + let server = FakeTokenServer::start(wrong_agent_token).await; + + let err = broker(&server) + .access_token(RESOURCE) + .await + .expect_err("wrong runtime agent should fail validation"); + + assert!(matches!(err, MicrosoftS2sError::ClaimValidation(_))); + assert!( + err.to_string().contains("runtime agent identity"), + "unexpected error: {err}" + ); + } + + #[tokio::test] + async fn validates_required_roles() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let missing_role_token = jwt(serde_json::json!({ + "aud": RESOURCE, + "tid": TENANT, + "azp": RUNTIME_AGENT, + "oid": RUNTIME_AGENT, + "sub": RUNTIME_AGENT, + "idtyp": "app", + "roles": ["Other.Role"], + "nbf": now.saturating_sub(30), + "exp": now + 3600 + })); + let server = FakeTokenServer::start(missing_role_token).await; + + let err = broker(&server) + .access_token(RESOURCE) + .await + .expect_err("missing required role should fail validation"); + + assert!(matches!(err, MicrosoftS2sError::ClaimValidation(_))); + assert!( + err.to_string().contains("missing required role"), + "unexpected error: {err}" + ); + } + + #[tokio::test] + async fn returns_authorization_header() { + let server = FakeTokenServer::start(runtime_token()).await; + + let header = broker(&server) + .authorization_header(RESOURCE) + .await + .expect("authorization header"); + + assert!(header.value.starts_with("Bearer ")); + assert!(!header.cache_hit); + } +} diff --git a/crates/openshell-providers/src/lib.rs b/crates/openshell-providers/src/lib.rs index e2bcc0c09..c622204a2 100644 --- a/crates/openshell-providers/src/lib.rs +++ b/crates/openshell-providers/src/lib.rs @@ -79,6 +79,7 @@ impl ProviderRegistry { registry.register(providers::codex::CodexProvider); registry.register(providers::copilot::CopilotProvider); registry.register(providers::opencode::OpencodeProvider); + registry.register(providers::microsoft_agent_s2s::MicrosoftAgentS2sProvider); registry.register(providers::generic::GenericProvider); registry.register(providers::openai::OpenaiProvider); registry.register(providers::anthropic::AnthropicProvider); @@ -131,6 +132,9 @@ pub fn normalize_provider_type(input: &str) -> Option<&'static str> { "codex" => Some("codex"), "copilot" => Some("copilot"), "opencode" => Some("opencode"), + "microsoft-agent-s2s" | "agent-s2s" | "agent-id-s2s" | "a365-s2s" => { + Some("microsoft-agent-s2s") + } "generic" => Some("generic"), "openai" => Some("openai"), "anthropic" => Some("anthropic"), @@ -162,6 +166,18 @@ mod tests { assert_eq!(normalize_provider_type("glab"), Some("gitlab")); assert_eq!(normalize_provider_type("gh"), Some("github")); assert_eq!(normalize_provider_type("CLAUDE"), Some("claude")); + assert_eq!( + normalize_provider_type("agent-s2s"), + Some("microsoft-agent-s2s") + ); + assert_eq!( + normalize_provider_type("agent-id-s2s"), + Some("microsoft-agent-s2s") + ); + assert_eq!( + normalize_provider_type("a365-s2s"), + Some("microsoft-agent-s2s") + ); assert_eq!(normalize_provider_type("generic"), Some("generic")); assert_eq!(normalize_provider_type("openai"), Some("openai")); assert_eq!(normalize_provider_type("anthropic"), Some("anthropic")); diff --git a/crates/openshell-providers/src/providers/microsoft_agent_s2s.rs b/crates/openshell-providers/src/providers/microsoft_agent_s2s.rs new file mode 100644 index 000000000..d0930e231 --- /dev/null +++ b/crates/openshell-providers/src/providers/microsoft_agent_s2s.rs @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, discover_with_spec, +}; + +pub struct MicrosoftAgentS2sProvider; + +pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { + id: "microsoft-agent-s2s", + credential_env_vars: &[ + "AZURE_TENANT_ID", + "A365_BLUEPRINT_CLIENT_ID", + "A365_BLUEPRINT_CLIENT_SECRET", + "A365_RUNTIME_AGENT_ID", + "A365_ALLOWED_AUDIENCES", + "A365_OBSERVABILITY_RESOURCE", + "A365_REQUIRED_ROLES", + ], +}; + +impl ProviderPlugin for MicrosoftAgentS2sProvider { + fn id(&self) -> &'static str { + SPEC.id + } + + fn discover_existing(&self) -> Result, ProviderError> { + discover_with_spec(&SPEC, &RealDiscoveryContext) + } + + fn credential_env_vars(&self) -> &'static [&'static str] { + SPEC.credential_env_vars + } +} + +#[cfg(test)] +mod tests { + use super::SPEC; + use crate::discover_with_spec; + use crate::test_helpers::MockDiscoveryContext; + + #[test] + fn discovers_microsoft_agent_s2s_env_credentials() { + let ctx = MockDiscoveryContext::new() + .with_env("AZURE_TENANT_ID", "tenant-id") + .with_env("A365_BLUEPRINT_CLIENT_ID", "blueprint-client-id") + .with_env("A365_BLUEPRINT_CLIENT_SECRET", "blueprint-secret") + .with_env("A365_RUNTIME_AGENT_ID", "runtime-agent-id") + .with_env("A365_ALLOWED_AUDIENCES", "api://aud-a,api://aud-b") + .with_env("A365_OBSERVABILITY_RESOURCE", "observability-resource") + .with_env("A365_REQUIRED_ROLES", "Agent365.Observability.OtelWrite"); + let discovered = discover_with_spec(&SPEC, &ctx) + .expect("discovery") + .expect("provider"); + assert_eq!( + discovered.credentials.get("AZURE_TENANT_ID"), + Some(&"tenant-id".to_string()) + ); + assert_eq!( + discovered.credentials.get("A365_BLUEPRINT_CLIENT_ID"), + Some(&"blueprint-client-id".to_string()) + ); + assert_eq!( + discovered.credentials.get("A365_BLUEPRINT_CLIENT_SECRET"), + Some(&"blueprint-secret".to_string()) + ); + assert_eq!( + discovered.credentials.get("A365_RUNTIME_AGENT_ID"), + Some(&"runtime-agent-id".to_string()) + ); + assert_eq!( + discovered.credentials.get("A365_ALLOWED_AUDIENCES"), + Some(&"api://aud-a,api://aud-b".to_string()) + ); + assert_eq!( + discovered.credentials.get("A365_OBSERVABILITY_RESOURCE"), + Some(&"observability-resource".to_string()) + ); + assert_eq!( + discovered.credentials.get("A365_REQUIRED_ROLES"), + Some(&"Agent365.Observability.OtelWrite".to_string()) + ); + } +} diff --git a/crates/openshell-providers/src/providers/mod.rs b/crates/openshell-providers/src/providers/mod.rs index 6fe395135..966c7058b 100644 --- a/crates/openshell-providers/src/providers/mod.rs +++ b/crates/openshell-providers/src/providers/mod.rs @@ -8,6 +8,7 @@ pub mod copilot; pub mod generic; pub mod github; pub mod gitlab; +pub mod microsoft_agent_s2s; pub mod nvidia; pub mod openai; pub mod opencode; diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 78d8ac741..63d7e8a18 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -18,6 +18,7 @@ path = "src/main.rs" openshell-core = { path = "../openshell-core" } openshell-ocsf = { path = "../openshell-ocsf" } openshell-policy = { path = "../openshell-policy" } +openshell-provider-auth = { path = "../openshell-provider-auth" } openshell-router = { path = "../openshell-router" } # Async runtime @@ -65,6 +66,7 @@ ipnet = "2" # Serialization serde_json = { workspace = true } serde_yml = { workspace = true } +url = { workspace = true } # Logging tracing = { workspace = true } diff --git a/crates/openshell-sandbox/src/child_env.rs b/crates/openshell-sandbox/src/child_env.rs index 914e06ea5..c45e7305f 100644 --- a/crates/openshell-sandbox/src/child_env.rs +++ b/crates/openshell-sandbox/src/child_env.rs @@ -3,7 +3,7 @@ use std::path::Path; -const LOCAL_NO_PROXY: &str = "127.0.0.1,localhost,::1"; +const LOCAL_NO_PROXY: &str = "127.0.0.1,localhost,::1,10.200.0.1"; pub(crate) fn proxy_env_vars(proxy_url: &str) -> [(&'static str, String); 9] { [ @@ -59,9 +59,9 @@ mod tests { let stdout = String::from_utf8(output.stdout).expect("utf8"); assert!(stdout.contains("HTTP_PROXY=http://10.200.0.1:3128")); - assert!(stdout.contains("NO_PROXY=127.0.0.1,localhost,::1")); + assert!(stdout.contains("NO_PROXY=127.0.0.1,localhost,::1,10.200.0.1")); assert!(stdout.contains("NODE_USE_ENV_PROXY=1")); - assert!(stdout.contains("no_proxy=127.0.0.1,localhost,::1")); + assert!(stdout.contains("no_proxy=127.0.0.1,localhost,::1,10.200.0.1")); } #[test] diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 34ee80bb5..6aac8e8fd 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -17,6 +17,7 @@ pub mod opa; mod policy; mod process; pub mod procfs; +mod provider_tokens; pub mod proxy; mod sandbox; mod secrets; @@ -267,7 +268,7 @@ pub async fn run_sandbox( // Fetch provider environment variables from the server. // This is done after loading the policy so the sandbox can still start // even if provider env fetch fails (graceful degradation). - let provider_env = if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { + let mut provider_env = if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { match grpc_client::fetch_provider_environment(endpoint, id).await { Ok(env) => { ocsf_emit!( @@ -300,9 +301,8 @@ pub async fn run_sandbox( } else { std::collections::HashMap::new() }; - - let (provider_env, secret_resolver) = SecretResolver::from_provider_env(provider_env); - let secret_resolver = secret_resolver.map(Arc::new); + let provider_token_resolver_port = + provider_tokens::microsoft_agent_s2s_resolver_port(&provider_env); // Create identity cache for SHA256 TOFU when OPA is active let identity_cache = opa_engine @@ -387,7 +387,11 @@ pub async fn run_sandbox( .as_ref() .and_then(|p| p.http_addr) .map_or(3128, |addr| addr.port()); - if let Err(e) = ns.install_bypass_rules(proxy_port) { + let provider_token_ports = provider_token_resolver_port + .iter() + .copied() + .collect::>(); + if let Err(e) = ns.install_bypass_rules(proxy_port, &provider_token_ports) { ocsf_emit!( ConfigStateChangeBuilder::new(ocsf_ctx()) .severity(SeverityId::Medium) @@ -423,6 +427,40 @@ pub async fn run_sandbox( // listener and workload process are exposed. apply_supervisor_startup_hardening()?; + #[cfg(target_os = "linux")] + let provider_token_resolver_bind_addr = { + let ip = netns.as_ref().map_or( + std::net::IpAddr::from([127, 0, 0, 1]), + NetworkNamespace::host_ip, + ); + SocketAddr::new(ip, provider_token_resolver_port.unwrap_or(0)) + }; + + #[cfg(not(target_os = "linux"))] + let provider_token_resolver_bind_addr = + SocketAddr::from(([127, 0, 0, 1], provider_token_resolver_port.unwrap_or(0))); + + let prepared_provider_tokens = provider_tokens::prepare_microsoft_agent_s2s( + &mut provider_env, + provider_token_resolver_bind_addr, + ) + .await?; + if !prepared_provider_tokens.environment.is_empty() { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "enabled") + .message("Started microsoft-agent-s2s provider token resolver") + .build() + ); + } + let provider_token_environment = prepared_provider_tokens.environment; + let _provider_token_resolver = prepared_provider_tokens.handle; + let (mut provider_env, secret_resolver) = SecretResolver::from_provider_env(provider_env); + provider_env.extend(provider_token_environment); + let secret_resolver = secret_resolver.map(Arc::new); + // Shared PID: set after process spawn so the proxy can look up // the entrypoint process's /proc/net/tcp for identity binding. let entrypoint_pid = Arc::new(AtomicU32::new(0)); diff --git a/crates/openshell-sandbox/src/provider_tokens.rs b/crates/openshell-sandbox/src/provider_tokens.rs new file mode 100644 index 000000000..0db9afdbe --- /dev/null +++ b/crates/openshell-sandbox/src/provider_tokens.rs @@ -0,0 +1,437 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Sandbox-local provider token resolvers. + +use miette::{IntoDiagnostic, Result, WrapErr}; +use openshell_provider_auth::microsoft_s2s::{MicrosoftS2sBroker, MicrosoftS2sConfig}; +use std::collections::HashMap; +use std::net::SocketAddr; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::task::JoinHandle; +use tracing::{debug, warn}; + +const MAX_REQUEST_HEADER_BYTES: usize = 8192; +const MICROSOFT_AGENT_S2S_TOKEN_PATH: &str = "/v1/microsoft-agent-s2s/token"; +const MICROSOFT_AGENT_S2S_RESOLVER_PORT: u16 = 3130; +const TOKEN_URL_ENV: &str = "OPENSHELL_MICROSOFT_AGENT_S2S_TOKEN_URL"; +const DEFAULT_AUDIENCE_ENV: &str = "OPENSHELL_MICROSOFT_AGENT_S2S_DEFAULT_AUDIENCE"; +const A365_TOKEN_PROVIDER_URL_ENV: &str = "A365_TOKEN_PROVIDER_URL"; + +const MICROSOFT_AGENT_S2S_KEYS: &[&str] = &[ + "AZURE_TENANT_ID", + "A365_BLUEPRINT_CLIENT_ID", + "A365_BLUEPRINT_CLIENT_SECRET", + "A365_RUNTIME_AGENT_ID", + "A365_ALLOWED_AUDIENCES", + "A365_OBSERVABILITY_RESOURCE", + "A365_REQUIRED_ROLES", +]; + +const MICROSOFT_AGENT_S2S_MARKER_KEYS: &[&str] = &[ + "A365_BLUEPRINT_CLIENT_ID", + "A365_BLUEPRINT_CLIENT_SECRET", + "A365_RUNTIME_AGENT_ID", +]; + +pub(crate) struct PreparedProviderTokenResolver { + pub environment: HashMap, + pub handle: Option, +} + +pub(crate) fn microsoft_agent_s2s_resolver_port( + provider_env: &HashMap, +) -> Option { + contains_microsoft_agent_s2s_inputs(provider_env).then_some(MICROSOFT_AGENT_S2S_RESOLVER_PORT) +} + +#[derive(Debug)] +pub(crate) struct ProviderTokenResolverHandle { + local_addr: SocketAddr, + token_path: String, + join: JoinHandle<()>, +} + +impl ProviderTokenResolverHandle { + fn url(&self) -> String { + format!("http://{}{}", self.local_addr, self.token_path) + } +} + +impl Drop for ProviderTokenResolverHandle { + fn drop(&mut self) { + self.join.abort(); + } +} + +pub(crate) async fn prepare_microsoft_agent_s2s( + raw_provider_env: &mut HashMap, + bind_addr: SocketAddr, +) -> Result { + if !contains_microsoft_agent_s2s_inputs(raw_provider_env) { + return Ok(PreparedProviderTokenResolver { + environment: HashMap::new(), + handle: None, + }); + } + + let provider_map = remove_microsoft_agent_s2s_inputs(raw_provider_env); + let config = MicrosoftS2sConfig::from_provider_maps(&provider_map, &HashMap::new()) + .into_diagnostic() + .wrap_err("invalid microsoft-agent-s2s provider configuration")?; + let default_audience = default_audience(&config); + let broker = MicrosoftS2sBroker::new(config) + .into_diagnostic() + .wrap_err("failed to initialize microsoft-agent-s2s token broker")?; + let handle = start_microsoft_agent_s2s_resolver(broker, default_audience.clone(), bind_addr) + .await + .wrap_err("failed to start microsoft-agent-s2s token resolver")?; + + let environment = resolver_environment(handle.url(), default_audience); + + Ok(PreparedProviderTokenResolver { + environment, + handle: Some(handle), + }) +} + +fn contains_microsoft_agent_s2s_inputs(provider_env: &HashMap) -> bool { + MICROSOFT_AGENT_S2S_MARKER_KEYS + .iter() + .any(|key| provider_env.contains_key(*key)) +} + +fn remove_microsoft_agent_s2s_inputs( + provider_env: &mut HashMap, +) -> HashMap { + let mut removed = HashMap::new(); + for key in MICROSOFT_AGENT_S2S_KEYS { + if let Some(value) = provider_env.remove(*key) { + removed.insert((*key).to_string(), value); + } + } + removed +} + +fn default_audience(config: &MicrosoftS2sConfig) -> Option { + config + .observability_resource + .clone() + .or_else(|| match config.allowed_audiences.as_slice() { + [only] => Some(only.clone()), + _ => None, + }) +} + +fn resolver_environment( + resolver_url: String, + default_audience: Option, +) -> HashMap { + let mut environment = HashMap::from([ + (TOKEN_URL_ENV.to_string(), resolver_url.clone()), + (A365_TOKEN_PROVIDER_URL_ENV.to_string(), resolver_url), + ]); + if let Some(audience) = default_audience { + environment.insert(DEFAULT_AUDIENCE_ENV.to_string(), audience); + } + environment +} + +async fn start_microsoft_agent_s2s_resolver( + broker: MicrosoftS2sBroker, + default_audience: Option, + bind_addr: SocketAddr, +) -> Result { + let listener = TcpListener::bind(bind_addr).await.into_diagnostic()?; + let local_addr = listener.local_addr().into_diagnostic()?; + let token_path = format!("{MICROSOFT_AGENT_S2S_TOKEN_PATH}/{}", uuid::Uuid::new_v4()); + let token_path_for_task = token_path.clone(); + + let join = tokio::spawn(async move { + loop { + match listener.accept().await { + Ok((stream, _peer)) => { + let broker = broker.clone(); + let default_audience = default_audience.clone(); + let token_path = token_path_for_task.clone(); + tokio::spawn(async move { + if let Err(err) = handle_microsoft_agent_s2s_connection( + stream, + broker, + default_audience, + token_path, + ) + .await + { + warn!(error = %err, "microsoft-agent-s2s token resolver request failed"); + } + }); + } + Err(err) => { + warn!(error = %err, "microsoft-agent-s2s token resolver accept failed"); + break; + } + } + } + }); + + Ok(ProviderTokenResolverHandle { + local_addr, + token_path, + join, + }) +} + +async fn handle_microsoft_agent_s2s_connection( + mut stream: TcpStream, + broker: MicrosoftS2sBroker, + default_audience: Option, + token_path: String, +) -> Result<()> { + let request = read_http_request(&mut stream).await?; + let response = match parse_token_request(&request, default_audience.as_deref(), &token_path) { + Ok(audience) => match broker.access_token(&audience).await { + Ok(token) => json_response( + 200, + "OK", + serde_json::json!({ + "access_token": token.access_token, + "token_type": "Bearer", + "expires_at_unix": token.expires_at_unix, + "cache_hit": token.cache_hit, + }), + ), + Err(err) => json_response( + 502, + "Bad Gateway", + serde_json::json!({ "error": err.to_string() }), + ), + }, + Err(err) => err.into_response(), + }; + stream + .write_all(response.as_bytes()) + .await + .into_diagnostic()?; + Ok(()) +} + +async fn read_http_request(stream: &mut TcpStream) -> Result { + let mut buffer = Vec::new(); + let mut chunk = [0_u8; 1024]; + loop { + let read = stream.read(&mut chunk).await.into_diagnostic()?; + if read == 0 { + break; + } + buffer.extend_from_slice(&chunk[..read]); + if buffer.windows(4).any(|window| window == b"\r\n\r\n") { + break; + } + if buffer.len() > MAX_REQUEST_HEADER_BYTES { + return Err(miette::miette!("token resolver request headers too large")); + } + } + String::from_utf8(buffer).into_diagnostic() +} + +fn parse_token_request( + request: &str, + default_audience: Option<&str>, + expected_path: &str, +) -> std::result::Result { + let request_line = request + .lines() + .next() + .ok_or_else(|| HttpError::new(400, "Bad Request", "missing HTTP request line"))?; + let mut parts = request_line.split_whitespace(); + let method = parts.next().unwrap_or_default(); + let target = parts.next().unwrap_or_default(); + let _version = parts.next().unwrap_or_default(); + + if method != "GET" { + return Err(HttpError::new( + 405, + "Method Not Allowed", + "method not allowed", + )); + } + + let (path, query) = target + .split_once('?') + .map_or((target, ""), |(path, query)| (path, query)); + if path != expected_path { + return Err(HttpError::new(404, "Not Found", "token endpoint not found")); + } + + let audience = url::form_urlencoded::parse(query.as_bytes()) + .find_map(|(key, value)| (key == "audience").then(|| value.into_owned())) + .or_else(|| default_audience.map(ToOwned::to_owned)) + .ok_or_else(|| { + HttpError::new(400, "Bad Request", "audience query parameter is required") + })?; + + if audience.trim().is_empty() { + return Err(HttpError::new( + 400, + "Bad Request", + "audience must not be empty", + )); + } + + debug!(audience = %audience, "microsoft-agent-s2s token resolver request accepted"); + Ok(audience) +} + +#[derive(Debug)] +struct HttpError { + status: u16, + reason: &'static str, + message: &'static str, +} + +impl HttpError { + const fn new(status: u16, reason: &'static str, message: &'static str) -> Self { + Self { + status, + reason, + message, + } + } + + fn into_response(self) -> String { + json_response( + self.status, + self.reason, + serde_json::json!({ "error": self.message }), + ) + } +} + +fn json_response(status: u16, reason: &str, body: serde_json::Value) -> String { + let body = body.to_string(); + format!( + "HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}", + body.len() + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn prepare_does_nothing_without_microsoft_s2s_inputs() { + let mut provider_env = HashMap::from([("API_KEY".to_string(), "secret".to_string())]); + let prepared = prepare_microsoft_agent_s2s(&mut provider_env, ([127, 0, 0, 1], 0).into()) + .await + .expect("prepare"); + + assert!(prepared.environment.is_empty()); + assert!(prepared.handle.is_none()); + assert_eq!(provider_env.get("API_KEY"), Some(&"secret".to_string())); + } + + #[test] + fn resolver_environment_exposes_only_local_token_metadata() { + let environment = resolver_environment( + "http://127.0.0.1:3130/v1/microsoft-agent-s2s/token/capability".to_string(), + Some("api://resource".to_string()), + ); + + assert_eq!( + environment.get(TOKEN_URL_ENV), + Some(&"http://127.0.0.1:3130/v1/microsoft-agent-s2s/token/capability".to_string()) + ); + assert_eq!( + environment.get(A365_TOKEN_PROVIDER_URL_ENV), + environment.get(TOKEN_URL_ENV) + ); + assert_eq!( + environment.get(DEFAULT_AUDIENCE_ENV), + Some(&"api://resource".to_string()) + ); + assert!(!environment.contains_key("A365_BLUEPRINT_CLIENT_SECRET")); + assert!(!environment.contains_key("A365_BLUEPRINT_CLIENT_ID")); + } + + #[test] + fn parse_uses_default_audience_when_query_is_absent() { + let request = "GET /v1/microsoft-agent-s2s/token/cap HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let audience = parse_token_request( + request, + Some("api://default"), + "/v1/microsoft-agent-s2s/token/cap", + ) + .expect("audience"); + assert_eq!(audience, "api://default"); + } + + #[test] + fn parse_decodes_audience_query_param() { + let request = "GET /v1/microsoft-agent-s2s/token/cap?audience=api%3A%2F%2Fresource HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let audience = parse_token_request(request, None, "/v1/microsoft-agent-s2s/token/cap") + .expect("audience"); + assert_eq!(audience, "api://resource"); + } + + #[test] + fn parse_rejects_missing_audience_without_default() { + let request = "GET /v1/microsoft-agent-s2s/token/cap HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let err = parse_token_request(request, None, "/v1/microsoft-agent-s2s/token/cap") + .expect_err("missing audience should fail"); + assert_eq!(err.status, 400); + } + + #[test] + fn parse_rejects_guessable_base_path() { + let request = "GET /v1/microsoft-agent-s2s/token?audience=api%3A%2F%2Fresource HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let err = parse_token_request(request, None, "/v1/microsoft-agent-s2s/token/cap") + .expect_err("base path should not authorize"); + assert_eq!(err.status, 404); + } + + #[test] + fn json_response_disables_token_caching() { + let response = json_response(200, "OK", serde_json::json!({"access_token": "token"})); + + assert!(response.contains("Cache-Control: no-store\r\n")); + assert!(response.contains("Pragma: no-cache\r\n")); + } + + #[test] + fn removes_broker_inputs_before_child_env_injection() { + let mut provider_env = HashMap::from([ + ("AZURE_TENANT_ID".to_string(), "tenant".to_string()), + ( + "A365_BLUEPRINT_CLIENT_ID".to_string(), + "blueprint".to_string(), + ), + ( + "A365_BLUEPRINT_CLIENT_SECRET".to_string(), + "secret".to_string(), + ), + ( + "A365_RUNTIME_AGENT_ID".to_string(), + "runtime-agent".to_string(), + ), + ( + "A365_ALLOWED_AUDIENCES".to_string(), + "api://resource".to_string(), + ), + ("API_KEY".to_string(), "kept".to_string()), + ]); + + let removed = remove_microsoft_agent_s2s_inputs(&mut provider_env); + + assert_eq!( + removed.get("A365_BLUEPRINT_CLIENT_SECRET"), + Some(&"secret".to_string()) + ); + assert!(!provider_env.contains_key("A365_BLUEPRINT_CLIENT_SECRET")); + assert!(!provider_env.contains_key("A365_BLUEPRINT_CLIENT_ID")); + assert!(!provider_env.contains_key("A365_RUNTIME_AGENT_ID")); + assert_eq!(provider_env.get("API_KEY"), Some(&"kept".to_string())); + } +} diff --git a/crates/openshell-sandbox/src/sandbox/linux/netns.rs b/crates/openshell-sandbox/src/sandbox/linux/netns.rs index 37d11f0c3..c0290fb4a 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/netns.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/netns.rs @@ -235,10 +235,11 @@ impl NetworkNamespace { /// Install iptables rules for bypass detection inside the namespace. /// /// Sets up OUTPUT chain rules that: - /// 1. ACCEPT traffic destined for the proxy (host_ip:proxy_port) - /// 2. ACCEPT loopback traffic - /// 3. ACCEPT established/related connections (response packets) - /// 4. LOG + REJECT all other TCP/UDP traffic (bypass attempts) + /// 1. ACCEPT traffic destined for the proxy (`host_ip:proxy_port`) + /// 2. ACCEPT traffic destined for supervisor-owned host services + /// 3. ACCEPT loopback traffic + /// 4. ACCEPT established/related connections (response packets) + /// 5. LOG + REJECT all other TCP/UDP traffic (bypass attempts) /// /// This provides two benefits: /// - **Fast-fail UX**: applications get immediate ECONNREFUSED instead of @@ -249,7 +250,11 @@ impl NetworkNamespace { /// Degrades gracefully if `iptables` is not available — the namespace /// still provides isolation via routing, just without fast-fail and /// diagnostic logging. - pub fn install_bypass_rules(&self, proxy_port: u16) -> Result<()> { + pub fn install_bypass_rules( + &self, + proxy_port: u16, + additional_host_ports: &[u16], + ) -> Result<()> { // Check if iptables is available before attempting to install rules. let iptables_path = match find_iptables() { Some(path) => path, @@ -281,6 +286,7 @@ impl NetworkNamespace { &iptables_path, &host_ip_str, &proxy_port_str, + additional_host_ports, &log_prefix, ) { openshell_ocsf::ocsf_emit!( @@ -336,25 +342,18 @@ impl NetworkNamespace { iptables_cmd: &str, host_ip: &str, proxy_port: &str, + additional_host_ports: &[u16], log_prefix: &str, ) -> Result<()> { // Rule 1: ACCEPT traffic to the proxy - run_iptables_netns( - &self.name, - iptables_cmd, - &[ - "-A", - "OUTPUT", - "-d", - &format!("{host_ip}/32"), - "-p", - "tcp", - "--dport", - proxy_port, - "-j", - "ACCEPT", - ], - )?; + self.install_host_tcp_accept_rule(iptables_cmd, host_ip, proxy_port)?; + + for port in additional_host_ports { + let port = port.to_string(); + if port != proxy_port { + self.install_host_tcp_accept_rule(iptables_cmd, host_ip, &port)?; + } + } // Rule 2: ACCEPT loopback traffic run_iptables_netns( @@ -486,6 +485,30 @@ impl NetworkNamespace { Ok(()) } + fn install_host_tcp_accept_rule( + &self, + iptables_cmd: &str, + host_ip: &str, + port: &str, + ) -> Result<()> { + run_iptables_netns( + &self.name, + iptables_cmd, + &[ + "-A", + "OUTPUT", + "-d", + &format!("{host_ip}/32"), + "-p", + "tcp", + "--dport", + port, + "-j", + "ACCEPT", + ], + ) + } + /// Install IPv6 bypass detection rules. /// /// Similar to `install_bypass_rules_for` but omits the proxy ACCEPT rule diff --git a/docs/sandboxes/manage-providers.mdx b/docs/sandboxes/manage-providers.mdx index fbfc4d380..0b913713f 100644 --- a/docs/sandboxes/manage-providers.mdx +++ b/docs/sandboxes/manage-providers.mdx @@ -83,8 +83,10 @@ Pass one or more `--provider` flags when creating a sandbox: openshell sandbox create --provider my-claude --provider my-github -- claude ``` -Each `--provider` flag attaches one provider. The sandbox receives all -credentials from every attached provider at runtime. +Each `--provider` flag attaches one provider. Most providers become placeholder +environment variables in the sandbox. Runtime broker providers, such as +`microsoft-agent-s2s`, expose a local resolver URL instead of raw broker +credentials. Providers cannot be added to a running sandbox. If you need to attach an @@ -125,6 +127,22 @@ The proxy resolves credential placeholders in the following parts of an HTTP req The proxy does not modify request bodies, cookies, or response content. +### Microsoft Agent S2S Token Resolver + +The `microsoft-agent-s2s` provider does not inject `A365_BLUEPRINT_CLIENT_SECRET` +or other broker inputs into the agent process. The sandbox supervisor keeps those +values in OpenShell-managed memory and starts a local token resolver. + +Agent runtimes can read: + +- `OPENSHELL_MICROSOFT_AGENT_S2S_TOKEN_URL` +- `OPENSHELL_MICROSOFT_AGENT_S2S_DEFAULT_AUDIENCE` +- `A365_TOKEN_PROVIDER_URL` + +Call the token URL with `GET ?audience=` to receive a short-lived +runtime-agent access token. If a single default audience is configured, the +`audience` query parameter can be omitted. + ### Fail-closed behavior If the proxy detects a credential placeholder in a request but cannot resolve it, it rejects the request with HTTP 500 instead of forwarding the raw placeholder to the upstream server. This prevents accidental credential leakage in server logs or error responses. @@ -158,6 +176,7 @@ The following provider types are supported. | `generic` | User-defined | Any service with custom credentials | | `github` | `GITHUB_TOKEN`, `GH_TOKEN` | GitHub API, `gh` CLI — refer to [Github Sandbox](/tutorials/github-sandbox) | | `gitlab` | `GITLAB_TOKEN`, `GLAB_TOKEN`, `CI_JOB_TOKEN` | GitLab API, `glab` CLI | +| `microsoft-agent-s2s` | Local resolver URL only; broker inputs stay supervisor-only | Microsoft Agent ID S2S/runtime-agent token broker inputs | | `nvidia` | `NVIDIA_API_KEY` | NVIDIA API Catalog | | `openai` | `OPENAI_API_KEY` | Any OpenAI-compatible endpoint. Set `--config OPENAI_BASE_URL` to point to the provider. Refer to [Configure](/inference/configure). | | `opencode` | `OPENCODE_API_KEY`, `OPENROUTER_API_KEY`, `OPENAI_API_KEY` | opencode tool |