-
Notifications
You must be signed in to change notification settings - Fork 15
Add Infisical KMS provider #88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,211 @@ | ||
| package infisical | ||
|
|
||
| import ( | ||
| "bytes" | ||
| "encoding/base64" | ||
| "encoding/json" | ||
| "errors" | ||
| "fmt" | ||
| "io/ioutil" | ||
| "net/http" | ||
| "net/url" | ||
| "strings" | ||
| "sync" | ||
| "time" | ||
| ) | ||
|
|
||
| var ErrInsecureSiteURL = errors.New("infisical-kms: INFISICAL_SITE_URL must use https://") | ||
|
|
||
| const tokenExpiryBuffer = 5 * time.Second | ||
|
|
||
| type kmsEncryptDecrypter interface { | ||
| encrypt(plaintext string) (string, error) | ||
| decrypt(ciphertext string) (string, error) | ||
| } | ||
|
|
||
| type kmsClient struct { | ||
| httpClient *http.Client | ||
| baseURL string | ||
| kmsKeyID string | ||
| clientID string | ||
| clientSecret string | ||
|
|
||
| mu sync.RWMutex | ||
| token string | ||
| expiresAt time.Time | ||
| } | ||
|
|
||
| func newKmsClient(siteURL, kmsKeyID, clientID, clientSecret string) (*kmsClient, error) { | ||
| base := strings.TrimRight(siteURL, "/") | ||
| u, err := url.Parse(base) | ||
| if err != nil || u.Host == "" { | ||
| return nil, fmt.Errorf("infisical-kms: invalid INFISICAL_SITE_URL %q: %w", siteURL, err) | ||
| } | ||
| if !strings.EqualFold(u.Scheme, "https") { | ||
| return nil, ErrInsecureSiteURL | ||
| } | ||
| if !strings.HasSuffix(base, "/api") { | ||
| base += "/api" | ||
| } | ||
| return &kmsClient{ | ||
| httpClient: &http.Client{Timeout: 30 * time.Second}, | ||
| baseURL: base, | ||
| kmsKeyID: kmsKeyID, | ||
| clientID: clientID, | ||
| clientSecret: clientSecret, | ||
| }, nil | ||
| } | ||
|
|
||
| type loginRequest struct { | ||
| ClientID string `json:"clientId"` | ||
| ClientSecret string `json:"clientSecret"` | ||
| } | ||
|
|
||
| type loginResponse struct { | ||
| AccessToken string `json:"accessToken"` | ||
| ExpiresIn int64 `json:"expiresIn"` | ||
| } | ||
|
|
||
| func (c *kmsClient) login() error { | ||
| body, err := json.Marshal(loginRequest{ | ||
| ClientID: c.clientID, | ||
| ClientSecret: c.clientSecret, | ||
| }) | ||
| if err != nil { | ||
| return fmt.Errorf("infisical-kms: failed to marshal login request: %w", err) | ||
| } | ||
|
|
||
| resp, err := c.httpClient.Post( | ||
| c.baseURL+"/v1/auth/universal-auth/login", | ||
| "application/json", | ||
| bytes.NewReader(body), | ||
| ) | ||
| if err != nil { | ||
| return fmt.Errorf("infisical-kms: login request failed: %w", err) | ||
| } | ||
| defer resp.Body.Close() | ||
|
|
||
| if resp.StatusCode != http.StatusOK { | ||
| msg, _ := ioutil.ReadAll(resp.Body) | ||
| return fmt.Errorf("infisical-kms: login returned %d: %s", resp.StatusCode, msg) | ||
| } | ||
|
|
||
| var result loginResponse | ||
| if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { | ||
| return fmt.Errorf("infisical-kms: failed to decode login response: %w", err) | ||
| } | ||
|
|
||
| c.mu.Lock() | ||
| c.token = result.AccessToken | ||
| c.expiresAt = time.Now().Add(time.Duration(result.ExpiresIn)*time.Second - tokenExpiryBuffer) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We wont renew the token if it is revoked before this expiry time. Can we check for 401 Unauthorized errors in doKmsRequest and renew the token?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we get a 401 or 403 error (also used for expired tokens in Infisical API), we'll log in again and retry once. |
||
| c.mu.Unlock() | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| func (c *kmsClient) ensureToken() error { | ||
| c.mu.RLock() | ||
| valid := c.token != "" && time.Now().Before(c.expiresAt) | ||
| c.mu.RUnlock() | ||
| if valid { | ||
| return nil | ||
| } | ||
| return c.login() | ||
| } | ||
|
|
||
| func (c *kmsClient) doKmsRequest(path string, reqBody, respBody interface{}) error { | ||
| if err := c.ensureToken(); err != nil { | ||
| return err | ||
| } | ||
|
|
||
| body, err := json.Marshal(reqBody) | ||
| if err != nil { | ||
| return fmt.Errorf("infisical-kms: failed to marshal request: %w", err) | ||
| } | ||
|
|
||
| retried := false | ||
| for { | ||
| status, respBytes, err := c.sendKmsRequest(path, body) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| if status == http.StatusOK { | ||
| if err := json.Unmarshal(respBytes, respBody); err != nil { | ||
| return fmt.Errorf("infisical-kms: failed to decode response: %w", err) | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| if !retried && (status == http.StatusUnauthorized || status == http.StatusForbidden) { | ||
| retried = true | ||
| if err := c.login(); err != nil { | ||
| return fmt.Errorf("infisical-kms: re-authentication failed: %w", err) | ||
| } | ||
| continue | ||
| } | ||
| return fmt.Errorf("infisical-kms: request returned %d: %s", status, respBytes) | ||
| } | ||
| } | ||
|
|
||
| func (c *kmsClient) sendKmsRequest(path string, body []byte) (int, []byte, error) { | ||
| req, err := http.NewRequest(http.MethodPost, c.baseURL+path, bytes.NewReader(body)) | ||
| if err != nil { | ||
| return 0, nil, fmt.Errorf("infisical-kms: failed to create request: %w", err) | ||
| } | ||
| req.Header.Set("Content-Type", "application/json") | ||
|
|
||
| c.mu.RLock() | ||
| req.Header.Set("Authorization", "Bearer "+c.token) | ||
| c.mu.RUnlock() | ||
|
|
||
| resp, err := c.httpClient.Do(req) | ||
| if err != nil { | ||
| return 0, nil, fmt.Errorf("infisical-kms: request failed: %w", err) | ||
| } | ||
| defer resp.Body.Close() | ||
|
|
||
| respBytes, err := ioutil.ReadAll(resp.Body) | ||
| if err != nil { | ||
| return resp.StatusCode, nil, fmt.Errorf("infisical-kms: failed to read response: %w", err) | ||
| } | ||
| return resp.StatusCode, respBytes, nil | ||
| } | ||
|
|
||
| type encryptRequest struct { | ||
| Plaintext string `json:"plaintext"` | ||
| } | ||
|
|
||
| type encryptResponse struct { | ||
| Ciphertext string `json:"ciphertext"` | ||
| } | ||
|
|
||
| func (c *kmsClient) encrypt(plaintext string) (string, error) { | ||
| var resp encryptResponse | ||
| path := fmt.Sprintf("/v1/kms/keys/%s/encrypt", c.kmsKeyID) | ||
| encoded := base64.StdEncoding.EncodeToString([]byte(plaintext)) | ||
| if err := c.doKmsRequest(path, encryptRequest{Plaintext: encoded}, &resp); err != nil { | ||
| return "", err | ||
| } | ||
| return resp.Ciphertext, nil | ||
| } | ||
|
|
||
| type decryptRequest struct { | ||
| Ciphertext string `json:"ciphertext"` | ||
| } | ||
|
|
||
| type decryptResponse struct { | ||
| Plaintext string `json:"plaintext"` | ||
| } | ||
|
|
||
| func (c *kmsClient) decrypt(ciphertext string) (string, error) { | ||
| var resp decryptResponse | ||
| path := fmt.Sprintf("/v1/kms/keys/%s/decrypt", c.kmsKeyID) | ||
| if err := c.doKmsRequest(path, decryptRequest{Ciphertext: ciphertext}, &resp); err != nil { | ||
| return "", err | ||
| } | ||
| decoded, err := base64.StdEncoding.DecodeString(resp.Plaintext) | ||
| if err != nil { | ||
| return "", fmt.Errorf("infisical-kms: failed to base64-decode plaintext: %w", err) | ||
| } | ||
| return string(decoded), nil | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| package infisical | ||
|
|
||
| import ( | ||
| "encoding/json" | ||
| "net/http" | ||
| "net/http/httptest" | ||
| "strings" | ||
| "sync/atomic" | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| func newTestClient(t *testing.T, base string) *kmsClient { | ||
| t.Helper() | ||
| // newKmsClient enforces https, so build the struct directly for httptest URLs. | ||
| return &kmsClient{ | ||
| httpClient: http.DefaultClient, | ||
| baseURL: strings.TrimRight(base, "/") + "/api", | ||
| kmsKeyID: "k-1", | ||
| clientID: "id", | ||
| clientSecret: "secret", | ||
| } | ||
| } | ||
|
|
||
| func writeLogin(w http.ResponseWriter, token string) { | ||
| w.WriteHeader(http.StatusOK) | ||
| _ = json.NewEncoder(w).Encode(loginResponse{AccessToken: token, ExpiresIn: 3600}) | ||
| } | ||
|
|
||
| func writeEncrypt(w http.ResponseWriter, ciphertext string) { | ||
| w.WriteHeader(http.StatusOK) | ||
| _ = json.NewEncoder(w).Encode(encryptResponse{Ciphertext: ciphertext}) | ||
| } | ||
|
|
||
| func retryServer(t *testing.T, firstStatus int, firstBody string) (*httptest.Server, *int32, *int32) { | ||
| t.Helper() | ||
| var loginCalls, encryptCalls int32 | ||
| srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| switch { | ||
| case strings.HasSuffix(r.URL.Path, "/v1/auth/universal-auth/login"): | ||
| n := atomic.AddInt32(&loginCalls, 1) | ||
| writeLogin(w, "token-"+string(rune('0'+n))) | ||
| case strings.Contains(r.URL.Path, "/v1/kms/keys/"): | ||
| n := atomic.AddInt32(&encryptCalls, 1) | ||
| if n == 1 { | ||
| w.WriteHeader(firstStatus) | ||
| _, _ = w.Write([]byte(firstBody)) | ||
| return | ||
| } | ||
| assert.Equal(t, "Bearer token-2", r.Header.Get("Authorization")) | ||
| writeEncrypt(w, "ct-ok") | ||
| default: | ||
| t.Fatalf("unexpected path: %s", r.URL.Path) | ||
| } | ||
| })) | ||
| return srv, &loginCalls, &encryptCalls | ||
| } | ||
|
|
||
| func TestDoKmsRequest_Retries401WithReLogin(t *testing.T) { | ||
| srv, loginCalls, encryptCalls := retryServer(t, http.StatusUnauthorized, | ||
| `{"statusCode":401,"error":"UnauthorizedError","message":"token revoked"}`) | ||
| defer srv.Close() | ||
|
|
||
| c := newTestClient(t, srv.URL) | ||
| require.NoError(t, c.login()) | ||
|
|
||
| ct, err := c.encrypt("hello") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "ct-ok", ct) | ||
| assert.EqualValues(t, 2, atomic.LoadInt32(loginCalls), "should re-login after 401") | ||
| assert.EqualValues(t, 2, atomic.LoadInt32(encryptCalls), "should retry the encrypt call once") | ||
| } | ||
|
|
||
| func TestDoKmsRequest_Retries403TokenErrorWithReLogin(t *testing.T) { | ||
| srv, loginCalls, encryptCalls := retryServer(t, http.StatusForbidden, | ||
| `{"statusCode":403,"error":"TokenError","message":"Your token has expired. Please re-authenticate."}`) | ||
| defer srv.Close() | ||
|
|
||
| c := newTestClient(t, srv.URL) | ||
| require.NoError(t, c.login()) | ||
|
|
||
| ct, err := c.encrypt("hello") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "ct-ok", ct) | ||
| assert.EqualValues(t, 2, atomic.LoadInt32(loginCalls), "should re-login after 403 TokenError") | ||
| assert.EqualValues(t, 2, atomic.LoadInt32(encryptCalls), "should retry once on 403 TokenError") | ||
| } | ||
|
|
||
| func TestDoKmsRequest_DoesNotInfiniteLoopOnRepeated403(t *testing.T) { | ||
| var loginCalls, encryptCalls int32 | ||
| srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| switch { | ||
| case strings.HasSuffix(r.URL.Path, "/v1/auth/universal-auth/login"): | ||
| atomic.AddInt32(&loginCalls, 1) | ||
| writeLogin(w, "tkn") | ||
| case strings.Contains(r.URL.Path, "/v1/kms/keys/"): | ||
| atomic.AddInt32(&encryptCalls, 1) | ||
| w.WriteHeader(http.StatusForbidden) | ||
| _, _ = w.Write([]byte(`{"statusCode":403,"error":"PermissionDenied","message":"missing KMS permission"}`)) | ||
| default: | ||
| t.Fatalf("unexpected path: %s", r.URL.Path) | ||
| } | ||
| })) | ||
| defer srv.Close() | ||
|
|
||
| c := newTestClient(t, srv.URL) | ||
| require.NoError(t, c.login()) | ||
|
|
||
| _, err := c.encrypt("hello") | ||
| require.Error(t, err) | ||
| assert.Contains(t, err.Error(), "403") | ||
| assert.EqualValues(t, 2, atomic.LoadInt32(&encryptCalls), "should retry exactly once on persistent 403") | ||
| assert.EqualValues(t, 2, atomic.LoadInt32(&loginCalls), "should re-login once before giving up") | ||
| } | ||
|
|
||
| func TestDoKmsRequest_DoesNotInfiniteLoopOnRepeated401(t *testing.T) { | ||
| var encryptCalls int32 | ||
|
|
||
| srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| switch { | ||
| case strings.HasSuffix(r.URL.Path, "/v1/auth/universal-auth/login"): | ||
| writeLogin(w, "any-token") | ||
| case strings.Contains(r.URL.Path, "/v1/kms/keys/"): | ||
| atomic.AddInt32(&encryptCalls, 1) | ||
| w.WriteHeader(http.StatusUnauthorized) | ||
| _, _ = w.Write([]byte(`unauthorized`)) | ||
| default: | ||
| t.Fatalf("unexpected path: %s", r.URL.Path) | ||
| } | ||
| })) | ||
| defer srv.Close() | ||
|
|
||
| c := newTestClient(t, srv.URL) | ||
| require.NoError(t, c.login()) | ||
|
|
||
| _, err := c.encrypt("hello") | ||
| require.Error(t, err) | ||
| assert.Contains(t, err.Error(), "401") | ||
| assert.EqualValues(t, 2, atomic.LoadInt32(&encryptCalls), "should retry exactly once on persistent 401") | ||
| } | ||
|
|
||
| func TestDoKmsRequest_HappyPath_NoRetry(t *testing.T) { | ||
| var loginCalls, encryptCalls int32 | ||
|
|
||
| srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| switch { | ||
| case strings.HasSuffix(r.URL.Path, "/v1/auth/universal-auth/login"): | ||
| atomic.AddInt32(&loginCalls, 1) | ||
| writeLogin(w, "good-token") | ||
| case strings.Contains(r.URL.Path, "/v1/kms/keys/"): | ||
| atomic.AddInt32(&encryptCalls, 1) | ||
| writeEncrypt(w, "ct") | ||
| default: | ||
| t.Fatalf("unexpected path: %s", r.URL.Path) | ||
| } | ||
| })) | ||
| defer srv.Close() | ||
|
|
||
| c := newTestClient(t, srv.URL) | ||
| require.NoError(t, c.login()) | ||
|
|
||
| ct, err := c.encrypt("hello") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "ct", ct) | ||
| assert.EqualValues(t, 1, atomic.LoadInt32(&loginCalls), "no extra login when first call succeeds") | ||
| assert.EqualValues(t, 1, atomic.LoadInt32(&encryptCalls)) | ||
| } | ||
|
|
||
| func TestNewKmsClient_RejectsHTTP(t *testing.T) { | ||
| _, err := newKmsClient("http://infisical.internal", "k", "id", "sec") | ||
| assert.ErrorIs(t, err, ErrInsecureSiteURL) | ||
| } | ||
|
|
||
| func TestNewKmsClient_AcceptsHTTPS(t *testing.T) { | ||
| c, err := newKmsClient("https://app.infisical.com", "k", "id", "sec") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "https://app.infisical.com/api", c.baseURL) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Infisical support TLSConfig? Or can we enforce an "https" URL check to ensure that clientSecret is not passed as plaintext over the wire
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update the code to add HTTPS enforcement.