diff --git a/Cargo.lock b/Cargo.lock index f3576ca19..6ecdadf73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3540,6 +3540,7 @@ version = "0.0.0" dependencies = [ "openshell-core", "serde", + "serde_json", "serde_yml", "thiserror 2.0.18", ] diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index bc766c53b..bab01b43e 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -9,6 +9,7 @@ use clap_complete::env::CompleteEnv; use miette::Result; use owo_colors::OwoColorize; use std::io::Write; +use std::path::PathBuf; use openshell_bootstrap::{ edge_token::load_edge_token, get_gateway_metadata, list_gateways, load_active_gateway, @@ -633,18 +634,20 @@ fn normalize_completion_script(output: Vec, executable: &std::path::Path) -> } #[derive(Clone, Debug, ValueEnum)] -enum CliProviderType { - Claude, - Opencode, - Codex, - Copilot, - Generic, - Openai, - Anthropic, - Nvidia, - Gitlab, - Github, - Outlook, +enum ProviderProfileOutput { + Table, + Yaml, + Json, +} + +impl ProviderProfileOutput { + fn as_str(&self) -> &'static str { + match self { + Self::Table => "table", + Self::Yaml => "yaml", + Self::Json => "json", + } + } } #[derive(Clone, Debug, ValueEnum)] @@ -662,24 +665,6 @@ impl From for openshell_cli::ssh::Editor { } } -impl CliProviderType { - fn as_str(&self) -> &'static str { - match self { - Self::Claude => "claude", - Self::Opencode => "opencode", - Self::Codex => "codex", - Self::Copilot => "copilot", - Self::Generic => "generic", - Self::Openai => "openai", - Self::Anthropic => "anthropic", - Self::Nvidia => "nvidia", - Self::Gitlab => "gitlab", - Self::Github => "github", - Self::Outlook => "outlook", - } - } -} - #[derive(Subcommand, Debug)] enum ProviderCommands { /// Create a provider config. @@ -690,8 +675,8 @@ enum ProviderCommands { name: String, /// Provider type. - #[arg(long = "type", value_enum)] - provider_type: CliProviderType, + #[arg(long = "type")] + provider_type: String, /// Load provider credentials/config from existing local state. #[arg(long, conflicts_with = "credentials")] @@ -736,7 +721,15 @@ enum ProviderCommands { /// List available provider profiles. #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] - ListProfiles, + ListProfiles { + /// Output format. + #[arg(short = 'o', long = "output", value_enum, default_value_t = ProviderProfileOutput::Table)] + output: ProviderProfileOutput, + }, + + /// Manage provider profiles. + #[command(subcommand, help_template = SUBCOMMAND_HELP_TEMPLATE)] + Profile(ProviderProfileCommands), /// Update an existing provider's credentials or config. #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] @@ -771,6 +764,51 @@ enum ProviderCommands { }, } +#[derive(Subcommand, Debug)] +enum ProviderProfileCommands { + /// Export a provider profile. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Export { + /// Provider profile id. + id: String, + + /// Output format. + #[arg(short = 'o', long = "output", value_enum, default_value_t = ProviderProfileOutput::Yaml)] + output: ProviderProfileOutput, + }, + + /// Import provider profiles from a file or directory. + #[command(group = clap::ArgGroup::new("source").required(true).args(["file", "from"]), help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Import { + /// Profile file to import. + #[arg(short = 'f', long = "file", value_hint = ValueHint::FilePath)] + file: Option, + + /// Directory containing profile files to import. + #[arg(long = "from", value_hint = ValueHint::DirPath)] + from: Option, + }, + + /// Validate provider profile files without registering them. + #[command(group = clap::ArgGroup::new("source").required(true).args(["file", "from"]), help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Lint { + /// Profile file to lint. + #[arg(short = 'f', long = "file", value_hint = ValueHint::FilePath)] + file: Option, + + /// Directory containing profile files to lint. + #[arg(long = "from", value_hint = ValueHint::DirPath)] + from: Option, + }, + + /// Delete a custom provider profile. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Delete { + /// Provider profile id. + id: String, + }, +} + // ----------------------------------------------------------------------- // Gateway commands (replaces the old `cluster` / `cluster admin` groups) // ----------------------------------------------------------------------- @@ -2777,9 +2815,35 @@ async fn main() -> Result<()> { } => { run::provider_list(endpoint, limit, offset, names, &tls).await?; } - ProviderCommands::ListProfiles => { - run::provider_list_profiles(endpoint, &tls).await?; + ProviderCommands::ListProfiles { output } => { + run::provider_list_profiles(endpoint, output.as_str(), &tls).await?; } + ProviderCommands::Profile(command) => match command { + ProviderProfileCommands::Export { id, output } => { + run::provider_profile_export(endpoint, &id, output.as_str(), &tls).await?; + } + ProviderProfileCommands::Import { file, from } => { + run::provider_profile_import( + endpoint, + file.as_deref(), + from.as_deref(), + &tls, + ) + .await?; + } + ProviderProfileCommands::Lint { file, from } => { + run::provider_profile_lint( + endpoint, + file.as_deref(), + from.as_deref(), + &tls, + ) + .await?; + } + ProviderProfileCommands::Delete { id } => { + run::provider_profile_delete(endpoint, &id, &tls).await?; + } + }, ProviderCommands::Update { name, from_existing, @@ -3479,9 +3543,113 @@ mod tests { assert!(matches!( cli.command, Some(Commands::Provider { - command: Some(ProviderCommands::ListProfiles) + command: Some(ProviderCommands::ListProfiles { + output: ProviderProfileOutput::Table + }) + }) + )); + } + + #[test] + fn provider_list_profiles_accepts_output_format() { + let cli = Cli::try_parse_from(["openshell", "provider", "list-profiles", "-o", "json"]) + .expect("provider list-profiles -o json should parse"); + + assert!(matches!( + cli.command, + Some(Commands::Provider { + command: Some(ProviderCommands::ListProfiles { + output: ProviderProfileOutput::Json + }) + }) + )); + } + + #[test] + fn provider_profile_commands_parse() { + let export = Cli::try_parse_from([ + "openshell", + "provider", + "profile", + "export", + "custom-api", + "-o", + "yaml", + ]) + .expect("provider profile export should parse"); + assert!(matches!( + export.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Profile(ProviderProfileCommands::Export { + id, + output: ProviderProfileOutput::Yaml + })) + }) if id == "custom-api" + )); + + let import = Cli::try_parse_from([ + "openshell", + "provider", + "profile", + "import", + "--from", + "./profiles", + ]) + .expect("provider profile import should parse"); + assert!(matches!( + import.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Profile(ProviderProfileCommands::Import { + from: Some(_), + .. + })) }) )); + + let delete = + Cli::try_parse_from(["openshell", "provider", "profile", "delete", "custom-api"]) + .expect("provider profile delete should parse"); + assert!(matches!( + delete.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Profile(ProviderProfileCommands::Delete { + id + })) + }) if id == "custom-api" + )); + } + + #[test] + fn provider_create_accepts_custom_profile_type_ids() { + let cli = Cli::try_parse_from([ + "openshell", + "provider", + "create", + "--name", + "work-github", + "--type", + "github-readonly", + "--credential", + "GITHUB_TOKEN=token", + ]) + .expect("provider create should parse custom profile ids"); + + match cli.command { + Some(Commands::Provider { + command: + Some(ProviderCommands::Create { + name, + provider_type, + credentials, + .. + }), + }) => { + assert_eq!(name, "work-github"); + assert_eq!(provider_type, "github-readonly"); + assert_eq!(credentials, vec!["GITHUB_TOKEN=token"]); + } + other => panic!("expected provider create command, got: {other:?}"), + } } #[test] @@ -3630,29 +3798,4 @@ mod tests { } } } - - /// Ensure every provider registered in `ProviderRegistry` has a - /// corresponding `CliProviderType` variant (and vice-versa). - /// This test would have caught the missing `Copilot` variant from #707. - #[test] - fn cli_provider_types_match_registry() { - let registry = openshell_providers::ProviderRegistry::new(); - let registry_types: std::collections::BTreeSet<&str> = - registry.known_types().into_iter().collect(); - - let cli_types: std::collections::BTreeSet<&str> = - ::value_variants() - .iter() - .map(CliProviderType::as_str) - .collect(); - - assert_eq!( - cli_types, - registry_types, - "CliProviderType variants must match ProviderRegistry.known_types(). \ - CLI-only: {:?}, Registry-only: {:?}", - cli_types.difference(®istry_types).collect::>(), - registry_types.difference(&cli_types).collect::>(), - ); - } } diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 4864c5a8a..4665f810d 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -25,20 +25,24 @@ use openshell_bootstrap::{ use openshell_core::proto::ProviderProfileCategory; use openshell_core::proto::{ ApproveAllDraftChunksRequest, ApproveDraftChunkRequest, ClearDraftChunksRequest, - CreateProviderRequest, CreateSandboxRequest, DeleteProviderRequest, DeleteSandboxRequest, - ExecSandboxRequest, GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, - GetGatewayConfigRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, - GetSandboxPolicyStatusRequest, GetSandboxRequest, HealthRequest, ListProviderProfilesRequest, - ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxesRequest, PolicySource, - PolicyStatus, Provider, ProviderProfile, RejectDraftChunkRequest, Sandbox, SandboxPhase, - SandboxPolicy, SandboxSpec, SandboxTemplate, SetClusterInferenceRequest, SettingScope, - SettingValue, UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, + CreateProviderRequest, CreateSandboxRequest, DeleteProviderProfileRequest, + DeleteProviderRequest, DeleteSandboxRequest, ExecSandboxRequest, GetClusterInferenceRequest, + GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, + GetProviderProfileRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, + GetSandboxPolicyStatusRequest, GetSandboxRequest, HealthRequest, ImportProviderProfilesRequest, + LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, + ListSandboxPoliciesRequest, ListSandboxesRequest, PolicySource, PolicyStatus, Provider, + ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, + Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, SetClusterInferenceRequest, + SettingScope, SettingValue, UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, setting_value, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_core::{ObjectId, ObjectName}; use openshell_providers::{ - ProviderRegistry, detect_provider_from_command, normalize_provider_type, + ProviderRegistry, ProviderTypeProfile, detect_provider_from_command, normalize_provider_type, + parse_profile_json, parse_profile_yaml, profile_to_json, profile_to_yaml, profiles_to_json, + profiles_to_yaml, }; use owo_colors::OwoColorize; use std::collections::{HashMap, HashSet, VecDeque}; @@ -3916,9 +3920,33 @@ pub async fn provider_create( let mut client = grpc_client(server, tls).await?; - let provider_type = normalize_provider_type(provider_type) - .ok_or_else(|| miette::miette!("unsupported provider type: {provider_type}"))? - .to_string(); + let provider_type = if let Some(provider_type) = normalize_provider_type(provider_type) { + provider_type.to_string() + } else { + let profile_id = provider_type.trim(); + if profile_id.is_empty() { + return Err(miette::miette!("provider type is required")); + } + let response = client + .get_provider_profile(GetProviderProfileRequest { + id: profile_id.to_string(), + }) + .await; + match response { + Ok(response) => response + .into_inner() + .profile + .map(|profile| profile.id) + .filter(|id| !id.trim().is_empty()) + .unwrap_or_else(|| profile_id.to_string()), + Err(status) if status.code() == Code::NotFound => { + return Err(miette::miette!( + "unsupported provider type or profile: {provider_type}" + )); + } + Err(status) => return Err(status).into_diagnostic(), + } + }; let mut credential_map = parse_credential_pairs(credentials)?; let mut config_map = parse_key_value_pairs(config, "--config")?; @@ -4085,7 +4113,7 @@ pub async fn provider_list( Ok(()) } -pub async fn provider_list_profiles(server: &str, tls: &TlsOptions) -> Result<()> { +pub async fn provider_list_profiles(server: &str, output: &str, tls: &TlsOptions) -> Result<()> { let mut client = grpc_client(server, tls).await?; let response = client .list_provider_profiles(ListProviderProfilesRequest { @@ -4100,6 +4128,23 @@ pub async fn provider_list_profiles(server: &str, tls: &TlsOptions) -> Result<() .cmp(&right.category) .then_with(|| left.id.cmp(&right.id)) }); + let dto_profiles = profiles + .iter() + .map(ProviderTypeProfile::from_proto) + .collect::>(); + + match output { + "yaml" => { + print!("{}", profiles_to_yaml(&dto_profiles).into_diagnostic()?); + return Ok(()); + } + "json" => { + println!("{}", profiles_to_json(&dto_profiles).into_diagnostic()?); + return Ok(()); + } + "table" => {} + _ => return Err(miette!("unsupported output format: {output}")), + } if profiles.is_empty() { println!("No provider profiles found."); @@ -4120,6 +4165,241 @@ pub async fn provider_list_profiles(server: &str, tls: &TlsOptions) -> Result<() Ok(()) } +pub async fn provider_profile_export( + server: &str, + id: &str, + output: &str, + tls: &TlsOptions, +) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .get_provider_profile(GetProviderProfileRequest { id: id.to_string() }) + .await + .into_diagnostic()?; + let profile = response + .into_inner() + .profile + .ok_or_else(|| miette!("provider profile '{id}' not found"))?; + let profile = ProviderTypeProfile::from_proto(&profile); + + match output { + "yaml" => print!("{}", profile_to_yaml(&profile).into_diagnostic()?), + "json" => println!("{}", profile_to_json(&profile).into_diagnostic()?), + "table" => { + return Err(miette!( + "profile export supports '-o yaml' and '-o json'; table output is not supported" + )); + } + _ => return Err(miette!("unsupported output format: {output}")), + } + Ok(()) +} + +pub async fn provider_profile_import( + server: &str, + file: Option<&Path>, + from: Option<&Path>, + tls: &TlsOptions, +) -> Result<()> { + let (items, mut diagnostics) = load_profile_import_items(file, from)?; + if items.is_empty() && diagnostics.is_empty() { + return Err(miette!("no provider profile files found")); + } + if profile_diagnostics_have_errors(&diagnostics) { + print_profile_diagnostics(&diagnostics); + return Err(miette!("provider profile import failed")); + } + + let mut client = grpc_client(server, tls).await?; + if !items.is_empty() { + let response = client + .import_provider_profiles(ImportProviderProfilesRequest { profiles: items }) + .await + .into_diagnostic()? + .into_inner(); + diagnostics.extend(response.diagnostics); + if response.imported { + println!( + "Imported {} provider profile{}.", + response.profiles.len(), + if response.profiles.len() == 1 { + "" + } else { + "s" + } + ); + return Ok(()); + } + } + + print_profile_diagnostics(&diagnostics); + Err(miette!("provider profile import failed")) +} + +pub async fn provider_profile_lint( + server: &str, + file: Option<&Path>, + from: Option<&Path>, + tls: &TlsOptions, +) -> Result<()> { + let (items, mut diagnostics) = load_profile_import_items(file, from)?; + if items.is_empty() && diagnostics.is_empty() { + return Err(miette!("no provider profile files found")); + } + + if !items.is_empty() { + let mut client = grpc_client(server, tls).await?; + let response = client + .lint_provider_profiles(LintProviderProfilesRequest { profiles: items }) + .await + .into_diagnostic()? + .into_inner(); + diagnostics.extend(response.diagnostics); + } + + if profile_diagnostics_have_errors(&diagnostics) { + print_profile_diagnostics(&diagnostics); + return Err(miette!("provider profile lint failed")); + } + + println!("Provider profile lint passed."); + Ok(()) +} + +pub async fn provider_profile_delete(server: &str, id: &str, tls: &TlsOptions) -> Result<()> { + let mut client = grpc_client(server, tls).await?; + let response = client + .delete_provider_profile(DeleteProviderProfileRequest { id: id.to_string() }) + .await + .into_diagnostic()? + .into_inner(); + if response.deleted { + println!("Deleted provider profile '{id}'."); + } else { + println!("Provider profile '{id}' was not deleted."); + } + Ok(()) +} + +fn load_profile_import_items( + file: Option<&Path>, + from: Option<&Path>, +) -> Result<( + Vec, + Vec, +)> { + let paths = profile_source_paths(file, from)?; + let mut items = Vec::new(); + let mut diagnostics = Vec::new(); + for path in paths { + match load_profile_import_item(&path) { + Ok(item) => items.push(item), + Err(diagnostic) => diagnostics.push(diagnostic), + } + } + Ok((items, diagnostics)) +} + +fn profile_source_paths(file: Option<&Path>, from: Option<&Path>) -> Result> { + if let Some(file) = file { + return Ok(vec![file.to_path_buf()]); + } + let Some(from) = from else { + return Ok(Vec::new()); + }; + let mut paths = Vec::new(); + for entry in std::fs::read_dir(from) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read profile directory {}", from.display()))? + { + let entry = entry.into_diagnostic()?; + let path = entry.path(); + if path.is_file() && profile_extension_supported(&path) { + paths.push(path); + } + } + paths.sort(); + Ok(paths) +} + +fn profile_extension_supported(path: &Path) -> bool { + matches!( + path.extension().and_then(|ext| ext.to_str()), + Some("yaml" | "yml" | "json") + ) +} + +fn load_profile_import_item( + path: &Path, +) -> Result { + let source = path.display().to_string(); + let input = std::fs::read_to_string(path).map_err(|err| { + profile_file_diagnostic( + &source, + format!("failed to read provider profile file: {err}"), + ) + })?; + let profile = match path.extension().and_then(|ext| ext.to_str()) { + Some("yaml" | "yml") => parse_profile_yaml(&input), + Some("json") => parse_profile_json(&input), + _ => { + return Err(profile_file_diagnostic( + &source, + "unsupported provider profile file format".to_string(), + )); + } + } + .map_err(|err| profile_file_diagnostic(&source, err.to_string()))?; + + Ok(ProviderProfileImportItem { + profile: Some(profile.to_proto()), + source, + }) +} + +fn profile_file_diagnostic(source: &str, message: String) -> ProviderProfileDiagnostic { + ProviderProfileDiagnostic { + source: source.to_string(), + profile_id: String::new(), + field: "file".to_string(), + message, + severity: "error".to_string(), + } +} + +fn print_profile_diagnostics(diagnostics: &[ProviderProfileDiagnostic]) { + if diagnostics.is_empty() { + return; + } + eprintln!("{}", "Provider profile diagnostics:".red().bold()); + for diagnostic in diagnostics { + let source = if diagnostic.source.is_empty() { + "" + } else { + &diagnostic.source + }; + let profile = if diagnostic.profile_id.is_empty() { + "-".to_string() + } else { + diagnostic.profile_id.clone() + }; + eprintln!( + " {} {} profile={} field={} {}", + diagnostic.severity.as_str().red(), + source, + profile, + diagnostic.field, + diagnostic.message + ); + } +} + +fn profile_diagnostics_have_errors(diagnostics: &[ProviderProfileDiagnostic]) -> bool { + diagnostics + .iter() + .any(|diagnostic| diagnostic.severity == "error") +} + fn display_provider_category(category: i32) -> &'static str { match ProviderProfileCategory::try_from(category).unwrap_or(ProviderProfileCategory::Other) { ProviderProfileCategory::Inference => "INFERENCE", diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index 0b29f73f4..15f620e8e 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -265,6 +265,27 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn import_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn lint_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn delete_provider_profile( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_provider( &self, request: tonic::Request, diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index 69d7b7354..01df4403a 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -191,6 +191,27 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn import_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn lint_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn delete_provider_profile( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_provider( &self, _request: tonic::Request, diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index cdef9614e..55ed69500 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -12,8 +12,8 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - Provider, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + Provider, ProviderProfile, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, }; use openshell_core::{ObjectId, ObjectName}; @@ -63,6 +63,7 @@ impl Drop for EnvVarGuard { #[derive(Clone, Default)] struct ProviderState { providers: Arc>>, + profiles: Arc>>, } #[derive(Clone, Default)] @@ -205,21 +206,72 @@ impl OpenShell for TestOpenShell { &self, _request: tonic::Request, ) -> Result, Status> { + let mut profiles = openshell_providers::default_profiles() + .iter() + .map(openshell_providers::ProviderTypeProfile::to_proto) + .collect::>(); + profiles.extend(self.state.profiles.lock().await.values().cloned()); Ok(Response::new( - openshell_core::proto::ListProviderProfilesResponse { - profiles: openshell_providers::default_profiles() - .iter() - .map(openshell_providers::ProviderTypeProfile::to_proto) - .collect(), - }, + openshell_core::proto::ListProviderProfilesResponse { profiles }, )) } async fn get_provider_profile( &self, - _request: tonic::Request, + request: tonic::Request, ) -> Result, Status> { - Err(Status::unimplemented("not implemented in test")) + let id = request.into_inner().id; + let profile = if let Some(profile) = openshell_providers::get_default_profile(&id) { + profile.to_proto() + } else { + self.state + .profiles + .lock() + .await + .get(&id) + .cloned() + .ok_or_else(|| Status::not_found("provider profile not found"))? + }; + Ok(Response::new( + openshell_core::proto::ProviderProfileResponse { + profile: Some(profile), + }, + )) + } + + async fn import_provider_profiles( + &self, + request: tonic::Request, + ) -> Result, Status> { + let mut profiles = self.state.profiles.lock().await; + let imported = request + .into_inner() + .profiles + .into_iter() + .filter_map(|item| item.profile) + .inspect(|profile| { + profiles.insert(profile.id.clone(), profile.clone()); + }) + .collect::>(); + Ok(Response::new( + openshell_core::proto::ImportProviderProfilesResponse { + diagnostics: Vec::new(), + profiles: imported, + imported: true, + }, + )) + } + + async fn lint_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::LintProviderProfilesResponse { + diagnostics: Vec::new(), + valid: true, + }, + )) } async fn update_provider( @@ -281,6 +333,17 @@ impl OpenShell for TestOpenShell { Ok(Response::new(DeleteProviderResponse { deleted })) } + async fn delete_provider_profile( + &self, + request: tonic::Request, + ) -> Result, Status> { + let id = request.into_inner().id; + let deleted = self.state.profiles.lock().await.remove(&id).is_some(); + Ok(Response::new( + openshell_core::proto::DeleteProviderProfileResponse { deleted }, + )) + } + type WatchSandboxStream = tokio_stream::wrappers::ReceiverStream>; type ExecSandboxStream = @@ -463,6 +526,7 @@ fn build_client_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { struct TestServer { endpoint: String, tls: TlsOptions, + state: ProviderState, _dir: TempDir, } @@ -483,11 +547,15 @@ async fn run_server() -> TestServer { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let incoming = TcpListenerStream::new(listener); + let state = ProviderState::default(); + let service = TestOpenShell { + state: state.clone(), + }; tokio::spawn(async move { Server::builder() .tls_config(tls_config) .unwrap() - .add_service(OpenShellServer::new(TestOpenShell::default())) + .add_service(OpenShellServer::new(service)) .serve_with_incoming(incoming) .await .unwrap(); @@ -507,6 +575,7 @@ async fn run_server() -> TestServer { TestServer { endpoint, tls, + state, _dir: dir, } } @@ -554,11 +623,245 @@ async fn provider_cli_run_functions_support_full_crud_flow() { async fn provider_list_profiles_cli_uses_profile_browsing_rpc() { let ts = run_server().await; - run::provider_list_profiles(&ts.endpoint, &ts.tls) + run::provider_list_profiles(&ts.endpoint, "table", &ts.tls) .await .expect("provider list-profiles"); } +#[tokio::test] +async fn provider_profile_cli_run_functions_support_custom_profiles() { + let ts = run_server().await; + let dir = tempfile::tempdir().unwrap(); + let profile_path = dir.path().join("custom-api.yaml"); + std::fs::write( + &profile_path, + r" +id: custom-api +display_name: Custom API +category: other +credentials: + - name: api_key + env_vars: [CUSTOM_API_KEY] + auth_style: bearer + header_name: authorization +endpoints: + - host: api.custom.example + port: 443 +binaries: [/usr/bin/custom] +", + ) + .unwrap(); + + run::provider_profile_lint(&ts.endpoint, Some(&profile_path), None, &ts.tls) + .await + .expect("profile lint"); + run::provider_profile_import(&ts.endpoint, Some(&profile_path), None, &ts.tls) + .await + .expect("profile import"); + run::provider_profile_export(&ts.endpoint, "custom-api", "yaml", &ts.tls) + .await + .expect("profile export"); + run::provider_list_profiles(&ts.endpoint, "json", &ts.tls) + .await + .expect("provider list-profiles json"); + run::provider_create( + &ts.endpoint, + "custom-provider", + "custom-api", + false, + &["CUSTOM_API_KEY=abc".to_string()], + &[], + &ts.tls, + ) + .await + .expect("custom profile provider create"); + + let provider = ts + .state + .providers + .lock() + .await + .get("custom-provider") + .cloned() + .expect("custom provider should be stored"); + assert_eq!(provider.r#type, "custom-api"); + + run::provider_delete(&ts.endpoint, &["custom-provider".to_string()], &ts.tls) + .await + .expect("custom provider delete"); + run::provider_profile_delete(&ts.endpoint, "custom-api", &ts.tls) + .await + .expect("profile delete"); +} + +#[tokio::test] +async fn provider_profile_import_from_directory_imports_supported_profile_files() { + let ts = run_server().await; + let dir = tempfile::tempdir().unwrap(); + std::fs::write( + dir.path().join("custom-yaml.yaml"), + r" +id: custom-yaml +display_name: Custom YAML +category: other +endpoints: + - host: api.yaml.example + port: 443 +binaries: [/usr/bin/yaml-client] +", + ) + .unwrap(); + std::fs::write( + dir.path().join("custom-json.json"), + r#"{ + "id": "custom-json", + "display_name": "Custom JSON", + "description": "", + "category": "other", + "credentials": [], + "endpoints": [{"host": "api.json.example", "port": 443}], + "binaries": ["/usr/bin/json-client"], + "inference_capable": false +}"#, + ) + .unwrap(); + std::fs::write(dir.path().join("notes.txt"), "ignored").unwrap(); + + run::provider_profile_import(&ts.endpoint, None, Some(dir.path()), &ts.tls) + .await + .expect("profile import --from"); + + run::provider_profile_export(&ts.endpoint, "custom-yaml", "yaml", &ts.tls) + .await + .expect("custom-yaml should be imported"); + run::provider_profile_export(&ts.endpoint, "custom-json", "json", &ts.tls) + .await + .expect("custom-json should be imported"); +} + +#[tokio::test] +#[allow(deprecated)] +async fn provider_profile_import_preserves_advanced_network_policy_fields() { + let ts = run_server().await; + let dir = tempfile::tempdir().unwrap(); + let profile_path = dir.path().join("advanced-api.yaml"); + std::fs::write( + &profile_path, + r" +id: advanced-api +display_name: Advanced API +category: other +endpoints: + - host: api.advanced.example + ports: [443, 8443] + protocol: rest + tls: terminate + enforcement: enforce + rules: + - allow: + method: GET + path: /v1/** + allowed_ips: [10.0.0.0/24] + deny_rules: + - method: POST + path: /admin/** + allow_encoded_slash: true + path: /v1 +binaries: + - path: /usr/bin/advanced + harness: true +", + ) + .unwrap(); + + run::provider_profile_import(&ts.endpoint, Some(&profile_path), None, &ts.tls) + .await + .expect("profile import"); + + let mut client = openshell_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client should connect"); + let profile = client + .get_provider_profile(openshell_core::proto::GetProviderProfileRequest { + id: "advanced-api".to_string(), + }) + .await + .expect("get provider profile") + .into_inner() + .profile + .expect("profile should exist"); + let endpoint = profile.endpoints.first().expect("endpoint should exist"); + assert_eq!(endpoint.ports, vec![443, 8443]); + assert_eq!(endpoint.rules.len(), 1); + assert_eq!(endpoint.deny_rules.len(), 1); + assert_eq!(endpoint.allowed_ips, vec!["10.0.0.0/24"]); + assert!(endpoint.allow_encoded_slash); + assert_eq!(endpoint.path, "/v1"); + assert!(profile.binaries[0].harness); +} + +#[tokio::test] +async fn provider_profile_import_from_directory_parse_error_prevents_partial_import() { + let ts = run_server().await; + let dir = tempfile::tempdir().unwrap(); + std::fs::write( + dir.path().join("custom-good.yaml"), + r" +id: custom-good +display_name: Custom Good +category: other +endpoints: + - host: api.good.example + port: 443 +", + ) + .unwrap(); + std::fs::write(dir.path().join("broken.yaml"), "id: [\n").unwrap(); + + let err = run::provider_profile_import(&ts.endpoint, None, Some(dir.path()), &ts.tls) + .await + .expect_err("profile import --from should fail on parse errors"); + assert!( + err.to_string().contains("provider profile import failed"), + "unexpected error: {err}" + ); + + run::provider_profile_export(&ts.endpoint, "custom-good", "yaml", &ts.tls) + .await + .expect_err("valid profiles should not be partially imported after local parse errors"); +} + +#[tokio::test] +async fn provider_profile_lint_from_directory_reports_parse_errors_without_importing() { + let ts = run_server().await; + let dir = tempfile::tempdir().unwrap(); + std::fs::write( + dir.path().join("custom-good.yaml"), + r" +id: custom-good +display_name: Custom Good +category: other +endpoints: + - host: api.good.example + port: 443 +", + ) + .unwrap(); + std::fs::write(dir.path().join("broken.yaml"), "id: [\n").unwrap(); + + let err = run::provider_profile_lint(&ts.endpoint, None, Some(dir.path()), &ts.tls) + .await + .expect_err("profile lint --from should fail on parse errors"); + assert!( + err.to_string().contains("provider profile lint failed"), + "unexpected error: {err}" + ); + + run::provider_profile_export(&ts.endpoint, "custom-good", "yaml", &ts.tls) + .await + .expect_err("lint should not import valid profiles"); +} + #[tokio::test] async fn provider_create_rejects_key_only_credentials_without_local_env_value() { let ts = run_server().await; diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index b65bdd684..59162f818 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -245,6 +245,27 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn import_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn lint_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn delete_provider_profile( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_provider( &self, _request: tonic::Request, diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index bfad9a7d5..ac1ff37c6 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -221,6 +221,27 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn import_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn lint_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn delete_provider_profile( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_provider( &self, _request: tonic::Request, diff --git a/crates/openshell-core/src/metadata.rs b/crates/openshell-core/src/metadata.rs index 90566dcfd..e7ffea61a 100644 --- a/crates/openshell-core/src/metadata.rs +++ b/crates/openshell-core/src/metadata.rs @@ -5,7 +5,9 @@ //! //! These traits provide uniform access to `ObjectMeta` fields across all resource types. -use crate::proto::{InferenceRoute, ObjectForTest, Provider, Sandbox, SshSession}; +use crate::proto::{ + InferenceRoute, ObjectForTest, Provider, Sandbox, SshSession, StoredProviderProfile, +}; use std::collections::HashMap; /// Provides access to the object's unique identifier. @@ -61,6 +63,25 @@ impl ObjectLabels for Provider { } } +// Implementations for StoredProviderProfile +impl ObjectId for StoredProviderProfile { + fn object_id(&self) -> &str { + self.metadata.as_ref().map_or("", |m| m.id.as_str()) + } +} + +impl ObjectName for StoredProviderProfile { + fn object_name(&self) -> &str { + self.metadata.as_ref().map_or("", |m| m.name.as_str()) + } +} + +impl ObjectLabels for StoredProviderProfile { + fn object_labels(&self) -> Option> { + self.metadata.as_ref().map(|m| m.labels.clone()) + } +} + // Implementations for SshSession impl ObjectId for SshSession { fn object_id(&self) -> &str { diff --git a/crates/openshell-providers/Cargo.toml b/crates/openshell-providers/Cargo.toml index 1a3bda8f6..e82574d73 100644 --- a/crates/openshell-providers/Cargo.toml +++ b/crates/openshell-providers/Cargo.toml @@ -13,6 +13,7 @@ repository.workspace = true [dependencies] openshell-core = { path = "../openshell-core" } serde = { workspace = true } +serde_json = { workspace = true } serde_yml = { workspace = true } thiserror = { workspace = true } diff --git a/crates/openshell-providers/src/lib.rs b/crates/openshell-providers/src/lib.rs index b2bf1e234..3b28030ca 100644 --- a/crates/openshell-providers/src/lib.rs +++ b/crates/openshell-providers/src/lib.rs @@ -17,7 +17,11 @@ pub use openshell_core::proto::Provider; pub use context::{DiscoveryContext, RealDiscoveryContext}; pub use discovery::discover_with_spec; -pub use profiles::{ProviderTypeProfile, default_profiles, get_default_profile}; +pub use profiles::{ + ProfileError, ProfileValidationDiagnostic, ProviderTypeProfile, default_profiles, + get_default_profile, normalize_profile_id, parse_profile_json, parse_profile_yaml, + profile_to_json, profile_to_yaml, profiles_to_json, profiles_to_yaml, validate_profile_set, +}; #[derive(Debug, thiserror::Error)] pub enum ProviderError { diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index b9c161d26..8c3f247cf 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -6,11 +6,12 @@ #![allow(deprecated)] // NetworkBinary::harness remains in the public proto for compatibility. use openshell_core::proto::{ - NetworkBinary, NetworkEndpoint, NetworkPolicyRule, ProviderProfile, ProviderProfileCategory, - ProviderProfileCredential, + GraphqlOperation, L7Allow, L7DenyRule, L7QueryMatcher, L7Rule, NetworkBinary, NetworkEndpoint, + NetworkPolicyRule, ProviderProfile, ProviderProfileCategory, ProviderProfileCredential, }; -use serde::{Deserialize, Deserializer, de}; -use std::collections::HashSet; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; +use std::collections::{HashMap, HashSet}; use std::sync::OnceLock; const BUILT_IN_PROFILE_YAMLS: &[&str] = &[ @@ -30,6 +31,8 @@ const BUILT_IN_PROFILE_YAMLS: &[&str] = &[ pub enum ProfileError { #[error("failed to parse provider profile YAML: {0}")] Parse(#[from] serde_yml::Error), + #[error("failed to parse provider profile JSON: {0}")] + JsonParse(#[from] serde_json::Error), #[error("provider profile id is required")] MissingId, #[error("duplicate provider profile id: {0}")] @@ -40,7 +43,33 @@ pub enum ProfileError { DuplicateCredentialEnvVar { id: String, env_var: String }, } -#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ProfileValidationDiagnostic { + pub source: String, + pub profile_id: String, + pub field: String, + pub message: String, + pub severity: String, +} + +impl ProfileValidationDiagnostic { + fn error( + source: impl Into, + profile_id: impl Into, + field: impl Into, + message: impl Into, + ) -> Self { + Self { + source: source.into(), + profile_id: profile_id.into(), + field: field.into(), + message: message.into(), + severity: "error".to_string(), + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] pub struct CredentialProfile { pub name: String, #[serde(default)] @@ -57,19 +86,111 @@ pub struct CredentialProfile { pub query_param: String, } -#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +// These YAML/JSON DTOs mirror the network policy protos intentionally. Keep +// every lossless conversion below in sync with proto/sandbox.proto. If a field +// is added to NetworkEndpoint, L7Rule, L7Allow, L7DenyRule, L7QueryMatcher, +// GraphqlOperation, or NetworkBinary, add it here and in both conversion +// directions unless the import/lint path explicitly rejects it. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] pub struct EndpointProfile { pub host: String, + #[serde(default, skip_serializing_if = "is_zero")] pub port: u32, - #[serde(default)] + #[serde(default, skip_serializing_if = "String::is_empty")] pub protocol: String, - #[serde(default)] + #[serde(default, skip_serializing_if = "String::is_empty")] + pub tls: String, + #[serde(default, skip_serializing_if = "String::is_empty")] pub access: String, - #[serde(default)] + #[serde(default, skip_serializing_if = "String::is_empty")] pub enforcement: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub rules: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub allowed_ips: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub ports: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub deny_rules: Vec, + #[serde(default, skip_serializing_if = "is_false")] + pub allow_encoded_slash: bool, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub persisted_queries: String, + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub graphql_persisted_queries: HashMap, + #[serde(default, skip_serializing_if = "is_zero")] + pub graphql_max_body_bytes: u32, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub path: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct L7RuleProfile { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub allow: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct L7AllowProfile { + #[serde(default, skip_serializing_if = "String::is_empty")] + pub method: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub path: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub command: String, + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub query: HashMap, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub operation_type: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub operation_name: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub fields: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct L7DenyRuleProfile { + #[serde(default, skip_serializing_if = "String::is_empty")] + pub method: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub path: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub command: String, + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub query: HashMap, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub operation_type: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub operation_name: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub fields: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct L7QueryMatcherProfile { + #[serde(default, skip_serializing_if = "String::is_empty")] + pub glob: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub any: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct GraphqlOperationProfile { + #[serde(default, skip_serializing_if = "String::is_empty")] + pub operation_type: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub operation_name: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub fields: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BinaryProfile { + pub path: String, + pub harness: bool, } -#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] pub struct ProviderTypeProfile { pub id: String, pub display_name: String, @@ -77,7 +198,8 @@ pub struct ProviderTypeProfile { pub description: String, #[serde( default = "default_category", - deserialize_with = "deserialize_category" + deserialize_with = "deserialize_category", + serialize_with = "serialize_category" )] pub category: ProviderProfileCategory, #[serde(default)] @@ -85,12 +207,43 @@ pub struct ProviderTypeProfile { #[serde(default)] pub endpoints: Vec, #[serde(default)] - pub binaries: Vec, + pub binaries: Vec, #[serde(default)] pub inference_capable: bool, } +// Provider profile import/export is expected to be lossless for the network +// policy fields exposed by the protobuf API. Do not collapse these DTOs into a +// narrower shape; direct gRPC imports and CLI YAML imports must preserve the +// same policy intent through storage and JIT composition. impl ProviderTypeProfile { + #[must_use] + pub fn from_proto(profile: &ProviderProfile) -> Self { + Self { + id: profile.id.clone(), + display_name: profile.display_name.clone(), + description: profile.description.clone(), + category: ProviderProfileCategory::try_from(profile.category) + .unwrap_or(ProviderProfileCategory::Other), + credentials: profile + .credentials + .iter() + .map(|credential| CredentialProfile { + name: credential.name.clone(), + description: credential.description.clone(), + env_vars: credential.env_vars.clone(), + required: credential.required, + auth_style: credential.auth_style.clone(), + header_name: credential.header_name.clone(), + query_param: credential.query_param.clone(), + }) + .collect(), + endpoints: profile.endpoints.iter().map(endpoint_from_proto).collect(), + binaries: profile.binaries.iter().map(binary_from_proto).collect(), + inference_capable: profile.inference_capable, + } + } + #[must_use] pub fn credential_env_vars(&self) -> Vec<&str> { let mut vars = Vec::new(); @@ -125,14 +278,7 @@ impl ProviderTypeProfile { }) .collect(), endpoints: self.endpoints.iter().map(endpoint_to_proto).collect(), - binaries: self - .binaries - .iter() - .map(|path| NetworkBinary { - path: path.clone(), - harness: false, - }) - .collect(), + binaries: self.binaries.iter().map(binary_to_proto).collect(), inference_capable: self.inference_capable, } } @@ -142,14 +288,54 @@ impl ProviderTypeProfile { NetworkPolicyRule { name: rule_name.to_string(), endpoints: self.endpoints.iter().map(endpoint_to_proto).collect(), - binaries: self - .binaries - .iter() - .map(|path| NetworkBinary { - path: path.clone(), - harness: false, - }) - .collect(), + binaries: self.binaries.iter().map(binary_to_proto).collect(), + } + } +} + +impl Serialize for BinaryProfile { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + if !self.harness { + return serializer.serialize_str(&self.path); + } + let mut state = serializer.serialize_struct("BinaryProfile", 2)?; + state.serialize_field("path", &self.path)?; + state.serialize_field("harness", &self.harness)?; + state.end() + } +} + +impl<'de> Deserialize<'de> for BinaryProfile { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(untagged)] + enum BinaryProfileInput { + Path(String), + Object(BinaryProfileObject), + } + + #[derive(Deserialize)] + struct BinaryProfileObject { + path: String, + #[serde(default)] + harness: bool, + } + + match BinaryProfileInput::deserialize(deserializer)? { + BinaryProfileInput::Path(path) => Ok(Self { + path, + harness: false, + }), + BinaryProfileInput::Object(binary) => Ok(Self { + path: binary.path, + harness: binary.harness, + }), } } } @@ -158,6 +344,16 @@ fn default_category() -> ProviderProfileCategory { ProviderProfileCategory::Other } +#[allow(clippy::trivially_copy_pass_by_ref)] +fn is_false(value: &bool) -> bool { + !*value +} + +#[allow(clippy::trivially_copy_pass_by_ref)] +fn is_zero(value: &u32) -> bool { + *value == 0 +} + fn deserialize_category<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, @@ -167,7 +363,19 @@ where .ok_or_else(|| de::Error::custom(format!("unsupported provider profile category: {raw}"))) } -fn provider_profile_category_from_yaml(raw: &str) -> Option { +#[allow(clippy::trivially_copy_pass_by_ref)] +fn serialize_category( + category: &ProviderProfileCategory, + serializer: S, +) -> Result +where + S: Serializer, +{ + serializer.serialize_str(provider_profile_category_to_yaml(*category)) +} + +#[must_use] +pub fn provider_profile_category_from_yaml(raw: &str) -> Option { match raw.trim().to_ascii_lowercase().replace('-', "_").as_str() { "" | "other" => Some(ProviderProfileCategory::Other), "inference" => Some(ProviderProfileCategory::Inference), @@ -180,20 +388,188 @@ fn provider_profile_category_from_yaml(raw: &str) -> Option &'static str { + match category { + ProviderProfileCategory::Inference => "inference", + ProviderProfileCategory::Agent => "agent", + ProviderProfileCategory::SourceControl => "source_control", + ProviderProfileCategory::Messaging => "messaging", + ProviderProfileCategory::Data => "data", + ProviderProfileCategory::Knowledge => "knowledge", + ProviderProfileCategory::Other | ProviderProfileCategory::Unspecified => "other", + } +} + fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { NetworkEndpoint { host: endpoint.host.clone(), port: endpoint.port, protocol: endpoint.protocol.clone(), - tls: String::new(), + tls: endpoint.tls.clone(), enforcement: endpoint.enforcement.clone(), access: endpoint.access.clone(), - rules: Vec::new(), - allowed_ips: Vec::new(), - ports: Vec::new(), - deny_rules: Vec::new(), - allow_encoded_slash: false, - ..Default::default() + rules: endpoint.rules.iter().map(rule_to_proto).collect(), + allowed_ips: endpoint.allowed_ips.clone(), + ports: endpoint.ports.clone(), + deny_rules: endpoint.deny_rules.iter().map(deny_rule_to_proto).collect(), + allow_encoded_slash: endpoint.allow_encoded_slash, + persisted_queries: endpoint.persisted_queries.clone(), + graphql_persisted_queries: endpoint + .graphql_persisted_queries + .iter() + .map(|(name, operation)| (name.clone(), graphql_operation_to_proto(operation))) + .collect(), + graphql_max_body_bytes: endpoint.graphql_max_body_bytes, + path: endpoint.path.clone(), + } +} + +fn endpoint_from_proto(endpoint: &NetworkEndpoint) -> EndpointProfile { + EndpointProfile { + host: endpoint.host.clone(), + port: endpoint.port, + protocol: endpoint.protocol.clone(), + tls: endpoint.tls.clone(), + access: endpoint.access.clone(), + enforcement: endpoint.enforcement.clone(), + rules: endpoint.rules.iter().map(rule_from_proto).collect(), + allowed_ips: endpoint.allowed_ips.clone(), + ports: endpoint.ports.clone(), + deny_rules: endpoint + .deny_rules + .iter() + .map(deny_rule_from_proto) + .collect(), + allow_encoded_slash: endpoint.allow_encoded_slash, + persisted_queries: endpoint.persisted_queries.clone(), + graphql_persisted_queries: endpoint + .graphql_persisted_queries + .iter() + .map(|(name, operation)| (name.clone(), graphql_operation_from_proto(operation))) + .collect(), + graphql_max_body_bytes: endpoint.graphql_max_body_bytes, + path: endpoint.path.clone(), + } +} + +fn binary_to_proto(binary: &BinaryProfile) -> NetworkBinary { + NetworkBinary { + path: binary.path.clone(), + harness: binary.harness, + } +} + +fn binary_from_proto(binary: &NetworkBinary) -> BinaryProfile { + BinaryProfile { + path: binary.path.clone(), + harness: binary.harness, + } +} + +fn rule_to_proto(rule: &L7RuleProfile) -> L7Rule { + L7Rule { + allow: rule.allow.as_ref().map(allow_to_proto), + } +} + +fn rule_from_proto(rule: &L7Rule) -> L7RuleProfile { + L7RuleProfile { + allow: rule.allow.as_ref().map(allow_from_proto), + } +} + +fn allow_to_proto(allow: &L7AllowProfile) -> L7Allow { + L7Allow { + method: allow.method.clone(), + path: allow.path.clone(), + command: allow.command.clone(), + query: allow + .query + .iter() + .map(|(name, matcher)| (name.clone(), query_matcher_to_proto(matcher))) + .collect(), + operation_type: allow.operation_type.clone(), + operation_name: allow.operation_name.clone(), + fields: allow.fields.clone(), + } +} + +fn allow_from_proto(allow: &L7Allow) -> L7AllowProfile { + L7AllowProfile { + method: allow.method.clone(), + path: allow.path.clone(), + command: allow.command.clone(), + query: allow + .query + .iter() + .map(|(name, matcher)| (name.clone(), query_matcher_from_proto(matcher))) + .collect(), + operation_type: allow.operation_type.clone(), + operation_name: allow.operation_name.clone(), + fields: allow.fields.clone(), + } +} + +fn deny_rule_to_proto(rule: &L7DenyRuleProfile) -> L7DenyRule { + L7DenyRule { + method: rule.method.clone(), + path: rule.path.clone(), + command: rule.command.clone(), + query: rule + .query + .iter() + .map(|(name, matcher)| (name.clone(), query_matcher_to_proto(matcher))) + .collect(), + operation_type: rule.operation_type.clone(), + operation_name: rule.operation_name.clone(), + fields: rule.fields.clone(), + } +} + +fn deny_rule_from_proto(rule: &L7DenyRule) -> L7DenyRuleProfile { + L7DenyRuleProfile { + method: rule.method.clone(), + path: rule.path.clone(), + command: rule.command.clone(), + query: rule + .query + .iter() + .map(|(name, matcher)| (name.clone(), query_matcher_from_proto(matcher))) + .collect(), + operation_type: rule.operation_type.clone(), + operation_name: rule.operation_name.clone(), + fields: rule.fields.clone(), + } +} + +fn query_matcher_to_proto(matcher: &L7QueryMatcherProfile) -> L7QueryMatcher { + L7QueryMatcher { + glob: matcher.glob.clone(), + any: matcher.any.clone(), + } +} + +fn query_matcher_from_proto(matcher: &L7QueryMatcher) -> L7QueryMatcherProfile { + L7QueryMatcherProfile { + glob: matcher.glob.clone(), + any: matcher.any.clone(), + } +} + +fn graphql_operation_to_proto(operation: &GraphqlOperationProfile) -> GraphqlOperation { + GraphqlOperation { + operation_type: operation.operation_type.clone(), + operation_name: operation.operation_name.clone(), + fields: operation.fields.clone(), + } +} + +fn graphql_operation_from_proto(operation: &GraphqlOperation) -> GraphqlOperationProfile { + GraphqlOperationProfile { + operation_type: operation.operation_type.clone(), + operation_name: operation.operation_name.clone(), + fields: operation.fields.clone(), } } @@ -201,6 +577,26 @@ pub fn parse_profile_yaml(input: &str) -> Result(input)?) } +pub fn parse_profile_json(input: &str) -> Result { + Ok(serde_json::from_str::(input)?) +} + +pub fn profile_to_yaml(profile: &ProviderTypeProfile) -> Result { + Ok(serde_yml::to_string(profile)?) +} + +pub fn profile_to_json(profile: &ProviderTypeProfile) -> Result { + Ok(serde_json::to_string_pretty(profile)?) +} + +pub fn profiles_to_yaml(profiles: &[ProviderTypeProfile]) -> Result { + Ok(serde_yml::to_string(profiles)?) +} + +pub fn profiles_to_json(profiles: &[ProviderTypeProfile]) -> Result { + Ok(serde_json::to_string_pretty(profiles)?) +} + pub fn parse_profile_catalog_yamls( inputs: &[&str], ) -> Result, ProfileError> { @@ -214,38 +610,215 @@ pub fn parse_profile_catalog_yamls( } fn validate_profiles(profiles: &[ProviderTypeProfile]) -> Result<(), ProfileError> { - let mut ids = HashSet::new(); - for profile in profiles { - if profile.id.trim().is_empty() { + let diagnostics = validate_profile_set( + &profiles + .iter() + .map(|profile| (String::new(), profile.clone())) + .collect::>(), + ); + if let Some(diagnostic) = diagnostics.first() { + if diagnostic.field == "id" && diagnostic.message == "provider profile id is required" { return Err(ProfileError::MissingId); } - if !ids.insert(profile.id.clone()) { - return Err(ProfileError::DuplicateId(profile.id.clone())); + if diagnostic.field == "id" + && diagnostic + .message + .starts_with("duplicate provider profile id") + { + return Err(ProfileError::DuplicateId(diagnostic.profile_id.clone())); + } + if diagnostic.field.starts_with("credentials.env_vars") { + return Err(ProfileError::DuplicateCredentialEnvVar { + id: diagnostic.profile_id.clone(), + env_var: diagnostic + .message + .trim_start_matches("duplicate credential env var '") + .trim_end_matches('\'') + .to_string(), + }); + } + if diagnostic.field.starts_with("endpoints") + && let Some(profile) = profiles + .iter() + .find(|profile| profile.id == diagnostic.profile_id) + && let Some(endpoint) = profile + .endpoints + .iter() + .find(|endpoint| !endpoint_is_valid(endpoint)) + { + return Err(ProfileError::InvalidEndpoint { + id: profile.id.clone(), + host: endpoint.host.clone(), + port: endpoint.port, + }); + } + } + + Ok(()) +} + +#[must_use] +pub fn normalize_profile_id(input: &str) -> Option { + let id = input.trim().to_ascii_lowercase(); + if is_valid_profile_id(&id) { + Some(id) + } else { + None + } +} + +fn is_valid_profile_id(id: &str) -> bool { + !id.is_empty() + && !id.starts_with('-') + && !id.ends_with('-') + && id.split('-').all(|part| { + !part.is_empty() + && part + .bytes() + .all(|b| b.is_ascii_lowercase() || b.is_ascii_digit()) + }) +} + +#[must_use] +pub fn validate_profile_set( + profiles: &[(String, ProviderTypeProfile)], +) -> Vec { + let mut diagnostics = Vec::new(); + let mut ids = HashSet::new(); + for (source, profile) in profiles { + let raw_profile_id = profile.id.as_str(); + let profile_id = raw_profile_id.trim(); + if profile_id.is_empty() { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + "", + "id", + "provider profile id is required", + )); + } else if normalize_profile_id(raw_profile_id).as_deref() != Some(raw_profile_id) { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "id", + "provider profile id must be lowercase kebab-case using only a-z, 0-9, and '-'", + )); + } else if !ids.insert(profile_id.to_string()) { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "id", + format!("duplicate provider profile id: {profile_id}"), + )); + } + + let mut credential_names = HashSet::new(); + for credential in &profile.credentials { + let credential_name = credential.name.trim(); + if credential_name.is_empty() { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.name", + "credential name is required", + )); + } else if !credential_names.insert(credential_name.to_string()) { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.name", + format!("duplicate credential name: {credential_name}"), + )); + } } let mut env_vars = HashSet::new(); for credential in &profile.credentials { for env_var in &credential.env_vars { - if !env_vars.insert(env_var) { - return Err(ProfileError::DuplicateCredentialEnvVar { - id: profile.id.clone(), - env_var: env_var.clone(), - }); + if env_var.trim().is_empty() { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.env_vars", + "credential env var must not be empty", + )); + } else if !env_vars.insert(env_var.trim().to_string()) { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.env_vars", + format!("duplicate credential env var '{env_var}'"), + )); } } + + let auth_style = credential.auth_style.trim().to_ascii_lowercase(); + match auth_style.as_str() { + "" | "basic" => {} + "bearer" | "header" => { + if credential.header_name.trim().is_empty() { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.header_name", + format!("header_name is required for {auth_style} auth"), + )); + } + } + "query" => { + if credential.query_param.trim().is_empty() { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.query_param", + "query_param is required for query auth", + )); + } + } + _ => diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.auth_style", + format!("unsupported auth_style: {}", credential.auth_style), + )), + } } - for endpoint in &profile.endpoints { - if endpoint.host.trim().is_empty() || endpoint.port == 0 || endpoint.port > 65_535 { - return Err(ProfileError::InvalidEndpoint { - id: profile.id.clone(), - host: endpoint.host.clone(), - port: endpoint.port, - }); + for (index, endpoint) in profile.endpoints.iter().enumerate() { + if !endpoint_is_valid(endpoint) { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + format!("endpoints[{index}]"), + format!("invalid endpoint '{}:{}'", endpoint.host, endpoint.port), + )); + } + } + + for (index, binary) in profile.binaries.iter().enumerate() { + if binary.path.trim().is_empty() { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + format!("binaries[{index}]"), + "binary path must not be empty", + )); } } } - Ok(()) + diagnostics +} + +fn endpoint_is_valid(endpoint: &EndpointProfile) -> bool { + if endpoint.host.trim().is_empty() { + return false; + } + if !endpoint.ports.is_empty() { + return endpoint + .ports + .iter() + .all(|port| (1..=65_535).contains(port)); + } + (1..=65_535).contains(&endpoint.port) } static DEFAULT_PROFILES: OnceLock> = OnceLock::new(); @@ -272,8 +845,9 @@ mod tests { use openshell_core::proto::ProviderProfileCategory; use super::{ - ProfileError, default_profiles, get_default_profile, parse_profile_catalog_yamls, - parse_profile_yaml, + ProfileError, ProviderTypeProfile, default_profiles, get_default_profile, + normalize_profile_id, parse_profile_catalog_yamls, parse_profile_json, parse_profile_yaml, + profile_to_json, profile_to_yaml, validate_profile_set, }; #[test] @@ -328,6 +902,209 @@ credentials: assert_eq!(profile.credential_env_vars(), vec!["EXAMPLE_API_KEY"]); } + #[test] + fn profile_json_round_trip_preserves_compact_dto_shape() { + let profile = get_default_profile("github").expect("github profile"); + let json = profile_to_json(profile).expect("profile should serialize"); + let parsed = parse_profile_json(&json).expect("profile should parse"); + + assert_eq!(parsed.id, "github"); + assert_eq!(parsed.category, ProviderProfileCategory::SourceControl); + assert_eq!(parsed.binaries[0].path, "/usr/bin/gh"); + } + + #[test] + fn profile_yaml_round_trip_preserves_full_network_policy_fields() { + let profile = parse_profile_yaml( + r" +id: advanced +display_name: Advanced +category: other +endpoints: + - host: api.example.com + ports: [443, 8443] + protocol: rest + tls: terminate + enforcement: enforce + access: read-only + rules: + - allow: + method: GET + path: /v1/** + query: + state: + any: [open, closed] + allowed_ips: [10.0.0.0/24] + deny_rules: + - method: POST + path: /admin/** + allow_encoded_slash: true + persisted_queries: allow_registered + graphql_persisted_queries: + hash-a: + operation_type: query + operation_name: Viewer + fields: [viewer] + graphql_max_body_bytes: 131072 + path: /graphql +binaries: + - path: /usr/bin/custom + harness: true +", + ) + .expect("profile should parse"); + let diagnostics = validate_profile_set(&[("advanced.yaml".to_string(), profile.clone())]); + assert!( + diagnostics.is_empty(), + "unexpected diagnostics: {diagnostics:?}" + ); + + let proto = profile.to_proto(); + let endpoint = proto.endpoints.first().expect("endpoint should exist"); + assert_eq!(endpoint.port, 0); + assert_eq!(endpoint.ports, vec![443, 8443]); + assert_eq!(endpoint.tls, "terminate"); + assert_eq!(endpoint.allowed_ips, vec!["10.0.0.0/24"]); + assert!(endpoint.allow_encoded_slash); + assert_eq!(endpoint.persisted_queries, "allow_registered"); + assert_eq!(endpoint.graphql_max_body_bytes, 131_072); + assert_eq!(endpoint.path, "/graphql"); + assert_eq!( + endpoint + .rules + .first() + .and_then(|rule| rule.allow.as_ref()) + .map(|allow| allow.method.as_str()), + Some("GET") + ); + assert_eq!(endpoint.deny_rules[0].method, "POST"); + assert_eq!( + endpoint + .graphql_persisted_queries + .get("hash-a") + .map(|operation| operation.operation_name.as_str()), + Some("Viewer") + ); + assert!(proto.binaries[0].harness); + + let reparsed = parse_profile_yaml(&profile_to_yaml(&profile).expect("serialize YAML")) + .expect("serialized profile should parse"); + let reprotoo = reparsed.to_proto(); + assert_eq!(reprotoo.endpoints[0].rules.len(), 1); + assert_eq!(reprotoo.endpoints[0].deny_rules.len(), 1); + assert_eq!(reprotoo.endpoints[0].ports, vec![443, 8443]); + assert!(reprotoo.binaries[0].harness); + } + + #[test] + fn validate_profile_set_returns_all_discoverable_diagnostics() { + let profile = parse_profile_yaml( + r#" +id: broken +display_name: Broken +credentials: + - name: api_key + env_vars: [BROKEN_TOKEN] + auth_style: query + - name: api_key + env_vars: [BROKEN_TOKEN, ""] + auth_style: unknown +endpoints: + - host: "" + port: 0 +binaries: ["", /usr/bin/broken] +"#, + ) + .expect("profile should parse"); + + let diagnostics = validate_profile_set(&[("broken.yaml".to_string(), profile)]); + let messages = diagnostics + .iter() + .map(|diagnostic| diagnostic.message.as_str()) + .collect::>(); + + assert!(messages.contains(&"duplicate credential name: api_key")); + assert!(messages.contains(&"duplicate credential env var 'BROKEN_TOKEN'")); + assert!(messages.contains(&"credential env var must not be empty")); + assert!(messages.contains(&"query_param is required for query auth")); + assert!(messages.contains(&"unsupported auth_style: unknown")); + assert!( + messages + .iter() + .any(|message| message.starts_with("invalid endpoint")) + ); + assert!(messages.contains(&"binary path must not be empty")); + } + + #[test] + fn validate_profile_set_rejects_noncanonical_profile_ids() { + let profiles = [ + ( + "space.yaml".to_string(), + ProviderTypeProfile { + id: " alex-api ".to_string(), + display_name: "Space".to_string(), + description: String::new(), + category: ProviderProfileCategory::Other, + credentials: Vec::new(), + endpoints: Vec::new(), + binaries: Vec::new(), + inference_capable: false, + }, + ), + ( + "underscore.yaml".to_string(), + ProviderTypeProfile { + id: "alex_api".to_string(), + display_name: "Underscore".to_string(), + description: String::new(), + category: ProviderProfileCategory::Other, + credentials: Vec::new(), + endpoints: Vec::new(), + binaries: Vec::new(), + inference_capable: false, + }, + ), + ( + "case.yaml".to_string(), + ProviderTypeProfile { + id: "Alex-API".to_string(), + display_name: "Case".to_string(), + description: String::new(), + category: ProviderProfileCategory::Other, + credentials: Vec::new(), + endpoints: Vec::new(), + binaries: Vec::new(), + inference_capable: false, + }, + ), + ]; + + let diagnostics = validate_profile_set(&profiles); + let id_errors = diagnostics + .iter() + .filter(|diagnostic| diagnostic.field == "id") + .collect::>(); + + assert_eq!(id_errors.len(), 3); + assert!( + id_errors + .iter() + .all(|diagnostic| diagnostic.message.contains("lowercase kebab-case")) + ); + } + + #[test] + fn normalize_profile_id_trims_and_lowercases_valid_ids() { + assert_eq!( + normalize_profile_id(" Alex-API "), + Some("alex-api".to_string()) + ); + assert_eq!(normalize_profile_id("alex_api"), None); + assert_eq!(normalize_profile_id("-alex"), None); + assert_eq!(normalize_profile_id("alex--api"), None); + } + #[test] fn parse_profile_catalog_yamls_rejects_duplicate_ids() { let err = parse_profile_catalog_yamls(&[ diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index 31970e9c5..87af948ed 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -12,23 +12,25 @@ use openshell_core::proto::{ ApproveAllDraftChunksRequest, ApproveAllDraftChunksResponse, ApproveDraftChunkRequest, ApproveDraftChunkResponse, ClearDraftChunksRequest, ClearDraftChunksResponse, CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - EditDraftChunkRequest, EditDraftChunkResponse, ExecSandboxEvent, ExecSandboxRequest, - GatewayMessage, GetDraftHistoryRequest, GetDraftHistoryResponse, GetDraftPolicyRequest, - GetDraftPolicyResponse, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderProfileRequest, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxLogsRequest, GetSandboxLogsResponse, - GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, + DeleteProviderProfileRequest, DeleteProviderProfileResponse, DeleteProviderRequest, + DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, EditDraftChunkRequest, + EditDraftChunkResponse, ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, + GetDraftHistoryRequest, GetDraftHistoryResponse, GetDraftPolicyRequest, GetDraftPolicyResponse, + GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderProfileRequest, + GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxLogsRequest, + GetSandboxLogsResponse, GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProviderProfilesRequest, ListProviderProfilesResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxPoliciesRequest, - ListSandboxPoliciesResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderProfileResponse, ProviderResponse, PushSandboxLogsRequest, PushSandboxLogsResponse, - RejectDraftChunkRequest, RejectDraftChunkResponse, RelayFrame, ReportPolicyStatusRequest, - ReportPolicyStatusResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, - SupervisorMessage, UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, - UpdateConfigResponse, UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, + HealthRequest, HealthResponse, ImportProviderProfilesRequest, ImportProviderProfilesResponse, + LintProviderProfilesRequest, LintProviderProfilesResponse, ListProviderProfilesRequest, + ListProviderProfilesResponse, ListProvidersRequest, ListProvidersResponse, + ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, ListSandboxesRequest, + ListSandboxesResponse, ProviderProfileResponse, ProviderResponse, PushSandboxLogsRequest, + PushSandboxLogsResponse, RejectDraftChunkRequest, RejectDraftChunkResponse, RelayFrame, + ReportPolicyStatusRequest, ReportPolicyStatusResponse, RevokeSshSessionRequest, + RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, + SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, SupervisorMessage, + UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, UpdateConfigResponse, + UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, }; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; @@ -258,17 +260,28 @@ impl OpenShell for OpenShellService { &self, request: Request, ) -> Result, Status> { - Ok(provider::handle_list_provider_profiles( - &self.state, - request, - )) + provider::handle_list_provider_profiles(&self.state, request).await } async fn get_provider_profile( &self, request: Request, ) -> Result, Status> { - provider::handle_get_provider_profile(&self.state, request) + provider::handle_get_provider_profile(&self.state, request).await + } + + async fn import_provider_profiles( + &self, + request: Request, + ) -> Result, Status> { + provider::handle_import_provider_profiles(&self.state, request).await + } + + async fn lint_provider_profiles( + &self, + request: Request, + ) -> Result, Status> { + provider::handle_lint_provider_profiles(&self.state, request).await } async fn update_provider( @@ -285,6 +298,13 @@ impl OpenShell for OpenShellService { provider::handle_delete_provider(&self.state, request).await } + async fn delete_provider_profile( + &self, + request: Request, + ) -> Result, Status> { + provider::handle_delete_provider_profile(&self.state, request).await + } + // --- Config / Policy --- async fn get_sandbox_config( diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 6caa81031..2c62c930a 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -46,7 +46,7 @@ use openshell_ocsf::{ use openshell_policy::{ PolicyMergeOp, ProviderPolicyLayer, compose_effective_policy, merge_policy, }; -use openshell_providers::get_default_profile; +use openshell_providers::{get_default_profile, normalize_provider_type}; use prost::Message; use sha2::{Digest, Sha256}; use std::collections::{BTreeMap, HashMap}; @@ -498,13 +498,28 @@ async fn profile_provider_policy_layers( .ok_or_else(|| Status::failed_precondition(format!("provider '{name}' not found")))?; let provider_type = provider.r#type.trim(); - let Some(profile) = get_default_profile(provider_type) else { - warn!( - provider_name = %name, - provider_type, - "provider type has no default profile; skipping provider policy layer" - ); - continue; + let profile = if let Some(canonical_type) = normalize_provider_type(provider_type) { + let Some(profile) = get_default_profile(canonical_type) else { + warn!( + provider_name = %name, + provider_type, + "legacy provider type has no profile; skipping provider policy layer" + ); + continue; + }; + profile.clone() + } else { + let Some(profile) = + super::provider::get_provider_type_profile(store, provider_type).await? + else { + warn!( + provider_name = %name, + provider_type, + "provider type has no profile; skipping provider policy layer" + ); + continue; + }; + profile }; let rule_name = openshell_policy::provider_rule_name(provider.object_name()); @@ -2868,6 +2883,150 @@ mod tests { assert!(layers.is_empty()); } + #[tokio::test] + async fn provider_policy_layers_skip_custom_profile_for_legacy_provider_type() { + let store = Store::connect("sqlite::memory:").await.unwrap(); + store + .put_message(&test_provider("custom-provider", "generic")) + .await + .unwrap(); + store + .put_message(&openshell_core::proto::StoredProviderProfile { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "profile-generic".to_string(), + name: "generic".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + }), + profile: Some(openshell_core::proto::ProviderProfile { + id: "generic".to_string(), + display_name: "Generic Override".to_string(), + description: String::new(), + category: openshell_core::proto::ProviderProfileCategory::Other as i32, + credentials: Vec::new(), + endpoints: vec![NetworkEndpoint { + host: "backdoor.example".to_string(), + port: 443, + ..Default::default() + }], + binaries: Vec::new(), + inference_capable: false, + }), + }) + .await + .unwrap(); + + let layers = profile_provider_policy_layers(&store, &["custom-provider".to_string()]) + .await + .unwrap(); + + assert!(layers.is_empty()); + } + + #[tokio::test] + #[allow(deprecated)] + async fn provider_policy_layers_include_custom_provider_profiles() { + let store = Store::connect("sqlite::memory:").await.unwrap(); + store + .put_message(&test_provider("work-custom", "custom-api")) + .await + .unwrap(); + store + .put_message(&openshell_core::proto::StoredProviderProfile { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "profile-custom-api".to_string(), + name: "custom-api".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + }), + profile: Some(openshell_core::proto::ProviderProfile { + id: "custom-api".to_string(), + display_name: "Custom API".to_string(), + description: String::new(), + category: openshell_core::proto::ProviderProfileCategory::Other as i32, + credentials: Vec::new(), + endpoints: vec![NetworkEndpoint { + host: "api.custom.example".to_string(), + protocol: "rest".to_string(), + ports: vec![443, 8443], + allowed_ips: vec!["10.0.0.0/24".to_string()], + rules: vec![L7Rule { + allow: Some(openshell_core::proto::L7Allow { + method: "GET".to_string(), + path: "/v1/**".to_string(), + ..Default::default() + }), + }], + allow_encoded_slash: true, + path: "/v1".to_string(), + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/custom".to_string(), + harness: true, + }], + inference_capable: false, + }), + }) + .await + .unwrap(); + + let layers = profile_provider_policy_layers(&store, &["work-custom".to_string()]) + .await + .unwrap(); + + assert_eq!(layers.len(), 1); + assert_eq!(layers[0].rule_name, "_provider_work_custom"); + assert_eq!(layers[0].rule.endpoints[0].host, "api.custom.example"); + assert_eq!(layers[0].rule.endpoints[0].ports, vec![443, 8443]); + assert_eq!(layers[0].rule.endpoints[0].rules.len(), 1); + assert_eq!(layers[0].rule.endpoints[0].allowed_ips, vec!["10.0.0.0/24"]); + assert!(layers[0].rule.endpoints[0].allow_encoded_slash); + assert_eq!(layers[0].rule.endpoints[0].path, "/v1"); + assert!(layers[0].rule.binaries[0].harness); + } + + #[tokio::test] + async fn provider_policy_layers_normalize_custom_provider_type_ids() { + let store = Store::connect("sqlite::memory:").await.unwrap(); + store + .put_message(&test_provider("work-custom", " Custom-API ")) + .await + .unwrap(); + store + .put_message(&openshell_core::proto::StoredProviderProfile { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "profile-custom-api".to_string(), + name: "custom-api".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + }), + profile: Some(openshell_core::proto::ProviderProfile { + id: "custom-api".to_string(), + display_name: "Custom API".to_string(), + description: String::new(), + category: openshell_core::proto::ProviderProfileCategory::Other as i32, + credentials: Vec::new(), + endpoints: vec![NetworkEndpoint { + host: "api.custom.example".to_string(), + port: 443, + ..Default::default() + }], + binaries: Vec::new(), + inference_capable: false, + }), + }) + .await + .unwrap(); + + let layers = profile_provider_policy_layers(&store, &["work-custom".to_string()]) + .await + .unwrap(); + + assert_eq!(layers.len(), 1); + assert_eq!(layers[0].rule.endpoints[0].host, "api.custom.example"); + } + #[tokio::test] async fn provider_policy_layers_include_known_provider_profiles() { let store = Store::connect("sqlite::memory:").await.unwrap(); diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 3ff2547b1..2f4893073 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -5,7 +5,7 @@ #![allow(clippy::result_large_err)] // gRPC handlers return Result, Status> -use crate::persistence::{ObjectType, Store, generate_name}; +use crate::persistence::{ObjectName, ObjectType, Store, generate_name}; use openshell_core::proto::Provider; use prost::Message; use tonic::Status; @@ -273,12 +273,18 @@ impl ObjectType for Provider { use crate::ServerState; use openshell_core::proto::{ - CreateProviderRequest, DeleteProviderRequest, DeleteProviderResponse, - GetProviderProfileRequest, GetProviderRequest, ListProviderProfilesRequest, - ListProviderProfilesResponse, ListProvidersRequest, ListProvidersResponse, - ProviderProfileResponse, ProviderResponse, UpdateProviderRequest, + CreateProviderRequest, DeleteProviderProfileRequest, DeleteProviderProfileResponse, + DeleteProviderRequest, DeleteProviderResponse, GetProviderProfileRequest, GetProviderRequest, + ImportProviderProfilesRequest, ImportProviderProfilesResponse, LintProviderProfilesRequest, + LintProviderProfilesResponse, ListProviderProfilesRequest, ListProviderProfilesResponse, + ListProvidersRequest, ListProvidersResponse, ProviderProfile, ProviderProfileDiagnostic, + ProviderProfileImportItem, ProviderProfileResponse, ProviderResponse, Sandbox, + StoredProviderProfile, UpdateProviderRequest, +}; +use openshell_providers::{ + ProfileValidationDiagnostic, ProviderTypeProfile, default_profiles, get_default_profile, + normalize_profile_id, normalize_provider_type, validate_profile_set, }; -use openshell_providers::{default_profiles, get_default_profile}; use std::sync::Arc; use tonic::{Request, Response}; @@ -320,32 +326,39 @@ pub(super) async fn handle_list_providers( Ok(Response::new(ListProvidersResponse { providers })) } -pub(super) fn handle_list_provider_profiles( - _state: &Arc, +impl ObjectType for StoredProviderProfile { + fn object_type() -> &'static str { + "provider_profile" + } +} + +pub(super) async fn handle_list_provider_profiles( + state: &Arc, request: Request, -) -> Response { +) -> Result, Status> { let request = request.into_inner(); let limit = clamp_limit(request.limit, 100, MAX_PAGE_SIZE) as usize; let offset = request.offset as usize; - let profiles = default_profiles() - .iter() + let mut profiles = merged_provider_profiles(state.store.as_ref()).await?; + profiles.sort_by(|left, right| left.id.cmp(&right.id)); + let profiles = profiles + .into_iter() .skip(offset) .take(limit) - .map(openshell_providers::ProviderTypeProfile::to_proto) + .map(|profile| profile.to_proto()) .collect(); - Response::new(ListProviderProfilesResponse { profiles }) + Ok(Response::new(ListProviderProfilesResponse { profiles })) } -pub(super) fn handle_get_provider_profile( - _state: &Arc, +pub(super) async fn handle_get_provider_profile( + state: &Arc, request: Request, ) -> Result, Status> { let id = request.into_inner().id; - if id.trim().is_empty() { - return Err(Status::invalid_argument("id is required")); - } - let profile = get_default_profile(id.trim()) + let id = normalize_profile_id_request(&id)?; + let profile = get_provider_type_profile(state.store.as_ref(), &id) + .await? .ok_or_else(|| Status::not_found("provider profile not found"))? .to_proto(); @@ -354,6 +367,315 @@ pub(super) fn handle_get_provider_profile( })) } +pub(super) async fn handle_import_provider_profiles( + state: &Arc, + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + let (profiles, mut diagnostics) = profiles_from_import_items(&request.profiles); + add_empty_profile_set_diagnostic(&profiles, &mut diagnostics); + diagnostics.extend(profile_conflict_diagnostics(state.store.as_ref(), &profiles).await?); + diagnostics.extend(validate_profile_set(&profiles)); + + if has_errors(&diagnostics) { + return Ok(Response::new(ImportProviderProfilesResponse { + diagnostics: diagnostics.into_iter().map(proto_diagnostic).collect(), + profiles: Vec::new(), + imported: false, + })); + } + + let mut imported = Vec::with_capacity(profiles.len()); + for (_, profile) in profiles { + let stored = stored_provider_profile(profile.to_proto()); + state + .store + .put_message(&stored) + .await + .map_err(|e| Status::internal(format!("persist provider profile failed: {e}")))?; + imported.push(stored.profile.unwrap_or_default()); + } + + Ok(Response::new(ImportProviderProfilesResponse { + diagnostics: Vec::new(), + profiles: imported, + imported: true, + })) +} + +pub(super) async fn handle_lint_provider_profiles( + state: &Arc, + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + let (profiles, mut diagnostics) = profiles_from_import_items(&request.profiles); + add_empty_profile_set_diagnostic(&profiles, &mut diagnostics); + diagnostics.extend(profile_conflict_diagnostics(state.store.as_ref(), &profiles).await?); + diagnostics.extend(validate_profile_set(&profiles)); + let valid = !has_errors(&diagnostics); + + Ok(Response::new(LintProviderProfilesResponse { + diagnostics: diagnostics.into_iter().map(proto_diagnostic).collect(), + valid, + })) +} + +pub(super) async fn handle_delete_provider_profile( + state: &Arc, + request: Request, +) -> Result, Status> { + let id = request.into_inner().id; + let id = normalize_profile_id_request(&id)?; + if get_default_profile(&id).is_some() { + return Err(Status::failed_precondition( + "built-in provider profiles cannot be deleted", + )); + } + + let existing = state + .store + .get_message_by_name::(&id) + .await + .map_err(|e| Status::internal(format!("fetch provider profile failed: {e}")))?; + if existing.is_none() { + return Err(Status::not_found("provider profile not found")); + } + + let blocking_sandboxes = sandboxes_using_profile(state.store.as_ref(), &id).await?; + if !blocking_sandboxes.is_empty() { + return Err(Status::failed_precondition(format!( + "provider profile '{id}' is in use by sandboxes: {}", + blocking_sandboxes.join(", ") + ))); + } + + let deleted = state + .store + .delete_by_name(StoredProviderProfile::object_type(), &id) + .await + .map_err(|e| Status::internal(format!("delete provider profile failed: {e}")))?; + + Ok(Response::new(DeleteProviderProfileResponse { deleted })) +} + +pub(super) async fn get_provider_type_profile( + store: &Store, + id: &str, +) -> Result, Status> { + let Some(id) = normalize_profile_id(id) else { + return Ok(None); + }; + if let Some(profile) = get_default_profile(&id) { + return Ok(Some(profile.clone())); + } + let profile = store + .get_message_by_name::(&id) + .await + .map_err(|e| Status::internal(format!("fetch provider profile failed: {e}")))? + .and_then(|stored| stored.profile) + .map(|profile| ProviderTypeProfile::from_proto(&profile)); + Ok(profile) +} + +async fn merged_provider_profiles(store: &Store) -> Result, Status> { + let mut profiles = default_profiles().to_vec(); + profiles.extend( + custom_provider_profiles(store) + .await? + .into_iter() + .filter_map(|stored| stored.profile) + .map(|profile| ProviderTypeProfile::from_proto(&profile)), + ); + Ok(profiles) +} + +async fn custom_provider_profiles(store: &Store) -> Result, Status> { + let records = store + .list(StoredProviderProfile::object_type(), 10_000, 0) + .await + .map_err(|e| Status::internal(format!("list provider profiles failed: {e}")))?; + + let mut profiles = Vec::with_capacity(records.len()); + for record in records { + let profile = StoredProviderProfile::decode(record.payload.as_slice()) + .map_err(|e| Status::internal(format!("decode provider profile failed: {e}")))?; + profiles.push(profile); + } + Ok(profiles) +} + +fn normalize_profile_id_request(id: &str) -> Result { + if id.trim().is_empty() { + return Err(Status::invalid_argument("id is required")); + } + normalize_profile_id(id).ok_or_else(|| { + Status::invalid_argument("id must be lowercase kebab-case using only a-z, 0-9, and '-'") + }) +} + +fn profiles_from_import_items( + items: &[ProviderProfileImportItem], +) -> ( + Vec<(String, ProviderTypeProfile)>, + Vec, +) { + let mut profiles = Vec::new(); + let mut diagnostics = Vec::new(); + for item in items { + let source = item.source.clone(); + let Some(profile) = item.profile.as_ref() else { + diagnostics.push(ProfileValidationDiagnostic { + source, + profile_id: String::new(), + field: "profile".to_string(), + message: "provider profile is required".to_string(), + severity: "error".to_string(), + }); + continue; + }; + profiles.push((source, ProviderTypeProfile::from_proto(profile))); + } + (profiles, diagnostics) +} + +fn add_empty_profile_set_diagnostic( + profiles: &[(String, ProviderTypeProfile)], + diagnostics: &mut Vec, +) { + if profiles.is_empty() && diagnostics.is_empty() { + diagnostics.push(ProfileValidationDiagnostic { + source: String::new(), + profile_id: String::new(), + field: "profiles".to_string(), + message: "at least one provider profile is required".to_string(), + severity: "error".to_string(), + }); + } +} + +async fn profile_conflict_diagnostics( + store: &Store, + profiles: &[(String, ProviderTypeProfile)], +) -> Result, Status> { + let mut diagnostics = Vec::new(); + for (source, profile) in profiles { + let Some(id) = normalize_profile_id(&profile.id) else { + continue; + }; + if get_default_profile(&id).is_some() { + diagnostics.push(ProfileValidationDiagnostic { + source: source.clone(), + profile_id: id.clone(), + field: "id".to_string(), + message: format!("provider profile '{id}' is built-in and cannot be overwritten"), + severity: "error".to_string(), + }); + continue; + } + if let Some(provider_type) = normalize_provider_type(&id) { + diagnostics.push(ProfileValidationDiagnostic { + source: source.clone(), + profile_id: id.clone(), + field: "id".to_string(), + message: format!( + "provider profile id '{id}' is reserved for legacy provider type '{provider_type}'" + ), + severity: "error".to_string(), + }); + continue; + } + if store + .get_message_by_name::(&id) + .await + .map_err(|e| Status::internal(format!("fetch provider profile failed: {e}")))? + .is_some() + { + diagnostics.push(ProfileValidationDiagnostic { + source: source.clone(), + profile_id: id.clone(), + field: "id".to_string(), + message: format!("custom provider profile '{id}' already exists"), + severity: "error".to_string(), + }); + } + } + Ok(diagnostics) +} + +fn stored_provider_profile(profile: ProviderProfile) -> StoredProviderProfile { + use crate::persistence::current_time_ms; + let now_ms = current_time_ms().unwrap_or_default(); + StoredProviderProfile { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: uuid::Uuid::new_v4().to_string(), + name: profile.id.clone(), + created_at_ms: now_ms, + labels: std::collections::HashMap::new(), + }), + profile: Some(profile), + } +} + +fn proto_diagnostic(diagnostic: ProfileValidationDiagnostic) -> ProviderProfileDiagnostic { + ProviderProfileDiagnostic { + source: diagnostic.source, + profile_id: diagnostic.profile_id, + field: diagnostic.field, + message: diagnostic.message, + severity: diagnostic.severity, + } +} + +fn has_errors(diagnostics: &[ProfileValidationDiagnostic]) -> bool { + diagnostics + .iter() + .any(|diagnostic| diagnostic.severity == "error") +} + +async fn sandboxes_using_profile(store: &Store, profile_id: &str) -> Result, Status> { + let mut blocking = Vec::new(); + let mut offset = 0; + loop { + let records = store + .list(Sandbox::object_type(), 1000, offset) + .await + .map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?; + if records.is_empty() { + break; + } + offset = offset + .checked_add( + u32::try_from(records.len()) + .map_err(|_| Status::internal("sandbox page size exceeded u32"))?, + ) + .ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?; + + for record in records { + let sandbox = Sandbox::decode(record.payload.as_slice()) + .map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?; + let Some(spec) = sandbox.spec.as_ref() else { + continue; + }; + for provider_name in &spec.providers { + let Some(provider) = store + .get_message_by_name::(provider_name) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + else { + continue; + }; + if normalize_profile_id(&provider.r#type).as_deref() == Some(profile_id) { + blocking.push(sandbox.object_name().to_string()); + break; + } + } + } + } + blocking.sort(); + blocking.dedup(); + Ok(blocking) +} + pub(super) async fn handle_update_provider( state: &Arc, request: Request, @@ -395,7 +717,10 @@ mod tests { use crate::tracing_bus::TracingLogBus; use openshell_core::Config; use openshell_core::proto::{ - GetProviderProfileRequest, ListProviderProfilesRequest, ProviderProfileCategory, + DeleteProviderProfileRequest, GetProviderProfileRequest, ImportProviderProfilesRequest, + L7Allow, L7Rule, LintProviderProfilesRequest, ListProviderProfilesRequest, NetworkBinary, + NetworkEndpoint, ProviderProfile, ProviderProfileCategory, ProviderProfileImportItem, + Sandbox, SandboxSpec, }; use openshell_core::{ObjectId, ObjectName}; use std::collections::HashMap; @@ -443,6 +768,29 @@ mod tests { } } + fn custom_profile(id: &str) -> ProviderProfile { + ProviderProfile { + id: id.to_string(), + display_name: format!("{id} Profile"), + description: String::new(), + category: ProviderProfileCategory::Other as i32, + credentials: Vec::new(), + endpoints: Vec::new(), + binaries: Vec::new(), + inference_capable: false, + } + } + + fn custom_profile_with_invalid_endpoint(id: &str) -> ProviderProfile { + let mut profile = custom_profile(id); + profile.endpoints.push(NetworkEndpoint { + host: String::new(), + port: 0, + ..Default::default() + }); + profile + } + async fn test_server_state() -> Arc { let store = Arc::new( Store::connect("sqlite::memory:?cache=shared") @@ -474,6 +822,8 @@ mod tests { offset: 0, }), ) + .await + .unwrap() .into_inner(); let github = response @@ -503,6 +853,7 @@ mod tests { id: "github".to_string(), }), ) + .await .unwrap() .into_inner() .profile @@ -519,10 +870,458 @@ mod tests { id: "generic".to_string(), }), ) + .await .unwrap_err(); assert_eq!(generic_err.code(), Code::NotFound); } + #[tokio::test] + async fn import_provider_profile_lists_and_gets_custom_profile() { + let state = test_server_state().await; + let response = handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(custom_profile("custom-api")), + source: "custom-api.yaml".to_string(), + }], + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.imported); + assert!(response.diagnostics.is_empty()); + + let listed = handle_list_provider_profiles( + &state, + Request::new(ListProviderProfilesRequest { + limit: 100, + offset: 0, + }), + ) + .await + .unwrap() + .into_inner(); + assert!( + listed + .profiles + .iter() + .any(|profile| profile.id == "custom-api") + ); + + let fetched = handle_get_provider_profile( + &state, + Request::new(GetProviderProfileRequest { + id: "custom-api".to_string(), + }), + ) + .await + .unwrap() + .into_inner() + .profile + .unwrap(); + assert_eq!(fetched.id, "custom-api"); + } + + #[tokio::test] + async fn import_provider_profile_rejects_builtin_overwrite() { + let state = test_server_state().await; + let response = handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(custom_profile("github")), + source: "github.yaml".to_string(), + }], + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(!response.imported); + assert!( + response + .diagnostics + .iter() + .any(|diagnostic| diagnostic.message.contains("built-in")) + ); + } + + #[tokio::test] + async fn import_provider_profile_rejects_legacy_provider_type_ids() { + let state = test_server_state().await; + let response = handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(custom_profile("generic")), + source: "generic.yaml".to_string(), + }], + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(!response.imported); + assert!( + response + .diagnostics + .iter() + .any(|diagnostic| diagnostic.message.contains("reserved")) + ); + + let missing = handle_get_provider_profile( + &state, + Request::new(GetProviderProfileRequest { + id: "generic".to_string(), + }), + ) + .await + .unwrap_err(); + assert_eq!(missing.code(), Code::NotFound); + } + + #[tokio::test] + async fn import_provider_profile_rejects_noncanonical_ids() { + let state = test_server_state().await; + let response = handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ + ProviderProfileImportItem { + profile: Some(custom_profile(" alex-api ")), + source: "space.yaml".to_string(), + }, + ProviderProfileImportItem { + profile: Some(custom_profile("alex_api")), + source: "underscore.yaml".to_string(), + }, + ProviderProfileImportItem { + profile: Some(custom_profile("Alex-API")), + source: "case.yaml".to_string(), + }, + ], + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(!response.imported); + assert_eq!( + response + .diagnostics + .iter() + .filter(|diagnostic| diagnostic.message.contains("lowercase kebab-case")) + .count(), + 3 + ); + } + + #[tokio::test] + async fn provider_profile_get_and_delete_normalize_request_ids() { + let state = test_server_state().await; + handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(custom_profile("alex-api")), + source: "alex-api.yaml".to_string(), + }], + }), + ) + .await + .unwrap(); + + let fetched = handle_get_provider_profile( + &state, + Request::new(GetProviderProfileRequest { + id: " Alex-API ".to_string(), + }), + ) + .await + .unwrap() + .into_inner() + .profile + .unwrap(); + assert_eq!(fetched.id, "alex-api"); + + let deleted = handle_delete_provider_profile( + &state, + Request::new(DeleteProviderProfileRequest { + id: " Alex-API ".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert!(deleted.deleted); + } + + #[tokio::test] + async fn import_provider_profiles_rejects_mixed_batch_without_partial_import() { + let state = test_server_state().await; + let response = handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ + ProviderProfileImportItem { + profile: Some(custom_profile("bulk-one")), + source: "bulk-one.yaml".to_string(), + }, + ProviderProfileImportItem { + profile: Some(custom_profile_with_invalid_endpoint("bulk-bad")), + source: "bulk-bad.yaml".to_string(), + }, + ProviderProfileImportItem { + profile: Some(custom_profile("bulk-two")), + source: "bulk-two.yaml".to_string(), + }, + ], + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(!response.imported); + assert!(response.profiles.is_empty()); + assert!(response.diagnostics.iter().any(|diagnostic| { + diagnostic.profile_id == "bulk-bad" + && diagnostic.field == "endpoints[0]" + && diagnostic.message.contains("invalid endpoint") + })); + + for id in ["bulk-one", "bulk-two"] { + let missing = handle_get_provider_profile( + &state, + Request::new(GetProviderProfileRequest { id: id.to_string() }), + ) + .await + .unwrap_err(); + assert_eq!(missing.code(), Code::NotFound); + } + } + + #[tokio::test] + #[allow(deprecated)] + async fn import_provider_profiles_preserves_advanced_proto_policy_fields() { + let state = test_server_state().await; + let response = handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(ProviderProfile { + id: "advanced-api".to_string(), + display_name: "Advanced API".to_string(), + description: String::new(), + category: ProviderProfileCategory::Other as i32, + credentials: Vec::new(), + endpoints: vec![NetworkEndpoint { + host: "api.advanced.example".to_string(), + protocol: "rest".to_string(), + ports: vec![443, 8443], + allowed_ips: vec!["10.0.0.0/24".to_string()], + rules: vec![L7Rule { + allow: Some(L7Allow { + method: "GET".to_string(), + path: "/v1/**".to_string(), + ..Default::default() + }), + }], + allow_encoded_slash: true, + path: "/v1".to_string(), + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/advanced".to_string(), + harness: true, + }], + inference_capable: false, + }), + source: "advanced-api.yaml".to_string(), + }], + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.imported); + + let fetched = handle_get_provider_profile( + &state, + Request::new(GetProviderProfileRequest { + id: "advanced-api".to_string(), + }), + ) + .await + .unwrap() + .into_inner() + .profile + .expect("profile should exist"); + let endpoint = fetched.endpoints.first().expect("endpoint should exist"); + assert_eq!(endpoint.ports, vec![443, 8443]); + assert_eq!(endpoint.allowed_ips, vec!["10.0.0.0/24"]); + assert_eq!(endpoint.rules.len(), 1); + assert_eq!( + endpoint.rules[0] + .allow + .as_ref() + .map(|allow| allow.path.as_str()), + Some("/v1/**") + ); + assert!(endpoint.allow_encoded_slash); + assert_eq!(endpoint.path, "/v1"); + assert!(fetched.binaries[0].harness); + } + + #[tokio::test] + async fn lint_provider_profiles_reports_mixed_batch_diagnostics() { + let state = test_server_state().await; + let response = handle_lint_provider_profiles( + &state, + Request::new(LintProviderProfilesRequest { + profiles: vec![ + ProviderProfileImportItem { + profile: Some(custom_profile("lint-one")), + source: "lint-one.yaml".to_string(), + }, + ProviderProfileImportItem { + profile: Some(custom_profile_with_invalid_endpoint("lint-bad")), + source: "lint-bad.yaml".to_string(), + }, + ProviderProfileImportItem { + profile: Some(custom_profile("lint-two")), + source: "lint-two.yaml".to_string(), + }, + ], + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(!response.valid); + assert!(response.diagnostics.iter().any(|diagnostic| { + diagnostic.profile_id == "lint-bad" + && diagnostic.field == "endpoints[0]" + && diagnostic.message.contains("invalid endpoint") + })); + + for id in ["lint-one", "lint-two"] { + let missing = handle_get_provider_profile( + &state, + Request::new(GetProviderProfileRequest { id: id.to_string() }), + ) + .await + .unwrap_err(); + assert_eq!(missing.code(), Code::NotFound); + } + } + + #[tokio::test] + async fn delete_provider_profile_rejects_builtin_and_in_use_custom_profiles() { + let state = test_server_state().await; + handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(custom_profile("custom-api")), + source: "custom-api.yaml".to_string(), + }], + }), + ) + .await + .unwrap(); + + let builtin_err = handle_delete_provider_profile( + &state, + Request::new(DeleteProviderProfileRequest { + id: "github".to_string(), + }), + ) + .await + .unwrap_err(); + assert_eq!(builtin_err.code(), Code::FailedPrecondition); + + create_provider_record( + state.store.as_ref(), + provider_with_values("custom-provider", "custom-api"), + ) + .await + .unwrap(); + state + .store + .put_message(&Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sandbox-id".to_string(), + name: "sandbox-using-custom".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + }), + spec: Some(SandboxSpec { + providers: vec!["custom-provider".to_string()], + ..Default::default() + }), + ..Default::default() + }) + .await + .unwrap(); + + let in_use_err = handle_delete_provider_profile( + &state, + Request::new(DeleteProviderProfileRequest { + id: "custom-api".to_string(), + }), + ) + .await + .unwrap_err(); + assert_eq!(in_use_err.code(), Code::FailedPrecondition); + assert!(in_use_err.message().contains("sandbox-using-custom")); + } + + #[tokio::test] + async fn delete_provider_profile_removes_unused_custom_profile() { + let state = test_server_state().await; + handle_import_provider_profiles( + &state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(custom_profile("custom-api")), + source: "custom-api.yaml".to_string(), + }], + }), + ) + .await + .unwrap(); + + let deleted = handle_delete_provider_profile( + &state, + Request::new(DeleteProviderProfileRequest { + id: "custom-api".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + assert!(deleted.deleted); + + let missing = handle_get_provider_profile( + &state, + Request::new(GetProviderProfileRequest { + id: "custom-api".to_string(), + }), + ) + .await + .unwrap_err(); + assert_eq!(missing.code(), Code::NotFound); + } + #[tokio::test] async fn provider_crud_round_trip_and_semantics() { let store = Store::connect("sqlite::memory:?cache=shared") diff --git a/crates/openshell-server/tests/auth_endpoint_integration.rs b/crates/openshell-server/tests/auth_endpoint_integration.rs index e5f9dc4e9..7b16ee991 100644 --- a/crates/openshell-server/tests/auth_endpoint_integration.rs +++ b/crates/openshell-server/tests/auth_endpoint_integration.rs @@ -524,6 +524,30 @@ impl openshell_core::proto::open_shell_server::OpenShell for TestOpenShell { Err(tonic::Status::unimplemented("test")) } + async fn import_provider_profiles( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> + { + Err(tonic::Status::unimplemented("test")) + } + + async fn lint_provider_profiles( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> + { + Err(tonic::Status::unimplemented("test")) + } + + async fn delete_provider_profile( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> + { + Err(tonic::Status::unimplemented("test")) + } + async fn update_provider( &self, _: tonic::Request, diff --git a/crates/openshell-server/tests/edge_tunnel_auth.rs b/crates/openshell-server/tests/edge_tunnel_auth.rs index ed6ed398f..39df0819f 100644 --- a/crates/openshell-server/tests/edge_tunnel_auth.rs +++ b/crates/openshell-server/tests/edge_tunnel_auth.rs @@ -185,6 +185,27 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn import_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn lint_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn delete_provider_profile( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_provider( &self, _request: tonic::Request, diff --git a/crates/openshell-server/tests/multiplex_integration.rs b/crates/openshell-server/tests/multiplex_integration.rs index c91dd6061..49c6f9c92 100644 --- a/crates/openshell-server/tests/multiplex_integration.rs +++ b/crates/openshell-server/tests/multiplex_integration.rs @@ -149,6 +149,27 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn import_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn lint_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn delete_provider_profile( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_provider( &self, _request: tonic::Request, diff --git a/crates/openshell-server/tests/multiplex_tls_integration.rs b/crates/openshell-server/tests/multiplex_tls_integration.rs index 6942d66f7..d6a244e49 100644 --- a/crates/openshell-server/tests/multiplex_tls_integration.rs +++ b/crates/openshell-server/tests/multiplex_tls_integration.rs @@ -162,6 +162,27 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn import_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn lint_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn delete_provider_profile( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_provider( &self, _request: tonic::Request, diff --git a/crates/openshell-server/tests/supervisor_relay_integration.rs b/crates/openshell-server/tests/supervisor_relay_integration.rs index d77cfd375..f8519cdc7 100644 --- a/crates/openshell-server/tests/supervisor_relay_integration.rs +++ b/crates/openshell-server/tests/supervisor_relay_integration.rs @@ -187,6 +187,27 @@ impl OpenShell for RelayGateway { Err(Status::unimplemented("unused")) } + async fn import_provider_profiles( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn lint_provider_profiles( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + + async fn delete_provider_profile( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn delete_provider( &self, _: tonic::Request, diff --git a/crates/openshell-server/tests/ws_tunnel_integration.rs b/crates/openshell-server/tests/ws_tunnel_integration.rs index f196edb07..173a7225d 100644 --- a/crates/openshell-server/tests/ws_tunnel_integration.rs +++ b/crates/openshell-server/tests/ws_tunnel_integration.rs @@ -179,6 +179,27 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } + async fn import_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn lint_provider_profiles( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn delete_provider_profile( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + async fn update_provider( &self, _request: tonic::Request, diff --git a/proto/openshell.proto b/proto/openshell.proto index 529ee0629..a4a18ce82 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -59,12 +59,24 @@ service OpenShell { rpc GetProviderProfile(GetProviderProfileRequest) returns (ProviderProfileResponse); + // Import custom provider type profiles. + rpc ImportProviderProfiles(ImportProviderProfilesRequest) + returns (ImportProviderProfilesResponse); + + // Validate provider type profiles without registering them. + rpc LintProviderProfiles(LintProviderProfilesRequest) + returns (LintProviderProfilesResponse); + // Update an existing provider by name. rpc UpdateProvider(UpdateProviderRequest) returns (ProviderResponse); // Delete a provider by name. rpc DeleteProvider(DeleteProviderRequest) returns (DeleteProviderResponse); + // Delete a custom provider type profile by id. + rpc DeleteProviderProfile(DeleteProviderProfileRequest) + returns (DeleteProviderProfileResponse); + // Get sandbox settings by id (called by sandbox entrypoint and poll loop). rpc GetSandboxConfig(openshell.sandbox.v1.GetSandboxConfigRequest) returns (openshell.sandbox.v1.GetSandboxConfigResponse); @@ -581,6 +593,21 @@ message GetProviderProfileRequest { string id = 1; } +// Provider profile payload with optional source metadata for diagnostics. +message ProviderProfileImportItem { + ProviderProfile profile = 1; + string source = 2; +} + +// Provider profile validation diagnostic. +message ProviderProfileDiagnostic { + string source = 1; + string profile_id = 2; + string field = 3; + string message = 4; + string severity = 5; +} + // Provider credential declaration. message ProviderProfileCredential { string name = 1; @@ -616,6 +643,12 @@ message ProviderProfile { bool inference_capable = 8; } +// Stored custom provider profile object. +message StoredProviderProfile { + openshell.datamodel.v1.ObjectMeta metadata = 1; + ProviderProfile profile = 2; +} + // Provider profile response. message ProviderProfileResponse { ProviderProfile profile = 1; @@ -626,11 +659,44 @@ message ListProviderProfilesResponse { repeated ProviderProfile profiles = 1; } +// Import custom provider profiles request. +message ImportProviderProfilesRequest { + repeated ProviderProfileImportItem profiles = 1; +} + +// Import custom provider profiles response. +message ImportProviderProfilesResponse { + repeated ProviderProfileDiagnostic diagnostics = 1; + repeated ProviderProfile profiles = 2; + bool imported = 3; +} + +// Lint provider profiles request. +message LintProviderProfilesRequest { + repeated ProviderProfileImportItem profiles = 1; +} + +// Lint provider profiles response. +message LintProviderProfilesResponse { + repeated ProviderProfileDiagnostic diagnostics = 1; + bool valid = 2; +} + // Delete provider response. message DeleteProviderResponse { bool deleted = 1; } +// Delete custom provider profile request. +message DeleteProviderProfileRequest { + string id = 1; +} + +// Delete custom provider profile response. +message DeleteProviderProfileResponse { + bool deleted = 1; +} + // Get sandbox provider environment request. message GetSandboxProviderEnvironmentRequest { // The sandbox ID.