Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions command/oauth/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,18 @@ websites. Learn more at https://en.wikipedia.org/wiki/OAuth.
This command by default performs the authorization flow with a preconfigured
Google application, but a custom one can be set combining the flags
**--client-id**, **--client-secret**, and **--provider**. The provider value
must be set to the OIDC discovery document (.well-known/openid-configuration)
endpoint. If Google is used this flag is not necessary, but the appropriate
must be set to the issuer URL of an OpenID Connect or OAuth 2.0 authorization
server. Its endpoints are discovered from the OpenID Connect discovery document
(.well-known/openid-configuration), falling back to the RFC 8414 OAuth 2.0
authorization server metadata document (.well-known/oauth-authorization-server).
If Google is used this flag is not necessary, but the appropriate
value would be be https://accounts.google.com or
https://accounts.google.com/.well-known/openid-configuration

A plaintext **http://** provider URL (e.g. a local test instance of an identity
provider) is only accepted together with the **--insecure** flag, because the
discovery request and tokens would otherwise be sent without TLS.

## EXAMPLES

Do the OAuth 2.0 flow using the default client:
Expand Down Expand Up @@ -328,8 +335,11 @@ type options struct {

// Validate validates the options.
func (o *options) Validate() error {
if o.Provider != "google" && o.Provider != "github" && !strings.HasPrefix(o.Provider, "https://") {
return errors.New("use a valid provider: google or github")
if o.Provider != "google" && o.Provider != "github" &&
!strings.HasPrefix(o.Provider, "https://") && !strings.HasPrefix(o.Provider, "http://") {
return errors.Errorf("invalid value '%s' for flag '--provider'; "+
"expected 'google', 'github', or an OIDC/OAuth issuer URL "+
"(e.g. https://accounts.google.com)", o.Provider)
}
if o.CallbackListener != "" {
if _, _, err := net.SplitHostPort(o.CallbackListener); err != nil {
Expand Down Expand Up @@ -363,6 +373,12 @@ func oauthCmd(c *cli.Context) error {
if err := opts.Validate(); err != nil {
return err
}
// A plaintext http:// provider transmits the discovery request and tokens
// without TLS, so it is only allowed behind the --insecure flag. This is
// primarily useful for local test instances of an identity provider.
if strings.HasPrefix(opts.Provider, "http://") && !c.Bool("insecure") {
return errs.RequiredInsecureFlag(c, "provider")
}
if (opts.Provider != "google" || c.IsSet("authorization-endpoint")) && !c.IsSet("client-id") {
return errors.New("flag '--client-id' required with '--provider'")
}
Expand Down Expand Up @@ -681,31 +697,64 @@ func newOauth(provider, clientID, clientSecret, authzEp, deviceAuthzEp, tokenEp,
}, nil
}

// discoveryPaths are the well-known metadata locations probed, in order, to
// resolve a provider's authorization-server endpoints. The OpenID Connect
// Discovery 1.0 path is tried first, falling back to the RFC 8414 OAuth 2.0
// Authorization Server Metadata path. See
// https://openid.net/specs/openid-connect-discovery-1_0.html and
// https://tools.ietf.org/html/rfc8414#section-5
var discoveryPaths = []string{
"/.well-known/openid-configuration",
"/.well-known/oauth-authorization-server",
}

func disco(provider string) (map[string]interface{}, error) {
u, err := url.Parse(provider)
if err != nil {
return nil, err
return nil, errors.Wrapf(err, "error parsing provider '%s'", provider)
}
// TODO: OIDC and OAuth specify two different ways of constructing this
// URL. This is the OIDC way. Probably want to try both. See
// https://tools.ietf.org/html/rfc8414#section-5
if !strings.Contains(u.Path, "/.well-known/openid-configuration") {
u.Path = path.Join(u.Path, "/.well-known/openid-configuration")

// If the provider already points directly at a well-known metadata
// document, fetch it as-is instead of appending another suffix.
if strings.Contains(u.Path, "/.well-known/") {
return fetchDiscovery(u.String())
}
resp, err := http.Get(u.String())

base := u.Path
var firstErr error
for _, p := range discoveryPaths {
u.Path = path.Join(base, p)
details, err := fetchDiscovery(u.String())
if err == nil {
return details, nil
}
if firstErr == nil {
firstErr = err
}
}
return nil, firstErr
}

// fetchDiscovery retrieves and parses an authorization-server metadata
// document from the given URL.
func fetchDiscovery(rawurl string) (map[string]interface{}, error) {
resp, err := http.Get(rawurl) // #nosec G704 -- request intentionally relies on user configuration
if err != nil {
return nil, errors.Wrapf(err, "error retrieving %s", u.String())
return nil, errors.Wrapf(err, "error retrieving %s", rawurl)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, errors.Errorf("error retrieving %s: unexpected status code %d", rawurl, resp.StatusCode)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrapf(err, "error retrieving %s", u.String())
return nil, errors.Wrapf(err, "error retrieving %s", rawurl)
}
details := make(map[string]interface{})
if err = json.Unmarshal(b, &details); err != nil {
return nil, errors.Wrapf(err, "error reading %s: unsupported format", u.String())
if err := json.Unmarshal(b, &details); err != nil {
return nil, errors.Wrapf(err, "error reading %s: unsupported format", rawurl)
}
return details, err
return details, nil
}

// postForm simulates http.PostForm but adds the header "Accept:
Expand Down
138 changes: 138 additions & 0 deletions command/oauth/cmd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package oauth

import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)

func TestOptions_Validate(t *testing.T) {
tests := []struct {
name string
provider string
wantErr bool
}{
{"google", "google", false},
{"github", "github", false},
{"https issuer", "https://accounts.google.com", false},
{"https issuer with path", "https://sso.example.org/realms/homelab", false},
// http:// is accepted by Validate(); the --insecure gating is
// enforced separately in oauthCmd.
{"http issuer", "http://localhost:8080/realms/test", false},
{"bare name", "keycloak", true},
{"ftp issuer", "ftp://accounts.google.com", true},
{"empty", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := &options{Provider: tt.provider}
err := o.Validate()
if tt.wantErr {
if err == nil {
t.Fatalf("Validate() error = nil, want error for provider %q", tt.provider)
}
return
}
if err != nil {
t.Fatalf("Validate() error = %v, want nil for provider %q", err, tt.provider)
}
})
}
}

const (
oidcDoc = `{"authorization_endpoint":"https://idp.example/auth","token_endpoint":"https://idp.example/token"}`
oauthDoc = `{"authorization_endpoint":"https://idp.example/oauth/auth","token_endpoint":"https://idp.example/oauth/token"}`
)

func TestDisco(t *testing.T) {
t.Run("oidc discovery", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/openid-configuration" {
_, _ = w.Write([]byte(oidcDoc))
return
}
http.NotFound(w, r)
}))
defer srv.Close()

d, err := disco(srv.URL)
if err != nil {
t.Fatalf("disco() error = %v", err)
}
if got := d["token_endpoint"]; got != "https://idp.example/token" {
t.Fatalf("token_endpoint = %v, want https://idp.example/token", got)
}
})

t.Run("rfc8414 oauth fallback", func(t *testing.T) {
// Only the OAuth authorization server metadata path is served; the
// OIDC path 404s, exercising the RFC 8414 fallback.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/oauth-authorization-server" {
_, _ = w.Write([]byte(oauthDoc))
return
}
http.NotFound(w, r)
}))
defer srv.Close()

