Skip to content
Open
144 changes: 137 additions & 7 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/databricks/cli/experimental/ssh/internal/keys"
"github.com/databricks/cli/experimental/ssh/internal/proxy"
"github.com/databricks/cli/experimental/ssh/internal/sessions"
"github.com/databricks/cli/experimental/ssh/internal/sshconfig"
"github.com/databricks/cli/experimental/ssh/internal/vscode"
sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace"
Expand Down Expand Up @@ -99,11 +100,11 @@ type ClientOptions struct {
}

func (o *ClientOptions) Validate() error {
if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" {
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)")
if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" && o.Accelerator == "" {
return errors.New("please provide --cluster or --accelerator flag")
}
if o.Accelerator != "" && o.ConnectionName == "" {
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
if o.Accelerator != "" && o.ClusterID != "" {
return errors.New("--accelerator flag can only be used with serverless compute, not with --cluster")
}
// Consider removing this check when we enable serverless CPU connections. Ideally Jobs API should do the validation
// for us, but they don't plan on doing it in the nearest future. For now we should not forget to check if there are
Expand All @@ -128,7 +129,7 @@ func (o *ClientOptions) Validate() error {
}

func (o *ClientOptions) IsServerlessMode() bool {
return o.ClusterID == "" && o.ConnectionName != ""
return o.ClusterID == "" && (o.ConnectionName != "" || o.Accelerator != "")
}

// SessionIdentifier returns the unique identifier for the session.
Expand Down Expand Up @@ -208,9 +209,16 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
cancel()
}()

// For serverless without explicit --name: auto-generate or reconnect to existing session.
if opts.IsServerlessMode() && opts.ConnectionName == "" && !opts.ProxyMode {
if err := opts.resolveServerlessSession(ctx, client); err != nil {
return err
}
}

sessionID := opts.SessionIdentifier()
if sessionID == "" {
return errors.New("either --cluster or --name must be provided")
return errors.New("either --cluster or --accelerator must be provided")
}

if !opts.ProxyMode {
Expand Down Expand Up @@ -332,6 +340,26 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
cmdio.LogString(ctx, "Connected!")
}

// Persist the session for future reconnects.
if opts.IsServerlessMode() && !opts.ProxyMode {
currentUser, userErr := client.CurrentUser.Me(ctx)
sessionUserName := ""
if userErr == nil {
sessionUserName = currentUser.UserName
}
err = sessions.Add(ctx, sessions.Session{
Name: opts.ConnectionName,
Accelerator: opts.Accelerator,
WorkspaceHost: client.Config.Host,
UserName: sessionUserName,
CreatedAt: time.Now(),
ClusterID: clusterID,
})
if err != nil {
log.Warnf(ctx, "Failed to save session state: %v", err)
}
}

if opts.ProxyMode {
return runSSHProxy(ctx, client, serverPort, clusterID, opts)
} else if opts.IDE != "" {
Expand Down Expand Up @@ -690,7 +718,6 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
}
serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
if err == nil {
cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...")
break
} else if retries < maxRetries-1 {
time.Sleep(2 * time.Second)
Expand All @@ -704,3 +731,106 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC

return userName, serverPort, effectiveClusterID, nil
}

// resolveServerlessSession handles auto-generation and reconnection for serverless sessions.
// It checks local state for existing sessions matching the workspace, accelerator, and user,
// probes them to see if they're still alive, and prompts the user to reconnect or create new.
func (o *ClientOptions) resolveServerlessSession(ctx context.Context, client *databricks.WorkspaceClient) error {
version := build.GetInfo().Version

me, err := client.CurrentUser.Me(ctx)
if err != nil {
return fmt.Errorf("failed to get current user: %w", err)
}

matching, err := sessions.FindMatching(ctx, client.Config.Host, o.Accelerator, me.UserName)
if err != nil {
log.Warnf(ctx, "Failed to load session state: %v", err)
}

// Probe sessions to find alive ones (limit to 5 most recent to avoid latency).
const maxProbe = 5
if len(matching) > maxProbe {
matching = matching[len(matching)-maxProbe:]
}

var alive []sessions.Session
for _, s := range matching {
_, _, _, probeErr := getServerMetadata(ctx, client, s.Name, s.ClusterID, version, o.Liteswap)
if probeErr == nil {
alive = append(alive, s)
} else if errors.Is(probeErr, errServerMetadata) {
// Only clean up when the server is definitively gone (metadata endpoint returns not-found).
// Transient errors (network, auth) should not trigger cleanup.
cleanupStaleSession(ctx, client, s, version)
} else {
log.Warnf(ctx, "Transient error probing session %s, skipping: %v", s.Name, probeErr)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Agent Swarm Review] [Critical]

Any probe error is treated as proof that the session is stale.

resolveServerlessSession() calls cleanupStaleSession() for every getServerMetadata() failure. That probe can fail for transient auth, network, workspace API, or version-mismatch reasons. In those cases the CLI will delete local SSH config, remove the session from state, and best-effort delete secret scopes and workspace content for a session that may still be alive.

Both reviewers flagged this. Isaac confirmed Critical in cross-review due to irreversible blast radius.

Suggestion: Only run destructive cleanup on definitive stale signals (e.g., 404/not-found). For transient errors, keep the session and surface a warning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Now only errServerMetadata (definitive not-found) triggers cleanup. Transient errors (network, auth) are logged as warnings and the session is kept.

}

if len(alive) > 0 && cmdio.IsPromptSupported(ctx) {
choices := make([]string, 0, len(alive)+1)
for _, s := range alive {
choices = append(choices, fmt.Sprintf("Reconnect to %s (started %s)", s.Name, s.CreatedAt.Format(time.RFC822)))
}
choices = append(choices, "Create new session")

choice, choiceErr := cmdio.AskSelect(ctx, "Found existing sessions:", choices)
if choiceErr != nil {
return fmt.Errorf("failed to prompt user: %w", choiceErr)
}

for i, s := range alive {
if choice == choices[i] {
o.ConnectionName = s.Name
cmdio.LogString(ctx, "Reconnecting to session: "+s.Name)
return nil
}
}
}

// No alive session selected — generate a new name.
o.ConnectionName = sessions.GenerateSessionName(o.Accelerator, client.Config.Host)
cmdio.LogString(ctx, "Creating new session: "+o.ConnectionName)
return nil
}

