Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions cmd/src/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package main

import (
"flag"
"fmt"
)

var authCommands commander

func init() {
usage := `'src auth' provides authentication-related helper commands.

Usage:

src auth command [command options]

The commands are:

token prints the current authentication token

Use "src auth [command] -h" for more information about a command.
`

flagSet := flag.NewFlagSet("auth", flag.ExitOnError)
handler := func(args []string) error {
authCommands.run(flagSet, "src auth", usage, args)
return nil
}

commands = append(commands, &command{
flagSet: flagSet,
handler: handler,
usageFunc: func() {
fmt.Println(usage)
},
})
}
68 changes: 68 additions & 0 deletions cmd/src/auth_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package main

import (
"context"
"flag"
"fmt"

"github.com/sourcegraph/sourcegraph/lib/errors"

"github.com/sourcegraph/src-cli/internal/oauth"
)

var (
loadOAuthToken = oauth.LoadToken
newOAuthTokenRefresher = func(token *oauth.Token) oauthTokenRefresher {
return oauth.NewTokenRefresher(token)
}
)

type oauthTokenRefresher interface {
GetToken(ctx context.Context) (oauth.Token, error)
}

func init() {
flagSet := flag.NewFlagSet("token", flag.ExitOnError)
usageFunc := func() {
fmt.Fprintf(flag.CommandLine.Output(), "Usage of 'src auth token':\n")
flagSet.PrintDefaults()
}

handler := func(args []string) error {
if err := flagSet.Parse(args); err != nil {
return err
}

token, err := resolveAuthToken(context.Background(), cfg)
if err != nil {
return err
}

fmt.Println(token)
return nil
}

authCommands = append(authCommands, &command{
flagSet: flagSet,
handler: handler,
usageFunc: usageFunc,
})
}

func resolveAuthToken(ctx context.Context, cfg *config) (string, error) {
if cfg.accessToken != "" {
return cfg.accessToken, nil
}

oauthToken, err := loadOAuthToken(ctx, cfg.endpointURL)
if err != nil {
return "", errors.Wrap(err, "error loading OAuth token; set SRC_ACCESS_TOKEN or run `src login`")
}

token, err := newOAuthTokenRefresher(oauthToken).GetToken(ctx)
if err != nil {
return "", errors.Wrap(err, "refreshing OAuth token")
}

return token.AccessToken, nil
}
128 changes: 128 additions & 0 deletions cmd/src/auth_token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package main

import (
"context"
"fmt"
"net/url"
"testing"

"github.com/sourcegraph/src-cli/internal/oauth"
)

func TestResolveAuthToken(t *testing.T) {
t.Run("uses configured access token before keyring", func(t *testing.T) {
reset := stubAuthTokenDependencies(t)
defer reset()

newRefresherCalled := false
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
newRefresherCalled = true
return fakeOAuthTokenRefresher{}
}

token, err := resolveAuthToken(context.Background(), &config{
accessToken: "access-token",
endpointURL: mustParseURL(t, "https://example.com"),
})
if err != nil {
t.Fatal(err)
}
if token != "access-token" {
t.Fatalf("token = %q, want %q", token, "access-token")
}
if newRefresherCalled {
t.Fatal("expected OAuth token refresher not to be created")
}
})

t.Run("uses stored oauth token", func(t *testing.T) {
reset := stubAuthTokenDependencies(t)
defer reset()

loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
return &oauth.Token{
AccessToken: "oauth-token",
}, nil
}

newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "oauth-token"}}
}

token, err := resolveAuthToken(context.Background(), &config{
endpointURL: mustParseURL(t, "https://example.com"),
})
if err != nil {
t.Fatal(err)
}
if token != "oauth-token" {
t.Fatalf("token = %q, want %q", token, "oauth-token")
}
})

t.Run("refreshes expiring oauth token", func(t *testing.T) {
reset := stubAuthTokenDependencies(t)
defer reset()

loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
return &oauth.Token{AccessToken: "old-token"}, nil
}

newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "new-token"}}
}

token, err := resolveAuthToken(context.Background(), &config{
endpointURL: mustParseURL(t, "https://example.com"),
})
if err != nil {
t.Fatal(err)
}
if token != "new-token" {
t.Fatalf("token = %q, want %q", token, "new-token")
}
})

t.Run("returns refresh error when shared refresh logic fails", func(t *testing.T) {
reset := stubAuthTokenDependencies(t)
defer reset()

loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
return &oauth.Token{AccessToken: "old-token"}, nil
}
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
return fakeOAuthTokenRefresher{err: fmt.Errorf("refresh failed")}
}

