From d70450e8d6ae02d232c9c70a9c9706475f5e7f6a Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 16:30:03 +0200 Subject: [PATCH 01/14] add refresh to oauthdevice.Client From d35663e424a3bfdb6bc2b689e5ddfb4d422792b3 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 4 Dec 2025 11:53:35 +0200 Subject: [PATCH 02/14] add internal/keyring package to use 99designs keyring - rename keyring to store - make keyring struct src-cli and set label on secret --- go.mod | 7 ++++ go.sum | 14 ++++++++ internal/keyring/keyring.go | 66 +++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+) create mode 100644 internal/keyring/keyring.go diff --git a/go.mod b/go.mod index 3c9e3eb338..2cce32f1dd 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,8 @@ require ( cloud.google.com/go/auth v0.17.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/monitoring v1.24.2 // indirect + github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect + github.com/99designs/keyring v1.2.2 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.50.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.50.0 // indirect @@ -64,6 +66,7 @@ require ( github.com/clipperhouse/uax29/v2 v2.2.0 // indirect github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 // indirect github.com/containerd/stargz-snapshotter/estargz v0.14.3 // indirect + github.com/danieljoos/wincred v1.2.0 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/cli v24.0.4+incompatible // indirect github.com/docker/distribution v2.8.2+incompatible // indirect @@ -71,6 +74,7 @@ require ( github.com/docker/docker-credential-helpers v0.8.0 // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/dvsekhvalnov/jose2go v1.5.0 // indirect github.com/envoyproxy/go-control-plane/envoy v1.32.4 // indirect github.com/felixge/fgprof v0.9.3 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -78,6 +82,7 @@ require ( github.com/go-chi/chi/v5 v5.2.2 // indirect github.com/go-jose/go-jose/v4 v4.1.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect github.com/gofrs/uuid/v5 v5.0.0 // indirect github.com/google/gnostic-models v0.6.9 // indirect github.com/google/go-containerregistry v0.19.1 // indirect @@ -85,6 +90,7 @@ require ( github.com/google/s2a-go v0.1.9 // indirect github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect + github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/iancoleman/strcase v0.3.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect @@ -95,6 +101,7 @@ require ( github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/morikuni/aec v1.0.0 // indirect + github.com/mtibben/percent v0.2.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0-rc4 // indirect github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 // indirect diff --git a/go.sum b/go.sum index f47d1d10c9..be3b08291b 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,10 @@ cloud.google.com/go/storage v1.50.0 h1:3TbVkzTooBvnZsk7WaAQfOsNrdoM8QHusXA1cpk6Q cloud.google.com/go/storage v1.50.0/go.mod h1:l7XeiD//vx5lfqE3RavfmU9yvk5Pp0Zhcv482poyafY= cloud.google.com/go/trace v1.11.6 h1:2O2zjPzqPYAHrn3OKl029qlqG6W8ZdYaOWRyr8NgMT4= cloud.google.com/go/trace v1.11.6/go.mod h1:GA855OeDEBiBMzcckLPE2kDunIpC72N+Pq8WFieFjnI= +github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMbk2FiG/kXiLl8BRyzTWDw7gX/Hz7Dd5eDMs= +github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4= +github.com/99designs/keyring v1.2.2 h1:pZd3neh/EmUzWONb35LxQfvuY7kiSXAq3HQd97+XBn0= +github.com/99designs/keyring v1.2.2/go.mod h1:wes/FrByc8j7lFOAGLGSNEg8f/PaI3cgTBqhFkHUrPk= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= @@ -139,6 +143,8 @@ github.com/creack/goselect v0.1.2/go.mod h1:a/NhLweNvqIYMuxcMOuWY516Cimucms3DglD github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= +github.com/danieljoos/wincred v1.2.0 h1:ozqKHaLK0W/ii4KVbbvluM91W2H3Sh0BncbUNPS7jLE= +github.com/danieljoos/wincred v1.2.0/go.mod h1:FzQLLMKBFdvu+osBrnFODiv32YGwCfx0SkRa/eYHgec= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -165,6 +171,8 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/dvsekhvalnov/jose2go v1.5.0 h1:3j8ya4Z4kMCwT5nXIKFSV84YS+HdqSSO0VsTQxaLAeM= +github.com/dvsekhvalnov/jose2go v1.5.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane v0.13.4 h1:zEqyPVyku6IvWCFwux4x9RxkLOMUL+1vC9xUFv5l2/M= @@ -212,6 +220,8 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 h1:ZpnhV/YsD2/4cESfV5+Hoeu/iUR3ruzNvZ+yQfO03a0= +github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gofrs/uuid/v5 v5.0.0 h1:p544++a97kEL+svbcFbCQVM9KFu0Yo25UoISXGNNH9M= @@ -258,6 +268,8 @@ github.com/grafana/regexp v0.0.0-20250905093917-f7b3be9d1853 h1:cLN4IBkmkYZNnk7E github.com/grafana/regexp v0.0.0-20250905093917-f7b3be9d1853/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= +github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= +github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= github.com/hexops/autogold v0.8.1/go.mod h1:97HLDXyG23akzAoRYJh/2OBs3kd80eHyKPvZw0S5ZBY= github.com/hexops/autogold v1.3.1 h1:YgxF9OHWbEIUjhDbpnLhgVsjUDsiHDTyDfy2lrfdlzo= github.com/hexops/autogold v1.3.1/go.mod h1:sQO+mQUCVfxOKPht+ipDSkJ2SCJ7BNJVHZexsXqWMx4= @@ -361,6 +373,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs= +github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= diff --git a/internal/keyring/keyring.go b/internal/keyring/keyring.go new file mode 100644 index 0000000000..b86eddc123 --- /dev/null +++ b/internal/keyring/keyring.go @@ -0,0 +1,66 @@ +// Package keyring provides secure credential storage using the system keychain. +package keyring + +import ( + "github.com/99designs/keyring" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +const ( + serviceName = "sourcegraph-cli" + + KeyOAuth = "oauth" +) + +// Store provides secure credential storage operations. +type Store struct { + ring keyring.Keyring +} + +// Open opens the system keyring for the Sourcegraph CLI. +func Open() (*Store, error) { + ring, err := keyring.Open(keyring.Config{ + ServiceName: serviceName, + KeychainName: "login", // This is the default name for the keychain where MacOS puts all login passwords + KeychainTrustApplication: true, // the keychain can trust src-cli! + }) + if err != nil { + return nil, errors.Wrap(err, "opening keyring") + } + return &Store{ring: ring}, nil +} + +// Set stores a key-value pair in the keyring. +func (s *Store) Set(key string, data []byte) error { + err := s.ring.Set(keyring.Item{ + Key: key, + Data: data, + Label: key, + }) + if err != nil { + return errors.Wrap(err, "storing item in keyring") + } + return nil +} + +// Get retrieves a value by key from the keyring. +// Returns nil, nil if the key is not found. +func (s *Store) Get(key string) ([]byte, error) { + item, err := s.ring.Get(key) + if err != nil { + if err == keyring.ErrKeyNotFound { + return nil, nil + } + return nil, errors.Wrap(err, "getting item from keyring") + } + return item.Data, nil +} + +// Delete removes a key from the keyring. +func (s *Store) Delete(key string) error { + err := s.ring.Remove(key) + if err != nil && err != keyring.ErrKeyNotFound { + return errors.Wrap(err, "removing item from keyring") + } + return nil +} From 0569974b20c3f55fcb42ea5a34229be1cb6c559d Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 4 Dec 2025 14:21:58 +0200 Subject: [PATCH 03/14] add OAuth Transport and use it if no access token - create token struct from TokenResponse - Token converts expiresIn to a timestamp - Store the token with the endpoint suffix - OAuth transport and use when available in api client --- cmd/src/login.go | 25 ++++++++--- cmd/src/login_test.go | 9 +++- cmd/src/main.go | 16 ++++++- internal/api/api.go | 33 ++++++++++---- internal/keyring/keyring.go | 6 +-- internal/oauth/flow.go | 77 +++++++++++++++++++++++++++----- internal/oauth/flow_test.go | 13 ++++-- internal/oauth/http_transport.go | 49 ++++++++++++++++++++ 8 files changed, 191 insertions(+), 37 deletions(-) create mode 100644 internal/oauth/http_transport.go diff --git a/cmd/src/login.go b/cmd/src/login.go index e85a0bd4e9..1c9191ff79 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -13,6 +13,7 @@ import ( "github.com/sourcegraph/src-cli/internal/api" "github.com/sourcegraph/src-cli/internal/cmderrors" + "github.com/sourcegraph/src-cli/internal/keyring" "github.com/sourcegraph/src-cli/internal/oauth" ) @@ -122,6 +123,13 @@ func loginCmd(ctx context.Context, p loginParams) error { noToken := cfg.AccessToken == "" endpointConflict := endpointArg != cfg.Endpoint + secretStore, err := keyring.Open() + if err != nil { + printProblem(fmt.Sprintf("could not open keyring for secret storage: %s", err)) + } + + cfg.Endpoint = endpointArg + if p.useOAuth { token, err := runOAuthDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient) if err != nil { @@ -130,8 +138,11 @@ func loginCmd(ctx context.Context, p loginParams) error { return cmderrors.ExitCode1 } - cfg.AccessToken = token - cfg.Endpoint = endpointArg + if err := oauth.StoreToken(secretStore, token); err != nil { + printProblem(fmt.Sprintf("Failed to store token in keyring store: %s", err)) + return cmderrors.ExitCode1 + } + client = cfg.apiClient(p.apiFlags, out) } else if noToken || endpointConflict { fmt.Fprintln(out) @@ -179,10 +190,10 @@ func loginCmd(ctx context.Context, p loginParams) error { return nil } -func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (string, error) { +func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) { authResp, err := client.Start(ctx, endpoint, nil) if err != nil { - return "", err + return nil, err } authURL := authResp.VerificationURIComplete @@ -204,12 +215,12 @@ func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, cli interval = 5 * time.Second } - tokenResp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn) + resp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn) if err != nil { - return "", err + return nil, err } - return tokenResp.AccessToken, nil + return resp.Token(endpoint), nil } func openInBrowser(url string) error { diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index ef7d01e019..37d3202227 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/sourcegraph/src-cli/internal/cmderrors" + "github.com/sourcegraph/src-cli/internal/oauth" ) func TestLogin(t *testing.T) { @@ -18,7 +19,13 @@ func TestLogin(t *testing.T) { t.Helper() var out bytes.Buffer - err = loginCmd(context.Background(), loginParams{cfg: cfg, client: cfg.apiClient(nil, io.Discard), endpoint: endpointArg, out: &out}) + err = loginCmd(context.Background(), loginParams{ + cfg: cfg, + client: cfg.apiClient(nil, io.Discard), + endpoint: endpointArg, + out: &out, + deviceFlowClient: oauth.NewClient(oauth.DefaultClientID), + }) return strings.TrimSpace(out.String()), err } diff --git a/cmd/src/main.go b/cmd/src/main.go index edfb1073d7..7fccd3e396 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -15,6 +15,8 @@ import ( "github.com/sourcegraph/sourcegraph/lib/errors" "github.com/sourcegraph/src-cli/internal/api" + "github.com/sourcegraph/src-cli/internal/keyring" + "github.com/sourcegraph/src-cli/internal/oauthdevice" ) const usageText = `src is a tool that provides access to Sourcegraph instances. @@ -122,7 +124,7 @@ type config struct { // apiClient returns an api.Client built from the configuration. func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { - return api.NewClient(api.ClientOpts{ + opts := api.ClientOpts{ Endpoint: c.Endpoint, AccessToken: c.AccessToken, AdditionalHeaders: c.AdditionalHeaders, @@ -130,7 +132,17 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { Out: out, ProxyURL: c.ProxyURL, ProxyPath: c.ProxyPath, - }) + } + store, err := keyring.Open() + if err != nil { + panic("HALP") + } + + if t, err := oauthdevice.LoadToken(store, c.Endpoint); err == nil { + opts.OAuthToken = t + } + + return api.NewClient(opts) } // readConfig reads the config file from the given path. diff --git a/internal/api/api.go b/internal/api/api.go index 5f750c1d4a..2a2dbe6415 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -18,6 +18,7 @@ import ( "github.com/kballard/go-shellquote" "github.com/mattn/go-isatty" + "github.com/sourcegraph/src-cli/internal/oauthdevice" "github.com/sourcegraph/src-cli/internal/version" ) @@ -85,21 +86,35 @@ type ClientOpts struct { ProxyURL *url.URL ProxyPath string + + OAuthToken *oauthdevice.Token } -func buildTransport(opts ClientOpts, flags *Flags) *http.Transport { - transport := http.DefaultTransport.(*http.Transport).Clone() +func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper { + var transport http.RoundTripper + { + tp := http.DefaultTransport.(*http.Transport).Clone() - if flags.insecureSkipVerify != nil && *flags.insecureSkipVerify { - transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } + if flags.insecureSkipVerify != nil && *flags.insecureSkipVerify { + tp.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + + if tp.TLSClientConfig == nil { + tp.TLSClientConfig = &tls.Config{} + } + + if opts.ProxyURL != nil || opts.ProxyPath != "" { + tp = withProxyTransport(tp, opts.ProxyURL, opts.ProxyPath) + } - if transport.TLSClientConfig == nil { - transport.TLSClientConfig = &tls.Config{} + transport = tp } - if opts.ProxyURL != nil || opts.ProxyPath != "" { - transport = withProxyTransport(transport, opts.ProxyURL, opts.ProxyPath) + if opts.AccessToken == "" && opts.OAuthToken != nil { + transport = &oauthdevice.Transport{ + Base: transport, + Token: opts.OAuthToken, + } } return transport diff --git a/internal/keyring/keyring.go b/internal/keyring/keyring.go index b86eddc123..47b18e03bb 100644 --- a/internal/keyring/keyring.go +++ b/internal/keyring/keyring.go @@ -6,11 +6,7 @@ import ( "github.com/sourcegraph/sourcegraph/lib/errors" ) -const ( - serviceName = "sourcegraph-cli" - - KeyOAuth = "oauth" -) +const serviceName = "sourcegraph-cli" // Store provides secure credential storage operations. type Store struct { diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index 584244cc43..3aadf75b99 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "github.com/sourcegraph/src-cli/internal/keyring" + "github.com/sourcegraph/sourcegraph/lib/errors" ) @@ -23,6 +25,9 @@ const ( // wellKnownPath is the path on the sourcegraph server where clients can discover OAuth configuration wellKnownPath = "/.well-known/openid-configuration" + // Key used to store the token in the store + KeyOAuth = "Sourcegraph CLI key storage" + GrantTypeDeviceCode string = "urn:ietf:params:oauth:grant-type:device_code" ScopeOpenID string = "openid" @@ -54,11 +59,18 @@ type DeviceAuthResponse struct { type TokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` - TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in,omitempty"` + TokenType string `json:"token_type"` Scope string `json:"scope,omitempty"` } +type Token struct { + Endpoint string `json:"endpoint"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt time.Time `json:"expires_at"` +} + type ErrorResponse struct { Error string `json:"error"` ErrorDescription string `json:"error_description,omitempty"` @@ -68,7 +80,7 @@ type Client interface { Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) - Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error) + Refresh(ctx context.Context, token *Token) (*TokenResponse, error) } type httpClient struct { @@ -310,22 +322,20 @@ func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode str } // Refresh exchanges a refresh token for a new access token. -func (c *httpClient) Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error) { - endpoint = strings.TrimRight(endpoint, "/") - - config, err := c.Discover(ctx, endpoint) +func (c *httpClient) Refresh(ctx context.Context, token *Token) (*TokenResponse, error) { + config, err := c.Discover(ctx, token.Endpoint) if err != nil { - return nil, errors.Wrap(err, "OIDC discovery failed") + errors.Wrap(err, "failed to discover OIDC configuration") } if config.TokenEndpoint == "" { - return nil, errors.New("token endpoint not found in OIDC configuration") + errors.New("OIDC configuration has no token endpoint") } data := url.Values{} data.Set("client_id", c.clientID) data.Set("grant_type", "refresh_token") - data.Set("refresh_token", refreshToken) + data.Set("refresh_token", token.RefreshToken) req, err := http.NewRequestWithContext(ctx, "POST", config.TokenEndpoint, strings.NewReader(data.Encode())) if err != nil { @@ -358,5 +368,52 @@ func (c *httpClient) Refresh(ctx context.Context, endpoint, refreshToken string) return nil, errors.Wrap(err, "parsing refresh token response") } - return &tokenResp, nil + return &tokenResp, err +} + +func (t *TokenResponse) Token(endpoint string) *Token { + return &Token{ + Endpoint: strings.TrimRight(endpoint, "/"), + RefreshToken: t.RefreshToken, + AccessToken: t.AccessToken, + ExpiresAt: time.Now().Add(time.Second * time.Duration(t.ExpiresIn)), + } +} + +func (t *Token) HasExpired() bool { + return time.Now().After(t.ExpiresAt) +} + +func (t *Token) ExpiringIn(d time.Duration) bool { + future := time.Now().Add(d) + return future.After(t.ExpiresAt) +} + +func StoreToken(store *keyring.Store, token *Token) error { + data, err := json.Marshal(token) + if err != nil { + return errors.Wrap(err, "failed to marshal token") + } + + if token.Endpoint == "" { + return errors.New("token endpoint cannot be empty when storing the token") + } + + key := fmt.Sprintf("%s <%s>", KeyOAuth, token.Endpoint) + return store.Set(key, data) +} + +func LoadToken(store *keyring.Store, endpoint string) (*Token, error) { + key := fmt.Sprintf("%s <%s>", KeyOAuth, endpoint) + var t Token + data, err := store.Get(key) + if err != nil { + return nil, errors.Wrap(err, "failed to get token from store") + } + + if err := json.Unmarshal(data, &t); err != nil { + return nil, errors.Wrap(err, "failed to unmarshall token") + } + + return &t, nil } diff --git a/internal/oauth/flow_test.go b/internal/oauth/flow_test.go index 8c82d9d119..0268195317 100644 --- a/internal/oauth/flow_test.go +++ b/internal/oauth/flow_test.go @@ -267,9 +267,9 @@ func TestStart_NoDeviceEndpoint(t *testing.T) { func TestPoll_Success(t *testing.T) { wantToken := TokenResponse{ AccessToken: "test-access-token", - TokenType: "Bearer", ExpiresIn: 3600, Scope: "read write", + TokenType: "Bearer", } server := newTestServer(t, testServerOptions{ @@ -313,6 +313,7 @@ func TestPoll_Success(t *testing.T) { if resp.TokenType != wantToken.TokenType { t.Errorf("TokenType = %q, want %q", resp.TokenType, wantToken.TokenType) } + } func TestPoll_AuthorizationPending(t *testing.T) { @@ -527,8 +528,8 @@ func TestRefresh_Success(t *testing.T) { json.NewEncoder(w).Encode(TokenResponse{ AccessToken: "new-access-token", RefreshToken: "new-refresh-token", - TokenType: "Bearer", ExpiresIn: 3600, + TokenType: "Bearer", }) }, }, @@ -536,7 +537,13 @@ func TestRefresh_Success(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID) - resp, err := client.Refresh(context.Background(), server.URL, "test-refresh-token") + token := &Token{ + Endpoint: server.URL, + AccessToken: "new-access-token", + RefreshToken: "test-refresh-token", + ExpiresAt: time.Now().Add(time.Second * time.Duration(3600)), + } + resp, err := client.Refresh(context.Background(), token) if err != nil { t.Fatalf("Refresh() error = %v", err) } diff --git a/internal/oauth/http_transport.go b/internal/oauth/http_transport.go new file mode 100644 index 0000000000..483b45108d --- /dev/null +++ b/internal/oauth/http_transport.go @@ -0,0 +1,49 @@ +package oauthdevice + +import ( + "context" + "net/http" + "time" +) + +var _ http.Transport + +var _ http.RoundTripper = (*Transport)(nil) + +type Transport struct { + Base http.RoundTripper + Token *Token +} + +// RoundTrip implements http.RoundTripper. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + token, err := maybeRefresh(ctx, t.Token) + if err != nil { + return nil, err + } + t.Token = token + + req2 := req.Clone(req.Context()) + req2.Header.Set("Authorization", "Bearer "+t.Token.AccessToken) + + if t.Base != nil { + return t.Base.RoundTrip(req2) + } + return http.DefaultTransport.RoundTrip(req2) +} + +func maybeRefresh(ctx context.Context, token *Token) (*Token, error) { + // token has NOT expired or NOT about to expire in 30s + if !(token.HasExpired() || token.ExpiringIn(time.Duration(30)*time.Second)) { + return token, nil + } + client := NewClient(DefaultClientID) + + resp, err := client.Refresh(ctx, token) + if err != nil { + return nil, err + } + + return resp.Token(token.Endpoint), nil +} From a94399da028bd1952704a55ec81355e67a0b3d55 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Mon, 8 Dec 2025 17:02:29 +0200 Subject: [PATCH 04/14] add secrets package to manage secrets - Add secret store that supports different backends - We use a registry map for a few secrets and the registry gets persisted as one secret to the keyring. We don't waant to create a keyring secret for every different secret - Store is opened once to load the registry. - use secretStorage to store oauth tokens --- cmd/src/login.go | 8 +- cmd/src/main.go | 12 ++- internal/oauth/flow.go | 25 ++++-- internal/{keyring => secrets}/keyring.go | 27 +++--- internal/secrets/store.go | 104 +++++++++++++++++++++++ 5 files changed, 140 insertions(+), 36 deletions(-) rename internal/{keyring => secrets}/keyring.go (66%) create mode 100644 internal/secrets/store.go diff --git a/cmd/src/login.go b/cmd/src/login.go index 1c9191ff79..2c10d5936a 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -13,7 +13,6 @@ import ( "github.com/sourcegraph/src-cli/internal/api" "github.com/sourcegraph/src-cli/internal/cmderrors" - "github.com/sourcegraph/src-cli/internal/keyring" "github.com/sourcegraph/src-cli/internal/oauth" ) @@ -123,11 +122,6 @@ func loginCmd(ctx context.Context, p loginParams) error { noToken := cfg.AccessToken == "" endpointConflict := endpointArg != cfg.Endpoint - secretStore, err := keyring.Open() - if err != nil { - printProblem(fmt.Sprintf("could not open keyring for secret storage: %s", err)) - } - cfg.Endpoint = endpointArg if p.useOAuth { @@ -138,7 +132,7 @@ func loginCmd(ctx context.Context, p loginParams) error { return cmderrors.ExitCode1 } - if err := oauth.StoreToken(secretStore, token); err != nil { + if err := oauth.StoreToken(token); err != nil { printProblem(fmt.Sprintf("Failed to store token in keyring store: %s", err)) return cmderrors.ExitCode1 } diff --git a/cmd/src/main.go b/cmd/src/main.go index 7fccd3e396..24dca7551e 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -15,7 +15,6 @@ import ( "github.com/sourcegraph/sourcegraph/lib/errors" "github.com/sourcegraph/src-cli/internal/api" - "github.com/sourcegraph/src-cli/internal/keyring" "github.com/sourcegraph/src-cli/internal/oauthdevice" ) @@ -133,13 +132,12 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { ProxyURL: c.ProxyURL, ProxyPath: c.ProxyPath, } - store, err := keyring.Open() - if err != nil { - panic("HALP") - } - if t, err := oauthdevice.LoadToken(store, c.Endpoint); err == nil { - opts.OAuthToken = t + // Only use OAuth if we do not have SRC_ACCESS_TOKEN set + if c.AccessToken == "" { + if t, err := oauthdevice.LoadToken(c.Endpoint); err == nil { + opts.OAuthToken = t + } } return api.NewClient(opts) diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index 3aadf75b99..5fd2bdbe8a 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -13,7 +13,7 @@ import ( "testing" "time" - "github.com/sourcegraph/src-cli/internal/keyring" + "github.com/sourcegraph/src-cli/internal/secrets" "github.com/sourcegraph/sourcegraph/lib/errors" ) @@ -389,7 +389,11 @@ func (t *Token) ExpiringIn(d time.Duration) bool { return future.After(t.ExpiresAt) } -func StoreToken(store *keyring.Store, token *Token) error { +func StoreToken(token *Token) error { + store, err := secrets.Store() + if err != nil { + return err + } data, err := json.Marshal(token) if err != nil { return errors.Wrap(err, "failed to marshal token") @@ -399,18 +403,23 @@ func StoreToken(store *keyring.Store, token *Token) error { return errors.New("token endpoint cannot be empty when storing the token") } - key := fmt.Sprintf("%s <%s>", KeyOAuth, token.Endpoint) - return store.Set(key, data) + key := fmt.Sprintf("oauth[%s]", token.Endpoint) + return store.Put(key, data) } -func LoadToken(store *keyring.Store, endpoint string) (*Token, error) { - key := fmt.Sprintf("%s <%s>", KeyOAuth, endpoint) - var t Token +func LoadToken(endpoint string) (*Token, error) { + store, err := secrets.Store() + if err != nil { + return nil, err + } + + key := fmt.Sprintf("oauth[%s]", endpoint) data, err := store.Get(key) if err != nil { - return nil, errors.Wrap(err, "failed to get token from store") + return nil, err } + var t Token if err := json.Unmarshal(data, &t); err != nil { return nil, errors.Wrap(err, "failed to unmarshall token") } diff --git a/internal/keyring/keyring.go b/internal/secrets/keyring.go similarity index 66% rename from internal/keyring/keyring.go rename to internal/secrets/keyring.go index 47b18e03bb..3acd981b18 100644 --- a/internal/keyring/keyring.go +++ b/internal/secrets/keyring.go @@ -1,5 +1,4 @@ -// Package keyring provides secure credential storage using the system keychain. -package keyring +package secrets import ( "github.com/99designs/keyring" @@ -8,13 +7,13 @@ import ( const serviceName = "sourcegraph-cli" -// Store provides secure credential storage operations. -type Store struct { +// keyringStore provides secure credential storage operations. +type keyringStore struct { ring keyring.Keyring } -// Open opens the system keyring for the Sourcegraph CLI. -func Open() (*Store, error) { +// open opens the system keyring for the Sourcegraph CLI. +func openKeyring() (*keyringStore, error) { ring, err := keyring.Open(keyring.Config{ ServiceName: serviceName, KeychainName: "login", // This is the default name for the keychain where MacOS puts all login passwords @@ -23,12 +22,12 @@ func Open() (*Store, error) { if err != nil { return nil, errors.Wrap(err, "opening keyring") } - return &Store{ring: ring}, nil + return &keyringStore{ring: ring}, nil } // Set stores a key-value pair in the keyring. -func (s *Store) Set(key string, data []byte) error { - err := s.ring.Set(keyring.Item{ +func (k *keyringStore) Put(key string, data []byte) error { + err := k.ring.Set(keyring.Item{ Key: key, Data: data, Label: key, @@ -41,11 +40,11 @@ func (s *Store) Set(key string, data []byte) error { // Get retrieves a value by key from the keyring. // Returns nil, nil if the key is not found. -func (s *Store) Get(key string) ([]byte, error) { - item, err := s.ring.Get(key) +func (k *keyringStore) Get(key string) ([]byte, error) { + item, err := k.ring.Get(key) if err != nil { if err == keyring.ErrKeyNotFound { - return nil, nil + return nil, ErrSecretNotFound } return nil, errors.Wrap(err, "getting item from keyring") } @@ -53,8 +52,8 @@ func (s *Store) Get(key string) ([]byte, error) { } // Delete removes a key from the keyring. -func (s *Store) Delete(key string) error { - err := s.ring.Remove(key) +func (k *keyringStore) Delete(key string) error { + err := k.ring.Remove(key) if err != nil && err != keyring.ErrKeyNotFound { return errors.Wrap(err, "removing item from keyring") } diff --git a/internal/secrets/store.go b/internal/secrets/store.go new file mode 100644 index 0000000000..dda306859e --- /dev/null +++ b/internal/secrets/store.go @@ -0,0 +1,104 @@ +package secrets + +import ( + "encoding/json" + "sync" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +const keyRegistry = "secret-registry" + +var ErrSecretNotFound = errors.New("secret not found") + +var openOnce = sync.OnceValues(Open) + +type SecretStorage interface { + Get(key string) ([]byte, error) + Put(key string, data []byte) error + Delete(key string) error +} + +type store struct { + backend SecretStorage + registry map[string][]byte + + mu sync.Mutex +} + +func Store() (SecretStorage, error) { + return openOnce() +} + +func Open() (SecretStorage, error) { + keyring, err := openKeyring() + if err != nil { + return nil, err + } + + registry, err := getRegistry(keyring) + if err != nil { + return nil, err + } + s := &store{ + backend: keyring, + registry: registry, + } + + return s, nil +} + +func getRegistry(s SecretStorage) (map[string][]byte, error) { + data, err := s.Get(keyRegistry) + if err != nil { + return nil, errors.Wrap(err, "failed to load registry from backing store") + } + + var registry map[string][]byte + if err := json.Unmarshal(data, ®istry); err != nil { + return nil, errors.Wrap(err, "failed to decode registry from backing store") + } + + return registry, nil +} + +func saveRegistry(s SecretStorage, registry map[string][]byte) error { + data, err := json.Marshal(®istry) + if err != nil { + return errors.Wrap(err, "registry encoding failure") + } + + if err = s.Put(keyRegistry, data); err != nil { + return errors.Wrap(err, "failed to persist registry to backing store") + } + + return nil +} + +func (s *store) Get(key string) ([]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.registry[key] + if !ok { + return nil, ErrSecretNotFound + } + + return v, nil +} + +func (s *store) Put(key string, data []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.registry[key] = data + + return saveRegistry(s.backend, s.registry) +} + +func (s *store) Delete(key string) error { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.registry, key) + return saveRegistry(s.backend, s.registry) +} From 682808d8fb7470d4c1ea23f68146be868b939693 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Fri, 23 Jan 2026 12:12:12 +0200 Subject: [PATCH 05/14] secrets: switch to zalando/go-keyring and add context support Amp-Thread-ID: https://ampcode.com/threads/T-019bea47-5179-7418-86cf-bf1d4cc93d28 Co-authored-by: Amp - minor keyring refactor --- cmd/src/login.go | 2 +- cmd/src/main.go | 3 +- go.mod | 11 +-- go.sum | 24 +++--- internal/oauth/flow.go | 24 +++--- internal/oauth/http_transport_test.go | 100 +++++++++++++++++++++++++ internal/secrets/keyring.go | 97 +++++++++++++++--------- internal/secrets/store.go | 104 -------------------------- 8 files changed, 192 insertions(+), 173 deletions(-) create mode 100644 internal/oauth/http_transport_test.go delete mode 100644 internal/secrets/store.go diff --git a/cmd/src/login.go b/cmd/src/login.go index 2c10d5936a..b2ce6fae8a 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -132,7 +132,7 @@ func loginCmd(ctx context.Context, p loginParams) error { return cmderrors.ExitCode1 } - if err := oauth.StoreToken(token); err != nil { + if err := oauth.StoreToken(ctx, token); err != nil { printProblem(fmt.Sprintf("Failed to store token in keyring store: %s", err)) return cmderrors.ExitCode1 } diff --git a/cmd/src/main.go b/cmd/src/main.go index 24dca7551e..f9794ee23e 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "flag" "io" @@ -135,7 +136,7 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { // Only use OAuth if we do not have SRC_ACCESS_TOKEN set if c.AccessToken == "" { - if t, err := oauthdevice.LoadToken(c.Endpoint); err == nil { + if t, err := oauthdevice.LoadToken(context.Background(), c.Endpoint); err == nil { opts.OAuthToken = t } } diff --git a/go.mod b/go.mod index 2cce32f1dd..0a04972e5a 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/sourcegraph/sourcegraph/lib v0.0.0-20240709083501-1af563b61442 github.com/stretchr/testify v1.11.1 github.com/tliron/glsp v0.2.2 + github.com/zalando/go-keyring v0.2.6 golang.org/x/sync v0.18.0 google.golang.org/api v0.256.0 google.golang.org/protobuf v1.36.10 @@ -41,12 +42,11 @@ require ( ) require ( + al.essio.dev/pkg/shellescape v1.5.1 // indirect cel.dev/expr v0.24.0 // indirect cloud.google.com/go/auth v0.17.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/monitoring v1.24.2 // indirect - github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect - github.com/99designs/keyring v1.2.2 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.50.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.50.0 // indirect @@ -66,7 +66,7 @@ require ( github.com/clipperhouse/uax29/v2 v2.2.0 // indirect github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 // indirect github.com/containerd/stargz-snapshotter/estargz v0.14.3 // indirect - github.com/danieljoos/wincred v1.2.0 // indirect + github.com/danieljoos/wincred v1.2.2 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/cli v24.0.4+incompatible // indirect github.com/docker/distribution v2.8.2+incompatible // indirect @@ -74,7 +74,6 @@ require ( github.com/docker/docker-credential-helpers v0.8.0 // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect - github.com/dvsekhvalnov/jose2go v1.5.0 // indirect github.com/envoyproxy/go-control-plane/envoy v1.32.4 // indirect github.com/felixge/fgprof v0.9.3 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -82,7 +81,7 @@ require ( github.com/go-chi/chi/v5 v5.2.2 // indirect github.com/go-jose/go-jose/v4 v4.1.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect + github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/gofrs/uuid/v5 v5.0.0 // indirect github.com/google/gnostic-models v0.6.9 // indirect github.com/google/go-containerregistry v0.19.1 // indirect @@ -90,7 +89,6 @@ require ( github.com/google/s2a-go v0.1.9 // indirect github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect - github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/iancoleman/strcase v0.3.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect @@ -101,7 +99,6 @@ require ( github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/morikuni/aec v1.0.0 // indirect - github.com/mtibben/percent v0.2.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0-rc4 // indirect github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 // indirect diff --git a/go.sum b/go.sum index be3b08291b..6cbdc71412 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +al.essio.dev/pkg/shellescape v1.5.1 h1:86HrALUujYS/h+GtqoB26SBEdkWfmMI6FubjXlsXyho= +al.essio.dev/pkg/shellescape v1.5.1/go.mod h1:6sIqp7X2P6mThCQ7twERpZTuigpr6KbZWtls1U8I890= cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY= cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= cloud.google.com/go v0.120.0 h1:wc6bgG9DHyKqF5/vQvX1CiZrtHnxJjBlKUyF9nP6meA= @@ -20,10 +22,6 @@ cloud.google.com/go/storage v1.50.0 h1:3TbVkzTooBvnZsk7WaAQfOsNrdoM8QHusXA1cpk6Q cloud.google.com/go/storage v1.50.0/go.mod h1:l7XeiD//vx5lfqE3RavfmU9yvk5Pp0Zhcv482poyafY= cloud.google.com/go/trace v1.11.6 h1:2O2zjPzqPYAHrn3OKl029qlqG6W8ZdYaOWRyr8NgMT4= cloud.google.com/go/trace v1.11.6/go.mod h1:GA855OeDEBiBMzcckLPE2kDunIpC72N+Pq8WFieFjnI= -github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMbk2FiG/kXiLl8BRyzTWDw7gX/Hz7Dd5eDMs= -github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4= -github.com/99designs/keyring v1.2.2 h1:pZd3neh/EmUzWONb35LxQfvuY7kiSXAq3HQd97+XBn0= -github.com/99designs/keyring v1.2.2/go.mod h1:wes/FrByc8j7lFOAGLGSNEg8f/PaI3cgTBqhFkHUrPk= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= @@ -143,8 +141,8 @@ github.com/creack/goselect v0.1.2/go.mod h1:a/NhLweNvqIYMuxcMOuWY516Cimucms3DglD github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= -github.com/danieljoos/wincred v1.2.0 h1:ozqKHaLK0W/ii4KVbbvluM91W2H3Sh0BncbUNPS7jLE= -github.com/danieljoos/wincred v1.2.0/go.mod h1:FzQLLMKBFdvu+osBrnFODiv32YGwCfx0SkRa/eYHgec= +github.com/danieljoos/wincred v1.2.2 h1:774zMFJrqaeYCK2W57BgAem/MLi6mtSE47MB6BOJ0i0= +github.com/danieljoos/wincred v1.2.2/go.mod h1:w7w4Utbrz8lqeMbDAK0lkNJUv5sAOkFi7nd/ogr0Uh8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -171,8 +169,6 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/dvsekhvalnov/jose2go v1.5.0 h1:3j8ya4Z4kMCwT5nXIKFSV84YS+HdqSSO0VsTQxaLAeM= -github.com/dvsekhvalnov/jose2go v1.5.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane v0.13.4 h1:zEqyPVyku6IvWCFwux4x9RxkLOMUL+1vC9xUFv5l2/M= @@ -220,8 +216,8 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= -github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 h1:ZpnhV/YsD2/4cESfV5+Hoeu/iUR3ruzNvZ+yQfO03a0= -github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= +github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= +github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gofrs/uuid/v5 v5.0.0 h1:p544++a97kEL+svbcFbCQVM9KFu0Yo25UoISXGNNH9M= @@ -253,6 +249,8 @@ github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgY github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ= @@ -268,8 +266,6 @@ github.com/grafana/regexp v0.0.0-20250905093917-f7b3be9d1853 h1:cLN4IBkmkYZNnk7E github.com/grafana/regexp v0.0.0-20250905093917-f7b3be9d1853/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= -github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= -github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= github.com/hexops/autogold v0.8.1/go.mod h1:97HLDXyG23akzAoRYJh/2OBs3kd80eHyKPvZw0S5ZBY= github.com/hexops/autogold v1.3.1 h1:YgxF9OHWbEIUjhDbpnLhgVsjUDsiHDTyDfy2lrfdlzo= github.com/hexops/autogold v1.3.1/go.mod h1:sQO+mQUCVfxOKPht+ipDSkJ2SCJ7BNJVHZexsXqWMx4= @@ -373,8 +369,6 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= -github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs= -github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= @@ -509,6 +503,8 @@ github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC4cIk= github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U= +github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s= +github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index 5fd2bdbe8a..4105c228bd 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -25,9 +25,6 @@ const ( // wellKnownPath is the path on the sourcegraph server where clients can discover OAuth configuration wellKnownPath = "/.well-known/openid-configuration" - // Key used to store the token in the store - KeyOAuth = "Sourcegraph CLI key storage" - GrantTypeDeviceCode string = "urn:ietf:params:oauth:grant-type:device_code" ScopeOpenID string = "openid" @@ -35,6 +32,10 @@ const ( ScopeEmail string = "email" ScopeOfflineAccess string = "offline_access" ScopeUserAll string = "user:all" + + // storeKeyFmt is the format of the key name that will be used to store a value + // typically the last element is the endpoint the value is for ie. src:oauth:https://sourcegraph.sourcegraph.com + storeKeyFmt string = "src:oauth:%s" ) var defaultScopes = []string{ScopeEmail, ScopeOfflineAccess, ScopeOpenID, ScopeProfile, ScopeUserAll} @@ -389,8 +390,12 @@ func (t *Token) ExpiringIn(d time.Duration) bool { return future.After(t.ExpiresAt) } -func StoreToken(token *Token) error { - store, err := secrets.Store() +func oauthKey(endpoint string) string { + return fmt.Sprintf(storeKeyFmt, endpoint) +} + +func StoreToken(ctx context.Context, token *Token) error { + store, err := secrets.Open(ctx) if err != nil { return err } @@ -403,17 +408,16 @@ func StoreToken(token *Token) error { return errors.New("token endpoint cannot be empty when storing the token") } - key := fmt.Sprintf("oauth[%s]", token.Endpoint) - return store.Put(key, data) + return store.Put(oauthKey(token.Endpoint), data) } -func LoadToken(endpoint string) (*Token, error) { - store, err := secrets.Store() +func LoadToken(ctx context.Context, endpoint string) (*Token, error) { + store, err := secrets.Open(ctx) if err != nil { return nil, err } - key := fmt.Sprintf("oauth[%s]", endpoint) + key := oauthKey(endpoint) data, err := store.Get(key) if err != nil { return nil, err diff --git a/internal/oauth/http_transport_test.go b/internal/oauth/http_transport_test.go new file mode 100644 index 0000000000..2260a31426 --- /dev/null +++ b/internal/oauth/http_transport_test.go @@ -0,0 +1,100 @@ +package oauthdevice + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +type mockRoundTripper struct { + handler func(*http.Request) (*http.Response, error) +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return m.handler(req) +} + +func TestTransport_SetsAuthorizationHeader(t *testing.T) { + var capturedAuth string + + transport := &Transport{ + Base: &mockRoundTripper{ + handler: func(req *http.Request) (*http.Response, error) { + capturedAuth = req.Header.Get("Authorization") + return &http.Response{StatusCode: 200}, nil + }, + }, + Token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(time.Hour), + }, + } + + req := httptest.NewRequest("GET", "http://example.com", nil) + _, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip() error = %v", err) + } + + if capturedAuth != "Bearer test-token" { + t.Errorf("Authorization = %q, want %q", capturedAuth, "Bearer test-token") + } +} + +func TestMaybeRefresh_RefreshesExpiredToken(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`)) + }, + }, + }) + defer server.Close() + + token := &Token{ + Endpoint: server.URL, + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), // expired + } + + result, err := maybeRefresh(context.Background(), token) + if err != nil { + t.Fatalf("maybeRefresh() error = %v", err) + } + + if result.AccessToken != "new-token" { + t.Errorf("AccessToken = %q, want %q", result.AccessToken, "new-token") + } +} + +func TestMaybeRefresh_RefreshesTokenExpiringSoon(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`)) + }, + }, + }) + defer server.Close() + + token := &Token{ + Endpoint: server.URL, + AccessToken: "expiring-soon-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(10 * time.Second), // expires in 10s (< 30s threshold) + } + + result, err := maybeRefresh(context.Background(), token) + if err != nil { + t.Fatalf("maybeRefresh() error = %v", err) + } + + if result.AccessToken != "new-token" { + t.Errorf("AccessToken = %q, want %q", result.AccessToken, "new-token") + } +} diff --git a/internal/secrets/keyring.go b/internal/secrets/keyring.go index 3acd981b18..d3f18eccc7 100644 --- a/internal/secrets/keyring.go +++ b/internal/secrets/keyring.go @@ -1,61 +1,86 @@ package secrets import ( - "github.com/99designs/keyring" + "context" + "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/zalando/go-keyring" ) -const serviceName = "sourcegraph-cli" +var ErrSecretNotFound = errors.New("secret not found") + +// Store provides secure credential storage operations. +type Store interface { + Get(key string) ([]byte, error) + Put(key string, data []byte) error + Delete(key string) error +} -// keyringStore provides secure credential storage operations. type keyringStore struct { - ring keyring.Keyring + ctx context.Context + serviceName string } -// open opens the system keyring for the Sourcegraph CLI. -func openKeyring() (*keyringStore, error) { - ring, err := keyring.Open(keyring.Config{ - ServiceName: serviceName, - KeychainName: "login", // This is the default name for the keychain where MacOS puts all login passwords - KeychainTrustApplication: true, // the keychain can trust src-cli! - }) - if err != nil { - return nil, errors.Wrap(err, "opening keyring") +// Open opens the system keyring for the Sourcegraph CLI. +func Open(ctx context.Context) (Store, error) { + return &keyringStore{ctx: ctx, serviceName: "Sourcegraph CLI"}, nil +} + +// withContext runs fn in a goroutine and returns its result, or ctx.Err() if the context is cancelled first. +func withContext[T any](ctx context.Context, fn func() (T, error)) (T, error) { + type result struct { + val T + err error + } + ch := make(chan result, 1) + go func() { + val, err := fn() + ch <- result{val, err} + }() + + select { + case <-ctx.Done(): + var zero T + return zero, ctx.Err() + case r := <-ch: + return r.val, r.err } - return &keyringStore{ring: ring}, nil } -// Set stores a key-value pair in the keyring. +// Put stores a key-value pair in the keyring. func (k *keyringStore) Put(key string, data []byte) error { - err := k.ring.Set(keyring.Item{ - Key: key, - Data: data, - Label: key, + _, err := withContext(k.ctx, func() (struct{}, error) { + err := keyring.Set(k.serviceName, key, string(data)) + if err != nil { + return struct{}{}, errors.Wrap(err, "storing item in keyring") + } + return struct{}{}, nil }) - if err != nil { - return errors.Wrap(err, "storing item in keyring") - } - return nil + return err } // Get retrieves a value by key from the keyring. -// Returns nil, nil if the key is not found. func (k *keyringStore) Get(key string) ([]byte, error) { - item, err := k.ring.Get(key) - if err != nil { - if err == keyring.ErrKeyNotFound { - return nil, ErrSecretNotFound + return withContext(k.ctx, func() ([]byte, error) { + secret, err := keyring.Get(k.serviceName, key) + if err != nil { + if err == keyring.ErrNotFound { + return nil, ErrSecretNotFound + } + return nil, errors.Wrap(err, "getting item from keyring") } - return nil, errors.Wrap(err, "getting item from keyring") - } - return item.Data, nil + return []byte(secret), nil + }) } // Delete removes a key from the keyring. func (k *keyringStore) Delete(key string) error { - err := k.ring.Remove(key) - if err != nil && err != keyring.ErrKeyNotFound { - return errors.Wrap(err, "removing item from keyring") - } - return nil + _, err := withContext(k.ctx, func() (struct{}, error) { + err := keyring.Delete(k.serviceName, key) + if err != nil && err != keyring.ErrNotFound { + return struct{}{}, errors.Wrap(err, "removing item from keyring") + } + return struct{}{}, nil + }) + return err } diff --git a/internal/secrets/store.go b/internal/secrets/store.go deleted file mode 100644 index dda306859e..0000000000 --- a/internal/secrets/store.go +++ /dev/null @@ -1,104 +0,0 @@ -package secrets - -import ( - "encoding/json" - "sync" - - "github.com/sourcegraph/sourcegraph/lib/errors" -) - -const keyRegistry = "secret-registry" - -var ErrSecretNotFound = errors.New("secret not found") - -var openOnce = sync.OnceValues(Open) - -type SecretStorage interface { - Get(key string) ([]byte, error) - Put(key string, data []byte) error - Delete(key string) error -} - -type store struct { - backend SecretStorage - registry map[string][]byte - - mu sync.Mutex -} - -func Store() (SecretStorage, error) { - return openOnce() -} - -func Open() (SecretStorage, error) { - keyring, err := openKeyring() - if err != nil { - return nil, err - } - - registry, err := getRegistry(keyring) - if err != nil { - return nil, err - } - s := &store{ - backend: keyring, - registry: registry, - } - - return s, nil -} - -func getRegistry(s SecretStorage) (map[string][]byte, error) { - data, err := s.Get(keyRegistry) - if err != nil { - return nil, errors.Wrap(err, "failed to load registry from backing store") - } - - var registry map[string][]byte - if err := json.Unmarshal(data, ®istry); err != nil { - return nil, errors.Wrap(err, "failed to decode registry from backing store") - } - - return registry, nil -} - -func saveRegistry(s SecretStorage, registry map[string][]byte) error { - data, err := json.Marshal(®istry) - if err != nil { - return errors.Wrap(err, "registry encoding failure") - } - - if err = s.Put(keyRegistry, data); err != nil { - return errors.Wrap(err, "failed to persist registry to backing store") - } - - return nil -} - -func (s *store) Get(key string) ([]byte, error) { - s.mu.Lock() - defer s.mu.Unlock() - v, ok := s.registry[key] - if !ok { - return nil, ErrSecretNotFound - } - - return v, nil -} - -func (s *store) Put(key string, data []byte) error { - s.mu.Lock() - defer s.mu.Unlock() - - s.registry[key] = data - - return saveRegistry(s.backend, s.registry) -} - -func (s *store) Delete(key string) error { - s.mu.Lock() - defer s.mu.Unlock() - - delete(s.registry, key) - return saveRegistry(s.backend, s.registry) -} From 0a6bca8c7cd47a6066c2c0b24358a8c0629518a7 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 11:25:32 +0200 Subject: [PATCH 06/14] token refresh - use token.ClientID during refresh - best effort store refresh token --- internal/oauth/http_transport.go | 15 +++- internal/oauth/http_transport_test.go | 98 +++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/internal/oauth/http_transport.go b/internal/oauth/http_transport.go index 483b45108d..7fe028824c 100644 --- a/internal/oauth/http_transport.go +++ b/internal/oauth/http_transport.go @@ -15,14 +15,23 @@ type Transport struct { Token *Token } +// storeRefreshedTokenFn is the function the transport should use to persist the token - mainly used during +// tests to swap out the implementation out with a mock +var storeRefreshedTokenFn = StoreToken + // RoundTrip implements http.RoundTripper. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() + prevToken := t.Token token, err := maybeRefresh(ctx, t.Token) if err != nil { return nil, err } t.Token = token + if token != prevToken { + // try to save the token if we fail let the request continue with in memory token + _ = storeRefreshedTokenFn(ctx, token) + } req2 := req.Clone(req.Context()) req2.Header.Set("Authorization", "Bearer "+t.Token.AccessToken) @@ -38,12 +47,14 @@ func maybeRefresh(ctx context.Context, token *Token) (*Token, error) { if !(token.HasExpired() || token.ExpiringIn(time.Duration(30)*time.Second)) { return token, nil } - client := NewClient(DefaultClientID) + client := NewClient(token.ClientID) resp, err := client.Refresh(ctx, token) if err != nil { return nil, err } - return resp.Token(token.Endpoint), nil + next := resp.Token(token.Endpoint) + next.ClientID = token.ClientID + return next, nil } diff --git a/internal/oauth/http_transport_test.go b/internal/oauth/http_transport_test.go index 2260a31426..c79c4bc723 100644 --- a/internal/oauth/http_transport_test.go +++ b/internal/oauth/http_transport_test.go @@ -2,6 +2,7 @@ package oauthdevice import ( "context" + "errors" "net/http" "net/http/httptest" "testing" @@ -98,3 +99,100 @@ func TestMaybeRefresh_RefreshesTokenExpiringSoon(t *testing.T) { t.Errorf("AccessToken = %q, want %q", result.AccessToken, "new-token") } } + +func TestTransport_RefreshPersistence(t *testing.T) { + tests := []struct { + name string + needsRefresh bool + persistErr error + wantAuthHeaderVal string + wantStoreCalls int + }{ + { + name: "persists refreshed token", + needsRefresh: true, + wantAuthHeaderVal: "Bearer new-token", + wantStoreCalls: 1, + }, + { + name: "does not persist unchanged token", + wantAuthHeaderVal: "Bearer valid-token", + wantStoreCalls: 0, + }, + { + name: "persist failure does not fail request", + needsRefresh: true, + persistErr: errors.New("persist failed"), + wantAuthHeaderVal: "Bearer new-token", + wantStoreCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + originalStoreFn := storeRefreshedTokenFn + defer func() { storeRefreshedTokenFn = originalStoreFn }() + + var storeCalls int + var storedToken *Token + storeRefreshedTokenFn = func(_ context.Context, token *Token) error { + storeCalls++ + storedToken = token + return tt.persistErr + } + + token := &Token{ + AccessToken: "valid-token", + ExpiresAt: time.Now().Add(time.Hour), + } + if tt.needsRefresh { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`)) + }, + }, + }) + defer server.Close() + token.Endpoint = server.URL + token.AccessToken = "expired-token" + token.RefreshToken = "refresh-token" + token.ExpiresAt = time.Now().Add(-time.Hour) + } + + var capturedAuth string + transport := &Transport{ + Base: &mockRoundTripper{ + handler: func(req *http.Request) (*http.Response, error) { + capturedAuth = req.Header.Get("Authorization") + return &http.Response{StatusCode: 200}, nil + }, + }, + Token: token, + } + + req := httptest.NewRequest("GET", "http://example.com", nil) + _, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip() error = %v", err) + } + + if capturedAuth != tt.wantAuthHeaderVal { + t.Errorf("Authorization = %q, want %q", capturedAuth, tt.wantAuthHeaderVal) + } + if storeCalls != tt.wantStoreCalls { + t.Errorf("store calls = %d, want %d", storeCalls, tt.wantStoreCalls) + } + + if tt.needsRefresh { + if storedToken == nil { + t.Fatal("stored token is nil") + } + if storedToken.AccessToken != "new-token" { + t.Errorf("stored AccessToken = %q, want %q", storedToken.AccessToken, "new-token") + } + } + }) + } +} From 0f2e720ce71dace9be24086cd0c4eb07047b388a Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 11:26:35 +0200 Subject: [PATCH 07/14] better error handling handle oauth discovery failure and set client id on token - use SRC CLI client id as default and handle discovery failures - add clientID flag and set it on the token improve error message and panic in apiClient if no usable token - warn if we fail to store the token on login - panic if apiClient has no accessToken or OAuth token to use --- cmd/src/login.go | 24 ++++++++++++++++++------ cmd/src/main.go | 3 +++ internal/oauth/flow.go | 23 +++++++++++------------ internal/oauth/flow_test.go | 16 ++++++++++++++++ 4 files changed, 48 insertions(+), 18 deletions(-) diff --git a/cmd/src/login.go b/cmd/src/login.go index b2ce6fae8a..4f755a530f 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -74,6 +74,7 @@ Examples: endpoint: endpoint, out: os.Stdout, useOAuth: *useOAuth, + oauthClientID: *OAuthClientID, apiFlags: apiFlags, deviceFlowClient: oauth.NewClient(oauth.DefaultClientID), }) @@ -92,6 +93,7 @@ type loginParams struct { endpoint string out io.Writer useOAuth bool + oauthClientID string apiFlags *api.Flags deviceFlowClient oauth.Client } @@ -125,7 +127,7 @@ func loginCmd(ctx context.Context, p loginParams) error { cfg.Endpoint = endpointArg if p.useOAuth { - token, err := runOAuthDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient) + token, err := runOAuthDeviceFlow(ctx, endpointArg, p.oauthClientID, out, p.deviceFlowClient) if err != nil { printProblem(fmt.Sprintf("OAuth Device flow authentication failed: %s", err)) fmt.Fprintln(out, createAccessTokenMessage) @@ -133,11 +135,19 @@ func loginCmd(ctx context.Context, p loginParams) error { } if err := oauth.StoreToken(ctx, token); err != nil { - printProblem(fmt.Sprintf("Failed to store token in keyring store: %s", err)) - return cmderrors.ExitCode1 + fmt.Fprintln(out) + fmt.Fprintf(out, "āš ļø Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err) } - client = cfg.apiClient(p.apiFlags, out) + client = api.NewClient(api.ClientOpts{ + Endpoint: cfg.Endpoint, + AdditionalHeaders: cfg.AdditionalHeaders, + Flags: p.apiFlags, + Out: out, + ProxyURL: cfg.ProxyURL, + ProxyPath: cfg.ProxyPath, + OAuthToken: token, + }) } else if noToken || endpointConflict { fmt.Fprintln(out) switch { @@ -184,7 +194,7 @@ func loginCmd(ctx context.Context, p loginParams) error { return nil } -func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) { +func runOAuthDeviceFlow(ctx context.Context, endpoint, clientID string, out io.Writer, client oauth.Client) (*oauth.Token, error) { authResp, err := client.Start(ctx, endpoint, nil) if err != nil { return nil, err @@ -214,7 +224,9 @@ func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, cli return nil, err } - return resp.Token(endpoint), nil + token := resp.Token(endpoint) + token.ClientID = clientID + return token, nil } func openInBrowser(url string) error { diff --git a/cmd/src/main.go b/cmd/src/main.go index f9794ee23e..a842c5b1b4 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -138,6 +138,9 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { if c.AccessToken == "" { if t, err := oauthdevice.LoadToken(context.Background(), c.Endpoint); err == nil { opts.OAuthToken = t + } else { + // TODO(burmudar): should return an error instead + panic("No access token set and no OAuth token found either - unable to create api client: " + err.Error()) } } diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index 4105c228bd..a63330b888 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -3,6 +3,7 @@ package oauth import ( + "cmp" "context" "encoding/json" "fmt" @@ -67,6 +68,7 @@ type TokenResponse struct { type Token struct { Endpoint string `json:"endpoint"` + ClientID string `json:"client_id,omitempty"` AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` ExpiresAt time.Time `json:"expires_at"` @@ -92,17 +94,14 @@ type httpClient struct { } func NewClient(clientID string) Client { - return &httpClient{ - clientID: clientID, - client: &http.Client{ - Timeout: 30 * time.Second, - }, - configCache: make(map[string]*OIDCConfiguration), - } + return NewClientWithHTTPClient(clientID, &http.Client{ + Timeout: 30 * time.Second, + }) } -func NewClientWithHTTPClient(c *http.Client) Client { +func NewClientWithHTTPClient(clientID string, c *http.Client) Client { return &httpClient{ + clientID: cmp.Or(clientID, DefaultClientID), client: c, configCache: make(map[string]*OIDCConfiguration), } @@ -170,7 +169,7 @@ func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string } data := url.Values{} - data.Set("client_id", DefaultClientID) + data.Set("client_id", c.clientID) if len(scopes) > 0 { data.Set("scope", strings.Join(scopes, " ")) } else { @@ -284,7 +283,7 @@ func (e *PollError) Error() string { func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode string) (*TokenResponse, error) { data := url.Values{} - data.Set("client_id", DefaultClientID) + data.Set("client_id", c.clientID) data.Set("device_code", deviceCode) data.Set("grant_type", GrantTypeDeviceCode) @@ -326,11 +325,11 @@ func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode str func (c *httpClient) Refresh(ctx context.Context, token *Token) (*TokenResponse, error) { config, err := c.Discover(ctx, token.Endpoint) if err != nil { - errors.Wrap(err, "failed to discover OIDC configuration") + return nil, errors.Wrap(err, "failed to discover OIDC configuration") } if config.TokenEndpoint == "" { - errors.New("OIDC configuration has no token endpoint") + return nil, errors.New("OIDC configuration has no token endpoint") } data := url.Values{} diff --git a/internal/oauth/flow_test.go b/internal/oauth/flow_test.go index 0268195317..0b1ad5dc93 100644 --- a/internal/oauth/flow_test.go +++ b/internal/oauth/flow_test.go @@ -555,3 +555,19 @@ func TestRefresh_Success(t *testing.T) { t.Errorf("RefreshToken = %q, want %q", resp.RefreshToken, "new-refresh-token") } } + +func TestRefresh_DiscoverFailure(t *testing.T) { + client := NewClient(DefaultClientID) + token := &Token{ + Endpoint: "http://127.0.0.1:1", + RefreshToken: "test-refresh-token", + } + + _, err := client.Refresh(context.Background(), token) + if err == nil { + t.Fatal("Refresh() expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to discover OIDC configuration") { + t.Errorf("error = %q, want discovery failure context", err.Error()) + } +} From 4c8fb193b2b70bf9b1e1ee1c2996fced41b0c5e4 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 14:43:24 +0200 Subject: [PATCH 08/14] fix removal of client id --- cmd/src/login.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cmd/src/login.go b/cmd/src/login.go index 4f755a530f..01cdda5523 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -74,7 +74,6 @@ Examples: endpoint: endpoint, out: os.Stdout, useOAuth: *useOAuth, - oauthClientID: *OAuthClientID, apiFlags: apiFlags, deviceFlowClient: oauth.NewClient(oauth.DefaultClientID), }) @@ -93,7 +92,6 @@ type loginParams struct { endpoint string out io.Writer useOAuth bool - oauthClientID string apiFlags *api.Flags deviceFlowClient oauth.Client } @@ -127,7 +125,7 @@ func loginCmd(ctx context.Context, p loginParams) error { cfg.Endpoint = endpointArg if p.useOAuth { - token, err := runOAuthDeviceFlow(ctx, endpointArg, p.oauthClientID, out, p.deviceFlowClient) + token, err := runOAuthDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient) if err != nil { printProblem(fmt.Sprintf("OAuth Device flow authentication failed: %s", err)) fmt.Fprintln(out, createAccessTokenMessage) @@ -194,7 +192,7 @@ func loginCmd(ctx context.Context, p loginParams) error { return nil } -func runOAuthDeviceFlow(ctx context.Context, endpoint, clientID string, out io.Writer, client oauth.Client) (*oauth.Token, error) { +func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) { authResp, err := client.Start(ctx, endpoint, nil) if err != nil { return nil, err @@ -225,7 +223,7 @@ func runOAuthDeviceFlow(ctx context.Context, endpoint, clientID string, out io.W } token := resp.Token(endpoint) - token.ClientID = clientID + token.ClientID = client.ClientID() return token, nil } From 077ecd896686e621b6c4f995c22d109896bd4728 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 14:43:41 +0200 Subject: [PATCH 09/14] fix package rename --- cmd/src/main.go | 4 ++-- internal/api/api.go | 6 +++--- internal/oauth/flow.go | 5 +++++ internal/oauth/http_transport.go | 2 +- internal/oauth/http_transport_test.go | 2 +- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/cmd/src/main.go b/cmd/src/main.go index a842c5b1b4..e506730841 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -16,7 +16,7 @@ import ( "github.com/sourcegraph/sourcegraph/lib/errors" "github.com/sourcegraph/src-cli/internal/api" - "github.com/sourcegraph/src-cli/internal/oauthdevice" + "github.com/sourcegraph/src-cli/internal/oauth" ) const usageText = `src is a tool that provides access to Sourcegraph instances. @@ -136,7 +136,7 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { // Only use OAuth if we do not have SRC_ACCESS_TOKEN set if c.AccessToken == "" { - if t, err := oauthdevice.LoadToken(context.Background(), c.Endpoint); err == nil { + if t, err := oauth.LoadToken(context.Background(), c.Endpoint); err == nil { opts.OAuthToken = t } else { // TODO(burmudar): should return an error instead diff --git a/internal/api/api.go b/internal/api/api.go index 2a2dbe6415..f1fb8f6443 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -18,7 +18,7 @@ import ( "github.com/kballard/go-shellquote" "github.com/mattn/go-isatty" - "github.com/sourcegraph/src-cli/internal/oauthdevice" + "github.com/sourcegraph/src-cli/internal/oauth" "github.com/sourcegraph/src-cli/internal/version" ) @@ -87,7 +87,7 @@ type ClientOpts struct { ProxyURL *url.URL ProxyPath string - OAuthToken *oauthdevice.Token + OAuthToken *oauth.Token } func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper { @@ -111,7 +111,7 @@ func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper { } if opts.AccessToken == "" && opts.OAuthToken != nil { - transport = &oauthdevice.Transport{ + transport = &oauth.Transport{ Base: transport, Token: opts.OAuthToken, } diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index a63330b888..92819fec6e 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -80,6 +80,7 @@ type ErrorResponse struct { } type Client interface { + ClientID() string Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) @@ -107,6 +108,10 @@ func NewClientWithHTTPClient(clientID string, c *http.Client) Client { } } +func (c *httpClient) ClientID() string { + return c.clientID +} + // Discover fetches the openid-configuration which contains all the routes a client should // use for authorization, device flows, tokens etc. // diff --git a/internal/oauth/http_transport.go b/internal/oauth/http_transport.go index 7fe028824c..4361367008 100644 --- a/internal/oauth/http_transport.go +++ b/internal/oauth/http_transport.go @@ -1,4 +1,4 @@ -package oauthdevice +package oauth import ( "context" diff --git a/internal/oauth/http_transport_test.go b/internal/oauth/http_transport_test.go index c79c4bc723..148619b9cd 100644 --- a/internal/oauth/http_transport_test.go +++ b/internal/oauth/http_transport_test.go @@ -1,4 +1,4 @@ -package oauthdevice +package oauth import ( "context" From a6e9155358c3c99c613079af1272f78f47a9b04b Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 15:15:07 +0200 Subject: [PATCH 10/14] report a nicer error when OAuth fails --- internal/api/api.go | 19 +++++++++++++++---- internal/oauth/http_transport.go | 5 +++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index f1fb8f6443..423cb73997 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -183,6 +183,7 @@ func (c *client) createHTTPRequest(ctx context.Context, method, p string, body i } else { req.Header.Set("User-Agent", "src-cli/"+version.BuildTag) } + if c.opts.AccessToken != "" { req.Header.Set("Authorization", "token "+c.opts.AccessToken) } @@ -264,10 +265,20 @@ func (r *request) do(ctx context.Context, result any) (bool, error) { // confirm the status code. You can test this easily with e.g. an invalid // endpoint like -endpoint=https://google.com if resp.StatusCode != http.StatusOK { - if resp.StatusCode == http.StatusUnauthorized && isatty.IsCygwinTerminal(os.Stdout.Fd()) { - fmt.Println("You may need to specify or update your access token to use this endpoint.") - fmt.Println("See https://github.com/sourcegraph/src-cli#readme") - fmt.Println("") + if resp.StatusCode == http.StatusUnauthorized { + if oauth.IsOAuthTransport(r.client.httpClient.Transport) { + fmt.Println("The OAuth token is invalid. Please check that the Sourcegraph CLI client is still authorized.") + fmt.Println("") + fmt.Println("To re-authorize, run: src login") + fmt.Println("") + fmt.Println("Learn more at https://github.com/sourcegraph/src-cli#readme") + fmt.Println("") + } + if isatty.IsCygwinTerminal(os.Stdout.Fd()) { + fmt.Println("You may need to specify or update your access token to use this endpoint.") + fmt.Println("See https://github.com/sourcegraph/src-cli#readme") + fmt.Println("") + } } body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/internal/oauth/http_transport.go b/internal/oauth/http_transport.go index 4361367008..9407bc625f 100644 --- a/internal/oauth/http_transport.go +++ b/internal/oauth/http_transport.go @@ -58,3 +58,8 @@ func maybeRefresh(ctx context.Context, token *Token) (*Token, error) { next.ClientID = token.ClientID return next, nil } + +func IsOAuthTransport(trp http.RoundTripper) bool { + _, ok := trp.(*Transport) + return ok +} From 42193acbc4cb16b4c5ce12ebb3c9696bbe6973b2 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 15:26:00 +0200 Subject: [PATCH 11/14] refactor token refresh tests --- internal/oauth/http_transport_test.go | 237 ++++++++++++-------------- 1 file changed, 105 insertions(+), 132 deletions(-) diff --git a/internal/oauth/http_transport_test.go b/internal/oauth/http_transport_test.go index 148619b9cd..1e8a599e0d 100644 --- a/internal/oauth/http_transport_test.go +++ b/internal/oauth/http_transport_test.go @@ -9,127 +9,129 @@ import ( "time" ) -type mockRoundTripper struct { - handler func(*http.Request) (*http.Response, error) -} - -func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - return m.handler(req) -} - -func TestTransport_SetsAuthorizationHeader(t *testing.T) { - var capturedAuth string - - transport := &Transport{ - Base: &mockRoundTripper{ - handler: func(req *http.Request) (*http.Response, error) { - capturedAuth = req.Header.Get("Authorization") - return &http.Response{StatusCode: 200}, nil - }, - }, - Token: &Token{ - AccessToken: "test-token", - ExpiresAt: time.Now().Add(time.Hour), - }, - } - - req := httptest.NewRequest("GET", "http://example.com", nil) - _, err := transport.RoundTrip(req) - if err != nil { - t.Fatalf("RoundTrip() error = %v", err) - } +type roundTripperFunc func(*http.Request) (*http.Response, error) - if capturedAuth != "Bearer test-token" { - t.Errorf("Authorization = %q, want %q", capturedAuth, "Bearer test-token") - } +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) } -func TestMaybeRefresh_RefreshesExpiredToken(t *testing.T) { - server := newTestServer(t, testServerOptions{ +func newRefreshServer(t *testing.T, accessToken string) *httptest.Server { + t.Helper() + return newTestServer(t, testServerOptions{ handlers: map[string]http.HandlerFunc{ - testTokenPath: func(w http.ResponseWriter, r *http.Request) { + testTokenPath: func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`)) + _, _ = w.Write([]byte(`{"access_token":"` + accessToken + `","refresh_token":"new-refresh","expires_in":3600}`)) }, }, }) - defer server.Close() - - token := &Token{ - Endpoint: server.URL, - AccessToken: "expired-token", - RefreshToken: "refresh-token", - ExpiresAt: time.Now().Add(-time.Hour), // expired - } - - result, err := maybeRefresh(context.Background(), token) - if err != nil { - t.Fatalf("maybeRefresh() error = %v", err) - } - - if result.AccessToken != "new-token" { - t.Errorf("AccessToken = %q, want %q", result.AccessToken, "new-token") - } } -func TestMaybeRefresh_RefreshesTokenExpiringSoon(t *testing.T) { - server := newTestServer(t, testServerOptions{ - handlers: map[string]http.HandlerFunc{ - testTokenPath: func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`)) - }, - }, - }) +func TestMaybeRefresh(t *testing.T) { + server := newRefreshServer(t, "new-token") defer server.Close() - token := &Token{ - Endpoint: server.URL, - AccessToken: "expiring-soon-token", - RefreshToken: "refresh-token", - ExpiresAt: time.Now().Add(10 * time.Second), // expires in 10s (< 30s threshold) - } - - result, err := maybeRefresh(context.Background(), token) - if err != nil { - t.Fatalf("maybeRefresh() error = %v", err) + tests := []struct { + name string + token *Token + wantAccess string + wantSame bool + }{ + { + name: "unchanged when still valid", + token: &Token{ + AccessToken: "valid-token", + ExpiresAt: time.Now().Add(time.Hour), + }, + wantAccess: "valid-token", + wantSame: true, + }, + { + name: "refreshes expired token", + token: &Token{ + Endpoint: server.URL, + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), + }, + wantAccess: "new-token", + }, + { + name: "refreshes token expiring soon", + token: &Token{ + Endpoint: server.URL, + AccessToken: "expiring-soon-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(10 * time.Second), + }, + wantAccess: "new-token", + }, } - if result.AccessToken != "new-token" { - t.Errorf("AccessToken = %q, want %q", result.AccessToken, "new-token") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := maybeRefresh(context.Background(), tt.token) + if err != nil { + t.Fatalf("maybeRefresh() error = %v", err) + } + if got.AccessToken != tt.wantAccess { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, tt.wantAccess) + } + if tt.wantSame && got != tt.token { + t.Errorf("token pointer changed for unexpired token") + } + }) } } -func TestTransport_RefreshPersistence(t *testing.T) { +func TestTransportRoundTrip(t *testing.T) { tests := []struct { - name string - needsRefresh bool - persistErr error - wantAuthHeaderVal string - wantStoreCalls int + name string + token *Token + persistErr error + wantAuthHeader string + wantStoreCalls int }{ { - name: "persists refreshed token", - needsRefresh: true, - wantAuthHeaderVal: "Bearer new-token", - wantStoreCalls: 1, + name: "uses existing token without persisting", + token: &Token{ + AccessToken: "valid-token", + ExpiresAt: time.Now().Add(time.Hour), + }, + wantAuthHeader: "Bearer valid-token", + wantStoreCalls: 0, }, { - name: "does not persist unchanged token", - wantAuthHeaderVal: "Bearer valid-token", - wantStoreCalls: 0, + name: "persists refreshed token", + token: &Token{ + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), + }, + wantAuthHeader: "Bearer new-token", + wantStoreCalls: 1, }, { - name: "persist failure does not fail request", - needsRefresh: true, - persistErr: errors.New("persist failed"), - wantAuthHeaderVal: "Bearer new-token", - wantStoreCalls: 1, + name: "ignores persist failures", + token: &Token{ + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), + }, + persistErr: errors.New("persist failed"), + wantAuthHeader: "Bearer new-token", + wantStoreCalls: 1, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + if tt.wantStoreCalls > 0 { + server := newRefreshServer(t, "new-token") + defer server.Close() + tt.token.Endpoint = server.URL + } + originalStoreFn := storeRefreshedTokenFn defer func() { storeRefreshedTokenFn = originalStoreFn }() @@ -141,57 +143,28 @@ func TestTransport_RefreshPersistence(t *testing.T) { return tt.persistErr } - token := &Token{ - AccessToken: "valid-token", - ExpiresAt: time.Now().Add(time.Hour), - } - if tt.needsRefresh { - server := newTestServer(t, testServerOptions{ - handlers: map[string]http.HandlerFunc{ - testTokenPath: func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`)) - }, - }, - }) - defer server.Close() - token.Endpoint = server.URL - token.AccessToken = "expired-token" - token.RefreshToken = "refresh-token" - token.ExpiresAt = time.Now().Add(-time.Hour) - } - var capturedAuth string - transport := &Transport{ - Base: &mockRoundTripper{ - handler: func(req *http.Request) (*http.Response, error) { - capturedAuth = req.Header.Get("Authorization") - return &http.Response{StatusCode: 200}, nil - }, - }, - Token: token, + tr := &Transport{ + Base: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedAuth = req.Header.Get("Authorization") + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }), + Token: tt.token, } - req := httptest.NewRequest("GET", "http://example.com", nil) - _, err := transport.RoundTrip(req) + _, err := tr.RoundTrip(httptest.NewRequest(http.MethodGet, "http://example.com", nil)) if err != nil { t.Fatalf("RoundTrip() error = %v", err) } - if capturedAuth != tt.wantAuthHeaderVal { - t.Errorf("Authorization = %q, want %q", capturedAuth, tt.wantAuthHeaderVal) + if capturedAuth != tt.wantAuthHeader { + t.Errorf("Authorization = %q, want %q", capturedAuth, tt.wantAuthHeader) } if storeCalls != tt.wantStoreCalls { t.Errorf("store calls = %d, want %d", storeCalls, tt.wantStoreCalls) } - - if tt.needsRefresh { - if storedToken == nil { - t.Fatal("stored token is nil") - } - if storedToken.AccessToken != "new-token" { - t.Errorf("stored AccessToken = %q, want %q", storedToken.AccessToken, "new-token") - } + if tt.wantStoreCalls > 0 && (storedToken == nil || storedToken.AccessToken != "new-token") { + t.Errorf("stored token = %#v, want access token %q", storedToken, "new-token") } }) } From 239125f6456de32c85aa71f5d4c8aa4c097ac895 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 15:41:28 +0200 Subject: [PATCH 12/14] fix ci - remove panic - use lib/errors from sg --- cmd/src/main.go | 3 --- internal/oauth/http_transport_test.go | 3 ++- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/cmd/src/main.go b/cmd/src/main.go index e506730841..41e5c55cd0 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -138,9 +138,6 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { if c.AccessToken == "" { if t, err := oauth.LoadToken(context.Background(), c.Endpoint); err == nil { opts.OAuthToken = t - } else { - // TODO(burmudar): should return an error instead - panic("No access token set and no OAuth token found either - unable to create api client: " + err.Error()) } } diff --git a/internal/oauth/http_transport_test.go b/internal/oauth/http_transport_test.go index 1e8a599e0d..4dac832d05 100644 --- a/internal/oauth/http_transport_test.go +++ b/internal/oauth/http_transport_test.go @@ -2,11 +2,12 @@ package oauth import ( "context" - "errors" "net/http" "net/http/httptest" "testing" "time" + + "github.com/sourcegraph/sourcegraph/lib/errors" ) type roundTripperFunc func(*http.Request) (*http.Response, error) From 7978ccec681a516e37838ec20fa2402331efaac6 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 15:56:52 +0200 Subject: [PATCH 13/14] inform user of oauth login --- cmd/src/login.go | 27 ++++++++++++++------------- cmd/src/login_test.go | 8 ++++---- internal/api/api.go | 2 +- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/cmd/src/login.go b/cmd/src/login.go index 01cdda5523..5a73ef4cc8 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -112,7 +112,9 @@ func loginCmd(ctx context.Context, p loginParams) error { export SRC_ACCESS_TOKEN=(your access token) To verify that it's working, run the login command again. -`, endpointArg, endpointArg) + + Alternatively, you can try logging in using OAuth by running: src login --oauth %s +`, endpointArg, endpointArg, endpointArg) if cfg.ConfigFilePath != "" { fmt.Fprintln(out) @@ -121,8 +123,17 @@ func loginCmd(ctx context.Context, p loginParams) error { noToken := cfg.AccessToken == "" endpointConflict := endpointArg != cfg.Endpoint - - cfg.Endpoint = endpointArg + if !p.useOAuth && (noToken || endpointConflict) { + fmt.Fprintln(out) + switch { + case noToken: + printProblem("No access token is configured.") + case endpointConflict: + printProblem(fmt.Sprintf("The configured endpoint is %s, not %s.", cfg.Endpoint, endpointArg)) + } + fmt.Fprintln(out, createAccessTokenMessage) + return cmderrors.ExitCode1 + } if p.useOAuth { token, err := runOAuthDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient) @@ -146,16 +157,6 @@ func loginCmd(ctx context.Context, p loginParams) error { ProxyPath: cfg.ProxyPath, OAuthToken: token, }) - } else if noToken || endpointConflict { - fmt.Fprintln(out) - switch { - case noToken: - printProblem("No access token is configured.") - case endpointConflict: - printProblem(fmt.Sprintf("The configured endpoint is %s, not %s.", cfg.Endpoint, endpointArg)) - } - fmt.Fprintln(out, createAccessTokenMessage) - return cmderrors.ExitCode1 } // See if the user is already authenticated. diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index 37d3202227..ab7a15056a 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -34,7 +34,7 @@ func TestLogin(t *testing.T) { if err != cmderrors.ExitCode1 { t.Fatal(err) } - wantOut := "āŒ Problem: No access token is configured.\n\nšŸ›  To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again." + wantOut := "āŒ Problem: No access token is configured.\n\nšŸ›  To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://sourcegraph.example.com" if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) } @@ -45,7 +45,7 @@ func TestLogin(t *testing.T) { if err != cmderrors.ExitCode1 { t.Fatal(err) } - wantOut := "āŒ Problem: No access token is configured.\n\nšŸ›  To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again." + wantOut := "āŒ Problem: No access token is configured.\n\nšŸ›  To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://sourcegraph.example.com" if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) } @@ -56,7 +56,7 @@ func TestLogin(t *testing.T) { if err != cmderrors.ExitCode1 { t.Fatal(err) } - wantOut := "āš ļø Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove f. See https://github.com/sourcegraph/src-cli#readme for more information.\n\nāŒ Problem: No access token is configured.\n\nšŸ›  To fix: Create an access token by going to https://example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again." + wantOut := "āš ļø Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove f. See https://github.com/sourcegraph/src-cli#readme for more information.\n\nāŒ Problem: No access token is configured.\n\nšŸ›  To fix: Create an access token by going to https://example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://example.com" if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) } @@ -74,7 +74,7 @@ func TestLogin(t *testing.T) { if err != cmderrors.ExitCode1 { t.Fatal(err) } - wantOut := "āŒ Problem: Invalid access token.\n\nšŸ›  To fix: Create an access token by going to $ENDPOINT/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=$ENDPOINT\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)" + wantOut := "āŒ Problem: Invalid access token.\n\nšŸ›  To fix: Create an access token by going to $ENDPOINT/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=$ENDPOINT\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth $ENDPOINT\n\n (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)" wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", endpoint) if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) diff --git a/internal/api/api.go b/internal/api/api.go index 423cb73997..ef9f822a7a 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -269,7 +269,7 @@ func (r *request) do(ctx context.Context, result any) (bool, error) { if oauth.IsOAuthTransport(r.client.httpClient.Transport) { fmt.Println("The OAuth token is invalid. Please check that the Sourcegraph CLI client is still authorized.") fmt.Println("") - fmt.Println("To re-authorize, run: src login") + fmt.Printf("To re-authorize, run: src login --oauth %s\n", r.client.opts.Endpoint) fmt.Println("") fmt.Println("Learn more at https://github.com/sourcegraph/src-cli#readme") fmt.Println("") From ba0344deda657b3aa2689e12df98a8e0bac87f2a Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 16:30:54 +0200 Subject: [PATCH 14/14] fix spelling --- internal/oauth/flow.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index 92819fec6e..7f22be3530 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -429,7 +429,7 @@ func LoadToken(ctx context.Context, endpoint string) (*Token, error) { var t Token if err := json.Unmarshal(data, &t); err != nil { - return nil, errors.Wrap(err, "failed to unmarshall token") + return nil, errors.Wrap(err, "failed to unmarshal token") } return &t, nil