// cleanupStaleSession removes all local and remote artifacts for a stale session.
func cleanupStaleSession(ctx context.Context, client *databricks.WorkspaceClient, s sessions.Session, version string) {
// Remove local SSH keys.
keyPath, err := keys.GetLocalSSHKeyPath(ctx, s.Name, "")
if err == nil {
os.RemoveAll(filepath.Dir(keyPath))
}

// Remove SSH config entry.
if err := sshconfig.RemoveHostConfig(ctx, s.Name); err != nil {
log.Debugf(ctx, "Failed to remove SSH config for %s: %v", s.Name, err)
}

// Delete secret scope (best-effort).
me, err := client.CurrentUser.Me(ctx)
if err == nil {
scopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, s.Name)
deleteErr := client.Secrets.DeleteScope(ctx, workspace.DeleteScope{Scope: scopeName})
if deleteErr != nil {
log.Debugf(ctx, "Failed to delete secret scope %s: %v", scopeName, deleteErr)
}
}

// Remove workspace content directory (best-effort).
contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, s.Name)
if err == nil {
deleteErr := client.Workspace.Delete(ctx, workspace.Delete{Path: contentDir, Recursive: true})
if deleteErr != nil {
log.Debugf(ctx, "Failed to delete workspace content for %s: %v", s.Name, deleteErr)
}
}

// Remove from local state.
if err := sessions.Remove(ctx, s.Name); err != nil {
log.Debugf(ctx, "Failed to remove session %s from state: %v", s.Name, err)
}

log.Infof(ctx, "Cleaned up stale session: %s", s.Name)
}
17 changes: 11 additions & 6 deletions experimental/ssh/internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ func TestValidate(t *testing.T) {
wantErr string
}{
{
name: "no cluster or connection name",
name: "no cluster or connection name or accelerator",
opts: client.ClientOptions{},
wantErr: "please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)",
wantErr: "please provide --cluster or --accelerator flag",
},
{
name: "proxy mode skips cluster/name check",
Expand All @@ -31,9 +31,13 @@ func TestValidate(t *testing.T) {
opts: client.ClientOptions{ClusterID: "abc-123"},
},
{
name: "accelerator without connection name",
name: "accelerator with cluster ID",
opts: client.ClientOptions{ClusterID: "abc-123", Accelerator: "GPU_1xA10"},
wantErr: "--accelerator flag can only be used with serverless compute (--name flag)",
wantErr: "--accelerator flag can only be used with serverless compute, not with --cluster",
},
{
name: "accelerator only (auto-generate session name)",
opts: client.ClientOptions{Accelerator: "GPU_1xA10"},
},
{
name: "connection name without accelerator",
Expand Down Expand Up @@ -64,8 +68,9 @@ func TestValidate(t *testing.T) {
wantErr: `invalid accelerator value: "CPU_1x", expected "GPU_1xA10" or "GPU_8xH100"`,
},
{
name: "both cluster ID and connection name",
opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn", Accelerator: "GPU_1xA10"},
name: "both cluster ID and connection name",
opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn", Accelerator: "GPU_1xA10"},
wantErr: `--accelerator flag can only be used with serverless compute, not with --cluster`,
},
{
name: "proxy mode with invalid connection name",
Expand Down
40 changes: 40 additions & 0 deletions experimental/ssh/internal/sessions/namegen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package sessions

import (
"crypto/md5"
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
"time"
)

// acceleratorPrefixes maps known accelerator types to short human-readable prefixes.
var acceleratorPrefixes = map[string]string{
"GPU_1xA10": "gpu-a10",
"GPU_8xH100": "gpu-h100",
}

// GenerateSessionName creates a human-readable session name from the accelerator type
// and workspace host. The workspace host is hashed into the name to avoid SSH known_hosts
// conflicts when connecting to different workspaces.
// Format: databricks-<prefix>-<date>-<workspace_hash><random_hex>.
func GenerateSessionName(accelerator, workspaceHost string) string {
prefix, ok := acceleratorPrefixes[accelerator]
if !ok {
prefix = strings.ToLower(strings.ReplaceAll(accelerator, "_", "-"))
}

date := time.Now().Format("20060102")

// Include a short hash of the workspace host to avoid known_hosts conflicts
// when connecting to different workspaces.
wsHash := md5.Sum([]byte(workspaceHost))
wsHashStr := hex.EncodeToString(wsHash[:])[:4]

b := make([]byte, 3)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand.Read failed: %v", err))
}
return "databricks-" + prefix + "-" + date + "-" + wsHashStr + hex.EncodeToString(b)
}
Loading
Loading