From f45448499e2109cd837aa7b7df183707fc543e10 Mon Sep 17 00:00:00 2001 From: markmnl Date: Sun, 14 Jun 2026 16:05:02 +0800 Subject: [PATCH 1/2] api-keys --- README.md | 156 +++++++++++- api_keys.sql | 28 +++ cmd/fmsg-webapi/apikey_cli.go | 132 ++++++++++ cmd/fmsg-webapi/main.go | 110 ++++++--- internal/apiauth/apikey.go | 66 +++++ internal/apiauth/apikey_test.go | 56 +++++ internal/apiauth/store.go | 307 ++++++++++++++++++++++++ internal/apiauth/subaccount.go | 49 ++++ internal/apiauth/token.go | 105 ++++++++ internal/handlers/subaccounts.go | 257 ++++++++++++++++++++ internal/handlers/token.go | 82 +++++++ internal/handlers/ws.go | 3 +- internal/middleware/cors.go | 2 +- internal/middleware/jwt.go | 399 ++++++++++++++++++------------- internal/middleware/jwt_test.go | 393 ++++++++++++------------------ 15 files changed, 1692 insertions(+), 453 deletions(-) create mode 100644 api_keys.sql create mode 100644 cmd/fmsg-webapi/apikey_cli.go create mode 100644 internal/apiauth/apikey.go create mode 100644 internal/apiauth/apikey_test.go create mode 100644 internal/apiauth/store.go create mode 100644 internal/apiauth/subaccount.go create mode 100644 internal/apiauth/token.go create mode 100644 internal/handlers/subaccounts.go create mode 100644 internal/handlers/token.go diff --git a/README.md b/README.md index 810fd6a..3c59417 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,11 @@ HTTP API providing user/client message handling for an fmsg host. Exposes CRUD o | `FMSG_JWT_ISSUER` | *(prod, required with JWKS)* | Expected `iss` claim value (e.g. `https://idp.example.com/`). Tokens with a different issuer are rejected. This must exactly match the token issuer. | | `FMSG_JWT_AUDIENCE` | *(prod, required with JWKS)* | Expected `aud` claim value for this application or API. | | `FMSG_JWT_ADDRESS_CLAIM` | *(prod, required with JWKS)* | JWT claim name containing the fmsg address in `@user@domain` form, e.g. `fmsg_address` or a namespaced custom claim. | -| `FMSG_API_JWT_SECRET` | *(dev)* | HMAC secret for HS256 token verification. Used only in dev mode (when `FMSG_JWT_JWKS_URL` is unset). Prefix with `base64:` to supply a base64-encoded key. Either this or `FMSG_JWT_JWKS_URL` must be set. | +| `FMSG_API_TOKEN_ED25519_PRIVATE_KEY` | *(optional)* | Base64-encoded Ed25519 private key or seed used to mint first-party JWTs from API keys. Required to enable `/fmsg/token` and sub-account routes. | +| `FMSG_API_TOKEN_ISSUER` | `fmsg-webapi` | Issuer for first-party API-key JWTs. | +| `FMSG_API_TOKEN_AUDIENCE` | `fmsg-webapi` | Audience for first-party API-key JWTs. | +| `FMSG_API_TOKEN_TTL` | `12h` | Lifetime of JWTs minted by `POST /fmsg/token`. | +| `FMSG_TRUSTED_PROXIES` | *(optional)* | Comma-separated trusted proxy CIDRs/IPs for Gin client IP resolution. Leave unset to use direct client addresses for API-key CIDR checks. | | `FMSG_TLS_CERT` | *(optional)* | Path to the TLS certificate file (e.g. `/etc/letsencrypt/live/example.com/fullchain.pem`). When set with `FMSG_TLS_KEY`, enables HTTPS. | | `FMSG_TLS_KEY` | *(optional)* | Path to the TLS private key file (e.g. `/etc/letsencrypt/live/example.com/privkey.pem`). Must be set together with `FMSG_TLS_CERT`. | | `FMSG_API_PORT` | `443` (TLS) / `8000` (plain) | TCP port to listen on. | @@ -36,8 +40,13 @@ A `.env` file placed in the working directory is loaded automatically at startup ## Authentication -All `/fmsg/*` routes require an `Authorization: Bearer ` header. The API -operates in one of two verification modes, selected automatically at startup: +Most `/fmsg/*` routes require an `Authorization: Bearer ` header. The +API can enable either or both authentication methods at startup: + +- RS256/JWKS tokens from an external identity provider. +- First-party Ed25519 JWTs minted by `POST /fmsg/token` from opaque API keys. + +Startup fails unless at least one method is configured. ### RS256 (production, JWKS-backed JWTs) @@ -68,11 +77,60 @@ includes the configured address claim. Whether that token is an ID token or access token is determined by the identity provider configuration for the deployment. -### HMAC (development) +### API Keys And First-Party JWTs + +Active when `FMSG_API_TOKEN_ED25519_PRIVATE_KEY` is set. Programmatic clients +authenticate with opaque API keys bound to sub-account addresses. The server +stores only API-key hashes and exchanges valid keys for short-lived Ed25519 JWTs. + +API keys are sent only to `POST /fmsg/token`: + +```http +Authorization: Bearer fmsgk__ +``` + +The returned JWT contains `sub` (the sub-account address), `owner`, `api_key_id`, +`iss`, `aud`, `iat`, and `exp`. Protected routes re-check the backing key row on +each request, so deleting a sub-account or expiring its key invalidates existing +tokens before their normal expiry. + +An RS256-authenticated owner can perform normal message routes as one of their +sub-accounts without changing request bodies: + +```http +X-FMSG-Act-As: @user_bot@example.com +``` + +The requested sub-account must be owned by the authenticated user and must exist +in fmsgid. + +Apply [api_keys.sql](api_keys.sql) before enabling API-key auth. + +To set a custom per-owner sub-account limit, insert an owner config row: -Active when `FMSG_JWT_JWKS_URL` is unset. Tokens must be HS256-signed with the -shared secret in `FMSG_API_JWT_SECRET`. Required claims are `sub` and `exp`; -`iat`/`nbf` are honoured when present. +```sql +INSERT INTO fmsg_api_sub_account (owner_addr, agent, max_sub_accounts) +VALUES ('@alice@example.com', '', 10) +ON CONFLICT (owner_addr, agent) +DO UPDATE SET max_sub_accounts = EXCLUDED.max_sub_accounts; +``` + +Operators can bootstrap or rotate keys without RS256 by using the built-in CLI +command. It uses the standard `PG*` connection environment variables and prints +the plaintext API key once: + +```bash +go run ./cmd/fmsg-webapi api-key create \ + -owner @alice@example.com \ + -agent bot \ + -cidr 203.0.113.0/24 \ + -expires 2026-12-31T00:00:00Z + +go run ./cmd/fmsg-webapi api-key rotate \ + -owner @alice@example.com \ + -agent bot \ + -expires 2027-03-31T00:00:00Z +``` ## Building @@ -101,6 +159,8 @@ export FMSG_JWT_JWKS_URL=https://idp.example.com/.well-known/jwks.json export FMSG_JWT_ISSUER=https://idp.example.com/ export FMSG_JWT_AUDIENCE=fmsg-web-client export FMSG_JWT_ADDRESS_CLAIM=fmsg_address +# Optional: also enable programmatic API keys. +# export FMSG_API_TOKEN_ED25519_PRIVATE_KEY=$(openssl rand -base64 32) export FMSG_TLS_CERT=/etc/letsencrypt/live/example.com/fullchain.pem export FMSG_TLS_KEY=/etc/letsencrypt/live/example.com/privkey.pem export PGHOST=localhost @@ -122,7 +182,7 @@ proxying `https://fmsgapi.example.com/` to `http://127.0.0.1:8000/`). ```bash export FMSG_DATA_DIR=/var/lib/fmsgd/ -export FMSG_API_JWT_SECRET=changeme +export FMSG_API_TOKEN_ED25519_PRIVATE_KEY=$(openssl rand -base64 32) export PGHOST=localhost export PGUSER=fmsg export PGPASSWORD=secret @@ -141,7 +201,10 @@ the HTTP server and kept alive by its own ping/pong heartbeat. ## API Routes -All routes are prefixed with `/fmsg` and require a valid `Authorization: Bearer ` header. The one exception is the WebSocket route `/fmsg/ws`, which additionally accepts the token via an `access_token` query parameter (browsers cannot set headers on a WebSocket). +All routes are prefixed with `/fmsg`. `POST /fmsg/token` accepts an API key and +returns a JWT. Other routes require a valid `Authorization: Bearer ` +header. The WebSocket route `/fmsg/ws` additionally accepts the token via an +`access_token` query parameter (browsers cannot set headers on a WebSocket). Rate limiting is enforced at the host level (e.g. `nftables`) rather than in the application. @@ -151,6 +214,11 @@ the application. | `GET` | `/fmsg` | List messages for user | | `GET` | `/fmsg/sent` | List authored messages (sent + drafts) | | `GET` | `/fmsg/ws` | WebSocket for pushed event notifications | +| `POST` | `/fmsg/token` | Exchange an API key for a JWT | +| `GET` | `/fmsg/sub-accounts` | List owned sub-accounts | +| `POST` | `/fmsg/sub-accounts` | Create a sub-account API key | +| `POST` | `/fmsg/sub-accounts/:agent/rotate-key` | Rotate a sub-account API key | +| `DELETE` | `/fmsg/sub-accounts/:agent` | Delete a sub-account | | `POST` | `/fmsg` | Create a draft message | | `GET` | `/fmsg/:id` | Retrieve a message | | `PUT` | `/fmsg/:id` | Update a draft message | @@ -168,6 +236,76 @@ the application. The `/fmsg/push/subscribe` routes are registered only when Web Push is configured (see [Web Push](#web-push)). +The `/fmsg/token` and `/fmsg/sub-accounts*` routes are registered only when +API-key auth is configured with `FMSG_API_TOKEN_ED25519_PRIVATE_KEY`. + +### POST `/fmsg/token` + +Exchanges an opaque API key for a short-lived JWT. + +**Authentication:** `Authorization: Bearer fmsgk__`. + +The key must be unexpired, match the stored hash, be used from an allowed CIDR, +and belong to a sub-account that exists in fmsgid. + +**Response:** + +```json +{ + "access_token": "eyJ...", + "token_type": "Bearer", + "expires_in": 43200, + "expires_at": "2026-12-31T12:00:00Z" +} +``` + +### GET `/fmsg/sub-accounts` + +Lists sub-accounts owned by the RS256-authenticated user. + +**Response:** + +```json +{ + "max_sub_accounts": 5, + "sub_accounts": [ + { + "agent": "bot", + "addr": "@alice_bot@example.com", + "key_id": "abc", + "allowed_cidrs": ["203.0.113.0/24"], + "key_expires_at": "2026-12-31T00:00:00Z" + } + ] +} +``` + +### POST `/fmsg/sub-accounts` + +Creates a sub-account and returns its plaintext API key once. Requires RS256 +owner authentication. + +```json +{ + "agent": "bot", + "allowed_cidrs": ["203.0.113.0/24"], + "key_expires_at": "2026-12-31T00:00:00Z" +} +``` + +The derived address is `@user_bot@domain`. `agent` may contain letters, digits, +dots, and hyphens, but not underscores. + +### POST `/fmsg/sub-accounts/:agent/rotate-key` + +Rotates a sub-account API key and returns the new plaintext key once. Requires +`key_expires_at`; `allowed_cidrs` may be supplied to replace the existing ranges. + +### DELETE `/fmsg/sub-accounts/:agent` + +Deletes a sub-account row and revokes future token exchange. Existing JWTs for +that key are rejected on their next protected-route request. + ### GET `/fmsg/ws` Upgrades the connection to a WebSocket over which the server pushes events that diff --git a/api_keys.sql b/api_keys.sql new file mode 100644 index 0000000..0e1a06d --- /dev/null +++ b/api_keys.sql @@ -0,0 +1,28 @@ +CREATE TABLE IF NOT EXISTS fmsg_api_sub_account ( + owner_addr varchar(255) NOT NULL, + agent varchar(64) NOT NULL, + sub_addr varchar(255), + key_id varchar(64), + key_hash bytea, + allowed_cidrs cidr[], + key_expires_at timestamptz, + max_sub_accounts int NOT NULL DEFAULT 5, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + PRIMARY KEY (owner_addr, agent), + UNIQUE (sub_addr), + UNIQUE (key_id), + CHECK (max_sub_accounts > 0), + CHECK ( + (agent = '' AND sub_addr IS NULL AND key_id IS NULL AND key_hash IS NULL AND allowed_cidrs IS NULL AND key_expires_at IS NULL) + OR + (agent <> '' AND sub_addr IS NOT NULL AND key_id IS NOT NULL AND key_hash IS NOT NULL AND allowed_cidrs IS NOT NULL AND cardinality(allowed_cidrs) > 0 AND key_expires_at IS NOT NULL) + ), + CHECK (agent = '' OR agent NOT LIKE '%\_%' ESCAPE '\') +); + +CREATE INDEX IF NOT EXISTS fmsg_api_sub_account_owner_idx + ON fmsg_api_sub_account ((lower(owner_addr))); + +CREATE INDEX IF NOT EXISTS fmsg_api_sub_account_sub_idx + ON fmsg_api_sub_account ((lower(sub_addr))); diff --git a/cmd/fmsg-webapi/apikey_cli.go b/cmd/fmsg-webapi/apikey_cli.go new file mode 100644 index 0000000..803dfce --- /dev/null +++ b/cmd/fmsg-webapi/apikey_cli.go @@ -0,0 +1,132 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "strings" + "time" + + "github.com/markmnl/fmsg-webapi/internal/apiauth" + "github.com/markmnl/fmsg-webapi/internal/db" + "github.com/markmnl/fmsg-webapi/internal/middleware" +) + +func runAPIKeyCLI(ctx context.Context, args []string) error { + if len(args) == 0 { + return fmt.Errorf("usage: api-key create|rotate -owner @user@domain -agent name -cidr 203.0.113.0/24 -expires 2026-12-31T00:00:00Z") + } + switch args[0] { + case "create": + return runAPIKeyCreate(ctx, args[1:]) + case "rotate": + return runAPIKeyRotate(ctx, args[1:]) + default: + return fmt.Errorf("unknown api-key command %q", args[0]) + } +} + +func runAPIKeyCreate(ctx context.Context, args []string) error { + fs := flag.NewFlagSet("api-key create", flag.ContinueOnError) + fs.SetOutput(os.Stderr) + owner := fs.String("owner", "", "owner fmsg address") + agent := fs.String("agent", "", "sub-account agent name") + cidrs := fs.String("cidr", "", "comma-separated allowed CIDR ranges") + expiresRaw := fs.String("expires", "", "API key expiry as RFC3339 timestamp") + if err := fs.Parse(args); err != nil { + return err + } + + subAddr, allowed, expires, key, hash, err := prepareCLIKeyInputs(*owner, *agent, *cidrs, *expiresRaw) + if err != nil { + return err + } + if len(allowed) == 0 { + return fmt.Errorf("cidr is required for create") + } + database, err := db.New(ctx, "") + if err != nil { + return err + } + defer database.Close() + + store := apiauth.NewStore(database) + if err := store.Create(ctx, *owner, *agent, subAddr, key.ID, hash, allowed, expires); err != nil { + return err + } + printCLIKey(*owner, *agent, subAddr, key) + return nil +} + +func runAPIKeyRotate(ctx context.Context, args []string) error { + fs := flag.NewFlagSet("api-key rotate", flag.ContinueOnError) + fs.SetOutput(os.Stderr) + owner := fs.String("owner", "", "owner fmsg address") + agent := fs.String("agent", "", "sub-account agent name") + cidrs := fs.String("cidr", "", "comma-separated allowed CIDR ranges; omit to keep existing") + expiresRaw := fs.String("expires", "", "API key expiry as RFC3339 timestamp") + if err := fs.Parse(args); err != nil { + return err + } + + subAddr, allowed, expires, key, hash, err := prepareCLIKeyInputs(*owner, *agent, *cidrs, *expiresRaw) + if err != nil { + return err + } + database, err := db.New(ctx, "") + if err != nil { + return err + } + defer database.Close() + + store := apiauth.NewStore(database) + replaceCIDRs := strings.TrimSpace(*cidrs) != "" + gotSubAddr, err := store.RotateKey(ctx, *owner, *agent, key.ID, hash, expires, allowed, replaceCIDRs) + if err != nil { + return err + } + if !strings.EqualFold(gotSubAddr, subAddr) { + return fmt.Errorf("stored sub-account address %s does not match derived address %s", gotSubAddr, subAddr) + } + printCLIKey(*owner, *agent, subAddr, key) + return nil +} + +func prepareCLIKeyInputs(owner, agent, cidrsRaw, expiresRaw string) (string, []string, time.Time, apiauth.APIKey, []byte, error) { + if !middleware.IsValidAddr(owner) { + return "", nil, time.Time{}, apiauth.APIKey{}, nil, fmt.Errorf("owner must be an fmsg address") + } + subAddr, err := apiauth.DeriveSubAccountAddr(owner, agent) + if err != nil { + return "", nil, time.Time{}, apiauth.APIKey{}, nil, err + } + expires, err := time.Parse(time.RFC3339, expiresRaw) + if err != nil || !expires.After(time.Now()) { + return "", nil, time.Time{}, apiauth.APIKey{}, nil, fmt.Errorf("expires must be a future RFC3339 timestamp") + } + var allowed []string + if strings.TrimSpace(cidrsRaw) != "" { + for _, cidr := range strings.Split(cidrsRaw, ",") { + allowed = append(allowed, strings.TrimSpace(cidr)) + } + } + if len(allowed) > 0 { + if err := apiauth.ValidateCIDRs(allowed); err != nil { + return "", nil, time.Time{}, apiauth.APIKey{}, nil, fmt.Errorf("invalid CIDR: %w", err) + } + } + key, err := apiauth.GenerateAPIKey() + if err != nil { + return "", nil, time.Time{}, apiauth.APIKey{}, nil, err + } + return subAddr, allowed, expires, key, apiauth.HashAPIKey(key.Value), nil +} + +func printCLIKey(owner, agent, subAddr string, key apiauth.APIKey) { + fmt.Printf("owner=%s\n", owner) + fmt.Printf("agent=%s\n", agent) + fmt.Printf("sub_addr=%s\n", subAddr) + fmt.Printf("key_id=%s\n", key.ID) + fmt.Printf("api_key=%s\n", key.Value) +} diff --git a/cmd/fmsg-webapi/main.go b/cmd/fmsg-webapi/main.go index 4348132..a3a2b77 100644 --- a/cmd/fmsg-webapi/main.go +++ b/cmd/fmsg-webapi/main.go @@ -3,7 +3,6 @@ package main import ( "context" "crypto/tls" - "encoding/base64" "errors" "log" "net/http" @@ -16,6 +15,7 @@ import ( "github.com/gin-gonic/gin" "github.com/joho/godotenv" + "github.com/markmnl/fmsg-webapi/internal/apiauth" "github.com/markmnl/fmsg-webapi/internal/db" "github.com/markmnl/fmsg-webapi/internal/handlers" "github.com/markmnl/fmsg-webapi/internal/middleware" @@ -25,16 +25,26 @@ func main() { // Load .env file if present (ignore error when absent). _ = godotenv.Load() + if len(os.Args) > 1 && os.Args[1] == "api-key" { + if err := runAPIKeyCLI(context.Background(), os.Args[2:]); err != nil { + log.Fatalf("api-key: %v", err) + } + return + } + // Required configuration. dataDir := mustEnv("FMSG_DATA_DIR") - // JWT configuration. Mode is selected automatically: - // * RS256 (prod, JWKS-backed JWTs) when FMSG_JWT_JWKS_URL is set. - // * HMAC (dev) otherwise, using FMSG_API_JWT_SECRET. + // JWT configuration. RS256 provider JWTs and first-party Ed25519 API + // tokens can be enabled independently. jwksURL := os.Getenv("FMSG_JWT_JWKS_URL") jwtIssuer := os.Getenv("FMSG_JWT_ISSUER") jwtAudience := os.Getenv("FMSG_JWT_AUDIENCE") jwtAddressClaim := os.Getenv("FMSG_JWT_ADDRESS_CLAIM") + apiTokenPrivate := os.Getenv("FMSG_API_TOKEN_ED25519_PRIVATE_KEY") + apiTokenIssuer := envOrDefault("FMSG_API_TOKEN_ISSUER", apiauth.DefaultTokenIssuer) + apiTokenAudience := envOrDefault("FMSG_API_TOKEN_AUDIENCE", apiauth.DefaultTokenAudience) + apiTokenTTL := envOrDefaultDuration("FMSG_API_TOKEN_TTL", apiauth.DefaultTokenTTL) // TLS configuration (optional — omit both to run plain HTTP). tlsCert := os.Getenv("FMSG_TLS_CERT") @@ -74,14 +84,27 @@ func main() { defer database.Close() log.Println("connected to PostgreSQL") - // Initialise JWT middleware. - jwtCfg, err := buildJWTConfig(ctx, jwksURL, jwtIssuer, jwtAudience, jwtAddressClaim, idURL) + apiStore := apiauth.NewStore(database) + var tokenIssuer *apiauth.TokenIssuer + if apiTokenPrivate != "" { + privateKey, err := apiauth.ParseEd25519PrivateKey(apiTokenPrivate) + if err != nil { + log.Fatalf("failed to parse FMSG_API_TOKEN_ED25519_PRIVATE_KEY: %v", err) + } + tokenIssuer = apiauth.NewTokenIssuer(privateKey, apiTokenIssuer, apiTokenAudience, apiTokenTTL) + log.Printf("API token auth enabled (issuer=%s, audience=%s, ttl=%s)", tokenIssuer.Issuer(), tokenIssuer.Audience(), tokenIssuer.TTL()) + } else { + log.Println("API token auth disabled (FMSG_API_TOKEN_ED25519_PRIVATE_KEY not set)") + } + + // Initialise authentication middleware. + jwtCfg, err := buildJWTConfig(ctx, jwksURL, jwtIssuer, jwtAudience, jwtAddressClaim, idURL, tokenIssuer, apiStore) if err != nil { - log.Fatalf("failed to configure JWT: %v", err) + log.Fatalf("failed to configure auth: %v", err) } jwtMiddleware, err := middleware.New(jwtCfg) if err != nil { - log.Fatalf("failed to initialise JWT middleware: %v", err) + log.Fatalf("failed to initialise auth middleware: %v", err) } // The WebSocket endpoint authenticates outside the Gin middleware chain, @@ -93,6 +116,13 @@ func main() { // Create Gin router. router := gin.Default() + if trustedProxies := parseCSV(os.Getenv("FMSG_TRUSTED_PROXIES")); len(trustedProxies) > 0 { + if err := router.SetTrustedProxies(trustedProxies); err != nil { + log.Fatalf("invalid FMSG_TRUSTED_PROXIES: %v", err) + } + } else if err := router.SetTrustedProxies(nil); err != nil { + log.Fatalf("failed to disable trusted proxies: %v", err) + } // CORS must run before authentication so that browser preflight (OPTIONS) // requests, which do not carry the Authorization header, are answered @@ -131,10 +161,23 @@ func main() { go hub.Run(context.Background()) wsHandler := handlers.NewWSHandler(jwtVerifier, hub, corsOrigins) + if tokenIssuer != nil { + tokenHandler := handlers.NewTokenHandler(apiStore, tokenIssuer, idURL) + router.POST("/fmsg/token", tokenHandler.Exchange) + } + // Register routes under /fmsg, all protected by JWT. fmsg := router.Group("/fmsg") fmsg.Use(jwtMiddleware) { + if tokenIssuer != nil { + subAccountHandler := handlers.NewSubAccountHandler(apiStore, idURL) + fmsg.GET("/sub-accounts", subAccountHandler.List) + fmsg.POST("/sub-accounts", subAccountHandler.Create) + fmsg.POST("/sub-accounts/:agent/rotate-key", subAccountHandler.RotateKey) + fmsg.DELETE("/sub-accounts/:agent", subAccountHandler.Delete) + } + fmsg.GET("", msgHandler.List) fmsg.GET("/sent", msgHandler.Sent) fmsg.POST("", msgHandler.Create) @@ -233,10 +276,19 @@ func envOrDefaultInt(key string, defaultValue int) int { return defaultValue } -// buildJWTConfig assembles a middleware.Config from environment-derived -// inputs, picking RS256 (prod, JWKS-backed JWTs) when a JWKS URL is supplied -// and falling back to HMAC (dev) otherwise. -func buildJWTConfig(ctx context.Context, jwksURL, issuer, audience, addressClaim, idURL string) (middleware.Config, error) { +func envOrDefaultDuration(key string, defaultValue time.Duration) time.Duration { + if v := os.Getenv(key); v != "" { + d, err := time.ParseDuration(v) + if err != nil { + log.Fatalf("environment variable %s must be a Go duration such as 12h: %v", key, err) + } + return d + } + return defaultValue +} + +// buildJWTConfig assembles a middleware.Config from environment-derived inputs. +func buildJWTConfig(ctx context.Context, jwksURL, issuer, audience, addressClaim, idURL string, tokenIssuer *apiauth.TokenIssuer, apiStore *apiauth.Store) (middleware.Config, error) { cfg := middleware.Config{ Issuer: issuer, Audience: audience, @@ -252,33 +304,21 @@ func buildJWTConfig(ctx context.Context, jwksURL, issuer, audience, addressClaim if err != nil { return cfg, err } - cfg.Mode = middleware.ModeRS256 cfg.JWKS = k.Keyfunc - log.Printf("JWT mode: RS256 (issuer=%s, jwks=%s, audience=%s, address_claim=%s)", issuer, jwksURL, audience, addressClaim) - return cfg, nil + log.Printf("RS256 auth enabled (issuer=%s, jwks=%s, audience=%s, address_claim=%s)", issuer, jwksURL, audience, addressClaim) + } else { + log.Println("RS256 auth disabled (FMSG_JWT_JWKS_URL not set)") } - secret := os.Getenv("FMSG_API_JWT_SECRET") - if secret == "" { - return cfg, errors.New("either FMSG_JWT_JWKS_URL (prod) or FMSG_API_JWT_SECRET (dev) must be set") + if tokenIssuer != nil { + cfg.APIPublicKey = tokenIssuer.PublicKey() + cfg.APIIssuer = tokenIssuer.Issuer() + cfg.APIAudience = tokenIssuer.Audience() + cfg.APIKeys = apiStore } - cfg.Mode = middleware.ModeHMAC - cfg.HMACKey = parseSecret(secret) - log.Println("JWT mode: HMAC (development)") - return cfg, nil -} -// parseSecret returns the HMAC key bytes for the given secret string. -// If s begins with "base64:" the remainder is base64-decoded; otherwise the -// raw string bytes are used. -func parseSecret(s string) []byte { - const prefix = "base64:" - if strings.HasPrefix(s, prefix) { - b, err := base64.StdEncoding.DecodeString(s[len(prefix):]) - if err != nil { - log.Fatalf("FMSG_API_JWT_SECRET has base64: prefix but is not valid base64: %v", err) - } - return b + if cfg.JWKS == nil && len(cfg.APIPublicKey) == 0 { + return cfg, errors.New("either FMSG_JWT_JWKS_URL or FMSG_API_TOKEN_ED25519_PRIVATE_KEY must be set") } - return []byte(s) + return cfg, nil } diff --git a/internal/apiauth/apikey.go b/internal/apiauth/apikey.go new file mode 100644 index 0000000..26e00dd --- /dev/null +++ b/internal/apiauth/apikey.go @@ -0,0 +1,66 @@ +package apiauth + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "strings" +) + +const ( + KeyPrefix = "fmsgk" +) + +var ( + ErrInvalidAPIKey = errors.New("invalid api key") +) + +type APIKey struct { + ID string + Secret string + Value string +} + +func GenerateAPIKey() (APIKey, error) { + id, err := randomURLToken(12) + if err != nil { + return APIKey{}, err + } + secret, err := randomURLToken(32) + if err != nil { + return APIKey{}, err + } + value := fmt.Sprintf("%s_%s_%s", KeyPrefix, id, secret) + return APIKey{ID: id, Secret: secret, Value: value}, nil +} + +func ParseAPIKey(value string) (APIKey, error) { + parts := strings.Split(value, "_") + if len(parts) != 3 || parts[0] != KeyPrefix || parts[1] == "" || parts[2] == "" { + return APIKey{}, ErrInvalidAPIKey + } + return APIKey{ID: parts[1], Secret: parts[2], Value: value}, nil +} + +func HashAPIKey(value string) []byte { + sum := sha256.Sum256([]byte(value)) + out := make([]byte, len(sum)) + copy(out, sum[:]) + return out +} + +func APIKeyHashMatches(value string, hash []byte) bool { + got := HashAPIKey(value) + return subtle.ConstantTimeCompare(got, hash) == 1 +} + +func randomURLToken(n int) (string, error) { + b := make([]byte, n) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} diff --git a/internal/apiauth/apikey_test.go b/internal/apiauth/apikey_test.go new file mode 100644 index 0000000..186cb6f --- /dev/null +++ b/internal/apiauth/apikey_test.go @@ -0,0 +1,56 @@ +package apiauth + +import ( + "encoding/base64" + "strings" + "testing" +) + +func TestGenerateParseAndHashAPIKey(t *testing.T) { + key, err := GenerateAPIKey() + if err != nil { + t.Fatal(err) + } + if !strings.HasPrefix(key.Value, KeyPrefix+"_") { + t.Fatalf("key prefix = %q", key.Value) + } + parsed, err := ParseAPIKey(key.Value) + if err != nil { + t.Fatal(err) + } + if parsed.ID != key.ID || parsed.Secret != key.Secret { + t.Fatalf("parsed key = %#v, want %#v", parsed, key) + } + hash := HashAPIKey(key.Value) + if !APIKeyHashMatches(key.Value, hash) { + t.Fatal("hash should match original key") + } + if APIKeyHashMatches(key.Value+"x", hash) { + t.Fatal("hash should not match modified key") + } +} + +func TestDeriveSubAccountAddr(t *testing.T) { + got, err := DeriveSubAccountAddr("@alice@example.com", "bot-1") + if err != nil { + t.Fatal(err) + } + if got != "@alice_bot-1@example.com" { + t.Fatalf("addr = %q", got) + } + if _, err := DeriveSubAccountAddr("@alice@example.com", "bad_agent"); err == nil { + t.Fatal("underscore agent should fail") + } +} + +func TestParseEd25519PrivateKeyAcceptsSeed(t *testing.T) { + seed := make([]byte, 32) + encoded := base64.StdEncoding.EncodeToString(seed) + key, err := ParseEd25519PrivateKey(encoded) + if err != nil { + t.Fatal(err) + } + if len(key) != 64 { + t.Fatalf("private key length = %d", len(key)) + } +} diff --git a/internal/apiauth/store.go b/internal/apiauth/store.go new file mode 100644 index 0000000..bdd082f --- /dev/null +++ b/internal/apiauth/store.go @@ -0,0 +1,307 @@ +package apiauth + +import ( + "context" + "crypto/subtle" + "errors" + "fmt" + "net" + "net/netip" + "strings" + "time" + + "github.com/jackc/pgx/v5" + + "github.com/markmnl/fmsg-webapi/internal/db" +) + +const DefaultMaxSubAccounts = 5 + +var ( + ErrNotFound = errors.New("sub-account not found") + ErrAlreadyExists = errors.New("sub-account already exists") + ErrLimitExceeded = errors.New("sub-account limit exceeded") + ErrCIDRDenied = errors.New("source IP not allowed") + ErrKeyExpired = errors.New("api key expired") + ErrKeyRevoked = errors.New("api key revoked") + ErrInvalidRemoteIP = errors.New("invalid source IP") +) + +type Store struct { + DB *db.DB +} + +type SubAccount struct { + OwnerAddr string + Agent string + Addr string + KeyID string + AllowedCIDRs []string + KeyExpiresAt time.Time + MaxSubAccounts int +} + +type APIKeyIdentity struct { + OwnerAddr string + SubAddr string + KeyID string +} + +func NewStore(database *db.DB) *Store { + return &Store{DB: database} +} + +func (s *Store) List(ctx context.Context, ownerAddr string) (int, []SubAccount, error) { + max, err := s.MaxSubAccounts(ctx, ownerAddr) + if err != nil { + return 0, nil, err + } + rows, err := s.DB.Pool.Query(ctx, + `SELECT owner_addr, agent, sub_addr, key_id, + ARRAY(SELECT cidr_value::text FROM unnest(allowed_cidrs) AS x(cidr_value)), + key_expires_at, max_sub_accounts + FROM fmsg_api_sub_account + WHERE lower(owner_addr) = lower($1) AND agent <> '' + ORDER BY agent`, ownerAddr) + if err != nil { + return 0, nil, err + } + defer rows.Close() + + var out []SubAccount + for rows.Next() { + var a SubAccount + if err := rows.Scan(&a.OwnerAddr, &a.Agent, &a.Addr, &a.KeyID, &a.AllowedCIDRs, &a.KeyExpiresAt, &a.MaxSubAccounts); err != nil { + return 0, nil, err + } + out = append(out, a) + } + if err := rows.Err(); err != nil { + return 0, nil, err + } + return max, out, nil +} + +func (s *Store) MaxSubAccounts(ctx context.Context, ownerAddr string) (int, error) { + var max int + err := s.DB.Pool.QueryRow(ctx, + `SELECT max_sub_accounts + FROM fmsg_api_sub_account + WHERE lower(owner_addr) = lower($1) AND agent = ''`, ownerAddr).Scan(&max) + if errors.Is(err, pgx.ErrNoRows) { + return DefaultMaxSubAccounts, nil + } + if err != nil { + return 0, err + } + return max, nil +} + +func (s *Store) Create(ctx context.Context, ownerAddr, agent, subAddr, keyID string, keyHash []byte, allowedCIDRs []string, keyExpiresAt time.Time) error { + tx, err := s.DB.Pool.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) //nolint:errcheck + + max, err := maxSubAccountsTx(ctx, tx, ownerAddr) + if err != nil { + return err + } + var count int + if err := tx.QueryRow(ctx, + `SELECT count(*) + FROM fmsg_api_sub_account + WHERE lower(owner_addr) = lower($1) AND agent <> ''`, ownerAddr).Scan(&count); err != nil { + return err + } + if count >= max { + return ErrLimitExceeded + } + + _, err = tx.Exec(ctx, + `INSERT INTO fmsg_api_sub_account + (owner_addr, agent, sub_addr, key_id, key_hash, allowed_cidrs, key_expires_at, max_sub_accounts, updated_at) + VALUES ($1, $2, $3, $4, $5, $6::cidr[], $7, $8, now())`, + ownerAddr, agent, subAddr, keyID, keyHash, allowedCIDRs, keyExpiresAt, max) + if isUniqueViolation(err) { + return ErrAlreadyExists + } + if err != nil { + return err + } + return tx.Commit(ctx) +} + +func (s *Store) RotateKey(ctx context.Context, ownerAddr, agent, keyID string, keyHash []byte, keyExpiresAt time.Time, allowedCIDRs []string, replaceCIDRs bool) (string, error) { + var subAddr string + var err error + if replaceCIDRs { + err = s.DB.Pool.QueryRow(ctx, + `UPDATE fmsg_api_sub_account + SET key_id = $3, key_hash = $4, key_expires_at = $5, allowed_cidrs = $6::cidr[], updated_at = now() + WHERE lower(owner_addr) = lower($1) AND agent = $2 AND agent <> '' + RETURNING sub_addr`, + ownerAddr, agent, keyID, keyHash, keyExpiresAt, allowedCIDRs).Scan(&subAddr) + } else { + err = s.DB.Pool.QueryRow(ctx, + `UPDATE fmsg_api_sub_account + SET key_id = $3, key_hash = $4, key_expires_at = $5, updated_at = now() + WHERE lower(owner_addr) = lower($1) AND agent = $2 AND agent <> '' + RETURNING sub_addr`, + ownerAddr, agent, keyID, keyHash, keyExpiresAt).Scan(&subAddr) + } + if isUniqueViolation(err) { + return "", ErrAlreadyExists + } + if errors.Is(err, pgx.ErrNoRows) { + return "", ErrNotFound + } + if err != nil { + return "", err + } + return subAddr, nil +} + +func (s *Store) Delete(ctx context.Context, ownerAddr, agent string) error { + tag, err := s.DB.Pool.Exec(ctx, + `DELETE FROM fmsg_api_sub_account + WHERE lower(owner_addr) = lower($1) AND agent = $2 AND agent <> ''`, + ownerAddr, agent) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return ErrNotFound + } + return nil +} + +func (s *Store) ValidateAPIKey(ctx context.Context, apiKey, remoteAddr string) (APIKeyIdentity, error) { + parsed, err := ParseAPIKey(apiKey) + if err != nil { + return APIKeyIdentity{}, err + } + + var ident APIKeyIdentity + var hash []byte + var cidrs []string + var expires time.Time + err = s.DB.Pool.QueryRow(ctx, + `SELECT owner_addr, sub_addr, key_id, key_hash, + ARRAY(SELECT cidr_value::text FROM unnest(allowed_cidrs) AS x(cidr_value)), + key_expires_at + FROM fmsg_api_sub_account + WHERE key_id = $1 AND agent <> ''`, parsed.ID). + Scan(&ident.OwnerAddr, &ident.SubAddr, &ident.KeyID, &hash, &cidrs, &expires) + if errors.Is(err, pgx.ErrNoRows) { + return APIKeyIdentity{}, ErrInvalidAPIKey + } + if err != nil { + return APIKeyIdentity{}, err + } + if subtle.ConstantTimeCompare(HashAPIKey(apiKey), hash) != 1 { + return APIKeyIdentity{}, ErrInvalidAPIKey + } + if time.Now().After(expires) { + return APIKeyIdentity{}, ErrKeyExpired + } + if err := remoteAllowed(remoteAddr, cidrs); err != nil { + return APIKeyIdentity{}, err + } + return ident, nil +} + +func (s *Store) ValidateToken(ctx context.Context, keyID, ownerAddr, subAddr, remoteAddr string) error { + var cidrs []string + var expires time.Time + err := s.DB.Pool.QueryRow(ctx, + `SELECT ARRAY(SELECT cidr_value::text FROM unnest(allowed_cidrs) AS x(cidr_value)), key_expires_at + FROM fmsg_api_sub_account + WHERE key_id = $1 + AND lower(owner_addr) = lower($2) + AND lower(sub_addr) = lower($3) + AND agent <> ''`, + keyID, ownerAddr, subAddr).Scan(&cidrs, &expires) + if errors.Is(err, pgx.ErrNoRows) { + return ErrKeyRevoked + } + if err != nil { + return err + } + if time.Now().After(expires) { + return ErrKeyExpired + } + return remoteAllowed(remoteAddr, cidrs) +} + +func (s *Store) ValidateActAs(ctx context.Context, ownerAddr, subAddr string) error { + var exists bool + err := s.DB.Pool.QueryRow(ctx, + `SELECT true + FROM fmsg_api_sub_account + WHERE lower(owner_addr) = lower($1) + AND lower(sub_addr) = lower($2) + AND agent <> ''`, + ownerAddr, subAddr).Scan(&exists) + if errors.Is(err, pgx.ErrNoRows) { + return ErrNotFound + } + if err != nil { + return err + } + return nil +} + +func maxSubAccountsTx(ctx context.Context, tx pgx.Tx, ownerAddr string) (int, error) { + var max int + err := tx.QueryRow(ctx, + `SELECT max_sub_accounts + FROM fmsg_api_sub_account + WHERE lower(owner_addr) = lower($1) AND agent = ''`, ownerAddr).Scan(&max) + if errors.Is(err, pgx.ErrNoRows) { + return DefaultMaxSubAccounts, nil + } + if err != nil { + return 0, err + } + return max, nil +} + +func remoteAllowed(remoteAddr string, cidrs []string) error { + addr, err := parseRemoteAddr(remoteAddr) + if err != nil { + return err + } + for _, raw := range cidrs { + prefix, err := netip.ParsePrefix(strings.TrimSpace(raw)) + if err != nil { + return fmt.Errorf("invalid stored CIDR %q: %w", raw, err) + } + if prefix.Contains(addr) { + return nil + } + } + return ErrCIDRDenied +} + +func parseRemoteAddr(remoteAddr string) (netip.Addr, error) { + if host, _, err := net.SplitHostPort(remoteAddr); err == nil { + remoteAddr = host + } + addr, err := netip.ParseAddr(remoteAddr) + if err != nil { + return netip.Addr{}, ErrInvalidRemoteIP + } + if addr.Is4In6() { + addr = addr.Unmap() + } + return addr, nil +} + +func isUniqueViolation(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "SQLSTATE 23505") +} diff --git a/internal/apiauth/subaccount.go b/internal/apiauth/subaccount.go new file mode 100644 index 0000000..7d2d0e1 --- /dev/null +++ b/internal/apiauth/subaccount.go @@ -0,0 +1,49 @@ +package apiauth + +import ( + "errors" + "net/netip" + "strings" + "unicode" +) + +var ErrInvalidAgent = errors.New("invalid agent") + +func ValidateAgent(agent string) error { + if agent == "" || len(agent) > 64 || strings.Contains(agent, "_") || strings.Contains(agent, "@") { + return ErrInvalidAgent + } + for _, r := range agent { + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '-' || r == '.' { + continue + } + return ErrInvalidAgent + } + return nil +} + +func DeriveSubAccountAddr(ownerAddr, agent string) (string, error) { + if err := ValidateAgent(agent); err != nil { + return "", err + } + if !strings.HasPrefix(ownerAddr, "@") { + return "", ErrInvalidAgent + } + parts := strings.SplitN(ownerAddr[1:], "@", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", ErrInvalidAgent + } + return "@" + parts[0] + "_" + agent + "@" + parts[1], nil +} + +func ValidateCIDRs(cidrs []string) error { + if len(cidrs) == 0 { + return errors.New("at least one CIDR is required") + } + for _, raw := range cidrs { + if _, err := netip.ParsePrefix(strings.TrimSpace(raw)); err != nil { + return err + } + } + return nil +} diff --git a/internal/apiauth/token.go b/internal/apiauth/token.go new file mode 100644 index 0000000..60bbfee --- /dev/null +++ b/internal/apiauth/token.go @@ -0,0 +1,105 @@ +package apiauth + +import ( + "crypto/ed25519" + "encoding/base64" + "errors" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + DefaultTokenIssuer = "fmsg-webapi" + DefaultTokenAudience = "fmsg-webapi" + DefaultTokenTTL = 12 * time.Hour +) + +type TokenIssuer struct { + privateKey ed25519.PrivateKey + publicKey ed25519.PublicKey + issuer string + audience string + ttl time.Duration +} + +type TokenClaims struct { + OwnerAddr string `json:"owner"` + APIKeyID string `json:"api_key_id"` + jwt.RegisteredClaims +} + +func NewTokenIssuer(privateKey ed25519.PrivateKey, issuer, audience string, ttl time.Duration) *TokenIssuer { + if issuer == "" { + issuer = DefaultTokenIssuer + } + if audience == "" { + audience = DefaultTokenAudience + } + if ttl == 0 { + ttl = DefaultTokenTTL + } + pub := privateKey.Public().(ed25519.PublicKey) + return &TokenIssuer{privateKey: privateKey, publicKey: pub, issuer: issuer, audience: audience, ttl: ttl} +} + +func (i *TokenIssuer) PublicKey() ed25519.PublicKey { + return i.publicKey +} + +func (i *TokenIssuer) Issuer() string { + return i.issuer +} + +func (i *TokenIssuer) Audience() string { + return i.audience +} + +func (i *TokenIssuer) TTL() time.Duration { + return i.ttl +} + +func (i *TokenIssuer) Mint(ownerAddr, subAddr, keyID string, now time.Time) (string, time.Time, error) { + expires := now.Add(i.ttl) + claims := TokenClaims{ + OwnerAddr: ownerAddr, + APIKeyID: keyID, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: i.issuer, + Subject: subAddr, + Audience: jwt.ClaimStrings{i.audience}, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(expires), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) + signed, err := tok.SignedString(i.privateKey) + if err != nil { + return "", time.Time{}, err + } + return signed, expires, nil +} + +func ParseEd25519PrivateKey(s string) (ed25519.PrivateKey, error) { + raw, err := base64.StdEncoding.DecodeString(s) + if err != nil { + raw, err = base64.RawStdEncoding.DecodeString(s) + } + if err != nil { + raw, err = base64.RawURLEncoding.DecodeString(s) + } + if err != nil { + return nil, fmt.Errorf("decode Ed25519 private key: %w", err) + } + switch len(raw) { + case ed25519.PrivateKeySize: + key := make(ed25519.PrivateKey, ed25519.PrivateKeySize) + copy(key, raw) + return key, nil + case ed25519.SeedSize: + return ed25519.NewKeyFromSeed(raw), nil + default: + return nil, errors.New("Ed25519 private key must be a base64-encoded 32-byte seed or 64-byte private key") + } +} diff --git a/internal/handlers/subaccounts.go b/internal/handlers/subaccounts.go new file mode 100644 index 0000000..30dcd0f --- /dev/null +++ b/internal/handlers/subaccounts.go @@ -0,0 +1,257 @@ +package handlers + +import ( + "errors" + "log" + "net/http" + "time" + + "github.com/gin-gonic/gin" + + "github.com/markmnl/fmsg-webapi/internal/apiauth" + "github.com/markmnl/fmsg-webapi/internal/middleware" +) + +type SubAccountHandler struct { + store *apiauth.Store + idURL string +} + +func NewSubAccountHandler(store *apiauth.Store, idURL string) *SubAccountHandler { + return &SubAccountHandler{store: store, idURL: idURL} +} + +type subAccountInput struct { + Agent string `json:"agent"` + AllowedCIDRs []string `json:"allowed_cidrs"` + KeyExpiresAt string `json:"key_expires_at"` +} + +type rotateKeyInput struct { + AllowedCIDRs []string `json:"allowed_cidrs"` + KeyExpiresAt string `json:"key_expires_at"` +} + +type subAccountResponse struct { + Agent string `json:"agent"` + Addr string `json:"addr"` + KeyID string `json:"key_id,omitempty"` + AllowedCIDRs []string `json:"allowed_cidrs"` + KeyExpiresAt string `json:"key_expires_at"` + APIKey string `json:"api_key,omitempty"` +} + +func (h *SubAccountHandler) List(c *gin.Context) { + owner, ok := requireRS256Owner(c) + if !ok { + return + } + + max, accounts, err := h.store.List(c.Request.Context(), owner) + if err != nil { + log.Printf("sub-accounts list: owner=%s: %v", owner, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list sub-accounts"}) + return + } + + out := make([]subAccountResponse, 0, len(accounts)) + for _, a := range accounts { + out = append(out, subAccountResponse{ + Agent: a.Agent, + Addr: a.Addr, + KeyID: a.KeyID, + AllowedCIDRs: a.AllowedCIDRs, + KeyExpiresAt: a.KeyExpiresAt.UTC().Format(time.RFC3339), + }) + } + c.JSON(http.StatusOK, gin.H{"max_sub_accounts": max, "sub_accounts": out}) +} + +func (h *SubAccountHandler) Create(c *gin.Context) { + owner, ok := requireRS256Owner(c) + if !ok { + return + } + + var in subAccountInput + if err := c.ShouldBindJSON(&in); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + expires, err := parseRequiredExpiry(in.KeyExpiresAt) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "key_expires_at must be a future RFC3339 timestamp"}) + return + } + if err := apiauth.ValidateCIDRs(in.AllowedCIDRs); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "allowed_cidrs must contain valid CIDR ranges"}) + return + } + subAddr, err := apiauth.DeriveSubAccountAddr(owner, in.Agent) + if err != nil || !middleware.IsValidAddr(subAddr) { + c.JSON(http.StatusBadRequest, gin.H{"error": "agent must be 1-64 letters/digits/dots/hyphens and contain no underscores"}) + return + } + if !checkAcceptingFmsgID(c, h.idURL, subAddr) { + return + } + + key, hash, err := newPlaintextKey() + if err != nil { + log.Printf("sub-account create: key generation failed: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate api key"}) + return + } + if err := h.store.Create(c.Request.Context(), owner, in.Agent, subAddr, key.ID, hash, in.AllowedCIDRs, expires); err != nil { + respondSubAccountStoreError(c, err) + return + } + c.JSON(http.StatusCreated, subAccountResponse{ + Agent: in.Agent, + Addr: subAddr, + KeyID: key.ID, + AllowedCIDRs: in.AllowedCIDRs, + KeyExpiresAt: expires.UTC().Format(time.RFC3339), + APIKey: key.Value, + }) +} + +func (h *SubAccountHandler) RotateKey(c *gin.Context) { + owner, ok := requireRS256Owner(c) + if !ok { + return + } + agent := c.Param("agent") + if err := apiauth.ValidateAgent(agent); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent"}) + return + } + expectedSubAddr, err := apiauth.DeriveSubAccountAddr(owner, agent) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent"}) + return + } + if !checkAcceptingFmsgID(c, h.idURL, expectedSubAddr) { + return + } + + var in rotateKeyInput + if err := c.ShouldBindJSON(&in); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + expires, err := parseRequiredExpiry(in.KeyExpiresAt) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "key_expires_at must be a future RFC3339 timestamp"}) + return + } + replaceCIDRs := in.AllowedCIDRs != nil + if replaceCIDRs { + if err := apiauth.ValidateCIDRs(in.AllowedCIDRs); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "allowed_cidrs must contain valid CIDR ranges"}) + return + } + } + + key, hash, err := newPlaintextKey() + if err != nil { + log.Printf("sub-account rotate: key generation failed: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate api key"}) + return + } + subAddr, err := h.store.RotateKey(c.Request.Context(), owner, agent, key.ID, hash, expires, in.AllowedCIDRs, replaceCIDRs) + if err != nil { + respondSubAccountStoreError(c, err) + return + } + c.JSON(http.StatusOK, gin.H{ + "agent": agent, + "addr": subAddr, + "key_id": key.ID, + "key_expires_at": expires.UTC().Format(time.RFC3339), + "api_key": key.Value, + }) +} + +func (h *SubAccountHandler) Delete(c *gin.Context) { + owner, ok := requireRS256Owner(c) + if !ok { + return + } + agent := c.Param("agent") + if err := apiauth.ValidateAgent(agent); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid agent"}) + return + } + if err := h.store.Delete(c.Request.Context(), owner, agent); err != nil { + respondSubAccountStoreError(c, err) + return + } + c.Status(http.StatusNoContent) +} + +func requireRS256Owner(c *gin.Context) (string, bool) { + if middleware.GetAuthType(c) != middleware.AuthTypeRS256 || middleware.GetIdentity(c) != middleware.GetOwnerIdentity(c) { + c.JSON(http.StatusForbidden, gin.H{"error": "RS256 owner authentication is required"}) + return "", false + } + return middleware.GetOwnerIdentity(c), true +} + +func parseRequiredExpiry(raw string) (time.Time, error) { + if raw == "" { + return time.Time{}, errors.New("missing expiry") + } + expires, err := time.Parse(time.RFC3339, raw) + if err != nil { + return time.Time{}, err + } + if !expires.After(time.Now()) { + return time.Time{}, errors.New("expiry must be in the future") + } + return expires, nil +} + +func newPlaintextKey() (apiauth.APIKey, []byte, error) { + key, err := apiauth.GenerateAPIKey() + if err != nil { + return apiauth.APIKey{}, nil, err + } + return key, apiauth.HashAPIKey(key.Value), nil +} + +func checkAcceptingFmsgID(c *gin.Context, idURL, addr string) bool { + code, accepting, err := middleware.CheckFmsgID(idURL, addr) + if err != nil { + log.Printf("sub-account fmsgid check: addr=%s: %v", addr, err) + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "identity service unavailable"}) + return false + } + if code == http.StatusNotFound { + c.JSON(http.StatusBadRequest, gin.H{"error": "sub-account not found in fmsgid"}) + return false + } + if code != http.StatusOK { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "identity service unavailable"}) + return false + } + if !accepting { + c.JSON(http.StatusForbidden, gin.H{"error": "sub-account is not accepting new messages"}) + return false + } + return true +} + +func respondSubAccountStoreError(c *gin.Context, err error) { + switch { + case errors.Is(err, apiauth.ErrAlreadyExists): + c.JSON(http.StatusConflict, gin.H{"error": "sub-account already exists"}) + case errors.Is(err, apiauth.ErrLimitExceeded): + c.JSON(http.StatusForbidden, gin.H{"error": "sub-account limit exceeded"}) + case errors.Is(err, apiauth.ErrNotFound): + c.JSON(http.StatusNotFound, gin.H{"error": "sub-account not found"}) + default: + log.Printf("sub-account store error: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to update sub-account"}) + } +} diff --git a/internal/handlers/token.go b/internal/handlers/token.go new file mode 100644 index 0000000..26ab52e --- /dev/null +++ b/internal/handlers/token.go @@ -0,0 +1,82 @@ +package handlers + +import ( + "errors" + "log" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + + "github.com/markmnl/fmsg-webapi/internal/apiauth" +) + +type TokenHandler struct { + store *apiauth.Store + issuer *apiauth.TokenIssuer + idURL string +} + +func NewTokenHandler(store *apiauth.Store, issuer *apiauth.TokenIssuer, idURL string) *TokenHandler { + return &TokenHandler{store: store, issuer: issuer, idURL: idURL} +} + +func (h *TokenHandler) Exchange(c *gin.Context) { + apiKey, err := bearerTokenStrict(c.GetHeader("Authorization")) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "missing or malformed Authorization header"}) + return + } + + ident, err := h.store.ValidateAPIKey(c.Request.Context(), apiKey, c.ClientIP()) + if err != nil { + respondTokenError(c, err) + return + } + if !checkAcceptingFmsgID(c, h.idURL, ident.SubAddr) { + return + } + + now := time.Now() + token, expires, err := h.issuer.Mint(ident.OwnerAddr, ident.SubAddr, ident.KeyID, now) + if err != nil { + log.Printf("token exchange: mint failed for key_id=%s sub=%s: %v", ident.KeyID, ident.SubAddr, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to mint token"}) + return + } + c.JSON(http.StatusOK, gin.H{ + "access_token": token, + "token_type": "Bearer", + "expires_in": int64(h.issuer.TTL().Seconds()), + "expires_at": expires.UTC().Format(time.RFC3339), + }) +} + +func respondTokenError(c *gin.Context, err error) { + switch { + case errors.Is(err, apiauth.ErrCIDRDenied): + c.JSON(http.StatusForbidden, gin.H{"error": "source IP not allowed"}) + case errors.Is(err, apiauth.ErrKeyExpired): + c.JSON(http.StatusUnauthorized, gin.H{"error": "api key expired"}) + case errors.Is(err, apiauth.ErrInvalidRemoteIP): + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid source IP"}) + case errors.Is(err, apiauth.ErrInvalidAPIKey): + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid api key"}) + default: + log.Printf("token exchange: validation failed: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to validate api key"}) + } +} + +func bearerTokenStrict(header string) (string, error) { + const prefix = "Bearer " + if len(header) <= len(prefix) || !strings.EqualFold(header[:len(prefix)], prefix) { + return "", errors.New("missing Bearer prefix") + } + token := strings.TrimSpace(header[len(prefix):]) + if token == "" { + return "", errors.New("empty token") + } + return token, nil +} diff --git a/internal/handlers/ws.go b/internal/handlers/ws.go index e85f941..f492d9a 100644 --- a/internal/handlers/ws.go +++ b/internal/handlers/ws.go @@ -82,11 +82,12 @@ func (h *WSHandler) Connect(c *gin.Context) { return } - addr, status, msg := h.verifier.Authenticate(token) + res, status, msg := h.verifier.AuthenticateRequest(c.Request.Context(), token, c.ClientIP(), c.GetHeader("X-FMSG-Act-As")) if status != http.StatusOK { c.JSON(status, gin.H{"error": msg}) return } + addr := res.Addr conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index ba9e8c9..ac763eb 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -31,7 +31,7 @@ type CORSConfig struct { func DefaultCORSConfig() CORSConfig { return CORSConfig{ AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowedHeaders: []string{"Authorization", "Content-Type"}, + AllowedHeaders: []string{"Authorization", "Content-Type", "X-FMSG-Act-As"}, MaxAge: 10 * time.Minute, } } diff --git a/internal/middleware/jwt.go b/internal/middleware/jwt.go index 7b42108..cec75fc 100644 --- a/internal/middleware/jwt.go +++ b/internal/middleware/jwt.go @@ -1,7 +1,9 @@ -// Package middleware configures the JWT authentication middleware. +// Package middleware configures authentication middleware. package middleware import ( + "context" + "crypto/ed25519" "errors" "fmt" "log" @@ -13,53 +15,42 @@ import ( "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "golang.org/x/sync/singleflight" + + "github.com/markmnl/fmsg-webapi/internal/apiauth" ) -// IdentityKey is the Gin context key under which the authenticated user -// address is stored. -const IdentityKey = "sub" +const ( + IdentityKey = "sub" + OwnerIdentityKey = "owner" + AuthTypeKey = "auth_type" + + AuthTypeRS256 = "rs256" + AuthTypeAPI = "api_token" +) // DefaultClockSkew is the leeway applied to iat/nbf/exp validation to tolerate // minor clock differences between services. const DefaultClockSkew = 10 * time.Second -// Mode selects the JWT verification strategy. -type Mode int - -const ( - // ModeHMAC verifies HS256 tokens with a shared symmetric secret. - // Intended for development and testing. - ModeHMAC Mode = iota - // ModeRS256 verifies RS256 JWTs whose public keys are served via JWKS. - ModeRS256 -) +type APIKeyChecker interface { + ValidateToken(ctx context.Context, keyID, ownerAddr, subAddr, remoteAddr string) error + ValidateActAs(ctx context.Context, ownerAddr, subAddr string) error +} -// Config configures the JWT middleware. +// Config configures authentication. type Config struct { - // Mode selects HMAC (dev) or RS256 (prod) verification. - Mode Mode - - // HMACKey is the symmetric secret bytes (required when Mode == ModeHMAC). - HMACKey []byte - - // JWKS resolves RSA public keys (typically by token header `kid`). - // Required when Mode == ModeRS256. - JWKS jwt.Keyfunc - - // Issuer, when non-empty, is required to match the token `iss` claim. - // Mandatory in RS256 mode. - Issuer string - - // Audience, when non-empty, is required to be present in the token - // `aud` claim. Mandatory in RS256 mode to pin tokens to the configured - // application or API. - Audience string - - // AddressClaim is the JWT claim name carrying the user's fmsg address. - // Mandatory in RS256 mode because external identity providers usually - // put provider-specific identifiers in `sub`. + // RS256/JWKS provider token verification. Enabled when JWKS is non-nil. + JWKS jwt.Keyfunc + Issuer string + Audience string AddressClaim string + // Ed25519 first-party API-token verification. Enabled when APIPublicKey is non-empty. + APIPublicKey ed25519.PublicKey + APIIssuer string + APIAudience string + APIKeys APIKeyChecker + // IDURL is the base URL of the fmsgid identity service. IDURL string @@ -68,46 +59,34 @@ type Config struct { ClockSkew time.Duration } -// Verifier verifies fmsg JWT bearer tokens. It is safe for concurrent use and -// is shared by the Gin authentication middleware and the WebSocket handler, -// which authenticates outside the Gin middleware chain (browsers cannot set an -// Authorization header on a WebSocket connection). +type authResult struct { + Addr string + OwnerAddr string + AuthType string +} + +// Verifier verifies fmsg bearer tokens. It is safe for concurrent use and is +// shared by Gin middleware and the WebSocket handler. type Verifier struct { - mode Mode - parser *jwt.Parser - keyFunc jwt.Keyfunc - idURL string + rsParser *jwt.Parser + rsKeyFunc jwt.Keyfunc + issuer string + audience string addressClaim string + apiParser *jwt.Parser + apiPublicKey ed25519.PublicKey + apiKeys APIKeyChecker + idURL string } -// NewVerifier constructs a Verifier from the given configuration. func NewVerifier(cfg Config) (*Verifier, error) { if cfg.ClockSkew == 0 { cfg.ClockSkew = DefaultClockSkew } - var ( - validMethods []string - keyFunc jwt.Keyfunc - ) + v := &Verifier{idURL: cfg.IDURL} - switch cfg.Mode { - case ModeHMAC: - if len(cfg.HMACKey) == 0 { - return nil, errors.New("middleware: HMAC mode requires a non-empty HMACKey") - } - validMethods = []string{jwt.SigningMethodHS256.Alg()} - key := cfg.HMACKey - keyFunc = func(t *jwt.Token) (interface{}, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %s", t.Method.Alg()) - } - return key, nil - } - case ModeRS256: - if cfg.JWKS == nil { - return nil, errors.New("middleware: RS256 mode requires a JWKS keyfunc") - } + if cfg.JWKS != nil { if cfg.Issuer == "" { return nil, errors.New("middleware: RS256 mode requires an Issuer") } @@ -117,110 +96,202 @@ func NewVerifier(cfg Config) (*Verifier, error) { if cfg.AddressClaim == "" { return nil, errors.New("middleware: RS256 mode requires an AddressClaim") } - validMethods = []string{jwt.SigningMethodRS256.Alg()} - jwks := cfg.JWKS - keyFunc = func(t *jwt.Token) (interface{}, error) { + v.rsKeyFunc = func(t *jwt.Token) (interface{}, error) { if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %s", t.Method.Alg()) } - return jwks(t) + return cfg.JWKS(t) } - default: - return nil, fmt.Errorf("middleware: unknown JWT mode %d", cfg.Mode) + v.rsParser = jwt.NewParser( + jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}), + jwt.WithLeeway(cfg.ClockSkew), + jwt.WithExpirationRequired(), + jwt.WithIssuedAt(), + jwt.WithIssuer(cfg.Issuer), + jwt.WithAudience(cfg.Audience), + ) + v.issuer = cfg.Issuer + v.audience = cfg.Audience + v.addressClaim = cfg.AddressClaim } - parserOpts := []jwt.ParserOption{ - jwt.WithValidMethods(validMethods), - jwt.WithLeeway(cfg.ClockSkew), - jwt.WithExpirationRequired(), - jwt.WithIssuedAt(), + if len(cfg.APIPublicKey) > 0 { + if cfg.APIKeys == nil { + return nil, errors.New("middleware: API token mode requires an API key checker") + } + if cfg.APIIssuer == "" { + cfg.APIIssuer = apiauth.DefaultTokenIssuer + } + if cfg.APIAudience == "" { + cfg.APIAudience = apiauth.DefaultTokenAudience + } + v.apiPublicKey = cfg.APIPublicKey + v.apiKeys = cfg.APIKeys + v.apiParser = jwt.NewParser( + jwt.WithValidMethods([]string{jwt.SigningMethodEdDSA.Alg()}), + jwt.WithLeeway(cfg.ClockSkew), + jwt.WithExpirationRequired(), + jwt.WithIssuedAt(), + jwt.WithIssuer(cfg.APIIssuer), + jwt.WithAudience(cfg.APIAudience), + ) } - if cfg.Issuer != "" { - parserOpts = append(parserOpts, jwt.WithIssuer(cfg.Issuer)) + + if v.rsParser == nil && v.apiParser == nil { + return nil, errors.New("middleware: at least one auth mode must be configured") } - if cfg.Audience != "" { - parserOpts = append(parserOpts, jwt.WithAudience(cfg.Audience)) + return v, nil +} + +func (v *Verifier) Authenticate(tokenStr string) (addr string, status int, msg string) { + res, status, msg := v.AuthenticateRequest(context.Background(), tokenStr, "127.0.0.1", "") + if status != http.StatusOK { + return "", status, msg } + return res.Addr, http.StatusOK, "" +} - return &Verifier{ - mode: cfg.Mode, - parser: jwt.NewParser(parserOpts...), - keyFunc: keyFunc, - idURL: cfg.IDURL, - addressClaim: cfg.AddressClaim, - }, nil +func (v *Verifier) AuthenticateRequest(ctx context.Context, tokenStr, remoteAddr, actAs string) (authResult, int, string) { + if v.rsParser != nil { + res, err := v.authenticateRS256(ctx, tokenStr, actAs) + if err == nil { + return res, http.StatusOK, "" + } + if status, msg, ok := authFailureFromError(err); ok { + return authResult{}, status, msg + } + } + if v.apiParser != nil { + res, err := v.authenticateAPIToken(ctx, tokenStr, remoteAddr, actAs) + if err == nil { + return res, http.StatusOK, "" + } + if status, msg, ok := authFailureFromError(err); ok { + return authResult{}, status, msg + } + } + log.Printf("auth rejected: reason=parse_error") + return authResult{}, http.StatusUnauthorized, "invalid token" } -// Authenticate parses & verifies a bearer token string, validates its claims, -// derives the user's fmsg address, and confirms via fmsgid that the user is -// known and accepting messages. -// -// The address is derived per mode: in RS256 mode the address comes from the -// configured address claim because `sub` is usually a provider-specific -// identifier; in HMAC dev mode the `sub` claim is the address. -// -// On success it returns the user address and http.StatusOK. On failure it -// returns the empty address, an HTTP status (400/401/403/503), and a -// client-safe error message. -func (v *Verifier) Authenticate(tokenStr string) (addr string, status int, msg string) { +func (v *Verifier) authenticateRS256(ctx context.Context, tokenStr, actAs string) (authResult, error) { claims := jwt.MapClaims{} - if _, err := v.parser.ParseWithClaims(tokenStr, claims, v.keyFunc); err != nil { - log.Printf("auth rejected: reason=parse_error err=%v", err) - return "", http.StatusUnauthorized, "invalid token" - } - - switch v.mode { - case ModeRS256: - // The token is valid and the user authenticated; a missing address - // claim just means no fmsg account exists yet, so respond 403 - // rather than 401 (which would trigger client token refreshes). - addr, _ = claims[v.addressClaim].(string) - if addr == "" { - sub, _ := claims["sub"].(string) - log.Printf("auth rejected: reason=no_address_claim claim=%q sub=%q", v.addressClaim, sub) - return "", http.StatusForbidden, "no fmsg account for this identity" + if _, err := v.rsParser.ParseWithClaims(tokenStr, claims, v.rsKeyFunc); err != nil { + return authResult{}, err + } + + owner, _ := claims[v.addressClaim].(string) + if owner == "" { + sub, _ := claims["sub"].(string) + log.Printf("auth rejected: reason=no_address_claim claim=%q sub=%q", v.addressClaim, sub) + return authResult{}, authError{status: http.StatusForbidden, msg: "no fmsg account for this identity"} + } + if status, msg := validateIdentity(owner, v.idURL); status != http.StatusOK { + return authResult{}, authError{status: status, msg: msg} + } + res := authResult{Addr: owner, OwnerAddr: owner, AuthType: AuthTypeRS256} + + if strings.TrimSpace(actAs) == "" { + return res, nil + } + if v.apiKeys == nil { + return authResult{}, authError{status: http.StatusForbidden, msg: "act-as is not enabled"} + } + actAs = strings.TrimSpace(actAs) + if !IsValidAddr(actAs) { + return authResult{}, authError{status: http.StatusUnauthorized, msg: "invalid act-as identity"} + } + if err := v.apiKeys.ValidateActAs(ctx, owner, actAs); err != nil { + return authResult{}, err + } + if status, msg := validateIdentity(actAs, v.idURL); status != http.StatusOK { + return authResult{}, authError{status: status, msg: msg} + } + res.Addr = actAs + return res, nil +} + +func (v *Verifier) authenticateAPIToken(ctx context.Context, tokenStr, remoteAddr, actAs string) (authResult, error) { + if strings.TrimSpace(actAs) != "" { + return authResult{}, authError{status: http.StatusForbidden, msg: "act-as is only available with RS256 authentication"} + } + claims := &apiauth.TokenClaims{} + _, err := v.apiParser.ParseWithClaims(tokenStr, claims, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok { + return nil, fmt.Errorf("unexpected signing method: %s", t.Method.Alg()) } - default: - addr, _ = claims["sub"].(string) + return v.apiPublicKey, nil + }) + if err != nil { + return authResult{}, err + } + subAddr := claims.Subject + if !IsValidAddr(subAddr) || !IsValidAddr(claims.OwnerAddr) || claims.APIKeyID == "" { + return authResult{}, authError{status: http.StatusUnauthorized, msg: "invalid token identity"} + } + if err := v.apiKeys.ValidateToken(ctx, claims.APIKeyID, claims.OwnerAddr, subAddr, remoteAddr); err != nil { + return authResult{}, err + } + if status, msg := validateIdentity(subAddr, v.idURL); status != http.StatusOK { + return authResult{}, authError{status: status, msg: msg} + } + return authResult{Addr: subAddr, OwnerAddr: claims.OwnerAddr, AuthType: AuthTypeAPI}, nil +} + +type authError struct { + status int + msg string +} + +func (e authError) Error() string { + return e.msg +} + +func authFailureFromError(err error) (int, string, bool) { + var ae authError + if errors.As(err, &ae) { + return ae.status, ae.msg, true + } + switch { + case errors.Is(err, apiauth.ErrCIDRDenied): + return http.StatusForbidden, "source IP not allowed", true + case errors.Is(err, apiauth.ErrKeyExpired): + return http.StatusUnauthorized, "api key expired", true + case errors.Is(err, apiauth.ErrKeyRevoked): + return http.StatusUnauthorized, "api key revoked", true + case errors.Is(err, apiauth.ErrInvalidRemoteIP): + return http.StatusUnauthorized, "invalid source IP", true + case errors.Is(err, apiauth.ErrNotFound): + return http.StatusForbidden, "sub-account not authorised", true } + return 0, "", false +} +func validateIdentity(addr, idURL string) (int, string) { if !IsValidAddr(addr) { log.Printf("auth rejected: reason=invalid_addr addr=%q", addr) - return "", http.StatusUnauthorized, "invalid identity" + return http.StatusUnauthorized, "invalid identity" } - - code, accepting, err := checkFmsgID(v.idURL, addr) + code, accepting, err := CheckFmsgID(idURL, addr) if err != nil { log.Printf("fmsgid check error for %s: %v", addr, err) - return "", http.StatusServiceUnavailable, "identity service unavailable" + return http.StatusServiceUnavailable, "identity service unavailable" } switch { case code == http.StatusNotFound: log.Printf("auth rejected: addr=%s reason=not_found", addr) - return "", http.StatusBadRequest, fmt.Sprintf("User %s not found", addr) + return http.StatusBadRequest, fmt.Sprintf("User %s not found", addr) case code == http.StatusOK && !accepting: log.Printf("auth rejected: addr=%s reason=not_accepting", addr) - return "", http.StatusForbidden, fmt.Sprintf("User %s not authorised to send new messages", addr) + return http.StatusForbidden, fmt.Sprintf("User %s not authorised to send new messages", addr) case code != http.StatusOK: log.Printf("auth rejected: addr=%s reason=fmsgid_status=%d", addr, code) - return "", http.StatusServiceUnavailable, "identity service unavailable" + return http.StatusServiceUnavailable, "identity service unavailable" } - - return addr, http.StatusOK, "" + return http.StatusOK, "" } -// New constructs the JWT verification middleware. -// -// The returned handler: -// - extracts a Bearer token from the Authorization header, -// - parses & verifies the signature according to cfg.Mode, -// - validates iss/aud/exp/nbf claims, -// - derives the user address (RS256: the configured address claim; -// HMAC: the sub claim) and validates its shape, -// - calls fmsgid to confirm the user is known and accepting messages, -// - on success stores the address in the Gin context under IdentityKey. -// -// On failure the response is 400/401/403/503 with a JSON `{"error": "..."}` body. +// New constructs the authentication middleware. func New(cfg Config) (gin.HandlerFunc, error) { verifier, err := NewVerifier(cfg) if err != nil { @@ -234,13 +305,15 @@ func New(cfg Config) (gin.HandlerFunc, error) { return } - addr, status, msg := verifier.Authenticate(tokenStr) + res, status, msg := verifier.AuthenticateRequest(c.Request.Context(), tokenStr, c.ClientIP(), c.GetHeader("X-FMSG-Act-As")) if status != http.StatusOK { respondAuth(c, status, msg) return } - c.Set(IdentityKey, addr) + c.Set(IdentityKey, res.Addr) + c.Set(OwnerIdentityKey, res.OwnerAddr) + c.Set(AuthTypeKey, res.AuthType) c.Next() }, nil } @@ -267,7 +340,7 @@ func extractBearer(header string) (string, error) { return tok, nil } -// GetIdentity retrieves the authenticated user address from the Gin context. +// GetIdentity retrieves the effective authenticated user address from the Gin context. func GetIdentity(c *gin.Context) string { v, exists := c.Get(IdentityKey) if !exists { @@ -277,6 +350,24 @@ func GetIdentity(c *gin.Context) string { return addr } +func GetOwnerIdentity(c *gin.Context) string { + v, exists := c.Get(OwnerIdentityKey) + if !exists { + return GetIdentity(c) + } + addr, _ := v.(string) + return addr +} + +func GetAuthType(c *gin.Context) string { + v, exists := c.Get(AuthTypeKey) + if !exists { + return "" + } + authType, _ := v.(string) + return authType +} + // IsValidAddr checks that the address has the form "@user@domain". func IsValidAddr(addr string) bool { if len(addr) < 3 { @@ -290,14 +381,9 @@ func IsValidAddr(addr string) bool { } // fmsgIDClient is a dedicated HTTP client with a bounded timeout so that a -// slow or hung fmsgid never blocks an API request goroutine indefinitely -// (which would otherwise hold the inbound HTTP connection open and exhaust -// the browser's per-host connection limit). +// slow or hung fmsgid never blocks an API request goroutine indefinitely. var fmsgIDClient = &http.Client{Timeout: 5 * time.Second} -// fmsgIDCacheTTL is how long a positive fmsgid lookup is cached. Tokens are -// re-validated every time, but the relatively expensive network round-trip to -// fmsgid is short-circuited for this window. Negative results are not cached. const fmsgIDCacheTTL = 30 * time.Second type fmsgIDEntry struct { @@ -308,9 +394,6 @@ type fmsgIDEntry struct { var fmsgIDCache sync.Map // map[string]fmsgIDEntry, key = addr -// fmsgIDGroup coalesces concurrent lookups for the same address so that a -// burst of cache misses (e.g. several browser requests arriving before the -// first response is cached) results in a single upstream fmsgid call. var fmsgIDGroup singleflight.Group type fmsgIDResult struct { @@ -318,12 +401,8 @@ type fmsgIDResult struct { acceptingNew bool } -// checkFmsgID queries the fmsgid service for a user address. -// Returns (statusCode, acceptingNew, error). Successful 200 responses are -// cached for fmsgIDCacheTTL to avoid hammering fmsgid when a browser fires -// many concurrent requests with the same JWT. Concurrent cache misses for -// the same address are deduplicated via singleflight. -func checkFmsgID(idURL, addr string) (int, bool, error) { +// CheckFmsgID queries the fmsgid service for a user address. +func CheckFmsgID(idURL, addr string) (int, bool, error) { if v, ok := fmsgIDCache.Load(addr); ok { entry := v.(fmsgIDEntry) if time.Now().Before(entry.expires) { @@ -333,8 +412,6 @@ func checkFmsgID(idURL, addr string) (int, bool, error) { } v, err, _ := fmsgIDGroup.Do(addr, func() (interface{}, error) { - // Re-check inside the singleflight in case another goroutine just - // populated the cache while we were waiting to enter. if v, ok := fmsgIDCache.Load(addr); ok { entry := v.(fmsgIDEntry) if time.Now().Before(entry.expires) { @@ -350,8 +427,6 @@ func checkFmsgID(idURL, addr string) (int, bool, error) { return res.code, res.acceptingNew, nil } -// fetchFmsgID performs the actual HTTP call to fmsgid and stores positive -// results in the cache. func fetchFmsgID(idURL, addr string) (fmsgIDResult, error) { url := strings.TrimRight(idURL, "/") + "/fmsgid/" + addr resp, err := fmsgIDClient.Get(url) //nolint:gosec // URL constructed from trusted config + validated addr @@ -371,7 +446,7 @@ func fetchFmsgID(idURL, addr string) (fmsgIDResult, error) { AcceptingNew bool `json:"acceptingNew"` } if err := decodeJSON(resp.Body, &result); err != nil { - return fmsgIDResult{code: http.StatusOK, acceptingNew: true}, nil // assume accepting if parse fails + return fmsgIDResult{code: http.StatusOK, acceptingNew: true}, nil } fmsgIDCache.Store(addr, fmsgIDEntry{ diff --git a/internal/middleware/jwt_test.go b/internal/middleware/jwt_test.go index 679782e..bbc5b80 100644 --- a/internal/middleware/jwt_test.go +++ b/internal/middleware/jwt_test.go @@ -1,9 +1,12 @@ package middleware import ( + "context" + "crypto/ed25519" "crypto/rand" "crypto/rsa" "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" @@ -11,9 +14,10 @@ import ( "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" + + "github.com/markmnl/fmsg-webapi/internal/apiauth" ) -// Provider values used by the RS256 fixtures. const ( testIssuer = "https://issuer.example.test/" testAudience = "fmsg-web-client" @@ -24,6 +28,25 @@ func init() { gin.SetMode(gin.TestMode) } +type fakeAPIKeys struct { + tokenErr error + actErr error +} + +func (f fakeAPIKeys) ValidateToken(_ context.Context, keyID, ownerAddr, subAddr, remoteAddr string) error { + if keyID == "" || ownerAddr == "" || subAddr == "" || remoteAddr == "" { + return errors.New("missing token validation input") + } + return f.tokenErr +} + +func (f fakeAPIKeys) ValidateActAs(_ context.Context, ownerAddr, subAddr string) error { + if ownerAddr == "" || subAddr == "" { + return errors.New("missing act-as input") + } + return f.actErr +} + func TestIsValidAddr(t *testing.T) { tests := []struct { addr string @@ -47,8 +70,6 @@ func TestIsValidAddr(t *testing.T) { } } -// fakeJWKS returns a jwt.Keyfunc that yields a fixed RSA public key for a -// single known kid. func fakeJWKS(kid string, pub *rsa.PublicKey) jwt.Keyfunc { return func(t *jwt.Token) (interface{}, error) { k, _ := t.Header["kid"].(string) @@ -59,7 +80,6 @@ func fakeJWKS(kid string, pub *rsa.PublicKey) jwt.Keyfunc { } } -// fmsgIDServer returns an httptest server emulating fmsgid responses. func fmsgIDServer(t *testing.T, status int, accepting bool) *httptest.Server { t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -72,18 +92,18 @@ func fmsgIDServer(t *testing.T, status int, accepting bool) *httptest.Server { })) } -// runMiddleware executes the middleware against a synthetic request bearing -// the given token, returning the recorded response. -func runMiddleware(t *testing.T, mw gin.HandlerFunc, token string) *httptest.ResponseRecorder { +func runMiddleware(t *testing.T, mw gin.HandlerFunc, token string, actAs string) *httptest.ResponseRecorder { t.Helper() w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodGet, "/fmsg", nil) + c.Request.RemoteAddr = "127.0.0.1:12345" if token != "" { c.Request.Header.Set("Authorization", "Bearer "+token) } - called := false - c.Set("__test_next__", &called) + if actAs != "" { + c.Request.Header.Set("X-FMSG-Act-As", actAs) + } mw(c) return w } @@ -109,8 +129,6 @@ func signHS256(t *testing.T, secret []byte, claims jwt.MapClaims) string { return s } -// rs256Claims returns provider-token-shaped claims carrying the given fmsg -// address in the configured address claim. func rs256Claims(addr string) jwt.MapClaims { claims := jwt.MapClaims{ "iss": testIssuer, @@ -125,69 +143,6 @@ func rs256Claims(addr string) jwt.MapClaims { return claims } -func TestHMACMode_Happy(t *testing.T) { - srv := fmsgIDServer(t, http.StatusOK, true) - defer srv.Close() - - secret := []byte("dev-secret") - mw, err := New(Config{Mode: ModeHMAC, HMACKey: secret, IDURL: srv.URL}) - if err != nil { - t.Fatalf("New: %v", err) - } - - tok := signHS256(t, secret, jwt.MapClaims{ - "sub": "@alice@example.com", - "iat": time.Now().Unix(), - "exp": time.Now().Add(time.Hour).Unix(), - }) - w := runMiddleware(t, mw, tok) - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String()) - } -} - -func TestHMACMode_ClockSkewLeeway(t *testing.T) { - srv := fmsgIDServer(t, http.StatusOK, true) - defer srv.Close() - secret := []byte("dev-secret") - mw, err := New(Config{Mode: ModeHMAC, HMACKey: secret, IDURL: srv.URL}) - if err != nil { - t.Fatalf("New: %v", err) - } - - // iat/nbf within leeway is accepted. - now := time.Now() - tok := signHS256(t, secret, jwt.MapClaims{ - "sub": "@alice@example.com", - "iat": now.Add(DefaultClockSkew - time.Second).Unix(), - "nbf": now.Add(DefaultClockSkew - time.Second).Unix(), - "exp": now.Add(time.Hour).Unix(), - }) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusOK { - t.Fatalf("within-skew token should be accepted, got %d", w.Code) - } - - // Beyond leeway is rejected. - tok = signHS256(t, secret, jwt.MapClaims{ - "sub": "@alice@example.com", - "nbf": now.Add(DefaultClockSkew + 5*time.Second).Unix(), - "exp": now.Add(time.Hour).Unix(), - }) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { - t.Fatalf("out-of-skew token should be rejected, got %d", w.Code) - } -} - -func TestHMACMode_MissingHeader(t *testing.T) { - mw, err := New(Config{Mode: ModeHMAC, HMACKey: []byte("k"), IDURL: "http://127.0.0.1:0"}) - if err != nil { - t.Fatal(err) - } - if w := runMiddleware(t, mw, ""); w.Code != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", w.Code) - } -} - func newRS256Fixture(t *testing.T) (priv *rsa.PrivateKey, jwks jwt.Keyfunc) { t.Helper() priv, err := rsa.GenerateKey(rand.Reader, 2048) @@ -199,7 +154,6 @@ func newRS256Fixture(t *testing.T) (priv *rsa.PrivateKey, jwks jwt.Keyfunc) { func rs256Config(idURL string, jwks jwt.Keyfunc) Config { return Config{ - Mode: ModeRS256, JWKS: jwks, Issuer: testIssuer, Audience: testAudience, @@ -218,40 +172,53 @@ func TestRS256Mode_Happy(t *testing.T) { } tok := signRS256(t, priv, "prod-1", rs256Claims("@alice@example.com")) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusOK { + if w := runMiddleware(t, mw, tok, ""); w.Code != http.StatusOK { t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String()) } } -func TestRS256Mode_IdentityIsAddressClaim(t *testing.T) { - const addr = "@claim@example.com" - fmsgIDCache.Delete(addr) - defer fmsgIDCache.Delete(addr) +func TestRS256Mode_ActAsSubAccount(t *testing.T) { + fmsgIDCache.Delete("@alice@example.com") + fmsgIDCache.Delete("@alice_bot@example.com") + defer fmsgIDCache.Delete("@alice@example.com") + defer fmsgIDCache.Delete("@alice_bot@example.com") - hits := 0 - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - hits++ - if r.URL.Path != "/fmsgid/"+addr { - http.Error(w, "wrong address", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]bool{"acceptingNew": true}) - })) + srv := fmsgIDServer(t, http.StatusOK, true) defer srv.Close() priv, jwks := newRS256Fixture(t) - v, err := NewVerifier(rs256Config(srv.URL, jwks)) + apiPub, _, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + mw, err := New(Config{ + JWKS: jwks, + Issuer: testIssuer, + Audience: testAudience, + AddressClaim: testAddressClaim, + IDURL: srv.URL, + APIPublicKey: apiPub, + APIKeys: fakeAPIKeys{}, + }) if err != nil { - t.Fatalf("NewVerifier: %v", err) + t.Fatalf("New: %v", err) } - tok := signRS256(t, priv, "prod-1", rs256Claims(addr)) - gotAddr, status, _ := v.Authenticate(tok) - if status != http.StatusOK || gotAddr != addr { - t.Fatalf("got addr=%q status=%d, want %s/200", gotAddr, status, addr) + tok := signRS256(t, priv, "prod-1", rs256Claims("@alice@example.com")) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/fmsg", nil) + c.Request.RemoteAddr = "127.0.0.1:12345" + c.Request.Header.Set("Authorization", "Bearer "+tok) + c.Request.Header.Set("X-FMSG-Act-As", "@alice_bot@example.com") + mw(c) + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String()) + } + if got := GetIdentity(c); got != "@alice_bot@example.com" { + t.Fatalf("identity=%q", got) } - if hits != 1 { - t.Fatalf("fmsgid hits = %d, want 1", hits) + if got := GetOwnerIdentity(c); got != "@alice@example.com" { + t.Fatalf("owner=%q", got) } } @@ -263,14 +230,13 @@ func TestRS256Mode_MissingAddressClaim(t *testing.T) { if err != nil { t.Fatal(err) } - // A valid ID token whose identity has no fmsg account yet. tok := signRS256(t, priv, "prod-1", rs256Claims("")) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusForbidden { + if w := runMiddleware(t, mw, tok, ""); w.Code != http.StatusForbidden { t.Fatalf("expected 403, got %d body=%s", w.Code, w.Body.String()) } } -func TestRS256Mode_MalformedAddressClaim(t *testing.T) { +func TestRS256Mode_WrongIssuerAudienceAndExpired(t *testing.T) { srv := fmsgIDServer(t, http.StatusOK, true) defer srv.Close() priv, jwks := newRS256Fixture(t) @@ -278,69 +244,27 @@ func TestRS256Mode_MalformedAddressClaim(t *testing.T) { if err != nil { t.Fatal(err) } - tok := signRS256(t, priv, "prod-1", rs256Claims("not-an-address")) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", w.Code) - } -} -func TestRS256Mode_WrongIssuer(t *testing.T) { - srv := fmsgIDServer(t, http.StatusOK, true) - defer srv.Close() - priv, jwks := newRS256Fixture(t) - mw, err := New(rs256Config(srv.URL, jwks)) - if err != nil { - t.Fatal(err) - } claims := rs256Claims("@alice@example.com") claims["iss"] = "https://evil.example.com/" - tok := signRS256(t, priv, "prod-1", claims) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", w.Code) - } -} - -func TestRS256Mode_WrongAudience(t *testing.T) { - srv := fmsgIDServer(t, http.StatusOK, true) - defer srv.Close() - priv, jwks := newRS256Fixture(t) - mw, err := New(rs256Config(srv.URL, jwks)) - if err != nil { - t.Fatal(err) - } - - // Token minted for a different configured application or API. - claims := rs256Claims("@alice@example.com") - claims["aud"] = "SomeOtherClientID" - tok := signRS256(t, priv, "prod-1", claims) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { - t.Fatalf("wrong aud: expected 401, got %d", w.Code) + if w := runMiddleware(t, mw, signRS256(t, priv, "prod-1", claims), ""); w.Code != http.StatusUnauthorized { + t.Fatalf("wrong issuer expected 401, got %d", w.Code) } - // Token with no audience at all. claims = rs256Claims("@alice@example.com") - delete(claims, "aud") - tok = signRS256(t, priv, "prod-1", claims) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { - t.Fatalf("missing aud: expected 401, got %d", w.Code) + claims["aud"] = "other" + if w := runMiddleware(t, mw, signRS256(t, priv, "prod-1", claims), ""); w.Code != http.StatusUnauthorized { + t.Fatalf("wrong audience expected 401, got %d", w.Code) } -} -func TestRS256Mode_UnknownKID(t *testing.T) { - srv := fmsgIDServer(t, http.StatusOK, true) - defer srv.Close() - priv, jwks := newRS256Fixture(t) - mw, err := New(rs256Config(srv.URL, jwks)) - if err != nil { - t.Fatal(err) - } - tok := signRS256(t, priv, "rotated-key", rs256Claims("@alice@example.com")) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", w.Code) + claims = rs256Claims("@alice@example.com") + claims["exp"] = time.Now().Add(-time.Hour).Unix() + if w := runMiddleware(t, mw, signRS256(t, priv, "prod-1", claims), ""); w.Code != http.StatusUnauthorized { + t.Fatalf("expired expected 401, got %d", w.Code) } } -func TestRS256Mode_AlgDowngrade(t *testing.T) { +func TestRS256Mode_RejectsHMACAlg(t *testing.T) { srv := fmsgIDServer(t, http.StatusOK, true) defer srv.Close() _, jwks := newRS256Fixture(t) @@ -348,148 +272,127 @@ func TestRS256Mode_AlgDowngrade(t *testing.T) { if err != nil { t.Fatal(err) } - // Sign with HS256 - must be rejected by an RS256-only middleware. tok := signHS256(t, []byte("anything"), rs256Claims("@alice@example.com")) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", w.Code) - } -} - -func TestRS256Mode_Expired(t *testing.T) { - srv := fmsgIDServer(t, http.StatusOK, true) - defer srv.Close() - priv, jwks := newRS256Fixture(t) - mw, err := New(rs256Config(srv.URL, jwks)) - if err != nil { - t.Fatal(err) - } - claims := rs256Claims("@alice@example.com") - claims["exp"] = time.Now().Add(-time.Hour).Unix() - tok := signRS256(t, priv, "prod-1", claims) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { + if w := runMiddleware(t, mw, tok, ""); w.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", w.Code) } } -func TestRS256Mode_Reuse(t *testing.T) { - srv := fmsgIDServer(t, http.StatusOK, true) - defer srv.Close() - priv, jwks := newRS256Fixture(t) - mw, err := New(rs256Config(srv.URL, jwks)) - if err != nil { - t.Fatal(err) - } - tok := signRS256(t, priv, "prod-1", rs256Claims("@alice@example.com")) - - if w := runMiddleware(t, mw, tok); w.Code != http.StatusOK { - t.Fatalf("first call expected 200, got %d", w.Code) - } - if w := runMiddleware(t, mw, tok); w.Code != http.StatusOK { - t.Fatalf("reuse expected 200, got %d", w.Code) - } -} - func TestRS256Mode_ConfigValidation(t *testing.T) { _, jwks := newRS256Fixture(t) - if _, err := NewVerifier(Config{Mode: ModeRS256, Issuer: testIssuer, Audience: testAudience, AddressClaim: testAddressClaim}); err == nil { - t.Error("missing JWKS: expected error") + if _, err := NewVerifier(Config{Issuer: testIssuer, Audience: testAudience, AddressClaim: testAddressClaim}); err == nil { + t.Error("missing auth modes: expected error") } - if _, err := NewVerifier(Config{Mode: ModeRS256, JWKS: jwks, Audience: testAudience, AddressClaim: testAddressClaim}); err == nil { + if _, err := NewVerifier(Config{JWKS: jwks, Audience: testAudience, AddressClaim: testAddressClaim}); err == nil { t.Error("missing Issuer: expected error") } - if _, err := NewVerifier(Config{Mode: ModeRS256, JWKS: jwks, Issuer: testIssuer, AddressClaim: testAddressClaim}); err == nil { + if _, err := NewVerifier(Config{JWKS: jwks, Issuer: testIssuer, AddressClaim: testAddressClaim}); err == nil { t.Error("missing Audience: expected error") } - if _, err := NewVerifier(Config{Mode: ModeRS256, JWKS: jwks, Issuer: testIssuer, Audience: testAudience}); err == nil { + if _, err := NewVerifier(Config{JWKS: jwks, Issuer: testIssuer, Audience: testAudience}); err == nil { t.Error("missing AddressClaim: expected error") } } -func TestRS256Mode_FmsgIDNotFound(t *testing.T) { - fmsgIDCache.Delete("@alice@example.com") - defer fmsgIDCache.Delete("@alice@example.com") +func TestRS256Mode_FmsgIDFailures(t *testing.T) { + priv, jwks := newRS256Fixture(t) + fmsgIDCache.Delete("@alice@example.com") srv := fmsgIDServer(t, http.StatusNotFound, false) - defer srv.Close() - priv, jwks := newRS256Fixture(t) mw, err := New(rs256Config(srv.URL, jwks)) if err != nil { t.Fatal(err) } tok := signRS256(t, priv, "prod-1", rs256Claims("@alice@example.com")) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d body=%s", w.Code, w.Body.String()) + if w := runMiddleware(t, mw, tok, ""); w.Code != http.StatusBadRequest { + t.Fatalf("not found expected 400, got %d", w.Code) } -} + srv.Close() -func TestRS256Mode_FmsgIDNotAccepting(t *testing.T) { fmsgIDCache.Delete("@alice@example.com") - defer fmsgIDCache.Delete("@alice@example.com") - - srv := fmsgIDServer(t, http.StatusOK, false) + srv = fmsgIDServer(t, http.StatusOK, false) defer srv.Close() - priv, jwks := newRS256Fixture(t) - mw, err := New(rs256Config(srv.URL, jwks)) + mw, err = New(rs256Config(srv.URL, jwks)) if err != nil { t.Fatal(err) } - tok := signRS256(t, priv, "prod-1", rs256Claims("@alice@example.com")) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusForbidden { - t.Fatalf("expected 403, got %d body=%s", w.Code, w.Body.String()) + if w := runMiddleware(t, mw, tok, ""); w.Code != http.StatusForbidden { + t.Fatalf("not accepting expected 403, got %d", w.Code) } } -func TestVerifier_Authenticate(t *testing.T) { +func TestAPITokenMode_Happy(t *testing.T) { srv := fmsgIDServer(t, http.StatusOK, true) defer srv.Close() - secret := []byte("dev-secret") - v, err := NewVerifier(Config{Mode: ModeHMAC, HMACKey: secret, IDURL: srv.URL}) + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + issuer := apiauth.NewTokenIssuer(priv, apiauth.DefaultTokenIssuer, apiauth.DefaultTokenAudience, time.Hour) + token, _, err := issuer.Mint("@alice@example.com", "@alice_bot@example.com", "kid1", time.Now()) if err != nil { - t.Fatalf("NewVerifier: %v", err) + t.Fatal(err) } - - // Valid token is accepted and yields the sub address. - tok := signHS256(t, secret, jwt.MapClaims{ - "sub": "@alice@example.com", - "iat": time.Now().Unix(), - "exp": time.Now().Add(time.Hour).Unix(), + mw, err := New(Config{ + APIPublicKey: issuer.PublicKey(), + APIIssuer: issuer.Issuer(), + APIAudience: issuer.Audience(), + APIKeys: fakeAPIKeys{}, + IDURL: srv.URL, }) - addr, status, _ := v.Authenticate(tok) - if status != http.StatusOK || addr != "@alice@example.com" { - t.Fatalf("valid token: got addr=%q status=%d, want @alice@example.com/200", addr, status) + if err != nil { + t.Fatal(err) + } + if w := runMiddleware(t, mw, token, ""); w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String()) + } +} + +func TestAPITokenMode_RejectsActAsAndRevokedKey(t *testing.T) { + srv := fmsgIDServer(t, http.StatusOK, true) + defer srv.Close() + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + issuer := apiauth.NewTokenIssuer(priv, "", "", time.Hour) + token, _, err := issuer.Mint("@alice@example.com", "@alice_bot@example.com", "kid1", time.Now()) + if err != nil { + t.Fatal(err) } - // Token signed with the wrong secret is rejected. - bad := signHS256(t, []byte("wrong-secret"), jwt.MapClaims{ - "sub": "@alice@example.com", - "exp": time.Now().Add(time.Hour).Unix(), + mw, err := New(Config{ + APIPublicKey: issuer.PublicKey(), + APIKeys: fakeAPIKeys{}, + IDURL: srv.URL, }) - if _, status, _ := v.Authenticate(bad); status != http.StatusUnauthorized { - t.Fatalf("bad signature: expected 401, got %d", status) + if err != nil { + t.Fatal(err) + } + if w := runMiddleware(t, mw, token, "@alice_other@example.com"); w.Code != http.StatusForbidden { + t.Fatalf("act-as expected 403, got %d", w.Code) } - // Token with a malformed sub is rejected. - noaddr := signHS256(t, secret, jwt.MapClaims{ - "sub": "not-an-address", - "exp": time.Now().Add(time.Hour).Unix(), + mw, err = New(Config{ + APIPublicKey: issuer.PublicKey(), + APIKeys: fakeAPIKeys{tokenErr: apiauth.ErrKeyRevoked}, + IDURL: srv.URL, }) - if _, status, _ := v.Authenticate(noaddr); status != http.StatusUnauthorized { - t.Fatalf("invalid addr: expected 401, got %d", status) + if err != nil { + t.Fatal(err) + } + if w := runMiddleware(t, mw, token, ""); w.Code != http.StatusUnauthorized { + t.Fatalf("revoked expected 401, got %d", w.Code) } } -func TestRS256Mode_FmsgIDUnavailable(t *testing.T) { - fmsgIDCache.Delete("@alice@example.com") - srv := fmsgIDServer(t, http.StatusInternalServerError, false) - defer srv.Close() - priv, jwks := newRS256Fixture(t) - mw, err := New(rs256Config(srv.URL, jwks)) +func TestAPITokenMode_ConfigValidation(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } - tok := signRS256(t, priv, "prod-1", rs256Claims("@alice@example.com")) - if w := runMiddleware(t, mw, tok); w.Code != http.StatusServiceUnavailable { - t.Fatalf("expected 503, got %d", w.Code) + if _, err := NewVerifier(Config{APIPublicKey: pub}); err == nil { + t.Fatal("missing API key checker: expected error") } } From abb15d59e83dc1b4bfdfd9d6a32d053c5d94b93f Mon Sep 17 00:00:00 2001 From: markmnl Date: Sun, 14 Jun 2026 16:17:29 +0800 Subject: [PATCH 2/2] no underscore --- internal/apiauth/apikey.go | 2 +- internal/apiauth/apikey_test.go | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/internal/apiauth/apikey.go b/internal/apiauth/apikey.go index 26e00dd..f7a2b99 100644 --- a/internal/apiauth/apikey.go +++ b/internal/apiauth/apikey.go @@ -62,5 +62,5 @@ func randomURLToken(n int) (string, error) { if _, err := rand.Read(b); err != nil { return "", err } - return base64.RawURLEncoding.EncodeToString(b), nil + return base64.RawStdEncoding.EncodeToString(b), nil } diff --git a/internal/apiauth/apikey_test.go b/internal/apiauth/apikey_test.go index 186cb6f..73e0060 100644 --- a/internal/apiauth/apikey_test.go +++ b/internal/apiauth/apikey_test.go @@ -14,6 +14,12 @@ func TestGenerateParseAndHashAPIKey(t *testing.T) { if !strings.HasPrefix(key.Value, KeyPrefix+"_") { t.Fatalf("key prefix = %q", key.Value) } + if strings.Contains(key.ID, "_") || strings.Contains(key.Secret, "_") { + t.Fatalf("key components must not contain delimiter: id=%q secret=%q", key.ID, key.Secret) + } + if got := strings.Count(key.Value, "_"); got != 2 { + t.Fatalf("key delimiter count = %d, want 2 in %q", got, key.Value) + } parsed, err := ParseAPIKey(key.Value) if err != nil { t.Fatal(err)