_, err := resolveAuthToken(context.Background(), &config{
endpointURL: mustParseURL(t, "https://example.com"),
})
if err == nil {
t.Fatal("expected error")
}
})
}

func stubAuthTokenDependencies(t *testing.T) func() {
t.Helper()

prevLoad := loadOAuthToken
prevNewRefresher := newOAuthTokenRefresher

return func() {
loadOAuthToken = prevLoad
newOAuthTokenRefresher = prevNewRefresher
}
}

type fakeOAuthTokenRefresher struct {
token oauth.Token
err error
}

func (r fakeOAuthTokenRefresher) GetToken(context.Context) (oauth.Token, error) {
if r.err != nil {
return oauth.Token{}, r.err
}
return r.token, nil
}
1 change: 1 addition & 0 deletions cmd/src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ The options are:

The commands are:

auth authentication helper commands
api interacts with the Sourcegraph GraphQL API
batch manages batch changes
code-intel manages code intelligence data
Expand Down
5 changes: 1 addition & 4 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,7 @@ func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper {
}

if opts.AccessToken == "" && opts.OAuthToken != nil {
transport = &oauth.Transport{
Base: transport,
Token: opts.OAuthToken,
}
transport = oauth.NewTransport(transport, opts.OAuthToken)
}

return transport
Expand Down
55 changes: 34 additions & 21 deletions internal/oauth/http_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,24 @@ var _ http.Transport

var _ http.RoundTripper = (*Transport)(nil)

const defaultRefreshWindow = 5 * time.Minute

type Transport struct {
Base http.RoundTripper
//Token is a OAuth token (which has a refresh token) that should be used during roundtrip to automatically
//refresh the OAuth access token once the current one has expired or is soon to expire
Token *Token
Base http.RoundTripper
refresher *TokenRefresher
}

//mu is a mutex that should be acquired whenever token used
mu sync.Mutex
type TokenRefresher struct {
token *Token
mu sync.Mutex
}

func NewTokenRefresher(token *Token) *TokenRefresher {
return &TokenRefresher{token: token}
}

func NewTransport(base http.RoundTripper, token *Token) *Transport {
return &Transport{Base: base, refresher: NewTokenRefresher(token)}
}

// storeRefreshedTokenFn is the function the transport should use to persist the token - mainly used during
Expand All @@ -30,8 +40,7 @@ var storeRefreshedTokenFn = StoreToken
// RoundTrip implements http.RoundTripper.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()

token, err := t.getToken(ctx)
token, err := t.refresher.GetToken(ctx)
if err != nil {
return nil, err
}
Expand All @@ -45,36 +54,40 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
return http.DefaultTransport.RoundTrip(req2)
}

// getToken returns a value copy of the token. If the token has expired or expiring soon it will be refreshed before returning.
// GetToken returns a value copy of the token. If the token has expired or expiring soon it will be refreshed before returning.
// Once the token is refreshed, the in-memory token is updated and a best effort is made to store the token.
//
// If storing the token fails, no error is returned. An error is only returned if refreshing the token
// fails.
func (t *Transport) getToken(ctx context.Context) (Token, error) {
t.mu.Lock()
defer t.mu.Unlock()
func (r *TokenRefresher) GetToken(ctx context.Context) (Token, error) {
r.mu.Lock()
defer r.mu.Unlock()

prevToken := t.Token
token, err := maybeRefresh(ctx, t.Token)
prevToken := r.token
token, err := maybeRefreshToken(ctx, r.token)
if err != nil {
return Token{}, err
}
t.Token = token
r.token = token
if token != prevToken {
// Try to save the token.
// If we fail let the request continue with the in-memory token
_ = storeRefreshedTokenFn(ctx, token)
}

return *t.Token, nil
return *r.token, nil
}

// maybeRefresh conditionally refreshes the token. If the token has expired or is expriing in the next 30s
// it will be refreshed and the updated token will be returned. Otherwise, no refresh occurs and the original
// token is returned.
func maybeRefresh(ctx context.Context, token *Token) (*Token, error) {
// maybeRefreshToken conditionally refreshes the token. If the token has expired or is
// expiring within the default refresh window, it will be refreshed and the updated token returned.
// Otherwise, no refresh occurs and the original token is returned.
func maybeRefreshToken(ctx context.Context, token *Token) (*Token, error) {
if token == nil {
return nil, errors.New("token is nil")
}

// token has NOT expired and is NOT about to expire in 30s
if !(token.HasExpired() || token.ExpiringIn(time.Duration(30)*time.Second)) {
if !(token.HasExpired() || token.ExpiringIn(defaultRefreshWindow)) {
return token, nil
}
client := NewClient(token.ClientID)
Expand Down
Loading
Loading