Skip to content

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,7 @@ public String authType() {
List<String> buildHostArgs(String cliPath, DatabricksConfig config) {
List<String> cmd =
new ArrayList<>(Arrays.asList(cliPath, "auth", "token", "--host", config.getHost()));
if (config.getExperimentalIsUnifiedHost() != null && config.getExperimentalIsUnifiedHost()) {
// For unified hosts, pass account_id, workspace_id, and experimental flag
cmd.add("--experimental-is-unified-host");
if (config.getAccountId() != null) {
cmd.add("--account-id");
cmd.add(config.getAccountId());
}
if (config.getWorkspaceId() != null) {
cmd.add("--workspace-id");
cmd.add(config.getWorkspaceId());
}
} else if (config.getClientType() == ClientType.ACCOUNT) {
if (config.getClientType() == ClientType.ACCOUNT) {
cmd.add("--account-id");
cmd.add(config.getAccountId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
import java.time.Duration;
import java.util.*;
import org.apache.http.HttpMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DatabricksConfig {
private static final Logger LOG = LoggerFactory.getLogger(DatabricksConfig.class);

private CredentialsProvider credentialsProvider = new DefaultCredentialsProvider();

@ConfigAttribute(env = "DATABRICKS_HOST")
Expand Down Expand Up @@ -219,12 +223,28 @@ private synchronized DatabricksConfig innerResolve() {
sortScopes();
ConfigLoader.fixHostIfNeeded(this);
initHttp();
tryResolveHostMetadata();
return this;
} catch (DatabricksException e) {
throw ConfigLoader.makeNicerError(e.getMessage(), e, this);
}
}

/**
* Attempts to resolve host metadata from the well-known endpoint. Logs a warning and continues if
* metadata resolution fails, since not all hosts support the discovery endpoint.
*/
private void tryResolveHostMetadata() {
if (host == null) {
return;
}
try {
resolveHostMetadata();
} catch (Throwable e) {
LOG.debug("Failed to resolve host metadata: {}", e.getMessage());
}
}

// Sort scopes in-place for better de-duplication in the refresh token cache.
private void sortScopes() {
if (scopes != null && !scopes.isEmpty()) {
Expand All @@ -250,11 +270,6 @@ public synchronized Map<String, String> authenticate() throws DatabricksExceptio
}
Map<String, String> headers = new HashMap<>(headerFactory.headers());

// For unified hosts with workspace operations, add the X-Databricks-Org-Id header
if (getHostType() == HostType.UNIFIED && workspaceId != null && !workspaceId.isEmpty()) {
headers.put("X-Databricks-Org-Id", workspaceId);
}

return headers;
} catch (DatabricksException e) {
String msg = String.format("%s auth: %s", credentialsProvider.authType(), e.getMessage());
Expand Down Expand Up @@ -712,23 +727,14 @@ public boolean isAws() {
}

public boolean isAccountClient() {
if (getHostType() == HostType.UNIFIED) {
throw new DatabricksException(
"Cannot determine account client status for unified hosts. "
+ "Use getHostType() or getClientType() instead. "
+ "For unified hosts, client type depends on whether workspaceId is set.");
}
if (host == null) {
return false;
}
return host.startsWith("https://accounts.") || host.startsWith("https://accounts-dod.");
}

/** Returns the host type based on configuration settings and host URL. */
/** Returns the host type based on the host URL pattern. */
public HostType getHostType() {
if (experimentalIsUnifiedHost != null && experimentalIsUnifiedHost) {
return HostType.UNIFIED;
}
if (host == null) {
return HostType.WORKSPACE;
}
Expand All @@ -738,15 +744,10 @@ public HostType getHostType() {
return HostType.WORKSPACE;
}

/** Returns the client type based on host type and workspace ID configuration. */
/** Returns the client type based on host type. */
public ClientType getClientType() {
HostType hostType = getHostType();
switch (hostType) {
case UNIFIED:
// For unified hosts, client type depends on whether workspaceId is set
return (workspaceId != null && !workspaceId.isEmpty())
? ClientType.WORKSPACE
: ClientType.ACCOUNT;
case ACCOUNTS:
return ClientType.ACCOUNT;
case WORKSPACE:
Expand Down Expand Up @@ -864,6 +865,10 @@ void resolveHostMetadata() throws IOException {
"discovery_url is not configured and could not be resolved from host metadata");
}
}
// For account hosts, use the accountId as the token audience if not already set.
if (tokenAudience == null && getClientType() == ClientType.ACCOUNT && accountId != null) {
tokenAudience = accountId;
}
}

private OpenIDConnectEndpoints fetchOidcEndpointsFromDiscovery() {
Expand All @@ -879,24 +884,11 @@ private OpenIDConnectEndpoints fetchOidcEndpointsFromDiscovery() {
return null;
}

private OpenIDConnectEndpoints getUnifiedOidcEndpoints(String accountId) throws IOException {
if (accountId == null || accountId.isEmpty()) {
throw new DatabricksException(
"account_id is required for unified host OIDC endpoint discovery");
}
String prefix = getHost() + "/oidc/accounts/" + accountId;
return new OpenIDConnectEndpoints(prefix + "/v1/token", prefix + "/v1/authorize");
}

private OpenIDConnectEndpoints fetchDefaultOidcEndpoints() throws IOException {
if (getHost() == null) {
return null;
}

// For unified hosts, use account-based OIDC endpoints
if (getHostType() == HostType.UNIFIED) {
return getUnifiedOidcEndpoints(getAccountId());
}
if (isAccountClient() && getAccountId() != null) {
String prefix = getHost() + "/oidc/accounts/" + getAccountId();
return new OpenIDConnectEndpoints(prefix + "/v1/token", prefix + "/v1/authorize");
Expand Down Expand Up @@ -962,6 +954,9 @@ private DatabricksConfig clone(Set<String> fieldsToSkip) {
if (fieldsToSkip.contains(f.getName())) {
continue;
}
if (java.lang.reflect.Modifier.isStatic(f.getModifiers())) {
continue;
}
try {
f.set(newConfig, f.get(this));
} catch (IllegalAccessException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,12 @@ public HeaderFactory configure(DatabricksConfig config) {
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", String.format("Bearer %s", idToken.getTokenValue()));

if (config.getClientType() == ClientType.ACCOUNT) {
AccessToken token;
try {
token = finalServiceAccountCredentials.createScoped(GCP_SCOPES).refreshAccessToken();
} catch (IOException e) {
String message =
"Failed to refresh access token from Google service account credentials.";
LOG.error(message + e);
throw new DatabricksException(message, e);
}
try {
AccessToken token =
finalServiceAccountCredentials.createScoped(GCP_SCOPES).refreshAccessToken();
headers.put(SA_ACCESS_TOKEN_HEADER, token.getTokenValue());
} catch (IOException e) {
LOG.warn("Failed to refresh GCP SA access token, skipping header: {}", e.getMessage());
}

return headers;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,11 @@ public HeaderFactory configure(DatabricksConfig config) {
throw new DatabricksException(message, e);
}

if (config.getClientType() == ClientType.ACCOUNT) {
try {
headers.put(
SA_ACCESS_TOKEN_HEADER, gcpScopedCredentials.refreshAccessToken().getTokenValue());
} catch (IOException e) {
String message = "Failed to refresh access token from scoped id token credentials.";
LOG.error(message + e);
throw new DatabricksException(message, e);
}
try {
headers.put(
SA_ACCESS_TOKEN_HEADER, gcpScopedCredentials.refreshAccessToken().getTokenValue());
} catch (IOException e) {
LOG.warn("Failed to refresh GCP SA access token, skipping header: {}", e.getMessage());
}

return headers;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,5 @@ public enum HostType {
WORKSPACE,

/** Traditional accounts host. */
ACCOUNTS,

/** Unified host supporting both workspace and account operations. */
UNIFIED
ACCOUNTS
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ public class HostMetadata {
@JsonProperty("workspace_id")
private String workspaceId;

@JsonProperty("cloud")
private String cloud;

public HostMetadata() {}

public HostMetadata(String oidcEndpoint, String accountId, String workspaceId) {
Expand All @@ -28,6 +31,13 @@ public HostMetadata(String oidcEndpoint, String accountId, String workspaceId) {
this.workspaceId = workspaceId;
}

public HostMetadata(String oidcEndpoint, String accountId, String workspaceId, String cloud) {
this.oidcEndpoint = oidcEndpoint;
this.accountId = accountId;
this.workspaceId = workspaceId;
this.cloud = cloud;
}

public String getOidcEndpoint() {
return oidcEndpoint;
}
Expand All @@ -39,4 +49,8 @@ public String getAccountId() {
public String getWorkspaceId() {
return workspaceId;
}

public String getCloud() {
return cloud;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import static org.junit.jupiter.api.Assertions.*;

import com.databricks.sdk.core.ClientType;
import com.databricks.sdk.core.DatabricksConfig;
import com.databricks.sdk.core.HostType;
import com.databricks.sdk.service.provisioning.Workspace;
Expand Down Expand Up @@ -49,27 +48,49 @@ public void testGetWorkspaceClientForUnifiedHost() {

WorkspaceClient workspaceClient = accountClient.getWorkspaceClient(workspace);

// Should have the same host
// Should have the same host (unified hosts reuse the same host)
assertEquals(unifiedHost, workspaceClient.config().getHost());

// Should have workspace ID set
assertEquals("123456", workspaceClient.config().getWorkspaceId());

// Should be workspace client type (on unified host)
assertEquals(ClientType.WORKSPACE, workspaceClient.config().getClientType());

// Host type should still be unified
assertEquals(HostType.UNIFIED, workspaceClient.config().getHostType());
// Host type is WORKSPACE (determined from URL pattern, not unified flag)
assertEquals(HostType.WORKSPACE, workspaceClient.config().getHostType());
}

@Test
public void testGetWorkspaceClientForUnifiedHostType() {
// Verify unified host type is correctly detected
DatabricksConfig config =
public void testGetWorkspaceClientForSpogHostDoesNotMutateAccountConfig() {
String spogHost = "https://mycompany.databricks.com";
DatabricksConfig accountConfig =
new DatabricksConfig()
.setHost("https://unified.databricks.com")
.setExperimentalIsUnifiedHost(true);
.setHost(spogHost)
.setExperimentalIsUnifiedHost(true)
.setAccountId("test-account")
.setToken("test-token");

AccountClient accountClient = new AccountClient(accountConfig);

// Get workspace client for first workspace
Workspace workspace1 = new Workspace();
workspace1.setWorkspaceId(111L);
workspace1.setDeploymentName("ws-1");
WorkspaceClient wc1 = accountClient.getWorkspaceClient(workspace1);

// Get workspace client for second workspace
Workspace workspace2 = new Workspace();
workspace2.setWorkspaceId(222L);
workspace2.setDeploymentName("ws-2");
WorkspaceClient wc2 = accountClient.getWorkspaceClient(workspace2);

// Each workspace client should have its own workspace ID
assertEquals("111", wc1.config().getWorkspaceId());
assertEquals("222", wc2.config().getWorkspaceId());

// Account config should not have been mutated
assertNull(accountConfig.getWorkspaceId());

assertEquals(HostType.UNIFIED, config.getHostType());
// Both should share the same SPOG host
assertEquals(spogHost, wc1.config().getHost());
assertEquals(spogHost, wc2.config().getHost());
}
}
Loading
Loading