diff --git a/cmd/serve.go b/cmd/serve.go index 34052590d..0bb9540d1 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -488,7 +488,7 @@ func buildAPIDependencies( metaschemaRepository := postgres.NewMetaSchemaRepository(logger, dbc) metaschemaService := metaschema.NewService(metaschemaRepository) - userPATService := userpat.NewService(logger, userPATRepo, cfg.App.PAT, organizationService, roleService, policyService, projectService, auditRecordRepository) + userPATService := userpat.NewService(logger, userPATRepo, cfg.App.PAT, organizationService, roleService, membershipService, projectService, auditRecordRepository) membershipService.SetUserPATService(userPATService) patAlertService := userpat.NewAlertService(userPATRepo, userService, organizationService, mailDialer, dbc, cfg.App.PAT.Alert, logger, auditRecordRepository) auditRecordService := auditrecord.NewService(auditRecordRepository, userService, serviceUserService, sessionService, userPATService) diff --git a/core/userpat/errors/errors.go b/core/userpat/errors/errors.go index b403dd7c4..b4e4389ff 100644 --- a/core/userpat/errors/errors.go +++ b/core/userpat/errors/errors.go @@ -17,4 +17,5 @@ var ( ErrScopeMismatch = errors.New("role does not support the specified scope") ErrRoleNotFound = errors.New("one or more requested roles do not exist") ErrProjectForbidden = errors.New("user does not have access to one or more specified projects") + ErrDuplicateScope = errors.New("only one role per resource type is allowed") ) diff --git a/core/userpat/mocks/membership_service.go b/core/userpat/mocks/membership_service.go new file mode 100644 index 000000000..6ea366929 --- /dev/null +++ b/core/userpat/mocks/membership_service.go @@ -0,0 +1,293 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + policy "github.com/raystack/frontier/core/policy" + mock "github.com/stretchr/testify/mock" +) + +// MembershipService is an autogenerated mock type for the MembershipService type +type MembershipService struct { + mock.Mock +} + +type MembershipService_Expecter struct { + mock *mock.Mock +} + +func (_m *MembershipService) EXPECT() *MembershipService_Expecter { + return &MembershipService_Expecter{mock: &_m.Mock} +} + +// ListPoliciesByPrincipal provides a mock function with given fields: ctx, principalID, principalType +func (_m *MembershipService) ListPoliciesByPrincipal(ctx context.Context, principalID string, principalType string) ([]policy.Policy, error) { + ret := _m.Called(ctx, principalID, principalType) + + if len(ret) == 0 { + panic("no return value specified for ListPoliciesByPrincipal") + } + + var r0 []policy.Policy + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) ([]policy.Policy, error)); ok { + return rf(ctx, principalID, principalType) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) []policy.Policy); ok { + r0 = rf(ctx, principalID, principalType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]policy.Policy) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, principalID, principalType) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MembershipService_ListPoliciesByPrincipal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListPoliciesByPrincipal' +type MembershipService_ListPoliciesByPrincipal_Call struct { + *mock.Call +} + +// ListPoliciesByPrincipal is a helper method to define mock.On call +// - ctx context.Context +// - principalID string +// - principalType string +func (_e *MembershipService_Expecter) ListPoliciesByPrincipal(ctx interface{}, principalID interface{}, principalType interface{}) *MembershipService_ListPoliciesByPrincipal_Call { + return &MembershipService_ListPoliciesByPrincipal_Call{Call: _e.mock.On("ListPoliciesByPrincipal", ctx, principalID, principalType)} +} + +func (_c *MembershipService_ListPoliciesByPrincipal_Call) Run(run func(ctx context.Context, principalID string, principalType string)) *MembershipService_ListPoliciesByPrincipal_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MembershipService_ListPoliciesByPrincipal_Call) Return(_a0 []policy.Policy, _a1 error) *MembershipService_ListPoliciesByPrincipal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MembershipService_ListPoliciesByPrincipal_Call) RunAndReturn(run func(context.Context, string, string) ([]policy.Policy, error)) *MembershipService_ListPoliciesByPrincipal_Call { + _c.Call.Return(run) + return _c +} + +// RemoveAllPATPolicies provides a mock function with given fields: ctx, patID +func (_m *MembershipService) RemoveAllPATPolicies(ctx context.Context, patID string) error { + ret := _m.Called(ctx, patID) + + if len(ret) == 0 { + panic("no return value specified for RemoveAllPATPolicies") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_RemoveAllPATPolicies_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveAllPATPolicies' +type MembershipService_RemoveAllPATPolicies_Call struct { + *mock.Call +} + +// RemoveAllPATPolicies is a helper method to define mock.On call +// - ctx context.Context +// - patID string +func (_e *MembershipService_Expecter) RemoveAllPATPolicies(ctx interface{}, patID interface{}) *MembershipService_RemoveAllPATPolicies_Call { + return &MembershipService_RemoveAllPATPolicies_Call{Call: _e.mock.On("RemoveAllPATPolicies", ctx, patID)} +} + +func (_c *MembershipService_RemoveAllPATPolicies_Call) Run(run func(ctx context.Context, patID string)) *MembershipService_RemoveAllPATPolicies_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MembershipService_RemoveAllPATPolicies_Call) Return(_a0 error) *MembershipService_RemoveAllPATPolicies_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_RemoveAllPATPolicies_Call) RunAndReturn(run func(context.Context, string) error) *MembershipService_RemoveAllPATPolicies_Call { + _c.Call.Return(run) + return _c +} + +// SetOrganizationMemberRole provides a mock function with given fields: ctx, orgID, principalID, principalType, roleID +func (_m *MembershipService) SetOrganizationMemberRole(ctx context.Context, orgID string, principalID string, principalType string, roleID string) error { + ret := _m.Called(ctx, orgID, principalID, principalType, roleID) + + if len(ret) == 0 { + panic("no return value specified for SetOrganizationMemberRole") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, orgID, principalID, principalType, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_SetOrganizationMemberRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetOrganizationMemberRole' +type MembershipService_SetOrganizationMemberRole_Call struct { + *mock.Call +} + +// SetOrganizationMemberRole is a helper method to define mock.On call +// - ctx context.Context +// - orgID string +// - principalID string +// - principalType string +// - roleID string +func (_e *MembershipService_Expecter) SetOrganizationMemberRole(ctx interface{}, orgID interface{}, principalID interface{}, principalType interface{}, roleID interface{}) *MembershipService_SetOrganizationMemberRole_Call { + return &MembershipService_SetOrganizationMemberRole_Call{Call: _e.mock.On("SetOrganizationMemberRole", ctx, orgID, principalID, principalType, roleID)} +} + +func (_c *MembershipService_SetOrganizationMemberRole_Call) Run(run func(ctx context.Context, orgID string, principalID string, principalType string, roleID string)) *MembershipService_SetOrganizationMemberRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_SetOrganizationMemberRole_Call) Return(_a0 error) *MembershipService_SetOrganizationMemberRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_SetOrganizationMemberRole_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_SetOrganizationMemberRole_Call { + _c.Call.Return(run) + return _c +} + +// SetPATAllProjectsRole provides a mock function with given fields: ctx, orgID, patID, roleID +func (_m *MembershipService) SetPATAllProjectsRole(ctx context.Context, orgID string, patID string, roleID string) error { + ret := _m.Called(ctx, orgID, patID, roleID) + + if len(ret) == 0 { + panic("no return value specified for SetPATAllProjectsRole") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, orgID, patID, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_SetPATAllProjectsRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetPATAllProjectsRole' +type MembershipService_SetPATAllProjectsRole_Call struct { + *mock.Call +} + +// SetPATAllProjectsRole is a helper method to define mock.On call +// - ctx context.Context +// - orgID string +// - patID string +// - roleID string +func (_e *MembershipService_Expecter) SetPATAllProjectsRole(ctx interface{}, orgID interface{}, patID interface{}, roleID interface{}) *MembershipService_SetPATAllProjectsRole_Call { + return &MembershipService_SetPATAllProjectsRole_Call{Call: _e.mock.On("SetPATAllProjectsRole", ctx, orgID, patID, roleID)} +} + +func (_c *MembershipService_SetPATAllProjectsRole_Call) Run(run func(ctx context.Context, orgID string, patID string, roleID string)) *MembershipService_SetPATAllProjectsRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MembershipService_SetPATAllProjectsRole_Call) Return(_a0 error) *MembershipService_SetPATAllProjectsRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_SetPATAllProjectsRole_Call) RunAndReturn(run func(context.Context, string, string, string) error) *MembershipService_SetPATAllProjectsRole_Call { + _c.Call.Return(run) + return _c +} + +// SetProjectMemberRole provides a mock function with given fields: ctx, projectID, principalID, principalType, roleID +func (_m *MembershipService) SetProjectMemberRole(ctx context.Context, projectID string, principalID string, principalType string, roleID string) error { + ret := _m.Called(ctx, projectID, principalID, principalType, roleID) + + if len(ret) == 0 { + panic("no return value specified for SetProjectMemberRole") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, projectID, principalID, principalType, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_SetProjectMemberRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetProjectMemberRole' +type MembershipService_SetProjectMemberRole_Call struct { + *mock.Call +} + +// SetProjectMemberRole is a helper method to define mock.On call +// - ctx context.Context +// - projectID string +// - principalID string +// - principalType string +// - roleID string +func (_e *MembershipService_Expecter) SetProjectMemberRole(ctx interface{}, projectID interface{}, principalID interface{}, principalType interface{}, roleID interface{}) *MembershipService_SetProjectMemberRole_Call { + return &MembershipService_SetProjectMemberRole_Call{Call: _e.mock.On("SetProjectMemberRole", ctx, projectID, principalID, principalType, roleID)} +} + +func (_c *MembershipService_SetProjectMemberRole_Call) Run(run func(ctx context.Context, projectID string, principalID string, principalType string, roleID string)) *MembershipService_SetProjectMemberRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_SetProjectMemberRole_Call) Return(_a0 error) *MembershipService_SetProjectMemberRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_SetProjectMemberRole_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_SetProjectMemberRole_Call { + _c.Call.Return(run) + return _c +} + +// NewMembershipService creates a new instance of MembershipService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMembershipService(t interface { + mock.TestingT + Cleanup(func()) +}) *MembershipService { + mock := &MembershipService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/userpat/mocks/policy_service.go b/core/userpat/mocks/policy_service.go index 56dbe4f9b..c859bb71a 100644 --- a/core/userpat/mocks/policy_service.go +++ b/core/userpat/mocks/policy_service.go @@ -22,110 +22,6 @@ func (_m *PolicyService) EXPECT() *PolicyService_Expecter { return &PolicyService_Expecter{mock: &_m.Mock} } -// Create provides a mock function with given fields: ctx, pol -func (_m *PolicyService) Create(ctx context.Context, pol policy.Policy) (policy.Policy, error) { - ret := _m.Called(ctx, pol) - - if len(ret) == 0 { - panic("no return value specified for Create") - } - - var r0 policy.Policy - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, policy.Policy) (policy.Policy, error)); ok { - return rf(ctx, pol) - } - if rf, ok := ret.Get(0).(func(context.Context, policy.Policy) policy.Policy); ok { - r0 = rf(ctx, pol) - } else { - r0 = ret.Get(0).(policy.Policy) - } - - if rf, ok := ret.Get(1).(func(context.Context, policy.Policy) error); ok { - r1 = rf(ctx, pol) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// PolicyService_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' -type PolicyService_Create_Call struct { - *mock.Call -} - -// Create is a helper method to define mock.On call -// - ctx context.Context -// - pol policy.Policy -func (_e *PolicyService_Expecter) Create(ctx interface{}, pol interface{}) *PolicyService_Create_Call { - return &PolicyService_Create_Call{Call: _e.mock.On("Create", ctx, pol)} -} - -func (_c *PolicyService_Create_Call) Run(run func(ctx context.Context, pol policy.Policy)) *PolicyService_Create_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(policy.Policy)) - }) - return _c -} - -func (_c *PolicyService_Create_Call) Return(_a0 policy.Policy, _a1 error) *PolicyService_Create_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *PolicyService_Create_Call) RunAndReturn(run func(context.Context, policy.Policy) (policy.Policy, error)) *PolicyService_Create_Call { - _c.Call.Return(run) - return _c -} - -// Delete provides a mock function with given fields: ctx, id -func (_m *PolicyService) Delete(ctx context.Context, id string) error { - ret := _m.Called(ctx, id) - - if len(ret) == 0 { - panic("no return value specified for Delete") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = rf(ctx, id) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// PolicyService_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' -type PolicyService_Delete_Call struct { - *mock.Call -} - -// Delete is a helper method to define mock.On call -// - ctx context.Context -// - id string -func (_e *PolicyService_Expecter) Delete(ctx interface{}, id interface{}) *PolicyService_Delete_Call { - return &PolicyService_Delete_Call{Call: _e.mock.On("Delete", ctx, id)} -} - -func (_c *PolicyService_Delete_Call) Run(run func(ctx context.Context, id string)) *PolicyService_Delete_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string)) - }) - return _c -} - -func (_c *PolicyService_Delete_Call) Return(_a0 error) *PolicyService_Delete_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *PolicyService_Delete_Call) RunAndReturn(run func(context.Context, string) error) *PolicyService_Delete_Call { - _c.Call.Return(run) - return _c -} - // List provides a mock function with given fields: ctx, flt func (_m *PolicyService) List(ctx context.Context, flt policy.Filter) ([]policy.Policy, error) { ret := _m.Called(ctx, flt) diff --git a/core/userpat/mocks/project_service.go b/core/userpat/mocks/project_service.go index 3dd430db6..80b1b68c6 100644 --- a/core/userpat/mocks/project_service.go +++ b/core/userpat/mocks/project_service.go @@ -5,9 +5,8 @@ package mocks import ( context "context" - mock "github.com/stretchr/testify/mock" - project "github.com/raystack/frontier/core/project" + mock "github.com/stretchr/testify/mock" ) // ProjectService is an autogenerated mock type for the ProjectService type @@ -94,4 +93,4 @@ func NewProjectService(t interface { t.Cleanup(func() { mock.AssertExpectations(t) }) return mock -} \ No newline at end of file +} diff --git a/core/userpat/service.go b/core/userpat/service.go index 98472e6b1..3fc2d8a87 100644 --- a/core/userpat/service.go +++ b/core/userpat/service.go @@ -43,10 +43,12 @@ type RoleService interface { List(ctx context.Context, f role.Filter) ([]role.Role, error) } -type PolicyService interface { - Create(ctx context.Context, pol policy.Policy) (policy.Policy, error) - List(ctx context.Context, flt policy.Filter) ([]policy.Policy, error) - Delete(ctx context.Context, id string) error +type MembershipService interface { + SetOrganizationMemberRole(ctx context.Context, orgID, principalID, principalType, roleID string) error + SetPATAllProjectsRole(ctx context.Context, orgID, patID, roleID string) error + SetProjectMemberRole(ctx context.Context, projectID, principalID, principalType, roleID string) error + RemoveAllPATPolicies(ctx context.Context, patID string) error + ListPoliciesByPrincipal(ctx context.Context, principalID, principalType string) ([]policy.Policy, error) } type ProjectService interface { @@ -63,21 +65,22 @@ type Service struct { logger *slog.Logger orgService OrganizationService roleService RoleService - policyService PolicyService + membershipService MembershipService projectService ProjectService auditRecordRepository AuditRecordRepository deniedPerms map[string]struct{} } func NewService(logger *slog.Logger, repo Repository, config Config, orgService OrganizationService, - roleService RoleService, policyService PolicyService, projectService ProjectService, auditRecordRepository AuditRecordRepository) *Service { + roleService RoleService, membershipService MembershipService, + projectService ProjectService, auditRecordRepository AuditRecordRepository) *Service { return &Service{ repo: repo, config: config, logger: logger, orgService: orgService, roleService: roleService, - policyService: policyService, + membershipService: membershipService, projectService: projectService, auditRecordRepository: auditRecordRepository, deniedPerms: config.DeniedPermissionsSet(), @@ -156,7 +159,7 @@ func (s *Service) Delete(ctx context.Context, userID, id string) error { return fmt.Errorf("soft deleting PAT: %w", err) } - if err := s.deletePolicies(ctx, id); err != nil { + if err := s.membershipService.RemoveAllPATPolicies(ctx, id); err != nil { return fmt.Errorf("deleting policies: %w", err) } @@ -232,6 +235,9 @@ func (s *Service) Update(ctx context.Context, toUpdate patmodels.PAT) (patmodels if err != nil { return patmodels.PAT{}, err } + if !existing.ExpiresAt.After(time.Now()) { + return patmodels.PAT{}, paterrors.ErrExpired + } if err := s.validateScopes(ctx, toUpdate.Scopes); err != nil { return patmodels.PAT{}, err @@ -291,7 +297,7 @@ func (s *Service) captureOldScope(ctx context.Context, pat *patmodels.PAT) (stri // replacePolicies deletes existing policies and creates new ones from scopes. // Re-checks PAT existence after delete to guard against concurrent soft-delete. func (s *Service) replacePolicies(ctx context.Context, patID, orgID string, scopes []patmodels.PATScope) error { - if err := s.deletePolicies(ctx, patID); err != nil { + if err := s.membershipService.RemoveAllPATPolicies(ctx, patID); err != nil { return fmt.Errorf("deleting old policies: %w", err) } @@ -317,24 +323,6 @@ func (s *Service) auditUpdate(ctx context.Context, updated patmodels.PAT, toUpda } } -// deletePolicies removes all SpiceDB policies associated with a PAT. -// Each policy.Delete call removes SpiceDB relations first, then hard-deletes the Postgres policy row. -func (s *Service) deletePolicies(ctx context.Context, patID string) error { - policies, err := s.policyService.List(ctx, policy.Filter{ - PrincipalID: patID, - PrincipalType: schema.PATPrincipal, - }) - if err != nil { - return fmt.Errorf("listing policies for PAT %s: %w", patID, err) - } - for _, pol := range policies { - if err := s.policyService.Delete(ctx, pol.ID); err != nil { - return fmt.Errorf("deleting policy %s: %w", pol.ID, err) - } - } - return nil -} - // Create generates a new PAT and returns it with the plaintext value. // The plaintext value is only available at creation time. func (s *Service) Create(ctx context.Context, req CreateRequest) (patmodels.PAT, string, error) { @@ -475,6 +463,7 @@ func (s *Service) validateScopes(ctx context.Context, scopes []patmodels.PATScop roleMap[r.ID] = r } + seen := make(map[string]bool, len(scopes)) for _, sc := range scopes { if !slices.Contains(supportedPATResourceTypes, sc.ResourceType) { return fmt.Errorf("resource type %s: %w", sc.ResourceType, paterrors.ErrUnsupportedScope) @@ -483,6 +472,10 @@ func (s *Service) validateScopes(ctx context.Context, scopes []patmodels.PATScop if !slices.Contains(r.Scopes, sc.ResourceType) { return fmt.Errorf("role %s does not support resource type %s: %w", sc.RoleID, sc.ResourceType, paterrors.ErrScopeMismatch) } + if seen[sc.ResourceType] { + return fmt.Errorf("resource type %s: %w", sc.ResourceType, paterrors.ErrDuplicateScope) + } + seen[sc.ResourceType] = true } return nil } @@ -526,17 +519,25 @@ func (s *Service) validateProjectAccess(ctx context.Context, userID, orgID strin return nil } -// createPolicies creates SpiceDB policies from pre-validated scopes. +// createPolicies writes the PAT's scopes via the membership package. func (s *Service) createPolicies(ctx context.Context, patID, orgID string, scopes []patmodels.PATScope) error { for _, sc := range scopes { switch sc.ResourceType { case schema.OrganizationNamespace: - if err := s.createOrgScopedPolicy(ctx, patID, orgID, sc.RoleID); err != nil { - return err + if err := s.membershipService.SetOrganizationMemberRole(ctx, orgID, patID, schema.PATPrincipal, sc.RoleID); err != nil { + return fmt.Errorf("set org role: %w", err) } case schema.ProjectNamespace: - if err := s.createProjectScopedPolicies(ctx, patID, orgID, sc.RoleID, sc.ResourceIDs); err != nil { - return err + if len(sc.ResourceIDs) == 0 { + if err := s.membershipService.SetPATAllProjectsRole(ctx, orgID, patID, sc.RoleID); err != nil { + return fmt.Errorf("set all-projects role: %w", err) + } + continue + } + for _, pid := range sc.ResourceIDs { + if err := s.membershipService.SetProjectMemberRole(ctx, pid, patID, schema.PATPrincipal, sc.RoleID); err != nil { + return fmt.Errorf("set project role on %s: %w", pid, err) + } } default: return fmt.Errorf("unsupported resource type %s: %w", sc.ResourceType, paterrors.ErrUnsupportedScope) @@ -606,51 +607,10 @@ func (s *Service) validateRolePermissions(roles []role.Role) error { return nil } -// createPATPolicy creates a single SpiceDB policy for a PAT. -func (s *Service) createPATPolicy(ctx context.Context, patID, roleID, resourceID, resourceType, grantRelation string) error { - if _, err := s.policyService.Create(ctx, policy.Policy{ - RoleID: roleID, - ResourceID: resourceID, - ResourceType: resourceType, - PrincipalID: patID, - PrincipalType: schema.PATPrincipal, - GrantRelation: grantRelation, - }); err != nil { - s.logger.Error("failed to create PAT policy", - "pat_id", patID, "role_id", roleID, "resource_id", resourceID, - "resource_type", resourceType, "grant_relation", grantRelation, "error", err) - return err - } - return nil -} - -// createOrgScopedPolicy creates a policy on the org with the default "granted" relation. -func (s *Service) createOrgScopedPolicy(ctx context.Context, patID, orgID, roleID string) error { - return s.createPATPolicy(ctx, patID, roleID, orgID, schema.OrganizationNamespace, schema.RoleGrantRelationName) -} - -// createProjectScopedPolicies creates policies for a project-scoped role. -// If resourceIDs is empty, it creates a single policy on the org with "pat_granted" relation -// (cascades to all projects). Otherwise, it creates one policy per project with default "granted". -func (s *Service) createProjectScopedPolicies(ctx context.Context, patID, orgID, roleID string, resourceIDs []string) error { - if len(resourceIDs) == 0 { - return s.createPATPolicy(ctx, patID, roleID, orgID, schema.OrganizationNamespace, schema.PATGrantRelationName) - } - for _, resourceID := range resourceIDs { - if err := s.createPATPolicy(ctx, patID, roleID, resourceID, schema.ProjectNamespace, schema.RoleGrantRelationName); err != nil { - return err - } - } - return nil -} - // enrichWithScope derives scopes from the PAT's policies. // Groups policies by role ID + resource type to reconstruct PATScope entries. func (s *Service) enrichWithScope(ctx context.Context, pat *patmodels.PAT) error { - policies, err := s.policyService.List(ctx, policy.Filter{ - PrincipalID: pat.ID, - PrincipalType: schema.PATPrincipal, - }) + policies, err := s.membershipService.ListPoliciesByPrincipal(ctx, pat.ID, schema.PATPrincipal) if err != nil { return fmt.Errorf("listing policies for PAT %s: %w", pat.ID, err) } diff --git a/core/userpat/service_test.go b/core/userpat/service_test.go index 498f37859..be832e8cf 100644 --- a/core/userpat/service_test.go +++ b/core/userpat/service_test.go @@ -34,7 +34,7 @@ var defaultConfig = userpat.Config{ MaxLifetime: "8760h", } -func newSuccessMocks(t *testing.T) (*mocks.OrganizationService, *mocks.RoleService, *mocks.PolicyService, *mocks.ProjectService, *mocks.AuditRecordRepository) { +func newSuccessMocks(t *testing.T) (*mocks.OrganizationService, *mocks.RoleService, *mocks.MembershipService, *mocks.ProjectService, *mocks.AuditRecordRepository) { t.Helper() orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). @@ -52,16 +52,22 @@ func newSuccessMocks(t *testing.T) (*mocks.OrganizationService, *mocks.RoleServi Name: "test-role", Scopes: []string{schema.OrganizationNamespace}, }, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.On("Create", mock.Anything, mock.Anything). - Return(policy.Policy{}, nil).Maybe() - policySvc.On("List", mock.Anything, mock.Anything). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + membershipSvc.On("RemoveAllPATPolicies", mock.Anything, mock.Anything). + Return(nil).Maybe() + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, nil).Maybe() - return orgSvc, roleSvc, policySvc, projSvc, auditRepo + return orgSvc, roleSvc, membershipSvc, projSvc, auditRepo } func TestService_Create(t *testing.T) { @@ -253,8 +259,8 @@ func TestService_Create(t *testing.T) { ExpiresAt: futureExpiry, CreatedAt: time.Date(2026, 2, 10, 0, 0, 0, 0, time.UTC), }, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() @@ -285,8 +291,8 @@ func TestService_Create(t *testing.T) { Return(int64(0), nil) repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() @@ -322,8 +328,8 @@ func TestService_Create(t *testing.T) { Return(int64(0), nil) repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() @@ -358,13 +364,13 @@ func TestService_Create(t *testing.T) { Return(int64(0), nil) repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, userpat.Config{ Enabled: true, Prefix: "custom", MaxPerUserPerOrg: 50, MaxLifetime: "8760h", - }, orgSvc, roleSvc, policySvc, nil, auditRepo) + }, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() @@ -389,8 +395,8 @@ func TestService_Create(t *testing.T) { Return(int64(49), nil) repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, }, { @@ -409,8 +415,8 @@ func TestService_Create(t *testing.T) { Return(int64(0), nil) repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, }, } @@ -444,8 +450,8 @@ func TestService_Create_UniquePATs(t *testing.T) { repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil).Times(2) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) req := userpat.CreateRequest{ UserID: "user-1", @@ -479,8 +485,8 @@ func TestService_Create_HashVerification(t *testing.T) { }). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, tokenValue, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", @@ -532,18 +538,11 @@ func TestService_CreatePolicies_OrgScopedRole(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"org-role-1"}}).Return([]role.Role{orgRole}, nil) roleSvc.On("Get", mock.Anything, "org-role-1").Return(orgRole, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().Create(mock.Anything, policy.Policy{ - RoleID: "org-role-1", - ResourceID: "org-1", - ResourceType: schema.OrganizationNamespace, - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - GrantRelation: schema.RoleGrantRelationName, - }).Return(policy.Policy{ID: "pol-1"}, nil) - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() - - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().SetOrganizationMemberRole(mock.Anything, "org-1", "pat-1", schema.PATPrincipal, "org-role-1").Return(nil) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -579,18 +578,11 @@ func TestService_CreatePolicies_ProjectScopedAllProjects(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"proj-role-1"}}).Return([]role.Role{projRole}, nil) roleSvc.On("Get", mock.Anything, "proj-role-1").Return(projRole, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().Create(mock.Anything, policy.Policy{ - RoleID: "proj-role-1", - ResourceID: "org-1", - ResourceType: schema.OrganizationNamespace, - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - GrantRelation: schema.PATGrantRelationName, - }).Return(policy.Policy{ID: "pol-1"}, nil) - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() - - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().SetPATAllProjectsRole(mock.Anything, "org-1", "pat-1", "proj-role-1").Return(nil) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -626,24 +618,10 @@ func TestService_CreatePolicies_ProjectScopedSpecificProjects(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"proj-role-1"}}).Return([]role.Role{projRole}, nil) roleSvc.On("Get", mock.Anything, "proj-role-1").Return(projRole, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().Create(mock.Anything, policy.Policy{ - RoleID: "proj-role-1", - ResourceID: "proj-a", - ResourceType: schema.ProjectNamespace, - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - GrantRelation: schema.RoleGrantRelationName, - }).Return(policy.Policy{ID: "pol-1"}, nil) - policySvc.EXPECT().Create(mock.Anything, policy.Policy{ - RoleID: "proj-role-1", - ResourceID: "proj-b", - ResourceType: schema.ProjectNamespace, - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - GrantRelation: schema.RoleGrantRelationName, - }).Return(policy.Policy{ID: "pol-2"}, nil) - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().SetProjectMemberRole(mock.Anything, "proj-a", "pat-1", schema.PATPrincipal, "proj-role-1").Return(nil) + membershipSvc.EXPECT().SetProjectMemberRole(mock.Anything, "proj-b", "pat-1", schema.PATPrincipal, "proj-role-1").Return(nil) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) projSvc.On("List", mock.Anything, mock.MatchedBy(func(f project.Filter) bool { @@ -652,7 +630,7 @@ func TestService_CreatePolicies_ProjectScopedSpecificProjects(t *testing.T) { {ID: "proj-a"}, {ID: "proj-b"}, }, nil).Maybe() - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, projSvc, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, projSvc, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -681,12 +659,12 @@ func TestService_CreatePolicies_DeniedPermission(t *testing.T) { Scopes: []string{schema.OrganizationNamespace}, }}, nil) - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) cfg := defaultConfig cfg.DeniedPermissions = []string{"app_organization_administer"} - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, cfg, orgSvc, roleSvc, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, cfg, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -714,9 +692,9 @@ func TestService_CreatePolicies_RoleFetchError(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"bad-role"}}). Return(nil, errors.New("role not found")) - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -748,9 +726,9 @@ func TestService_CreatePolicies_UnsupportedScope(t *testing.T) { Scopes: []string{schema.GroupNamespace}, }}, nil) - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -782,9 +760,9 @@ func TestService_CreatePolicies_MissingRoleID(t *testing.T) { Scopes: []string{schema.OrganizationNamespace}, }}, nil) - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -812,8 +790,8 @@ func TestService_CreatePolicies_NoRoles(t *testing.T) { repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", @@ -919,11 +897,31 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { {RoleID: "org-viewer-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "granted"}, }, }, + { + // Both the all-projects pat_granted policy and the org granted policy + // land on (org-1, PAT) — SetOrganizationMemberRole must skip the + // pat_granted row when replacing existing org policies, otherwise the + // project-all-projects access is silently dropped when scopes arrive + // in this order. + name: "ex5: project all-projects first, then org — order does not drop pat_granted", + scopes: []models.PATScope{ + {RoleID: "proj-owner-id", ResourceType: schema.ProjectNamespace}, + {RoleID: "org-mgr-id", ResourceType: schema.OrganizationNamespace}, + }, + roles: []role.Role{ + {ID: "proj-owner-id", Name: "app_project_owner", Permissions: []string{"app_project_get", "app_project_update", "app_project_delete"}, Scopes: []string{schema.ProjectNamespace}}, + {ID: "org-mgr-id", Name: "app_organization_manager", Permissions: []string{"app_organization_get", "app_organization_update"}, Scopes: []string{schema.OrganizationNamespace}}, + }, + want: []wantPolicy{ + {RoleID: "org-mgr-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "granted"}, + {RoleID: "proj-owner-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "pat_granted"}, + }, + }, - // ── Multiple roles of same scope ───────────────────────────────── + // ── Duplicate scopes rejected (1 role per resource type) ───────── { - name: "multiple org roles create separate org policies", + name: "two org-scoped roles rejected", scopes: []models.PATScope{ {RoleID: "org-viewer-id", ResourceType: schema.OrganizationNamespace}, {RoleID: "org-billing-id", ResourceType: schema.OrganizationNamespace}, @@ -932,13 +930,12 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { {ID: "org-viewer-id", Name: "app_organization_viewer", Permissions: []string{"app_organization_get"}, Scopes: []string{schema.OrganizationNamespace}}, {ID: "org-billing-id", Name: "app_organization_billing_viewer", Permissions: []string{"app_organization_billingview"}, Scopes: []string{schema.OrganizationNamespace}}, }, - want: []wantPolicy{ - {RoleID: "org-viewer-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "granted"}, - {RoleID: "org-billing-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "granted"}, - }, + want: nil, + wantErr: true, + wantErrIs: paterrors.ErrDuplicateScope, }, { - name: "multiple project roles, all projects → separate pat_granted policies", + name: "two project-scoped roles (all projects) rejected", scopes: []models.PATScope{ {RoleID: "proj-viewer-id", ResourceType: schema.ProjectNamespace}, {RoleID: "proj-editor-id", ResourceType: schema.ProjectNamespace}, @@ -947,13 +944,12 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { {ID: "proj-viewer-id", Name: "app_project_viewer", Permissions: []string{"app_project_get"}, Scopes: []string{schema.ProjectNamespace}}, {ID: "proj-editor-id", Name: "app_project_editor", Permissions: []string{"app_project_get", "app_project_update"}, Scopes: []string{schema.ProjectNamespace}}, }, - want: []wantPolicy{ - {RoleID: "proj-viewer-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "pat_granted"}, - {RoleID: "proj-editor-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "pat_granted"}, - }, + want: nil, + wantErr: true, + wantErrIs: paterrors.ErrDuplicateScope, }, { - name: "multiple project roles, specific projects → policy per role per project", + name: "two project-scoped roles (specific projects) rejected", scopes: []models.PATScope{ {RoleID: "proj-viewer-id", ResourceType: schema.ProjectNamespace, ResourceIDs: []string{"proj-1", "proj-2"}}, {RoleID: "proj-editor-id", ResourceType: schema.ProjectNamespace, ResourceIDs: []string{"proj-1", "proj-2"}}, @@ -962,12 +958,9 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { {ID: "proj-viewer-id", Name: "app_project_viewer", Permissions: []string{"app_project_get"}, Scopes: []string{schema.ProjectNamespace}}, {ID: "proj-editor-id", Name: "app_project_editor", Permissions: []string{"app_project_get", "app_project_update"}, Scopes: []string{schema.ProjectNamespace}}, }, - want: []wantPolicy{ - {RoleID: "proj-viewer-id", ResourceID: "proj-1", ResourceType: schema.ProjectNamespace, Grant: "granted"}, - {RoleID: "proj-viewer-id", ResourceID: "proj-2", ResourceType: schema.ProjectNamespace, Grant: "granted"}, - {RoleID: "proj-editor-id", ResourceID: "proj-1", ResourceType: schema.ProjectNamespace, Grant: "granted"}, - {RoleID: "proj-editor-id", ResourceID: "proj-2", ResourceType: schema.ProjectNamespace, Grant: "granted"}, - }, + want: nil, + wantErr: true, + wantErrIs: paterrors.ErrDuplicateScope, }, // ── Scope isolation ────────────────────────────────────────────── @@ -1161,15 +1154,45 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { } } - // --- policyService: capture all Create calls + // --- membershipService: capture every write the service makes, translate + // each membership call into the equivalent policy.Policy shape so the + // existing wantPolicy assertions keep working. var captured []policy.Policy - policySvc := mocks.NewPolicyService(t) - policySvc.On("Create", mock.Anything, mock.AnythingOfType("policy.Policy")). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + captured = append(captured, policy.Policy{ + RoleID: args.String(4), + ResourceID: args.String(1), + ResourceType: schema.OrganizationNamespace, + PrincipalID: args.String(2), + PrincipalType: args.String(3), + GrantRelation: schema.RoleGrantRelationName, + }) + }).Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + captured = append(captured, policy.Policy{ + RoleID: args.String(3), + ResourceID: args.String(1), + ResourceType: schema.OrganizationNamespace, + PrincipalID: args.String(2), + PrincipalType: schema.PATPrincipal, + GrantRelation: schema.PATGrantRelationName, + }) + }).Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { - captured = append(captured, args.Get(1).(policy.Policy)) - }). - Return(policy.Policy{ID: "pol-gen"}, nil).Maybe() - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + captured = append(captured, policy.Policy{ + RoleID: args.String(4), + ResourceID: args.String(1), + ResourceType: schema.ProjectNamespace, + PrincipalID: args.String(2), + PrincipalType: args.String(3), + GrantRelation: schema.RoleGrantRelationName, + }) + }).Return(nil).Maybe() + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) projSvc.On("List", mock.Anything, mock.MatchedBy(func(f project.Filter) bool { @@ -1178,7 +1201,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { {ID: "proj-1"}, {ID: "proj-2"}, {ID: "proj-3"}, {ID: "proj-a"}, {ID: "proj-b"}, }, nil).Maybe() - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, cfg, orgSvc, roleSvc, policySvc, projSvc, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, cfg, orgSvc, roleSvc, membershipSvc, projSvc, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -1267,32 +1290,29 @@ func TestService_CreatePolicies_PolicyCreateFailure(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) auditRepo := mocks.NewAuditRecordRepository(t) - orgViewerRole := role.Role{ID: "org-viewer-id", Name: "app_organization_viewer", Permissions: []string{"app_organization_get"}, Scopes: []string{schema.OrganizationNamespace}} - orgBillingRole := role.Role{ID: "org-billing-id", Name: "app_organization_billing", Permissions: []string{"app_organization_billingview"}, Scopes: []string{schema.OrganizationNamespace}} + projViewerRole := role.Role{ID: "proj-viewer-id", Name: "app_project_viewer", Permissions: []string{"app_project_get"}, Scopes: []string{schema.ProjectNamespace}} roleSvc := mocks.NewRoleService(t) - roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"org-viewer-id", "org-billing-id"}}). - Return([]role.Role{orgViewerRole, orgBillingRole}, nil) - roleSvc.On("Get", mock.Anything, "org-viewer-id").Return(orgViewerRole, nil).Maybe() - roleSvc.On("Get", mock.Anything, "org-billing-id").Return(orgBillingRole, nil).Maybe() - - // first policy Create succeeds, second fails - policySvc := mocks.NewPolicyService(t) - policySvc.On("Create", mock.Anything, mock.MatchedBy(func(p policy.Policy) bool { - return p.RoleID == "org-viewer-id" - })).Return(policy.Policy{ID: "pol-1"}, nil) - policySvc.On("Create", mock.Anything, mock.MatchedBy(func(p policy.Policy) bool { - return p.RoleID == "org-billing-id" - })).Return(policy.Policy{}, errors.New("spicedb unavailable")) - - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"proj-viewer-id"}}). + Return([]role.Role{projViewerRole}, nil) + roleSvc.On("Get", mock.Anything, "proj-viewer-id").Return(projViewerRole, nil).Maybe() + + // one scope with two project IDs invokes SetProjectMemberRole twice; first succeeds, second fails + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().SetProjectMemberRole(mock.Anything, "proj-1", "pat-1", schema.PATPrincipal, "proj-viewer-id").Return(nil) + membershipSvc.EXPECT().SetProjectMemberRole(mock.Anything, "proj-2", "pat-1", schema.PATPrincipal, "proj-viewer-id").Return(errors.New("spicedb unavailable")) + + projSvc := mocks.NewProjectService(t) + projSvc.On("List", mock.Anything, mock.Anything). + Return([]project.Project{{ID: "proj-1"}, {ID: "proj-2"}}, nil).Maybe() + + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, projSvc, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", Title: "fail-token", Scopes: []models.PATScope{ - {RoleID: "org-viewer-id", ResourceType: schema.OrganizationNamespace}, - {RoleID: "org-billing-id", ResourceType: schema.OrganizationNamespace}, + {RoleID: "proj-viewer-id", ResourceType: schema.ProjectNamespace, ResourceIDs: []string{"proj-1", "proj-2"}}, }, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -1636,10 +1656,10 @@ func TestService_Get(t *testing.T) { patID: "pat-1", setup: func() *userpat.Service { repo := mocks.NewRepository(t) - orgSvc, _, policySvc, _, auditRepo := newSuccessMocks(t) + orgSvc, _, membershipSvc, _, auditRepo := newSuccessMocks(t) return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, userpat.Config{ Enabled: false, - }, orgSvc, nil, policySvc, nil, auditRepo) + }, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: true, wantErrIs: paterrors.ErrDisabled, @@ -1652,8 +1672,8 @@ func TestService_Get(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(models.PAT{}, paterrors.ErrNotFound) - orgSvc, _, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + orgSvc, _, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: true, wantErrIs: paterrors.ErrNotFound, @@ -1666,8 +1686,8 @@ func TestService_Get(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(testPAT, nil) - orgSvc, _, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + orgSvc, _, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: true, wantErrIs: paterrors.ErrNotFound, @@ -1680,12 +1700,12 @@ func TestService_Get(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(testPAT, nil) - orgSvc, _, policySvc, _, auditRepo := newSuccessMocks(t) - policySvc.On("List", mock.Anything, mock.Anything). + orgSvc, _, membershipSvc, _, auditRepo := newSuccessMocks(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return([]policy.Policy{ {RoleID: "role-1", ResourceType: "app/organization", ResourceID: "org-1"}, }, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -1698,11 +1718,11 @@ func TestService_Get(t *testing.T) { repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(testPAT, nil) orgSvc := mocks.NewOrganizationService(t) - policySvc := mocks.NewPolicyService(t) - policySvc.On("List", mock.Anything, mock.Anything). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return(nil, errors.New("spicedb down")) auditRepo := mocks.NewAuditRecordRepository(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: true, }, @@ -1822,41 +1842,16 @@ func TestService_Delete(t *testing.T) { repo.EXPECT().Delete(mock.Anything, "pat-1"). Return(nil) orgSvc := mocks.NewOrganizationService(t) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return(nil, errors.New("spicedb down")) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().RemoveAllPATPolicies(mock.Anything, "pat-1"). + Return(errors.New("spicedb down")) auditRepo := mocks.NewAuditRecordRepository(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: true, }, { - name: "should return error when policy delete fails after soft-delete", - userID: "user-1", - patID: "pat-1", - setup: func() *userpat.Service { - repo := mocks.NewRepository(t) - repo.EXPECT().GetByID(mock.Anything, "pat-1"). - Return(testPAT, nil) - repo.EXPECT().Delete(mock.Anything, "pat-1"). - Return(nil) - orgSvc := mocks.NewOrganizationService(t) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{{ID: "pol-1"}}, nil) - policySvc.EXPECT().Delete(mock.Anything, "pol-1"). - Return(errors.New("spicedb unavailable")) - auditRepo := mocks.NewAuditRecordRepository(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) - }, - wantErr: true, - }, - { - name: "should delete successfully with policies", + name: "should delete successfully", userID: "user-1", patID: "pat-1", setup: func() *userpat.Service { @@ -1868,45 +1863,12 @@ func TestService_Delete(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{ - {ID: "pol-1"}, - {ID: "pol-2"}, - }, nil) - policySvc.EXPECT().Delete(mock.Anything, "pol-1").Return(nil) - policySvc.EXPECT().Delete(mock.Anything, "pol-2").Return(nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().RemoveAllPATPolicies(mock.Anything, "pat-1").Return(nil) auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) - }, - wantErr: false, - }, - { - name: "should delete successfully with no policies", - userID: "user-1", - patID: "pat-1", - setup: func() *userpat.Service { - repo := mocks.NewRepository(t) - repo.EXPECT().GetByID(mock.Anything, "pat-1"). - Return(testPAT, nil) - repo.EXPECT().Delete(mock.Anything, "pat-1"). - Return(nil) - orgSvc := mocks.NewOrganizationService(t) - orgSvc.On("GetRaw", mock.Anything, mock.Anything). - Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) - auditRepo := mocks.NewAuditRecordRepository(t) - auditRepo.On("Create", mock.Anything, mock.Anything). - Return(auditmodels.AuditRecord{}, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -1923,15 +1885,12 @@ func TestService_Delete(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().RemoveAllPATPolicies(mock.Anything, "pat-1").Return(nil) auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, errors.New("audit db down")) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2039,6 +1998,20 @@ func TestService_Update(t *testing.T) { wantErr: true, wantErrIs: paterrors.ErrNotFound, }, + { + name: "should return ErrExpired when PAT has already expired", + input: defaultInput, + setup: func() *userpat.Service { + expiredPAT := testPAT + expiredPAT.ExpiresAt = time.Now().Add(-time.Hour) + repo := mocks.NewRepository(t) + repo.EXPECT().GetByID(mock.Anything, "pat-1"). + Return(expiredPAT, nil) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, nil, nil, nil, nil) + }, + wantErr: true, + wantErrIs: paterrors.ErrExpired, + }, { name: "should return error when role validation fails", input: defaultInput, @@ -2064,14 +2037,11 @@ func TestService_Update(t *testing.T) { roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, mock.Anything). Return([]role.Role{validRole}, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) repo.EXPECT().Update(mock.Anything, mock.Anything). Return(models.PAT{}, errors.New("db error")) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, policySvc, nil, nil) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, membershipSvc, nil, nil) }, wantErr: true, }, @@ -2085,14 +2055,11 @@ func TestService_Update(t *testing.T) { roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, mock.Anything). Return([]role.Role{validRole}, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) repo.EXPECT().Update(mock.Anything, mock.Anything). Return(models.PAT{}, paterrors.ErrConflict) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, policySvc, nil, nil) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, membershipSvc, nil, nil) }, wantErr: true, wantErrIs: paterrors.ErrConflict, @@ -2107,20 +2074,12 @@ func TestService_Update(t *testing.T) { roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, mock.Anything). Return([]role.Role{validRole}, nil) - policySvc := mocks.NewPolicyService(t) - // captureOldScope call - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil).Once() + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) repo.EXPECT().Update(mock.Anything, mock.Anything). Return(updatedPAT, nil) - // deletePolicies call - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return(nil, errors.New("spicedb down")).Once() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, policySvc, nil, nil) + membershipSvc.EXPECT().RemoveAllPATPolicies(mock.Anything, "pat-1").Return(errors.New("spicedb down")) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, membershipSvc, nil, nil) }, wantErr: true, }, @@ -2129,23 +2088,20 @@ func TestService_Update(t *testing.T) { input: defaultInput, setup: func() *userpat.Service { repo := mocks.NewRepository(t) - // getOwnedPAT repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(testPAT, nil).Once() roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, mock.Anything). Return([]role.Role{validRole}, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) repo.EXPECT().Update(mock.Anything, mock.Anything). Return(updatedPAT, nil) + membershipSvc.EXPECT().RemoveAllPATPolicies(mock.Anything, "pat-1").Return(nil) // TOCTOU re-check returns not found (concurrent delete) repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(models.PAT{}, paterrors.ErrNotFound).Once() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, policySvc, nil, nil) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, membershipSvc, nil, nil) }, wantErr: true, }, @@ -2164,14 +2120,13 @@ func TestService_Update(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) // captureOldScope + enrichWithScope (after update) - policySvc.On("List", mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) - policySvc.On("Create", mock.Anything, mock.Anything). - Return(policy.Policy{}, nil) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) + membershipSvc.On("RemoveAllPATPolicies", mock.Anything, "pat-1").Return(nil) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() repo.EXPECT().Update(mock.Anything, mock.Anything). Return(updatedPAT, nil) // TOCTOU re-check @@ -2180,7 +2135,7 @@ func TestService_Update(t *testing.T) { auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2198,13 +2153,12 @@ func TestService_Update(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.On("List", mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) - policySvc.On("Create", mock.Anything, mock.Anything). - Return(policy.Policy{}, nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) + membershipSvc.On("RemoveAllPATPolicies", mock.Anything, "pat-1").Return(nil) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() repo.EXPECT().Update(mock.Anything, mock.Anything). Return(updatedPAT, nil) repo.EXPECT().GetByID(mock.Anything, "pat-1"). @@ -2212,7 +2166,7 @@ func TestService_Update(t *testing.T) { auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, errors.New("audit db down")) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2369,13 +2323,13 @@ func TestService_Regenerate(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.On("List", mock.Anything, mock.Anything). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return([]policy.Policy{}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2410,13 +2364,13 @@ func TestService_Regenerate(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.On("List", mock.Anything, mock.Anything). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return([]policy.Policy{}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2434,13 +2388,13 @@ func TestService_Regenerate(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.On("List", mock.Anything, mock.Anything). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return([]policy.Policy{}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, errors.New("audit db down")) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2599,9 +2553,11 @@ func TestService_ValidateProjectAccess(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, mock.Anything).Return([]role.Role{ {ID: "role-1", Name: "proj_viewer", Scopes: []string{schema.ProjectNamespace}, Permissions: []string{"app_project_get"}}, }, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.On("Create", mock.Anything, mock.Anything).Return(policy.Policy{}, nil).Maybe() - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) projSvc.On("List", mock.Anything, mock.MatchedBy(func(f project.Filter) bool { return f.OrgID == "org-1" && f.Principal != nil && f.Principal.ID == "user-1" && f.Principal.Type == schema.UserPrincipal @@ -2611,7 +2567,7 @@ func TestService_ValidateProjectAccess(t *testing.T) { auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, projSvc, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, projSvc, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -2638,14 +2594,16 @@ func TestService_ValidateProjectAccess(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, mock.Anything).Return([]role.Role{ {ID: "role-1", Name: "proj_viewer", Scopes: []string{schema.ProjectNamespace}, Permissions: []string{"app_project_get"}}, }, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.On("Create", mock.Anything, mock.Anything).Return(policy.Policy{}, nil).Maybe() - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() // No projectService mock needed — all-projects scope has empty ResourceIDs, skips validation - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -2692,13 +2650,10 @@ func TestService_List(t *testing.T) { Return(models.PATList{ PATs: []models.PAT{{ID: "pat-1", UserID: "user-1", OrgID: "org-1"}}, }, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return(nil, errors.New("policy service down")) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return(nil, errors.New("policy service down")) auditRepo := mocks.NewAuditRecordRepository(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, nil, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, nil, membershipSvc, nil, auditRepo) _, err := svc.List(context.Background(), "user-1", "org-1", nil) if err == nil || !strings.Contains(err.Error(), "enriching PAT scope") { @@ -2715,21 +2670,15 @@ func TestService_List(t *testing.T) { {ID: "pat-2", UserID: "user-1", OrgID: "org-1", Title: "token-2"}, }, }, nil) - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) // enrichWithScope for pat-1 - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{ + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{ {ID: "pol-1", RoleID: "role-1", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, GrantRelation: "granted"}, }, nil) // enrichWithScope for pat-2 - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-2", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-2", schema.PATPrincipal).Return([]policy.Policy{}, nil) auditRepo := mocks.NewAuditRecordRepository(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, nil, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, nil, membershipSvc, nil, auditRepo) result, err := svc.List(context.Background(), "user-1", "org-1", nil) if err != nil { diff --git a/internal/api/v1beta1connect/user_pat.go b/internal/api/v1beta1connect/user_pat.go index 723940713..ed244364f 100644 --- a/internal/api/v1beta1connect/user_pat.go +++ b/internal/api/v1beta1connect/user_pat.go @@ -35,7 +35,8 @@ func (h *ConnectHandler) getLoggedInPrincipalWithUser(ctx context.Context) (*aut // mapPATError maps PAT service errors to Connect RPC error codes. func mapPATError(err error) *connect.Error { switch { - case errors.Is(err, paterrors.ErrDisabled): + case errors.Is(err, paterrors.ErrDisabled), + errors.Is(err, paterrors.ErrExpired): return connect.NewError(connect.CodeFailedPrecondition, err) case errors.Is(err, paterrors.ErrNotFound): return connect.NewError(connect.CodeNotFound, err) @@ -47,6 +48,7 @@ func mapPATError(err error) *connect.Error { errors.Is(err, paterrors.ErrDeniedRole), errors.Is(err, paterrors.ErrUnsupportedScope), errors.Is(err, paterrors.ErrScopeMismatch), + errors.Is(err, paterrors.ErrDuplicateScope), errors.Is(err, paterrors.ErrProjectForbidden), errors.Is(err, paterrors.ErrExpiryInPast), errors.Is(err, paterrors.ErrExpiryExceeded): diff --git a/test/e2e/regression/pat_test.go b/test/e2e/regression/pat_test.go index cec4be759..e372702ab 100644 --- a/test/e2e/regression/pat_test.go +++ b/test/e2e/regression/pat_test.go @@ -748,6 +748,34 @@ func (s *PATRegressionTestSuite) TestPATCRUD_CreateErrors() { s.Assert().Error(err) s.Assert().Equal(connect.CodeInvalidArgument, connect.CodeOf(err)) }) + + s.Run("two org-scoped roles in one request", func() { + _, err := s.testBench.Client.CreateCurrentUserPAT(ctxAdmin, connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ + Title: "two-org-scopes", + OrgId: orgID, + Scopes: []*frontierv1beta1.PATScope{ + {RoleId: s.roleID(schema.RoleOrganizationViewer), ResourceType: schema.OrganizationNamespace}, + {RoleId: s.roleID(schema.RoleOrganizationManager), ResourceType: schema.OrganizationNamespace}, + }, + ExpiresAt: timestamppb.New(time.Now().Add(24 * time.Hour)), + })) + s.Assert().Error(err) + s.Assert().Equal(connect.CodeInvalidArgument, connect.CodeOf(err)) + }) + + s.Run("two project-scoped roles in one request", func() { + _, err := s.testBench.Client.CreateCurrentUserPAT(ctxAdmin, connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ + Title: "two-project-scopes", + OrgId: orgID, + Scopes: []*frontierv1beta1.PATScope{ + {RoleId: s.roleID(schema.RoleProjectViewer), ResourceType: schema.ProjectNamespace}, + {RoleId: s.roleID(schema.RoleProjectOwner), ResourceType: schema.ProjectNamespace}, + }, + ExpiresAt: timestamppb.New(time.Now().Add(24 * time.Hour)), + })) + s.Assert().Error(err) + s.Assert().Equal(connect.CodeInvalidArgument, connect.CodeOf(err)) + }) } func TestEndToEndPATRegressionTestSuite(t *testing.T) {