d, err := disco(srv.URL)
if err != nil {
t.Fatalf("disco() error = %v", err)
}
if got := d["token_endpoint"]; got != "https://idp.example/oauth/token" {
t.Fatalf("token_endpoint = %v, want https://idp.example/oauth/token", got)
}
})

t.Run("explicit well-known url fetched as-is", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/openid-configuration" {
_, _ = w.Write([]byte(oidcDoc))
return
}
http.NotFound(w, r)
}))
defer srv.Close()

d, err := disco(srv.URL + "/.well-known/openid-configuration")
if err != nil {
t.Fatalf("disco() error = %v", err)
}
if d["authorization_endpoint"] != "https://idp.example/auth" {
t.Fatalf("unexpected metadata: %v", d)
}
})

t.Run("non-200 reports status code", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "nope", http.StatusNotFound)
}))
defer srv.Close()

_, err := disco(srv.URL)
if err == nil {
t.Fatal("disco() error = nil, want error")
}
if !strings.Contains(err.Error(), "404") {
t.Fatalf("disco() error = %v, want mention of status code 404", err)
}
})

t.Run("invalid json reports unsupported format", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("<html>not json</html>"))
}))
defer srv.Close()

_, err := disco(srv.URL)
if err == nil {
t.Fatal("disco() error = nil, want error")
}
if !strings.Contains(err.Error(), "unsupported format") {
t.Fatalf("disco() error = %v, want unsupported format", err)
}
})
}