Skip to content
56 changes: 48 additions & 8 deletions cmd/auth/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"net/url"
"strings"

"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra"
"gopkg.in/ini.v1"
Expand All @@ -30,6 +32,8 @@ func canonicalHost(host string) (string, error) {

var ErrNoMatchingProfiles = errors.New("no matching profiles found")

const shellQuotedSpecialChars = " \t\n\r\"\\$`!#&|;(){}[]<>?*~'"

func resolveSection(cfg *config.Config, iniFile *config.File) (*ini.Section, error) {
var candidates []*ini.Section
configuredHost, err := canonicalHost(cfg.Host)
Expand Down Expand Up @@ -122,16 +126,22 @@ func newEnvCommand() *cobra.Command {
if err != nil {
return err
}
vars := map[string]string{}
for _, a := range config.ConfigAttributes {
if a.IsZero(cfg) {
continue
}
envValue := a.GetString(cfg)
for _, envName := range a.EnvVars {
vars[envName] = envValue
// Output KEY=VALUE lines when the user explicitly passes --output text.
if cmd.Flag("output").Changed && root.OutputType(cmd) == flags.OutputText {
w := cmd.OutOrStdout()
for _, a := range config.ConfigAttributes {
if a.IsZero(cfg) {
continue
}
v := a.GetString(cfg)
for _, envName := range a.EnvVars {
fmt.Fprintf(w, "%s=%s\n", envName, quoteEnvValue(v))
}
}
return nil
}

vars := collectEnvVars(cfg)
raw, err := json.MarshalIndent(map[string]any{
"env": vars,
}, "", " ")
Expand All @@ -144,3 +154,33 @@ func newEnvCommand() *cobra.Command {

return cmd
}

// collectEnvVars returns the environment variables for the given config
// as a map from env var name to value.
func collectEnvVars(cfg *config.Config) map[string]string {
vars := map[string]string{}
for _, a := range config.ConfigAttributes {
if a.IsZero(cfg) {
continue
}
v := a.GetString(cfg)
for _, envName := range a.EnvVars {
vars[envName] = v
}
}
return vars
}

// quoteEnvValue quotes a value for KEY=VALUE output if it contains spaces or
// shell-special characters. Single quotes prevent shell expansion, and
// embedded single quotes use the POSIX-compatible '\” sequence.
func quoteEnvValue(v string) string {
if v == "" {
return `''`
}
needsQuoting := strings.ContainsAny(v, shellQuotedSpecialChars)
if !needsQuoting {
return v
}
return "'" + strings.ReplaceAll(v, "'", "'\\''") + "'"
}
99 changes: 99 additions & 0 deletions cmd/auth/env_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package auth

import (
"bytes"
"testing"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestQuoteEnvValue(t *testing.T) {
cases := []struct {
name string
in string
want string
}{
{name: "simple value", in: "hello", want: "hello"},
{name: "empty value", in: "", want: `''`},
{name: "value with space", in: "hello world", want: "'hello world'"},
{name: "value with tab", in: "hello\tworld", want: "'hello\tworld'"},
{name: "value with double quote", in: `say "hi"`, want: "'say \"hi\"'"},
{name: "value with backslash", in: `path\to`, want: "'path\\to'"},
{name: "url value", in: "https://example.com", want: "https://example.com"},
{name: "value with dollar", in: "price$5", want: "'price$5'"},
{name: "value with backtick", in: "hello`world", want: "'hello`world'"},
{name: "value with bang", in: "hello!world", want: "'hello!world'"},
{name: "value with single quote", in: "it's", want: "'it'\\''s'"},
{name: "value with newline", in: "line1\nline2", want: "'line1\nline2'"},
{name: "value with carriage return", in: "line1\rline2", want: "'line1\rline2'"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := quoteEnvValue(c.in)
assert.Equal(t, c.want, got)
})
}
}

func TestEnvCommand_TextOutput(t *testing.T) {
cases := []struct {
name string
args []string
wantJSON bool
}{
{
name: "default output is JSON",
args: []string{"--host", "https://test.cloud.databricks.com"},
wantJSON: true,
},
{
name: "explicit --output text produces KEY=VALUE lines",
args: []string{"--host", "https://test.cloud.databricks.com", "--output", "text"},
wantJSON: false,
},
{
name: "explicit --output json produces JSON",
args: []string{"--host", "https://test.cloud.databricks.com", "--output", "json"},
wantJSON: true,
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
parent := &cobra.Command{Use: "databricks"}
outputFlag := flags.OutputText
parent.PersistentFlags().VarP(&outputFlag, "output", "o", "output type: text or json")

envCmd := newEnvCommand()
parent.AddCommand(envCmd)
parent.SetContext(cmdio.MockDiscard(t.Context()))

// Set DATABRICKS_TOKEN so the SDK's config.Authenticate succeeds
// without hitting a real endpoint.
t.Setenv("DATABRICKS_TOKEN", "test-token-value")

var buf bytes.Buffer
parent.SetOut(&buf)
parent.SetArgs(append([]string{"env"}, c.args...))

err := parent.Execute()
require.NoError(t, err)

output := buf.String()
if c.wantJSON {
assert.Contains(t, output, "{")
assert.Contains(t, output, "DATABRICKS_HOST")
} else {
assert.NotContains(t, output, "{")
assert.Contains(t, output, "DATABRICKS_HOST=")
assert.Contains(t, output, "=")
// Verify KEY=VALUE format (no JSON structure)
assert.NotContains(t, output, `"env"`)
}
})
}
}
24 changes: 18 additions & 6 deletions cmd/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (
"strings"
"time"

"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/credentials/u2m/cache"
Expand Down Expand Up @@ -83,17 +85,27 @@ using a client ID and secret is not supported.`,
if err != nil {
return err
}
raw, err := json.MarshalIndent(t, "", " ")
if err != nil {
return err
}
_, _ = cmd.OutOrStdout().Write(raw)
return nil
return writeTokenOutput(cmd, t)
}

return cmd
}

func writeTokenOutput(cmd *cobra.Command, t *oauth2.Token) error {
// Output plain token when the user explicitly passes --output text.
if cmd.Flag("output").Changed && root.OutputType(cmd) == flags.OutputText {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), t.AccessToken)
return nil
}

raw, err := json.MarshalIndent(t, "", " ")
if err != nil {
return err
}
_, _ = cmd.OutOrStdout().Write(raw)
return nil
}

type loadTokenArgs struct {
// authArguments is the parsed auth arguments, including the host and optionally the account ID.
authArguments *auth.AuthArguments
Expand Down
105 changes: 105 additions & 0 deletions cmd/auth/token_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"bytes"
"context"
"net/http"
"testing"
Expand All @@ -10,8 +11,10 @@ import (
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
)
Expand Down Expand Up @@ -729,3 +732,105 @@ func (e errProfiler) LoadProfiles(context.Context, profile.ProfileMatchFunction)
func (e errProfiler) GetPath(context.Context) (string, error) {
return "<error>", nil
}

func TestTokenCommand_TextOutput(t *testing.T) {
profiler := profile.InMemoryProfiler{
Profiles: profile.Profiles{
{
Name: "test-ws",
Host: "https://test-ws.cloud.databricks.com",
},
},
}
tokenCache := &inMemoryTokenCache{
Tokens: map[string]*oauth2.Token{
"test-ws": {
RefreshToken: "test-ws",
Expiry: time.Now().Add(1 * time.Hour),
},
},
}
persistentAuthOpts := []u2m.PersistentAuthOption{
u2m.WithTokenCache(tokenCache),
u2m.WithOAuthEndpointSupplier(&MockApiClient{}),
u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}),
}

cases := []struct {
name string
args []string
wantSubstr string
wantJSON bool
}{
{
name: "default output is JSON",
args: []string{"--profile", "test-ws"},
wantSubstr: `"access_token"`,
wantJSON: true,
},
{
name: "explicit --output json produces JSON",
args: []string{"--profile", "test-ws", "--output", "json"},
wantSubstr: `"access_token"`,
wantJSON: true,
},
{
name: "explicit --output text produces plain token with newline",
args: []string{"--profile", "test-ws", "--output", "text"},
wantSubstr: "new-access-token\n",
wantJSON: false,
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
authArgs := &auth.AuthArguments{}

parent := &cobra.Command{Use: "databricks"}
outputFlag := flags.OutputText
parent.PersistentFlags().VarP(&outputFlag, "output", "o", "output type: text or json")
parent.PersistentFlags().StringP("profile", "p", "", "~/.databrickscfg profile")

tokenCmd := newTokenCommand(authArgs)
// Override RunE to inject test profiler and token cache while reusing
// the production output formatter.
tokenCmd.RunE = func(cmd *cobra.Command, args []string) error {
profileName := ""
if f := cmd.Flag("profile"); f != nil {
profileName = f.Value.String()
}
tok, err := loadToken(cmd.Context(), loadTokenArgs{
authArguments: authArgs,
profileName: profileName,
args: args,
tokenTimeout: 1 * time.Hour,
profiler: profiler,
persistentAuthOpts: persistentAuthOpts,
})
if err != nil {
return err
}
return writeTokenOutput(cmd, tok)
}

parent.AddCommand(tokenCmd)
parent.SetContext(ctx)

var buf bytes.Buffer
parent.SetOut(&buf)
parent.SetArgs(append([]string{"token"}, c.args...))

err := parent.Execute()
assert.NoError(t, err)

output := buf.String()
assert.Contains(t, output, c.wantSubstr)
if c.wantJSON {
assert.Contains(t, output, "{")
} else {
assert.NotContains(t, output, "{")
}
})
}
}
Loading